洛谷P4721 【模板】分治 FFT 题解
题目链接:P4721 【模板】分治 FFT
题意:
给定序列 $g_{1\dots n - 1}$,求序列 $f_{0\dots n - 1}$。
其中
边界为 $f_0=1$。
答案对 $998244353$ 取模。
输入格式:
第一行一个整数 $n$ 。
第二行 $n-1$ 个整数 $g_{1\dots n - 1}$。
输出格式:
一行 $n$ 个整数,表示 $f_{0\dots n - 1}$ 对 $998244353$ 取模后的值。
数据范围:
$2\leq n\leq 10^5$,$0\leq g_i<998244353$。
本题可以通过多项式求逆解决,不过这就体现不出「分治 FFT 模板题」的意义了。
分治 FFT/分治NTT 主要利用的是 CDQ 分治的思想。
先求出 $[l,\,\mathrm{mid}-1)$ 的答案(即 $f$ ),可以发现这部分对于 $[\mathrm{mid},r)$ 的贡献为
其中 $*$ 表示卷积,$f_{[l,r]}$ 表示由 $f_l,f_{l+1},\cdots,f_r$ 构成的序列,$g$ 同理。
那么我们只需要在分治的过程中计算左侧对右侧的贡献即可
时间复杂度 $\mathcal{O}(n \log^2 n)$ ,显然 $998244353$ 是 NTT 模数,可以直接 NTT 。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
#define rep(i, a, b) for(int i = (a), i##END = (b); i <= i##END; i++)
#define Rep(i, a, b) for(int i = (a), i##END = (b); i >= i##END; i--)
#define N ((int)(2e5 + 15))
namespace NTT
{
#define NTT_N ((N + N) * 2)
const int P = 998244353;
const int G = 3, Gi = 332748118;
void add(int &x, int y) { (x += y) >= P ? x -= P : 0; }
int qpow(int a, int b)
{
int r = 1;
for(a %= P; b; b >>= 1, a = 1ll * a * a % P)
if(b & 1) r = 1ll * r * a % P;
return r;
}
int l, limit, r[NTT_N], a[NTT_N], b[NTT_N];
void init(int len) // size(a) plus size(b)
{
for(limit = 1, l = 0; limit <= len; limit *= 2, ++l);
for(int i = 0; i < limit; i++)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
void NTT(int *A, int type)
{
for(int i = 0; i < limit; i++) if(i < r[i]) swap(A[i], A[r[i]]);
for(int mid = 1; mid < limit; mid *= 2)
{
int Wn = qpow((type == 1 ? G : Gi), (P - 1) / (mid * 2));
for(int j = 0; j < limit; j += (mid * 2))
{
int w = 1;
for(int k = 0; k < mid; k++, w = 1ll * w * Wn % P)
{
int x = A[j + k], y = 1ll * w * A[j + k + mid] % P;
A[j + k] = (x + y) % P; A[j + k + mid] = (x - y + P) % P;
}
}
}
if(type != 1)
{
const int Inv = qpow(limit, P - 2);
for(int i = 0; i < limit; i++) A[i] = 1ll * A[i] * Inv % P;
}
}
int* convolution(int *A, int n, int *B, int m)
{
rep(i, 0, limit - 1) a[i] = b[i] = 0;
rep(i, 0, n) a[i] = A[i]; rep(i, 0, m) b[i] = B[i];
init(n + m + 1); NTT(a, 1); NTT(b, 1);
for(int i = 0; i < limit; i++) a[i] = 1ll * a[i] * b[i] % P;
NTT(a, -1); return a;
}
}
using NTT::qpow, NTT::P, NTT::add;
#define inv(x) (qpow(x, P - 2))
int n, a[N], b[N], f[N], g[N];
void solve(int l, int r)
{
if(r - l < 2) return;
int mid = (l + r) >> 1; solve(l, mid);
memset(a + (r - l) / 2, 0, (r - l) / 2 * sizeof(int));
memcpy(a, f + l, (r - l) / 2 * sizeof(int));
memcpy(b, g, (r - l) * sizeof(int));
int *c = NTT::convolution(a, r - l, b, r - l);
rep(i, mid, r - 1) add(f[i], c[i - l]);
solve(mid, r);
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
cin >> n; g[0] = 0; f[0] = 1;
rep(i, 1, n - 1) cin >> g[i];
solve(0, n);
rep(i, 0, n - 1) cout << f[i] << " \n"[i == n - 1];
return 0;
}
参考文献: