嘘~ 正在从服务器偷取页面 . . .

快速沃尔什变换 FWT


快速沃尔什变换 FWT

模板题链接:P4717 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)

题意

给定长度为 $2^n$ 两个序列 $A,B$,设

分别当 $\operatorname\circ$ 是 or, and, xor 时求出 $C$。

输入格式

第一行,一个整数 $n$。

第二行,$2^n$ 个数 $A_0, A_1, \ldots, A_{2^n-1}$。

第三行,$2^n$ 个数 $B_0, B_1, \ldots, B_{2^n-1}$。

输出格式

三行,每行 $2^n$ 个数,分别代表 $\operatorname\circ$ 是 or, and, xor 时 $C_0, C_1, \ldots, C_{2^n-1}$ 的值$\bmod 998244353$。

数据范围

$1 \le n \le 17$。


$\mathcal{Part}\ 0$ 简介

在 OI 中,FWT 是用于解决对下标进行位运算卷积问题的方法。

具体地,对于已知序列 $a,b$ ,FWT 可以在 $\mathcal{O}(n \log n)$ 的复杂度内计算出

其中 $\operatorname{\circ}$ 为二元位运算的一种。在本题中即为 $\lor,\,\land,\,\oplus$ (或/与/异或)

FWT 的灵感来源于 FFT ,具体思路详见参考文献2。本文仅介绍基本的原理。


$\mathcal{Part}\ 1$ 或运算

计算

由于 $(j \lor k) = i,~(k \lor i) = i$ 可推出 $((j \lor k) \lor i) = i$ 。

考虑构造 $f_a(i) = \sum_{(j \lor i) = i}a_j$ ,那么

考虑如何求 $f_a$ ,令 $a_0$ 为 $a$ 中下标最高位为 $0$ 的那部分,$a_1$ 表示 $a$ 中下标最高位为 $1$ 的那部分,则

分治即可。反正,我们要转换回 $a$ ,根据上式则有

提示:上述公式中 $\cup$ 表示序列的拼接,$+$ 和 $-$ 表示对应位置相加减。

这部分的代码:

void OR(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++) add(f[i + j + k], f[i + j] * x % mod);
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = - 1

$\mathcal{Part}\ 2$ 与运算

和或运算同理,直接贴代码了。

void AND(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++) add(f[i + j], f[i + j + k] * x % mod);
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = - 1

$\mathcal{Part}\ 3$ 异或运算

定义 $x\otimes y = \mathrm{popcount}(x\land y) \bmod 2$ ,即 $x,y$ 的与在二进制下 $1$ 的个数

那么 $(i \otimes j) \oplus (i \otimes k) = i \otimes (j \oplus k)$

考虑构造

那么

因此

提示:这里 $\frac{1}{2}$ 也是对应位置乘上 $\frac{1}{2}$ 的意思。

这部分的代码:

void XOR(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++)
            {
                f[i + j] = (f[i + j] + f[i + j + k]) % mod;
                f[i + j + k] = (f[i + j] + mod - 2 * f[i + j + k] % mod) % mod;
                f[i + j] = f[i + j] * x % mod; f[i + j + k] = f[i + j + k] * x % mod;
            }
} // a -> f(a) 则 x = 1, f(a) -> a 则 x = 1 / 2

注意这里是 $\frac{1}{2}$ 。


$\mathcal{Part}\ 4$ 完整实现

时间复杂度 $\mathcal{O}(n \log n)$ ,其中 $n$ 是序列的长度

代码:

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
const int mod = 998244353;
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
void add(int &x, int y) { (x += y) >= mod ? x -= mod : 0; }
#define rep(i, a, b) for(int i = (a), i##END = (b); i <= i##END; i++)
#define Rep(i, a, b) for(int i = (a), i##END = (b); i >= i##END; i--)
#define N ((int)(1 << 17) + 15)

const int inv2 = 499122177;
int n, A[N], B[N], a[N], b[N];
void OR(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++) add(f[i + j + k], f[i + j] * x % mod);
}
void AND(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++) add(f[i + j], f[i + j + k] * x % mod);
}
void XOR(int *f, int x = 1)
{
    for(int o = 2, k = 1; o <= n; o <<= 1, k <<= 1)
        for(int i = 0; i < n; i += o)
            for(int j = 0; j < k; j++)
            {
                f[i + j] = (f[i + j] + f[i + j + k]) % mod;
                f[i + j + k] = (f[i + j] + mod - 2 * f[i + j + k] % mod) % mod;
                f[i + j] = f[i + j] * x % mod; f[i + j + k] = f[i + j + k] * x % mod;
            }
}
void init() { rep(i, 0, n - 1) { a[i] = A[i]; b[i] = B[i]; } }
void print(int *f) { rep(i, 0, n - 1) cout << f[i] << " \n"[i == n - 1]; }
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    // freopen("check.in","r",stdin);
    // freopen("check.out","w",stdout);
    int m; cin >> m; n = 1 << m;
    rep(i, 0, n - 1) { cin >> A[i], A[i] %= mod; }
    rep(i, 0, n - 1) { cin >> B[i], B[i] %= mod; }
    
    init(); OR(a); OR(b);
    rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
    OR(a, mod - 1); print(a);

    
    init(); AND(a); AND(b);
    rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
    AND(a, mod - 1); print(a);

    init(); XOR(a); XOR(b);
    rep(i, 0, n - 1) a[i] = a[i] * b[i] % mod;
    XOR(a, inv2); print(a);
    return 0;
}

参考文献

[1] https://www.luogu.com.cn/article/2pavj2pd

[2] https://www.luogu.com.cn/article/crmftzuv


文章作者: q779
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-ND 4.0 许可协议。转载请注明来源 q779 !
评论
  目录