简介

主席树是一种数据结构,其主要应用是区间第 $k$ 大问题

权值线段树

传统的线段树用于维护一段区间的值,可以方便地查询区间信息;而如果将线段树转化为权值线段树,每个叶子节点存储某个元素出现次数,一段区间的总和表示区间内所有数出现次数的总和,这样可以方便地求出整体第 $k$ 大:
从根节点向下走,如果 $k$ 小于等于左子树大小,说明第 $k$ 大在左子树的区间中,在左子树中继续查找即可;否则,说明第 $k$ 大在右子树的区间中,此时将 $k$ 减去左子树大小,并在右子树中继续查找

前缀和

权值线段树可以用来处理整体的第 $k$ 大,而我们可以对于一个长度为 $n$ 的序列建立 $n$ 棵上述的权值线段树,第 $i$ 棵表示 $a_1$ ~ $a_i$ 的所有数组成的权值线段树;如果要查询 $[l,r]$ 中的第 $k$ 大,可以使用第 $r$ 棵线段树减去第 $l-1$ 棵线段树,得到整个区间组成的权值线段树,并进行上述算法得到区间中的第 $k$ 大
但是这个算法存在两个问题:
1.每个线段树要占用 $O(nlogn)$ 的空间,算法的空间复杂度为 $O(n^2logn)$ ,占用空间过多
2.建立每棵线段树至少要用 $O(nlogn)$ 的时间,每次查询又要用 $O(nlogn)$ 的时间构建区间的权值线段树,总时间复杂度 $O((m+n)nlogn)$
看上去还比不上直接提取区间的暴力算法

主席树

仔细思考,发现上述算法的 $n$ 棵线段树中,相邻的两棵线段树仅有 $O(logn)$ 个节点不同,因此本质不同的节点只有 $O(nlogn)$ 个;我们可以充分利用这一特点,每次只重新创建与上次所不同的节点,相同的节点直接使用前一棵的即可
为了节省空间,可以将第 $0$ 棵线段树置为空,每次插入一个新叶子节点时接入一条长度为 $O(logn)$ 的链;总空间、时间复杂度仍为 $O(nlogn)$
查询时构造整棵线段树,需要构造 $O(nlogn)$ 个节点,但每次查询只会用到 $O(logn)$ 个节点,直接动态构造这些节点即可;为了方便,可以不显式构造这些节点,而是直接用两棵线段树上的值相减

构造

现在,假设 $n$ 为 $6$ ,序列为 $1,3,2,3,6,1$
那么每棵树都如下图所示:
1
因为是同一个问题,所以线段树都是一样的,只是节点的权值不一样,下图为第 $4$ 棵线段树,序列为 $1,3,2,3$ (红色数字为节点权值)
2
和线段树对比:

1
2
3
4
5
6
7
8
9
10
11
12
13
struct Node    //线段树,4倍空间
{
int l,r;
int sum;
}t[maxn<<2];

struct Node //主席树,32倍空间
{
int ls,rs;
int sum;
}N[maxn<<5];

int RT[maxn<<5],Ncnt; //记录每棵树的根节点编号和总的节点数量

现有长度为 $6$ 的序列 $4,3,2,3,6,1$ ,重现一下前 $3$ 棵树的构造过程
3
4
5

一般来说,我们先建一棵空树

1
2
3
4
5
6
7
8
9
10
T.build(T.RT[0],1,len);

void build(int &rt,int l,int r)
{
rt=++Ncnt;
if(l==r) return;
int mid=(l+r)>>1;
build(N[rt].ls,l,mid);
build(N[rt].rs,mid+1,r);
}

然后再依次建树,这里 $b$ 数组是排序并去重后的 $a$ 数组,因为数可能会很大,所以用 $unique$ 去重,然后离散化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
for(int i=1;i<=n;i++)
{
int p=lower_bound(b+1,b+len+1,a[i])-b;
T.RT[i]=T.update(T.RT[i-1],1,len,p);
}

int update(int rt,int l,int r,int p)
{
int nrt=++Ncnt;
N[nrt].ls=N[rt].ls,N[nrt].rs=N[rt].rs,N[nrt].sum=N[rt].sum+1;
if(l==r) return nrt;
int mid=(l+r)>>1;
if(p<=mid) N[nrt].ls=update(N[nrt].ls,l,mid,p);
else N[nrt].rs=update(N[nrt].rs,mid+1,r,p);
return nrt;
}

查询

根据前缀和思想,我们可以得到 $[l,r]$ 的权值线段树,然后就能很轻松地找到区间第 $k$ 大

1
2
3
4
5
6
7
8
9
10
int p=T.query(T.RT[l-1],T.RT[r],1,len,k);
cout<<b[p]<<'\n';

int query(int u,int v,int l,int r,int k)
{
int mid=(l+r)>>1,x=N[N[v].ls].sum-N[N[u].ls].sum;
if(l==r) return l;
if(x>=k) return query(N[u].ls,N[v].ls,l,mid,k);
else return query(N[u].rs,N[v].rs,mid+1,r,k-x);
}

题目

题目链接

「洛谷」P3834

题目大意

长度为 $n$ 的序列, $m$ 个询问,对于每个询问,输出指定区间的第 $k$ 大值

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <bits/stdc++.h>
using namespace std;

const int maxn=2e5+5;

struct CTree
{
struct Node
{
int ls,rs;
int sum;
}N[maxn<<5];

int RT[maxn<<5],Ncnt;

void build(int &rt,int l,int r)
{
rt=++Ncnt;
if(l==r) return;
int mid=(l+r)>>1;
build(N[rt].ls,l,mid);
build(N[rt].rs,mid+1,r);
}

int update(int rt,int l,int r,int p)
{
int nrt=++Ncnt;
N[nrt].ls=N[rt].ls,N[nrt].rs=N[rt].rs,N[nrt].sum=N[rt].sum+1;
if(l==r) return nrt;
int mid=(l+r)>>1;
if(p<=mid) N[nrt].ls=update(N[nrt].ls,l,mid,p);
else N[nrt].rs=update(N[nrt].rs,mid+1,r,p);
return nrt;
}

int query(int u,int v,int l,int r,int k)
{
int mid=(l+r)>>1,x=N[N[v].ls].sum-N[N[u].ls].sum;
if(l==r) return l;
if(x>=k) return query(N[u].ls,N[v].ls,l,mid,k);
else return query(N[u].rs,N[v].rs,mid+1,r,k-x);
}
}T;

int a[maxn],b[maxn];

int main()
{
ios::sync_with_stdio(false);
int n,m;
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i],b[i]=a[i];
sort(b+1,b+n+1);
int len=unique(b+1,b+n+1)-b-1;

T.build(T.RT[0],1,len);
for(int i=1;i<=n;i++)
{
int p=lower_bound(b+1,b+len+1,a[i])-b;
T.RT[i]=T.update(T.RT[i-1],1,len,p);
}
for(int i=1;i<=m;i++)
{
int l,r,k;
cin>>l>>r>>k;
int p=T.query(T.RT[l-1],T.RT[r],1,len,k);
cout<<b[p]<<'\n';
}
return 0;
}