알고리즘 문제 풀이/백준 문제 풀이

백준 13511 - 트리와 쿼리 2

rejo 2024. 8. 16. 16:02

문제 링크 : https://www.acmicpc.net/problem/13511

 

문제

N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다.

아래의 두 쿼리를 수행하는 프로그램을 작성하시오.

  • 1 u v: u에서 v로 가는 경로의 비용을 출력한다.
  • 2 u v k: u에서 v로 가는 경로에 존재하는 정점 중에서 k번째 정점을 출력한다. k는 u에서 v로 가는 경로에 포함된 정점의 수보다 작거나 같다.

입력

첫째 줄에 N (2 ≤ N ≤ 100,000)이 주어진다.

둘째 줄부터 N-1개의 줄에는 i번 간선이 연결하는 두 정점 번호 u와 v와 간선의 비용 w가 주어진다.

다음 줄에는 쿼리의 개수 M (1 ≤ M ≤ 100,000)이 주어진다.

다음 M개의 줄에는 쿼리가 한 줄에 하나씩 주어진다.

간선의 비용은 항상 1,000,000보다 작거나 같은 자연수이다.

출력

각각의 쿼리의 결과를 순서대로 한 줄에 하나씩 출력한다.

풀이 과정

https://readytojoin.tistory.com/entry/LCA-%EC%B5%9C%EC%86%8C-%EA%B3%B5%ED%86%B5-%EC%A1%B0%EC%83%81-%EC%B0%BE%EA%B8%B0?category=1193467

 

LCA - 최소 공통 조상 찾기

최소 공통 조상 (LCA, Lowest Common Ancestor) 알고리즘은 트리 구조 안에서 특정 정점 두 개의 공통 조상 정점 중 가장 가까운 공통 조상 정점을 찾는 알고리즘이다. 단순히 두 정점의 깊이를 맞춘 뒤

readytojoin.tistory.com

단순 DFS 보다는 LCA + 희소 배열을 활용해 빠르게 쿼리에 대한 답을 구한다.

 

1번 쿼리는 누적 비용을 희소 배열에 저장하면 되니 간단히 구할 수 있다. 1761번에서 다루었다.

https://readytojoin.tistory.com/entry/%EB%B0%B1%EC%A4%80-1761-%EC%A0%95%EC%A0%90%EB%93%A4%EC%9D%98-%EA%B1%B0%EB%A6%AC

 

백준 1761 - 정점들의 거리

문제 링크 : https://www.acmicpc.net/problem/1761 문제N(2 ≤ N ≤ 40,000)개의 정점으로 이루어진 트리가 주어지고 M(1 ≤ M ≤ 10,000)개의 두 노드 쌍을 입력받을 때 두 노드 사이의 거리를 출력하라.입력첫째

readytojoin.tistory.com

 

그러나 2번 쿼리는 까다로운 처리가 필요하다. u 에서 최소 공통 조상까지의 거리, v 에서 최소 공통 조상까지의 거리를 따로 구해 활용했다.

 

- 만약 u 에서 LCA 까지의 거리보다 k가 작다면, u -> LCA 경로 사이에 k 번째 정점이 있다는 것이므로, u 노드에서 k 번 부모 방향으로 이동했을 때의 정점을 찾는다.
- 만약 u 에서 LCA 까지의 거리보다 k가 크다면, u -> LCA 를 넘어 LCA -> v 경로 사이에 k 번째 정점이 있다는 것이므로, v 노드에서 (v 에서 LCA 까지의 거리) - (k - (u 에서 LCA 까지의 거리)) 번 부모 방향으로 이동했을 때의 정점을 찾는다.
- 만약 u 에서 LCA 까지의 거리가 k와 같다면 최소 공통 조상이라는 뜻이므로 LCA 노드 번호를 출력한다.

 

# 13511. 트리와 쿼리 2 ( PLATINUM 3 )
import sys
input = sys.stdin.readline
sys.setrecursionlimit(110000)

n = int(input().rstrip())
graph = [[] for _ in range(n+1)]
for _ in range(n-1):
  a, b, d = map(int, input().rstrip().split())
  graph[a].append([b, d])
  graph[b].append([a, d])

depth = [0 for _ in range(n+1)]
visited = [0 for _ in range(n+1)]
parent = [[-1, 0] for _ in range(n+1)]

def dfs(now, p, dep, pre_dist):
  global depth
  global visited
  global parent

  depth[now] = dep

  for next, dist in graph[now]:
    if visited[next] == 0:
      visited[next] = 1
      dfs(next, now, dep + 1, dist)

  parent[now] = [p, pre_dist]

for i in range(1, n+1):
  if visited[i] == 0:
    visited[i] = 1
    dfs(i, -1, 0, 0)

table = [[[0, 0] for _ in range(n+1)] for _ in range(21)]
for i in range(1, n+1): table[0][i] = [parent[i][0], parent[i][1]]

for i in range(1, 21):
  for j in range(1, n+1):
    table[i][j] = [table[i-1][table[i-1][j][0]][0], table[i-1][table[i-1][j][0]][1] + table[i-1][j][1]]

def lca(a, b, mode, k):
  if mode == 1:
    if depth[a] < depth[b]:
      a, b = b, a
      
    res = 0
    for i in range(20, -1, -1):
      if depth[a] - depth[b] >= (1 << i):
        a, cost = table[i][a]
        res += cost
  
    if a == b: return res
    else:
      for i in range(20, -1, -1):
        if table[i][a][0] != table[i][b][0]:
          a, cost1 = table[i][a]
          b, cost2 = table[i][b]
          res += cost1 + cost2
  
      res += table[0][a][1] + table[0][b][1]
      return res
  else:
    # a -> LCA 개수와 b -> LCA 개수 구하기
    
    # 만약 a -> LCA 개수보다 k가 크다면 b에서 (b -> LCA) - (k - (a -> LCA)) 만큼 간 노드
    # 만약 a -> LCA 개수가 k와 같다면 LCA 노드 출력
    # 만약 a -> LCA 개수보다 k가 작다면 a에서 k 만큼 간 노드 출력

    original_a = int(a)
    original_b = int(b)
    
    a_cnt = 0
    b_cnt = 0

    if depth[a] > depth[b]:
      for i in range(20, -1, -1):
        if depth[a] - depth[b] >= (1 << i):
          a, _ = table[i][a]
          a_cnt += 1 << i
    elif depth[a] < depth[b]:
      for i in range(20, -1, -1):
        if depth[b] - depth[a] >= (1 << i):
          b, _ = table[i][b]
          b_cnt += 1 << i

    if a != b:
      for i in range(20, -1, -1):
        if table[i][a][0] != table[i][b][0]:
          a, _ = table[i][a]
          b, _ = table[i][b]
          a_cnt += 1 << i
          b_cnt += 1 << i

      a, _ = table[0][a]
      b, _ = table[0][b]
      a_cnt += 1
      b_cnt += 1 
    #print('--', a_cnt, b_cnt, 'LCA', a)
    k -= 1
    if a_cnt == k: return a
    elif a_cnt < k:
      left = b_cnt - (k - a_cnt)
      b = int(original_b)
      for i in range(20, -1, -1):
        if left & (1 << i):
          b, _ = table[i][b]
      return b
    else:
      left = k
      a = int(original_a)
      #print('original', a, '->', k)
      for i in range(20, -1, -1):
        if left & (1 << i):
          #print(left, i, a, '->', table[i][a][0])
          a, _ = table[i][a]
      return a

    
m = int(input().rstrip())
for _ in range(m):
  order = list(map(int, input().rstrip().split()))
  if order[0] == 1:
    print(lca(order[1], order[2], 1, -1))
  else:
    print(lca(order[1], order[2], 2, order[3]))