문제 링크 : https://www.acmicpc.net/problem/13511
문제
N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다.
아래의 두 쿼리를 수행하는 프로그램을 작성하시오.
- 1 u v: u에서 v로 가는 경로의 비용을 출력한다.
- 2 u v k: u에서 v로 가는 경로에 존재하는 정점 중에서 k번째 정점을 출력한다. k는 u에서 v로 가는 경로에 포함된 정점의 수보다 작거나 같다.
입력
첫째 줄에 N (2 ≤ N ≤ 100,000)이 주어진다.
둘째 줄부터 N-1개의 줄에는 i번 간선이 연결하는 두 정점 번호 u와 v와 간선의 비용 w가 주어진다.
다음 줄에는 쿼리의 개수 M (1 ≤ M ≤ 100,000)이 주어진다.
다음 M개의 줄에는 쿼리가 한 줄에 하나씩 주어진다.
간선의 비용은 항상 1,000,000보다 작거나 같은 자연수이다.
출력
각각의 쿼리의 결과를 순서대로 한 줄에 하나씩 출력한다.
풀이 과정
단순 DFS 보다는 LCA + 희소 배열을 활용해 빠르게 쿼리에 대한 답을 구한다.
1번 쿼리는 누적 비용을 희소 배열에 저장하면 되니 간단히 구할 수 있다. 1761번에서 다루었다.
그러나 2번 쿼리는 까다로운 처리가 필요하다. u 에서 최소 공통 조상까지의 거리, v 에서 최소 공통 조상까지의 거리를 따로 구해 활용했다.
- 만약 u 에서 LCA 까지의 거리보다 k가 작다면, u -> LCA 경로 사이에 k 번째 정점이 있다는 것이므로, u 노드에서 k 번 부모 방향으로 이동했을 때의 정점을 찾는다.
- 만약 u 에서 LCA 까지의 거리보다 k가 크다면, u -> LCA 를 넘어 LCA -> v 경로 사이에 k 번째 정점이 있다는 것이므로, v 노드에서 (v 에서 LCA 까지의 거리) - (k - (u 에서 LCA 까지의 거리)) 번 부모 방향으로 이동했을 때의 정점을 찾는다.
- 만약 u 에서 LCA 까지의 거리가 k와 같다면 최소 공통 조상이라는 뜻이므로 LCA 노드 번호를 출력한다.
# 13511. 트리와 쿼리 2 ( PLATINUM 3 )
import sys
input = sys.stdin.readline
sys.setrecursionlimit(110000)
n = int(input().rstrip())
graph = [[] for _ in range(n+1)]
for _ in range(n-1):
a, b, d = map(int, input().rstrip().split())
graph[a].append([b, d])
graph[b].append([a, d])
depth = [0 for _ in range(n+1)]
visited = [0 for _ in range(n+1)]
parent = [[-1, 0] for _ in range(n+1)]
def dfs(now, p, dep, pre_dist):
global depth
global visited
global parent
depth[now] = dep
for next, dist in graph[now]:
if visited[next] == 0:
visited[next] = 1
dfs(next, now, dep + 1, dist)
parent[now] = [p, pre_dist]
for i in range(1, n+1):
if visited[i] == 0:
visited[i] = 1
dfs(i, -1, 0, 0)
table = [[[0, 0] for _ in range(n+1)] for _ in range(21)]
for i in range(1, n+1): table[0][i] = [parent[i][0], parent[i][1]]
for i in range(1, 21):
for j in range(1, n+1):
table[i][j] = [table[i-1][table[i-1][j][0]][0], table[i-1][table[i-1][j][0]][1] + table[i-1][j][1]]
def lca(a, b, mode, k):
if mode == 1:
if depth[a] < depth[b]:
a, b = b, a
res = 0
for i in range(20, -1, -1):
if depth[a] - depth[b] >= (1 << i):
a, cost = table[i][a]
res += cost
if a == b: return res
else:
for i in range(20, -1, -1):
if table[i][a][0] != table[i][b][0]:
a, cost1 = table[i][a]
b, cost2 = table[i][b]
res += cost1 + cost2
res += table[0][a][1] + table[0][b][1]
return res
else:
# a -> LCA 개수와 b -> LCA 개수 구하기
# 만약 a -> LCA 개수보다 k가 크다면 b에서 (b -> LCA) - (k - (a -> LCA)) 만큼 간 노드
# 만약 a -> LCA 개수가 k와 같다면 LCA 노드 출력
# 만약 a -> LCA 개수보다 k가 작다면 a에서 k 만큼 간 노드 출력
original_a = int(a)
original_b = int(b)
a_cnt = 0
b_cnt = 0
if depth[a] > depth[b]:
for i in range(20, -1, -1):
if depth[a] - depth[b] >= (1 << i):
a, _ = table[i][a]
a_cnt += 1 << i
elif depth[a] < depth[b]:
for i in range(20, -1, -1):
if depth[b] - depth[a] >= (1 << i):
b, _ = table[i][b]
b_cnt += 1 << i
if a != b:
for i in range(20, -1, -1):
if table[i][a][0] != table[i][b][0]:
a, _ = table[i][a]
b, _ = table[i][b]
a_cnt += 1 << i
b_cnt += 1 << i
a, _ = table[0][a]
b, _ = table[0][b]
a_cnt += 1
b_cnt += 1
#print('--', a_cnt, b_cnt, 'LCA', a)
k -= 1
if a_cnt == k: return a
elif a_cnt < k:
left = b_cnt - (k - a_cnt)
b = int(original_b)
for i in range(20, -1, -1):
if left & (1 << i):
b, _ = table[i][b]
return b
else:
left = k
a = int(original_a)
#print('original', a, '->', k)
for i in range(20, -1, -1):
if left & (1 << i):
#print(left, i, a, '->', table[i][a][0])
a, _ = table[i][a]
return a
m = int(input().rstrip())
for _ in range(m):
order = list(map(int, input().rstrip().split()))
if order[0] == 1:
print(lca(order[1], order[2], 1, -1))
else:
print(lca(order[1], order[2], 2, order[3]))
'알고리즘 문제 풀이 > 백준 문제 풀이' 카테고리의 다른 글
백준 5214 - 환승 (0) | 2024.11.08 |
---|---|
백준 24520 - Meet In The Middle (0) | 2024.08.16 |
백준 3176 - 도로 네트워크 (0) | 2024.08.16 |
백준 1761 - 정점들의 거리 (0) | 2024.08.16 |
백준 17435 - 합성함수와 쿼리 (0) | 2024.08.16 |