模拟赛题讲解[9]
来自 Roundgod 2022-07-29 noi.ac #2693
原题来自 ABC172E NEQ
问题描述:
给定 $n,m$,你需要计算满足以下条件的数组对 $A=[a_1,a_2,\dots,a_n]$ 和 $B=[b_1,b_2,\dots,b_n]$ 的个数:
- 对于所有 $1\leq i\leq n$ ,都有 $1\leq a_i\leq m$ 以及 $1\leq b_i\leq m$.
- 对于所有 $1\leq i\leq n$ ,都有 $a_i\neq b_i$.
- 对于所有的 $1\leq i\lt j\leq n$ ,都有 $a_i\neq a_j$ 且 $b_i\neq b_j$.
由于答案可能过大,你需要输出答案对 $10^9+7$取模后的值。
输入格式:
输入第一行包含一个整数 $t(1\leq t\leq 10)$ ,表示数据的组数。 对于每组测试数据,输入为一行,包含两个整数 $n,m$。
输出格式:
对于每组测试数据,在一行中输出一个整数,表示答案。
输入1:
2 2
2 3
1 1
10 10
114514 1919810
输出1
2
18
0
306442892
145678131
样例说明:
对于样例的第一组测试数据, 合法数组对有以下 $2$ 种:
数据范围:
对于 $20\%$ 的数据,$1\leq n\leq m\leq 10$
对于 $100\%$ 的数据,$1\leq n\leq m\leq 2\times 10^6$
题解:
显然的容斥题。设性质 $p_i:a_i=b_i,S \subseteq [m]$ (这里 $[m]=\{1,2,\dots,m\}$)
则我们可以用这个经典公式来计算题目要求的 $a_i \ne b_i$ (注意公式中是 $|S|$ 而不是 $|S|+1$ )
考虑这里的 $N(S)$ 如何计算。不难发现
中括号前面的指对 $S$ 指定的这 $i$ 个 $a_j=b_j$ 染色,后者就是对 $a_j \ne b_j$ 的染色
因为 $a,b$ 此时染色相互独立,所以直接乘法原理(就是那个平方)
你敢信我模拟赛这么水的题没想出来,主要还是 组合数 学的不扎实
所以答案就是
这里的 $i$ 就是在枚举 $|S|$ , $\binom{n}{i}$ 就是大小等于 $i$ 的 $S$ 的个数
然后预处理一下阶乘啊逆元啊什么的就好啦
时间复杂度 $\mathcal{O}(Qn)$
代码:
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <iomanip>
#include <random>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
#define N (int)(2e6+15)
const int p=1e9+7;
int n,m,fac[N],invf[N];
int qpow(int a,int b)
{
int ans=1,base=a%p;
while(b)
{
if(b&1) ans=ans*base%p;
base=base*base%p;
b>>=1;
}
return ans;
}
int inv(int x){return qpow(x,p-2);}
void init()
{
invf[0]=fac[0]=1;
for(int i=1; i<=N-5; i++)
fac[i]=fac[i-1]*i%p;
invf[N-5]=inv(fac[N-5]);
for(int i=N-5-1; i>=1; i--)
invf[i]=invf[i+1]*(i+1)%p;
}
int A(int n,int k)
{
if(n<k) return 0;
return fac[n]*invf[n-k]%p;
}
int C(int n,int k)
{
if(n<k) return 0;
return fac[n]*invf[k]%p*invf[n-k]%p;
}
void add(int &x,int y){x+=y;if(x>=p)x-=p;}
void dec(int &x,int y){x-=y;if(x<0)x+=p;}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
init(); int Q; cin >> Q;
while(Q--)
{
cin >> n >> m;
int res=0;
for(int i=0; i<=n; i++)
{
int tmp=C(n,i)%p*A(m,i)%p*A(m-i,n-i)%p*A(m-i,n-i)%p;
if(i&1) dec(res,tmp); else add(res,tmp);
}
cout << res << '\n';
}
return 0;
}