[ARC061F] Card Game for Three 题解
题目链接:[ARC061F] Card Game for Three
题意:
有三堆牌, 分别有 $n_1, n_2, n_3$ 张。牌上写着数字 $1,2,3$ 中的一个。
先从牌堆 $1$ 中抽一张,接下来,牌上写着几就从几号牌堆抽取。
求在所有可能的 $3^{n_1+n_2+n_3}$ 种方案中,先把牌堆 $1$ 抽空的方案数。
答案对 $10^9+7$ 取模。 $n_1, n_2, n_3 \leq 3 \times 10^5$,时限 $3 \mathrm{~s}$ 。
计数好题,是我怎么看都不可能做出来的题。不过至少我现在会了。
条件计数的题目有两种解决办法:要么容斥,要么寻找更简洁的充要条件。
这道题看上去就不像容斥,因此考虑寻找充要条件,将原题的计数转化为更简单的题目的计数。
把抽出来的牌排成一个序列,显然每种放置方式都恰好对应一个序列。【构造映射】
注意到牌可能拿不完,因此一个序列可能对应多种方案。
具体地,一个长度为 $m$ 的序列对应 $3^{n_1 + n_2 + n_3 - m}$ 种方案。【检查反映射】
可以发现,该序列仅有的约束,要求率先将堆 $1$ 拿空。【检查充要条件】
于是,问题就变成了:对每个长度,求先将堆 $1$ 拿空的序列的个数。
因为操作序列中一定恰有 $n_1$ 个 $1$ ,且最后一个必须是 $1$ ,我们枚举抽出的非 $1$ 牌个数 $k$ ,则方案数为
其中,$\binom{k + n_1 - 1}{k}$ 表示 $n_1-1$ 个自由的 $1$ 与非 $1$ 数混合的方案数,后面的求和是瓜分非 $1$ 牌的方案数。
然而式子中的后半部分并不能快速求解,不妨考虑递推求解
将组合数裂开(非常重要的一步):
注意式子中所有不合法的组合数均须定义为 $0$ 。
预处理出所有 $S(k)$ 后,答案即为下式:
时间复杂度 $\mathcal{O}(N)$
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
const int mod = 1e9 + 7;
void up(int &x,int y) { x < y ? x = y : 0; }
void down(int &x,int y) { x > y ? x = y : 0; }
void add(int &x,int y) { (x += y) >= mod ? x -= mod : 0; }
#define N ((int)(1e6 + 15))
int fac[N],inv[N];
int qpow(int a,int b)
{
int r = 1;
while(b) {
if(b & 1) r = r * a % mod;
b >>= 1; a = a * a % mod;
}
return r;
}
int C(int n,int m)
{
if(m < 0 || n < m) return 0;
return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
void init(int n)
{
fac[0] = 1;
for(int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
inv[n] = qpow(fac[n], mod - 2);
for(int i = n - 1; ~i; i--) inv[i] = inv[i + 1] * (i + 1) % mod;
}
void initS(int n2,int n3,int *S)
{
S[0] = 1;
for(int k = 1; k <= n2 + n3; k++)
S[k] = ((2 * S[k - 1] - C(k - 1, k - 1 - n3) - C(k - 1, n2)) % mod + mod) % mod;
}
int n1,n2,n3,S[N];
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
cin >> n1 >> n2 >> n3; init(n1 + n2 + n3); initS(n2, n3, S);
int res = 0, x = qpow(3, n2 + n3), y = qpow(3, mod - 2);
for(int k = 0; k <= n2 + n3; x = x * y % mod, k++)
add(res, x * C(n1 + k - 1, k) % mod * S[k] % mod);
cout << res << '\n';
return 0;
}
参考文献:
[1] https://www.luogu.com.cn/blog/command-block/solution-at2070