세그먼트 트리(Sement Tree)
특정 구간 내 데이터에 대한 연산을 빠르게 구할 수 있는 트리
(연산: 합, 최댓값, 최솟값 등)

세그먼트 트리의 잎 노드: 배열 원소와 같은 값
세그먼트 트리의 내부 노드: 구간(부분 배열)에 대한 값
세그먼트 트리의 높이: Ceil(O(log2N) (N은 배열의 크기)
세그먼트 트리의 크기: 2ceil(logn)+1-1 (log의 밑은 2)
(실제 코딩 시 index가 0인 원소는 사용하지 않으므로 트리 크기+1만큼 공간을 사용한다)
알고리즘
– 업데이트 (배열 원소 및 세그먼트 트리 수정)
– 구간 연산 (합, 최댓값, 최솟값 등)
시간복잡도
– 세그먼트 트리 생성: O(N)
– 업데이트: O(logN)
– 구간 연산: O(logN)
-> 데이터를 M번 업데이트 및 구간 연산: O(MlogN)

세그먼트 트리 (구간 합 구하기: Range Sum Query)
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
int size;
int* s_tree;
int init(int* arr, int node, int s, int e) {
if (s == e)
return s_tree[node] = arr[s];
int m = (s + e) / 2;
// 재귀적으로 두 구간으로 나눈 뒤 그 합을 자기 자신으로 함
// ex) [node,s,e]=[1,0,4]=[2,0,2]+[3,3,4]
return s_tree[node] = init(arr, node * 2, s, m) + init(arr, node * 2 + 1, m + 1, e);
}
// s,e: 배열의 시작,끝 인덱스 (탐색하는 구간)
// l,r: 부분 배열의 시작,끝 인덱스 (합을 구하려는 구간)
int sum(int node, int s, int e, int l, int r) {
// case1) 탐색하는 구간이 부분 배열 구간과 무관 -> 0 리턴
// ex) s,e=(0,1), l,r=(2,3)
if (l > e || r < s)
return 0;
// case2) 탐색하는 구간이 부분 배열 구간에 속함 -> 노드의 값 리턴
// ex) s,e=(2,2), l,r=(2,3)
if (l <= s && e <= r)
return s_tree[node];
// 재귀적으로 두 구간으로 나눈 뒤 case2를 만족한 구간들을 합침
// ex) (l=2,r=3) -> (s=2,e=2)+(s=3,e=3)
int m = (s + e) / 2;
return sum(node * 2, s, m, l, r) + sum(node * 2 + 1, m + 1, e, l, r);
}
// s,e: 배열의 시작,끝 인덱스 (탐색하는 구간)
// idx: 배열의 특정 인덱스 (바꾸려는 원소)
// diff: idx번째 원소를 val로 변경할 때 기존 값과의 차이
void _update(int node, int s, int e, int idx, int diff) {
// 탐색하는 구간이 바꾸려는 원소를 지님
if (s <= idx && idx <= e) {
// 탐색하는 구간이 바꾸려는 원소와 일치 -> 종료
if (s == e) {
s_tree[node] += diff;
return;
}
else {
s_tree[node] += diff;
int m = (s + e) / 2;
// 재귀적으로 두 구간으로 나눈 뒤 idx와 관련있는 구간이면 업데이트
_update(node * 2, s, m, idx, diff);
_update(node * 2 + 1, m + 1, e, idx, diff);
}
}
// 탐색하는 구간이 바꾸려는 원소와 무관
return ;
}
void update(int node, int s, int e, int* arr, int idx, int val) {
int diff = val - arr[idx];
arr[idx] = val;
_update(node, s, e, idx, diff);
}
// 배열 출력
void aprint(int* arr, int size) {
printf("배열: ");
for (int i = 0; i < size; i++) {
printf("[%d] %d ", i, arr[i]);
}
printf("\n");
}
// 세그먼트 트리 출력
void sprint(int size) {
printf("세그먼트 트리: ");
for (int i = 1; i < size; i++) {
printf("[%d] %d ", i, s_tree[i]);
}
printf("\n");
}
int main(void)
{
// 배열 선언 및 배열 크기(size) 초기화
int arr[] = { 0,1,2,3,4,5 };
size = 6;
// 세그먼트 트리 선언
int h = (int)ceil(log2(size));
int tree_size = (int)pow(2, h + 1);
s_tree = (int*)calloc(tree_size, sizeof(int));
// 세그먼트 트리 초기화 (arr 배열 이용)
init(arr, 1, 0, size - 1);
// 배열 및 세그먼트 트리 출력
aprint(arr, size);
sprint(tree_size);
// 구간 합 구하기 (배열의 l~r 구간에 대한 합을 세그먼트 트리로 구하기)
printf("[2,4]의 합:%d\n\n", sum(1, 0, size - 1, 2, 4)); // 특정 구간
/* ------------------------------------------------------------------ */
/* ------------------------------------------------------------------ */
// 업데이트 (배열의 idx번째 원소를 val로 바꿀 때 세그먼트 트리 업데이트)
int idx = 4;
int val = 5;
printf("업데이트: arr[%d]=%d -> %d\n", idx, arr[idx], val);
update(1, 0, size - 1, arr, idx, val); // arr[4]=4가 5가 되었을 때 세그먼트 트리 업데이트
// 배열 및 세그먼트 트리 출력
aprint(arr, size);
sprint(tree_size);
// 구간 합 구하기
printf("[2,4]의 합:%d\n", sum(1, 0, size - 1, 2, 4)); // 특정 구간
return 0;
}

세그먼트 트리 (구간 최댓값 구하기: Range Max Query)
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#define MAX(a,b) (((a) >= (b)) ? (a):(b))
int size;
int* s_tree;
int init(int* arr, int node, int s, int e) {
if (s == e)
return s_tree[node] = arr[s];
int m = (s + e) / 2;
return s_tree[node] = MAX(init(arr, node * 2, s, m), init(arr, node * 2 + 1, m + 1, e));
}
// s,e: 배열의 시작,끝 인덱스 (탐색하는 구간)
// l,r: 부분 배열의 시작,끝 인덱스 (MAX을 구하려는 구간)
int ret(int node, int s, int e, int l, int r) {
if (l > e || r < s)
return 0;
if (l <= s && e <= r)
return s_tree[node];
int m = (s + e) / 2;
return MAX(ret(node * 2, s, m, l, r), ret(node * 2 + 1, m + 1, e, l, r));
}
// s,e: 배열의 시작,끝 인덱스 (탐색하는 구간)
// idx: 배열의 특정 인덱스 (바꾸려는 원소)
// val: 바꿀 숫자
int _update(int node, int s, int e, int idx, int val) {
if (s <= idx && idx <= e) {
if (s == e) {
s_tree[node] = val;
}
else {
int m = (s + e) / 2;
s_tree[node] = MAX(_update(node * 2, s, m, idx, val), _update(node * 2 + 1, m + 1, e, idx, val));
}
}
return s_tree[node];
}
void update(int node, int s, int e, int* arr, int idx, int val) {
arr[idx] = val;
_update(node, s, e, idx, val);
}
// 배열 출력
void aprint(int* arr, int size) {
printf("배열: ");
for (int i = 0; i < size; i++) {
printf("[%d] %d ", i, arr[i]);
}
printf("\n");
}
// 세그먼트 트리 출력
void sprint(int size) {
printf("세그먼트 트리: ");
for (int i = 1; i < size; i++) {
printf("[%d] %d ", i, s_tree[i]);
}
printf("\n");
}
int main(void)
{
// 배열 선언 및 배열 크기(size) 초기화
int arr[] = { 0,1,2,3,4,5 };
size = 6;
// 세그먼트 트리 선언
int h = (int)ceil(log2(size));
int tree_size = (int)pow(2, h + 1);
s_tree = (int*)calloc(tree_size, sizeof(int));
// 세그먼트 트리 초기화 (arr 배열 이용)
init(arr, 1, 0, size - 1);
// 배열 및 세그먼트 트리 출력
aprint(arr, size);
sprint(tree_size);
// 구간 최댓값 구하기 (배열의 l~r 구간에 대한 최댓값을 세그먼트 트리로 구하기)
printf("[2,4]의 최댓값:%d\n\n", ret(1, 0, size - 1, 2, 4)); // 특정 구간
/* ------------------------------------------------------------------ */
/* ------------------------------------------------------------------ */
// 업데이트 (배열의 idx번째 원소를 val로 바꿀 때 세그먼트 트리 업데이트)
int idx = 4;
int val = 5;
printf("업데이트: arr[%d]=%d -> %d\n", idx, arr[idx], val);
update(1, 0, size - 1, arr, idx, val); // arr[4]=4가 5가 되었을 때 세그먼트 트리 업데이트
// 배열 및 세그먼트 트리 출력
aprint(arr, size);
sprint(tree_size);
// 구간 최댓값 구하기 (배열의 l~r 구간에 대한 최댓값을 세그먼트 트리로 구하기)
printf("[2,4]의 최댓값:%d\n\n", ret(1, 0, size - 1, 2, 4)); // 특정 구간
return 0;
}
