快速沃尔什变换 FWT
模板题链接:P4717 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)
题意:
给定长度为 $2^n$ 两个序列 $A,B$,设
分别当 $\operatorname\circ$ 是 or, and, xor 时求出 $C$。
输入格式:
第一行,一个整数 $n$。
第二行,$2^n$ 个数 $A_0, A_1, \ldots, A_{2^n-1}$。
第三行,$2^n$ 个数 $B_0, B_1, \ldots, B_{2^n-1}$。
输出格式:
三行,每行 $2^n$ 个数,分别代表 $\operatorname\circ$ 是 or, and, xor 时 $C_0, C_1, \ldots, C_{2^n-1}$ 的值$\bmod 998244353$。
数据范围:
$1 \le n \le 17$。
$\mathcal{Part}\ 0$ 简介
在 OI 中,FWT 是用于解决对下标进行位运算卷积问题的方法。
具体地,对于已知序列 $a,b$ ,FWT 可以在 $\mathcal{O}(n \log n)$ 的复杂度内计算出
其中 $\operatorname{\circ}$ 为二元位运算的一种。在本题中即为 $\lor,\,\land,\,\oplus$ (或/与/异或)
FWT 的灵感来源于 FFT ,具体思路详见参考文献2。本文仅介绍基本的原理。
$\mathcal{Part}\ 1$ 或运算
计算
由于 $(j \lor k) = i,~(k \lor i) = i$ 可推出 $((j \lor k) \lor i) = i$ 。
考虑构造 $f_a(i) = \sum_{(j \lor i) = i}a_j$ ,那么
考虑如何求 $f_a$ ,令 $a_0$ 为 $a$ 中下标最高位为 $0$ 的那部分,$a_1$ 表示 $a$ 中下标最高位为 $1$ 的那部分,则
分治即可。反正,我们要转换回 $a$ ,根据上式则有
提示:上述公式中 $\cup$ 表示序列的拼接,$+$ 和 $-$ 表示对应位置相加减。
这部分的代码:
void OR(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++) add(f[i + j + k], f[i + j] * x % mod);
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = - 1
$\mathcal{Part}\ 2$ 与运算
和或运算同理,直接贴代码了。
void AND(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++) add(f[i + j], f[i + j + k] * x % mod);
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = - 1
$\mathcal{Part}\ 3$ 异或运算
定义 $x\otimes y = \mathrm{popcount}(x\land y) \bmod 2$ ,即 $x,y$ 的与在二进制下 $1$ 的个数
那么 $(i \otimes j) \oplus (i \otimes k) = i \otimes (j \oplus k)$
考虑构造
那么
因此
提示:这里 $\frac{1}{2}$ 也是对应位置乘上 $\frac{1}{2}$ 的意思。
这部分的代码:
void XOR(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++)
{
f[i + j] = (f[i + j] + f[i + j + k]) % mod;
f[i + j + k] = (f[i + j] + mod - 2 * f[i + j + k] % mod) % mod;
f[i + j] = f[i + j] * x % mod; f[i + j + k] = f[i + j + k] * x % mod;
}
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = 1 / 2
注意这里是 $\frac{1}{2}$ 。
$\mathcal{Part}\ 4$ 完整实现
时间复杂度 $\mathcal{O}(n \log n)$ ,其中 $n$ 是序列的长度
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
const int mod = 998244353;
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
void add(int &x, int y) { (x += y) >= mod ? x -= mod : 0; }
#define rep(i, a, b) for(int i = (a), i##END = (b); i <= i##END; i++)
#define Rep(i, a, b) for(int i = (a), i##END = (b); i >= i##END; i--)
#define N ((int)(1 << 17) + 15)
const int inv2 = 499122177;
int n, A[N], B[N], a[N], b[N];
void OR(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++) add(f[i + j + k], f[i + j] * x % mod);
}
void AND(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++) add(f[i + j], f[i + j + k] * x % mod);
}
void XOR(int *f, int x = 1)
{
for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
for(int i = 0; i < n; i += o)
for(int j = 0; j < k; j++)
{
f[i + j] = (f[i + j] + f[i + j + k]) % mod;
f[i + j + k] = (f[i + j] + mod - 2 * f[i + j + k] % mod) % mod;
f[i + j] = f[i + j] * x % mod; f[i + j + k] = f[i + j + k] * x % mod;
}
}
void init() { rep(i, 0, n - 1) { a[i] = A[i]; b[i] = B[i]; } }
void print(int *f) { rep(i, 0, n - 1) cout << f[i] << " \n"[i == n - 1]; }
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
int m; cin >> m; n = 1 << m;
rep(i, 0, n - 1) { cin >> A[i], A[i] %= mod; }
rep(i, 0, n - 1) { cin >> B[i], B[i] %= mod; }
init(); OR(a); OR(b);
rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
OR(a, mod - 1); print(a);
init(); AND(a); AND(b);
rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
AND(a, mod - 1); print(a);
init(); XOR(a); XOR(b);
rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
XOR(a, inv2); print(a);
return 0;
}
参考文献: