快速沃尔什变换 FWT
模板题链接:P4717 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)
题意:
给定长度为 \(2^n\) 两个序列 \(A,B\),设 \[ C_i=\sum_{(j\operatorname\circ k) = i}A_j \times B_k \]
分别当 \(\operatorname\circ\) 是 or, and, xor 时求出 \(C\)。
输入格式:
第一行,一个整数 \(n\)。
第二行,\(2^n\) 个数 \(A_0, A_1, \ldots, A_{2^n-1}\)。
第三行,\(2^n\) 个数 \(B_0, B_1, \ldots, B_{2^n-1}\)。
输出格式:
三行,每行 \(2^n\) 个数,分别代表 \(\operatorname\circ\) 是 or, and, xor 时 \(C_0, C_1, \ldots, C_{2^n-1}\) 的值\(\bmod 998244353\)。
数据范围:
\(1 \le n \le 17\)。
\(\mathcal{Part}\ 0\) 简介
在 OI 中,FWT 是用于解决对下标进行位运算卷积问题的方法。
具体地,对于已知序列 \(a,b\) ,FWT 可以在 \(\mathcal{O}(n \log n)\) 的复杂度内计算出 \[ c_i=\sum_{i=(j \operatorname{\circ} k)} a_j b_k \] 其中 \(\operatorname{\circ}\) 为二元位运算的一种。在本题中即为 \(\lor,\,\land,\,\oplus\) (或/与/异或)
FWT 的灵感来源于 FFT ,具体思路详见参考文献2。本文仅介绍基本的原理。
\(\mathcal{Part}\ 1\) 或运算
计算 \[ c_i=\sum_{i=(j \operatorname\lor k)} a_j b_k \] 由于 \((j \lor k) = i,~(k \lor i) = i\) 可推出 \(((j \lor k) \lor i) = i\) 。
考虑构造 \(f_a(i) = \sum_{(j \lor i) = i}a_j\) ,那么 \[ \begin{aligned} f_a \times f_b & =\left(\sum_{(j \lor i)=i} a_j\right)\left(\sum_{(k \lor i)=i} b_k\right) \\[6pt]& =\sum_{(j \lor i) = i}\sum_{(k \lor i)=i} a_j b_k \\[6pt]& =\sum_{((j \lor k) \lor i) = i} a_j b_k \\[6pt]& =f_c \end{aligned} \] 考虑如何求 \(f_a\) ,令 \(a_0\) 为 \(a\) 中下标最高位为 \(0\) 的那部分,\(a_1\) 表示 \(a\) 中下标最高位为 \(1\) 的那部分,则 \[ f_a = f_{a_0} \cup (f_{a_0} + f_{a_1}) \] 分治即可。反正,我们要转换回 \(a\) ,根据上式则有 \[ a = a_0 \cup (a_1 - a_0) \]
提示:上述公式中 \(\cup\) 表示序列的拼接,\(+\) 和 \(-\) 表示对应位置相加减。
这部分的代码:
void OR(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++) add(f[i + j + k], f[i + j] * x % mod);
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = - 1
\(\mathcal{Part}\ 2\) 与运算
和或运算同理,直接贴代码了。
void AND(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++) add(f[i + j], f[i + j + k] * x % mod);
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = - 1
\(\mathcal{Part}\ 3\) 异或运算
定义 \(x\otimes y = \mathrm{popcount}(x\land y) \bmod 2\) ,即 \(x,y\) 的与在二进制下 \(1\) 的个数
那么 \((i \otimes j) \oplus (i \otimes k) = i \otimes (j \oplus k)\)
考虑构造 \[ f_a(i) = \sum_{(i\otimes j) = 0} a_j - \sum_{(i\otimes j) = 1}a_j \] 那么 \[ \small \begin{aligned} f_a\times f_b & =\left(\sum_{(i \otimes j)=0} a_j-\sum_{(i \otimes j)=1} a_j\right)\left(\sum_{(i \otimes k)=0} b_k-\sum_{(i \otimes k)=1} b_k\right) \\[6pt]& =\left(\sum_{(i \otimes j)=0} a_j\right)\left(\sum_{(i \otimes k)=0} b_k\right)-\left(\sum_{(i \otimes j)=0} a_j\right)\left(\sum_{(i \otimes k)=1} b_k\right) -\left(\sum_{(i \otimes j)=1} a_j\right)\left(\sum_{(i \otimes k)=0} b_k\right)+\left(\sum_{(i \otimes j)=1} a_j\right)\left(\sum_{(i \otimes k)=1} b_k\right) \\[6pt]& =\sum_{(i \otimes(j \oplus k))=0} a_j b_k-\sum_{(i \otimes(j \oplus k))=1} a_j b_k \\[6pt]& =f_c \end{aligned} \] 因此 \[ f_a = (f_{a_0} + f_{a_1})\cup(f_{a_0} - f_{a_1}) \\[6pt] a = \frac{a_0 + a_1}{2} \cup \frac{a_0 - a_1}{2} \]
提示:这里 \(\frac{1}{2}\) 也是对应位置乘上 \(\frac{1}{2}\) 的意思。
这部分的代码:
void XOR(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++)
{
f[i + j] = (f[i + j] + f[i + j + k]) % mod;
f[i + j + k] = (f[i + j] + mod - 2 * f[i + j + k] % mod) % mod;
f[i + j] = f[i + j] * x % mod; f[i + j + k] = f[i + j + k] * x % mod;
}
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = 1 / 2
注意这里是 \(\frac{1}{2}\) 。
\(\mathcal{Part}\ 4\) 完整实现
时间复杂度 \(\mathcal{O}(n \log n)\) ,其中 \(n\) 是序列的长度
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
const int mod = 998244353;
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
void add(int &x, int y) { (x += y) >= mod ? x -= mod : 0; }
#define rep(i, a, b) for(int i = (a), i##END = (b); i <= i##END; i++)
#define Rep(i, a, b) for(int i = (a), i##END = (b); i >= i##END; i--)
#define N ((int)(1 << 17) + 15)
const int inv2 = 499122177;
int n, A[N], B[N], a[N], b[N];
void OR(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++) add(f[i + j + k], f[i + j] * x % mod);
}
void AND(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++) add(f[i + j], f[i + j + k] * x % mod);
}
void XOR(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++)
{
f[i + j] = (f[i + j] + f[i + j + k]) % mod;
f[i + j + k] = (f[i + j] + mod - 2 * f[i + j + k] % mod) % mod;
f[i + j] = f[i + j] * x % mod; f[i + j + k] = f[i + j + k] * x % mod;
}
}
void init() { rep(i, 0, n - 1) { a[i] = A[i]; b[i] = B[i]; } }
void print(int *f) { rep(i, 0, n - 1) cout << f[i] << " \n"[i == n - 1]; }
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
int m; cin >> m; n = 1 << m;
rep(i, 0, n - 1) { cin >> A[i], A[i] %= mod; }
rep(i, 0, n - 1) { cin >> B[i], B[i] %= mod; }
init(); OR(a); OR(b);
rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
OR(a, mod - 1); print(a);
init(); AND(a); AND(b);
rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
AND(a, mod - 1); print(a);
init(); XOR(a); XOR(b);
rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
XOR(a, inv2); print(a);
return 0;
}
参考文献: