嘘~ 正在从服务器偷取页面 . . .

CF17E Palisection 题解


CF17E Palisection 题解

题目链接:Palisection

题意

给定一个长度为 $n$ 的小写字母串。

问有多少对相交的回文子串(包含也算相交) 。

输入格式

第一行是字符串长度 $n$,第二行字符串。

输出格式

相交的回文子串个数 $\bmod 51123987$​ 。

数据范围

$1\le n\le 2\times10^6$

相交的不太好算,那就算 回文串的数量 减去 不相交的回文串数量

容易发现,以 $i$ 结尾的回文串的不相交回文串,就是以 $j(j > i)$ 为开头的回文串的总数

考虑把原字符串翻转一下,建个回文自动机,然后用后缀和优化一下

再把串复原,跑一遍回文自动机,这样就可以用乘法原理算出答案了。

不过这道题会卡空间,所以我们需要用邻接表或者链式前向星来写,每次暴力找儿子

时间复杂度 $\mathcal{O}(26n)$ ,空间复杂度 $\mathcal{O}(n)$ 。

代码:

#include <bits/stdc++.h>
using namespace std;
// #define int long long
// #define INF 0x3f3f3f3f3f3f3f3f
typedef long long ll;
const ll mod = 51123987;
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
void add(ll &x, ll y) { (x += y) >= mod ? x -= mod : 0; }
#define N ((int)(2e6 + 15))

char s[N];
ll res, cnt[N], sum[N];
struct Edge{ int v, w, next; } e[N];
int n, tot, pos = 1, head[N], fail[N], len[N];
int getson(int p, int c)
{
    for(int i = head[p]; i; i = e[i].next)
        if(e[i].w == c) return e[i].v;
    return -1;
}
int getfail(int p, int i)
{
    while(s[i - len[p] - 1] != s[i]) p = fail[p];
    return p;
}
void insert(const int i, const int c, const int type)
{
    static int p; p = i > 1 ? getfail(p, i) : 0;
    int to = getson(p, c);
    if(to == -1)
    {
        to = ++tot; len[tot] = len[p] + 2; 
        e[++pos] = {to, c, head[p]}; head[p] = pos;
        fail[tot] = p > 0 ? getson(getfail(fail[p], i), c) : 1;
        cnt[tot] = cnt[fail[tot]] + 1;
    }
    p = to;
    if(!type) add(sum[n - i + 1], cnt[p]);
    else add(res, cnt[p] * sum[i + 1] % mod);
}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    // freopen("check.in","r",stdin);
    // freopen("check.out","w",stdout);
    cin >> n >> (s + 1); sum[n + 1] = 0;
    fail[tot = 1] = 0; len[0] = -1; reverse(s + 1, s + 1 + n);
    for(int i = 1; i <= n; i++) insert(i, s[i] - 'a', 0);
    for(int i = n; i; i--) add(sum[i], sum[i + 1]);
    pos = 1; for(int i = 0; i <= tot; i++) head[i] = 0;
    fail[tot = 1] = 0; len[0] = -1; reverse(s + 1, s + 1 + n);
    for(int i = 1; i <= n; i++) insert(i, s[i] - 'a', 1);
    cout << ((sum[1] * (sum[1] - 1) / 2) % mod - res + mod) % mod << '\n';
    return 0;
}

参考文献

[1] https://www.luogu.com.cn/article/9dx9nxzp


文章作者: q779
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-ND 4.0 许可协议。转载请注明来源 q779 !
评论
  目录