CDQ分治基础版

发布时间 2023-07-17 22:49:49作者: 铃狐sama

CDQ分治学习笔记——基础分治

(后面会有更复杂的优化dpCDQ) awa

我绝对不会承认因为我还不会CDQ优化dp所以才不写进阶分治的 QAQ


CDQ分治,怎么说呢,主要是为了优化时间复杂度用的,常用于多维偏序(找点对数量)

偏序:

比如对于一个变量(结构体)而言,有三个量a,b,c;而偏序就是可能两个结构体只要满足a,b,c的某种关系ans就++(关系一般是大于小于)

直接说可能很难理解,但是看例题就明白了(只看题,后面要说思想)https://www.luogu.com.cn/problem/P3810

CDQ分治思想:

我先随便按a/b/c中一个给这个序列排序,例如我要满足x.a<y.a(b,c这里暂且不管),那么我就从小到大排序
用处:此时我选择一个l,r,令r>l,此时就会出现三种情况
l,r都在mid左边,此时递归即可;l,r都在mid右边,此时还是递归即可。这两种不再赘述
l在mid左侧,r在mid右侧,此时就要考虑贡献了。
这剩下的就很像一个题目:给你一个坐标系,然后一些点,问你一个点左下角的点有多少个,这种题可用树状数组解决

但可能还不知道怎么做,这里详细说一下,假如说要满足x.b<y.b且x.c<y.c。那我mid左侧的按照从小到大b排序,右侧同样操作
然后我mid右边端点不动(设此时位置为i),左边端点(设位置为j)只要bj小于此时bi,j就可以继续增大。当要超过时,i就统计答案然后i++继续
注意,1<=j<=mid;mid+1<=i<=R
然后i怎么统计答案呢,我现在已经知道对于右侧端点,b满足他的所有左侧端点了,现在只需要满足c就可以了
那么每一次j++前把cj加入树状数组里,在i++前查询<ci的有多少个就好了!

但是这其中有两个很重要的点:
1.写cmp的时候尽可能详细,也就是说,要把x.a=y.a考虑,x.b=y.b考虑。一个也不能少,不然很有可能wa但是样例又过
2.记得加入cj,计算完这一组CDQ(l,r)后,要清空树状数组,但不用memset(TLE警告)
3.一定要注意顺序,实现CDQ后面的,再sort计算y<=mid<=x的情况,而dp又和这个不一样

例题一

CDQ 模板:https://www.luogu.com.cn/problem/P3810

代码:



#include<bits/stdc++.h>
using namespace std;
#define int long long 
/*
调了半天,错在cmpa我只判断a的大小去了,没考虑b,cwa 
然后换行打成“”wa 
然后树状数组开小了wa on 10 
*/
struct node{
	int a;
	int b;
	int c;
	int cnt;//重复点的个数 
	int ans;//满足条件的个数 
}dot[200005],a[200005];
int as[200005];

bool cmpa(node x,node y){
	if(x.a==y.a)
	{
		if(x.b==y.b)return x.c<y.c;
		else return x.b<y.b;
	}
	else return x.a<y.a;
}//第一维排序
bool cmpb(node x,node y){
	if(x.b!=y.b)return x.b<y.b;
	else return x.c<y.c;
}

//--------------------------------------
int n,k;
int tree[400005];
int lowbit(int x){
	return x&(-x);
}
void add(int pos,int val){
	for(int i=pos;i<=k;i+=lowbit(i)){
		tree[i]+=val;
	}
}
int query(int pos){
	int ret=0;
	for(int i=pos;i>=1;i-=lowbit(i)){
		ret+=tree[i];
	}
	return ret;
}
//-------------------
void cdq(int l,int r){
	if(l==r){
		return ;
	}
	int mid=(l+r)/2;
//	cout<<"test"<<mid<<endl ;
	cdq(l,mid);
	cdq(mid+1,r);
	sort(a+l,a+mid+1,cmpb);
	sort(a+mid+1,a+r+1,cmpb);
	int i,j=l;
	for(i=mid+1;i<=r;i++){
		while(a[i].b>=a[j].b&&j<=mid){
			add(a[j].c,a[j].cnt);
			j++;
		}
		a[i].ans+=query(a[i].c);
//		cout<<"test"<<a[i].ans<<endl;
	}
	for(i=l;i<=j-1;i++){
		add(a[i].c,-a[i].cnt);
	}
	
	
}

signed main(){
	ios::sync_with_stdio(false);
	
	cin >> n >> k;
	for(int i=1;i<=n;i++){
		cin >> dot[i].a>>dot[i].b>>dot[i].c;
	}
	sort(dot+1,dot+1+n,cmpa);
	int cnt=0;
	int top=0;
	for(int i=1;i<=n;i++){
		top++;
		if(dot[i].a!=dot[i+1].a||dot[i].b!=dot[i+1].b||dot[i].c!=dot[i+1].c){
			cnt++;
			a[cnt].a=dot[i].a;
			a[cnt].b=dot[i].b;
			a[cnt].c=dot[i].c;
			a[cnt].cnt=top;
			a[cnt].ans=0;
			top=0;
		}
	}
	cdq(1,cnt);
	for(int i=1;i<=cnt;i++){
		as[a[i].ans+a[i].cnt-1]+=a[i].cnt;
	}
	for(int i=0;i<=n-1;i++){
		cout<<as[i]<<endl;
	}
//	for(int i=1;i<=k;i++){
//		cout<<a[i].ans<<" test ";
//	}
	
}


然后还有不是那么明显的三维偏序(这一维可能是范围、时间等等,请看下面两题)

例题二

https://www.luogu.com.cn/problem/P8575 这道题隐藏第三维是dfn

#include<bits/stdc++.h>
using namespace std;
#define int long long
/*
点对满足
blue[to]<blue[x]
red[to]<red[x];
to在x子树内

而最后一个条件可以视为dfn[x]+1<=dfn[to]<=dfn[x]+sz[x]-1 
*/
vector<int>mp[200005];
int n;
struct star{
	int blue;
	int red;
	int sz;
	int dfn;
	int id;
}st[200005];
int ans[200005];
int t=0;
void predfs(int x,int fa){
	st[x].sz=1;
	t++;
	st[x].dfn=t;
	for(int i=0;i<mp[x].size();i++){
		int to=mp[x][i];
		if(to==fa){
			continue;
		}
		predfs(to,x);
		st[x].sz+=st[to].sz;
	}
}
bool cmp1(star x,star y){
	if(x.blue!=y.blue)return x.blue<y.blue;
	else if(x.red!=y.red)return x.red<y.red;
	else return x.dfn>y.dfn;
}
bool cmp2(star x,star y){
	if(x.red!=y.red)return x.red<y.red;
	else return x.blue<y.blue;
}
int tree[200005];
//-----------------------------------------
int lowbit(int x){
	return x&(-x);
}
void add(int pos,int val){
	for(int i=pos;i<=n;i+=lowbit(i)){
		tree[i]+=val;
	}
}
int query(int pos){
	int ret=0;
	for(int i=pos;i>=1;i-=lowbit(i)){
		ret+=tree[i];
	}
	return ret;
}
//---------------------------------------
void cdq(int l,int r){
	if(l==r){
		return ;
	} 
	int mid=(l+r)/2;
	cdq(l,mid);
	cdq(mid+1,r);
	sort(st+l,st+mid+1,cmp2);
	sort(st+mid+1,st+1+r,cmp2);
	int i;
	int j=l;
	for(i=mid+1;i<=r;i++){
		while(st[i].red>=st[j].red&&j<=mid){
			add(st[j].dfn,1);
			j++;
		}
		ans[st[i].id]+=query(min(st[i].dfn+st[i].sz-1,n))-query(st[i].dfn);
//		cout<<"test"<<a[i].ans<<endl;
	}
	for(i=l;i<=j-1;i++){
		add(st[i].dfn,-1);
	}

}
signed main(){
	ios::sync_with_stdio(false);
	cin >> n;
	for(int i=1;i<=n-1;i++){
		int u,v;
		cin >> u >> v;
		mp[u].push_back(v);
		mp[v].push_back(u);
	}
	for(int i=1;i<=n;i++){
		cin >> st[i].blue >> st[i].red;
		st[i].id=i;
	}
	predfs(1,1);
	sort(st+1,st+1+n,cmp1);
	cdq(1,n);
	for(int i=1;i<=n;i++){
		if(ans[i]){
			cout<<ans[i]<<endl;
		}
	}
}

例题三

https://www.luogu.com.cn/problem/P3157#submit
隐藏第三维是时间



#include<bits/stdc++.h>
using namespace std;
#define int long long
/*
发现只要求出初始逆序对,然后每次求出降低的逆序对即可
对于每一个被删的元素,消失的逆序对等于
在它前面,权值比他大,且删去时间比他晚的点个数
       				+
在它后面,权值比他小,且删去时间比他晚的点个数
*/
struct node{
	int m;
	int v;
	int d;
	int id;
	int t;
}e[400025];
bool cmp1(node x,node y){
	return x.d<y.d;
}
int a[400005];
int pos[400005],ans[400005];
//-----------------------------------
int n,k,m;
int tree[400005];
int lowbit(int x){
	return x&(-x);
}
void add(int poss,int val){
	for(int i=poss;i<=n;i+=lowbit(i)){
		tree[i]+=val;
	}
}
int query(int poss){
	int ret=0;
	for(int i=poss;i>=1;i-=lowbit(i)){
		ret+=tree[i];
	}
	return ret;
}
//-------------------------------------
void cdq(int l,int r){
    if (l==r){
    	return;
	} 
    int mid=(l+r)>>1;
	int j=l;
    cdq(l,mid),cdq(mid+1,r);
    sort(e+l,e+mid+1,cmp1);
    sort(e+mid+1,e+r+1,cmp1);
    for(int i=mid+1;i<=r;i++){
        while(j<=mid&&e[j].d<=e[i].d){
        	add(e[j].v,e[j].m);
			j++;
		}
        ans[e[i].id]+=e[i].m*(query(n)-query(e[i].v));
    }
    for (int i=l;i<j;i++){
    	add(e[i].v,-e[i].m);	
	}
    j=mid;
    for (int i=r;i>mid;i--){
        while (j>=l&&e[j].d>=e[i].d){
        	add(e[j].v,e[j].m);
			j--;	
		}
        ans[e[i].id]+=e[i].m*query(e[i].v-1);
    }
    for (int i=mid;i>j;i--){
    	add(e[i].v,-e[i].m);
	} 
}
int tot;
signed main(){
	ios::sync_with_stdio(false);
    cin >> n >> m;
    for (int i=1;i<=n;i++){
    	cin >> a[i];
		pos[a[i]]=i;
		e[++tot]=(node){1,a[i],i,0,tot};
	}
    for (int i=1;i<=m;i++){
    	int x;
		cin >> x;
		e[++tot]=(node){-1,x,pos[x],i,tot};	
	}
    cdq(1,tot);
    for(int i=1;i<=m;++i){
    	ans[i]+=ans[i-1];
	} 
    for(int i=0;i<m;++i){
    	cout<<ans[i]<<endl;	
	}

	 
}