洛谷P4302 [SCOI2003]字符串折叠 题解
题意:
折叠的定义如下:
- 一个字符串可以看成它自身的折叠,记作 $S = S$ 。
- $X(S)$ 是 $X(X>1)$ 个 $S$ 连接在一起的串的折叠,记作 $X(S) = SSSS\dots S(x \,\texttt{个}\,S)$。
- 如果 $A = A^{\prime}, B = B^{\prime}$ ,则 $AB = A^{\prime}B^{\prime}$ 。例如,因为 $\mathtt{3(A)} = \mathtt{AAA}, \mathtt{2(B)} = \mathtt{BB}$,所以 $\mathtt{3(A)C2(B)} = \mathtt{AAACBB}$ ,而 $\mathtt{2(3(A)C)2(B)} = \mathtt{AAACAAACBB}$ 。
给一个字符串,求它的最短折叠。例如 $\mathtt{AAAAAAAAAABABABCCD}$ 的最短折叠为:$\mathtt{9(A)3(AB)CCD}$。
输入格式:
仅一行,即字符串 $S$ ,长度保证不超过 $100$ 。
输出格式:
仅一行,即最短的折叠长度。
区间dp + kmp 神仙题。
设 $f_{i,j}$ 表示区间 $[i,j]$ 的最短折叠长度。
显然边界为 $f_{i,i} = 1$ ,最后答案为 $f_{1,n}$ 。
下文中使用 $a \downarrow b+c$ 表示 $a \leftarrow \min\{a,b+c\}$
考虑转移。第一种,直接两个串合并
第二种,折叠。一个原始的想法是,按套路把两个区间合并,计算他们的贡献
但我们无法得知两个区间中的某个是否已经被折叠过了,以及它们各被折叠了几次
于是尝试直接暴力枚举折叠后的子串长度,然后转移。
记 $d=j-i+1$ ,并且 $[i,j]$ 由 $X$ 个长度为 $k$ 的串拼接而成( $d=X\cdot k$ ),则
然后稍微优化一下,比如枚举 $d$ 的因数,时间复杂度 $O(n^3 \log n)$ (貌似可以水过?)
仔细思考可以发现,很多枚举是不必要的。
根据 border 定理,一个串有长度为 $k$ 的周期当且仅当它有长为 $d-k$ 的 border 。
因此我们可以预处理出每个子串的最长 border (显然跑 $n$ 遍 $\mathtt{kmp}$ 即可)
其实就是预处理每个子串的最小周期 $k_0$
然后枚举 $k=t\cdot k_0$ 为周期,这样就可以在 $O(\log n)$ 的时间内转移第二部分了
结果复杂度瓶颈就变成了区间dp的转移(悲
时间复杂度 $O(n^3)$
代码:
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cstdarg>
#include <cmath>
#include <iomanip>
#include <random>
using namespace std;
#define int long long
#define INF 0x3f3f3f3f3f3f3f3f
#define N ((int)(125))
char str[N];
int n,val[N],f[N][N],g[N][N];
void down(int &x,int y) { x > y ? x = y : 0;}
void init()
{
int x=0;
for(; x<10; ++x) val[x] = 3;
for(; x<100; ++x) val[x] = 4;
for(; x<N-5; ++x) val[x] = 5;
}
void KMP(int t)
{
int *fail = g[t]+t-1; char *s = str+t-1;
// cout << (s+1) << '\n';
fail[1]=0;
for(int i=2,j=0; i<=n-t+1; i++)
{
while(j && s[j+1] != s[i]) j=fail[j];
if(s[j+1] == s[i]) ++j; fail[i]=j;
}
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("check.in","r",stdin);
// freopen("check.out","w",stdout);
cin >> (str+1); n=strlen(str+1); init();
memset(f,0x3f,sizeof(f));
for(int i=1; i<=n; i++) {KMP(i); f[i][i] = 1;}
// for(int i=1; i<=n; i++)
// for(int j=1; j<=n; j++)
// cout << g[i][j] << " \n"[j==n];
for(int len=2,L; len<=n; len++)
for(int i=1,j=i+len-1; j<=n; i++,j++)
{
for(int k=i; k<j; k++)
down(f[i][j], f[i][k]+f[k+1][j]);
L=len-g[i][j];
if(L <= g[i][j])
for(int k=L; k<=len; k+=L)
if(len%k == 0)
down(f[i][j], f[i][i+k-1] + val[len/k]);
}
cout << f[1][n] << '\n';
return 0;
}
参考文献:
[1] https://yhx-12243.github.io/OI-transit/records/lydsy1090.html