嘘~ 正在从服务器偷取页面 . . .

洛谷P3352 [ZJOI2016]线段树 题解


洛谷P3352 [ZJOI2016]线段树 题解

题目链接:P3352 [ZJOI2016]线段树

题意

cxy 遇到了一个题目:有一个序列 \(a_1,a_2,\ldots,a_n\)\(m\) 次操作。每次操作把一个区间内的数改成区间内的最大值,问最后每个数是多少。 cxy 很快地就使用了线段树解决了这个问题。

于是充满智慧的 cxy 想,如果操作是随机的,即在这 \(m\) 次操作中每次等概率随机地选择一个区间 \([l,r]\)\(1 \leq l \leq r \leq n\)),然后将这个区间内的数改成区间内最大值(注意这样的区间共有 \(\frac{n(n+1)}{2}\) 个),最后每个数的期望大小是多少呢?

cxy 非常热爱随机,所以她给出的输入序列也是随机的(随机方式见数据规模和约定)。

对于每个数,输出它的期望乘 \(\left(\frac{n(n+1)}{2} \right)^m\) 再对 \(10^9+7\) 取模的值。

输入格式

第一行包含两个正整数 \(n,m\),表示序列里数的个数和操作的个数。

接下来一行,包含 \(n\) 个非负整数 \(a_1,a_2,\ldots,a_n\)

输出格式

输出共一行,包含 \(n\) 个整数,表示每个数的答案。

数据范围

\(n \le 400,m\le 400\)

对于所有的测试数据,保证序列中数的大小不超过 \(10^9\),并且每个数是 \(0\)\(10^9\) 之间的随机整数

中秋节晚上颓颓颓,Ravenfield装了十几个mod

期望dp+组合数神题。

因为答案乘以了 \(\left(\frac{n(n+1)}{2} \right)^m\) ,所以最后就是求所有方案中每个数的和。

首先考虑 \(01\) 序列怎么做。

我们可以统计每一个位置最终为 \(0\) 的方案数

  • 如果 \(a_i\) 初始为 \(1\) ,则无论如何最终都不会变成 \(0\)

  • 如果 \(a_i\) 初始为 \(0\) ,则设它左边最近的 \(1\)\(l\) ,右边最近的 \(1\)\(r\)

    显然每次操作只会使 \(l,r\) 缩小(也可能不变)。

因此我们对每个初始为 \(0\) 的极长子段做一遍dp,然后把总贡献减去它。

也就是把 \(\left(\frac{n(n+1)}{2} \right)^m\) 减去「最终为 \(0\) 的方案数」(因为是 \(01\) 序列,所以贡献等于方案数)

\(f_{x,l,r}\) 表示 \(x\) 次操作后,区间缩到 \((l,r)\) 的方案数,则 \[ f_{x,l,r} = f_{x-1,l,r} \times g_{l,r} + \sum_{0\le t < l}f_{x-1,t,r} + \sum_{r < t\le n+1}f_{x-1,l,t} \]

注意这里是开区间,即 \[ \forall i \in(l,r),~h^x(a_i) = 0,h^x(a_l) = h^x(a_r)=1 \] 其中 \(h^x(a_i)\) 表示第 \(x\) 操作后 \(a_i\) 的值。

其中 \(g_{l,r}\) 表示无用操作的数量,有 \[ \begin{aligned} g_{l,r} &= \left[\binom{l+1}{2} + \binom{r-l}{2} + \binom{n-r+2}{2} \right] \\&= \frac{1}{2}\left[l(l+1)+(n-r+1)(n-r+2)+(r-l-1)(r-l)\right] \end{aligned} \]

注:这里涉及了可重组合,即 \(n\) 个不同元素选 \(m\) 个可重元素组合的方案数为 \(\dbinom{n+m-1}{m}\)

不难发现上面的 \(\sum\limits_{1\le t < l}f_{x-1,t,r} + \sum\limits_{r < t\le n}f_{x-1,l,t}\) 可以前缀和优化。

因此 \(01\) 序列时间复杂度 \(\mathcal{O}(n^2m)\)

然后考虑原题怎么做。不难发现,最终序列的某个数 \(x\) 可以拆分为(记 \(M=\max\{a_i\}\)\[ x \Leftrightarrow M - \sum_{i=0}^{M} [i > x] \]

注:(来自参考文献[2])

不拆分为 \(\sum_{i=0}^{M} [i < x]\) 的原因是统计贡献时与 \(l,r\) 有关,很麻烦,不利于下文中的整体dp。

我们枚举这个 \(i\) ,然后把所有 \(\ge i\) 的位置标为 \(1\) ,所有 \(< i\) 的标记为 \(0\) ,这样就可以求出每个位置 \(<i\) 的方案数。(相当于做 \(M\)\(01\) 的情况)

但是这个枚举的过程不会爆炸吗?注意到 \(i\)\(0\)\(M\) 扫的过程中, \(01\) 序列其实只会变化 \(\mathcal{O}(n)\) 次,因此实际不用做 \(M\) 次。

因此复杂度为 \(\mathcal{O}(n^3m)\) ,但因为随机数据可以卡过。

不过其实是存在稳定 \(\mathcal{O}(n^2m)\) 的算法的。

注意到每一层的dp方程是一样的,根据期望的线性性质,

我们可以一次性把初值全加进去,然后跑一次整体的dp就可以算出来了,是不是很神奇?

实现的时候把所有数从大到小依次添加贡献,每次加的方案数就是与前一次的差值(具体可以看代码)

时间复杂度 \(\mathcal{O}(n^2m)\) ,空间的话滚一下数组就好了。

代码:

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
#define N ((int)(555))
const int mod = 1e9+7;

bool vis[N];
struct node{int id,w;} a[N];
int n,m,ans[N],f[N][N],s1[N][N],s2[N][N];
bool cmp(node x,node y){return x.w < y.w; }
void Add(int &x,int y) { (x += y) >= mod ? x-=mod : 0;}
int add(int x,int y) { return (x += y) < mod ? x : x-mod;}
int qpow(int a,int b)
{int ans=1,base=a%mod;while(b){if(b&1)ans=ans*base%mod;base=base*base%mod;b>>=1;}return ans;}
int g(int l,int r) {return ( l*(l+1)+ (n-r+1)*(n-r+2) + (r-l-1)*(r-l) ) / 2;}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    // freopen("check.in","r",stdin);
    // freopen("check.out","w",stdout);
    cin >> n >> m; vis[n+1]=1;
    for(int i=1; i<=n; i++) { cin >> a[i].w; a[i].id = i; } sort(a+1, a+1+n, cmp);
    for(int i=n,last; i; i--)
    {
        last = 0; vis[a[i].id]=1;
        for(int j=1; j<=n+1; j++) if(vis[j]) { f[last][j] += a[i].w-a[i-1].w; last = j; }
    }
    ans[1] = a[n].w * qpow(n*(n+1)/2,m) % mod; for(int i=2; i<=n; i++) ans[i] = ans[1];
    for(int i=1; i<=m; i++)
    {
        for(int j=0; j<=n; j++) for(int k=n+1; k>j+1; --k)
        {
            s1[j][k] = add(j ? s1[j-1][k] : 0, f[j][k] * j % mod);
            s2[j][k] = add(k <= n ? s2[j][k+1] : 0, f[j][k] * (n-k+1) % mod);
        }
        for(int j=0; j<=n; j++) for(int k=j+2; k<=n+1; k++)
        {
            f[j][k] = f[j][k] * g(j,k) % mod;
            if(j) Add(f[j][k], s1[j-1][k]); 
            if(k<=n) Add(f[j][k], s2[j][k+1]);
        }
    }
    for(int i=0; i<=n; i++) for(int j=i+2; j<=n+1; j++)
        for(int k=i+1; k<j; k++) Add(ans[k], (mod-f[i][j])%mod);
    for(int i=1; i<=n; i++) cout << ans[i] << " \n"[i==n];
    return 0;
}

呼,终于做完这题了,修错误就修了半天,很难想象我到底有多少还有笔误没修

参考文献

[1] https://www.luogu.com.cn/blog/119621/solution-p3352

[2] https://www.luogu.com.cn/blog/i207M/gai-shuai-dp-hao-ti-zjoi2016-xian-duan-shu-xie-ti-bao-gao

[3] https://yhx-12243.github.io/OI-transit/records/lydsy4574%3Blg3352%3Buoj196%3Bloj2093.html


文章作者: q779
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-ND 4.0 许可协议。转载请注明来源 q779 !
评论
  目录