嘘~ 正在从服务器偷取页面 . . .

模拟赛题讲解[22]


模拟赛题讲解[22]

来自 yukuai26 2022-10-07 noi.ac #2851

题目描述

小明有一个 $3\times n$ 的方格, 每一格可能是 o 或者 x

小明想要把所有 x 变成 o

他每次会选择一个至少满足以下条件之之一的 x

  • 与他水平相邻的格子中有 $2$ 个是 o
  • 与他垂直相邻的格子中有 $2$ 个是 o

之后他会把这个 x 变成 o

他想要知道把所有 x 变成 o 的方案数(如果变化 x 的位置的顺序不同,则是不同的方案)

输入格式

第一行一个正整数 $n$

接下来 $3$ 行,每行一个长度为 $n$ 的字符串

输出格式

一个整数表示答案。答案对 $10^9+7$ 取模

样例输入 1

3
oxo
xxo
oxo

样例输出 1

14

样例输入 2

10
ooxooxoxoo
xooxxxoxxx
oxoxoooooo

样例输出 2

149022720

样例输入 3

20
oxooxoxooxoxooxoxoxo
oxxxoxoxxxooxxxxxoox
oxooxoxooxooxooxoxoo

样例输出 3

228518545

数据范围

数据保证没有无解的情况。

对于 $100\%$ 的数据,满足 $1 \leq n \leq 2000$ 。


题解

毒瘤dp题。

下文使用 $a \uparrow b+c$ 表示 a += b+c

设 $f_{i,j,k}$ 表示在前 $i$ 列的所有格子中,$(2,i)$ 是第 $j$ 个被删掉的,

且 $(2,i)$ 是否在 $(2,i+1)$ 之前被删除( $k=0/1$ 表示 $(2,i)$ 在前/后被删除)时的方案数。

考虑转移。我们用刷表法来递推,则前 $i$ 列所有 x 点删除顺序的序列我们已经知道了。

我们钦定 $(2,i+1)$ 是第 $j$ 个被删除的,然后在原本的删除序列中插入第 $i+1$ 列的 $3$ 个元素。

然后讨论 $(2,i+1)$ 是垂直删除的还是水平删除的。记前 $i$ 列有 $s$ 个 x 点,$w_i$ 表示 $(1,i)$ 和 $(3,i)$ 有几个 x

但是可能存在「可以同时被两者之一删除」的情况,为了防止重复计数,不妨钦定垂直删除优先。

为了方便转移,设 $g_{x,k}$ 表示将 $k$ 个数插入原来长为 $x$ 的序列的方案数,则有

注:考虑最后的序列为 $x+k$ 个数,其中 $x$ 个是不能改变顺序的,因此 $\binom{x+k}{x}$ 。

然后那 $k$ 个数是可以改变顺序放入剩下 $k$ 个空位的,即 $k!$ 。乘起来就是这个。

  • 如果是垂直删除,那么它上下两个格子都要在它前面删除,则

  • 如果是水平删除,那么它左右两个格子要在它前面删除。

    同时,上下两个格子至少要有一个在它后面被删除,否则我们会优先垂直删除它。

    对于这种情况,显然 $f_{i+1,j,0}$ 是没有任何增量的。

    特别地,我们不考虑 $f_{i,t,1}$ 到 $f_{i+1,j,1}$ 的转移。

    因为 $k=1$ 就是为了能够让 $i+1$ 被水平删除才设置的

    正常情况下我们应当只转移 $\mathtt{ord}(i) < \mathtt{ord}(i+1) > \mathtt{ord}(i+2)$ 的到 $f_{i+1,j,1}$ 即它被水平删除

    但是这种 $1$ 到 $1$ 的转移,会出现 $\mathtt{ord}(i) > \mathtt{ord}(i+1)> \mathtt{ord}(i+2)$ ,这种情况是无法水平删除的

    总结一下就是说,$\mathtt{ord}(i)>\mathtt{ord}(i+1)$ 的不能水平删除,但是会被错误地算进去(

我们对于每个 $(2,l)$ 到 $(2,r)$ 均为 x 的子段做一遍dp。

注意 $r=n$ 的时候,不能把 $k=1$ 的答案算进去。因为我们强制要求 $n$ 比 $n+1$ 先删除(显然)

对于相邻的两个子段,考虑合并,设左侧的答案为 $r_1$ ,x 的个数为 $s_1$ ;右侧的答案为 $r_2$ ,x 的1个数为 $s_2$ ,则

后面那个系数是因为段与段之间相互独立,用类似于 $g_{x,k}$ 的方法将左边序列不改变其顺序地放入空位。

这样做的时间复杂度是 $\mathcal{O}(n^3)$ 的。注意到转移方程中的 $\sum f$ 长的就很前缀和优化,随便搞一搞就好了。

但是这只是dp部分的答案,我们只dp了 $(2,i)$ 为 x 的列,其他的列的答案也要合并起来。

时间复杂度 $\mathcal{O}(n^2)$

代码:

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
#define DDD cout << "This is Line " << __LINE__
const int mod = 1e9 + 7;
void up(int &x,int y) { x < y ? x = y : 0; }
void down(int &x,int y) { x > y ? x = y : 0; }
void add(int &x,int y) { x = (x + (y % mod)) % mod;  }
typedef pair<int,int> pii;
#define Fi first
#define Se second
#define now (pre ^ 1)
#define N ((int)(6e3+65))

char a[4][N];
bitset<N> is;
int n,tmp,w[N],fac[N],invf[N],f[2][N][2];
int mul(int cnt, ...)
{
    va_list ptr; va_start(ptr,cnt); int res = 1;
    for(int i=0; i<cnt; i++) res = res * va_arg(ptr,int) % mod;
    va_end(ptr); return res;
}
int qpow(int a,int b)
{
    int ans = 1, base = a % mod;
    for(; b; b >>= 1)
    {if(b & 1) ans = ans * base % mod; base = base * base % mod; }
    return ans;
}
void init(int n)
{
    fac[0]=1;
    for(int i=1; i<=n; i++) fac[i] = fac[i-1] * i % mod;
    invf[n] = qpow(fac[n], mod-2);
    for(int i=n; i; i--) invf[i-1] = invf[i] * i % mod;
}
int C(int n,int m)
{
    if(n <= m) return n==m;
    return mul(3, fac[n], invf[m], invf[n-m]);
}
pii merge(pii a, pii b)
{ return {mul(3, a.Fi, b.Fi, C(a.Se + b.Se, a.Se)), a.Se + b.Se}; }
int put(int x,int k)
{
    int res = 1;
    for(int i=1; i<=k; i++) res = res * (x + i) % mod;
    return res;
}
pii dp(int l,int r)
{
    memset(f,0,sizeof(f)); f[0][0][0] = 1;
    int pre = 0, sum = 0;
    for(int i=l,rp; i<=r; i++)
    {
        memset(f[now], 0, sizeof(f[now]));
        rp = 0;
        for(int j=1; j <= sum+1; j++)
        {
            add(rp, f[pre][j-1][0]);
            for(int k=0; k < w[i]; k++)
            {
                tmp=mul(4, rp, put(j-1,k), put(sum+1-j, w[i]-k), C(w[i],k));
                add(f[now][j + k][1], tmp);
            }
        }
        rp = 0;
        for(int j=sum; j; j--)
        {
            add(rp, f[pre][j][1]);
            tmp = mul(2, rp, put(j-1, w[i]));
            add(f[now][j + w[i]][0], tmp);
        }
        rp = 0;
        for(int j=0; j<=sum; j++) add(rp, f[pre][j][0]);
        for(int j=1; j <= sum + 1; j++)
        {
            tmp = mul(2, rp, put(j-1, w[i]));
            add(f[now][j + w[i]][0], tmp);
        }
        pre ^= 1; sum += w[i] + 1;
    }
    int res = 0;
    for(int i=1; i<=sum; i++) add(res, f[pre][i][0]);
    if(r != n) for(int i=1; i<=sum; i++) add(res, f[pre][i][1]);
    return {res, sum};
}
int getAns()
{
    pii ans = {1,0};
    for(int i=1,j; i<=n; i++)
        if(is[i])
        {
            for(j=i; j<n && is[j+1]; ++j);
            ans = merge(ans, dp(i,j)); i = j;
            // cout << ans.Fi << '\n';
        }
    return ans.Fi;
}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    // freopen("check.in","r",stdin);
    // freopen("check.out","w",stdout);
    init(6010); cin >> n;
    for(int i=1; i<=3; i++) cin >> (a[i]+1);
    bool ok=1;
    for(int i=1; i<=3; i+=2)
    {
        if(a[i][1] == 'x' || a[i][n] == 'x') ok = 0;
        for(int j=2; j<=n; j++)
            if(a[i][j] == 'x' && a[i][j-1] == 'x') ok = 0;
    }
    if(!ok) return cout << "0\n",0;
    for(int i=1; i<=n; i++)
    {
        is[i] = (a[2][i] == 'x');
        w[i] = ((a[1][i] == 'x') + (a[3][i] == 'x'));
    }
    int dt = 0, ext = 0;
    for(int i=1; i<=n; i++) is[i] ? (dt += w[i]+1) : (ext += w[i]);
    int res = getAns();
    for(int i=1; i<=ext; i++) res = res * (dt + i) % mod;
    cout << res << '\n';
    return 0;
}

感谢 yukuai26 老师的耐心指导 Orz


文章作者: q779
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-ND 4.0 许可协议。转载请注明来源 q779 !
评论
  目录