KM 算法求解二分图最小权完美匹配

问题描述

给定一个二分图,其左侧有 $X$个节点,右侧有 $Y$个节点,左侧第 $i$个和右侧第 $j$个节点之间有一条权值为 $w_{ij}$的边。

现求这个二分图权值和最大的一个完美匹配。

算法复杂度

最优为 $O(n^3)$

先决条件

了解增广路有关定理、性质

熟练掌握二分图的最大匹配、一般图的最大流算法

指导思想

习近平新时代中国特色社会主义思想

KM 算法的精髓在于,它为每个节点设置了一个叫做顶标的量 $l_i$,并且对于每条边 $w_{ij}$,无时无刻不满足 $l_i+l_j\geq w_{ij}$。这个需要满足的条件我们姑且叫做“ 先决条件”。最佳匹配的边一定满足 $w_{ij}=l_i+l_j$。因此如果通过调整顶标,使得所有节点的顶标和尽可能的小,且仍然满足先决条件,最佳匹配的权值一定就是最后所有节点的顶标之和。

先给出一个重要的定义:

  • 相等子图:所有 $w_{ij}=l_i+l_j$的边和其端点构成的子图

我们调整顶标的目的就是扩大相等子图。如果相等子图足够大,以至于其最大匹配就是原图的最大匹配,那么我们就找到了答案,因为相等子图的任意一个最大匹配都是原图的最佳匹配。

如果现在的相等子图不够大,我们就需要调整顶标。直到相等子图足够大。

算法流程

对于每个 $X$部的节点,我们依次为它寻找相等子图中的一个匹配点。也就是说,当算法执行到 $X$部第 $i$个点,则前 $i-1$个点一定都在相等子图中,且现在的相等子图存在一个包含前 $i-1$个点的最大匹配。

从第 $i$个点出发,在相等子图中寻找一条增广路。

  • 如果找到了增广路,那么算法继续处理下一个点
  • 如果没有找到增广路,则通过调整顶标扩大相等子图,然后再尝试寻找第 $i$个点出发的增广路,没找到再调整顶标,不停循环。

如何调整顶标呢?

由于第 $i$个点出发没有任何一条” 增广路”,我们肯定找到了许多” 交错路”,而将这些交错路连起来可以构成一棵交错树 (类似匈牙利算法,而且是相等子图的一个生成树)。如果将交错树中 $X$部的所有点的 $l$减去 $d$,$Y$部的加上 $d$,并且这个 $d$在满足先决条件的前提下尽可能大,则相等子图内的任意一条边都还在相等子图内,而且会有新的边加入到相等子图内。因为:

  • 一条边 $(u,v)$的两个端点都在相等子图中:$(l_u+d)+(l_v-d)=w_{uv}$仍然成立
  • 一条边 $(u,v)$只有 $u$在相等子图中:$(l_u+d)+l_v\geq w_{uv}$仍然成立
  • 一条边 $(u,v)$只有 $v$在相等子图中:$l_u+(l_v-d)\geq w_{uv}$需要成立,则 $d\leq l_u+l_v-w_{uv}$,并且如果 $d=l_u+l_v-w_{uv}$,这条边会加入到相等子图内。
  • 一条边 $(u,v)$两个端点都不在相等子图中:$l_u+l_v\geq w_{uv}$仍然成立

我们可以用以修改顶标的最大的 $d$就是上面第三类边所确定的最小的 $d$。

至此,完全按照上面的指导来实现,算法的复杂度是 $O(n^4)$的。如果用 $bfs$等方法去实现,算法的复杂度可以降至 $O(n^3)$

一个更高的角度

在学习 KM 算法的时候,我遇到了一个疑问:为什么我们在算法过程中始终在最小化顶标,但是最后得到的却是一个最大化的答案呢?

这个形式的最优化问题与线性规划中的强对偶定理有着异曲同工之妙。不难猜测,二分图最大权完美匹配也可以规约到某个线性规划问题上。

  • 强对偶定理:$min\{c^Tx|Ax\geq b\}=max\{b^Ty|A^Ty\leq c\}$

现在我们给矩阵 $A$,向量 $b,c$赋予二分图匹配中的某些内涵,便可以将这两个问题等价。

$A$是一个 $m\times n$的矩阵,其中 $A_{ij}=1$当且仅当 $j$为第 $i$条边的一个端点

$b$是一个 $m$维向量,其中 $b_{i}$为第 $i$条边的边权

$c$是一个 $n$维向量,其中 $c_i$为 $1$

$x$是一个需要求的 $n$维向量,其中 $x_i$为点 $x$的顶标

$y$是一个需要求的 $m$维向量,其中 $x_i$代表第 $i$条边是否选择

那么 $Ax\geq b$描述的是顶标之和大于等于边权,$A^Ty\leq c$描述的是每个点只能与一条边相连。$c^Tx$意为最小化顶标之和,$b^Ty$意为最大化边权之和。这样岂不妙哉?

代码实现

按照本文的指导,我们可以实现 $O(n^4)$的 dfs 做法。在实现中,引入一个新的数组 slack,代表右侧每一条不属于联通子图的边所决定的最小的 $d$

int n, m, mat[MX][MX];      //两边的点数,边矩阵
int lft[MX], rgt[MX];       //左右的顶标
int slk[MX];                //slack 数组
int mch[MX], mlf[MX];       //右左侧节点的匹配点
int vlf[MX], vrt[MX];       //左右侧节点是否访问过
int cl, cr;

bool dfs(int x)
{
    vlf[x] = 1;
    for(int y=1; y<=m; y++)
    {
        if(!vrt[y])
        {
            int t = lft[x] + rgt[y] - mat[x][y];
            if(t == 0)
            {
                vrt[y] = 1;
                if(!mch[y] || dfs(mch[y]))
                {
                    mch[y] = x;
                    mlf[x] = y;
                    return 1;
                }
            }
            else slk[y] = min(slk[y], t);
        }
    }
    return 0;
}

void KM()
{
    for(int x=1; x<=cl; x++)
        for(int y=1; y<=cr; y++)
            lft[x] = max(lft[x], mat[x][y]);
    for(int x=1; x<=cl; x++)
    {
        memset(slk, 0x3f, sizeof(slk));
        while(1)
        {
            memset(vlf, 0, sizeof(vlf));
            memset(vrt, 0, sizeof(vrt));
            if(dfs(x)) break;
            int d = 123123123;
            for(int y=1; y<=cr; y++)
                if(!vrt[y])
                    d = min(d, slk[y]);
            for(int y=1; y<=cl; y++)
                if(vlf[y]) lft[y] -= d;
            for(int y=1; y<=cr; y++)
                if(vrt[y]) rgt[y] += d;
                else slk[y] -= d;
        }
    }
}

将这种做法改为 bfs,即可做到 $O(n^3)$。bfs 算法的原理是以 bfs 的形式构建一棵交错路树。每次选择不在树上 (即相等子图) 的限制的 $d$最小的一个节点加入交错路树。找到一条增广路后在交错路树上不停跳二级祖先翻转匹配边和非匹配边。

int vis[MX];            //右侧节点是否访问过
int slk[MX];            //slack 数组
int mch[MX];            //右侧、左侧节点的匹配节点
int pre[MX];            //交错路树上的二级祖先
int lbx[MX], lby[MX];   //左右侧顶标
int mat[MX][MX];        //边矩阵
int n1, n2, n;          //两边的点数,两边点数的最大值

void KM()
{
    for(int i=1; i<=n; i++)
        for(int j=1; j<=n; j++)
            lbx[i] = max(lbx[i], mat[i][j]);
    for(int i=1; i<=n; i++)
    {
        int py, p, d, x;
        for(int j=1; j<=n; j++) slk[j] = +oo, vis[j] = 0;
        for(mch[py=0]=i; mch[py]; py=p)
        {
            vis[py] = 1; d = +oo; x = mch[py];
            for(int y=1; y<=n; y++)
            {
                if(!vis[y])
                {
                    if(lbx[x]+lby[y]-mat[x][y] < slk[y]) slk[y] = lbx[x]+lby[y]-mat[x][y], pre[y] = py;
                    if(slk[y] < d) d = slk[y], p = y;
                }
            }
            for(int y=0; y<=n; y++)
                if(vis[y]) lbx[mch[y]] -= d, lby[y] += d;
                else slk[y] -= d;
        }
        for(; py; py=pre[py]) mch[py] = mch[pre[py]];
    }
}
分类: 文章

0 条评论

发表评论

电子邮件地址不会被公开。 必填项已用*标注