알고리즘 문제 풀이/백준 문제 풀이

백준 1761 - 정점들의 거리

rejo 2024. 8. 16. 15:53

문제 링크 : https://www.acmicpc.net/problem/1761

 

문제

N(2 ≤ N ≤ 40,000)개의 정점으로 이루어진 트리가 주어지고 M(1 ≤ M ≤ 10,000)개의 두 노드 쌍을 입력받을 때 두 노드 사이의 거리를 출력하라.

입력

첫째 줄에 노드의 개수 N이 입력되고 다음 N-1개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 M이 주어지고, 다음 M개의 줄에 거리를 알고 싶은 노드 쌍이 한 줄에 한 쌍씩 입력된다. 두 점 사이의 거리는 10,000보다 작거나 같은 자연수이다.

정점은 1번부터 N번까지 번호가 매겨져 있다.

출력

M개의 줄에 차례대로 입력받은 두 노드 사이의 거리를 출력한다.

풀이 과정

https://readytojoin.tistory.com/entry/LCA-%EC%B5%9C%EC%86%8C-%EA%B3%B5%ED%86%B5-%EC%A1%B0%EC%83%81-%EC%B0%BE%EA%B8%B0?category=1193467

 

LCA - 최소 공통 조상 찾기

최소 공통 조상 (LCA, Lowest Common Ancestor) 알고리즘은 트리 구조 안에서 특정 정점 두 개의 공통 조상 정점 중 가장 가까운 공통 조상 정점을 찾는 알고리즘이다. 단순히 두 정점의 깊이를 맞춘 뒤

readytojoin.tistory.com

LCA 알고리즘을 통해 최소 공통 조상을 찾아가면서 건너간 간선들의 비용을 센다. 문제에서 루트가 따로 주어지지 않아 루트를 구했는데 임의의 루트로 두었어도 괜찮았을 것 같다. 희소 배열을 이용한 LCA 를 사용한다면. 희소 배열에 정점과 함께 2^i 만큼 이동했을 때의 비용도 누적 합산하여 저장한다.

# 1761. 정점들의 거리 ( PLATINUM 5 )
import sys
input = sys.stdin.readline
sys.setrecursionlimit(50000)

n = int(input().rstrip())
graph = [[] for _ in range(n+1)]
for _ in range(n-1):
  a, b, d = map(int, input().rstrip().split())
  graph[a].append([b, d])
  graph[b].append([a, d])

depth = [0 for _ in range(n+1)]
visited = [0 for _ in range(n+1)]
parent = [[-1, 0] for _ in range(n+1)]

def dfs(now, p, dep, pre_dist):
  global depth
  global visited
  global parent

  depth[now] = dep
  
  for next, dist in graph[now]:
    if visited[next] == 0:
      visited[next] = 1
      dfs(next, now, dep + 1, dist)

  parent[now] = [p, pre_dist]

for i in range(1, n+1):
  if visited[i] == 0:
    visited[i] = 1
    dfs(i, -1, 0, 0)

root = -1
for i in range(1, n+1):
  if parent[i][0] == -1:
    root = i
    break

table = [[[0, 0] for _ in range(n+1)] for _ in range(21)]
for i in range(1, n+1): table[0][i] = [parent[i][0], parent[i][1]]

for i in range(1, 21):
  for j in range(1, n+1):
    table[i][j] = [table[i-1][table[i-1][j][0]][0], table[i-1][table[i-1][j][0]][1] + table[i-1][j][1]]

def lca(a, b):
  if depth[a] < depth[b]:
    a, b = b, a

  res = 0
  for i in range(20, -1, -1):
    if depth[a] - depth[b] >= (1 << i):
      a, cost = table[i][a]
      res += cost

  if a == b:
    return res
  else:
    for i in range(20, -1, -1):
      if table[i][a][0] != table[i][b][0]:
        a, cost1 = table[i][a]
        b, cost2 = table[i][b]
        res += cost1 + cost2

    res += table[0][a][1] + table[0][b][1]
    return res

m = int(input().rstrip())
for _ in range(m):
  a, b = map(int, input().rstrip().split())
  print(lca(a, b))