嘘~ 正在从服务器偷取页面 . . .

[AGC023F] 01 on Tree 题解


[AGC023F] 01 on Tree 题解

题目链接:[AGC023F] 01 on Tree

题意

给出一棵 $n$ 个节点的树,以及一个空序列。每个节点上有一个取值在 $\{0, 1\}$ 中的数。

每次你可以选择没有父亲节点的点删除,并且将这个节点上的数字放在当前数列末尾。

请你求出这个数列可能得到的最小逆序对数。

$1 \le n \leq 2 \times 10^5$ 。

听说是 Exchange Argument 的经典题,之前没见过。

首先这个删父亲再倒序放实在是有些奇怪,我们可以把它看作一棵内向树的某种拓扑序。

显然,如果可选的节点中存在 $0$ ,那么优先选这些节点一定不劣。

考虑每次选择一个节点后,将其与父亲合并。

这里考虑并查集即可,然后维护当前节点有多少个 $0,1$ 。

对于两个节点 $a,b$ ,记 $a_0,a_1$ 和 $b_0,b_1$ 是他们的 $0,1$ 个数

由于 $a,b$ 的子树已经按最优方案选择了,所以我们只需要考虑最小化因他们的顺序产生的逆序对即可

显然当 $a_1 \times b_0 < b_1 \times a_0$ 时,先选 $a$ 更优,也就是 $\frac{a_1}{a_0}<\frac{b_1}{b_0}$ 的情况下会优先选 $a$ 。

那么我们搞一个小根堆,每次取出这个权值最小的点,将其与父亲合并,再把父亲扔到堆里就可以了

时间复杂度 $\mathcal{O}(n \log n)$

代码:

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
#define rep(i, a, b) for(int i = (a), i##END = (b); i <= i##END; i++)
#define Rep(i, a, b) for(int i = (a), i##END = (b); i >= i##END; i--)
#define N ((int)(2e5 + 15))

namespace DSU
{
    int f[N], cnt[N][2];
    void init(int n) { rep(i, 1, n) f[i] = i; }
    int find(int x) { return f[x] == x ? x : f[x] = find(f[x]); }
    int merge(int u, int v)
    {
        int x = find(u), y = find(v);
        int res = cnt[y][1] * cnt[x][0];
        f[x] = y; cnt[y][1] += cnt[x][1]; cnt[y][0] += cnt[x][0];
        return res;
    }
}using namespace DSU;
int n, fa[N], a[N]; bool vis[N];
typedef pair<double, int> pdi;
priority_queue<pdi, vector<pdi>, greater<pdi>> q;
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    // freopen("check.in","r",stdin);
    // freopen("check.out","w",stdout);
    cin >> n; init(n); rep(i, 2, n) cin >> fa[i];
    rep(i, 1, n)
    {
        cin >> a[i]; ++cnt[i][a[i]];
        if(!cnt[i][0]) q.push({1e100, i});
        else q.push({(double)cnt[i][1] / cnt[i][0], i});
    }
    int res = 0;
    while(!q.empty())
    {
        int u = q.top().second; q.pop();
        if(vis[u]) continue;
        vis[u] = true; res += merge(u, fa[u]);
        int v = find(fa[u]); if(!v) continue;
        if(!cnt[v][0]) q.push({1e100, v});
        else q.push({(double)cnt[v][1] / cnt[v][0], v});
    }
    cout << res << '\n';
    return 0;
}

参考文献

[1] https://www.luogu.com.cn/article/cxok8ft8


文章作者: q779
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-ND 4.0 许可协议。转载请注明来源 q779 !
评论
  目录