cdq分治学习笔记
文章目录
【注意】最后更新于 January 10, 2021,文中内容可能已过时,请谨慎使用。
cdq分治也是咕了好久了..最近总算把它学了。
cdq分治是一种离线算法,可以代替一些复杂的数据结构,降低代码难度,减小常数。废话大家都知道。
本文未完待续(cdq分治的其它应用,如维护凸壳,待填坑)。
简介
感觉cdq分治不如叫“ex归并排序”,就是以操作的时间作为初始顺序,在递归处理的过程中按位置归并排序。
更一般地说,对于一个二维偏序 $P(i,j)=P_1(a_i,a_j)\land P_2(b_i,b_j)$,位置 $i$ 的修改对位置 $j$ 的询问(询问为类前缀和形式,区间询问需拆成两个前缀询问)有影响当且仅当 $P(i,j)=true$,cdq分治就是以其中一维为初始顺序,对另一维进行归并排序的过程中计算左区间里修改的总和,将左区间修改的影响应用到右区间。
学会了之后就会发现,cdq分治的确就是这样,已经描述的很清楚了,然而在没学会的时候估计是看不懂上面这段话的..所以结合具体题目来看一看吧。
例题
【模板】树状数组 1
树状数组裸题!冷静,我们来用ex归并排序做..(嗯,我决定就这么叫它了)
按照我们上面说的,我们把操作存下来,询问拆成两个前缀和相减,初始值视作修改,需要存的信息有操作的种类(修改、询问的左端点减一、询问的右端点),操作的位置($p$、$l-1$、$r$)以及修改加上的值/询问的编号。如果写法正常的话你已经以操作的时间作为初始顺序了..
然后,写个归并排序,按操作的位置排序,同一个位置的修改要放在询问的前面。然后,在归并排序的过程中,遇到左区间里的修改就更新左区间修改的总和,遇到右区间里的询问就用记录的“左区间修改的总和”更新这个询问的答案。
具体见代码
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
const int N=500010;
struct Node
{
int type,p,val; //type为2表示修改,type为-1表示左端点减一,type为1表示右端点
bool operator<(const Node& b) const { return p==b.p?type>b.type:p<b.p; }
} q[N<<2],tmp[N<<2];
void solve(int l,int r);
int n,m,tot,qtot;
ll ans[N];
int main()
{
int i,op,x,y;
scanf("%d%d",&n,&m);
for (i=1;i<=n;++i) //初始值视作修改
{
scanf("%d",&x);
q[++tot].type=2;
q[tot].p=i;
q[tot].val=x;
}
for (i=1;i<=m;++i)
{
scanf("%d%d%d",&op,&x,&y);
if (op==1)
{
q[++tot].type=2;
q[tot].p=x;
q[tot].val=y;
}
else //询问拆成两个前缀和相减
{
q[++tot].type=-1;
q[tot].p=x-1;
q[tot].val=++qtot;
q[++tot].type=1;
q[tot].p=y;
q[tot].val=qtot;
}
}
solve(1,tot+1);
for (i=1;i<=qtot;++i) printf("%lld\n",ans[i]);
return 0;
}
void solve(int l,int r)
{
if (l==r-1) return;
int i,j,k,mid;
ll sum=0;
i=k=l;
j=mid=(l+r)>>1;
solve(l,mid);
solve(mid,r);
while (i<mid&&j<r)
{
if (q[i]<q[j])
{
if (q[i].type==2) sum+=q[i].val; //记录左区间里的修改之和
tmp[k++]=q[i++];
}
else
{
if (q[j].type!=2) ans[q[j].val]+=q[j].type*sum; //将左区间里的修改应用到右区间里的询问
tmp[k++]=q[j++];
}
}
while (i<mid) tmp[k++]=q[i++];
while (j<r)
{
if (q[j].type!=2) ans[q[j].val]+=q[j].type*sum;
tmp[k++]=q[j++];
}
for (i=l;i<r;++i) q[i]=tmp[i];
}
之前说过ex归并排序本质上是一个二维偏序限制了修改对询问的影响,所以也可以先按位置排序再按时间排序。只不过..这样写很奇怪,很麻烦,常数又大。然而为了理解ex归并排序的本质,我还是写了份这个做法..
一种奇怪的写法
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=500010;
struct Node
{
int type,tim,p,val;
Node(int _type=0,int _tim=0,int _p=0,int _val=0):type(_type),tim(_tim),p(_p),val(_val){}
} q[N<<2],tmp[N<<2];
void solve(int l,int r);
int n,m,tot,qtot;
ll ans[N];
int main()
{
int i,j,op,x,y;
scanf("%d%d",&n,&m);
for (i=1;i<=n;++i)
{
scanf("%d",&x);
q[++tot]=Node(2,0,i,x);
}
for (i=1;i<=m;++i)
{
scanf("%d%d%d",&op,&x,&y);
if (op==1) q[++tot]=Node(2,i,x,y);
else
{
q[++tot]=Node(-1,i,x-1,++qtot);
q[++tot]=Node(1,i,y,qtot);
}
}
sort(q+1,q+tot+1,[](const Node& x,const Node& y){return x.p==y.p?x.type>y.type:x.p<y.p;});
solve(1,tot+1);
for (i=1;i<=qtot;++i) printf("%lld\n",ans[i]);
return 0;
}
void solve(int l,int r)
{
if (l==r-1) return;
int i,j,k,mid;
ll sum=0;
i=k=l;
j=mid=(l+r)>>1;
solve(l,mid);
solve(mid,r);
while (i<mid&&j<r)
{
if (q[i].tim<q[j].tim)
{
if (q[i].type==2) sum+=q[i].val;
tmp[k++]=q[i++];
}
else
{
if (q[j].type!=2) ans[q[j].val]+=q[j].type*sum;
tmp[k++]=q[j++];
}
}
while (i<mid) tmp[k++]=q[i++];
while (j<r)
{
if (q[j].type!=2) ans[q[j].val]+=q[j].type*sum;
tmp[k++]=q[j++];
}
for (i=l;i<r;++i) q[i]=tmp[i];
}
【模板】三维偏序(陌上花开)
有两种做法,一种是cdq分治套树状数组,需要注意的有两点,一是清空树状数组可以用时间戳,二是 $a$, $b$, $c$ 都相等的元素要合并。
参考代码
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N=100010;
const int K=200010;
struct Node
{
int a,b,c,w,f;
} a[N],b[N];
void solve(int l,int r);
void add(int p,int x);
int query(int p);
int n,k,d[N],BIT[K],vis[K],tim,tot;
int main()
{
int i;
scanf("%d%d",&n,&k);
for (i=1;i<=n;++i)
{
scanf("%d%d%d",&b[i].a,&b[i].b,&b[i].c);
b[i].w=1;
}
sort(b+1,b+n+1,[](const Node& x,const Node& y){return x.a==y.a?(x.b==y.b?x.c<y.c:x.b<y.b):x.a<y.a;});
for (i=1;i<=n;++i)
{
if (b[i].a!=b[i+1].a||b[i].b!=b[i+1].b||b[i].c!=b[i+1].c) a[++tot]=b[i];
else b[i+1].w+=b[i].w;
}
solve(1,tot+1);
for (i=1;i<=tot;++i) d[a[i].f+a[i].w]+=a[i].w;
for (i=1;i<=n;++i) printf("%d\n",d[i]);
return 0;
}
void solve(int l,int r)
{
if (l==r-1) return;
int i,j,k,mid;
i=k=l;
j=mid=(l+r)>>1;
solve(l,mid);
solve(mid,r);
++tim;
while (i<mid&&j<r)
{
if (a[i].b<=a[j].b)
{
add(a[i].c,a[i].w);
b[k++]=a[i++];
}
else
{
a[j].f+=query(a[j].c);
b[k++]=a[j++];
}
}
while (i<mid) b[k++]=a[i++];
while (j<r)
{
a[j].f+=query(a[j].c);
b[k++]=a[j++];
}
for (i=l;i<r;++i) a[i]=b[i];
}
void add(int p,int x)
{
for (;p<=k;p+=(p&-p))
{
if (vis[p]!=tim)
{
BIT[p]=0;
vis[p]=tim;
}
BIT[p]+=x;
}
}
int query(int p)
{
int out=0;
for (;p;p-=(p&-p)) if (vis[p]==tim) out+=BIT[p];
return out;
}
还有一种做法是cdq分治套cdq分治:
参考代码
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N=100010;
struct Node
{
int a,b,c,d,w,id;
} a[N],b[N],c[N];
void solve(int l,int r);
void solve2(int l,int r);
int n,k,d[N],tot,ans[N];
int main()
{
int i;
scanf("%d%d",&n,&k);
for (i=1;i<=n;++i)
{
scanf("%d%d%d",&b[i].a,&b[i].b,&b[i].c);
b[i].w=1;
b[i].id=i;
}
sort(b+1,b+n+1,[](const Node& x,const Node& y){return x.a==y.a?(x.b==y.b?x.c<y.c:x.b<y.b):x.a<y.a;});
for (i=1;i<=n;++i)
{
if (b[i].a!=b[i+1].a||b[i].b!=b[i+1].b||b[i].c!=b[i+1].c) a[++tot]=b[i];
else b[i+1].w+=b[i].w;
}
solve(1,tot+1);
for (i=1;i<=tot;++i) d[ans[a[i].id]+a[i].w]+=a[i].w;
for (i=1;i<=n;++i) printf("%d\n",d[i]);
return 0;
}
void solve(int l,int r)
{
if (l==r-1) return;
int i,j,k,mid;
i=k=l;
j=mid=(l+r)>>1;
solve(l,mid);
solve(mid,r);
while (i<mid&&j<r)
{
if (a[i].b<=a[j].b)
{
a[i].d=a[i].w;
b[k++]=a[i++];
}
else
{
a[j].d=0;
b[k++]=a[j++];
}
}
while (i<mid)
{
a[i].d=a[i].w;
b[k++]=a[i++];
}
while (j<r)
{
a[j].d=0;
b[k++]=a[j++];
}
for (i=l;i<r;++i) a[i]=b[i];
solve2(l,r);
}
void solve2(int l,int r)
{
if (l==r-1) return;
int i,j,k,mid,sum=0;
i=k=l;
j=mid=(l+r)>>1;
solve2(l,mid);
solve2(mid,r);
while (i<mid&&j<r)
{
if (b[i].c<=b[j].c)
{
sum+=b[i].d;
c[k++]=b[i++];
}
else
{
if (!b[j].d) ans[b[j].id]+=sum;
c[k++]=b[j++];
}
}
while (i<mid) c[k++]=b[i++];
while (j<r)
{
if (!b[j].d) ans[b[j].id]+=sum;
c[k++]=b[j++];
}
for (i=l;i<r;++i) b[i]=c[i];
}
cdq分治求偏序对的本质
(下文中“偏序问题”即求满足偏序关系的数对个数。而”高维偏序“实际上是多个严格弱序的并。非严格偏序与之类似,主要是在代码上有些细节改变。)
大家知道,二维偏序可以先按一维排序后用普通的归并排序解决,那为什么“三维偏序”不可以呢?
首先,按其中一维排序相当于降了一维,问题就变成了“为什么一维偏序可以用普通的归并排序解决,而二维偏序不可以”。
原因就在于,两个偏序关系的并,不一定具有不可比性的传递性。(Strict Weak Ordering 相关内容参见我的另一篇博客)
可以证明,两个严格弱序的并一定是一个严格偏序,但不一定是一个严格弱序。而 cdq 分治可以将多个严格弱序的并进行降维,每一次 cdq 分治都标记出哪些位置会对其它位置有贡献,并按某一维排序。
上面说的有点乱..简单概括一下。排序可以降维,只有严格弱序能排序,高维偏序不一定是严格弱序,cdq 分治在排序的过程中标记了元素之间如何贡献答案。所以 cdq 分治就可以解决高维偏序问题了..
所以,我们可以写出一份 cdq 分治求 $k$ 维偏序对的代码:
题意简述:第一行 $n$, $k$,后 $n$ 行每行 $k$ 个数 $a_{i,1..k}$,对每个 $i$ 求 $\forall d\in[1,k],a_{i,d}<a_{j,d}$ 的 $j$ 个数。
当然,$n$ 要足够大,否则会被暴力艹。只不过理论上来说,如果维数是常数,复杂度就比暴力更优…
代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int N=10005;
const int K=15;
void solve(int l,int r,int d);
int n,k,ans[N];
struct Node
{
int w[K],type,id;
bool operator<(const Node& y) const
{
if (w[0]!=y.w[0]) return w[0]<y.w[0];
for (int i=1;i<k;++i) if (w[i]!=y.w[i]) return w[i]>y.w[i]; //如果是非严格偏序都应该顺着排,严格偏序除了第一维都应该倒着排。这是由于相等元素可以/不可以转移。
return false;
}
} a[K][N],tmp[N];
int main()
{
int i,j;
scanf("%d%d",&n,&k);
for (i=1;i<=n;++i)
{
a[0][i].id=i;
for (j=0;j<k;++j) scanf("%d",a[0][i].w+j);
}
sort(a[0]+1,a[0]+n+1);
solve(1,n+1,1);
for (i=1;i<=n;++i) printf("%d\n",ans[i]);
return 0;
}
void solve(int l,int r,int d)
{
if (l==r-1) return;
int i,j,p,mid,sum=0;
i=p=l;
j=mid=(l+r)>>1;
solve(l,mid,d);
solve(mid,r,d);
while (i<mid&&j<r)
{
if (a[d-1][i].w[d]<a[d-1][j].w[d])
{
a[d][p]=tmp[p]=a[d-1][i++];
if (d>1&&a[d][p].type!=1) a[d][p].type=0;
else
{
a[d][p].type=1;
if (d==k-1) ++sum;
}
++p;
}
else
{
a[d][p]=tmp[p]=a[d-1][j++];
if (d>1&&a[d][p].type!=2) a[d][p].type=0;
else
{
a[d][p].type=2;
if (d==k-1) ans[a[d][p].id]+=sum;
}
++p;
}
}
while (i<mid)
{
a[d][p]=tmp[p]=a[d-1][i++];
if (d>1&&a[d][p].type!=1) a[d][p].type=0;
else a[d][p].type=1;
++p;
}
while (j<r)
{
a[d][p]=tmp[p]=a[d-1][j++];
if (d>1&&a[d][p].type!=2) a[d][p].type=0;
else
{
a[d][p].type=2;
if (d==k-1) ans[a[d][p].id]+=sum;
}
++p;
}
for (i=l;i<r;++i) a[d-1][i]=tmp[i];
if (d<k-1) solve(l,r,d+1);
}
评论正在加载中...如果评论较长时间无法加载,你可以 搜索对应的 issue 或者 新建一个 issue 。