学习笔记_ 点分治 [点分治, 树, 分治]

又学了一种什么题都是模板的算法, 深感药丸

沙茶的 点分治不详解

点分治, 可以用来找满足某种约束条件的路径的条数$or$经过最多$or$最少边的路径数等等

具体思想和实现起来就是计算和利用子树的计算来去重

我们考虑以当前节点为根的子树对答案的贡献:

  • 以当前节点为起点的路径组合的贡献
  • 在当前节点子树中的路径的贡献 (不经过当前节点的路径的贡献, 即忽略的路径 需要加上)
  • 以某棵以$x$节点为根的子树中不符合要求, 而重复经过路径$x \rightarrow dq \,\, \&\& \,\, dq \rightarrow x$后”符合要求”的路径 (即需要去重的路径)

这样就可以直接到子树中递归分治了

然而我们发现, 如果这样搞的话, 一些奇怪的数据可能把我们卡成$n ^ 2$, 然而你不是AH-mhr, 你没有高超的暴力技巧, 于是我们每次应该找树的重心, 每次删除重心并递归重心的子树, 这样一次递归毕然会使节点数量减少一半, 复杂度为$\mathrm O(n\lg n \cdot$每次递归操作的复杂度$)$

这玩意还是看代码比较得劲…

沙茶一不小心水过的 超SB模板题

poj1741 Tree

点分模板题, 没啥说的

  • 漏了i < jWA了好几发…
  • 不能只在init_size()之前memset一次maxside数组…耗时又大又会出问题…而是应该直接在每次递归到当前节点的时候就赋值为0

沙茶的 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define MAXN (10000 + 5)
#define INF (0x7fffffff)
#define rev(a) ((((a) - 1) ^ 1) + 1)
#define rint register int
#define cint const int
using namespace std;
struct edg
{
int from, to, next, cost;
} b[MAXN << 1];
int g[MAXN], cntb, n, maxside[MAXN], size[MAXN], k, dis[MAXN];
bool vis[MAXN];
void adn(cint from, cint to, cint cost)
{
b[++cntb].next = g[from];
b[cntb].from = from;
b[cntb].to = to;
b[cntb].cost = cost;
g[from] = cntb;
}
void init();
int dfs(cint); // to solve the problem
int cal(cint, cint); // to calculate the number of road in the tree
void init_size(cint, cint); // to init the size of each node's son-tree
int dfs_root(cint, cint, cint); // to init the maxside of each node and return the root of the tree
int getroot(cint); // to work those faction
void dfs_dis(cint); // to init the distance from node to root
int main()
{
while (true)
{
scanf("%d%d", &n, &k);
if (!n && !k)
break;
init();
printf("%d\n", dfs(1));
}
return 0;
}
void init()
{
cntb = 0;
memset(g, 0, sizeof(g));
memset(vis, false, sizeof(vis));
for (rint i = 1; i < n; i++)
{
int srx, sry, srz;
scanf("%d%d%d", &srx, &sry, &srz);
adn(srx, sry, srz);
adn(sry, srx, srz);
}
}
int dfs(cint dq)
{
cint root = getroot(dq);
vis[root] = true;
int re = cal(root, 0);
for (rint i = g[root]; i; i = b[i].next)
if (!vis[b[i].to])
{
re -= cal(b[i].to, b[i].cost);
re += dfs(b[i].to);
}
return re;
}
inline int getroot(cint dq)
{
init_size(dq, dq);
return dfs_root(dq, dq, dq);
}
void init_size(cint dq, cint fa)
{
size[dq] = 1;
maxside[dq] = 0; // 这个地方直接赋为0就好了
for (rint i = g[dq]; i; i = b[i].next)
if (b[i].to != fa && !vis[b[i].to])
{
init_size(b[i].to, dq);
maxside[dq] = max(maxside[dq], size[b[i].to]);
size[dq] += size[b[i].to];
}
}
int dfs_root(cint dq, cint fa, cint root)
{
int minsonside = n, re = 0;
for (rint i = g[dq]; i; i = b[i].next)
if (b[i].to != fa && !vis[b[i].to])
{
rint t = dfs_root(b[i].to, dq, root);
if (maxside[t] < minsonside)
minsonside = maxside[t], re = t;
}
maxside[dq] = max(maxside[dq], size[root] - size[dq]);
if (minsonside > maxside[dq])
re = dq;
return re;
}
void dfs_dis(cint dq, cint fa, cint dqdis)
{
dis[++dis[0]] = dqdis;
for (rint i = g[dq]; i; i = b[i].next)
if (!vis[b[i].to] && b[i].to != fa)
dfs_dis(b[i].to, dq, dqdis + b[i].cost);
}
inline int cal(cint dq, cint fir)
{
int re = 0;
dis[0] = 0;
dfs_dis(dq, dq, fir);
sort(dis + 1, dis + dis[0] + 1);
rint j = dis[0];
for (rint i = 1; i < j; i++)
{
while (dis[j] + dis[i] > k && i < j) --j; // i < j 不能漏了...
re += j - i;
}
return re;
}

/*
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
*/

聪聪可可

也是裸题吧…据说有dp做法就跟着写了一发, 我一开始胡的dp和点分治差不多, 然后一直WA…也不知道为啥…后来看远古神犇blog, 改成po姐的写法…过了…我感觉其实差不多啊, 有没有人愿意来看看啊qwqwqwqwq

细节就是一开始写点分治的时候忘了 如果两条膜3为0的路径加起来路径和膜3也是0…调了好半天才过的样例…

还有就是一条路径可以正反选两遍非常恶心…

沙茶的 代码

点分治写法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define MAXN (20000 + 5)
#define LL long long
using namespace std;
struct edg
{
int from, to, next, cost;
}b[MAXN << 1];
int g[MAXN], cntb, dis[MAXN], maxside[MAXN], size[MAXN], n;
bool vis[MAXN];
void adn(int from, int to, int cost)
{
b[++cntb].next = g[from];
b[cntb].from = from;
b[cntb].to = to;
b[cntb].cost = cost;
g[from] = cntb;
}
int gcd(int a, int b)
{ return (b ? gcd(b, a % b) : a); }
void init_size(int, int);
int dfs_root(int, int, int);
int getroot(int);
void dfs_dis(int, int, int);
int dfs(int);
int cal(int, int);
int main()
{
// freopen("in.in", "r", stdin);
memset(vis, false, sizeof(vis));
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int srx, sry, srz;
scanf("%d%d%d", &srx, &sry, &srz);
adn(srx, sry, srz);
adn(sry, srx, srz);
}
int upside = dfs(1), downside = n * n, ta = gcd(upside, downside);
printf("%d/%d", upside / ta, downside / ta);
return 0;
}
int dfs(int dq)
{
int root = getroot(dq);
vis[root] = true;
int re = cal(root, 0);
for (int i = g[root]; i; i = b[i].next)
if (!vis[b[i].to])
{
re -= cal(b[i].to, b[i].cost);
re += dfs(b[i].to);
}
return re;
}
int getroot(int dq)
{
init_size(dq, dq);
return dfs_root(dq, dq, dq);
}
void init_size(int dq, int fa)
{
size[dq] = 1, maxside[dq] = 0;
for (int i = g[dq]; i; i = b[i].next)
if (b[i].to != fa && !vis[b[i].to])
{
init_size(b[i].to, dq);
size[dq] += size[b[i].to];
maxside[dq] = max(maxside[dq], size[b[i].to]);
}
}
int dfs_root(int dq, int fa, int root)
{
int re, minsonside = n;
for (int i = g[dq]; i; i = b[i].next)
if (b[i].to != fa && !vis[b[i].to])
{
int t = dfs_root(b[i].to, dq, root);
if (maxside[t] < minsonside)
re = t, minsonside = maxside[t];
}
maxside[dq] = max(maxside[dq], size[root] - size[dq]);
if (maxside[dq] < minsonside)
re = dq;
return re;
}
void dfs_dis(int dq, int fa, int dqdis)
{
dis[++dis[0]] = dqdis;
for (int i = g[dq]; i; i = b[i].next)
if (b[i].to != fa && !vis[b[i].to])
dfs_dis(b[i].to, dq, dqdis + b[i].cost);
}
int cal(int dq, int fir)
{
dis[0] = 0;
dfs_dis(dq, dq, fir);
int num[3] = {0, 0, 0};
for (int i = 1; i <= dis[0]; i++)
{
dis[i] %= 3;
++num[dis[i]];
}
return ((num[1] * num[2]) << 1) + num[0] * num[0]; // 别忘了膜3为0的路径也可以组成合法路径
}

/*
5
1 2 1
1 3 2
1 4 1
2 5 3
*/

dp写法

AC

用dp写果然短了不少

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define MAXN (20000 + 5)
using namespace std;
struct edg
{
int from, to, next, cost;
}b[MAXN << 1];
int g[MAXN], cntb, dis[MAXN][3], fa[MAXN], n, ans; // f[i][x]: 当前在节点i, 到根的距离 % 3为x的方案数, f[i][1] = dis[dq][1] * f[son][0] - f[son][1 - cost] emmmm
void adn(int from, int to, int cost)
{
b[++cntb].next = g[from];
b[cntb].from = from;
b[cntb].to = to;
b[cntb].cost = cost;
g[from] = cntb;
}
// f[i][1] = dis[dq][1] * dis[dq][0] - f[son][1 - cost] + f[son][cost]
void dfs(int dq)
{
// dis[dq][i]: 在以dq节点为根的子树中, 有dis[dq][i]条以dq为起点的路径长度为i
dis[dq][1] = dis[dq][2] = 0;
dis[dq][0] = 1;
for (int i = g[dq]; i; i = b[i].next)
if (fa[dq] != b[i].to)
{
fa[b[i].to] = dq;
dfs(b[i].to);
for (int j = 0; j < 3; j++)
{ // j + x + b[i].cost = 0 -> x =
ans += dis[dq][j] * dis[b[i].to][(-((b[i].cost) % 3) - j + 6) % 3];
dis[dq][j] += dis[b[i].to][(j + 3 - (b[i].cost % 3)) % 3];
}
}
}
int gcd(int a, int b)
{ return (b ? gcd(b, a % b) : a); }
int main()
{
// freopen("in.in", "r", stdin);
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int srx, sry, srz;
scanf("%d%d%d", &srx, &sry, &srz);
adn(srx, sry, srz);
adn(sry, srx, srz);
}
fa[1] = 1;
dfs(1);
ans = ans * 2 + n;
int ta = gcd(ans, n * n);
printf("%d/%d", ans / ta, n * n / ta);
return 0;
}

/*
5
1 2 1
1 3 2
1 4 1
2 5 3
*/
WA

蜜汁WA了…qwqwqwqwqwq有人愿意看看嘛 /kel

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define MAXN (20000 + 5)
using namespace std;
struct edg
{
int from, to, next, cost;
}b[MAXN << 1];
int g[MAXN], cntb, f[MAXN][3], dis[MAXN][3], fa[MAXN], n; // f[i][x]: 当前在节点i, 到根的距离 % 3为x的方案数, f[i][1] = dis[dq][1] * f[son][0] - f[son][1 - cost] emmmm
void adn(int from, int to, int cost)
{
b[++cntb].next = g[from];
b[cntb].from = from;
b[cntb].to = to;
b[cntb].cost = cost;
g[from] = cntb;
}
// f[i][1] = dis[dq][1] * dis[dq][0] - f[son][1 - cost] + f[son][cost]
void dfs(int dq)
{ // f[dq][i]: 在以dq节点为根的子树中, 有f[dq][i]条路径长度为i
// dis[dq][i]: 在以dq节点为根的子树中, 有dis[dq][i]条以dq为起点的路径长度为i
dis[dq][1] = dis[dq][2] = 0;
dis[dq][0] = 1;
for (int i = g[dq]; i; i = b[i].next)
if (fa[dq] != b[i].to)
{
fa[b[i].to] = dq;
dfs(b[i].to);
for (int j = 0; j < 3; j++)
dis[dq][j] += dis[b[i].to][(j + 3 - (b[i].cost % 3)) % 3];
}
f[dq][1] = dis[dq][1] * dis[dq][0] + dis[dq][2] * (dis[dq][2] - 1) / 2;
f[dq][2] = dis[dq][2] * dis[dq][0] + dis[dq][1] * (dis[dq][1] - 1) / 2;
f[dq][0] = dis[dq][0] * (dis[dq][0] - 1) / 2 + dis[dq][1] * dis[dq][2];
for (int i = g[dq]; i; i = b[i].next)
if (fa[dq] != b[i].to)
for (int j = 0; j < 3; j++)
{
f[dq][j] -= f[b[i].to][(j + 3 - (b[i].cost % 3)) % 3];
f[dq][j] += f[b[i].to][j];
}
}
int gcd(int a, int b)
{ return (b ? gcd(b, a % b) : a); }
int main()
{
// freopen("in.in", "r", stdin);
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int srx, sry, srz;
scanf("%d%d%d", &srx, &sry, &srz);
adn(srx, sry, srz);
adn(sry, srx, srz);
}
fa[1] = 1;
dfs(1);
f[1][0] = f[1][0] * 2 + n;
int ta = gcd(f[1][0], n * n);
printf("%d/%d", f[1][0] / ta, n * n / ta);
return 0;
}

/*
5
1 2 1
1 3 2
1 4 1
2 5 3
*/

Race

沙茶的 代码

犯了一些sb错误…比如

  • 递归找根没有记录返回的根(?????)直接把b[i].to当成子树的根了(????)…
  • 计算maxside数组居然用size[root] - maxside[dq], 应该是size[root] - size[dq]
  • 计算cnt的时候把ij全写成了i
  • 在子树中计算去重的时候没有考虑到初始的时候来回已经经过了两条边, 也就是要从2开始计算边数

然后你就有95分辣! IOI的题95分是不是很激动?

一开始以为是哪里的细节出了问题, 静态查了半天 + 胡乱搞反例都没啥用, 最后还是看了po爷爷的blog…

还有一个问题…其实你的算法整个都是有问题的…

就是…可能会出现dis数组相等的问题, 所以不能直接移动j指针; 而是应该对于每一个i, 都备份一个j并移动, 对于下一个i从头开始判断, 复杂度…不是很懂…不过大概不是$\mathrm O(n)$ 吧…

最后….卡常让人怀疑人生…最后bzoj还是没过…不过某谷还是能过的…

又: 某谷上这个题…是黑的…被我用三个小号…刷成紫题了…Emmm… 你们加油把他搞成红的啊

卡常前

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define MAXN (2000000 + 5)
#define INF (0x7fffffff)
#define cint const int
#define rint register int
#define pii pair<int, int>
using namespace std;
struct edg
{
int from, to, next, cost;
}b[MAXN << 1];
int g[MAXN], cntb, n, k, ans = -1, totdis, maxside[MAXN], size[MAXN], cnt[MAXN], dis0;
pii dis[MAXN];
bool vis[MAXN];
inline void adn(int from, int to, int cost)
{
b[++cntb].next = g[from];
b[cntb].from = from;
b[cntb].to = to;
b[cntb].cost = cost;
g[from] = cntb;
}
inline int read()
{
rint re = 0, f = 1;
char x = 0;
while (x < '0' || x > '9')
if ((x = getchar()) == '-')
f = -1;
while (x <= '9' && x >= '0')
{
re = (re << 1) + (re << 3) + x - '0';
x = getchar();
}
return re * f;
}
inline void cal(cint, cint, cint, cint);
void dfs_dis(cint, cint, cint, cint);
inline int get_root(cint);
int dfs_root(cint, cint, cint);
void init_size(cint, cint);
void dfs(cint);
int main()
{
// freopen("ioi2011-race.out", "w", stdout);
// freopen("in.in", "r", stdin);
memset(vis, false, sizeof(vis));
// scanf("%d%d", &n, &k);
n = read(), k = read();
for (int i = 1; i < n; i++)
{
// int srx, sry, srz;
int srx = read(), sry = read(), srz = read();
// scanf("%d%d%d", &srx, &sry, &srz);
++srx, ++sry;
adn(srx, sry, srz);
adn(sry, srx, srz);
}
dfs(1);
for (int i = 0; i < n && ans == -1; i++)
if (cnt[i] > 0)
ans = i;
printf("%d", ans);
return 0;
}
void dfs(int dq)
{
int root = get_root(dq);
vis[root] = true;
cal(root, 0, 1, 0);
for (int i = g[root]; i; i = b[i].next)
if (!vis[b[i].to])
{
cal(b[i].to, b[i].cost, -1, 2); // 如果是以b[i].to为根的链...是不是要改wa为1啊... 不是
dfs(b[i].to);
}
}
inline int get_root(cint dq)
{
init_size(dq, dq);
return dfs_root(dq, dq, dq);
}
void init_size(cint dq, cint fa)
{
maxside[dq] = 0;
size[dq] = 1;
for (int i = g[dq]; i; i = b[i].next)
if (!vis[b[i].to] && b[i].to != fa)
{
init_size(b[i].to, dq);
maxside[dq] = max(maxside[dq], size[b[i].to]);
size[dq] += size[b[i].to];
}
}
int dfs_root(cint dq, cint fa, cint root)
{
int re, minmaxside = n;
for (int i = g[dq]; i; i = b[i].next)
if (!vis[b[i].to] && b[i].to != fa)
{
int t = dfs_root(b[i].to, dq, root); // 犯了奇奇怪怪的错误...这里没有记录t...gg
if (maxside[t] < minmaxside)
re = t, minmaxside = maxside[t];
}
maxside[dq] = max(maxside[dq], size[root] - size[dq]); // 这里减的是size...不是maxside...日益zz
if (maxside[dq] < minmaxside)
re = dq;
return re;
}
void dfs_dis(cint dq, cint fa, cint dqdis, cint way)
{
dis[++dis0] = make_pair(dqdis, way);
for (int i = g[dq]; i; i = b[i].next)
if (b[i].to != fa && !vis[b[i].to])
dfs_dis(b[i].to, dq, dqdis + b[i].cost, way + 1);
}
inline void cal(cint dq, cint fir, cint zh, cint wa) // wa: 记录初始有几条边
{
dis0 = 0;
dfs_dis(dq, dq, fir, 0);
sort(dis + 1, dis + dis0 + 1);
int j = dis0;
for (int i = 1; i < j; i++)
{
while (dis[i].first + dis[j].first > k && i < j) --j;
int last = j;
while (dis[i].first + dis[last].first == k && i < last)
{
cnt[dis[i].second + dis[last].second + wa] += zh; // 这个地方把i写成j了...
--last;
}
}
}

/*
4 3
0 1 1
1 2 2
1 3 4
*/

卡常后

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define MAXN (200000 + 5)
#define INF (0x7fffffff)
#define cint const int
#define rint register int
#define pii pair<int, int>
#define max(a, b) ((a) > (b) ? (a) : (b))
#pragma GCC optimize(3)
//#define getchar() (S == T && (T = (S = BB) + fread(BB, 1, MAXN << 5, stdin), S == T) ? EOF : *S++)
//char BB[MAXN * 20], *S = BB, *T = BB;
using namespace std;
int to[MAXN << 1], next[MAXN << 1], cost[MAXN << 1];
int g[MAXN], cntb, n, k, ans = -1, totdis, maxside[MAXN], size[MAXN], cnt[MAXN], dis0;
pii dis[MAXN];
bool vis[MAXN];
inline void adn(cint fro, cint t, cint co)
{
next[++cntb] = g[fro];
to[cntb] = t;
cost[cntb] = co;
g[fro] = cntb;
}
inline char gc(){
static char BUFF[1000000],*S=BUFF,*T=BUFF;
return S==T&&(T=(S=BUFF)+fread(BUFF,1,1000000,stdin),S==T)?EOF:*S++;
}
inline int read()
{
int x=0,c=1;
char ch=' ';
while((ch>'9'||ch<'0')&&ch!='-')ch=gc();
while(ch=='-')c*=-1,ch=gc();
while(ch<='9'&&ch>='0')x=x*10+ch-'0',ch=gc();
return x*c;
}
inline void cal(cint, cint, cint, cint);
void dfs_dis(cint, cint, cint, cint);
inline int get_root(cint);
int dfs_root(cint, cint, cint);
void init_size(cint, cint);
void dfs(cint);
int main()
{
// freopen("in.in", "r", stdin);
// scanf("%d%d", &n, &k);
n = read(), k = read();
for (rint i = 1; i < n; ++i)
{
// int srx, sry, srz;
rint srx = read(), sry = read(), srz = read();
// scanf("%d%d%d", &srx, &sry, &srz);
++srx, ++sry;
adn(srx, sry, srz);
adn(sry, srx, srz);
}
dfs(1);
for (rint i = 0; i < n && ans == -1; ++i)
if (cnt[i] > 0)
ans = i;
printf("%d", ans);
return 0;
}
void dfs(cint dq)
{
cint root = get_root(dq);
vis[root] = true;
cal(root, 0, 1, 0);
for (rint i = g[root]; i; i = next[i])
if (!vis[to[i]])
{
cal(to[i], cost[i], -1, 2); // 如果是以b[i].to为根的链...是不是要改wa为1啊... 然而不是
dfs(to[i]);
}
}
inline int get_root(cint dq)
{
init_size(dq, dq);
return dfs_root(dq, dq, dq);
}
void init_size(cint dq, cint fa)
{
maxside[dq] = 0;
size[dq] = 1;
for (rint i = g[dq]; i; i = next[i])
if (!vis[to[i]] && to[i] != fa)
{
init_size(to[i], dq);
maxside[dq] = max(maxside[dq], size[to[i]]);
size[dq] += size[to[i]];
}
}
int dfs_root(cint dq, cint fa, cint root)
{
rint re, minmaxside = n;
for (rint i = g[dq]; i; i = next[i])
if (!vis[to[i]] && to[i] != fa)
{
cint t = dfs_root(to[i], dq, root); // 犯了奇奇怪怪的错误...这里没有记录t...gg
if (maxside[t] < minmaxside)
re = t, minmaxside = maxside[t];
}
maxside[dq] = max(maxside[dq], size[root] - size[dq]); // 这里减的是size...不是maxside...日益zz
if (maxside[dq] < minmaxside)
re = dq;
return re;
}
void dfs_dis(cint dq, cint fa, cint dqdis, cint way)
{
dis[++dis0] = make_pair(dqdis, way);
for (rint i = g[dq]; i; i = next[i])
if (to[i] != fa && !vis[to[i]])
dfs_dis(to[i], dq, dqdis + cost[i], way + 1);
}
inline void cal(cint dq, cint fir, cint zh, cint wa) // wa: 记录初始有几条边
{
dis0 = 0;
dfs_dis(dq, dq, fir, 0);
sort(dis + 1, dis + dis0 + 1);
rint j = dis0;
for (rint i = 1; i < j; i++)
{
while (dis[i].first + dis[j].first > k && i < j) --j;
rint last = j; // 注意这个地方的lasst非常得劲, 你可能会出现两个数相等的情况草草草
while (dis[i].first + dis[last].first == k && i < last)
{
cnt[dis[i].second + dis[last].second + wa] += zh; // 这个地方把i写成j了...
--last;
}
}
}

/*
4 3
0 1 1
1 2 2
1 3 4
*/

感觉文化课彻底凉了…

By 吃枣药丸的 Cansult