嘘~ 正在从服务器偷取页面 . . .

洛谷P5357 【模板】AC 自动机(二次加强版)题解


洛谷P5357 【模板】AC 自动机(二次加强版)题解

题目链接:P5357 【模板】AC 自动机(二次加强版)

题意:$n$ 个模式串 $s_i$(不保证互异),要求输出这些模式串在文本串 $S$ 中出现的次数

建议大家先去做下加强版的,题解在此

我们已经在加强版初步解决了次数统计的问题

可以发现本题的数据范围 $n\le2\times10^5,\sum |s_i|\le2\times10^5,|S|\le2\times10^5$

而原来算法的时间复杂度是 $O\left(|S|\left|\max\{s_i\}\right|\right)$,T飞了

那么考虑怎么优化暴力跳fail的问题

注意到所有fail连出的有向边构成了一个DAG(有向无环图)

证明很简单,最长后缀一定是单调递减的

我们把这个DAG看作一棵树

那么所有的儿子结点一定会跳到父亲结点,并使父亲结点权值增加1

解法一:直接树形dp统计答案

这个我没写代码 qwq

解法二:拓扑排序

我们只要在拓扑排序的过程中统计答案即可

这样我们就可以把时间复杂度压到 $O\left(\sum|s_i|+|S|\right)$ 了!

其他注意点:

由于可能存在相同的模式串,显然它们的出现次数相同

那我们原来的e[u]=id就不可用了

咋办?并查集啊!

而在本题中较为特殊,合并产生的图一定是个菊花图

所以不用并查集,直接用数组也可(这样常数小一点)

但是我一开始写的并查集懒地改,就这样吧 qwq 反正影响很小

代码如下:

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define N (int)(2e5+5)
#define L (int)(2e6+5)
char t[L],s[N];
int n,ans[N],e[N],val[N],in[N];
int trie[N][32],tot,fail[N],f[N];
void init(){for(int i=1; i<=n; i++)f[i]=i;}
int find(int x){return f[x]==x?x:f[x]=find(f[x]);}
void merge(int u,int v){f[find(u)]=find(v);}
void insert(int l,char *s,int id)
{
	int u=0;
	for(int i=1; i<=l; i++)
	{
		int c=s[i]-'a';
		if(!trie[u][c])trie[u][c]=++tot;
		u=trie[u][c];
	}
	if(!e[u])e[u]=id;
	else merge(id,e[u]);
}
queue<int>q;
void build()
{
	for(int i=0; i<26; i++)
		if(trie[0][i])q.push(trie[0][i]);
	while(!q.empty())
	{
		int u=q.front();q.pop();
		for(int i=0; i<26; i++)
		{
			if(trie[u][i])
			{
				fail[trie[u][i]]=trie[fail[u]][i];
				++in[trie[fail[u]][i]];
				q.push(trie[u][i]);
			}else trie[u][i]=trie[fail[u]][i];
		}
	}
}
void AC(int l,char *t)
{
	int u=0;
	for(int i=1; i<=l; i++)
	{
		u=trie[u][t[i]-'a'];
		++val[u];
	}
	for(int i=1; i<=tot; i++)
		if(!in[i])q.push(i);
	while(!q.empty())
	{
		int u=q.front();q.pop();
		if(e[u])ans[e[u]]=val[u];
		val[fail[u]]+=val[u];
		if(!--in[fail[u]])q.push(fail[u]);
	}
}
signed main()
{
	scanf("%lld",&n); init();
	for(int i=1; i<=n; i++)
	{
		scanf("%s\n",s+1);
		insert(strlen(s+1),s,i);
	}
	scanf("%s\n",t+1);
	build();
	AC(strlen(t+1),t);
	for(int i=1; i<=n; i++)
		printf("%lld\n",ans[find(i)]);
	return 0;
}

文章作者: q779
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-ND 4.0 许可协议。转载请注明来源 q779 !
评论
  目录