树状数组——从背模板到树套树

这是一篇披着PJ组数据结构外衣的树套树教程。

大约会(尝试着)较为本质地简介一下树状数组?

基础树状数组

树状数组,英文名 BIT(Binary Indexed Tree)(不是TreeArray)。

原理的话..看图大约是一目了然的:

其中,黑色的矩形(包括红色的正方形)代表这一部分的和,而红色的正方形代表这部分和在树状数组中的下标。如果把这些区间连边,就像是一棵二叉树,所以叫树状数组。

举几个栗子,$BIT[3]$ 表示 $A[3]$,$BIT[6]$ 表示 $A[5]+A[6]$,$BIT[12]$ 表示 $A[9]+A[10]+A[11]+A[12]$。

我们把每个下标用二进制表示,可以发现,二进制表示的末尾有 $k​$ 个 $0​$,在树状数组里它就代表一段长为 $2^k​$ 的区间的和。由于树状数组和下标的二进制联系紧密,所以英文叫 Binary Indexed Tree。

可以定义 $lowbit(x)$ 为 $x$ 的二进制表示中最低位的 $1$ 表示的数。如 $lowbit(101_{(2)})=1$,$lowbit(110100_{(2)}=4)$,这样的话,树状数组中下标为 $x$ 的元素就表示了一段长为 $lowbit(x)$ 的区间的和。

由于计算机中存储带符号整数的方式,$lowbit(x)=$x&-x,具体原因可以自行搜索“补码”。

考虑如何更新树状数组:如果我们要更新第 $p$ 位,先更新 $BIT[p]$,再更新 $BIT[p+lowbit(p)]$,再更新 $BIT[p+lowbit(p)+lowbit(p+lowbit(p))]$……一直更新到原数列的长度。

考虑如何查询某个前缀和:如果我们要查询前 $p$ 位的前缀和,结果就是 $BIT[p]+BIT[p-lowbit(p)]+BIT[p-lowbit(p)-lowbit(p-lowbit(p))]$……一直查询到 $lowbit$ 为 $1$ 的节点。

把树状数组看成二叉树,深度不超过 $\log(n)$,所以单次操作复杂度是 $\mathcal O(\log n)$。

大概就是这样,代码比较简短:

1
2
3
4
5
6
7
8
9
10
11
void add(int p,int x)
{
for (;p<=n;p+=(p&-p)) BIT[p]+=x;
}

int query(int p)
{
int out=0;
for (;p;p-=(p&-p)) out+=BIT[p];
return out;
}

稍进阶一点点的树状数组

由于本篇教程是“从背模板到树套树”而不是“摆脱线段树与平衡树”,所以不会提及那方面的高级用法。

维护前缀积

把+改成*。

维护前缀异或和

把+改成^。

维护前缀矩阵积

把+改成矩阵乘法。

诶,等等,怎么全WA了?

因为矩阵乘法不具有交换律..

比如说,两个矩阵 $A$ 和 $B$,树状数组里存的是 $A$ 和 $A\times B$,把 $A$ 乘上 $C$ 后树状数组里第二项我们期望它是 $A\times C\times B$,而实际上它是 $A\times B\times C$..

所以树状数组到底在维护什么?

警告:本人其实没怎么学过群论..下文群论相关可能有口胡成分。

在维护一个阿贝尔群..

等等,群是什么?群号多少?

..就是一堆元素,定义了一种运算,它满足结合律、交换律,有单位元(谁和它运算都得到本身)、逆元(每个元素都存在一个元素运算后得到单位元)。如果只是前缀信息按理来说是不需要逆元的..然而一般都是要维护区间信息,而不只是前缀信息,所以需要逆元..

树状数组套动态开点线段树

简介

终于到正题了。

我们来定义一个阿贝尔群:

它的元素是一些同构的动态开点线段树,运算是把对应节点的信息相加,要求节点维护的信息是阿贝尔群。

一般来说,主席树可以解决的静态问题带修就要用树套树了..

修改就是把树状数组里的+换成动态开点线段树的修改操作,询问就是把+换成merge。直接 merge 复杂度好像不太对..(其实我不太会证线段树合并复杂度..)所以可以开个数组,把需要询问的节点存下来,然后在询问函数里合并信息。如果是询问区间,就把两个端点在树状数组里对应的节点存下来。

例题

P2617 Dynamic Rankings

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <iostream>
#include <cstdio>
#include <cctype>
#include <algorithm>

using namespace std;

int read()
{
int out=0;
char c;
while (!isdigit(c=getchar()));
for (;isdigit(c);c=getchar()) out=out*10+c-'0';
return out;
}

const int N=100001;

struct Node
{
int val,ls,rs;
} t[N<<9];

int modify(int x,int l,int r,int p,int type);
int merge(int x,int y);
int query(int l,int r,int k);
void change(int p,int x,int y);

int n,m,tot,a[N],BIT[N],lsh[N<<1],cnt,tp[N],l[N],r[N],xx[N],totx,toty,x[N],y[N];
char op[10];

int main()
{
int i,j;

n=read();
m=read();

for (i=1;i<=n;++i) lsh[++cnt]=a[i]=read();

for (i=1;i<=m;++i)
{
scanf("%s",op);
if (op[0]=='Q')
{
tp[i]=0;
l[i]=read();
r[i]=read();
xx[i]=read();
}
else
{
tp[i]=1;
l[i]=read();
lsh[++cnt]=xx[i]=read();
}
}

sort(lsh+1,lsh+cnt+1);
cnt=unique(lsh+1,lsh+cnt+1)-lsh;

for (i=1;i<=n;++i)
{
a[i]=lower_bound(lsh+1,lsh+cnt,a[i])-lsh;
change(i,a[i],1);
}

for (i=1;i<=m;++i)
{
if (tp[i])
{
change(l[i],a[l[i]],-1);
change(l[i],a[l[i]]=xx[i]=lower_bound(lsh+1,lsh+cnt,xx[i])-lsh,1);
}
else
{
totx=toty=0;
for (j=l[i]-1;j;j-=(j&-j)) x[++totx]=BIT[j];
for (j=r[i];j;j-=(j&-j)) y[++toty]=BIT[j];
printf("%d\n",query(1,cnt,xx[i]));
}
}

return 0;
}

void change(int p,int x,int y)
{
for (;p<=n;p+=(p&-p)) BIT[p]=modify(BIT[p],1,cnt,x,y);
}

int modify(int x,int l,int r,int p,int type)
{
int u=++tot;
t[u]=t[x];
t[u].val+=type;
if (l==r-1) return u;
int mid=l+r>>1;
if (p<mid) t[u].ls=modify(t[u].ls,l,mid,p,type);
else t[u].rs=modify(t[u].rs,mid,r,p,type);
return u;
}

int query(int l,int r,int k)
{
if (l==r-1) return lsh[l];
int i,sum=0;
for (i=1;i<=totx;++i) sum-=t[t[x[i]].ls].val;
for (i=1;i<=toty;++i) sum+=t[t[y[i]].ls].val;
if (sum>=k)
{
for (i=1;i<=totx;++i) x[i]=t[x[i]].ls;
for (i=1;i<=toty;++i) y[i]=t[y[i]].ls;
return query(l,l+r>>1,k);
}
else
{
for (i=1;i<=totx;++i) x[i]=t[x[i]].rs;
for (i=1;i<=toty;++i) y[i]=t[y[i]].rs;
return query(l+r>>1,r,k-sum);
}
}

P3380 【模板】二逼平衡树(树套树)

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include <iostream>
#include <cstdio>
#include <cctype>
#include <algorithm>

using namespace std;

const int N=50010;
const int INF=0x7fffffff;

int read()
{
int out=0;
char c;
while (!isdigit(c=getchar()));
for (;isdigit(c);c=getchar()) out=out*10+c-'0';
return out;
}

struct Node
{
int val,ls,rs;
} t[N<<8];

int insert(int x,int l,int r,int p,int type);
int qsum(int l,int r,int L,int R,int d);
int kth(int l,int r,int k);
void modify(int p,int x,int y);

int n,m,tot,a[N],BIT[N],lsh[N<<1],cnt,op[N],l[N],r[N],k[N],totx,toty,X[20][N],Y[20][N];

int main()
{
int i,p;

n=read();
m=read();

for (i=1;i<=n;++i) a[i]=lsh[++cnt]=read();

for (i=1;i<=m;++i)
{
op[i]=read();
if (op[i]==3)
{
l[i]=read();
k[i]=lsh[++cnt]=read();
}
else
{
l[i]=read();
r[i]=read();
k[i]=read();
if (op[i]!=2) lsh[++cnt]=k[i];
}
}

sort(lsh+1,lsh+cnt+1);
cnt=unique(lsh+1,lsh+cnt+1)-lsh;

for (i=1;i<=n;++i)
{
a[i]=lower_bound(lsh+1,lsh+cnt,a[i])-lsh;
modify(i,a[i],1);
}

for (i=1;i<=m;++i)
{
if (op[i]==3)
{
k[i]=lower_bound(lsh+1,lsh+cnt,k[i])-lsh;
modify(l[i],a[l[i]],-1);
modify(l[i],a[l[i]]=k[i],1);
}
else
{
totx=toty=0;
for (p=l[i]-1;p;p-=(p&-p)) X[0][++totx]=BIT[p];
for (p=r[i];p;p-=(p&-p)) Y[0][++toty]=BIT[p];
if (op[i]==2) printf("%d\n",kth(1,cnt,k[i]));
else
{
k[i]=lower_bound(lsh+1,lsh+cnt,k[i])-lsh;
if (op[i]==1) printf("%d\n",qsum(1,cnt,1,k[i],0)+1);
else if (op[i]==4)
{
int rk=qsum(1,cnt,1,k[i],0);
if (rk) printf("%d\n",kth(1,cnt,rk));
else printf("%d\n",-INF);
}
else
{
int rk=qsum(1,cnt,1,k[i]+1,0);
if (rk<=r[i]-l[i]) printf("%d\n",kth(1,cnt,rk+1));
else printf("%d\n",INF);
}
}
}
}

return 0;
}

void modify(int p,int x,int y)
{
for (;p<=n;p+=(p&-p)) BIT[p]=insert(BIT[p],1,cnt,x,y);
}

int insert(int x,int l,int r,int p,int type)
{
int u=++tot;
t[u]=t[x];
t[u].val+=type;
if (l==r-1) return u;
int mid=l+r>>1;
if (p<mid) t[u].ls=insert(t[u].ls,l,mid,p,type);
else t[u].rs=insert(t[u].rs,mid,r,p,type);
return u;
}

int qsum(int l,int r,int L,int R,int d)
{
if (l>=R||r<=L) return 0;
int i,sum=0;
if (L<=l&&R>=r)
{
for (i=1;i<=totx;++i) sum-=t[X[d][i]].val;
for (i=1;i<=toty;++i) sum+=t[Y[d][i]].val;
return sum;
}
for (i=1;i<=totx;++i) X[d+1][i]=t[X[d][i]].ls;
for (i=1;i<=toty;++i) Y[d+1][i]=t[Y[d][i]].ls;
sum=qsum(l,l+r>>1,L,R,d+1);
for (i=1;i<=totx;++i) X[d+1][i]=t[X[d][i]].rs;
for (i=1;i<=toty;++i) Y[d+1][i]=t[Y[d][i]].rs;
return sum+qsum(l+r>>1,r,L,R,d+1);
}

int kth(int l,int r,int k)
{
if (l==r-1) return lsh[l];
int i,sum=0;
for (i=1;i<=totx;++i) sum-=t[t[X[0][i]].ls].val;
for (i=1;i<=toty;++i) sum+=t[t[Y[0][i]].ls].val;
if (sum>=k)
{
for (i=1;i<=totx;++i) X[0][i]=t[X[0][i]].ls;
for (i=1;i<=toty;++i) Y[0][i]=t[Y[0][i]].ls;
return kth(l,l+r>>1,k);
}
for (i=1;i<=totx;++i) X[0][i]=t[X[0][i]].rs;
for (i=1;i<=toty;++i) Y[0][i]=t[Y[0][i]].rs;
return kth(l+r>>1,r,k-sum);
}

[CQOI2011]动态逆序对

这题用树套树做有点卡空间..需要把带返回值的动态开点改成直接修改。

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <iostream>
#include <cstdio>
#include <cctype>
#include <cstring>

using namespace std;

int read()
{
int out=0;
char c;
while (!isdigit(c=getchar()));
for (;isdigit(c);c=getchar()) out=out*10+c-'0';
return out;
}

const int N=100010;

struct Node
{
int val,ls,rs;
} t[N*90];

void change(int& u,int l,int r,int p);
int query(int l,int r,int L,int R,int d);

int n,m,a[N],p[N],del[N],BIT[N],x[20][20],y[20][20],xtot,ytot,tot;
bool deleted[N];
long long ans,out[N];

int main()
{
int i,j;

n=read();
m=read();

for (i=1;i<=n;++i)
{
a[i]=read();
p[a[i]]=i;
}

for (i=1;i<=m;++i)
{
del[i]=p[read()];
deleted[del[i]]=true;
}

for (i=n;i>=1;--i)
{
if (!deleted[i])
{
for (j=a[i];j;j-=(j&-j)) ans+=BIT[j];
for (j=a[i];j<=n;j+=(j&-j)) ++BIT[j];
}
}

memset(BIT,0,sizeof(BIT));

for (i=1;i<=n;++i)
{
if (!deleted[i])
{
for (j=i;j<=n;j+=(j&-j))
{
change(BIT[j],1,n+1,a[i]);
}
}
}

for (i=m;i>=1;--i)
{
xtot=ytot=0;
for (j=del[i];j;j-=(j&-j)) y[0][++ytot]=BIT[j];
ans+=query(1,n+1,a[del[i]]+1,n+1,0);
xtot=ytot=0;
for (j=del[i];j;j-=(j&-j)) x[0][++xtot]=BIT[j];
for (j=n;j;j-=(j&-j)) y[0][++ytot]=BIT[j];
ans+=query(1,n+1,1,a[del[i]],0);
for (j=del[i];j<=n;j+=(j&-j)) change(BIT[j],1,n+1,a[del[i]]);
out[i]=ans;
}

for (i=1;i<=m;++i) printf("%lld\n",out[i]);

return 0;
}

void change(int& u,int l,int r,int p)
{
if (!u) u=++tot;
++t[u].val;
if (l==r-1) return;
int mid=l+r>>1;
if (p<mid) change(t[u].ls,l,mid,p);
else change(t[u].rs,mid,r,p);
}

int query(int l,int r,int L,int R,int d)
{
if (l>=R||r<=L) return 0;
int i,ret=0,mid=l+r>>1;
if (L<=l&&R>=r)
{
for (i=1;i<=xtot;++i) ret-=t[x[d][i]].val;
for (i=1;i<=ytot;++i) ret+=t[y[d][i]].val;
return ret;
}
for (i=1;i<=xtot;++i) x[d+1][i]=t[x[d][i]].ls;
for (i=1;i<=ytot;++i) y[d+1][i]=t[y[d][i]].ls;
ret+=query(l,mid,L,R,d+1);
for (i=1;i<=xtot;++i) x[d+1][i]=t[x[d][i]].rs;
for (i=1;i<=ytot;++i) y[d+1][i]=t[y[d][i]].rs;
ret+=query(mid,r,L,R,d+1);
return ret;
}