浅谈舞蹈链(DLX)
一、舞蹈链
舞蹈链 (Dancing links),也叫 DLX ,是由 Donald Knuth 提出的数据结构,目的是快速实现他提出的X算法。X算法是一种递归算法,时间复杂度不确定,深度优先,通过回溯寻找精确覆盖问题所有可能的解
(以上摘自维基百科)
舞蹈链的主要思想来源于双向链表
我们设 $l[x]$ 表示元素 $x$ 的左指针, $r[x]$ 表示元素 $x$ 的右指针
显然,如果想要删除元素 $x$ ,我们可以做以下操作
r[l[x]]=r[x]; // x左侧的元素的右指针指向x右侧的元素
l[r[x]]=l[x]; // x右侧的元素的左指针指向x左侧的元素
那恢复元素 $x$ 呢? 我们可以发现删除 $x$ 的时候, $x$ 的左右指针并没有改变,即 $l[x]$ 和 $r[x]$ 并没有改变,于是我们可以做以下操作
r[l[x]]=x; // x左侧的元素的右指针重新指向x
l[r[x]]=x; // x右侧的元素的左指针重新指向x
这样如果 $x$ 左右两侧没有改变,我们就可以恢复 $x$ 所在的位置
那么精确覆盖问题又是什么呢?
给定矩阵,要求选出一个由若干行组成的集合,使得每一列上都有且仅有一个 $1$
例如该矩阵选出的行为 $1,4,5$ 行
我们来模拟一下朴素X算法求解的过程
以下过程用红色表示选择了这一行,绿色表示存在冲突的元素,灰色表示删除的行
1.选择第一行
2.标记所有和第一行冲突的元素
3.删除存在冲突的行
4.接着选择第二行
5.标记与第二行冲突的元素
6.删除存在冲突的行
7.发现没有可以选择的行了,而已选的不满足要求,回溯,选择第四行
8.接下来的同理,不断执行,直到找到答案
我们会发现, $X$ 算法花了大量的时间在找 $1$ ,而且删改很不方便
为了解决这个问题,舞蹈链就产生了
模板题 $\to$ P4929 【模板】舞蹈链(DLX)
(注:为了方便讲述,以下引用这篇博客中的图片(感谢图片的作者!))
舞蹈链的结构即交叉十字循环双向链,本文中以数组形式实现链表
int n,m; // 行、列数
int u[MAXN],d[MAXN],l[MAXN],r[MAXN],h[MAXN];
// 每个结点的上下左右指针;每一行的头指针
int row[MAXN],col[MAXN],s[MAXN],ansk[MAXN],pos;
// 每个结点原先所在的行、列;每一列的结点个数;ansk记录搜索信息;结点总数
如下图所示
别急!我们一步一步来实现这个复杂的数据结构
首先初始化上方的列头结点
我们可以称列头结点为限制,行头结点为决策 (注:这个做题的时候有用)
void init()
{
for(R int i=0; i<=m; i++)
{
l[i]=i-1;
r[i]=i+1;
u[i]=d[i]=i;
}
l[0]=m;r[m]=0; //循环链表
memset(h,-1,sizeof(h)); //每一行的头结点都为空
memset(s,0,sizeof(s)); //每一列的结点数都为0
pos=m+1; //已经搭建好了m个列头结点,下一个加入的结点从m+1开始编号
}
接下来,我们来把插入结点的功能完成(注:这里比较复杂,可以感性理解一下)
void link(R int x,R int y)
{
s[y]++; //所在的列结点数加1
row[pos]=x;col[pos]=y; //记录编号为pos的结点(即新加入的结点)的行和列
u[pos]=y; //pos结点的上指针指向插入的列y
d[pos]=d[y]; //pos结点的下指针指向插入位置下方的元素
u[d[y]]=pos; //插入位置下方的元素的上指针指向pos结点
d[y]=pos; //插入位置上方的下指针指向pos结点
if(h[x]<0)h[x]=l[pos]=r[pos]=pos;//如果pos结点所在的这一行没有头结点,就自己当
else //不然就插入头结点一侧 (下面的就不注释了,和上面插入的方法类似)
{
l[pos]=l[h[x]];
r[pos]=h[x];
r[l[h[x]]]=pos;
l[h[x]]=pos;
}
pos++; //下一个结点不要标错号了
}
现在我们来完成删除和恢复操作(差不多的)
inline void rm(R int y)
{
l[r[y]]=l[y];r[l[y]]=r[y];//删除y列的结点(这个位置已经填满了)
for(R int i=d[y]; i!=y; i=d[i])
for(R int j=r[i]; j!=i; j=r[j])//删除这一行(冲突的行)
{
d[u[j]]=d[j];
u[d[j]]=u[j];
s[col[j]]--;//别忘了减1
}
}
inline void rv(R int y)
{
for(R int i=u[y]; i!=y; i=u[i])
for(R int j=l[i]; j!=i; j=l[j])//恢复也是一样的
{
d[u[j]]=j;
u[d[j]]=j;
s[col[j]]++;
}
r[l[y]]=y;l[r[y]]=y;
}
主体部分,就是深度优先搜索
bool dance(R int dep)
{
if(!r[0])//所有的列头结点都被选了,说明成功了
{
for(R int i=0; i<dep; i++)
printf("%lld%c",ansk[i]," \n"[i==dep-1]);
return 1;
}
R int y=r[0];
for(R int i=r[0]; i; i=r[i])
if(s[i]<s[y])y=i;//这里是一个剪枝,每次选择结点最少的列能在一定情况下提高性能
rm(y);//删掉这一列
for(R int i=d[y]; i!=y; i=d[i])//每次选择这列中的一行
{
ansk[dep]=row[i];
for(R int j=r[i]; j!=i; j=r[j])rm(col[j]);//删掉这一行中所有结点(这一行冲突)
if(dance(dep+1))return 1;//成功就返回
for(R int j=l[i]; j!=i; j=l[j])rv(col[j]);//恢复
}
rv(y);
return 0;
}
如果您不太理解的话,可以看看下面的动图(注:其实和之前模拟的有些相似)
算法执行过程 (注:图片是这篇博客的)
最终的答案即下图 (选择 $1,4,5$ )
最后贴上完整代码
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define R register
#define MAXN 250505
template<typename T>inline void read(R T &k)
{
R char ch=getchar(); R T x=0,f=1;
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
k=x*f;
}
int n,m;
int u[MAXN],d[MAXN],l[MAXN],r[MAXN],h[MAXN];
int row[MAXN],col[MAXN],s[MAXN],ansk[MAXN],pos;
void init()
{
for(R int i=0; i<=m; i++)
{
l[i]=i-1;
r[i]=i+1;
u[i]=d[i]=i;
}l[0]=m;r[m]=0;
memset(h,-1,sizeof(h));
memset(s,0,sizeof(s));
pos=m+1;
}
void link(R int x,R int y)
{
s[y]++;
row[pos]=x;col[pos]=y;
u[pos]=y;
d[pos]=d[y];
u[d[y]]=pos;
d[y]=pos;
if(h[x]<0)h[x]=l[pos]=r[pos]=pos;
else
{
l[pos]=l[h[x]];
r[pos]=h[x];
r[l[h[x]]]=pos;
l[h[x]]=pos;
}
pos++;
}
inline void rm(R int y)
{
l[r[y]]=l[y];r[l[y]]=r[y];
for(R int i=d[y]; i!=y; i=d[i])
for(R int j=r[i]; j!=i; j=r[j])
{
d[u[j]]=d[j];
u[d[j]]=u[j];
s[col[j]]--;
}
}
inline void rv(R int y)
{
for(R int i=u[y]; i!=y; i=u[i])
for(R int j=l[i]; j!=i; j=l[j])
{
d[u[j]]=j;
u[d[j]]=j;
s[col[j]]++;
}
r[l[y]]=y;l[r[y]]=y;
}
bool dance(R int dep)
{
if(!r[0])
{
for(R int i=0; i<dep; i++)
printf("%lld%c",ansk[i]," \n"[i==dep-1]);
return 1;
}
R int y=r[0];
for(R int i=r[0]; i; i=r[i])
if(s[i]<s[y])y=i;
rm(y);
for(R int i=d[y]; i!=y; i=d[i])
{
ansk[dep]=row[i];
for(R int j=r[i]; j!=i; j=r[j])rm(col[j]);
if(dance(dep+1))return 1;
for(R int j=l[i]; j!=i; j=l[j])rv(col[j]);
}
rv(y);
return 0;
}
signed main()
{
read(n);read(m);
init();
for(R int i=1; i<=n; i++)
for(R int j=1,t; j<=m; j++)
{
read(t);
if(t)link(i,j);
}
if(!dance(0))puts("No Solution!");
return 0;
}
二、例题
如果您看到这里,还是很明白的话,那么我们来讲个例题
题目链接:SP13980 SUDOGOB - Sudoku goblin
题意:给定一个 $9 \times 9$ 的数独,输出可填的方案数,多组数据
选择这个例题当然不是让你写暴搜的
首先考虑决策
每个格子上填数字,至多有 $9\times 9\times 9 = 729$ 种决策
再考虑限制
- 每个点只能填一个数
- 每行一个数只能填一次
- 每列一个数只能填一次
- 每个九宫格一个数只能填一次
限制数为 $9\times 9 \times 4 = 324$
对于精准覆盖问题,我们本质上是在选择若干决策,使其恰好满足所有限制条件
因此这题可以用 DLX 求解
那么 MAXN
只要开到 729*324 就行了(注:不过开大点保险)
这题唯一的细节是插入结点的行、列,还是很简单的题目
代码如下
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define R register
#define MAXN (729*324+6666)
int Q;
bool suc;
int a[25][25];
int ans[25][25],Ans;
int h[MAXN],l[MAXN],r[MAXN],u[MAXN],d[MAXN];
int row[MAXN],col[MAXN],pos,s[MAXN],ansk[MAXN];
template<typename T>inline void read(R T &k)
{
R char ch=getchar();R T x=0,f=1;
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
k=x*f;
}
inline void init()
{
int m=324;Ans=0;suc=0;
for(R int i=0; i<=m; i++)
{
l[i]=i-1;r[i]=i+1;
u[i]=d[i]=i;
}
l[0]=m;r[m]=0;
memset(s,0,sizeof(s));
memset(h,-1,sizeof(h));
pos=m+1;
}
inline void link(R int x,R int y)
{
s[y]++;
row[pos]=x;col[pos]=y;
u[pos]=y;
d[pos]=d[y];
u[d[y]]=pos;
d[y]=pos;
if(h[x]<0)h[x]=l[pos]=r[pos]=pos;
else
{
l[pos]=l[h[x]];
r[pos]=h[x];
r[l[h[x]]]=pos;
l[h[x]]=pos;
}
++pos;
}
inline void rm(R int y)
{
r[l[y]]=r[y];l[r[y]]=l[y];
for(R int i=d[y]; i!=y; i=d[i])
for(R int j=r[i]; j!=i; j=r[j])
{
u[d[j]]=u[j];
d[u[j]]=d[j];
s[col[j]]--;
}
}
inline void rv(R int y)
{
for(R int i=u[y]; i!=y; i=u[i])
for(R int j=l[i]; j!=i; j=l[j])
{
u[d[j]]=j;
d[u[j]]=j;
s[col[j]]++;
}
l[r[y]]=y;r[l[y]]=y;
}
void dance(R int dep)
{
if(!r[0])
{
Ans++;
if(Ans>1)suc=1;
for(R int i=0; i<dep&&!suc; i++)
{
R int x=(ansk[i]-1)/9/9+1;
R int y=(ansk[i]-1)/9%9+1;
R int z=(ansk[i]-1)%9+1;
ans[x][y]=z;
}
return;
}
R int y=r[0];
for(R int i=r[0]; i!=0; i=r[i])
if(s[i]<s[y])y=i;
rm(y);
for(R int i=d[y]; i!=y; i=d[i])
{
ansk[dep]=row[i];
for(R int j=r[i]; j!=i; j=r[j])rm(col[j]);
dance(dep+1);
for(R int j=l[i]; j!=i; j=l[j])rv(col[j]);
}
rv(y);
}
signed main()
{
read(Q);
while(Q--)
{
init();
for(R int i=1; i<=9; i++)
for(R int j=1; j<=9; j++)
{
read(a[i][j]);
R int &t=a[i][j];
for(R int k=1; k<=9; k++)
{
if(t!=k&&t!=0)continue;
R int o=(i-1)*9*9+(j-1)*9+k;
R int c1=81*0 + (i-1)*9+(j-1)+1;
R int c2=81*1 + (i-1)*9+k;
R int c3=81*2 + (j-1)*9+k;
R int c4=81*3 + ((i-1)/3*3+(j-1)/3)*9+k;
link(o,c1);link(o,c2);link(o,c3);link(o,c4);
}
}
dance(0);
if(!Ans){puts("0");continue;}
printf("%lld\n",Ans);
for(R int i=1; i<=9&&!suc; i++)
for(R int j=1; j<=9; j++)
printf("%lld%c",ans[i][j]," \n"[j==9]);
}
return 0;
}
当然,如果您有兴趣的话,可以去做下这道题
这道题就是上一题的加强版,其实没什么区别,如果您理解了例题的话,这题就很简单了