洛谷原题链接 || Hydro 域链接。
算法分析
使用树形动态规划。对每个节点 u,维护两个状态:
dp[u][0]
:选中 u 时,子树的最大权值和。
dp[u][d]
(1 \leq d <k):未选中 u 时,子树中最近选中节点距 u 为 d 的最大权值和。
dp[u][k]
:未选中 u 且子树中无选中节点或最近选中节点距 u ≥k。
状态转移
- 选中 u:所有子节点 v 的最近选中节点距离必须 ≥k-1。
dp[u][0] = a_u + \sum_v \max_{d \geq k-1} dp[v][d]
- 未选中 u:对每个可能的父节点传递的距离 d_p,更新子节点状态:
- 若 d_p+1 <k,则 d_{child} = d_p+1。
- 否则,d_{child} =k。
转移时,合并所有子节点在不同距离下的最大值。
复杂度
时间复杂度 O(nk),空间复杂度 O(nk)。
代码实现
题外话:DeepSeek 在写代码的时候第十行末尾多打了一个右括号。以下代码没有这个问题。
import sys
from collections import defaultdict
sys.setrecursionlimit(1 << 25)
def main():
n, k = map(int, sys.stdin.readline().split())
a = list(map(int, sys.stdin.readline().split()))
tree = [[] for _ in range(n)]
for _ in range(n-1):
u, v = map(int, sys.stdin.readline().split())
tree[u-1].append(v-1)
tree[v-1].append(u-1)
# dp[u][d]: 在u的子树中,最近的选中节点距离u为d时的最大和
dp = [defaultdict(int) for _ in range(n)]
def dfs(u, parent):
# 初始化,选中自己
select = a[u]
not_select = defaultdict(int)
not_select[k] = 0 # 初始时,未选中且无选中节点
for v in tree[u]:
if v == parent:
continue
dfs(v, u)
# 处理选中u的情况:子节点的最近选中距离 >=k-1
max_sub = 0
for d in dp[v]:
if d >= k-1:
if dp[v][d] > max_sub:
max_sub = dp[v][d]
select += max_sub
# 处理未选中u的情况:合并子节点的状态
new_not_select = defaultdict(int)
for d_parent in not_select:
current_sum = not_select[d_parent]
max_v = 0
for d_v in dp[v]:
new_d = d_parent +1 if d_parent +1 <k else k
if d_v == new_d:
candidate = dp[v][d_v]
if candidate > max_v:
max_v = candidate
if max_v ==0:
# 子节点无法满足条件时,取所有可能的最大值
max_v = max(dp[v].values()) if dp[v] else 0
current_sum += max_v
key = min(d_parent +1, k)
if key not in new_not_select or current_sum > new_not_select[key]:
new_not_select[key] = current_sum
not_select = new_not_select
# 更新dp[u]
dp[u][0] = select
for d in not_select:
dp[u][d] = not_select[d]
dfs(0, -1)
# 根节点可以选中或未选中,取最大值
max_val = 0
for d in dp[0]:
if d ==0 or d >=k:
if dp[0][d] > max_val:
max_val = dp[0][d]
print(max_val)
if __name__ == "__main__":
main()
【AI 提供】数据生成代码
import sys
import random
from random import randint
def generate_tree(n):
edges = []
for i in range(1, n):
u = i
v = randint(0, i-1)
edges.append((u+1, v+1)) # 转换为1-based
return edges
def main():
n = randint(3, 10)
k = randint(1, 3)
a = [randint(1, 100) for _ in range(n)]
edges = generate_tree(n)
print(n, k)
print(' '.join(map(str, a)))
for u, v in edges:
print(u, v)
if __name__ == "__main__":
main()
【实际上】数据生成代码
import sys
from collections import defaultdict
import random
from random import randint
sys.setrecursionlimit(1 << 25)
def main(infile, outfile):
n, k = map(int, infile.readline().split())
a = list(map(int, infile.readline().split()))
tree = [[] for _ in range(n)]
for _ in range(n-1):
u, v = map(int, infile.readline().split())
tree[u-1].append(v-1)
tree[v-1].append(u-1)
# dp[u][d]: 在u的子树中,最近的选中节点距离u为d时的最大和
dp = [defaultdict(int) for _ in range(n)]
def dfs(u, parent):
# 初始化,选中自己
select = a[u]
not_select = defaultdict(int)
not_select[k] = 0 # 初始时,未选中且无选中节点
for v in tree[u]:
if v == parent:
continue
dfs(v, u)
# 处理选中u的情况:子节点的最近选中距离 >=k-1
max_sub = 0
for d in dp[v]:
if d >= k-1:
if dp[v][d] > max_sub:
max_sub = dp[v][d]
select += max_sub
# 处理未选中u的情况:合并子节点的状态
new_not_select = defaultdict(int)
for d_parent in not_select:
current_sum = not_select[d_parent]
max_v = 0
for d_v in dp[v]:
new_d = d_parent +1 if d_parent +1 <k else k
if d_v == new_d:
candidate = dp[v][d_v]
if candidate > max_v:
max_v = candidate
if max_v ==0:
# 子节点无法满足条件时,取所有可能的最大值
max_v = max(dp[v].values()) if dp[v] else 0
current_sum += max_v
key = min(d_parent +1, k)
if key not in new_not_select or current_sum > new_not_select[key]:
new_not_select[key] = current_sum
not_select = new_not_select
# 更新dp[u]
dp[u][0] = select
for d in not_select:
dp[u][d] = not_select[d]
dfs(0, -1)
# 根节点可以选中或未选中,取最大值
max_val = 0
for d in dp[0]:
if d ==0 or d >=k:
if dp[0][d] > max_val:
max_val = dp[0][d]
print(max_val, file=outfile)
infile.close()
outfile.close()
def generate_tree(n):
edges = []
for i in range(1, n):
u = i
v = randint(0, i-1)
edges.append((u+1, v+1)) # 转换为1-based
return edges
def gen(case):
i = open(f'd{case}.in', 'a')
k = randint(1, 3 if case == 1 else 10**3)
n = randint(k, 10 if case == 1 else 10**5)
a = [randint(1, 100) for _ in range(n)]
edges = generate_tree(n)
print(n, k, file=i)
print(' '.join(map(str, a)), file=i)
for u, v in edges:
print(u, v, file=i)
i.close()
main(open(f'd{case}.in', 'r'), open(f'd{case}.out', 'a'))
if __name__ == "__main__":
for case in range(1, 11):
gen(case)