-- 예전 기록/BOJ

[ BOJ ] 16978 : 수열과 쿼리 22 ( PLATINUM 4 ) / C

rejo 2023. 4. 10. 14:42

문제

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오.

  • 1 i v: Ai = v로 변경한다.
  • 2 k i j: k번째 1번 쿼리까지 적용되었을 때, Ai, Ai+1, ..., Aj의 합을 출력한다.

입력

첫째 줄에 수열의 크기 N (1 ≤ N ≤ 100,000)이 주어진다.

둘째 줄에는 A1, A2, ..., AN이 주어진다. (1 ≤ Ai ≤ 1,000,000)

셋째 줄에는 쿼리의 개수 M (1 ≤ M ≤ 100,000)이 주어진다.

넷째 줄부터 M개의 줄에는 쿼리가 한 줄에 하나씩 주어진다. 1번 쿼리의 경우 1 ≤ i ≤ N, 1 ≤ v ≤ 1,000,000 이고, 2번 쿼리의 경우 1 ≤ i ≤ j ≤ N이고, 0 ≤ k ≤ (쿼리가 주어진 시점까지 있었던 1번 쿼리의 수)이다.

입력으로 주어지는 모든 수는 정수이다.

출력

모든 2번 쿼리마다 합을 출력한다.

풀이 과정

2번 쿼리를 k번째 1번 쿼리 기준으로 정렬하고, 순차적으로 실행한 뒤 다시 입력된 순서대로 정렬한다. 오프라인 쿼리 기법을 이용해 유리하게 쿼리를 정렬한다.

#include <stdio.h>
#define SIZE 100001
typedef long long LL;

int arr[SIZE];
LL tree[SIZE * 4];

int query1[SIZE][2];
int q1s = 0;
int query2[SIZE][4];
int q2s = 0;
int mtmp[SIZE][4];

LL result[SIZE][2];
LL mres[SIZE][2];

void init(int start, int end, int idx) {
	if (start == end) {
		tree[idx] = arr[start];
		return;
	}

	int mid = (start + end) / 2;
	init(start, mid, idx * 2);
	init(mid + 1, end, idx * 2 + 1);
	tree[idx] = tree[idx * 2] + tree[idx * 2 + 1];
}

void merge_k(int start, int mid, int end) {
	int leftidx = start;
	int rightidx = mid + 1;
	int allidx = start;

	while (leftidx <= mid && rightidx <= end) {
		if (query2[leftidx][0] < query2[rightidx][0]) {
			for (int i = 0; i < 4; i++) mtmp[allidx][i] = query2[leftidx][i];
			allidx += 1;
			leftidx += 1;
		}
		else if (query2[leftidx][0] > query2[rightidx][0]) {
			for (int i = 0; i < 4; i++) mtmp[allidx][i] = query2[rightidx][i];
			allidx += 1;
			rightidx += 1;
		}
		else {
			if (query2[leftidx][3] < query2[rightidx][3]) {
				for (int i = 0; i < 4; i++) mtmp[allidx][i] = query2[leftidx][i];
				allidx += 1;
				leftidx += 1;
			}
			else if (query2[leftidx][3] > query2[rightidx][3]) {
				for (int i = 0; i < 4; i++) mtmp[allidx][i] = query2[rightidx][i];
				allidx += 1;
				rightidx += 1;
			}
		}
	}

	while (leftidx <= mid) {
		for (int i = 0; i < 4; i++) mtmp[allidx][i] = query2[leftidx][i];
		allidx += 1;
		leftidx += 1;
	}

	while (rightidx <= end) {
		for (int i = 0; i < 4; i++) mtmp[allidx][i] = query2[rightidx][i];
		allidx += 1;
		rightidx += 1;
	}

	for (int i = start; i <= end; i++) {
		for (int j = 0; j < 4; j++) query2[i][j] = mtmp[i][j];
	}
}

void merge_sort_k(int start, int end) {
	if (start < end) {
		int mid = (start + end) / 2;
		merge_sort_k(start, mid);
		merge_sort_k(mid + 1, end);
		merge_k(start, mid, end);
	}
}

void update(int start, int end, int idx, int what, int value) {
	if (what < start || end < what) return;
	if (start == end) {
		tree[idx] = value;
		return;
	}

	int mid = (start + end) / 2;
	update(start, mid, idx * 2, what, value);
	update(mid + 1, end, idx * 2 + 1, what, value);
	tree[idx] = tree[idx * 2] + tree[idx * 2 + 1];
}

LL interval_sum(int start, int end, int idx, int left, int right) {
	if (end < left || right < start) return 0;
	if (left <= start && end <= right) return tree[idx];

	int mid = (start + end) / 2;
	return interval_sum(start, mid, idx * 2, left, right) + interval_sum(mid + 1, end, idx * 2 + 1, left, right);
}

void merge_result(int start, int mid, int end) {
	int leftidx = start;
	int rightidx = mid + 1;
	int allidx = start;
	
	while (leftidx <= mid && rightidx <= end) {
		if (result[leftidx][1] < result[rightidx][1]) {
			mres[allidx][0] = result[leftidx][0];
			mres[allidx][1] = result[leftidx][1];
			allidx += 1;
			leftidx += 1;
		}
		else if (result[leftidx][1] > result[rightidx][1]) {
			mres[allidx][0] = result[rightidx][0];
			mres[allidx][1] = result[rightidx][1];
			allidx += 1;
			rightidx += 1;
		}
	}

	while (leftidx <= mid) {
		mres[allidx][0] = result[leftidx][0];
		mres[allidx][1] = result[leftidx][1];
		allidx += 1;
		leftidx += 1;
	}
	
	while (rightidx <= end) {
		mres[allidx][0] = result[rightidx][0];
		mres[allidx][1] = result[rightidx][1];
		allidx += 1;
		rightidx += 1;
	}

	for (int i = start; i <= end; i++) {
		result[i][0] = mres[i][0];
		result[i][1] = mres[i][1];
	}
}

void merge_sort_result(int start, int end) {
	if (start < end) {
		int mid = (start + end) / 2;
		merge_sort_result(start, mid);
		merge_sort_result(mid + 1, end);
		merge_result(start, mid, end);
	}
}

int main(void) {
	int n;
	scanf("%d", &n);

	for (int i = 0; i < n; i++) scanf("%d", &arr[i]);

	init(0, n - 1, 1);

	int m;
	scanf("%d", &m);
	for (int i = 0; i < m; i++) {
		int mode;
		scanf("%d", &mode);

		if (mode == 1) {
			scanf("%d %d", &query1[q1s][0], &query1[q1s][1]);
			q1s += 1;
		}
		else {
			scanf("%d %d %d", &query2[q2s][0], &query2[q2s][1], &query2[q2s][2]);
			query2[q2s][3] = q2s;
			q2s += 1;
		}
	}

	// 1. 쿼리 k 번째로 정렬
	merge_sort_k(0, q2s - 1);

	// 2. 쿼리 수행 -> 결과값 저장
	int now = 0;
	for (int i = 0; i < q2s; i++) {
		// printf("[%d %d %d %d] \n", query2[i][0], query2[i][1], query2[i][2], query2[i][3]);
		while (now != query2[i][0]) {
			update(0, n - 1, 1, query1[now][0] - 1, query1[now][1]);
			now += 1;
		}
		result[i][0] = interval_sum(0, n - 1, 1, query2[i][1] - 1, query2[i][2] - 1);
		result[i][1] = query2[i][3];
	}
	
	// 3. 결과값 정렬 및 출력
	merge_sort_result(0, q2s - 1);

	for (int i = 0; i < q2s; i++) {
		// printf("[%d %d] \n", result[i][0], result[i][1]);
		printf("%lld\n", result[i][0]);
	}

	return 0;
}