模拟赛题讲解[22]
来自 yukuai26 2022-10-07 noi.ac #2851
题目描述:
小明有一个 $3\times n$ 的方格, 每一格可能是 o
或者 x
小明想要把所有 x
变成 o
他每次会选择一个至少满足以下条件之之一的 x
- 与他水平相邻的格子中有 $2$ 个是
o
- 与他垂直相邻的格子中有 $2$ 个是
o
之后他会把这个 x
变成 o
他想要知道把所有 x
变成 o
的方案数(如果变化 x
的位置的顺序不同,则是不同的方案)
输入格式:
第一行一个正整数 $n$
接下来 $3$ 行,每行一个长度为 $n$ 的字符串
输出格式:
一个整数表示答案。答案对 $10^9+7$ 取模
样例输入 1:
3
oxo
xxo
oxo
样例输出 1:
14
样例输入 2:
10
ooxooxoxoo
xooxxxoxxx
oxoxoooooo
样例输出 2:
149022720
样例输入 3:
20
oxooxoxooxoxooxoxoxo
oxxxoxoxxxooxxxxxoox
oxooxoxooxooxooxoxoo
样例输出 3:
228518545
数据范围:
数据保证没有无解的情况。
对于 $100\%$ 的数据,满足 $1 \leq n \leq 2000$ 。
题解:
毒瘤dp题。
下文使用 $a \uparrow b+c$ 表示 a += b+c
。
设 $f_{i,j,k}$ 表示在前 $i$ 列的所有格子中,$(2,i)$ 是第 $j$ 个被删掉的,
且 $(2,i)$ 是否在 $(2,i+1)$ 之前被删除( $k=0/1$ 表示 $(2,i)$ 在前/后被删除)时的方案数。
考虑转移。我们用刷表法来递推,则前 $i$ 列所有 x
点删除顺序的序列我们已经知道了。
我们钦定 $(2,i+1)$ 是第 $j$ 个被删除的,然后在原本的删除序列中插入第 $i+1$ 列的 $3$ 个元素。
然后讨论 $(2,i+1)$ 是垂直删除的还是水平删除的。记前 $i$ 列有 $s$ 个 x
点,$w_i$ 表示 $(1,i)$ 和 $(3,i)$ 有几个 x
。
但是可能存在「可以同时被两者之一删除」的情况,为了防止重复计数,不妨钦定垂直删除优先。
为了方便转移,设 $g_{x,k}$ 表示将 $k$ 个数插入原来长为 $x$ 的序列的方案数,则有
注:考虑最后的序列为 $x+k$ 个数,其中 $x$ 个是不能改变顺序的,因此 $\binom{x+k}{x}$ 。
然后那 $k$ 个数是可以改变顺序放入剩下 $k$ 个空位的,即 $k!$ 。乘起来就是这个。
如果是垂直删除,那么它上下两个格子都要在它前面删除,则
如果是水平删除,那么它左右两个格子要在它前面删除。
同时,上下两个格子至少要有一个在它后面被删除,否则我们会优先垂直删除它。
对于这种情况,显然 $f_{i+1,j,0}$ 是没有任何增量的。
特别地,我们不考虑 $f_{i,t,1}$ 到 $f_{i+1,j,1}$ 的转移。
因为 $k=1$ 就是为了能够让 $i+1$ 被水平删除才设置的
正常情况下我们应当只转移 $\mathtt{ord}(i) < \mathtt{ord}(i+1) > \mathtt{ord}(i+2)$ 的到 $f_{i+1,j,1}$ 即它被水平删除
但是这种 $1$ 到 $1$ 的转移,会出现 $\mathtt{ord}(i) > \mathtt{ord}(i+1)> \mathtt{ord}(i+2)$ ,这种情况是无法水平删除的
总结一下就是说,$\mathtt{ord}(i)>\mathtt{ord}(i+1)$ 的不能水平删除,但是会被错误地算进去(
我们对于每个 $(2,l)$ 到 $(2,r)$ 均为 x
的子段做一遍dp。
注意 $r=n$ 的时候,不能把 $k=1$ 的答案算进去。因为我们强制要求 $n$ 比 $n+1$ 先删除(显然)
对于相邻的两个子段,考虑合并,设左侧的答案为 $r_1$ ,x
的个数为 $s_1$ ;右侧的答案为 $r_2$ ,x
的1个数为 $s_2$ ,则
后面那个系数是因为段与段之间相互独立,用类似于 $g_{x,k}$ 的方法将左边序列不改变其顺序地放入空位。
这样做的时间复杂度是 $\mathcal{O}(n^3)$ 的。注意到转移方程中的 $\sum f$ 长的就很前缀和优化,随便搞一搞就好了。
但是这只是dp部分的答案,我们只dp了 $(2,i)$ 为 x
的列,其他的列的答案也要合并起来。
时间复杂度 $\mathcal{O}(n^2)$
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
#define DDD cout << "This is Line " << __LINE__
const int mod = 1e9 + 7;
void up(int &x,int y) { x < y ? x = y : 0; }
void down(int &x,int y) { x > y ? x = y : 0; }
void add(int &x,int y) { x = (x + (y % mod)) % mod; }
typedef pair<int,int> pii;
#define Fi first
#define Se second
#define now (pre ^ 1)
#define N ((int)(6e3+65))
char a[4][N];
bitset<N> is;
int n,tmp,w[N],fac[N],invf[N],f[2][N][2];
int mul(int cnt, ...)
{
va_list ptr; va_start(ptr,cnt); int res = 1;
for(int i=0; i<cnt; i++) res = res * va_arg(ptr,int) % mod;
va_end(ptr); return res;
}
int qpow(int a,int b)
{
int ans = 1, base = a % mod;
for(; b; b >>= 1)
{if(b & 1) ans = ans * base % mod; base = base * base % mod; }
return ans;
}
void init(int n)
{
fac[0]=1;
for(int i=1; i<=n; i++) fac[i] = fac[i-1] * i % mod;
invf[n] = qpow(fac[n], mod-2);
for(int i=n; i; i--) invf[i-1] = invf[i] * i % mod;
}
int C(int n,int m)
{
if(n <= m) return n==m;
return mul(3, fac[n], invf[m], invf[n-m]);
}
pii merge(pii a, pii b)
{ return {mul(3, a.Fi, b.Fi, C(a.Se + b.Se, a.Se)), a.Se + b.Se}; }
int put(int x,int k)
{
int res = 1;
for(int i=1; i<=k; i++) res = res * (x + i) % mod;
return res;
}
pii dp(int l,int r)
{
memset(f,0,sizeof(f)); f[0][0][0] = 1;
int pre = 0, sum = 0;
for(int i=l,rp; i<=r; i++)
{
memset(f[now], 0, sizeof(f[now]));
rp = 0;
for(int j=1; j <= sum+1; j++)
{
add(rp, f[pre][j-1][0]);
for(int k=0; k < w[i]; k++)
{
tmp=mul(4, rp, put(j-1,k), put(sum+1-j, w[i]-k), C(w[i],k));
add(f[now][j + k][1], tmp);
}
}
rp = 0;
for(int j=sum; j; j--)
{
add(rp, f[pre][j][1]);
tmp = mul(2, rp, put(j-1, w[i]));
add(f[now][j + w[i]][0], tmp);
}
rp = 0;
for(int j=0; j<=sum; j++) add(rp, f[pre][j][0]);
for(int j=1; j <= sum + 1; j++)
{
tmp = mul(2, rp, put(j-1, w[i]));
add(f[now][j + w[i]][0], tmp);
}
pre ^= 1; sum += w[i] + 1;
}
int res = 0;
for(int i=1; i<=sum; i++) add(res, f[pre][i][0]);
if(r != n) for(int i=1; i<=sum; i++) add(res, f[pre][i][1]);
return {res, sum};
}
int getAns()
{
pii ans = {1,0};
for(int i=1,j; i<=n; i++)
if(is[i])
{
for(j=i; j<n && is[j+1]; ++j);
ans = merge(ans, dp(i,j)); i = j;
// cout << ans.Fi << '\n';
}
return ans.Fi;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
init(6010); cin >> n;
for(int i=1; i<=3; i++) cin >> (a[i]+1);
bool ok=1;
for(int i=1; i<=3; i+=2)
{
if(a[i][1] == 'x' || a[i][n] == 'x') ok = 0;
for(int j=2; j<=n; j++)
if(a[i][j] == 'x' && a[i][j-1] == 'x') ok = 0;
}
if(!ok) return cout << "0\n",0;
for(int i=1; i<=n; i++)
{
is[i] = (a[2][i] == 'x');
w[i] = ((a[1][i] == 'x') + (a[3][i] == 'x'));
}
int dt = 0, ext = 0;
for(int i=1; i<=n; i++) is[i] ? (dt += w[i]+1) : (ext += w[i]);
int res = getAns();
for(int i=1; i<=ext; i++) res = res * (dt + i) % mod;
cout << res << '\n';
return 0;
}
感谢 yukuai26 老师的耐心指导 Orz