嘘~ 正在从服务器偷取页面 . . .

洛谷P3181 [HAOI2016] 找相同字符 题解


洛谷P3181 [HAOI2016] 找相同字符 题解

题目链接:P3181 [HAOI2016] 找相同字符

题意

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。

两个方案不同当且仅当这两个子串中有一个位置不同。

输入格式

两行,两个字符串 $S,T$,长度分别为 $n,m$。

输出格式

输出一个整数表示答案。

数据范围

$1\le n,m\le 2\times 10^5$,字符串中只有小写字母。

建议先去看一下 洛谷P4248 [AHOI2013] 差异 题解

主要是要知道下面这个东西以及怎么算。

那么考虑把两个串用一个分隔符拼在一起,这样答案就是

后面减掉的是因为多算了在同一个串的两个后缀的答案。

不过前面那个东西在两个串全是 $\tt{aaaa}$ 的时候会爆 long long ,要开 __int128

时间复杂度 $\mathcal{O}(n \log n)$ ,复杂度瓶颈是建后缀数组。

代码:

#include <bits/stdc++.h>
using namespace std;
// #define int long long
// #define INF 0x3f3f3f3f3f3f3f3f
typedef long long ll;
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
#define N ((int)(4e5 + 55))
#define mem(a) memset(a, 0, sizeof(a))
#define check(i, w) (tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + w] == tmp[sa[i - 1] + w])

char *s, a[N], b[N], c[N];
int top, stk[N], sa[N], rk[N * 2], tmp[N * 2], height[N], cnt[N], l[N], r[N];
void sort(const int n, const int m, int w)
{
    memset(cnt, 0, sizeof(int) * (m + 5));
    for(int i = 1; i <= n; i++) tmp[i] = sa[i];
    for(int i = 1; i <= n; i++) ++cnt[rk[tmp[i] + w]];
    for(int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
    for(int i = n; i; i--) sa[cnt[rk[tmp[i] + w]]--] = tmp[i];
}
void getlcp(const int n)
{
    int k = 0;
    for(int i = 1; i <= n; height[rk[i]] = k, i++)
    {
        const int j = sa[rk[i] - 1]; if(k) --k;
        while(i + k <= n && j + k <= n && s[i + k] == s[j + k]) ++k;
    }
}
void init(const int n)
{
    const int m = max(n, 233);
    for(int i = 1; i <= n; i++) { sa[i] = i, rk[i] = s[i]; }
    for(int w = 1; w < n; w *= 2)
    {
        sort(n, m, w); sort(n, m, 0);
        for(int i = 1; i <= n; i++) tmp[i] = rk[i];
        for(int i = 1, p = 0; i <= n; i++)
            if(check(i, w)) rk[sa[i]] = p; else rk[sa[i]] = ++p;
    }
}
ll solve(const int n)
{
    mem(rk); mem(sa); mem(tmp); // tmp 记得要清空
    init(n); getlcp(n); stk[top = 1] = 1;
    for(int i = 2; i <= n; l[i] = stk[top], stk[++top] = i, i++)
        while(top && height[stk[top]] > height[i]) { r[stk[top]] = i, --top; }
    while(top) r[stk[top]] = n + 1, --top;
    ll res = 0;
    for(int i = 2; i <= n; i++)
        res += (ll)(r[i] - i) * (i - l[i]) * height[i];
    return res;
}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    // freopen("check.in","r",stdin);
    // freopen("check.out","w",stdout);
    cin >> (a + 1) >> (b + 1);
    int n = strlen(a + 1), m = strlen(b + 1);
    for(int i = 1; i <= n; i++) c[i] = a[i];
    for(int i = 1; i <= m; i++) c[i + n + 1] = b[i];
    c[n + 1] = 1; // 不要用 0 , 否则会导致 rk 跑出来是 0
    s = c; const ll _1 = solve(n + m + 1);
    s = a; const ll _2 = solve(n);
    s = b; const ll _3 = solve(m);
    cout << _1 - _2 - _3 << '\n';
    return 0;
}

题外话

放图片。


文章作者: q779
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-ND 4.0 许可协议。转载请注明来源 q779 !
评论
  目录