原题链接
思路
预处理各节点到根(如节点 1)的异或值,转化为数组中两数异或最大值问题。使用字典树逐位确定最大异或结果。
std
class TrieNode:
def __init__(self):
self.children = [None, None]
def insert(root, num):
node = root
for i in range(31, -1, -1):
bit = (num >> i) & 1
if not node.children[bit]:
node.children[bit] = TrieNode()
node = node.children[bit]
def query(root, num):
node = root
max_xor = 0
for i in range(31, -1, -1):
bit = (num >> i) & 1
toggle = 1 - bit
if node.children[toggle]:
max_xor |= (1 << i)
node = node.children[toggle]
else:
node = node.children[bit] or node
return max_xor
n = int(input())
w = list(map(int, input().split()))
from collections import defaultdict
tree = defaultdict(list)
for _ in range(n-1):
u, v = map(int, input().split())
tree[u].append(v)
tree[v].append(u)
xor = [0]*(n+1)
visited = [False]*(n+1)
stack = [(1, 0)]
while stack:
u, val = stack.pop()
if visited[u]:
continue
visited[u] = True
xor[u] = val ^ w[u-1]
for v in tree[u]:
if not visited[v]:
stack.append((v, xor[u]))
root = TrieNode()
max_val = 0
for num in xor[1:]:
insert(root, num)
max_val = max(max_val, query(root, num))
print(max_val)
【AI 提供】数据生成代码
import random
n = 10**5
print(n)
w = [random.randint(0, 2**31-1) for _ in range(n)]
print(' '.join(map(str, w)))
# 生成树结构
parent = list(range(n+1))
edges = []
for i in range(2, n+1):
p = random.randint(1, i-1)
edges.append((p, i))
random.shuffle(edges)
for u, v in edges[:n-1]:
print(u, v)
【实际上】数据生成代码
import random
from collections import defaultdict
class TrieNode:
def __init__(self):
self.children = [None, None]
def insert(root, num):
node = root
for i in range(31, -1, -1):
bit = (num >> i) & 1
if not node.children[bit]:
node.children[bit] = TrieNode()
node = node.children[bit]
def query(root, num):
node = root
max_xor = 0
for i in range(31, -1, -1):
bit = (num >> i) & 1
toggle = 1 - bit
if node.children[toggle]:
max_xor |= (1 << i)
node = node.children[toggle]
else:
node = node.children[bit] or node
return max_xor
for case in range(1, 11):
n = random.randint(1, 10**5)
w = [random.randint(0, 2**31-1) for _ in range(n)]
# 生成树结构
parent = list(range(n+1))
edges = []
for i in range(2, n+1):
p = random.randint(1, i-1)
edges.append((p, i))
random.shuffle(edges)
with open(f'c{case}.in', 'w') as f:
print(n, file=f)
print(' '.join(map(str, w)), file=f)
for u, v in edges[:n-1]:
print(u, v, file=f)
with open(f'c{case}.out', 'w') as f:
tree = defaultdict(list)
for u, v in edges:
tree[u].append(v)
tree[v].append(u)
xor = [0]*(n+1)
visited = [False]*(n+1)
stack = [(1, 0)]
while stack:
u, val = stack.pop()
if visited[u]:
continue
visited[u] = True
xor[u] = val ^ w[u-1]
for v in tree[u]:
if not visited[v]:
stack.append((v, xor[u]))
root = TrieNode()
max_val = 0
for num in xor[1:]:
insert(root, num)
max_val = max(max_val, query(root, num))
print(max_val, file=f)