洛谷P5904 [POI2014] HOT-Hotels 加强版 题解
题目链接:P5904 [POI2014] HOT-Hotels 加强版
题意:
给出一棵有 $n$ 个点的树,求有多少组点 $(i,j,k)$ 满足 $i,j,k$ 两两之间的距离都相等。
$(i,j,k)$ 与 $(i,k,j)$ 算作同一组。
输入格式:
第一行一个整数 $n$。
接下来 $n-1$ 行,每行两个整数 $a,b$,表示在 $a,b$ 之间有一条边。
输出格式:
一行一个整数,表示所有合法的点的组数。
数据范围:
对于 $100\%$ 的数据, $1\le n\le10^5, 1\le a\le b\le n$。
设 $f(i,j)$ 为 $i$ 所在子树中 $\mathrm{dis}(x,i)=j$ 的 $x$ 的个数。
设 $g(i,j)$ 为 $i$ 所在子树中满足
的无序二元组 $(x,y)$ 的个数。并记 $\mathrm{Ans}$ 为答案,那么
这里 $\uparrow$ 均表示增量。直接转移的话,时间复杂度为 $\mathcal{O}(n^2)$ 。
注意到这里的第二维为深度,而每个节点所在子树的深度其实很多时候并没有 $\mathcal{O}(n)$ 。
考虑利用长链剖分优化转移,并利用指针直接 $\mathcal{O}(1)$ 继承重儿子的信息。
因为每个节点只属于一条长链,所以长链的总长度为 $\mathcal{O}(n)$ 。
对于每条轻链只会向重链合并一次,这相当于删去了轻链,故均摊也是 $\mathcal{O}(n)$ 的。
故总时间复杂度为 $\mathcal{O}(n)$
代码:
#include <bits/stdc++.h>
using namespace std;
// #define int long long
// #define INF 0x3f3f3f3f3f3f3f3f
typedef long long ll;
void up(int &x, int y) { x < y ? x = y : 0; }
void down(int &x, int y) { x > y ? x = y : 0; }
#define rep(i, a, b) for(int i = (a), i##END = (b); i <= i##END; i++)
#define Rep(i, a, b) for(int i = (a), i##END = (b); i >= i##END; i--)
#define pb push_back
#define N ((int)(1e5 + 15))
vector<int> vec[N];
int son[N], mxdep[N]; ll res, *f[N], *g[N], p[N * 4], *o = p;
void dfs(int u, int fa)
{
for(auto v : vec[u]) if(v != fa)
if(dfs(v, u), mxdep[v] > mxdep[son[u]]) son[u] = v;
mxdep[u] = mxdep[son[u]] + 1;
}
void DP(int u, int fa)
{
if(son[u]) { f[son[u]] = f[u] + 1; g[son[u]] = g[u] - 1; DP(son[u], u); }
f[u][0] = 1; res += g[u][0];
for(auto v : vec[u]) if(v != fa && v != son[u])
{
f[v] = o; o += mxdep[v] * 2; g[v] = o; o += mxdep[v] * 2;
DP(v, u);
rep(i, 0, mxdep[v] - 1)
{
if(i) res += f[u][i - 1] * g[v][i];
res += g[u][i + 1] * f[v][i];
}
rep(i, 0, mxdep[v] - 1)
{
g[u][i + 1] += f[u][i + 1] * f[v][i];
if(i) g[u][i - 1] += g[v][i];
f[u][i + 1] += f[v][i];
}
}
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
int n; cin >> n;
for(int i = 1, u, v; i < n; i++)
{
cin >> u >> v;
vec[u].pb(v); vec[v].pb(u);
}
dfs(1, 0);
f[1] = o; o += mxdep[1] * 2;
g[1] = o; o += mxdep[1] * 2;
DP(1, 0); cout << res << '\n';
return 0;
}
参考文献: