洛谷P2486 [SDOI2011]染色 题解
题目链接:P2486 [SDOI2011]染色
题意:
给定一棵 $n$ 个节点的无根树,共有 $m$ 个操作,操作分为两种:
- 将节点 $a$ 到节点 $b$ 的路径上的所有点(包括 $a$ 和 $b$)都染成颜色 $c$。
- 询问节点 $a$ 到节点 $b$ 的路径上的颜色段数量。
颜色段的定义是极长的连续相同颜色被认为是一段。例如
112221
由三段组成:11
、222
、1
。输入格式:
输入的第一行是用空格隔开的两个整数,分别代表树的节点个数 $n$ 和操作个数 $m$。
第二行有 $n$ 个用空格隔开的整数,第 $i$ 个整数 $w_i$ 代表结点 $i$ 的初始颜色。
第 $3$ 到第 $(n + 1)$ 行,每行两个用空格隔开的整数 $u, v$,代表树上存在一条连结节点 $u$ 和节点 $v$ 的边。
第 $(n + 2)$ 到第 $(n + m + 1)$ 行,每行描述一个操作,其格式为:
每行首先有一个字符 $op$,代表本次操作的类型。
- 若 $op$ 为
C
,则代表本次操作是一次染色操作,在一个空格后有三个用空格隔开的整数 $a, b, c$,代表将 $a$ 到 $b$ 的路径上所有点都染成颜色 $c$。- 若 $op$ 为
Q
,则代表本次操作是一次查询操作,在一个空格后有两个用空格隔开的整数 $a, b$,表示查询 $a$ 到 $b$ 路径上的颜色段数量。输出格式:
对于每次查询操作,输出一行一个整数代表答案。
数据范围:
对于 $100\%$ 的数据,$1 \leq n, m \leq 10^5$,$1 \leq w_i, c \leq 10^9$,$1 \leq a, b, u, v \leq n$,$op$ 一定为
C
或Q
,保证给出的图是一棵树。除原数据外,还存在一组不计分的 hack 数据。
一眼树剖,鉴定为:毒瘤。
貌似 ttys000 巨佬调了很久(( 然后我就半路跑过去写了这题
就是朴素的树剖,考虑线段树的结点维护什么
因为它求的是极大子段数,不是颜色数
所以我们只需要知道每个区间左端点、右端点的颜色,已经它有几个段即可
左右子树合并的时候要注意如果 $a.r=b.l$ ,要把答案减一,因为这俩合并了以后是 $1$ 个段
然后树剖的区间查询要注意合并顺序,建议自己画一画图
时间复杂度 $O\left((n+m) \log^2 n\right)$
代码:
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cstdarg>
#include <cmath>
#include <iomanip>
#include <random>
using namespace std;
#define DEBUG (cout << "1145141919810 ... " << '\n')
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
#define N ((int)(1e5+15))
int n,Q,idx,pos=1,head[N];
int a[N],t[N],id[N],fa[N],top[N],dep[N],son[N],sz[N];
struct Edge{int u,v,next;} e[N*2];
struct node{int lc,rc,sum,tag;} tr[N*4];
#define ls(at) (at << 1)
#define rs(at) (at << 1 | 1)
void addEdge(int u,int v)
{
e[++pos]={u,v,head[u]};
head[u]=pos;
}
void dfs1(int u,int f)
{
fa[u] = f; dep[u] = dep[f] + 1; sz[u]=1;
int mx=-1;
for(int i=head[u]; i; i=e[i].next)
{
int v=e[i].v; if(v==f) continue;
dfs1(v,u); sz[u] += sz[v];
if(mx < sz[v]) mx=sz[v], son[u] = v;
}
}
void dfs2(int u,int ftop)
{
id[u] = ++idx; t[idx] = a[u]; top[u] = ftop;
if(!son[u]) return; dfs2(son[u], ftop);
for(int i=head[u]; i; i=e[i].next)
{
int v = e[i].v;
if(v != fa[u] && v != son[u]) dfs2(v,v);
}
}
node merge(node a,node b)
{
if(a.lc==-1) return b;
if(b.rc==-1) return a;
return {a.lc, b.rc, a.sum + b.sum - (a.rc == b.lc)};
}
void push_up(int at)
{
node tmp=merge(tr[ls(at)], tr[rs(at)]);
tr[at] = tmp; tr[at].tag=0;
}
void proc(int k,int at)
{
tr[at]={k,k,1,k};
}
void push_down(int at)
{
if(tr[at].tag)
{
proc(tr[at].tag,ls(at));
proc(tr[at].tag,rs(at));
tr[at].tag=0;
}
}
void build(int l,int r,int at)
{
if(l==r)
return tr[at]={t[l],t[l],1,0}, void(0);
int mid = (l+r) >> 1;
build(l,mid,ls(at)); build(mid+1,r,rs(at));
push_up(at);
}
void update(int nl,int nr,int l,int r,int k,int at)
{
if(nl <= l && r <= nr) return tr[at]={k,k,1,k},void(0);
push_down(at);
int mid = (l+r) >> 1;
if(nl <= mid) update(nl,nr,l,mid,k,ls(at));
if(nr > mid) update(nl,nr,mid+1,r,k,rs(at));
push_up(at);
}
node query(int nl,int nr,int l,int r,int at)
{
if(nl <= l && r <= nr) return tr[at];
node res={-1,-1,0,0};
push_down(at);
int mid = (l+r) >> 1;
if(nl <= mid)
{
res=query(nl,nr,l,mid,ls(at));
if(nr > mid) res = merge(res, query(nl,nr,mid+1,r,rs(at)));
}
else if(nr > mid) res = query(nl,nr,mid+1,r,rs(at));
return res;
}
void upRange(int x,int y,int k)
{
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x,y);
update(id[top[x]],id[x],1,n,k,1);
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x,y);
update(id[x],id[y],1,n,k,1);
}
node qRange(int x,int y)
{
node res1={-1,-1,0,0},res2={-1,-1,0,0};
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]])
{
res2=merge(query(id[top[y]],id[y],1,n,1),res2);
y = fa[top[y]];
}else
{
res1=merge(query(id[top[x]],id[x],1,n,1),res1);
x = fa[top[x]];
}
}
if(dep[x] > dep[y])
{
swap(res2.lc,res2.rc);
res2=merge(res2,query(id[y],id[x],1,n,1));
return merge(res2,res1);
}else
{
swap(res1.lc,res1.rc);
res1=merge(res1,query(id[x],id[y],1,n,1));
return merge(res1,res2);
}
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
cin >> n >> Q;
for(int i=1; i<=n; i++) cin >> a[i];
for(int i=1,u,v; i<n; i++)
{
cin >> u >> v;
addEdge(u,v); addEdge(v,u);
}
// DEBUG;
dfs1(1,1); dfs2(1,1); build(1,n,1); char op;
// DEBUG;
for(int x,y,z; Q--; )
{
cin >> op >> x >> y;
if(op == 'C') cin >> z, upRange(x,y,z);
else cout << qRange(x,y).sum << '\n';
}
return 0;
}