洛谷P5405 [CTS2019] 氪金手游 题解
题目链接:P5405 [CTS2019] 氪金手游
题意:
小刘同学是一个喜欢氪金手游的男孩子。
他最近迷上了一个新游戏,游戏的内容就是不断地抽卡。现在已知:
- 卡池里总共有 $N$ 种卡,第 $i$ 种卡有一个权值 $W_i$,小刘同学不知道 $W_i$ 具体的值是什么。但是他通过和网友交流,他了解到 $W_i$ 服从一个分布。
- 具体地,对每个 $i$,小刘了解到三个参数 $p_{i,1},p_{i,2},p_{i,3}$,$W_i$ 将会以 $p_{i,j}$ 的概率取值为 $j$,保证 $p_{i,1}+p_{i,2}+p_{i,3}=1$。
小刘开始玩游戏了,他每次会氪一元钱来抽一张卡,其中抽到卡 $i$ 的概率为:$\frac{W_i}{\sum_j W_j}$
小刘会不停地抽卡,直到他手里集齐了全部 $N$ 种卡。
抽卡结束之后,服务器记录下来了小刘第一次得到每张卡的时间 $T_i$。游戏公司在这里设置了一个彩蛋:公司准备了 $N-1$ 个二元组 $(u_i,v_i)$,如果对任意的 $i$,成立 $T_{u_i}<T_{v_i}$,那么游戏公司就会认为小刘是极其幸运的,从而送给他一个橱柜的手办作为幸运大奖。
游戏公司为了降低获奖概率,它准备的这些 $(u_i,v_i)$ 满足这样一个性质:
对于任意的 $\varnothing\ne S\subsetneq\{1,2,\ldots,N\}$,总能找到 $(u_i,v_i)$ 满足:$u_i\in S,v_i\notin S$ 或者 $u_i\notin S,v_i\in S$。
请你求出小刘同学能够得到幸运大奖的概率,可以保证结果是一个有理数,请输出它对 $998244353$ 取模的结果。
输入格式:
第一行一个整数 $N$,表示卡的种类数。
接下来 $N$ 行,每行三个整数 $a_{i,1},a_{i,2},a_{i,3}$,而题目给出的 $p_{i,j}=\frac{a_{i,j}}{a_{i,1}+a_{i,2}+a_{i,3}}$。
接下来 $N-1$ 行,每行两个整数 $u_i,v_i$,描述一个二元组(意义见题目描述)。
输出格式:
输出一行一个整数,表示所求概率对 $998244353$ 取模的结果。
数据范围:
对于全部的测试数据,保证 $N\le 1000$,$a_{i,j}\le 10^6$。
题目给的奇怪性质,就是说 $(u_i,v_i)$ 的基图(转成无向边后的图)是一棵树。
首先考虑最简单的情况,即这张图是一个根为 $1$ 的外向树。
考虑 $u$ 所在子树,我们要求 $T_u$ 小于其子树中任何的 $T$ 。
假设 $W$ 已经确定,我们枚举 $t=T_u$ ,即可得到 $u$ 合法的概率
其中 $\mathrm{T}(u)$ 表示 $u$ 所在子树。因为结果只和 $u$ 的子树有关,那么就可以 dp 了
设 $f(u,s)$ 表示 $u$ 所在子树 $\sum_{i \in \mathrm{T}(u)} W_i=s$ 时的中奖概率。
这是一个树上背包,复杂度 $\mathcal{O}(n^2)$ 。然后不要忘记乘上 $W$ 每种情况的贡献。
那么原题中的内向边怎么办呢?考虑容斥,用不考虑当前边的情况减去外向边的情况。
怎么算不考虑当前边呢?那就是转移的时候任何合法的情况都能算上
时间复杂度 $\mathcal{O}(n^2)$
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
const int mod = 998244353;
void up(int &x,int y) { x < y ? x = y : 0; }
void down(int &x,int y) { x > y ? x = y : 0; }
void add(int &x, int y) { (x += y) >= mod ? x -= mod : 0; }
#define N ((int)(1e3 + 15))
vector< pair<int, bool> > vec[N];
int n, a[N][3], inv[N * 3], f[N][N * 3], g[N * 3], s[N];
int qpow(int a, int b)
{
int r = 1;
while(b)
{
if(b & 1) r = r * a % mod;
b >>= 1; a = a * a % mod;
}
return r;
}
void dfs(int u, int fa)
{
f[u][0] = 1;
for(auto [v, t] : vec[u]) if(v != fa)
{
dfs(v, u);
for(int i = 0; i <= s[u] * 3; i++)
for(int j = 0; j <= s[v] * 3; j++)
{
int p = f[u][i] * f[v][j] % mod;
if(t) { add(g[i], p); add(g[i + j], mod - p); }
else { add(g[i + j], p); }
}
s[u] += s[v];
copy_n(g, s[u] * 3 + 1, f[u]);
fill_n(g, s[u] * 3 + 1, 0ll);
}
for(int i = s[u] * 3; ~i; f[u][i] = 0, i--)
for(int j = 1; j <= 3; j++)
add(f[u][i + j], f[u][i] * a[u][j - 1] % mod * j * inv[i + j] % mod);
++s[u];
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
int n; cin >> n;
for(int i = 1; i <= n; i++)
{
cin >> a[i][0] >> a[i][1] >> a[i][2];
int _inv = qpow(a[i][0] + a[i][1] + a[i][2], mod - 2);
a[i][0] = a[i][0] * _inv % mod;
a[i][1] = a[i][1] * _inv % mod;
a[i][2] = a[i][2] * _inv % mod;
}
inv[1] = 1;
for(int i = 2; i <= n * 3; i++)
inv[i] = (mod - mod / i) * inv[mod % i] % mod;
for(int i = 1, u, v; i < n; i++)
{
cin >> u >> v;
vec[u].push_back({v, 0});
vec[v].push_back({u, 1});
}
dfs(1, 0); int res = 0;
for(int i = 1; i <= n * 3; i++) add(res, f[1][i]);
cout << res << '\n';
return 0;
}
参考文献:
[1] https://www.luogu.com.cn/article/kek3z48i
[2] https://www.luogu.com.cn/article/5h1eol46
题外话:
放张好看的图