嘘~ 正在从服务器偷取页面 . . .

快速沃尔什变换 FWT


快速沃尔什变换 FWT

模板题链接:P4717 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)

题意

给定长度为 \(2^n\) 两个序列 \(A,B\),设 \[ C_i=\sum_{(j\operatorname\circ k) = i}A_j \times B_k \]

分别当 \(\operatorname\circ\) 是 or, and, xor 时求出 \(C\)

输入格式

第一行,一个整数 \(n\)

第二行,\(2^n\) 个数 \(A_0, A_1, \ldots, A_{2^n-1}\)

第三行,\(2^n\) 个数 \(B_0, B_1, \ldots, B_{2^n-1}\)

输出格式

三行,每行 \(2^n\) 个数,分别代表 \(\operatorname\circ\) 是 or, and, xor 时 \(C_0, C_1, \ldots, C_{2^n-1}\) 的值\(\bmod 998244353\)

数据范围

\(1 \le n \le 17\)


\(\mathcal{Part}\ 0\) 简介

在 OI 中,FWT 是用于解决对下标进行位运算卷积问题的方法。

具体地,对于已知序列 \(a,b\) ,FWT 可以在 \(\mathcal{O}(n \log n)\) 的复杂度内计算出 \[ c_i=\sum_{i=(j \operatorname{\circ} k)} a_j b_k \] 其中 \(\operatorname{\circ}\) 为二元位运算的一种。在本题中即为 \(\lor,\,\land,\,\oplus\) (或/与/异或)

FWT 的灵感来源于 FFT ,具体思路详见参考文献2。本文仅介绍基本的原理。


\(\mathcal{Part}\ 1\) 或运算

计算 \[ c_i=\sum_{i=(j \operatorname\lor k)} a_j b_k \] 由于 \((j \lor k) = i,~(k \lor i) = i\) 可推出 \(((j \lor k) \lor i) = i\)

考虑构造 \(f_a(i) = \sum_{(j \lor i) = i}a_j\) ,那么 \[ \begin{aligned} f_a \times f_b & =\left(\sum_{(j \lor i)=i} a_j\right)\left(\sum_{(k \lor i)=i} b_k\right) \\[6pt]& =\sum_{(j \lor i) = i}\sum_{(k \lor i)=i} a_j b_k \\[6pt]& =\sum_{((j \lor k) \lor i) = i} a_j b_k \\[6pt]& =f_c \end{aligned} \] 考虑如何求 \(f_a\) ,令 \(a_0\)\(a\) 中下标最高位为 \(0\) 的那部分,\(a_1\) 表示 \(a\) 中下标最高位为 \(1\) 的那部分,则 \[ f_a = f_{a_0} \cup (f_{a_0} + f_{a_1}) \] 分治即可。反正,我们要转换回 \(a\) ,根据上式则有 \[ a = a_0 \cup (a_1 - a_0) \]

提示:上述公式中 \(\cup\) 表示序列的拼接,\(+\)\(-\) 表示对应位置相加减。

这部分的代码:

void OR(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++) add(f[i + j + k], f[i + j] * x % mod);
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = - 1

\(\mathcal{Part}\ 2\) 与运算

和或运算同理,直接贴代码了。

void AND(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++) add(f[i + j], f[i + j + k] * x % mod);
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = - 1

\(\mathcal{Part}\ 3\) 异或运算

定义 \(x\otimes y = \mathrm{popcount}(x\land y) \bmod 2\) ,即 \(x,y\) 的与在二进制下 \(1\) 的个数

那么 \((i \otimes j) \oplus (i \otimes k) = i \otimes (j \oplus k)\)

考虑构造 \[ f_a(i) = \sum_{(i\otimes j) = 0} a_j - \sum_{(i\otimes j) = 1}a_j \] 那么 \[ \small \begin{aligned} f_a\times f_b & =\left(\sum_{(i \otimes j)=0} a_j-\sum_{(i \otimes j)=1} a_j\right)\left(\sum_{(i \otimes k)=0} b_k-\sum_{(i \otimes k)=1} b_k\right) \\[6pt]& =\left(\sum_{(i \otimes j)=0} a_j\right)\left(\sum_{(i \otimes k)=0} b_k\right)-\left(\sum_{(i \otimes j)=0} a_j\right)\left(\sum_{(i \otimes k)=1} b_k\right) -\left(\sum_{(i \otimes j)=1} a_j\right)\left(\sum_{(i \otimes k)=0} b_k\right)+\left(\sum_{(i \otimes j)=1} a_j\right)\left(\sum_{(i \otimes k)=1} b_k\right) \\[6pt]& =\sum_{(i \otimes(j \oplus k))=0} a_j b_k-\sum_{(i \otimes(j \oplus k))=1} a_j b_k \\[6pt]& =f_c \end{aligned} \] 因此 \[ f_a = (f_{a_0} + f_{a_1})\cup(f_{a_0} - f_{a_1}) \\[6pt] a = \frac{a_0 + a_1}{2} \cup \frac{a_0 - a_1}{2} \]

提示:这里 \(\frac{1}{2}\) 也是对应位置乘上 \(\frac{1}{2}\) 的意思。

这部分的代码:

void XOR(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++)
            {
                f[i + j] = (f[i + j] + f[i + j + k]) % mod;
                f[i + j + k] = (f[i + j] + mod - 2 * f[i + j + k] % mod) % mod;
                f[i + j] = f[i + j] * x % mod; f[i + j + k] = f[i + j + k] * x % mod;
            }
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = 1 / 2

注意这里是 \(\frac{1}{2}\)


\(\mathcal{Part}\ 4\) 完整实现

时间复杂度 \(\mathcal{O}(n \log n)\) ,其中 \(n\) 是序列的长度

代码:

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
const int mod = 998244353;
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
void add(int &x, int y) { (x += y) >= mod ? x -= mod : 0; }
#define rep(i, a, b) for(int i = (a), i##END = (b); i <= i##END; i++)
#define Rep(i, a, b) for(int i = (a), i##END = (b); i >= i##END; i--)
#define N ((int)(1 << 17) + 15)

const int inv2 = 499122177;
int n, A[N], B[N], a[N], b[N];
void OR(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++) add(f[i + j + k], f[i + j] * x % mod);
}
void AND(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++) add(f[i + j], f[i + j + k] * x % mod);
}
void XOR(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++)
            {
                f[i + j] = (f[i + j] + f[i + j + k]) % mod;
                f[i + j + k] = (f[i + j] + mod - 2 * f[i + j + k] % mod) % mod;
                f[i + j] = f[i + j] * x % mod; f[i + j + k] = f[i + j + k] * x % mod;
            }
}
void init() { rep(i, 0, n - 1) { a[i] = A[i]; b[i] = B[i]; } }
void print(int *f) { rep(i, 0, n - 1) cout << f[i] << " \n"[i == n - 1]; }
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    // freopen("check.in","r",stdin);
    // freopen("check.out","w",stdout);
    int m; cin >> m; n = 1 << m;
    rep(i, 0, n - 1) { cin >> A[i], A[i] %= mod; }
    rep(i, 0, n - 1) { cin >> B[i], B[i] %= mod; }
    
    init(); OR(a); OR(b);
    rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
    OR(a, mod - 1); print(a);

    
    init(); AND(a); AND(b);
    rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
    AND(a, mod - 1); print(a);

    init(); XOR(a); XOR(b);
    rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
    XOR(a, inv2); print(a);
    return 0;
}

参考文献

[1] https://www.luogu.com.cn/article/2pavj2pd

[2] https://www.luogu.com.cn/article/crmftzuv


文章作者: q779
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-ND 4.0 许可协议。转载请注明来源 q779 !
评论
  目录