题目地址:
https://www.acwing.com/problem/content/254/
给定一个有
N
N
N个点(编号
0
,
1
,
…
,
N
?
1
0,1,…,N?1
0,1,…,N?1)的树,每条边都有一个权值(不超过
1000
1000
1000)。树上两个节点
x
x
x与
y
y
y之间的路径长度就是路径上各条边的权值之和。求长度不超过
K
K
K的路径有多少条。
输入格式: 输入包含多组测试用例。 每组测试用例的第一行包含两个整数
N
N
N和
K
K
K。 接下来
N
?
1
N?1
N?1行,每行包含三个整数
u
,
v
,
l
u,v,l
u,v,l,表示节点
u
u
u与
v
v
v之间存在一条边,且边的权值为
l
l
l。 当输入用例
N
=
0
,
K
=
0
N=0,K=0
N=0,K=0时,表示输入终止,且该用例无需处理。
输出格式: 每个测试用例输出一个结果。每个结果占一行。
数据范围:
1
≤
N
≤
1
0
4
1≤N≤10^4
1≤N≤104
1
≤
K
≤
5
×
1
0
6
1≤K≤5×10^6
1≤K≤5×106
0
≤
l
≤
1
0
3
0≤l≤10^3
0≤l≤103
思路是分治。假设一开始我们随便选择一个顶点
u
u
u为树根,那么所有的路径可以分为三类: 1、完全在
u
u
u的某一棵子树里,这个可以递归求解; 2、路径的两个端点一个在子树里,另一个是树根,这个可以通过DFS求解每个点到树根的路径来解决; 3、路径的两个端点在不同子树里,这个可以先DFS求解每个点到树根的距离,然后在这个距离数组里求解和小于等于
K
K
K的数对有多少个。但是这个会将两个顶点在同一棵子树里的非法情况包括进去,我们可以对每个子树再求一次和小于等于
K
K
K的数对,然后将这个情况扣除即可。在一个数组里求和小于等于
K
K
K的数对个数,可以先排序,然后用双指针。
分治的每一层的时间大概是
O
(
n
log
?
n
)
O(n\log n)
O(nlogn),如果递归层数过多,算法会退化到
O
(
n
2
log
?
n
)
O(n^2\log n)
O(n2logn),所以我们为了递归层数尽量少,需要使得每次选树根的时候,子树大小尽量均衡,于是我们想到可以每次选重心作为树根(当然我们只需要该点的最大子树节点个数小于等于总点数一半即可,这样每次递归下去一层,最大子树节点个数就减半,从而达到递归
O
(
log
?
n
)
O(\log n)
O(logn)层的效果,不一定非要选择重心。不妨就将满足条件的点统称为“重心”)。求解树根,可以先DFS求树的节点个数,然后再DFS暴力枚举解决。
代码如下:
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1e4 + 10, M = N << 1;
int n, m;
int h[N], e[M], w[M], ne[M], idx;
int sz[N];
bool vis[N];
int p[N];
void add(int a, int b, int c) {
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
int get_sz(int u, int from) {
if (vis[u]) return 0;
sz[u] = 1;
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v != from) sz[u] += get_sz(v, u);
}
return sz[u];
}
bool get_wc(int u, int from, int tot, int &wc) {
int ms = 0;
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v == from || vis[v]) continue;
if (get_wc(v, u, tot, wc)) return true;
ms = max(ms, sz[v]);
}
ms = max(ms, tot - sz[u]);
if (ms <= tot / 2) {
wc = u;
return true;
}
return false;
}
int get_dist(int u, int from, int dist, int &pt) {
if (vis[u]) return 0;
int cnt = 1;
p[pt++] = dist;
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v != from) cnt += get_dist(v, u, dist + w[i], pt);
}
return cnt;
}
int calc(int l, int r) {
sort(p + l, p + r + 1);
int res = 0;
for (int i = l, j = r; i < j;)
if (p[i] + p[j] <= m) res += j - i, i++;
else j--;
return res;
}
int dfs(int u) {
if (vis[u]) return 0;
get_wc(u, -1, get_sz(u, -1), u);
vis[u] = true;
int res = 0, pt = 0;
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
int cnt = get_dist(v, -1, w[i], pt);
if (!cnt) continue;
res -= calc(pt - cnt, pt - 1);
int l = pt - cnt, r = pt - 1;
while (l < r) {
int mid = l + (r - l + 1 >> 1);
if (p[mid] <= m) l = mid;
else r = mid - 1;
}
if (p[l] <= m) res += l - (pt - cnt) + 1;
}
res += calc(0, pt - 1);
for (int i = h[u]; ~i; i = ne[i]) res += dfs(e[i]);
return res;
}
int main() {
while (scanf("%d%d", &n, &m), n || m) {
memset(h, -1, sizeof h);
memset(vis, 0, sizeof vis);
idx = 0;
for (int i = 1; i <= n - 1; i++) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
a++, b++;
add(a, b, c), add(b, a, c);
}
printf("%d\n", dfs(1));
}
}
每组数据时间复杂度
O
(
n
log
?
2
n
)
O(n\log^2n)
O(nlog2n),空间
O
(
n
)
O(n)
O(n)。
每次选择重心的话,一共递归
O
(
log
?
n
)
O(\log n)
O(logn)层,每层时间
O
(
n
log
?
(
n
/
2
k
)
)
O(n\log (n/2^k))
O(nlog(n/2k)),
k
k
k是层编号,总共
O
(
n
log
?
2
n
)
O(n\log^2n)
O(nlog2n)。
|