此生此世做过的第一恶心的树型 DP


题意:

很简单:就是要求一棵树上相距距离不超过 k 的点有几对。多组数据,n<=10000

分析:

看到不能 $O(n^2)$赶紧想有没有带 $log$的算法。没想到。于是百度。网上说是树的点分治,大概看懂了。

算法思想:

对于一棵树,其中任何一条路径都有:要么经过根,要么不经过根。于是我们把这棵树的根节点揪出来,寻找所有经过根的合法路径,然后把这个根删掉,得到很多子树,对子树再进行上述计算。由于新的子树上的路径绝对不经过旧的大树的根,所以路径不会有重复。又由于很显然地每一条路径都只能属于一棵子树也必然属于一棵树,所以这种方法是正确的。经过合理的选择根节点,我们只计算了 $logn$层,所以这种方法的复杂度为 $O(logn)*P(n)$,$P(n)$的复杂度取决于你寻找路径的算法的好坏。

下面我们讨论如何寻找经过一个树的根的路径总数。
首先,我们可以获得这棵树中所有节点到根的距离。那么所有和小于 k 的点对有可能构成合法路径的两端。为什么只是有可能呢?因为两个点有可能出现在这棵树的同一个子树上,他们构成的路径不经过根。
设 $A$为满足 $dis[x]+dis[y]\leq k$的 $(x,y)$的数量,$B$为满足 $dis[x]+dis[y]\leq k$且 $x,y$所在子树相同的 $(x,y)$数量,那么这棵树中经过根节点的路径条数就是 $A-B$。
我们可以 $O(n)$求出 $dis[i]$,$O(nlogn)$将 $dis[]$排序,再 $O(n)$利用单调性找出每一个 x 所对应的 $dis$最大的 $y$(当 $dis[x]$增加时,$dis[y]$不会增加, 呈单调递减),也就是 $A$。对于这棵树的每一棵子树,我们又可以 $O(n)$求出所有 $B$。然后 $A-B$就是答案。这样的理想时间复杂度为 $O(nlognlogn)$

还需要注意几点:

1. 单纯的将子树分治是不可行的,因为出题人会`专门把树扯成一条链让你 $O(n^2logn)$地吃屎。于是我们需要专门用一个 DP 找出树的重心不断进行分治,而复杂度为 $O(nlognlogn)$。
2. 这一道题很无耻的卡 memset(), 删掉 memset 你就奇迹般地从 TLE 变成了 547ms。此题还会与分段式桶排发生反应,我的比 sort 在 1e7 下快 10 倍的桶排竟然会 TLE, 而 sort 奇迹般 AC。考试时如果遇到这种让我做一下午一晚上的题,我一定会说: 打暴力。如果非要给这个暴力加上一个期限,我希望是+1s。(+1s 就可以用 memset)

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

#define MX 10010

using namespace std;

int fst[MX],nxt[MX*2],v[MX*2],w[MX*2],lnum;
int vis[MX];
int n,k;
int mx[MX],sum[MX];

inline void addeg(int nu,int nv,int nw)
{
    lnum++;
    nxt[lnum]=fst[nu];
    fst[nu]=lnum;
    v[lnum]=nv;
    w[lnum]=nw;
}

inline void init()
{
    lnum=-1;
    memset(fst,-1,sizeof(fst));
}

inline void input()
{
    int a,b,c;
    for(int i=1;i<n;i++)
    {
        scanf("%d%d%d",&a,&b,&c);
        addeg(a,b,c);
        addeg(b,a,c);
    }
}

int root,sz;
void _getroot(int x,int fa)
{
    sum[x]=mx[x]=0;
    for(int i=fst[x];i!=-1;i=nxt[i])
    {
        if(v[i]==fa||vis[v[i]])continue;
        _getroot(v[i],x);
        sum[x]+=sum[v[i]]+1;
        mx[x]=max(mx[x],sum[v[i]]+1);
    }
    mx[x]=max(mx[x],sz-sum[x]-1);
    if(mx[x]<mx[root])root=x;
}
inline void getroot(int x,int fa)
{
    root=0;
    mx[root]=99999999;
    _getroot(x,fa);
}

int q[MX],dp[MX];
int dis[MX];
inline void getdep(int pred,int x,int fa)
{
    int h=1,t=1,now;
    memset(dp,0xff,sizeof(dp));
    dis[0]=0;
    dp[x]=pred;
    dis[++dis[0]]=pred;
    q[h]=x;
    while(h>=t)
    {
        now=q[t++];
        for(int i=fst[now];i!=-1;i=nxt[i])
        {
            if(v[i]==fa||vis[v[i]]||dp[v[i]]!=-1)continue;
            dp[v[i]]=dp[now]+w[i];
            dis[++dis[0]]=dp[v[i]];
            q[++h]=v[i];
        }
    }
}

int tdis[MX];
int sch(int x,int fa)
{
    int a=0,b=0;
    vis[x]=1;
    tdis[0]=0;
    for(int i=fst[x];i!=-1;i=nxt[i])
    {
        if(v[i]==fa||vis[v[i]])continue;
        getdep(w[i],v[i],x);
        sort(dis+1,dis+dis[0]+1);
        for(int j=1;j<=dis[0];j++)tdis[++tdis[0]]=dis[j];
        for(int j=1,c=dis[0];j<=dis[0];j++)
        {
            while(dis[c]+dis[j]>k&&c>=1)c--;
            if(c<=j)break;
            b+=c-j;
        }
    }
    sort(tdis+1,tdis+tdis[0]+1);
    for(int j=tdis[0];j>=1;j--)if(tdis[j]<=k){a+=j;break;}
    for(int i=1,j=tdis[0];i<=tdis[0];i++)
    {
        while(tdis[j]+tdis[i]>k&&j>=1)j--;
        if(j<=i)break;
        a+=j-i;
    }
    a-=b;
    for(int i=fst[x];i!=-1;i=nxt[i])
    {
        if(v[i]==fa||vis[v[i]])continue;
        sz=sum[v[i]]+1;
        getroot(v[i],x);
        a+=sch(root,x);
    }
    return a;
}

int main()
{
    while(~scanf("%d%d",&n,&k))
    {
        memset(vis,0,sizeof(vis));
        if(n==0&&k==0)break;
        init();
        input();
        sz=n;
        getroot(1,0);
        printf("%d\n",sch(root,0));
    }
    return 0;
}

分类: 文章

0 条评论

发表评论

电子邮件地址不会被公开。 必填项已用*标注