Tarjan算法求LCA
这是求一种 LCA 的离线算法,用的不多但是思想很有趣,并且有着美妙的 $\mathcal{O}(n+q)$ 复杂度。
其实这个算法很简单,不像网上那些瞎七搭八的文章讲的那么复杂
首先考虑将所有问题离线,每个节点记录通过询问相关联的节点,然后从根开始 dfs 。
对于每个节点,遍历它的每个子节点,并把子节点与其合并。这个合并操作我们可以用并查集来维护。
然后我们遍历所有通过询问相关联的节点。如果那个节点被访问过 $1$ 次了,那么答案就是 $\mathrm{find}(v)$ 对应的节点。
结合代码来感受一下:
void Tarjan(int u)
{
vis[u] = 1; pre[u] = u; // pre[u] 表示 u 合并到哪里了
for(int i = head[u]; i; i = e[i].next)
{ int v = e[i].v; if(!vis[v]) { Tarjan(v); merge(u,v); pre[find(u)] = u; } }
for(int i : vec[u])
{
int v = (q[i].v == u ? q[i].u : q[i].v); // 询问的另一个节点
if(vis[v]) ans[i] = pre[find(v)]; // pre维护当前(并查集)块对应的节点
}
}
为什么是对的呢?考虑 $u$ 的左子树已经被遍历的情况,此时应当递归遍历右子树。
因为左子树已经并到 $u$ 上了,所以任意 $v \in$ 左子树有「 $\mathrm{find}(v)$ 对应的节点为 $u$ 」。
则当遍历到节点 $x \in$ 右子树时,$u$ 是 $x,v$ 的 LCA 显然符合 LCA 本身的定义。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
void up(int &x,int y) { x < y ? x = y : 0; }
void down(int &x,int y) { x > y ? x = y : 0; }
#define N ((int)(5e5 + 15))
vector<int> vec[N]; bitset<N> vis;
int n,m,rt,pos = 1,ans[N],head[N],f[N],pre[N],sz[N];
struct Edge { int u,v,next; } e[N * 2], q[N];
void addEdge(int u,int v) { e[++pos] = {u,v,head[u]}; head[u] = pos; }
void init(int len) { for(int i = 1; i <= len; i++) f[i] = i, sz[i] = 1; }
int find(int x) { return f[x] == x ? x : f[x] = find(f[x]); }
void merge(int x,int y)
{
x = find(x); y = find(y);
if(sz[x] > sz[y]) swap(x,y);
f[x] = y; sz[y] += sz[x];
}
void Tarjan(int u)
{
vis[u] = 1; pre[u] = u;
for(int i = head[u]; i; i = e[i].next)
{ int v = e[i].v; if(!vis[v]) { Tarjan(v); merge(u,v); pre[find(u)] = u; } }
for(int i : vec[u])
{
int v = (q[i].v == u ? q[i].u : q[i].v);
if(vis[v]) ans[i] = pre[find(v)];
}
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
cin >> n >> m >> rt; init(n);
for(int i = 1, u,v; i < n; i++) { cin >> u >> v; addEdge(u,v); addEdge(v,u); }
for(int i = 1, u,v; i <= m; i++) {
cin >> u >> v; q[i] = {u,v};
vec[u].push_back(i); vec[v].push_back(i);
}
Tarjan(rt); for(int i = 1; i <= m; i++) cout << ans[i] << '\n';
return 0;
}
参考文献: