洛谷P3294 [SCOI2016]背单词 题解
题目链接:P3294 [SCOI2016]背单词
题意:给定 $n$ 个不同的字符串,求一个最优顺序使得花费最小
设当前字符串为 $a$ ,位于排列中的 $x$ 位置
- 如果 $a$ 存在后缀且 $a$ 的后缀在 $a$ 之后,花费增加 $n^2$
- 如果 $a$ 不存在后缀,则花费增加 $x$
- 设 $y$ 为 $a$ 之前与其相距最小的,且是 $a$ 的后缀的字符串的位置(前提是 $a$ 存在后缀),则花费增加 $x-y$
根据贪心,前两个条件基本上没有用
同时为了方便处理,把所有字符串反转,考虑其前缀。
原问题转化为:
设某个字符串为 $a_i$ ,位于排列中的 $x_i$ 位置
$a_i$ 所有前缀必须在 $a_i$ 之前
设 $y_i$ 为 $a_i$ 之前与其相距最小的,且是 $a_i$ 的前缀的字符串的位置
如果 $a$ 不存在前缀,则 $y_i=0$ ,记 $v_i=x_i-y_i$
给定 $n$ 个字符串,求一种排列顺序,使得 $\sum v_i$ 最小,求出这个最小值。
考虑建 $\tt{trie}$ 树以存储字符串
然后重构整棵树,只保留「是字符串结尾的结点」,注意根节点也要保留
然后对于每个结点,将其子结点按子树大小从小到大排序
然后跑一遍dfs就好了
重构是为了方便排序子结点。
排序的原因:
设当前结点为 $u$ ,则 $u$ 的所有子结点都有相同的前缀
这个前缀就是根节点到 $u$ 形成的字符串。
对于排序后在后面的子结点,想在排列中距离 $u$ 尽可能小
就需要前面的子结点所在子树大小尽可能小,这样他们之间夹着的字符串就少
根据贪心,不难发现这样可以最小化答案
时间复杂度 $O(L \log L)$ ,其中 $L=\sum |s_i|$
代码:
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <iomanip>
#include <random>
using namespace std;
#define ll long long
// #define int long long
// #define INF 0x3f3f3f3f3f3f3f3f
#define L (int)(5.1e5+15)
bool cmp(int a,int b);
struct Trie
{
vector<int> vec[L]; ll res;
int tot,trie[L][26],ed[L],last[L],sz[L],cnt;
void insert(string s)
{
int u=0;
for(int i=0; i<s.size(); i++)
{
int c=s[i]-'a';
if(!trie[u][c])trie[u][c]=++tot;
u=trie[u][c];
}
ed[u]=1;
}
void rebd(int u)
{
if(ed[u]&&u)
{
vec[last[u]].push_back(u);
last[u]=u;
}
for(int i=0; i<26; i++)
if(trie[u][i])
{
int v=trie[u][i];
last[v]=last[u]; rebd(v);
}
}
void dfs(int u)
{
sz[u]=1;
for(int i=0; i<vec[u].size(); i++)
{
dfs(vec[u][i]);
sz[u]+=sz[vec[u][i]];
}
sort(vec[u].begin(),vec[u].end(),cmp);
}
void solve(int u)
{
int dfn=cnt++;
for(int i=0; i<vec[u].size(); i++)
{
res+=cnt-dfn;
solve(vec[u][i]);
}
}
void main()
{
ed[0]=1; rebd(0);
dfs(0); solve(0);
cout << res << '\n';
}
}tr;
bool cmp(int a,int b){return tr.sz[a]<tr.sz[b];}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
int n; string s;
cin >> n;
for(int i=1; i<=n; i++)
{
cin >> s;
reverse(s.begin(),s.end());
tr.insert(s);
}
tr.main();
return 0;
}