树链剖分入门练手题。

然鹅我居然调了一个上午+半个上午!

最后居然是少写了一个 f 导致 WA 了一片 [滑稽]。

总的来说,树链剖分其实并不是那么的难,只需要两个 dfs+线段树+lca 思想即可 AC

两遍 dfs 很好理解:

const int N=3e4+2;
struct Node{
    int fa,dep,size,son,top,seg;
    #define fa(x) tree[x].fa
    #define d(x) tree[x].dep
    #define s(x) tree[x].size
    #define son(x) tree[x].son
    #define top(x) tree[x].top
    #define seg(x) tree[x].seg
}tree[N];
struct Edge{
    int nxt,to;
    #define nxt(x) edge[x].nxt
    #define to(x) edge[x].to
}edge[N<<2];
int num[N<<2],rev[N<<2],sum[N<<2],Max[N<<2],head[N<<2];
int n,m,cnt,Ans_sum,Ans_max;
//以上是需要定义的东西
//声明一下这些东西的含义:
/*
fa[x]:x 在树中的父亲
dep[x]:x 在树中的深度
size[x]:x 的子树结点数(子树大小)
son[x]:x 的重儿子,即 u→son[u] 为重边
top[x]:x 所在重路径的顶部结点(深度最小)
seg[x]:x 在线段树中的位置(下标)
rev[x]: 线段树中第 x 个位置对应的树中结点编号,即 rev[seg[x]]=x
//上为树链剖分的定义所需
//第二个结构体为图
num 为权值,Max、sum 为线段树维护的最大值与和
head 为链式前向星的必备数组
*/
//第一遍 dfs 处理树链剖分七个值的前四个
inline void dfs1(int u,int f){
    s(u)=1,fa(u)=f,d(u)=d(f)+1;
    for(register int i=head[u];i;i=nxt(i)){
        int v=to(i);if(v==f)continue;
        dfs1(v,u);s(u)+=s(v);
        if(s(v)>s(son(u)))son(u)=v;
    }return;
}
//第二遍处理后三个
inline void dfs2(int u,int f){
    if(son(u)){
        seg(son(u))=++seg(0),top(son(u))=top(u);
        rev[seg(0)]=son(u),dfs2(son(u),u);
    }for(register int i=head[u];i;i=nxt(i)){
        int v=to(i);if(!top(v)&&v!=f){
            seg(v)=++seg(0),top(v)=v;
            rev[seg(0)]=v;dfs2(v,u);
        }
    }return;
}

部分效果 (start 即为我们的 top(重链顶端)):


线段树单点修改即可:

inline void pushup(int x){
    sum[x]=sum[x<<1]+sum[(x<<1)+1];
    Max[x]=max(Max[x<<1],Max[(x<<1)+1]);
}
inline void build(int k,int l,int r){
    int mid=(l+r)>>1;if(l==r)
    {sum[k]=Max[k]=num[rev[l]];return;}
    build(k<<1,l,mid);build((k<<1)+1,mid+1,r);pushup(k);
}
inline void change(int k,int l,int r,int val,int x){
    if(x>r||x<l)return;int mid=(l+r)>>1;
    if(l==r&&r==x){sum[k]=Max[k]=val;return;}
    if(mid>=x)change(k<<1,l,mid,val,x);
    if(mid<x)change((k<<1)+1,mid+1,r,val,x);pushup(k);
}
inline void query(int k,int l,int r,int L,int R){
    if(L>r||R<l)return;
    if(L<=l&&r<=R){
        Ans_max=max(Ans_max,Max[k]);
        Ans_sum+=sum[k];return;
    }int mid=(l+r)>>1;
    if(mid>=L)query(k<<1,l,mid,L,R);
    if(mid<r)query((k<<1)+1,mid+1,r,L,R);
}

怎么询问呢?

先看一下我们的重链:

由于 dfs 的顺序,同一条重链上的节点在线段树中位置是连续的。

所以每次对于一个节点 x,我们只需要询问 x 到 top(x)之间的路径即可 (连在一起的线段树可以直接区间询问),然后 x 再跳到 top(x) 的爸爸(x->top(x) 已经询问完了),就这样一直往上跳,跳到最后 x 和另一个节点 y 都在同一条重链上即可 (即 top(x)==top(y))。

到了一条重链上,那么就可以直接查询了。

最后综合几次查询的结果,即为 x 到 y 直接的结果了。

就像下图一样:

注意:轻儿子的 top 当然是自己,即只记录自己的答案,重儿子可以跟着重链顶端一起询问 (不理解画个图,演示一下 dfs 的过程就好多了)

询问代码:

inline void ask(int x,int y){
    int fx=top(x),fy=top(y);
    while(fx!=fy){
        if(d(fx)<d(fy))swap(fx,fy),swap(x,y);
        query(1,1,seg(0),seg(fx),seg(x));
        x=fa(fx),fx=top(x);
    }if(d(x)>d(y))swap(x,y);
    query(1,1,seg(0),seg(x),seg(y));return;
}

所以只需要再加一个建图就好了。

注意建双向边,dfs 的时候判断一下是不是 fa。

AC 代码:

//树链剖分模板
//总结:两遍 dfs+线段树+lca 思想,线段树很重要 
#include<bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define A printf("A")
#define ld long double
#define RI register int
#define max(x,y) (x)>(y)?(x):(y)
#define min(x,y) (x)<(y)?(x):(y)
#define Match
using namespace std;
const int N=3e4+2;
struct Node{
    int fa,dep,size,son,top,seg;
    #define fa(x) tree[x].fa
    #define d(x) tree[x].dep
    #define s(x) tree[x].size
    #define son(x) tree[x].son
    #define top(x) tree[x].top
    #define seg(x) tree[x].seg
}tree[N];
struct Edge{
    int nxt,to;
    #define nxt(x) edge[x].nxt
    #define to(x) edge[x].to
}edge[N<<2];
int num[N<<2],rev[N<<2],sum[N<<2],Max[N<<2],head[N<<2];
int n,m,cnt,Ans_sum,Ans_max;
template <typename Tp> inline void IN(Tp &x){
    int f=1;x=0;char ch=getchar();
    while(ch<'0'||ch>'9')if(ch=='-')f=-1,ch=getchar();
    while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();x*=f;
}
inline void pushup(int x){
    sum[x]=sum[x<<1]+sum[(x<<1)+1];
    Max[x]=max(Max[x<<1],Max[(x<<1)+1]);
}
inline void build(int k,int l,int r){
    int mid=(l+r)>>1;if(l==r)
    {sum[k]=Max[k]=num[rev[l]];return;}
    build(k<<1,l,mid);build((k<<1)+1,mid+1,r);pushup(k);
}
inline void change(int k,int l,int r,int val,int x){
    if(x>r||x<l)return;int mid=(l+r)>>1;
    if(l==r&&r==x){sum[k]=Max[k]=val;return;}
    if(mid>=x)change(k<<1,l,mid,val,x);
    if(mid<x)change((k<<1)+1,mid+1,r,val,x);pushup(k);
}
inline void query(int k,int l,int r,int L,int R){
    if(L>r||R<l)return;
    if(L<=l&&r<=R){
        Ans_max=max(Ans_max,Max[k]);
        Ans_sum+=sum[k];return;
    }int mid=(l+r)>>1;
    if(mid>=L)query(k<<1,l,mid,L,R);
    if(mid<r)query((k<<1)+1,mid+1,r,L,R);
}
inline void dfs1(int u,int f){
    s(u)=1,fa(u)=f,d(u)=d(f)+1;
    for(register int i=head[u];i;i=nxt(i)){
        int v=to(i);if(v==f)continue;
        dfs1(v,u);s(u)+=s(v);
        if(s(v)>s(son(u)))son(u)=v;
    }return;
}
inline void dfs2(int u,int f){
    if(son(u)){
        seg(son(u))=++seg(0),top(son(u))=top(u);
        rev[seg(0)]=son(u),dfs2(son(u),u);
    }for(register int i=head[u];i;i=nxt(i)){
        int v=to(i);if(!top(v)&&v!=f){
            seg(v)=++seg(0),top(v)=v;
            rev[seg(0)]=v;dfs2(v,u);
        }
    }return;
}
inline void add(int x,int y){
    nxt(++cnt)=head[x],head[x]=cnt,to(cnt)=y;
    nxt(++cnt)=head[y],head[y]=cnt,to(cnt)=x;
}
inline void ask(int x,int y){
    int fx=top(x),fy=top(y);
    while(fx!=fy){
        if(d(fx)<d(fy))swap(fx,fy),swap(x,y);
        query(1,1,seg(0),seg(fx),seg(x));
        x=fa(fx),fx=top(x);
    }if(d(x)>d(y))swap(x,y);
    query(1,1,seg(0),seg(x),seg(y));return;
}char op[10];
int main(){
//    freopen("1036.in","r",stdin);
//    freopen("1036.out","w",stdout);
    scanf("%d",&n);
    for(register int x,y,i=1;i<n;++i)
    {scanf("%d%d",&x,&y);add(x,y);}
    for(register int i=1;i<=n;++i)scanf("%d",&num[i]);
    dfs1(1,0);seg(0)=seg(1)=top(1)=rev[1]=1;
    dfs2(1,0);build(1,1,seg(0));scanf("%d",&m);
    for(register int x,y,i=1;i<=m;++i){
        scanf("%s",op);scanf("%d%d",&x,&y);
        if(op[0]=='C')change(1,1,seg(0),y,seg(x));
        else{
            Ans_sum=0,Ans_max=-(inf<<1);ask(x,y);
            if(op[1]=='S')printf("%d\n",Ans_sum);
            else printf("%d\n",Ans_max);
        }
    }return 0;
}
我居然因为少了 f 调了三个多小时…..
分类: 文章

Qiuly

QAQ

0 条评论

发表回复

Avatar placeholder

您的电子邮箱地址不会被公开。 必填项已用 * 标注