洛谷P3181 [HAOI2016] 找相同字符 题解
题意:
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。
两个方案不同当且仅当这两个子串中有一个位置不同。
输入格式:
两行,两个字符串 $S,T$,长度分别为 $n,m$。
输出格式:
输出一个整数表示答案。
数据范围:
$1\le n,m\le 2\times 10^5$,字符串中只有小写字母。
建议先去看一下 洛谷P4248 [AHOI2013] 差异 题解
主要是要知道下面这个东西以及怎么算。
那么考虑把两个串用一个分隔符拼在一起,这样答案就是
后面减掉的是因为多算了在同一个串的两个后缀的答案。
不过前面那个东西在两个串全是 $\tt{aaaa}$ 的时候会爆 long long ,要开 __int128 。
时间复杂度 $\mathcal{O}(n \log n)$ ,复杂度瓶颈是建后缀数组。
代码:
#include <bits/stdc++.h>
using namespace std;
// #define int long long
// #define INF 0x3f3f3f3f3f3f3f3f
typedef long long ll;
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
#define N ((int)(4e5 + 55))
#define mem(a) memset(a, 0, sizeof(a))
#define check(i, w) (tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + w] == tmp[sa[i - 1] + w])
char *s, a[N], b[N], c[N];
int top, stk[N], sa[N], rk[N * 2], tmp[N * 2], height[N], cnt[N], l[N], r[N];
void sort(const int n, const int m, int w)
{
memset(cnt, 0, sizeof(int) * (m + 5));
for(int i = 1; i <= n; i++) tmp[i] = sa[i];
for(int i = 1; i <= n; i++) ++cnt[rk[tmp[i] + w]];
for(int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
for(int i = n; i; i--) sa[cnt[rk[tmp[i] + w]]--] = tmp[i];
}
void getlcp(const int n)
{
int k = 0;
for(int i = 1; i <= n; height[rk[i]] = k, i++)
{
const int j = sa[rk[i] - 1]; if(k) --k;
while(i + k <= n && j + k <= n && s[i + k] == s[j + k]) ++k;
}
}
void init(const int n)
{
const int m = max(n, 233);
for(int i = 1; i <= n; i++) { sa[i] = i, rk[i] = s[i]; }
for(int w = 1; w < n; w *= 2)
{
sort(n, m, w); sort(n, m, 0);
for(int i = 1; i <= n; i++) tmp[i] = rk[i];
for(int i = 1, p = 0; i <= n; i++)
if(check(i, w)) rk[sa[i]] = p; else rk[sa[i]] = ++p;
}
}
ll solve(const int n)
{
mem(rk); mem(sa); mem(tmp); // tmp 记得要清空
init(n); getlcp(n); stk[top = 1] = 1;
for(int i = 2; i <= n; l[i] = stk[top], stk[++top] = i, i++)
while(top && height[stk[top]] > height[i]) { r[stk[top]] = i, --top; }
while(top) r[stk[top]] = n + 1, --top;
ll res = 0;
for(int i = 2; i <= n; i++)
res += (ll)(r[i] - i) * (i - l[i]) * height[i];
return res;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
cin >> (a + 1) >> (b + 1);
int n = strlen(a + 1), m = strlen(b + 1);
for(int i = 1; i <= n; i++) c[i] = a[i];
for(int i = 1; i <= m; i++) c[i + n + 1] = b[i];
c[n + 1] = 1; // 不要用 0 , 否则会导致 rk 跑出来是 0
s = c; const ll _1 = solve(n + m + 1);
s = a; const ll _2 = solve(n);
s = b; const ll _3 = solve(m);
cout << _1 - _2 - _3 << '\n';
return 0;
}
题外话:
放图片。
