洛谷P3352 [ZJOI2016]线段树 题解
题目链接:P3352 [ZJOI2016]线段树
题意:
cxy 遇到了一个题目:有一个序列 $a_1,a_2,\ldots,a_n$,$m$ 次操作。每次操作把一个区间内的数改成区间内的最大值,问最后每个数是多少。 cxy 很快地就使用了线段树解决了这个问题。
于是充满智慧的 cxy 想,如果操作是随机的,即在这 $m$ 次操作中每次等概率随机地选择一个区间 $[l,r]$($1 \leq l \leq r \leq n$),然后将这个区间内的数改成区间内最大值(注意这样的区间共有 $\frac{n(n+1)}{2}$ 个),最后每个数的期望大小是多少呢?
cxy 非常热爱随机,所以她给出的输入序列也是随机的(随机方式见数据规模和约定)。
对于每个数,输出它的期望乘 $\left(\frac{n(n+1)}{2} \right)^m$ 再对 $10^9+7$ 取模的值。
输入格式:
第一行包含两个正整数 $n,m$,表示序列里数的个数和操作的个数。
接下来一行,包含 $n$ 个非负整数 $a_1,a_2,\ldots,a_n$。
输出格式:
输出共一行,包含 $n$ 个整数,表示每个数的答案。
数据范围:
$n \le 400,m\le 400$ 。
对于所有的测试数据,保证序列中数的大小不超过 $10^9$,并且每个数是 $0$ 到 $10^9$ 之间的随机整数。
中秋节晚上颓颓颓,Ravenfield装了十几个mod
期望dp+组合数神题。
因为答案乘以了 $\left(\frac{n(n+1)}{2} \right)^m$ ,所以最后就是求所有方案中每个数的和。
首先考虑 $01$ 序列怎么做。
我们可以统计每一个位置最终为 $0$ 的方案数
如果 $a_i$ 初始为 $1$ ,则无论如何最终都不会变成 $0$ 。
如果 $a_i$ 初始为 $0$ ,则设它左边最近的 $1$ 在 $l$ ,右边最近的 $1$ 在 $r$ 。
显然每次操作只会使 $l,r$ 缩小(也可能不变)。
因此我们对每个初始为 $0$ 的极长子段做一遍dp,然后把总贡献减去它。
也就是把 $\left(\frac{n(n+1)}{2} \right)^m$ 减去「最终为 $0$ 的方案数」(因为是 $01$ 序列,所以贡献等于方案数)
设 $f_{x,l,r}$ 表示 $x$ 次操作后,区间缩到 $(l,r)$ 的方案数,则
注意这里是开区间,即
其中 $h^x(a_i)$ 表示第 $x$ 操作后 $a_i$ 的值。
其中 $g_{l,r}$ 表示无用操作的数量,有
注:这里涉及了可重组合,即 $n$ 个不同元素选 $m$ 个可重元素组合的方案数为 $\dbinom{n+m-1}{m}$ 。
不难发现上面的 $\sum\limits_{1\le t < l}f_{x-1,t,r} + \sum\limits_{r < t\le n}f_{x-1,l,t}$ 可以前缀和优化。
因此 $01$ 序列时间复杂度 $\mathcal{O}(n^2m)$ 。
然后考虑原题怎么做。不难发现,最终序列的某个数 $x$ 可以拆分为(记 $M=\max\{a_i\}$ )
注:(来自参考文献[2])
不拆分为 $\sum_{i=0}^{M} [i < x]$ 的原因是统计贡献时与 $l,r$ 有关,很麻烦,不利于下文中的整体dp。
我们枚举这个 $i$ ,然后把所有 $\ge i$ 的位置标为 $1$ ,所有 $< i$ 的标记为 $0$ ,这样就可以求出每个位置 $<i$ 的方案数。(相当于做 $M$ 次 $01$ 的情况)
但是这个枚举的过程不会爆炸吗?注意到 $i$ 从 $0$ 到 $M$ 扫的过程中, $01$ 序列其实只会变化 $\mathcal{O}(n)$ 次,因此实际不用做 $M$ 次。
因此复杂度为 $\mathcal{O}(n^3m)$ ,但因为随机数据可以卡过。
不过其实是存在稳定 $\mathcal{O}(n^2m)$ 的算法的。
注意到每一层的dp方程是一样的,根据期望的线性性质,
我们可以一次性把初值全加进去,然后跑一次整体的dp就可以算出来了,是不是很神奇?
实现的时候把所有数从大到小依次添加贡献,每次加的方案数就是与前一次的差值(具体可以看代码)
时间复杂度 $\mathcal{O}(n^2m)$ ,空间的话滚一下数组就好了。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
#define N ((int)(555))
const int mod = 1e9+7;
bool vis[N];
struct node{int id,w;} a[N];
int n,m,ans[N],f[N][N],s1[N][N],s2[N][N];
bool cmp(node x,node y){return x.w < y.w; }
void Add(int &x,int y) { (x += y) >= mod ? x-=mod : 0;}
int add(int x,int y) { return (x += y) < mod ? x : x-mod;}
int qpow(int a,int b)
{int ans=1,base=a%mod;while(b){if(b&1)ans=ans*base%mod;base=base*base%mod;b>>=1;}return ans;}
int g(int l,int r) {return ( l*(l+1)+ (n-r+1)*(n-r+2) + (r-l-1)*(r-l) ) / 2;}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
cin >> n >> m; vis[n+1]=1;
for(int i=1; i<=n; i++) { cin >> a[i].w; a[i].id = i; } sort(a+1, a+1+n, cmp);
for(int i=n,last; i; i--)
{
last = 0; vis[a[i].id]=1;
for(int j=1; j<=n+1; j++) if(vis[j]) { f[last][j] += a[i].w-a[i-1].w; last = j; }
}
ans[1] = a[n].w * qpow(n*(n+1)/2,m) % mod; for(int i=2; i<=n; i++) ans[i] = ans[1];
for(int i=1; i<=m; i++)
{
for(int j=0; j<=n; j++) for(int k=n+1; k>j+1; --k)
{
s1[j][k] = add(j ? s1[j-1][k] : 0, f[j][k] * j % mod);
s2[j][k] = add(k <= n ? s2[j][k+1] : 0, f[j][k] * (n-k+1) % mod);
}
for(int j=0; j<=n; j++) for(int k=j+2; k<=n+1; k++)
{
f[j][k] = f[j][k] * g(j,k) % mod;
if(j) Add(f[j][k], s1[j-1][k]);
if(k<=n) Add(f[j][k], s2[j][k+1]);
}
}
for(int i=0; i<=n; i++) for(int j=i+2; j<=n+1; j++)
for(int k=i+1; k<j; k++) Add(ans[k], (mod-f[i][j])%mod);
for(int i=1; i<=n; i++) cout << ans[i] << " \n"[i==n];
return 0;
}
呼,终于做完这题了,修错误就修了半天,很难想象我到底有多少还有笔误没修
参考文献:
[1] https://www.luogu.com.cn/blog/119621/solution-p3352
[2] https://www.luogu.com.cn/blog/i207M/gai-shuai-dp-hao-ti-zjoi2016-xian-duan-shu-xie-ti-bao-gao
[3] https://yhx-12243.github.io/OI-transit/records/lydsy4574%3Blg3352%3Buoj196%3Bloj2093.html