树形DP
1. 树形DP
定义
2. AcWing上的树形DP题目
AcWing 285. 没有上司的舞会
问题描述
分析
代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 6010;
int n;
int h[N], e[N], ne[N], idx;
int happy[N];
int f[N][2];
bool has_fa[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u) {
f[u][1] = happy[u];
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
dfs(j);
f[u][1] += f[j][0];
f[u][0] += max(f[j][0], f[j][1]);
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &happy[i]);
memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i++) {
int a, b;
scanf("%d%d", &a, &b);
add(b, a);
has_fa[a] = true;
}
int root = 1;
while (has_fa[root]) root++;
dfs(root);
printf("%d\n", max(f[root][0], f[root][1]));
return 0;
}
AcWing 1072. 树的最长路径
问题描述
分析
- 在递归求解的过程中记录每个点所有子树的路径和最大值和次大值(可能相等)即可,最长路径就是两者之和。
代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 10010, M = N * 2;
int n;
int h[N], e[M], w[M], ne[M], idx;
int ans;
void add(int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
int dfs(int u, int father) {
int d1 = 0, d2 = 0;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == father) continue;
int d = dfs(j, u) + w[i];
if (d >= d1) d2 = d1, d1 = d;
else if (d > d2) d2 = d;
}
ans = max(ans, d1 + d2);
return d1;
}
int main() {
cin >> n;
memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i++) {
int a, b, c;
cin >> a >> b >> c;
add(a, b, c), add(b, a, c);
}
dfs(1, -1);
cout << ans << endl;
return 0;
}
AcWing 1073. 树的中心
问题描述
分析
-
本题记录每个点到子节点的路径的最大值d1 和次大值d2 ,p1 记录每个d1 对应的路径,p2 记录d2 对应的路径。 -
up 记录从某个点向其父节点可以到达的最长路径,对于节点x ,up[x] 存在两种情况: (1)向父节点u 走取得最大值; (2)达到父节点u 后向下折返取得最大值(要求折返路线不能原路返回,即经过x );
- 正是因为不能原路返回,所以要存储每个节点到子节点的路径的次大值
d2 ,注意这里的次大值可能和最大值相同,因为向下可能存在多条路径长度相同的最大值。
代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 10010, M = N * 2;
int n;
int h[N], e[M], w[M], ne[M], idx;
int d1[N], d2[N], p1[N], p2[N], up[N];
void add(int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
void dfs_d(int u, int father) {
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == father) continue;
dfs_d(j, u);
int d = d1[j] + w[i];
if (d >= d1[u]) {
d2[u] = d1[u], d1[u] = d;
p1[u] = j;
} else if (d > d2[u]) d2[u] = d;
}
}
void dfs_u(int u, int father) {
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == father) continue;
if (p1[u] == j) up[j] = max(up[u], d2[u]) + w[i];
else up[j] = max(up[u], d1[u]) + w[i];
dfs_u(j, u);
}
}
int main() {
cin >> n;
memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i++) {
int a, b, c;
cin >> a >> b >> c;
add(a, b, c), add(b, a, c);
}
dfs_d(1, -1);
dfs_u(1, -1);
int res = 1e9;
for (int i = 1; i <= n; i++) res = min(res, max(d1[i], up[i]));
cout << res << endl;
return 0;
}
AcWing 1075. 数字转换
问题描述
分析
-
我们使用sum 记录每个数的约数之和,按照题意,如果sum[i]<i ,可以在sum[i] 和i 之间建立一条边,因为对于每个i ,最多建立一条边,且连接更小的数,因此我们可以得到一个森林。我们求出森林中每棵树的直径即得到本题的答案。 -
树的直径可以使用AcWing 1072. 树的最长路径的做法。 -
另外就是求解每个数的约数之和,常规做法是对于给定的数k ,按照试除法求出其所有约数,这样的时间复杂度是
O
(
n
×
n
)
O(n \times \sqrt n)
O(n×n
?)的。我们可以反过来考虑,考虑每个因子i 的倍数,这样的时间复杂度是
O
(
n
×
l
o
g
(
n
)
)
O(n \times log(n))
O(n×log(n))的(调和级数)。
代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 50010;
int n;
int h[N], e[N], ne[N], idx;
int sum[N];
bool st[N];
int ans;
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
int dfs(int u) {
st[u] = true;
int d1 = 0, d2 = 0;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (!st[j]) {
int d = dfs(j) + 1;
if (d >= d1) d2 = d1, d1 = d;
else if (d > d2) d2 = d;
}
}
ans = max(ans, d1 + d2);
return d1;
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++)
for (int j = 2; j <= n / i; j++)
sum[i * j] += i;
memset(h, -1, sizeof h);
for (int i = 1; i <= n; i++)
if (sum[i] < i)
add(sum[i], i);
for (int i = 1; i <= n; i++)
if (!st[i])
dfs(i);
cout << ans << endl;
return 0;
}
AcWing 1074. 二叉苹果树
问题描述
分析
-
本题是AcWing 10. 有依赖的背包问题问题的一个简化版,区别有两点:(1)本题子树要么没有,要么是两个;(2)本题是保留边,AcWing10 是保留点。 -
使用f[i][j] 表示:以i 为根且保留j 条边最多保留苹果数目。 -
dfs 的过程中,对于某个节点来说,如果其有两个子节点a、b ,则物品组有四种组合方式,我们可以依次考虑每棵子树需要保留多少条边,这样就不要枚举a、b 有几棵子树了(这样枚举是指数级别的)。 -
本题解法和AcWing10 是一样的,可以参考:背包问题(背包九讲)。
代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 110, M = N * 2;
int n, m;
int h[N], e[M], w[M], ne[M], idx;
int f[N][N];
void add(int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u, int father) {
for (int i = h[u]; ~i; i = ne[i]) {
if (e[i] == father) continue;
dfs(e[i], u);
for (int j = m; j; j--)
for (int k = 0; k + 1 <= j; k++)
f[u][j] = max(f[u][j], f[u][j - k - 1] + f[e[i]][k] + w[i]);
}
}
int main() {
cin >> n >> m;
memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i++) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
dfs(1, -1);
printf("%d\n", f[1][m]);
return 0;
}
AcWing 323. 战略游戏
问题描述
分析
-
本题相当于问每条边上至少选择一个点,最小的权值和是多少(这里点的权值都为1)?相当于是AcWing 285. 没有上司的舞会的一个对偶问题。对比如下: (1)没有上司的舞会:每条边上最多选一个点,让选出的点的权值最大。 (2)战略游戏:每条边上最少选择一个点,让选出的点的权值最小。 -
分析如下:
代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 1510, M = N * 2;
int n;
int h[N], e[M], ne[M], idx;
int f[N][2];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u, int father) {
f[u][0] = 0, f[u][1] = 1;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == father) continue;
dfs(j, u);
f[u][0] += f[j][1];
f[u][1] += min(f[j][0], f[j][1]);
}
}
int main() {
while (cin >> n) {
memset(h, -1, sizeof h);
idx = 0;
for (int i = 0; i < n; i++) {
int id, cnt;
scanf("%d:(%d)", &id, &cnt);
while (cnt--) {
int ver;
scanf("%d", &ver);
add(id, ver), add(ver, id);
}
}
dfs(0, -1);
printf("%d\n", min(f[0][0], f[0][1]));
}
return 0;
}
AcWing 1077. 皇宫看守
问题描述
分析
f(i, 0): 点i被父节点看到的所有集合对应的最小花费
f(i, 1): 点i被子节点看到的所有集合对应的最小花费
f(i, 2): 在点i上放置警卫的所有摆放方案的最小花费
f
(
i
,
0
)
=
∑
m
i
n
(
f
(
j
,
1
)
,
f
(
j
,
2
)
)
f
(
i
,
2
)
=
∑
m
i
n
(
f
(
j
,
0
)
,
f
(
j
,
1
)
,
f
(
j
,
2
)
)
f
(
i
,
1
)
=
m
i
n
k
(
f
(
k
,
2
)
+
∑
j
≠
k
(
f
(
j
,
1
)
,
f
(
j
,
2
)
)
)
f(i, 0) = \sum min(f(j, 1), f(j, 2)) \\ f(i, 2) = \sum min(f(j, 0), f(j, 1), f(j, 2)) \\ f(i, 1) = \underset {k}{min} \Bigl(f(k, 2) + \underset {j \neq k}{\sum}(f(j, 1), f(j, 2)) \Bigr)
f(i,0)=∑min(f(j,1),f(j,2))f(i,2)=∑min(f(j,0),f(j,1),f(j,2))f(i,1)=kmin?(f(k,2)+j?=k∑?(f(j,1),f(j,2)))
代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 1510, M = N * 2;
int n;
int h[N], e[M], w[N], ne[M], idx;
int f[N][3];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u, int father) {
f[u][2] = w[u];
int sum = 0;
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == father) continue;
dfs(j, u);
f[u][0] += min(f[j][1], f[j][2]);
f[u][2] += min(min(f[j][0], f[j][1]), f[j][2]);
sum += min(f[j][1], f[j][2]);
}
f[u][1] = 1e9;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
f[u][1] = min(f[u][1], f[j][2] + sum - min(f[j][1], f[j][2]));
}
}
int main() {
cin >> n;
memset(h, -1, sizeof h);
for (int i = 0; i < n; i++) {
int id, cost, cnt;
cin >> id >> cost >> cnt;
w[id] = cost;
while (cnt--) {
int ver;
cin >> ver;
add(id, ver), add(ver, id);
}
}
dfs(1, -1);
cout << min(f[1][1], f[1][2]) << endl;
return 0;
}
3 LeetCode上一些树形DP问题
Leetcode 0124 二叉树中的最大路径和
题目描述:Leetcode 0124 二叉树中的最大路径和
分析
- 得到上述结果后,若令
left=max(0, f(u->left)) ,right=max(0, f(u->right)) ,则经过u 的所有路径中的最大的一个对应的路径和为:u->val+left+right 。
代码
class Solution {
public:
int ans;
int maxPathSum(TreeNode* root) {
ans = INT_MIN;
dfs(root);
return ans;
}
int dfs(TreeNode* u) {
if (!u) return 0;
int left = max(0, dfs(u->left)), right = max(0, dfs(u->right));
ans = max(ans, u->val + left + right);
return u->val + max(left, right);
}
};
class Solution {
int ans = Integer.MIN_VALUE;
public int maxPathSum(TreeNode root) {
dfs(root);
return ans;
}
private int dfs(TreeNode u) {
if (u == null) return 0;
int left = Math.max(0, dfs(u.left)), right = Math.max(0, dfs(u.right));
ans = Math.max(ans, u.val + left + right);
return u.val + Math.max(left, right);
}
}
时空复杂度分析
Leetcode 0310 最小高度树
题目描述:Leetcode 0310 最小高度树
分析
-
本题的考点:动态规划、树形DP。 -
本题和AcWing 1073. 树的中心一样。 -
本题记录每个点到子节点的路径的最大值d1 和次大值d2 ,p1 记录每个d1 对应的路径,p2 记录d2 对应的路径。 -
up 记录从某个点向其父节点可以到达的最长路径,对于节点x ,up[x] 存在两种情况: (1)向父节点u 走取得最大值; (2)达到父节点u 后向下折返取得最大值(要求折返路线不能原路返回,即经过x );
- 正是因为不能原路返回,所以要存储每个节点到子节点的路径的次大值
d2 ,注意这里的次大值可能和最大值相同,因为向下可能存在多条路径长度相同的最大值。
代码
class Solution {
public:
vector<vector<int>> g;
vector<int> d1, d2, p1, p2, up;
vector<int> findMinHeightTrees(int n, vector<vector<int>> &edges) {
g.resize(n);
d1 = d2 = p1 = p2 = up = vector<int>(n);
for (auto &e : edges) {
int a = e[0], b = e[1];
g[a].push_back(b), g[b].push_back(a);
}
dfs1(0, -1);
dfs2(0, -1);
int mind = n + 1;
for (int i = 0; i < n; i++) mind = min(mind, max(up[i], d1[i]));
vector<int> res;
for (int i = 0; i < n; i++)
if (max(up[i], d1[i]) == mind)
res.push_back(i);
return res;
}
void dfs1(int u, int father) {
for (int x : g[u]) {
if (x == father) continue;
dfs1(x, u);
int d = d1[x] + 1;
if (d >= d1[u]) {
d2[u] = d1[u], d1[u] = d;
p2[u] = p1[u], p1[u] = x;
} else if (d > d2[u]) {
d2[u] = d;
p2[u] = x;
}
}
}
void dfs2(int u, int father) {
for (int x : g[u]) {
if (x == father) continue;
if (p1[u] == x) up[x] = max(up[u], d2[u]) + 1;
else up[x] = max(up[u], d1[u]) + 1;
dfs2(x, u);
}
}
};
class Solution {
List<List<Integer>> g = new ArrayList<>();
int[] d1, d2, p1, p2, up;
public List<Integer> findMinHeightTrees(int n, int[][] edges) {
d1 = new int[n]; d2 = new int[n]; p1 = new int[n]; p2 = new int[n]; up = new int[n];
for (int i = 0; i < n; i++) g.add(new ArrayList<>());
for (int[] e : edges) {
int a = e[0], b = e[1];
g.get(a).add(b); g.get(b).add(a);
}
dfs1(0, -1);
dfs2(0, -1);
int mind = n + 1;
for (int i = 0; i < n; i++) mind = Math.min(mind, Math.max(d1[i], up[i]));
List<Integer> res = new ArrayList<>();
for (int i = 0; i < n; i++)
if (Math.max(d1[i], up[i]) == mind)
res.add(i);
return res;
}
private void dfs1(int u, int father) {
for (int x : g.get(u)) {
if (x == father) continue;
dfs1(x, u);
int d = d1[x] + 1;
if (d >= d1[u]) {
d2[u] = d1[u]; d1[u] = d;
p2[u] = p1[u]; p1[u] = x;
} else if (d > d2[u]) {
d2[u] = d;
p2[u] = x;
}
}
}
private void dfs2(int u, int father) {
for (int x : g.get(u)) {
if (x == father) continue;
if (p1[u] == x) up[x] = Math.max(up[u], d2[u]) + 1;
else up[x] = Math.max(up[u], d1[u]) + 1;
dfs2(x, u);
}
}
}
时空复杂度分析
Leetcode 0337 打家劫舍 III
题目描述:Leetcode 0337 打家劫舍 III
分析
代码
class Solution {
public:
int rob(TreeNode *root) {
auto f = dfs(root);
return max(f[0], f[1]);
}
vector<int> dfs(TreeNode *u) {
if (!u) return {0, 0};
auto x = dfs(u->left), y = dfs(u->right);
return {max(x[0], x[1]) + max(y[0], y[1]), x[0] + y[0] + u->val};
}
};
class Solution {
public int rob(TreeNode root) {
int[] f = dfs(root);
return Math.max(f[0], f[1]);
}
private int[] dfs(TreeNode root) {
if (root == null) return new int[]{0, 0};
int[] x = dfs(root.left), y = dfs(root.right);
return new int[]{Math.max(x[0], x[1]) + Math.max(y[0], y[1]), x[0] + y[0] + root.val};
}
}
时空复杂度分析
Leetcode 0968 监控二叉树
题目描述:Leetcode 0968 监控二叉树
分析
f(i, 0): 点i被父节点看到的所有集合对应的最小花费
f(i, 1): 点i被子节点看到的所有集合对应的最小花费
f(i, 2): 在点i上放置警卫的所有摆放方案的最小花费
f
(
i
,
0
)
=
∑
m
i
n
(
f
(
j
,
1
)
,
f
(
j
,
2
)
)
f
(
i
,
2
)
=
∑
m
i
n
(
f
(
j
,
0
)
,
f
(
j
,
1
)
,
f
(
j
,
2
)
)
f
(
i
,
1
)
=
m
i
n
k
(
f
(
k
,
2
)
+
∑
j
≠
k
(
f
(
j
,
1
)
,
f
(
j
,
2
)
)
)
f(i, 0) = \sum min(f(j, 1), f(j, 2)) \\ f(i, 2) = \sum min(f(j, 0), f(j, 1), f(j, 2)) \\ f(i, 1) = \underset {k}{min} \Bigl(f(k, 2) + \underset {j \neq k}{\sum}(f(j, 1), f(j, 2)) \Bigr)
f(i,0)=∑min(f(j,1),f(j,2))f(i,2)=∑min(f(j,0),f(j,1),f(j,2))f(i,1)=kmin?(f(k,2)+j?=k∑?(f(j,1),f(j,2)))
代码
class Solution {
public:
const int INF = 1e8;
vector<int> dfs(TreeNode* root) {
if (!root) return {0, 0, INF};
auto l = dfs(root->left), r = dfs(root->right);
return {
min(l[1], l[2]) + min(r[1], r[2]),
min(l[2] + min(r[1], r[2]), r[2] + min(l[1], l[2])),
min(l[0], min(l[1], l[2])) + min(r[0], min(r[1], r[2])) + 1,
};
}
int minCameraCover(TreeNode* root) {
auto f = dfs(root);
return min(f[1], f[2]);
}
};
class Solution {
static int INF = (int) (1e8);
public int minCameraCover(TreeNode root) {
int[] f = dfs(root);
return Math.min(f[1], f[2]);
}
int[] dfs(TreeNode root) {
if (root == null) return new int[]{0, 0, INF};
int[] l = dfs(root.left), r = dfs(root.right);
return new int[] {
Math.min(l[1], l[2]) + Math.min(r[1], r[2]),
Math.min(l[2] + Math.min(r[1], r[2]), r[2] + Math.min(l[1], l[2])),
Math.min(l[0], Math.min(l[1], l[2])) + Math.min(r[0], Math.min(r[1], r[2])) + 1,
};
}
}
class Solution:
def minCameraCover(self, root: TreeNode) -> int:
f = self.dfs(root)
return min(f[1], f[2])
def dfs(self, root):
if root is None:
return (0, 0, int(1e8))
l, r = self.dfs(root.left), self.dfs(root.right)
return (
min(l[1], l[2]) + min(r[1], r[2]),
min(l[2] + min(r[1], r[2]), r[2] + min(l[1], l[2])),
min(l[0], min(l[1], l[2])) + min(r[0], min(r[1], r[2])) + 1,
)
时空复杂度分析
|