学习笔记_ 主席树 [线段树, 主席树, 数据结构]

主席树可以解决一些在序列上查找区间第k大的问题

扯淡的 主席树不详解

静态区间主席树(不资瓷修改)

主席树其实可以说是一类可以合并的树的前缀和

最常用的应该就是权值线段树的前缀和, 可以用来求区间内的第k大
我们在求整个区间的第k大的时候用到了权值线段树的左右儿子查询, 那么我们怎么查询一个给定区间而不是整个区间的第k大呢?

如果我们给这个给定的区间建权值线段树, 我们就可以找到这个区间的第k大了, 但显然我们是没有办法建$n^2$棵权值线段树的, 考虑平常我们做这种区间问题一般用什么…? 前缀和! 我们只要知道$[1 , le - 1]$和$[1, ri]$的权值线段树, 让每个节点相减($[1, ri]$的减去$[1, le - 1]$的), 在得到的新的权值线段树上找全体第k小不就行了?

然而…并没有足够的空间和时间让你建n棵权值线段树(时间空间都是$n^2$), 然后我们再回想一下前缀和的处理,

1
f[i] = f[i - 1] + a[i];

我们要最大限度的利用当前位置和上一个位置的相同点(不同的只有一个), 考虑权值线段树新加入一个点, 只会对一条链产生影响, 那么我们只需要新建这一条链, 其他的还采用上一个权值线段树即可

每次不改变值, 只是新建一条链…是不是想到了可持久化线段树? 所以主席树就是一棵可持久化线段树, 只不过每一个历史版本$x$代表$[1, x]$上的权值线段树

这样我们每次插入一个叶子节点x, 就看x是在当前节点的左子区间还是右子区间, 在哪个区间就递归建下去, 另一个区间直接赋成上一个前缀线段树的区间即可

主席树的建树和查询

不要看着麻烦…其实写起来超休闲的~

动态区间主席树(资瓷单点修改)

考虑权值线段树的修改(x -> y), 其实就是让x节点(和他上面的链)值--, 然后把另一个节点y(和上面链)的节点值++, 然后我们就发现…这得修改$n\log_2 n$个节点(前缀和修改一个数就要把后面的全部修改掉, 修改一次是$\log_2 n$)肯定承受不了…那么想一想怎么快速的修改单点然后维护前缀和呢…树状数组!

这里有很多说法说是树状数组套主席树…其实我觉得似乎并不准确…一开始把我直接搞蒙了…还在研究怎么把这玩意套进去…
其实说成树状数组套权值线段树也许更好理解…或者说主席树不再是只从他的前一个版本新添加链了, 而是维护了一个更短的区间$[i - lowbit(i), i]$(原先是$[1, i]$)

大概就是这样…因为维护的不再是一整个区间, 所以求权值和要多一个$log$, 因为维护的不再是一整个区间, 所以修改也变成了$log$, 总的复杂度$nlog_2^2n$

看!主席动起来了!

具体实现挖个坑…

沙茶的 超大常数丑陋主席树代码

简单模板(Luogu3834)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// luogu-judger-enable-o2
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define pii pair<int, int>
#define MAXN (200000 + 5)
#define INF (0x7ffffff)
using namespace std;
struct node
{
int ls, rs, le, ri, zh; // 代表: 当前节点的左, 右儿子(因为线段树可持久化了, 所以不能再愉快的((dq) << 1) 和 (((dq) << 1) | 1) 了); 维护的区间左, 右边界; 当前节点的区间和(即在这个范围的数据个数)
}b[MAXN << 5];
int n, root[MAXN], cntnode, soa[MAXN], a[MAXN]; // root: 每一个前缀和的根节点编号
bool cmp(int, int);
void init();
void js(int&, int, int, int, int);
int cx(int, int, int);
int main(int argc, char const *argv[])
{
int q;
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++)
{
scanf("%d", &soa[i]);
a[i] = soa[i];
}
init();
for (int i = 1; i <= n; i++)
js(root[i], root[i - 1], 1, n, a[i]);
for (int i = 1; i <= q; i++)
{
int srl, srr, srk;
scanf("%d%d%d", &srl, &srr, &srk);
printf("%d\n", cx(root[srl - 1], root[srr], srk));
}
return 0;
}
void init() // 离散化, 这就是以后的标准操作了...代码短且可靠
{
sort(soa + 1, soa + n + 1, cmp);
int dqn = unique(soa + 1, soa + n + 1) - soa; // 注意一定要去重, 否则同一个数据就不能在线段树的一个节点上了
for (int i = 1; i <= n; i++)
{
a[i] = lower_bound(soa + 1, soa + dqn + 1, a[i]) - soa;
}
}
void js(int& dq, int pre, int le, int ri, int zh) // 注意这里传引用非常的妙, 可以在不知道当前节点儿子编号的情况下继续递归; pre就是上一个版本的线段树
{
dq = ++cntnode;
b[dq].le = le, b[dq].ri = ri;
if (le == ri)
{
b[dq].zh = b[pre].zh + 1; // 在上一个版本(前一个前缀和)的基础上添了一个
return ;
}
int mi = (le + ri) >> 1;
if (zh <= mi)
{
js(b[dq].ls, b[pre].ls, le, mi, zh);
b[dq].rs = b[pre].rs; // 插入的值不在这一个区间, 这个区间没有变化, 直接把边拉过去就行
}
else
{
js(b[dq].rs, b[pre].rs, mi + 1, ri, zh);
b[dq].ls = b[pre].ls;
}
b[dq].zh = b[b[dq].ls].zh + b[b[dq].rs].zh;
}
int cx(int le, int ri, int k)
{
if (b[ri].le == b[ri].ri)
return soa[b[ri].le]; // 注意这个地方应该返回ri的区间, 因为le所在的版本可能还没有这个数就只能返回0了...
int mi = (b[le].le + b[le].ri) >> 1, c = b[b[ri].ls].zh - b[b[le].ls].zh; // c: 利用前缀和求区间的权值线段树
if (k <= c) // 判断左子树是否够k个数
return cx(b[le].ls, b[ri].ls, k);
else
return cx(b[le].rs, b[ri].rs, k - c);
}
bool cmp(int x, int y)
{ return x < y; }

简单模板(在树上)(Luogu2633)

不知道为啥脑子抽了求LCA的时候写错了…就是在枚举d数组的时候是这么写的:

1
2
3
4
5
6
7
for (int i = lgn; i; i--)
{
if (deep[x] == deep[y])
break;
if (deep[x] <= deep[d[y][i]])
y = d[y][i];
}

然后这个i就到不了0了….然后就GG了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
// 辣鸡BZOJ卡我常数削我空间还让我PE????
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define MAXN (300000 + 5)
#define MAXM (100000 + 5)
#define MAXL (25)
#define swap(a, b) { int t = a; a = b; b = t; }
using namespace std;
struct edg
{
int from, to, next;
edg() {}
edg(int a, int b, int c): from(a), to(b), next(c) {}
} tb[MAXM << 1];
int n, cost[MAXN], cntbt, tg[MAXN], d[MAXN][MAXL], deep[MAXN], fa[MAXN], lgn, scost[MAXN];

struct node
{
int ls, rs, zh;
} b[MAXN << 4];
int root[MAXN], cntnode;

void js(int&, int, int, int, int);
int cx(int, int, int, int, int, int, int);
void init(int, int);
inline int lg(int);
void adn(int, int);
int lca(int, int);
inline bool cmp(int, int);
int main(int argc, char const *argv[])
{
int q;
scanf("%d%d", &n, &q);
lgn = lg(n);
for (int i = 1; i <= n; i++)
{
scanf("%d", &scost[i]);
cost[i] = scost[i];
}

{
sort(scost + 1, scost + n + 1, cmp);
int tn = unique(scost + 1, scost + n + 1) - scost;
for (int i = 1; i <= n; i++)
cost[i] = lower_bound(scost + 1, scost + tn, cost[i]) - scost;
}

for (int i = 1; i < n; i++)
{
int srx, sry;
scanf("%d%d", &srx, &sry);
adn(srx, sry);
adn(sry, srx);
}
fa[1] = 0;
init(1, 1);
int lastans = 0;
for (int i = 1; i <= q; i++)
{
int srx, sry, srk, lcaxy;
scanf("%d%d%d", &srx, &sry, &srk);
#ifndef DEBUG
srx ^= lastans;
#endif
lcaxy = lca(srx, sry);
printf("%d", lastans = cx(root[lcaxy], root[fa[lcaxy]], root[srx], root[sry], srk, 1, n));
if (i < q)
puts("");
}
return 0;
}
void adn(int from, int to)
{
tb[++cntbt] = edg(from, to, tg[from]);
tg[from] = cntbt;
}
void init(int dq, int de)
{
js(root[dq], root[fa[dq]], 1, n, cost[dq]);
deep[dq] = de;
for (int i = 1; i <= lgn; i++)
{
d[dq][i] = d[d[dq][i - 1]][i - 1];
}
for (int i = tg[dq]; i; i = tb[i].next)
{
if (tb[i].to != fa[dq])
{
d[tb[i].to][0] = fa[tb[i].to] = dq;
init(tb[i].to, de + 1);
}
}
}
int lca(int x, int y)
{
if (deep[y] < deep[x])
swap(x, y);
for (int i = lgn; i >= 0; i--)
{
if (deep[x] == deep[y])
break;
if (deep[x] <= deep[d[y][i]])
y = d[y][i];
}
if (x == y)
return x;
for (int i = lgn; i >= 0; i--)
if (d[x][i] != d[y][i])
x = d[x][i], y = d[y][i];
return fa[x];
}
inline int lg(int x)
{
int re = 0;
for (; (1 << re) <= x; re++)
{}
return re;
}
// HJTREE
void js(int& dq, int pre, int le, int ri, int zh)
{
dq = ++cntnode;
// b[dq].le = le, b[dq].ri = ri;
if (le == ri)
{
b[dq].zh = b[pre].zh + 1;
return ;
}
int mi = (le + ri) >> 1;
if (zh <= mi)
{
js(b[dq].ls, b[pre].ls, le, mi, zh);
b[dq].rs = b[pre].rs;
}
else
{
js(b[dq].rs, b[pre].rs, mi + 1, ri, zh);
b[dq].ls = b[pre].ls;
}
b[dq].zh = b[b[dq].ls].zh + b[b[dq].rs].zh;
}
int cx(int l1, int l2, int r1, int r2, int k, int le, int ri)
{
if (le == ri)
return scost[le];
int mi = (le + ri) >> 1, c = (b[b[r2].ls].zh + b[b[r1].ls].zh - b[b[l1].ls].zh - b[b[l2].ls].zh);
if (k <= c)
return cx(b[l1].ls, b[l2].ls, b[r1].ls, b[r2].ls, k, le, mi);
else
return cx(b[l1].rs, b[l2].rs, b[r1].rs, b[r2].rs, k - c, mi + 1, ri);
}
inline bool cmp(int x, int y)
{
return x < y;
}

/*
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
*/

坑…待填…

By 期末考试炸飞的 Cansult