CF708E Student's Camp 题解
题意:
有一个 \((n+2) \times m\) 的网格。
给定 \(n,m,a,b,k\) ,记 \(p=\frac{a}{b}\)
除了第 \(0\) 行和第 \(n+1\) 行,其他每一行
- 白天最左边的格子有 \(p\) 的概率被吹走。
- 夜晚最右边的格子有 \(p\) 的概率被吹走。
如果网格(包括第 \(0\) 行和第 \(n+1\) 行)被分成了两个及以上连通分量,则称网格不连通。
求 \(k\) 天后,网格始终保持连通的概率。
\(n,m \le 1.5 \times 10^3,~k \le 10^5,~a < b \le 10^9\),答案对 \(10^9+7\) 取模。
特别🐮🍺的一道题,考虑概率dp。
下文中出现的
奇怪变量名(因为参考自 link),先给出对应的代码形式$_i$
pl[i]
,\(\mathtt{pr}_i \to\)pr[i]
\(F_{i,j} \to\)
F[i][j]
, \(\mathtt{sl}_{i,j} \to\)sl[i][j]
, \(\mathtt{sr}_{i,j} \to\)sr[i][j]
注意到每一行是相互独立的,故
设 \(\mathtt{pl}_i\) 表示某一行在 \(k\) 个白天后,最左边的砖块为第 \(i\) 块的概率
设 \(\mathtt{pr}_i\) 表示某一行在 \(k\) 个夜晚后,最右边的砖块为第 \(i\) 块的概率
根据小学数学,如果把这里这个 \(i\)
看作一个离散随机变量,则它一定服从二项分布
记 \(q=1-p\) ,即 \[ \mathtt{pl}_{i+1} = \dbinom{k}{i} q^{i} p^{k-i} \] 根据二项分布的对称性可得 \(\mathtt{pr}_{m-i} = \mathtt{pl}_{i+1}\)
根据砖块吹走的性质,最后每一行一定是连续的一段,
故设 \(f_{i,l,r}\) 表示 \(k\) 天后,从下往上数第 \(i\) 行仅剩下砖块 \(l,l+1,\cdots,r\) 的概率
如果建筑没有倒塌,那么上下两行一定是有交集的。
这个交集并不是很好枚举,但是没交集的部分还是比较好枚举的
设 \(F_{i,r}\) 表示 \(k\) 天后,从下往上数第 \(i\) 行最右边的砖块为第 \(r\) 块的概率,则 \[ F_{i,r} = \sum_{l=1}^{r} f_{i,l,r} \] 设 \(\mathtt{sr}_{i,j}\) 表示 \(k\) 天后,从下往上数第 \(i\) 行最右边的砖块为第 \(j\) 块或更靠左的概率,则 \[ \mathtt{sr}_{i,j} = \sum_{r=1}^{j}F_{i,r} = \sum_{r=1}^{j}\sum_{l=1}^{r} f_{i,l,r} = \sum_{1 \le l\le r \le j} f_{i,l,r} \] 设 \(\mathtt{sl}_{i,j}\) 表示 \(k\) 天后,从下往上数第 \(i\) 行最左边的砖块为第 \(j\) 块或更靠右的概率,则 \[ \mathtt{sl}_{i,j} = \sum_{j \le l \le r \le m} f_{i,l,r} \] 这个可以由对称性 \(\mathtt{sr}_{i,m-j+1}\) 的值推得。
故转移方程为 \[ \begin{aligned} f_{i,l,r} &= \mathtt{pl}_l \times \mathtt{pr}_{r} \times \sum_{[j,k] \cap [l,r] \ne \varnothing} f_{i-1,j,k} \\\\&=\mathtt{pl}_l\times \mathtt{pr}_r \times \left(\sum_{1 \le j \le k \le m}f_{i-1,j,k} - \sum_{1 \le j \le k < l} f_{i-1,j,k}-\sum_{r<j\le k \le m} f_{i-1,j,k} \right) \\\\&=\mathtt{pl}_l\times \mathtt{pr}_r \times (\mathtt{sr}_{i-1,m} - \mathtt{sr}_{i-1,l-1} -\mathtt{sr}_{i-1,r+1}) \end{aligned} \] 于是 \[ \begin{aligned} F_{i,r} &= \sum_{l=1}^{r} f_{i,l,r} \\&=\sum_{l=1}^{r}{\left(\mathtt{pl}_l\times \mathtt{pr}_r \times (\mathtt{sr}_{i-1,m} - \mathtt{sr}_{i-1,l-1} -\mathtt{sr}_{i-1,r+1})\right)} \\&= \mathtt{pr}_r \times \left(\left(\sum_{l=1}^{r} \mathtt{pl}_l\right) \times (\mathtt{sr}_{i-1,m}-\mathtt{sl}_{i-1,r+1}) - \sum_{l=1}^{r} (\mathtt{pl}_l \times \mathtt{sr}_{i-1,l-1})\right) \end{aligned} \] 设 \(\mathtt{s1}_i = \sum_{k=1}^{i} \mathtt{pl}_k,~\mathtt{s2}_i = \sum_{k=1}^{i} \mathtt{pl}_k \times \mathtt{sr}_{i-1,k-1}\) ,则 \[ F_{i,r} = \mathtt{pr}_r \times (\mathtt{s1}_r \times (\mathtt{sr}_{i-1,m} - \mathtt{sl}_{i-1,r+1})-\mathtt{s2}_r) \] 然后我们就能 \(O(nm)\) 求解啦!
最终的答案就是 \(\mathtt{sr}_{n,m}\)
呼,总算写完了,累死了(
代码:
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cstdarg>
#include <cmath>
#include <iomanip>
#include <random>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
#define N (int)(2e3+15)
#define K (int)(1e5+15)
const int mod=1e9+7;
int n,m,i,j,k,p,q;
int P[K],Q[K],C[N],inv[N]; // p^i, q^i, C(k,i), i^-1
int pl[N],pr[N],s1[N];
int F[N][N],sl[N][N],sr[N][N],s2[N];
int qpow(int a,int b)
{
int ans=1,base=a%mod;
for(; b; b>>=1)
{
if(b&1) ans=ans*base%mod;
base=base*base%mod;
}
return ans;
}
int mul(int cnt, ...)
{
va_list ptr; va_start(ptr,cnt);
int res=1;
for(int i=0; i<cnt; i++)
res=res*va_arg(ptr,int)%mod;
va_end(ptr);
return res;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
cin >> n >> m >> i >> j >> k;
q=i*qpow(j,mod-2)%mod;
p=((1-q+mod)%mod+mod)%mod;
inv[1]=C[0]=P[0]=Q[0]=1; C[1]=k;
for(i=1; i<=k; i++) {P[i]=P[i-1]*p%mod; Q[i]=Q[i-1]*q%mod;}
for(i=2; i<=m; i++)
{
inv[i]=(mod-mod/i) * inv[mod%i] % mod;
C[i]=mul(3, C[i-1], (k-i+1), inv[i]);
}
for(i=0; i<=k && i<m; i++) pr[m-i]=pl[i+1]=mul(3, C[i], Q[i], P[k-i]);
for(i=1; i<=m; i++) s1[i]=(s1[i-1]+pl[i]) % mod;
sr[0][m]=1;
for(i=1; i<=n; i++)
{
for(j=1; j<=m; j++)
s2[j] = (s2[j-1] + pl[j] * sr[i-1][j-1]) % mod;
for(j=1; j<=m; j++)
{
F[i][j] = pr[j] * ((sr[i-1][m] - sl[i-1][j+1]) * s1[j] %mod - s2[j]) % mod;
F[i][j] < 0 ? F[i][j] += mod : 0;
}
for(j=1; j<=m; j++)
{
sr[i][j] = (sr[i][j-1]+F[i][j]) % mod;
sl[i][m-j+1]=sr[i][j];
}
}
cout << sr[n][m] << '\n';
return 0;
}