洛谷P4178 Tree 题解
题目链接:P4178 Tree
题意:
给定一棵 $n$ 个节点的树,每条边有边权,求出树上两点距离小于等于 $k$ 的点对数量。
输入格式:
第一行输入一个整数 $n$,表示节点个数。
第二行到第 $n$ 行每行输入三个整数 $u,v,w$ ,表示 $u$ 与 $v$ 有一条边,边权是 $w$。
第 $n+1$ 行一个整数 $k$ 。
输出格式:
一行一个整数,表示答案。
数据范围:
$1\leq n\leq 4\times 10^4,~1\leq u,v\leq n,~0\leq w\leq 10^3,~0\leq k\leq 2\times 10^4$。
这个题和 P3806 【模板】点分治 1 的差别是,询问只有一个,且问的是小于等于 $k$ 的。
是不是直接把已经知道的距离的值排个序,然后用两个指针跑一遍,就能得到答案了呢?
并不是,因为我们会统计到一些不合法的路径,比如下图
那么解决办法其实很简单,我们在枚举重心的儿子时,把重复的那一段,即 $2\times w(u,v)$ 加上
然后重新算一遍 $v$ 子树的距离,把答案减去统计到的结果即可。可能有点抽象,直接看代码吧。
时间复杂度 $\mathcal{O}(n \log^2 n)$
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
void up(int &x,int y) { x < y ? x = y : 0; }
void down(int &x,int y) { x > y ? x = y : 0; }
#define N ((int)(4e4 + 15))
int n, m, rt, rt_fa, res, sum, K, pos = 1, tot;
int head[N], sz[N], mx[N], q[N], val[N], dis[N]; char vis[N];
struct Edge { int u, v, w, next; } e[N * 2];
void addEdge(int u, int v, int w) {
e[++pos] = {u, v, w, head[u]}; head[u] = pos;
}
void getroot(int u, int fa)
{
mx[u] = 0; sz[u] = 1;
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].v; if(v == fa || vis[v]) continue;
getroot(v, u); sz[u] += sz[v]; up(mx[u], sz[v]);
}
up(mx[u], sum - sz[u]); if(mx[u] < mx[rt]) { rt = u, rt_fa = fa; }
}
void getdis(int u, int fa)
{
val[++tot] = dis[u];
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].v; if(v == fa || vis[v]) continue;
dis[v] = dis[u] + e[i].w; getdis(v, u);
}
}
int calc(int u, int w)
{
tot = 0; dis[u] = w; getdis(u, 0);
sort(val + 1, val + 1 + tot);
int l = 1, r = tot, num = 0;
while(l <= r) if(val[l] + val[r] <= K) { num += r - l, ++l; } else --r;
return num;
}
void solve(int u)
{
vis[u] = true; res += calc(u, 0);
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].v; if(vis[v]) continue;
res -= calc(v, e[i].w);
sum = v == rt_fa ? n - sz[u] : sz[v];
mx[rt = rt_fa = 0] = n; getroot(v, u); solve(rt);
}
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
cin >> n;
for(int i = 1, u, v, w; i < n; i++)
{
cin >> u >> v >> w;
addEdge(u, v, w); addEdge(v, u, w);
}
cin >> K;
mx[rt = rt_fa = 0] = sum = n;
getroot(1, 0); solve(rt); cout << res << '\n';
return 0;
}
参考文献: