论如何使用 K-D tree 卡过”HNOI2015 接水果”

思路:

首先,我们会比较自然地想到将盘子和水果的关系抽象成整棵树的 DFS 序上的点对。
如果一条树上路径[a,b] 是[c,d] 的子路径,那么 c 和 d 一定会在与 a,b 有关的特定的区间内。

因此求多少已知路径包含给定路径将非常容易,因为这个问题可以转化为有多少点对 (x,y) 满足 $x\in [l_1,r_1],y\in [l_2,r_2]$。这个问题可以由 K-D 树或者扫描线或者什么诡异的方法解决,如用 k-d 树,我们只需要实现区间加和单点修改。

接着我们在此基础上,考虑如何满足 “第 k 大” 这个限制。

如果我们二分每这个第 k 大值,该问题又可以转化为判定性问题:求有多少小于该值的满足条件的点对。

所以我们对每个水果,同时二分答案,按二分的第 k 大值排序。盘子也按权值排序。将水果和盘子按权值从小到达,把盘子插入到 k-d 树中,用水果询问。这样我们就知道每个水果路径包含了多少比他小的盘子路径,利用这个数量继续二分即可。

这是一种类似整体二分的方法,但相比整体二分,这种方法更像二分答案。

时间复杂度 $O(n\sqrt{n}log{n})$

卡·常数 的诅咒

然而非常激动地想到这种做法,非常激动地码完,非常激动地过了样例,却只有 20 分,剩下全部 TLE。开 O2 后还是只有 70 分。
找到本地数据测试后,发现有一些精心构造的数据需要 5 秒才能通过,于是理所当然地在洛谷上 TLE 了。

于是我开始优化代码

  • 问题 1:同时二分时有些的 l,r 区间已经重合,没有必要处理这些询问。
  • 解决 1:特判掉这些没用的询问-0.07s
  • 问题 2:二分的区间太大,没有必要二分 0 到 1e9 的所有数。
  • 解决 2:将所有可能作为答案的数离散化,二分这些数的编号-0.5s
  • 问题 3:gprof 后发现各种运算符重载和 min,max 等比较耗时
  • 解决 3:运算符用 const&重载,手写 min,max-0.3s
  • 问题 4:读入也许比较耗时
  • 解决 4:读入优化-0.04s
  • 问题 5:最影响整体复杂度的是 kd 树的区间加
  • 解决 5:对 kd 树实行标记永久化,榨干区间加的常数-0.6s
  • 问题 6:二分答案时会多次询问同样的矩形,每次在 kd 树中递归非常浪费时间
  • 解决 6:预处理所有可能用到的矩形在 kd 树中用到了哪些节点,以及对这些节点实行了什么操作,真正区间加时直接访问这些节点-0.8s

至此,我将 5s 的代码优化成了 2.7s,于是我就仰天长啸,拍案而起,但最终还是
觉得对得起浪费掉的这两个晚上的。

(由于各种原因,代码非常长)

#pragma GCC optimize("Ofast")
#pragma GCC target("sse3","sse2","sse")
#pragma GCC target("avx","sse4","sse4.1","sse4.2","ssse3")
#pragma GCC target("f16c")
#pragma GCC optimize("inline","fast-math","unroll-loops","no-stack-protector")
#pragma GCC diagnostic error "-fwhole-program"
#pragma GCC diagnostic error "-fcse-skip-blocks"
#pragma GCC diagnostic error "-funsafe-loop-optimizations"
#pragma GCC diagnostic error "-std=c++14"
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <ctime>
#define MX 40004

int D;
int real[MX],rnum;
int gfake(int a)
{
    return std::lower_bound(real+1,real+rnum+1,a)-real;
}
void lsh()
{
    std::sort(real+1,real+rnum+1);
    rnum=std::unique(real+1,real+rnum+1)-real-1;
}
struct node
{
    int p[2],mn[2],mx[2],s[2],cnt,laz;
    bool operator < (const node& t)const{return p[D]!=t.p[D]?p[D]<t.p[D]:p[D^1]<t.p[D^1];}
}tseq[MX*2];
struct edge
{
    int u,v,w;
}e[MX*2],plt[MX],qur[MX];
struct bina
{
    int t,u,v,k,l,r,m,a,id;     //type; u; v; kth; l; r; mid; ans; id;
    bool operator < (const bina& tmp)const{return real[m]==real[tmp.m]?t<tmp.t:real[m]<real[tmp.m];}
}bin[MX*2];
int n,p,q;
int fst[MX],nxt[MX*2],lnum;
int dfn[MX],rak[MX],mst[MX][2],dcnt;
int out[MX],rout[MX][1000],rtyp[MX][1000],rsiz[MX];

int mmin(const int &a,const int &b){return a>b?b:a;}
int mmax(const int &a,const int &b){return a<b?b:a;}

struct KDT
{
    node tre[MX*2];
    int cnt,rot;
    void init(){for(int i=1;i<=cnt;++i)tre[i].cnt=tre[i].laz=0;}
    void upd(int a)
    {
        tre[a].mn[0]=tre[a].mx[0]=tre[a].p[0];
        if(tre[a].s[0])
        {
            tre[a].mn[0]=mmin(tre[a].mn[0],tre[tre[a].s[0]].mn[0]),
            tre[a].mx[0]=mmax(tre[a].mx[0],tre[tre[a].s[0]].mx[0]);
        }
        if(tre[a].s[1])
        {
            tre[a].mn[0]=mmin(tre[a].mn[0],tre[tre[a].s[1]].mn[0]),
            tre[a].mx[0]=mmax(tre[a].mx[0],tre[tre[a].s[1]].mx[0]);
        }
        tre[a].mn[1]=tre[a].mx[1]=tre[a].p[1];
        if(tre[a].s[0])
        {
            tre[a].mn[1]=mmin(tre[a].mn[1],tre[tre[a].s[0]].mn[1]),
            tre[a].mx[1]=mmax(tre[a].mx[1],tre[tre[a].s[0]].mx[1]);
        }
        if(tre[a].s[1])
        {
            tre[a].mn[1]=mmin(tre[a].mn[1],tre[tre[a].s[1]].mn[1]),
            tre[a].mx[1]=mmax(tre[a].mx[1],tre[tre[a].s[1]].mx[1]);
        }
    }
    void psd(int a)
    {
        if(!tre[a].laz)return;
        tre[tre[a].s[0]].cnt+=tre[a].laz,
        tre[tre[a].s[1]].cnt+=tre[a].laz,
        tre[tre[a].s[0]].laz+=tre[a].laz,
        tre[tre[a].s[1]].laz+=tre[a].laz,
        tre[a].laz=0;
    }
    bool itsct(const node& a,const node& b)
    {
        return (*a.mn<=*b.mx&&*b.mn<=*a.mx&&
                *(a.mn+1)<=*(b.mx+1)&&*(b.mn+1)<=*(a.mx+1));
    }
    void build(int &a,int l,int r,int d)
    {
        if(l>r)return;
        D=d;
        int mid=(l+r)>>1;
        std::nth_element(tseq+l,tseq+mid,tseq+r+1);
        tre[a=++cnt]=tseq[mid];
        build(tre[a].s[0],l,mid-1,d^1);
        build(tre[a].s[1],mid+1,r,d^1);
        upd(a);
    }
    void pre_add(int a,node &t,int tar)
    {
        ++rsiz[tar];
        rout[tar][rsiz[tar]]=a;
        if( *t.mn<=*tre[a].mn&&*tre[a].mx<=*t.mx&&
            *(t.mn+1)<=*(tre[a].mn+1)&&*(tre[a].mx+1)<=*(t.mx+1))
            {rtyp[tar][rsiz[tar]]|=1;return;}
        if( *t.mn<=*tre[a].p&&*tre[a].p<=*t.mx&&
            *(t.mn+1)<=*(tre[a].p+1)&&*(tre[a].p+1)<=*(t.mx+1))
            {rtyp[tar][rsiz[tar]]|=2;}
        int ls=*tre[a].s,rs=*(tre[a].s+1);
        if( ls&&itsct(*(tre+ls),t))pre_add(ls,t,tar);
        if( rs&&itsct(*(tre+rs),t))pre_add(rs,t,tar);
    }
    void add(int id)
    {
        for(int i=1;i<=rsiz[id];++i)
        {
            if(rtyp[id][i]&1)++tre[rout[id][i]].laz;
            if(rtyp[id][i]&2)++tre[rout[id][i]].cnt;
        }
    }
    int query(int a,node &t,int d)
    {
        if(!a)return 0;
        D=d;
        if(t.p[0]==tre[a].p[0]&&t.p[1]==tre[a].p[1])return tre[a].cnt+tre[a].laz;
        else if(t<tre[a])return query(tre[a].s[0],t,d^1)+tre[a].laz;
        else return query(tre[a].s[1],t,d^1)+tre[a].laz;
    }
}ktr;

struct LCA
{
    int fa[17][MX],dep[MX];
    void init()
    {
        for(int i=1;i<=16;++i)
            for(int j=1;j<=n;++j)
                fa[i][j]=fa[i-1][fa[i-1][j]];
    }
    int lca(int a,int b)
    {
        if(dep[a]<dep[b])a^=b,b^=a,a^=b;
        for(int i=16;i>=0;--i)
            if(dep[fa[i][a]]>=dep[b])
                a=fa[i][a];
        if(a==b)return a;
        for(int i=16;i>=0;--i)
            if(fa[i][a]!=fa[i][b])
                a=fa[i][a],b=fa[i][b];
        return fa[0][a];
    }
    int son(int a,int b)
    {
        for(int i=16;i>=0;--i)
            if(dep[fa[i][b]]>dep[a])
                b=fa[i][b];
        return b;
    }
}lca;

void addeg(int nu,int nv)
{
    nxt[++lnum]=fst[nu];
    fst[nu]=lnum;
    e[lnum]=(edge){nu,nv};
}

int read()
{
    int x=0;char ch=getchar();
    while(!isdigit(ch))ch=getchar();
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return x;
}

void input()
{
    int a,b;
    memset(fst,0xff,sizeof(fst)),lnum=-1;
    n=read(),p=read(),q=read();
    for(int i=1;i<n;++i)
    {
        a=read(),b=read();
        addeg(a,b);
        addeg(b,a);
    }
    for(int i=1;i<=p;++i)plt[i].u=read(),plt[i].v=read(),plt[i].w=read();
    for(int i=1;i<=q;++i)qur[i].u=read(),qur[i].v=read(),qur[i].w=read();
}

void dfs(int x,int f,int d)
{
    lca.fa[0][x]=f;
    lca.dep[x]=d;
    dfn[x]=++dcnt;
    rak[dcnt]=x;
    mst[x][0]=dcnt;
    for(int i=fst[x];i!=-1;i=nxt[i])
        if(e[i].v!=f)
            dfs(e[i].v,x,d+1);
    mst[x][1]=dcnt;
}

void work()
{
    node t;int tmp=0,f;
    dfs(1,0,1);
    lca.init();
    for(int i=1;i<=p;++i)real[++rnum]=plt[i].w;
    lsh();
    for(int i=1;i<=p;++i)
    {
        int u=plt[i].u,v=plt[i].v;
        if(dfn[u]>dfn[v])u^=v,v^=u,u^=v;
        bin[i]=(bina){0,u,v,0,0,0,gfake(plt[i].w),0,i};
    }
    for(int i=1;i<=q;++i)
    {
        int u=qur[i].u,v=qur[i].v;
        if(dfn[u]>dfn[v])u^=v,v^=u,u^=v;
        bin[p+i]=(bina){1,u,v,qur[i].w,1,rnum,(rnum+1)>>1,0,i};
    }
    for(int i=1;i<=p+q;++i)
        if(bin[i].t)
        {
            int du=dfn[bin[i].u],dv=dfn[bin[i].v];
            tseq[++tmp]=(node){du,dv,du,dv,du,dv,0,0,0,0};
        }
    ktr.build(ktr.rot,1,tmp,0);
    for(int j=1;j<=p+q;j++)
        if(!bin[j].t)
        {
            int u=bin[j].u,v=bin[j].v,g=lca.lca(u,v);
            if(g!=u&&g!=v)
            {
                t=(node){0,0,mst[u][0],mst[v][0],mst[u][1],mst[v][1],0,0,0,0};
                if(t.mn[0]<=t.mx[0]&&t.mn[1]<=t.mx[1])ktr.pre_add(ktr.rot,t,bin[j].id);
            }
            else
            {
                int x=lca.son(u,v);
                t=(node){0,0,1,mst[v][0],mst[x][0]-1,mst[v][1],0,0,0,0};
                if(t.mn[0]<=t.mx[0]&&t.mn[1]<=t.mx[1])ktr.pre_add(ktr.rot,t,bin[j].id);
                t=(node){0,0,mst[v][0],mst[x][1]+1,mst[v][1],n,0,0,0,0};
                if(t.mn[0]<=t.mx[0]&&t.mn[1]<=t.mx[1])ktr.pre_add(ktr.rot,t,bin[j].id);
            }
        }
    for(int i=1;i<=17;++i)
    {
        f=0;
        ktr.init();
        std::sort(bin+1,bin+p+q+1);
        for(int j=1;j<=p+q;++j)
        {
            if(!bin[j].t)
            {
                int u=bin[j].u,v=bin[j].v,g=lca.lca(u,v);
                if(g!=u&&g!=v)ktr.add(bin[j].id);
                else ktr.add(bin[j].id);
            }
            else if(bin[j].l<bin[j].r)
            {
                f=1;
                t=(node){dfn[bin[j].u],dfn[bin[j].v],0,0,0,0,0,0,0,0};
                bin[j].a=ktr.query(ktr.rot,t,0);
                if(bin[j].a>=bin[j].k)bin[j].r=bin[j].m;
                else bin[j].l=bin[j].m+1;
                bin[j].m=(bin[j].l+bin[j].r)>>1;
            }
        }
        if(!f)break;
    }
    for(int i=1;i<=p+q;++i)if(bin[i].t)out[bin[i].id]=bin[i].m;
    for(int i=1;i<=q;++i)printf("%d\n",real[out[i]]);
}

int main()
{

    input();
    work();
    return 0;
}
分类: 文章

0 条评论

发表回复

Avatar placeholder

您的电子邮箱地址不会被公开。 必填项已用*标注