洛谷P2336 [SCOI2012] 喵星球上的点名 题解
题意:
cxy 幸运地被选做了地球到喵星球的留学生。他发现喵星人在上课前的点名现象非常有趣。
假设课堂上有 $n$ 个喵星人,每个喵星人的名字由姓和名构成。喵星球上的老师会选择 $m$ 个串来点名,每次读出一个串的时候,如果这个串是一个喵星人的姓或名的子串,那么这个喵星人就必须答到。
然而,由于喵星人的字码如此古怪,以至于不能用 ASCII 码来表示。为了方便描述,cxy 决定用数串来表示喵星人的名字。
现在你能帮助 cxy 统计每次点名的时候有多少喵星人答到,以及 $m$ 次点名结束后每个喵星人答到多少次吗?
输入格式:
首先定义喵星球上的字符串给定方法:
先给出一个正整数 $l$,表示字符串的长度,接下来 $l$ 个整数,第 $i$ 个整数 $a_i$ 表示字符串的第 $i$ 个字符。
输入的第一行有两个整数,分别表示喵星人的个数 $n$ 个点名次数 $m$。
接下来 $n$ 行,每行两个喵星球上的字符串,按照定义的方法给出,依次表示第 $i$ 只喵的姓和名。
接下来 $m$ 行,每行一个喵星球上的字符串,表示一个老师点名的串。
输出格式:
对于每个老师点名的串,输出一行一个整数表示有多少只喵答到。
然后在最后一行输出 $n$ 个用空格隔开的整数,第 $i$ 个整数表示第 $i$ 个喵星人被点到的次数。
数据范围:
保证 $1 \leq n\le 5 \times 10^4$,$1 \leq m \le 10^5$,喵星人的名字总长和点名串的总长分别不超过 $10^5$
保证喵星人的字符串中作为字符存在的数不超过 $10^4$ 。
首先把姓和名用分隔符拼起来,再把这些姓名用分隔符拼起来,接着把询问串也用分隔符拼上去
对于每个询问串,向左右找到合法区间,那么问题就变成了
- 查询区间内有多少种颜色。
- 查询每种颜色被多少区间包含。
第一问就是经典老题:HH 的项链。
考虑用树状数组,修改的话每次加上贡献后减去 pre 的贡献即可。
提前把询问离线后按右端点排序,这样答案就是树状数组查区间和了。
第二问可以和第一问一起做。
对于询问 $(l,r)$ ,它的贡献就是一个区间加法,可以用树状数组维护差分数组实现。
不过由于同一种颜色会出现在 $(l,r)$ 中,好在我们已经维护了 pre ,那么答案就是 qry(i) - qry(pre)
。
时间复杂度 $\mathcal{O}(N \log N)$ ,其中 $N = \sum|s_i|$ 。
代码:
#include <bits/stdc++.h>
using namespace std;
// #define int long long
#define INF 0x3f3f3f3f
typedef long long ll;
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
#define N ((int)(4e5 + 55))
#define mem(a) memset(a, 0, sizeof(a))
#define check(i, w) (tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + w] == tmp[sa[i - 1] + w])
int trM, s[N], len[N], hd[N], lg[N], pre[N], lp[N], st[19][N], tr1[N], tr2[N];
int ans1[N], ans2[N], sa[N], rk[N * 2], col[N], tmp[N * 2], height[N], cnt[N];
int lowbit(int x) { return x & (-x); }
void add(int *tr, int x, int v) { for(int i = x; i && i <= trM; i += lowbit(i)) tr[i] += v; }
int qry(int *tr, int x, int r = 0) { for(int i = x; i; i -= lowbit(i)) r += tr[i]; return r; }
struct node { int l, r, id; } a[N];
void sort(const int n, const int m, int w)
{
memset(cnt, 0, sizeof(int) * (m + 5));
for(int i = 1; i <= n; i++) tmp[i] = sa[i];
for(int i = 1; i <= n; i++) ++cnt[rk[tmp[i] + w]];
for(int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
for(int i = n; i; i--) sa[cnt[rk[tmp[i] + w]]--] = tmp[i];
}
void getlcp(const int n) // 求 height 数组
{
int k = 0;
for(int i = 1; i <= n; height[rk[i]] = k, i++)
{
const int j = sa[rk[i] - 1]; if(k) --k;
while(i + k <= n && j + k <= n && s[i + k] == s[j + k]) ++k;
}
}
void init(const int n) // SA 板子
{
const int m = max(n, 66666);
for(int i = 1; i <= n; i++) { sa[i] = i, rk[i] = s[i]; }
for(int w = 1; w < n; w *= 2)
{
sort(n, m, w); sort(n, m, 0);
for(int i = 1; i <= n; i++) tmp[i] = rk[i];
for(int i = 1, p = 0; i <= n; i++)
if(check(i, w)) rk[sa[i]] = p; else rk[sa[i]] = ++p;
}
}
int query(int l, int r)
{
if(l == r) return INF;
int k = lg[r - l];
return min(st[k][l + 1], st[k][r - (1 << k) + 1]); // 查询 [l + 1, r]
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
int tot, q, n = 0; cin >> tot >> q; int _ = 10005; // 分隔符要取大于值域的
for(int i = 1, m, x; i <= tot; s[++n] = ++_, i++)
{
cin >> m;
for(int j = 1; j <= m; j++) { cin >> x; s[++n] = x, col[n] = i; }
s[++n] = ++_; cin >> m;
for(int j = 1; j <= m; j++) { cin >> x; s[++n] = x, col[n] = i; }
}
for(int i = 1; i <= q; s[++n] = ++_, i++)
{
cin >> len[n + 1]; hd[n + 1] = i;
for(int j = len[n + 1], x; j; j--) { cin >> x; s[++n] = x; col[n] = -i; }
}
// for(int i = 1; i <= n; i++) cout << s[i] << ' ';
init(n); getlcp(n); mem(tmp); trM = n;
for(int i = 1; i <= n; i++) st[0][i] = height[i];
for(int i = 1; i < 19; i++) // ST 表
{
int *f = st[i], *g = st[i - 1];
for(int j = 1; j + (1 << i) - 1 <= n; j++)
f[j] = min(g[j], g[j + (1 << (i - 1))]);
}
for(int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
for(int i = 1; i <= n; i++)
{
if(col[sa[i]] > 0) { pre[i] = tmp[col[sa[i]]]; tmp[col[sa[i]]] = i; }
if(hd[i])
{
a[hd[i]].id = hd[i]; int l = 1, r = rk[i];
while(l < r)
{
int mid = (l + r) >> 1;
if(query(mid, rk[i]) >= len[i]) r = mid; else l = mid + 1;
}
a[hd[i]].l = lp[hd[i]] = l; l = rk[i], r = n;
while(l < r)
{
int mid = (l + r + 1) >> 1;
if(query(rk[i], mid) >= len[i]) l = mid; else r = mid - 1;
}
a[hd[i]].r = l; // 我这个二分写法最后 l = r
}
}
// for(int i = 1; i <= q; i++) cout << a[i].l << ' ' << a[i].r << '\n';
sort(a + 1, a + 1 + q, [](node a, node b) { return a.r < b.r; });
sort(lp + 1, lp + q + 1);
for(int i = 1, j = 1, k = 1; i <= n; i++)
{
for(; j <= q && lp[j] == i; ++j) add(tr2, i, 1);
if(col[sa[i]] > 0)
{
ans2[col[sa[i]]] += qry(tr2, i) - qry(tr2, pre[i]);
add(tr1, i, 1); add(tr1, pre[i], -1);
}
for(; k <= q && a[k].r == i; ++k)
{
ans1[a[k].id] = qry(tr1, a[k].r) - qry(tr1, a[k].l - 1);
add(tr2, a[k].l, -1); // 用了一些奇♂妙的技巧少写一个函数(大雾)
}
}
for(int i = 1; i <= q; i++) cout << ans1[i] << '\n';
for(int i = 1; i <= tot; i++) cout << ans2[i] << " \n"[i == tot];
return 0;
}
参考文献:
[1] https://www.luogu.com.cn/article/5fczau5p
题外话:
这个代码是真长啊。