动态区间第 k 大的一种 $O(nlog^2n)$的树套树解法

题意:

给定一个序列,有 “将一个数修改为另一个” 的操作,和询问 “[l,r] 区间内的第 k 小值是几” 的询问。要求在 1s 内对一个长 $10^4$的序列完成 $10^4$组修改和询问。

思路:

1.
我最先想到的树套树的思路是:线段树套 Splay。

线段树的每一个区间 [l,r] 定义为原序列的 [l,r] 中的数字组成的排序 Splay。这样,对于每个修改操作,我们对 $logn$棵包含了这个位置的 Splay 删除一个数,添加一个数即可。对于每个询问操作 [l,r],我们需要二分答案 x,再在 logn 棵恰好组成 [l,r] 区间的 Splay 中查询 x 的排名。

由于线段树有 logn 层,每层的所有 Splay 合并后恰为一个完整的原序列,所以空间复杂度为 O(nlogn) 的。

经过上面的分析可知:修改复杂度为 $O(nlog^2n)$,查询为 $O(nlog^3n)$。实践证明可以通过所有测试点。洛谷最慢测试点 340ms.
2.
第二种思路,即本文终点介绍的思路,两种操作都是 $O(nlog^2n)$的,只是还需要对所有出现过的数字离散化。

刚才,我们的线段树对应的是区间,Splay 对应的是值。如果互换一下呢?

现在线段树是建立在离散化后的实数域上的。线段树区间 [l,r] 定义为原序列中值属于 [l,r] 的所有值的下标的排序 Splay。即:将值属于 [l,r] 的所有位置提取为一个新的序列,保存在这个 Splay 里。

对于修改操作,我们依旧修改 $logn$棵 Splay。如果我们将 a[i] 修改为 b,则将所有的线段树节点 $[l,r](l\leq a[i]\leq r)$中的 Splay 中删除 i。同时向所有的线段树节点 $[l,r](l\leq b\leq r)$中的 Splay 添加 i。故一次修改操作的复杂度为 $O(log^2n)$。

对于查询操作,注意到这棵线段树是支持前缀和的。即实数区间 [l,r] 内的数 x$(a\leq x \leq b)$的个数等于 [1,r] 内的个数减 [1,l-1] 内的个数(这里的 x 就是原序列的下标)。所以我们只需要在树上二分即可。如果现在我们确定了 k 小值一定在实数区间 [a,b] 内,那么如果 [a,(a+b)/2] 内的 Splay 中满足上述条件的节点小于 k,则 k 小值一定在 [(a+b)/2+1,b] 内,反之同理。这样的查询操作只需要对 logn 个线段树节点进行查询,故一次询问的时间复杂度为 $O(log^2 n)$。

实践证明可以通过所有测试点,洛谷最慢测试点 92ms.

细节问题

这种方法虽然省去了一个 log,但是其代码量却比之前多了一个 log。

以下是一些需要注意的地方:

  • 离散化不但要离散原序列的值,也要离散修改出的值。
  • 在最初插入节点时最好使用类似线段树建树一样的归并方式,这样可以降低常数。实践证明不这样洛谷最慢点为 296ms.
  • 最好不要用这种方法因为我打了 200 行。
/*
A data structure used to maintain interval kth number
With splays in a segment tree
*/
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <map>

#define MX 61005
#define mid ((l+r)>>1)
#define ls (a<<1)
#define rs (a<<1|1)

using namespace std;

typedef struct splnode
{
    int x,f,siz,s[2];
}node;
typedef struct tqeury
{
    int l,r,x,t;
}query;
query qur[MX];
node t[MX*18];
int seq[MX],real[MX];
vector<int>ord[MX];
map<int,int>mp;
map<int,int>::iterator itr;
int tnum;
int n,m;
int bar[MX],bnum;

typedef struct trenode
{
    int root,l,r;
    inline int pos(int a){return t[t[a].f].s[1]==a;}
    inline void upd(int a){t[a].siz=t[t[a].s[0]].siz+t[t[a].s[1]].siz+1;}
    inline void rot(int a)
    {
        int f=t[a].f,g=t[f].f,p=pos(a),q=pos(f);
        t[f].s[p]=t[a].s[!p],t[a].s[!p]=f,t[f].f=a;
        if(t[f].s[p])t[t[f].s[p]].f=f;
        if(t[a].f=g)t[g].s[q]=a;
        upd(f),upd(a);
    }
    inline void spl(int tar,int a)
    {
        while(t[a].f!=tar)
            if(t[t[a].f].f==tar)rot(a);
            else if(pos(a)==pos(t[a].f))rot(t[a].f),rot(a);
            else rot(a),rot(a);
        if(!tar)root=a;
    }
    int merg(int f,int l,int r)
    {
        if(l>r)return 0;
        int a=++tnum;
        t[a]=(node){bar[mid],f,1,0,0};
        t[a].s[0]=merg(a,l,mid-1);
        t[a].s[1]=merg(a,mid+1,r);
        upd(a);
        return a;
    }
    void insrt(int &a,int f,int x)
    {
        if(!a)t[a=++tnum]=(node){x,f,1,0,0},spl(0,tnum);
        else if(x<t[a].x)insrt(t[a].s[0],a,x);
        else insrt(t[a].s[1],a,x);
    }
    int findn(int a,int x)
    {
        if(!a)return 0;
        else if(t[a].x==x)return a;
        else if(x<t[a].x)return findn(t[a].s[0],x);
        else return findn(t[a].s[1],x);
    }
    void del(int x)
    {
        int a=findn(root,x);
        spl(0,a);
        int la=t[a].s[0],ra=t[a].s[1];
        while(t[la].s[1])la=t[la].s[1];
        spl(a,la);
        t[la].s[1]=ra,t[ra].f=la,t[root=la].f=0;
        spl(0,ra);
    }
    int rank(int a,int x)
    {
        if(!a)return 0;
        else if(x>=t[a].x)return rank(t[a].s[1],x)+t[t[a].s[0]].siz+1;
        else return rank(t[a].s[0],x);
    }
}segt;
segt tre[MX*4];

void build(int a,int l,int r)
{
    tre[a].l=l,tre[a].r=r;
    if(l<r)build(ls,l,mid),build(rs,mid+1,r);
    bar[1]=-MX;
    bnum=1;
    for(int p=l;p<=r;p++)
        for(int i=0;i<ord[p].size();i++)
            bar[++bnum]=ord[p][i];
    bar[++bnum]=MX;
    sort(bar+1,bar+bnum+1);
    tre[a].root=tre[a].merg(0,1,bnum);
}

void del(int a,int p,int x)
{
    int l=tre[a].l,r=tre[a].r;
    tre[a].del(x);
    if(l==r)return;
    else if(p<=mid)del(ls,p,x);
    else del(rs,p,x);
}

void ins(int a,int p,int x)
{
    int l=tre[a].l,r=tre[a].r;
    tre[a].insrt(tre[a].root,0,x);
    if(l==r)return;
    else if(p<=mid)ins(ls,p,x);
    else ins(rs,p,x);
}

int kth(int a,int ql,int qr,int k)
{
    int dlt=tre[ls].rank(tre[ls].root,qr)-tre[ls].rank(tre[ls].root,ql);
    if(tre[a].l==tre[a].r)return tre[a].r;
    else if(dlt<k)return kth(rs,ql,qr,k-dlt);
    else return kth(ls,ql,qr,k);
}

void lsh()
{
    int x;
    for(x=1;x<=n;x++)mp[seq[x]]=1;
    for(x=1;x<=m;x++)if(qur[x].t==0)mp[qur[x].x]=1;
    for(x=1,itr=mp.begin();itr!=mp.end();itr++,x++)itr->second=x;
    for(x=1,itr=mp.begin();itr!=mp.end();itr++,x++)real[itr->second]=itr->first;
    for(x=1;x<=n;x++)seq[x]=mp[seq[x]];
    for(x=1;x<=m;x++)if(qur[x].t==0)qur[x].x=mp[qur[x].x];
}

void inpt()
{
    char str[10];
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)scanf("%d",&seq[i]);
    for(int i=1;i<=m;i++)
    {
        scanf("%s",str);
        if(str[0]=='C')qur[i].t=0,scanf("%d%d",&qur[i].l,&qur[i].x);
        else qur[i].t=1,scanf("%d%d%d",&qur[i].l,&qur[i].r,&qur[i].x);
    }
    lsh();
    for(int i=1;i<=n;i++)ord[seq[i]].push_back(i);
    n=mp.size();
    build(1,1,n);
}

void work()
{
    for(int i=1;i<=m;i++)
    {
        if(qur[i].t==0)
        {
            del(1,seq[qur[i].l],qur[i].l);
            ins(1,qur[i].x,qur[i].l);
            seq[qur[i].l]=qur[i].x;
        }
        else printf("%d\n",real[kth(1,qur[i].l-1,qur[i].r,qur[i].x)]);
    }
}

void init()
{
    tnum=0;
    mp.clear();
    for(int i=1;i<=n;i++)ord[i].clear();
}

int main()
{
    int T;
    scanf("%d",&T);
    for(int i=1;i<=T;i++)
    {
        init();
        inpt();
        work();
    }
    return 0;
}
分类: 文章

1 条评论

konnyakuxzy · 2018年1月9日 8:35 下午

我去这代码确实挺长的 QvQ
您码力太强了 Orz
不过确实奇怪网上居然没有这种权值线段树の题解

发表评论

电子邮件地址不会被公开。 必填项已用*标注