嘘~ 正在从服务器偷取页面 . . .

洛谷P3047 [USACO12FEB]Nearby Cows G 题解


洛谷P3047 [USACO12FEB]Nearby Cows G 题解

题目链接:P3047 [USACO12FEB]Nearby Cows G

题意

给你一棵 \(n\) 个点的树,点带权,对于每个节点求出距离它不超过 \(k\) 的所有节点权值和 \(m_i\)

「数据范围」 对于 \(100\%\) 的数据:\(1 \le n \le 10^5\)\(1 \le k \le 20\)\(0 \le c_i \le 1000\)

换根dp,也叫二次扫描法

这题随便找一个根当做树根,然后扫两遍才能出答案

首先设 \(f[u][j]\) 表示 \(u\) 所在子树,与 \(u\) 相距恰好 \(j\) 的结点个数,则 \[ f[u][j]=\sum_{v \in \text{son}[u]} f[v][j-1] \] 这是第一遍dfs,可以发现我们没有从非 \(u\) 所在子树获取答案

考虑第二遍dfs,设 \(g[u][j]\) 表示在整棵树中与 \(u\) 相距恰好 \(j\) 的结点个数

这个答案一定是从父节点的 \(g\) 转移而来

但是这里会有一个问题,\(g[fa][j-1]\) 包含了从 \(f[u][j-2]\) 转移来的答案

直接加的话会导致重复,考虑容斥

不懂的话建议画个图,别像我一样一开始干瞪着 qwq\[ g[u][j]=f[u][j]+g[fa][j-1]-f[u][j-2] \] 然后 \(g\) 可以直接在 \(f\) 上搞,节约空间

时间复杂度 \(O(n)\)

代码:

#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <iomanip>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
#define N (int)(1e5+5)

struct Edge
{
    int u,v,next;
}e[N<<1];
int n,k,num;
int f[N][21],pos=1,head[N],dep[N];
void addEdge(int u,int v)
{
    e[++pos]={u,v,head[u]};
    head[u]=pos;
}
void dfs1(int u)
{
    for(int i=head[u]; i; i=e[i].next)
    {
        int v=e[i].v;if(dep[v])continue;
        dep[v]=dep[u]+1;dfs1(v);
        for(int j=1; j<=k; j++)
            f[u][j]+=f[v][j-1];
    }
}
void dfs2(int u)
{
    for(int i=head[u]; i; i=e[i].next)
    {
        int v=e[i].v;
        if(dep[v]<dep[u])continue;
        for(int j=k; j>=2; j--)
            f[v][j]-=f[v][j-2];
        for(int j=1; j<=k; j++)
            f[v][j]+=f[u][j-1];
        dfs2(v);
    }
}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    // freopen("check.in","r",stdin);
    // freopen("check.out","w",stdout);
    cin >> n >> k;
    for(int i=1,u,v; i<n; i++)
    {
        cin >> u >> v;
        addEdge(u,v);addEdge(v,u);
    }
    for(int i=1; i<=n; i++)
        cin >> f[i][0];
    dep[1]=1;dfs1(1);dfs2(1);
    for(int i=1; i<=n; i++)
    {
        int res=0;
        for(int j=0; j<=k; j++)
            res+=f[i][j];
        cout << res << '\n';
    }
    return 0;
}

文章作者: q779
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-ND 4.0 许可协议。转载请注明来源 q779 !
评论
  目录