模拟赛题讲解[24]
来自 s_r_f 2023-08-02 noi.ac #3192
题目描述:
小 B 有一个数字 $x$,一开始是 $0$。
他会 $n$ 种加法魔法和 $m$ 种乘法魔法:第 $i$ 种加法魔法会让自己的数字 $x$ 增加 $a_i$ ,第 $i$ 种乘法魔法会让自己的数字 $x$ 变为原来的 $b_i$ 倍。所有的 $a_i$ 和 $b_i$ 都非负。
他想知道,每种魔法必须恰好使用一次的情况下,数字 $x$ 的最大值是多少,以及有多少种不同的顺序能让 $x$ 达到这个最大值。
不同编号的魔法,即使效果一样也被看做不同的。
因为答案可能非常大,所以你只需要输出答案对 $998244353$ 取模的结果。
提示: 在模 $998244353$ 意义下,对于非零 $x$,有 $x^{-1} \equiv x^{998244351} \pmod {998244353}$
输入格式:
第一行,两个整数 $n$ $m$。
第二行,$n$ 个整数 $a_1, a_2,\cdots, a_n$
第三行,$m$ 个整数 $b_1, b_2,\cdots, b_m$
输出格式:
输出两行。
第一行一个非负整数 $\mathrm{Max}$ ,表示 $x$ 的最大值 对 $998244353$ 取模的结果。
第二行一个非负整数 $\mathrm{Ans}$ ,表示顺序数量 对 $998244353$ 取模的结果。
样例输入1:
2 2
2 3
4 5
样例输出1:
100
4
样例输入2:
2 2
0 1
4 5
样例输出2:
20
8
样例输入3:
2 2
0 1
0 30
样例输出3:
30
4
数据范围:
子任务 1(50pts) : $n + m \leq 10,0 \leq a_i,b_i \leq 2$
子任务 2(30pts) $:$ $a_i \leq 1$
子任务 3(20pts) : $0 \leq n,m \leq 100000,0\leq a_i,b_i \leq 10^3$
题解:
可以看出我的组合数有多烂。
第一问的话,可以证明最终答案一定形如 $(x\times 0+a_i+\cdots)\times b_i\times \cdots$ ,注意乘 $0$ 要在一开始就乘掉。
考虑第二问。如果不存在 $a_i > 0$ ,则所有运算可以任意排序,答案为 $(n + m)!$ 。
如果存在 $a_i = 0$ ,则
- 只考虑 $b_i=0,~b_i > 1$ 和 $a_i > 0$ ,此时的顺序不能乱搞,因此此部分为他们各自的阶乘的乘积。
- 考虑加入 $b_i = 1$ 和 $a_i = 0$ 的运算,这些运算就可以放在任何位置。
因此答案为
其中 $c$ 分别代表 $b_i=1$ 和 $a_i=0$ 、$b_i=0$ 、$b_i > 1$ 、$a_i > 0$ 。
时间复杂度 $\mathcal{O}(n+m)$
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
const int mod = 998244353;
void up(int &x,int y) { x < y ? x = y : 0; }
void down(int &x,int y) { x > y ? x = y : 0; }
void add(int &x, int y) { (x += y) >= mod ? x -= mod : 0; }
#define N ((int)(4e5 + 15))
int n,m,mx,cnt1,cnt2,p1,p2,a1[N],b1[N],a[N],b[N],fac[N << 2],inv[N << 2];
int qpow(int a,int b)
{
int c = 1;
for(; b; b >>= 1) {
if(b & 1) { c = c * a % mod; }
a = a * a % mod;
}
return c;
}
int C(int n,int m)
{
if(n < m) return 0;
return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
fac[0] = inv[0] = 1;
for(int i = 1; i <= N - 10; i++)
{
fac[i] = fac[i - 1] * i % mod;
inv[i] = qpow(fac[i], mod - 2);
}
cin >> n >> m;
for(int i = 1; i <= n; i++)
{
cin >> a[i]; add(mx, a[i]);
if(a[i] == 0) ++cnt1; else a1[++p1] = a[i];
}
for(int i = 1; i <= m; i++)
{
cin >> b[i];
if(b[i] == 0) ++cnt2;
else {
if(b[i] == 1) ++cnt1; else b1[++p2] = b[i];
mx = mx * b[i] % mod;
}
}
cout << mx << '\n';
if(!mx) { cout << fac[n + m] << '\n'; } else {
cout << fac[p1] * fac[p2] % mod * fac[cnt1] % mod * fac[cnt2] % mod * C(p1 + p2 + cnt1 + cnt2, cnt1) % mod << '\n';
}
return 0;
}