CF1401D Maximum Distributed Tree 题解
题目链接:CF1401D Maximum Distributed Tree
题意:
给定一棵 $n$ 个节点,$n-1$ 条边的树。你可以在每一条树上的边标上边权,使得:
- 每个边权都为 正整数;
- 这 $n-1$ 个边权的 乘积 等于 $k$;
- 边权为 $1$ 的边的数量最少。
定义 $f(u,v)$ 表示节点 $u$ 到节点 $v$ 的简单路径经过的边权总和。你的任务是让 $\sum\limits_{i=1}^{n-1}\sum\limits_{j=i+1}^{n} f(i,j)$ 最大。
最终答案可能很大,对 $10^9+7$ 取模即可。
$k$ 有可能很大,输入数据中包含了 $m$ 个质数 $p_i$,那么 $k$ 为这些质数的乘积。
输入格式:
第一行,一个整数 $t$ $(1\leq t\leq 100)$,表示多组测试数据个数。对于每一个测试数据:
第一行,一个整数 $n$ $(2\leq n\leq 10^5)$,表示树上节点数;
第 $2$ 至 $n$ 行,每行两个整数 $u_i$ 和 $v_i$ $(1\leq u_i,v_i \leq n,u_i\neq v_i)$,描述了一条无向边;
第 $n+1$ 行,一个整数 $m$ $(1\leq m\leq 6\times 10^4)$,表示 $k$ 分解成质因子的个数;
第 $n+2$ 行,$m$ 个 质数 $p_i$ $(2\leq p_i< 6\times 10^4)$,有 $k=\prod_{i=1}^m p_i$。
数据保证所有的 $n$ 总和不超过 $10^5$,所有的 $m$ 总和不超过 $6\times 10^4$。数据给出的边保证能够形成一棵树。
输出格式:
一行,一个整数,表示最大的答案对 $10^9+7$ 取模后的值。
考虑每条边的贡献。
一条边被经过的次数显然为 $\mathtt{size}(u) \times (n-\mathtt{size}(u))$ ,$u$ 为边上的一个端点。
因此我们按这个次数排序从大到小,然后贪心地给予较大的边权。
注意如果 $m > n-1$ ,我们可以把多出来的边权全部给最前面的那条边。
时间复杂度 $\mathcal{O}(n \log n)$
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
const int mod = 1e9+7;
void up(int &x,int y) { x < y ? x = y : 0; }
void down(int &x,int y) { x > y ? x = y : 0; }
#define N ((int)(1e5+15))
vector<int> val, pri;
int n,m,pos=1,head[N],sz[N];
struct Edge { int u,v,next; } e[N * 2];
void addEdge(int u,int v) { e[++pos] = {u,v,head[u]}; head[u] = pos; }
void dfs(int u,int fa)
{
sz[u] = 1;
for(int i=head[u]; i; i=e[i].next)
{
int v = e[i].v; if(v == fa) continue;
dfs(v,u); sz[u] += sz[v];
}
if(u != 1) val.push_back(sz[u] * (n - sz[u]));
}
void clear()
{ pos = 1; for(int i=1; i<=n; i++) head[i]=0; val.clear(); pri.clear(); }
void work()
{
cin >> n; clear();
for(int i=1,u,v; i<n; i++) { cin >> u >> v; addEdge(u,v); addEdge(v,u); }
dfs(1,0); sort(val.begin(), val.end(), greater<int>());
cin >> m;
for(int i=1,x; i<=m; i++) { cin >> x; pri.push_back(x); }
sort(pri.begin(), pri.end(), greater<int>());
if(m <= n-1) for(int i=0; i<m; i++) val[i] = val[i] * pri[i] % mod;
else
{
reverse(pri.begin(), pri.end());
while(pri.size() >= n)
{
val[0] = val[0] * pri.back() % mod;
pri.pop_back();
}
reverse(pri.begin(), pri.end());
for(int i=0; i<pri.size(); i++) val[i] = val[i] * pri[i] % mod;
}
int res = 0;
for(int v : val) res = (res + v) % mod;
cout << res << '\n';
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
int _Q; cin >> _Q; while(_Q--) work();
return 0;
}