CF1202E You Are Given Some Strings... 题解
题目链接:CF1202E You Are Given Some Strings...
题意:
给定文本串 \(t\) 和 \(n\) 个模式串 \(s_i\) ,求 \[ \sum_{i=1}^n\sum_{j=1}^nf(s_i + s_j) \] \(f(s)\) 定义为 \(s\) 在 \(t\) 中出现的次数,\(s_i+s_j\) 定义为拼接字符串 \(s_i\) 和 \(s_j\) 。
输入格式:
一行 \(t\) ,一行 \(n\) ,接下来 \(n\) 行 \(s_i\) 。
输出格式:
一行,答案。
数据范围:
\(1\le |t|,\sum |s_i| \le 2\times 10^5\) 。
这种题肯定是考虑某个串的贡献啦。
这里我们考虑枚举划分点 \(x\) 。记该点前为 \(s_i\) ,该点后为 \(s_j\)
统计出每个 \(x\) 的 \(f(x),g(x+1)\) 分别表示前后的个数。
怎么统计呢?注意到 \(f(x)\) 其实就是 \(t_{1 \cdots x}\) 的后缀中有几个已知子串,\(g(x)\) 就是反串的 \(f(x)\)
这个东西可以用 AC 自动机来搞。就是直接在模式串上建 AC 自动机然后跑一遍 \(t\) 。
考虑 fail 指针指向的是根到当前串的最长后缀,则所求等价于为模式串的后缀的数量。
根据 AC 自动机的板子可以知道,这个只要在 fail 树上求个和就好了。
时间复杂度 \(\mathcal{O}(\sum |s_i| \times |\Sigma|)\)
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
void up(int &x,int y) { x < y ? x = y : 0; }
void down(int &x,int y) { x > y ? x = y : 0; }
#define N ((int)(4e5+15))
struct Trie
{
queue<int> q;
int tot,trie[N][26],ed[N],fail[N];
void insert(char *s)
{
int u = 0;
for(int i=1; s[i]; i++)
{
int c = s[i] - 'a';
if(!trie[u][c]) trie[u][c] = ++tot;
u = trie[u][c];
}
++ed[u];
}
void build()
{
for(int i=0; i<26; i++) if(trie[0][i]) q.push(trie[0][i]);
for(int u; !q.empty(); )
{
u = q.front(); q.pop();
for(int i=0; i<26; i++)
{
if(trie[u][i])
{
fail[trie[u][i]] = trie[fail[u]][i];
ed[trie[u][i]] += ed[fail[trie[u][i]]]; q.push(trie[u][i]);
}else trie[u][i] = trie[fail[u]][i];
}
}
}
void query(char *s,int *f)
{
int u = 0;
for(int i=1; s[i]; i++)
{ u = trie[u][s[i] - 'a']; f[i] = ed[u]; }
}
}tr1,tr2;
char s[N],t[N]; int f[N],g[N];
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
int n,l; cin >> (t + 1) >> n; l = strlen(t + 1);
for(int i=1; i<=n; i++)
{
cin >> (s + 1); tr1.insert(s);
reverse(s + 1, s + 1 + strlen(s + 1)); tr2.insert(s);
}
tr1.build(); tr2.build(); tr1.query(t,f);
reverse(t + 1, t + 1 + l); tr2.query(t,g);
// for(int i=1; i<=l; i++) cout << f1[i] << " \n"[i==l];
// for(int i=1; i<=l; i++) cout << f2[i] << " \n"[i==l];
int res = 0;
for(int i=1; i<=l; i++) res += f[i] * g[l - i];
cout << res << '\n';
return 0;
}
参考文献:
[1] https://www.luogu.com.cn/blog/c2522943959/solution-cf1202e