简介
主席树是一种数据结构,其主要应用是区间第 $k$ 大问题
权值线段树
传统的线段树用于维护一段区间的值,可以方便地查询区间信息;而如果将线段树转化为权值线段树,每个叶子节点存储某个元素出现次数,一段区间的总和表示区间内所有数出现次数的总和,这样可以方便地求出整体第 $k$ 大:
从根节点向下走,如果 $k$ 小于等于左子树大小,说明第 $k$ 大在左子树的区间中,在左子树中继续查找即可;否则,说明第 $k$ 大在右子树的区间中,此时将 $k$ 减去左子树大小,并在右子树中继续查找
前缀和
权值线段树可以用来处理整体的第 $k$ 大,而我们可以对于一个长度为 $n$ 的序列建立 $n$ 棵上述的权值线段树,第 $i$ 棵表示 $a_1$ ~ $a_i$ 的所有数组成的权值线段树;如果要查询 $[l,r]$ 中的第 $k$ 大,可以使用第 $r$ 棵线段树减去第 $l-1$ 棵线段树,得到整个区间组成的权值线段树,并进行上述算法得到区间中的第 $k$ 大
但是这个算法存在两个问题:
1.每个线段树要占用 $O(nlogn)$ 的空间,算法的空间复杂度为 $O(n^2logn)$ ,占用空间过多
2.建立每棵线段树至少要用 $O(nlogn)$ 的时间,每次查询又要用 $O(nlogn)$ 的时间构建区间的权值线段树,总时间复杂度 $O((m+n)nlogn)$
看上去还比不上直接提取区间的暴力算法
主席树
仔细思考,发现上述算法的 $n$ 棵线段树中,相邻的两棵线段树仅有 $O(logn)$ 个节点不同,因此本质不同的节点只有 $O(nlogn)$ 个;我们可以充分利用这一特点,每次只重新创建与上次所不同的节点,相同的节点直接使用前一棵的即可
为了节省空间,可以将第 $0$ 棵线段树置为空,每次插入一个新叶子节点时接入一条长度为 $O(logn)$ 的链;总空间、时间复杂度仍为 $O(nlogn)$
查询时构造整棵线段树,需要构造 $O(nlogn)$ 个节点,但每次查询只会用到 $O(logn)$ 个节点,直接动态构造这些节点即可;为了方便,可以不显式构造这些节点,而是直接用两棵线段树上的值相减
构造
现在,假设 $n$ 为 $6$ ,序列为 $1,3,2,3,6,1$
那么每棵树都如下图所示:
因为是同一个问题,所以线段树都是一样的,只是节点的权值不一样,下图为第 $4$ 棵线段树,序列为 $1,3,2,3$ (红色数字为节点权值)
和线段树对比:1
2
3
4
5
6
7
8
9
10
11
12
13struct Node //线段树,4倍空间
{
int l,r;
int sum;
}t[maxn<<2];
struct Node //主席树,32倍空间
{
int ls,rs;
int sum;
}N[maxn<<5];
int RT[maxn<<5],Ncnt; //记录每棵树的根节点编号和总的节点数量
现有长度为 $6$ 的序列 $4,3,2,3,6,1$ ,重现一下前 $3$ 棵树的构造过程
一般来说,我们先建一棵空树1
2
3
4
5
6
7
8
9
10T.build(T.RT[0],1,len);
void build(int &rt,int l,int r)
{
rt=++Ncnt;
if(l==r) return;
int mid=(l+r)>>1;
build(N[rt].ls,l,mid);
build(N[rt].rs,mid+1,r);
}
然后再依次建树,这里 $b$ 数组是排序并去重后的 $a$ 数组,因为数可能会很大,所以用 $unique$ 去重,然后离散化1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16for(int i=1;i<=n;i++)
{
int p=lower_bound(b+1,b+len+1,a[i])-b;
T.RT[i]=T.update(T.RT[i-1],1,len,p);
}
int update(int rt,int l,int r,int p)
{
int nrt=++Ncnt;
N[nrt].ls=N[rt].ls,N[nrt].rs=N[rt].rs,N[nrt].sum=N[rt].sum+1;
if(l==r) return nrt;
int mid=(l+r)>>1;
if(p<=mid) N[nrt].ls=update(N[nrt].ls,l,mid,p);
else N[nrt].rs=update(N[nrt].rs,mid+1,r,p);
return nrt;
}
查询
根据前缀和思想,我们可以得到 $[l,r]$ 的权值线段树,然后就能很轻松地找到区间第 $k$ 大1
2
3
4
5
6
7
8
9
10int p=T.query(T.RT[l-1],T.RT[r],1,len,k);
cout<<b[p]<<'\n';
int query(int u,int v,int l,int r,int k)
{
int mid=(l+r)>>1,x=N[N[v].ls].sum-N[N[u].ls].sum;
if(l==r) return l;
if(x>=k) return query(N[u].ls,N[v].ls,l,mid,k);
else return query(N[u].rs,N[v].rs,mid+1,r,k-x);
}
题目
题目链接
题目大意
长度为 $n$ 的序列, $m$ 个询问,对于每个询问,输出指定区间的第 $k$ 大值
完整代码
1 |
|