문제
어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.
입력
첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.
입력으로 주어지는 모든 수는 -2^63보다 크거나 같고, 2^63-1보다 작거나 같은 정수이다.
출력
첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -2^63보다 크거나 같고, 2^63-1보다 작거나 같은 정수이다.
풀이 과정
처음 세그먼트 트리에 대해서 배우게 되었다. 여러 데이터 중 특정 부분의 최댓값, 최솟값, 합, 곱 등을 구하는 데에 최적화된 알고리즘으로, 단순 탐색보다 O(logN) 으로 빠르게 값을 구할 수 있다는 장점이 있다.
세그먼트 트리 공부에 도움이 되었던 블로그 : https://velog.io/@kimdukbae/%EC%9E%90%EB%A3%8C%EA%B5%AC%EC%A1%B0-%EC%84%B8%EA%B7%B8%EB%A8%BC%ED%8A%B8-%ED%8A%B8%EB%A6%AC-Segment-Tree
특정 부분의 구간합을 구할 때, 주어진 부분에 포함된다면 그 노드의 값만 가져오면 되는 편리한 탐색을 가지고 있어서 신기했고, 특정 데이터를 업데이트해야할 때 연관된 부분에만 값 변동을 주기 때문에 편리하고 재밌는 알고리즘이었다.
import sys
input = sys.stdin.readline
n, m, k = map(int, input().split())
arr = []
tree = [0] * (n * 4)
for _ in range(n): arr.append(int(input()))
# Tree initiation
def init(start, end, index):
if start == end:
tree[index] = arr[start]
return tree[index]
mid = (start + end) // 2
tree[index] = init(start, mid, index * 2) + init(mid + 1, end, index * 2 + 1) # Prefix-Sum
return tree[index]
init(0, n - 1, 1)
def interval_sum(start, end, index, left, right):
if end < left or right < start: return 0
if left <= start and end <= right: return tree[index]
mid = (start + end) // 2
return interval_sum(start, mid, index * 2, left, right) + interval_sum(mid + 1, end, index * 2 + 1, left, right)
def update(start, end, index, what, value):
if what < start or end < what: return
tree[index] += value
if start == end: return
mid = (start + end) // 2
update(start, mid, index * 2, what, value)
update(mid + 1, end, index * 2 + 1, what, value)
for _ in range(m+k):
n1, n2, n3 = map(int, input().split())
if n1 == 1:
update(0, n - 1, 1, n2-1, n3 - arr[n2 - 1])
arr[n2-1] = n3
else:
print(interval_sum(0, n - 1, 1, n2 - 1, n3 - 1))
자주 사용되는 세그먼트 트리 알고리즘을 배우게 되어 좋았다. 앞으로 자주 사용하면서 연습할 것 같다.
'-- 예전 기록 > BOJ' 카테고리의 다른 글
[ BOJ ] 2955 : 스도쿠 풀기 ( GOLD 2 ) / Python (0) | 2023.03.11 |
---|---|
[ BOJ ] 16932 : 모양 만들기 ( GOLD 3 ) / Python (0) | 2023.03.11 |
[ BOJ ] 12837 : 가계부 (Hard) ( GOLD 1 ) / C (0) | 2023.03.10 |
[ BOJ ] 5373 : 큐빙 ( PLATINUM 5 ) / Python (2) | 2023.03.09 |
[ BOJ ] 10800 : 컬러볼 ( GOLD 3 ) / Python (0) | 2023.03.03 |