树状数组——从背模板到树套树

文章目录

【注意】最后更新于 February 12, 2020,文中内容可能已过时,请谨慎使用。

这是一篇披着PJ组数据结构外衣的树套树教程。

大约会(尝试着)较为本质地简介一下树状数组?

基础树状数组

树状数组,英文名 BIT(Binary Indexed Tree)(不是TreeArray)。

原理的话..看图大约是一目了然的:

BIT

其中,黑色的矩形(包括红色的正方形)代表这一部分的和,而红色的正方形代表这部分和在树状数组中的下标。如果把这些区间连边,就像是一棵二叉树,所以叫树状数组。

举几个栗子,$BIT[3]$ 表示 $A[3]$,$BIT[6]$ 表示 $A[5]+A[6]$,$BIT[12]$ 表示 $A[9]+A[10]+A[11]+A[12]$。

我们把每个下标用二进制表示,可以发现,二进制表示的末尾有 $k​$ 个 $0​$,在树状数组里它就代表一段长为 $2^k​$ 的区间的和。由于树状数组和下标的二进制联系紧密,所以英文叫 Binary Indexed Tree。

可以定义 $lowbit(x)$ 为 $x$ 的二进制表示中最低位的 $1$ 表示的数。如 $lowbit(101_{(2)})=1$,$lowbit(110100_{(2)}=4)$,这样的话,树状数组中下标为 $x$ 的元素就表示了一段长为 $lowbit(x)$ 的区间的和。

由于计算机中存储带符号整数的方式,$lowbit(x)=$x&-x,具体原因可以自行搜索“补码”。

考虑如何更新树状数组:如果我们要更新第 $p$ 位,先更新 $BIT[p]$,再更新 $BIT[p+lowbit(p)]$,再更新 $BIT[p+lowbit(p)+lowbit(p+lowbit(p))]$……一直更新到原数列的长度。

考虑如何查询某个前缀和:如果我们要查询前 $p$ 位的前缀和,结果就是 $BIT[p]+BIT[p-lowbit(p)]+BIT[p-lowbit(p)-lowbit(p-lowbit(p))]$……一直查询到 $lowbit$ 为 $1$ 的节点。

把树状数组看成二叉树,深度不超过 $\log(n)$,所以单次操作复杂度是 $\mathcal O(\log n)$。

大概就是这样,代码比较简短:

void add(int p,int x)
{
    for (;p<=n;p+=(p&-p)) BIT[p]+=x;
}

int query(int p)
{
    int out=0;
    for (;p;p-=(p&-p)) out+=BIT[p];
    return out;
}

稍进阶一点点的树状数组

由于本篇教程是“从背模板到树套树”而不是“摆脱线段树与平衡树”,所以不会提及那方面的高级用法。

维护前缀积

把+改成*。

维护前缀异或和

把+改成^。

维护前缀矩阵积

把+改成矩阵乘法。

诶,等等,怎么全WA了?

因为矩阵乘法不具有交换律..

比如说,两个矩阵 $A$ 和 $B$,树状数组里存的是 $A$ 和 $A\times B$,把 $A$ 乘上 $C$ 后树状数组里第二项我们期望它是 $A\times C\times B$,而实际上它是 $A\times B\times C$..

所以树状数组到底在维护什么?

警告:本人其实没怎么学过群论..下文群论相关可能有口胡成分。

在维护一个阿贝尔群..

等等,群是什么?群号多少?

..就是一堆元素,定义了一种运算,它满足结合律、交换律,有单位元(谁和它运算都得到本身)、逆元(每个元素都存在一个元素运算后得到单位元)。如果只是前缀信息按理来说是不需要逆元的..然而一般都是要维护区间信息,而不只是前缀信息,所以需要逆元..

树状数组套动态开点线段树

简介

终于到正题了。

我们来定义一个阿贝尔群:

它的元素是一些同构的动态开点线段树,运算是把对应节点的信息相加,要求节点维护的信息是阿贝尔群。

一般来说,主席树可以解决的静态问题带修就要用树套树了..

修改就是把树状数组里的+换成动态开点线段树的修改操作,询问就是把+换成merge。直接 merge 复杂度好像不太对..(其实我不太会证线段树合并复杂度..)所以可以开个数组,把需要询问的节点存下来,然后在询问函数里合并信息。如果是询问区间,就把两个端点在树状数组里对应的节点存下来。

例题

Dynamic Rankings

代码
#include <iostream>
#include <cstdio>
#include <cctype>
#include <algorithm>

using namespace std;

int read()
{
    int out=0;
    char c;
    while (!isdigit(c=getchar()));
    for (;isdigit(c);c=getchar()) out=out*10+c-'0';
    return out;
}

const int N=100001;

struct Node
{
    int val,ls,rs;
} t[N<<9];

int modify(int x,int l,int r,int p,int type);
int merge(int x,int y);
int query(int l,int r,int k);
void change(int p,int x,int y);

int n,m,tot,a[N],BIT[N],lsh[N<<1],cnt,tp[N],l[N],r[N],xx[N],totx,toty,x[N],y[N];
char op[10];

int main()
{
    int i,j;

    n=read();
    m=read();

    for (i=1;i<=n;++i) lsh[++cnt]=a[i]=read();

    for (i=1;i<=m;++i)
    {
        scanf("%s",op);
        if (op[0]=='Q')
        {
            tp[i]=0;
            l[i]=read();
            r[i]=read();
            xx[i]=read();
        }
        else
        {
            tp[i]=1;
            l[i]=read();
            lsh[++cnt]=xx[i]=read();
        }
    }

    sort(lsh+1,lsh+cnt+1);
    cnt=unique(lsh+1,lsh+cnt+1)-lsh;

    for (i=1;i<=n;++i)
    {
        a[i]=lower_bound(lsh+1,lsh+cnt,a[i])-lsh;
        change(i,a[i],1);
    }

    for (i=1;i<=m;++i)
    {
        if (tp[i])
        {
            change(l[i],a[l[i]],-1);
            change(l[i],a[l[i]]=xx[i]=lower_bound(lsh+1,lsh+cnt,xx[i])-lsh,1);
        }
        else
        {
            totx=toty=0;
            for (j=l[i]-1;j;j-=(j&-j)) x[++totx]=BIT[j];
            for (j=r[i];j;j-=(j&-j)) y[++toty]=BIT[j];
            printf("%d\n",query(1,cnt,xx[i]));
        }
    }

    return 0;
}

void change(int p,int x,int y)
{
    for (;p<=n;p+=(p&-p)) BIT[p]=modify(BIT[p],1,cnt,x,y);
}

int modify(int x,int l,int r,int p,int type)
{
    int u=++tot;
    t[u]=t[x];
    t[u].val+=type;
    if (l==r-1) return u;
    int mid=l+r>>1;
    if (p<mid) t[u].ls=modify(t[u].ls,l,mid,p,type);
    else t[u].rs=modify(t[u].rs,mid,r,p,type);
    return u;
}

int query(int l,int r,int k)
{
    if (l==r-1) return lsh[l];
    int i,sum=0;
    for (i=1;i<=totx;++i) sum-=t[t[x[i]].ls].val;
    for (i=1;i<=toty;++i) sum+=t[t[y[i]].ls].val;
    if (sum>=k)
    {
        for (i=1;i<=totx;++i) x[i]=t[x[i]].ls;
        for (i=1;i<=toty;++i) y[i]=t[y[i]].ls;
        return query(l,l+r>>1,k);
    }
    else
    {
        for (i=1;i<=totx;++i) x[i]=t[x[i]].rs;
        for (i=1;i<=toty;++i) y[i]=t[y[i]].rs;
        return query(l+r>>1,r,k-sum);
    }
}

【模板】二逼平衡树(树套树)

代码
#include <iostream>
#include <cstdio>
#include <cctype>
#include <algorithm>

using namespace std;

const int N=50010;
const int INF=0x7fffffff;

int read()
{
    int out=0;
    char c;
    while (!isdigit(c=getchar()));
    for (;isdigit(c);c=getchar()) out=out*10+c-'0';
    return out;
}

struct Node
{
    int val,ls,rs;
} t[N<<8];

int insert(int x,int l,int r,int p,int type);
int qsum(int l,int r,int L,int R,int d);
int kth(int l,int r,int k);
void modify(int p,int x,int y);

int n,m,tot,a[N],BIT[N],lsh[N<<1],cnt,op[N],l[N],r[N],k[N],totx,toty,X[20][N],Y[20][N];

int main()
{
    int i,p;

    n=read();
    m=read();

    for (i=1;i<=n;++i) a[i]=lsh[++cnt]=read();

    for (i=1;i<=m;++i)
    {
        op[i]=read();
        if (op[i]==3)
        {
            l[i]=read();
            k[i]=lsh[++cnt]=read();
        }
        else
        {
            l[i]=read();
            r[i]=read();
            k[i]=read();
            if (op[i]!=2) lsh[++cnt]=k[i];
        }
    }

    sort(lsh+1,lsh+cnt+1);
    cnt=unique(lsh+1,lsh+cnt+1)-lsh;

    for (i=1;i<=n;++i)
    {
        a[i]=lower_bound(lsh+1,lsh+cnt,a[i])-lsh;
        modify(i,a[i],1);
    }

    for (i=1;i<=m;++i)
    {
        if (op[i]==3)
        {
            k[i]=lower_bound(lsh+1,lsh+cnt,k[i])-lsh;
            modify(l[i],a[l[i]],-1);
            modify(l[i],a[l[i]]=k[i],1);
        }
        else
        {
            totx=toty=0;
            for (p=l[i]-1;p;p-=(p&-p)) X[0][++totx]=BIT[p];
            for (p=r[i];p;p-=(p&-p)) Y[0][++toty]=BIT[p];
            if (op[i]==2) printf("%d\n",kth(1,cnt,k[i]));
            else
            {
                k[i]=lower_bound(lsh+1,lsh+cnt,k[i])-lsh;
                if (op[i]==1) printf("%d\n",qsum(1,cnt,1,k[i],0)+1);
                else if (op[i]==4)
                {
                    int rk=qsum(1,cnt,1,k[i],0);
                    if (rk) printf("%d\n",kth(1,cnt,rk));
                    else printf("%d\n",-INF);
                }
                else
                {
                    int rk=qsum(1,cnt,1,k[i]+1,0);
                    if (rk<=r[i]-l[i]) printf("%d\n",kth(1,cnt,rk+1));
                    else printf("%d\n",INF);
                }
            }
        }
    }

    return 0;
}

void modify(int p,int x,int y)
{
    for (;p<=n;p+=(p&-p)) BIT[p]=insert(BIT[p],1,cnt,x,y);
}

int insert(int x,int l,int r,int p,int type)
{
    int u=++tot;
    t[u]=t[x];
    t[u].val+=type;
    if (l==r-1) return u;
    int mid=l+r>>1;
    if (p<mid) t[u].ls=insert(t[u].ls,l,mid,p,type);
    else t[u].rs=insert(t[u].rs,mid,r,p,type);
    return u;
}

int qsum(int l,int r,int L,int R,int d)
{
    if (l>=R||r<=L) return 0;
    int i,sum=0;
    if (L<=l&&R>=r)
    {
        for (i=1;i<=totx;++i) sum-=t[X[d][i]].val;
        for (i=1;i<=toty;++i) sum+=t[Y[d][i]].val;
        return sum;
    }
    for (i=1;i<=totx;++i) X[d+1][i]=t[X[d][i]].ls;
    for (i=1;i<=toty;++i) Y[d+1][i]=t[Y[d][i]].ls;
    sum=qsum(l,l+r>>1,L,R,d+1);
    for (i=1;i<=totx;++i) X[d+1][i]=t[X[d][i]].rs;
    for (i=1;i<=toty;++i) Y[d+1][i]=t[Y[d][i]].rs;
    return sum+qsum(l+r>>1,r,L,R,d+1);
}

int kth(int l,int r,int k)
{
    if (l==r-1) return lsh[l];
    int i,sum=0;
    for (i=1;i<=totx;++i) sum-=t[t[X[0][i]].ls].val;
    for (i=1;i<=toty;++i) sum+=t[t[Y[0][i]].ls].val;
    if (sum>=k)
    {
        for (i=1;i<=totx;++i) X[0][i]=t[X[0][i]].ls;
        for (i=1;i<=toty;++i) Y[0][i]=t[Y[0][i]].ls;
        return kth(l,l+r>>1,k);
    }
    for (i=1;i<=totx;++i) X[0][i]=t[X[0][i]].rs;
    for (i=1;i<=toty;++i) Y[0][i]=t[Y[0][i]].rs;
    return kth(l+r>>1,r,k-sum);
}

[CQOI2011]动态逆序对

这题用树套树做有点卡空间..需要把带返回值的动态开点改成直接修改。

代码
#include <iostream>
#include <cstdio>
#include <cctype>
#include <cstring>

using namespace std;

int read()
{
    int out=0;
    char c;
    while (!isdigit(c=getchar()));
    for (;isdigit(c);c=getchar()) out=out*10+c-'0';
    return out;
}

const int N=100010;

struct Node
{
    int val,ls,rs;
} t[N*90];

void change(int& u,int l,int r,int p);
int query(int l,int r,int L,int R,int d);

int n,m,a[N],p[N],del[N],BIT[N],x[20][20],y[20][20],xtot,ytot,tot;
bool deleted[N];
long long ans,out[N];

int main()
{
    int i,j;

    n=read();
    m=read();

    for (i=1;i<=n;++i)
    {
        a[i]=read();
        p[a[i]]=i;
    }

    for (i=1;i<=m;++i)
    {
        del[i]=p[read()];
        deleted[del[i]]=true;
    }

    for (i=n;i>=1;--i)
    {
        if (!deleted[i])
        {
            for (j=a[i];j;j-=(j&-j)) ans+=BIT[j];
            for (j=a[i];j<=n;j+=(j&-j)) ++BIT[j];
        }
    }

    memset(BIT,0,sizeof(BIT));

    for (i=1;i<=n;++i)
    {
        if (!deleted[i])
        {
            for (j=i;j<=n;j+=(j&-j))
            {
                change(BIT[j],1,n+1,a[i]);
            }
        }
    }

    for (i=m;i>=1;--i)
    {
        xtot=ytot=0;
        for (j=del[i];j;j-=(j&-j)) y[0][++ytot]=BIT[j];
        ans+=query(1,n+1,a[del[i]]+1,n+1,0);
        xtot=ytot=0;
        for (j=del[i];j;j-=(j&-j)) x[0][++xtot]=BIT[j];
        for (j=n;j;j-=(j&-j)) y[0][++ytot]=BIT[j];
        ans+=query(1,n+1,1,a[del[i]],0);
        for (j=del[i];j<=n;j+=(j&-j)) change(BIT[j],1,n+1,a[del[i]]);
        out[i]=ans;
    }

    for (i=1;i<=m;++i) printf("%lld\n",out[i]);

    return 0;
}

void change(int& u,int l,int r,int p)
{
    if (!u) u=++tot;
    ++t[u].val;
    if (l==r-1) return;
    int mid=l+r>>1;
    if (p<mid) change(t[u].ls,l,mid,p);
    else change(t[u].rs,mid,r,p);
}

int query(int l,int r,int L,int R,int d)
{
    if (l>=R||r<=L) return 0;
    int i,ret=0,mid=l+r>>1;
    if (L<=l&&R>=r)
    {
        for (i=1;i<=xtot;++i) ret-=t[x[d][i]].val;
        for (i=1;i<=ytot;++i) ret+=t[y[d][i]].val;
        return ret;
    }
    for (i=1;i<=xtot;++i) x[d+1][i]=t[x[d][i]].ls;
    for (i=1;i<=ytot;++i) y[d+1][i]=t[y[d][i]].ls;
    ret+=query(l,mid,L,R,d+1);
    for (i=1;i<=xtot;++i) x[d+1][i]=t[x[d][i]].rs;
    for (i=1;i<=ytot;++i) y[d+1][i]=t[y[d][i]].rs;
    ret+=query(mid,r,L,R,d+1);
    return ret;
}

评论正在加载中...如果评论较长时间无法加载,你可以 搜索对应的 issue 或者 新建一个 issue