Data Structure – Segment Tree (Range Sum & Range Max)

세그먼트 트리(Sement Tree)

특정 구간 내 데이터에 대한 연산을 빠르게 구할 수 있는 트리
(연산: 합, 최댓값, 최솟값 등)

Segment trees (for beginners) | HackerEarth

세그먼트 트리의 잎 노드: 배열 원소와 같은 값
세그먼트 트리의 내부 노드: 구간(부분 배열)에 대한 값
세그먼트 트리의 높이: Ceil(O(log2N) (N은 배열의 크기)
세그먼트 트리의 크기: 2ceil(logn)+1-1 (log의 밑은 2)
(실제 코딩 시 index가 0인 원소는 사용하지 않으므로 트리 크기+1만큼 공간을 사용한다)

알고리즘

업데이트 (배열 원소 및 세그먼트 트리 수정)
구간 연산 (합, 최댓값, 최솟값 등)

시간복잡도

세그먼트 트리 생성: O(N)
업데이트: O(logN)
구간 연산: O(logN)
-> 데이터를 M번 업데이트 및 구간 연산: O(MlogN)

세그먼트 트리 (구간 합 구하기: Range Sum Query)

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

int size;
int* s_tree;

int init(int* arr, int node, int s, int e) {
	if (s == e)
		return s_tree[node] = arr[s];
	int m = (s + e) / 2;

	// 재귀적으로 두 구간으로 나눈 뒤 그 합을 자기 자신으로 함 
	// ex) [node,s,e]=[1,0,4]=[2,0,2]+[3,3,4]
	return s_tree[node] = init(arr, node * 2, s, m) + init(arr, node * 2 + 1, m + 1, e);
}

// s,e: 배열의 시작,끝 인덱스 (탐색하는 구간)
// l,r: 부분 배열의 시작,끝 인덱스 (합을 구하려는 구간)
int sum(int node, int s, int e, int l, int r) {
	// case1) 탐색하는 구간이 부분 배열 구간과 무관 -> 0 리턴
	// ex) s,e=(0,1), l,r=(2,3)
	if (l > e || r < s)
		return 0;

	// case2) 탐색하는 구간이 부분 배열 구간에 속함 -> 노드의 값 리턴 
	// ex) s,e=(2,2), l,r=(2,3)
	if (l <= s && e <= r)
		return s_tree[node];

	// 재귀적으로 두 구간으로 나눈 뒤 case2를 만족한 구간들을 합침 
	// ex) (l=2,r=3) -> (s=2,e=2)+(s=3,e=3)
	int m = (s + e) / 2;
	return sum(node * 2, s, m, l, r) + sum(node * 2 + 1, m + 1, e, l, r);
}

// s,e: 배열의 시작,끝 인덱스 (탐색하는 구간)
// idx: 배열의 특정 인덱스 (바꾸려는 원소)
// diff: idx번째 원소를 val로 변경할 때 기존 값과의 차이
void _update(int node, int s, int e, int idx, int diff) {
	// 탐색하는 구간이 바꾸려는 원소를 지님
	if (s <= idx && idx <= e) {
		// 탐색하는 구간이 바꾸려는 원소와 일치 -> 종료 
		if (s == e) {
			s_tree[node] += diff;
			return;
		}
		else {
			s_tree[node] += diff;
			int m = (s + e) / 2;
			// 재귀적으로 두 구간으로 나눈 뒤 idx와 관련있는 구간이면 업데이트 
			_update(node * 2, s, m, idx, diff);
			_update(node * 2 + 1, m + 1, e, idx, diff);
		}
	}
	// 탐색하는 구간이 바꾸려는 원소와 무관
	return ;
}

void update(int node, int s, int e, int* arr, int idx, int val) {
	int diff = val - arr[idx];
	arr[idx] = val;
	_update(node, s, e, idx, diff);
}

// 배열 출력
void aprint(int* arr, int size) {
	printf("배열: ");
	for (int i = 0; i < size; i++) {
		printf("[%d] %d ", i, arr[i]);
	}
	printf("\n");
}

// 세그먼트 트리 출력 
void sprint(int size) {
	printf("세그먼트 트리: ");
	for (int i = 1; i < size; i++) {
		printf("[%d] %d ", i, s_tree[i]);
	}
	printf("\n");
}

int main(void)
{
	// 배열 선언 및 배열 크기(size) 초기화 
	int arr[] = { 0,1,2,3,4,5 };
	size = 6;

	// 세그먼트 트리 선언
	int h = (int)ceil(log2(size));
	int tree_size = (int)pow(2, h + 1);
	s_tree = (int*)calloc(tree_size, sizeof(int));

	// 세그먼트 트리 초기화 (arr 배열 이용) 
	init(arr, 1, 0, size - 1);

	// 배열 및 세그먼트 트리 출력
	aprint(arr, size);
	sprint(tree_size);

	// 구간 합 구하기 (배열의 l~r 구간에 대한 합을 세그먼트 트리로 구하기)
	printf("[2,4]의 합:%d\n\n", sum(1, 0, size - 1, 2, 4)); // 특정 구간

	/* ------------------------------------------------------------------  */
	/* ------------------------------------------------------------------  */

	// 업데이트 (배열의 idx번째 원소를 val로 바꿀 때 세그먼트 트리 업데이트)
	int idx = 4;
	int val = 5;
	printf("업데이트: arr[%d]=%d -> %d\n", idx, arr[idx], val);
	update(1, 0, size - 1, arr, idx, val); // arr[4]=4가 5가 되었을 때 세그먼트 트리 업데이트 

	// 배열 및 세그먼트 트리 출력
	aprint(arr, size);
	sprint(tree_size);

	// 구간 합 구하기 
	printf("[2,4]의 합:%d\n", sum(1, 0, size - 1, 2, 4)); // 특정 구간

	return 0;
}

세그먼트 트리 (구간 최댓값 구하기: Range Max Query)

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#define MAX(a,b) (((a) >= (b)) ? (a):(b))

int size;
int* s_tree;

int init(int* arr, int node, int s, int e) {
	if (s == e)
		return s_tree[node] = arr[s];
	int m = (s + e) / 2;
	return s_tree[node] = MAX(init(arr, node * 2, s, m), init(arr, node * 2 + 1, m + 1, e));
}

// s,e: 배열의 시작,끝 인덱스 (탐색하는 구간)
// l,r: 부분 배열의 시작,끝 인덱스 (MAX을 구하려는 구간)
int ret(int node, int s, int e, int l, int r) {
	if (l > e || r < s)
		return 0;

	if (l <= s && e <= r)
		return s_tree[node];

	int m = (s + e) / 2;
	return MAX(ret(node * 2, s, m, l, r), ret(node * 2 + 1, m + 1, e, l, r));
}

// s,e: 배열의 시작,끝 인덱스 (탐색하는 구간)
// idx: 배열의 특정 인덱스 (바꾸려는 원소)
// val: 바꿀 숫자 
int _update(int node, int s, int e, int idx, int val) {
	if (s <= idx && idx <= e) {
		if (s == e) {
			s_tree[node] = val;
		}
		else {
			int m = (s + e) / 2;
			s_tree[node] = MAX(_update(node * 2, s, m, idx, val), _update(node * 2 + 1, m + 1, e, idx, val));
		}
	}

	return s_tree[node];
}

void update(int node, int s, int e, int* arr, int idx, int val) {
	arr[idx] = val;
	_update(node, s, e, idx, val);
}

// 배열 출력
void aprint(int* arr, int size) {
	printf("배열: ");
	for (int i = 0; i < size; i++) {
		printf("[%d] %d ", i, arr[i]);
	}
	printf("\n");
}

// 세그먼트 트리 출력 
void sprint(int size) {
	printf("세그먼트 트리: ");
	for (int i = 1; i < size; i++) {
		printf("[%d] %d ", i, s_tree[i]);
	}
	printf("\n");
}


int main(void)
{
	// 배열 선언 및 배열 크기(size) 초기화 
	int arr[] = { 0,1,2,3,4,5 };
	size = 6;

	// 세그먼트 트리 선언
	int h = (int)ceil(log2(size));
	int tree_size = (int)pow(2, h + 1);
	s_tree = (int*)calloc(tree_size, sizeof(int));

	// 세그먼트 트리 초기화 (arr 배열 이용) 
	init(arr, 1, 0, size - 1);

	// 배열 및 세그먼트 트리 출력
	aprint(arr, size);
	sprint(tree_size);

	// 구간 최댓값 구하기 (배열의 l~r 구간에 대한 최댓값을 세그먼트 트리로 구하기)
	printf("[2,4]의 최댓값:%d\n\n", ret(1, 0, size - 1, 2, 4)); // 특정 구간

	/* ------------------------------------------------------------------  */
	/* ------------------------------------------------------------------  */

	// 업데이트 (배열의 idx번째 원소를 val로 바꿀 때 세그먼트 트리 업데이트)
	int idx = 4;
	int val = 5;
	printf("업데이트: arr[%d]=%d -> %d\n", idx, arr[idx], val);
	update(1, 0, size - 1, arr, idx, val); // arr[4]=4가 5가 되었을 때 세그먼트 트리 업데이트 

	// 배열 및 세그먼트 트리 출력
	aprint(arr, size);
	sprint(tree_size);

	// 구간 최댓값 구하기 (배열의 l~r 구간에 대한 최댓값을 세그먼트 트리로 구하기)
	printf("[2,4]의 최댓값:%d\n\n", ret(1, 0, size - 1, 2, 4)); // 특정 구간

	return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *