CF915F Imbalance Value of a Tree 题解
题目链接:CF915F Imbalance Value of a Tree
题意:
给定一棵 \(n\) 个点的树,每个点有点权 \(a_i\) 。
定义 \(I(x,y)\) 为 \(x\) 到 \(y\) 的简单路径上的 \(\max\{a_i\}-\min\{a_i\}\)
求 \[ \sum_{i=1}^{n}\sum_{j=1}^{n} I(i,j) \] 数据范围:
\(1 \le n \le 10^6\)
这个 \(I(i,j)\) 是真的难看。考虑化简。 \[ \sum_{i=1}^{n}\sum_{j=1}^{n} f(i,j)-\sum_{i=1}^{n}\sum_{j=1}^{n} g(i,j) \] \(f(i,j)\) 表示 \(i\) 到 \(j\) 的简单路径上的 \(\max\{a_i\}\)
\(g(i,j)\) 表示 \(i\) 到 \(j\) 的简单路径上的 \(\min\{a_i\}\)
套路地,由于我们只需要知道总答案,而并非每一条路径的答案
于是考虑每个点的贡献。但是这个点就很不好处理。
我们可以先看看最熟悉的边权怎么处理
不失一般性,考虑边权最大值的贡献。
如果是边权的话,我们可以把所有的边从小到大排序
然后一条条边的加入,计算贡献后合并边两端的结点
对于当前的边 \((u,v)\) ,它的贡献就是 \[ w_{\text{mx}}(u,v) \times \mathrm{size}(u) \times \mathrm{size}(v) \] 其中 \(\mathrm{size}(u)\) 表示当前 \(u\) 所在连通块的大小。
然后我们考虑点权怎么处理
考虑一条路径 \(a \rightarrow b \rightarrow c\)
它的点权最大值就是(懒得写 \(a_a,a_b,a_c\) ,看得懂就好) \[
\max\{a,b,c\}
\] 根据 \(\max\)
多次嵌套不变的性质,可知 \[
\max\{a,b,c\}=\max\left\{\max\{a,b\}, ~\max\{b,c\}\right\}
\] 注意到这构成了 \((a,b),(b,c)\)
的配对,于是我们就可以把点权化为边权了。 \[
w_{\text{mx}}(u,v) \leftarrow \max\{a_u,a_v\}
\] 最小值也是一样的,就是从大到小排序而已。
时间复杂度 \(O(n)\)
树上的题目,如果 \(n\le 10^6\) ,一般都是 \(O(n)\) 的
代码:
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cstdarg>
#include <cmath>
#include <iomanip>
#include <random>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
#define N ((int)(1e6+15))
int n,res,f[N],sz[N],val[N];
struct Edge{int x,y,mx,mn;}e[N];
void add(int &x,int y) {x += y;}
void init(int n){for(int i=1; i<=n; i++) f[i]=i, sz[i]=1;}
int find(int x){return f[x] == x ? x : f[x] = find(f[x]);}
void merge(int u,int v)
{
int x=find(u),y=find(v);
if(sz[x] > sz[y]) swap(x,y);
f[x]=y; sz[y] += sz[x];
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
cin >> n;
for(int i=1; i<=n; i++) cin >> val[i];
for(int i=1; i<n; i++)
{
cin >> e[i].x >> e[i].y;
e[i].mx = max(val[e[i].x],val[e[i].y]);
e[i].mn = min(val[e[i].x],val[e[i].y]);
}
init(n); sort(e+1,e+n,[](Edge a,Edge b){return a.mn>b.mn;});
for(int i=1,x,y; i<n; i++)
{
x=find(e[i].x); y=find(e[i].y);
add(res, -e[i].mn * sz[x] * sz[y]); merge(x,y);
}
init(n); sort(e+1,e+n,[](Edge a,Edge b){return a.mx<b.mx;});
for(int i=1,x,y; i<n; i++)
{
x=find(e[i].x); y=find(e[i].y);
add(res, e[i].mx * sz[x] * sz[y]); merge(x,y);
}
cout << res << '\n';
return 0;
}
参考文献: