목차
구간 트리(Segment Tree)
- 세그먼트 트리는 구간 쿼리와 갱신(Update) 작업을 효율적으로 처리할 수 있는 자료 구조입니다.
- 배열의 특정 구간에 대한 정보(예를 들면 합, 최댓값, 최솟값)를 빠르게 계산하고, 배열의 원소가 변경될 때 이 정보를 쉽게 업데이트할 수 있도록 설계되었습니다.
- 리스트를 슬라이싱 해서 해당 구간의 연산을 자주 하는 문제가 있을 경우 세그먼트 트리를 이용하여 시간 단축이 가능합니다.
- 전체 배열의 개수가 N이고 (특정 리스트 슬라이싱 구간 연산) 쿼리의 개수가 Q라고 할 경우, 시간복잡도는 리스트 슬라이싱으로 O(N), 쿼리연산으로 O(Q)이므로 결국 O(QN)의 시간복잡도가 걸릴 것입니다.
- 반면, 세그먼트 트리를 사용하면, 초기 배열에 대한 트리를 구성하는 데 O(NlogN) 시간이 소요되고(실제로는 더 최적화된 O 시간으로 구성할 수 있습니다), 각 쿼리 처리 시 의 시간이 소요됩니다. 따라서 쿼리 개를 처리하는 데는 총 의 시간이 걸립니다.
- 배열의 원소가 변경될 때, 즉 업데이트 작업도 의 시간으로 효율적으로 처리할 수 있습니다.
- 세그먼트 트리는 주로 트리 형태로 구현되며, 리프 노드에는 배열의 원소, 내부 노드에는 자식 노드의 정보를 합한 값이 저장됩니다.
- 세그먼트 트리의 연산
- build : 주어진 배열을 기반으로 세그먼트 트리를 구축하는 과정입니다. 일반적으로 O(Nlog(N) 시간복잡도를 가집니다.
- query : 특정 구간에 대한 정보( 예 : 합, 최댓값, 최솟값 등)를 조회하는 과정입니다. O(logN) 시간복잡도를 가집니다.
- update : 배열의 특정 원소가 변경될 때 세그먼트 트리를 업데이트하는 과정입니다. O(logN)의 시간복잡도를 가집니다.
- 예시의 [백준 2042] 구간합 구하기를 통해 개념 정리가 가능합니다.
예시
- [백준 2042] 구간 합 구하기
- N 은 최대 1000000 , M과 K는 최대 10000입니다. 만약 단순히 N개의 수에다가 수변경과 구간합을 구할 경우 (M+K)N 만큼의 시간이 소요될 것입니다.
- 즉 20000* 1000000 은 시간 초과가 발생하게 될 것 입니다.
- 이를 이제 세그먼트 트리로 구성해 놓을 경우 생성 시에는 Nlog(N), 수의 변경 시에는 Mlog(N) , 부분합 구할 시에는 Klog(N) 만큼의 시간이 소요될 것입니다.
- log1000000는 약 20으로 훨씬 적은 시간으로 해결이 가능합니다.
- 세그먼트 트리 생성하기(build)
- 세그먼트 트리는 N개의 리스트가 있다면, 2 * 2^(math.ceil(logN)) 크기의 배열이 됩니다. (아래 예시를 보면 왜 이러한 수식이 나오는지 이해할 수 있습니다.)
- 세그먼트 트리는 완전 이진 트리 형태로 구성됩니다. 리프 노드에 기존 배열의 값들이 들어가게 됩니다. 각 리프노드의 부모들은 그 두 구간의 연산(예를 들면 합, 최댓값, 최솟값)을 의미합니다. 이런 식으로, 최상단까지 올라가게 된다면, 부모 노드는 결국 모든 구간에 대한 연산(예를 들면 합, 최댓값, 최솟값)이 됩니다.
- 예를 들어 1, 2, 3 ,4 ,5의 배열에 대해 구간합을 구하는 세그먼트 트리를 생성해 보겠습니다.
- 전체 세그먼트 트리 배열의 길이는 2*2^(math.ceil(math.log2(5)))가 됩니다.
- math.ceil(math.log2(5)) = 3
- 2^3 = 8
- 2*2^3 = 16
- 리프노드는 2^(math.ceil( math.log2(N) )) 인덱스부터 시작됩니다.
- 부모노드로 이동하는 것은 index//2 , 자식노드로 이동하는 것은 index*2, index*2+1로 표현가능합니다.
- 부모노드는 자식노드 구간에 대한 연산을 의미하므로 1,2,3,4,5의 배열에 대한 구간합을 구하는 세그먼트 트리를 생성해 보면 다음과 같습니다.
- (인덱스 0 은 무시합니다. 인덱스 1이 루트 노드가 됩니다.)
- [0,15,10,5,3,7,5,0,1,2,3,4,5,0,0,0]
- 가장 먼저 2^(math.ceil( math.log2(N)) ) =2^3 = 8부터 리스트 배열 넣어주기
- [0,0,0,0,0,0,0,0,1,2,3,4,5,0,0,0]
- 그 후에 부모 노드로 하나씩 올라가기
- 인덱스 8 자식노드 와 인덱스 9 자식노드의 부모 노드 -> 인덱스 4에 인덱스 8 과인덱스 9의 값 더해주기
- [0,0,0,0,3,0,0,0,1,2,3,4,5,0,0,0]
- 마찬가지로 인덱스 5, 6 ,7 채워주기
- [0,0,0,0,3,7,5,0,1,2,3,4,5,0,0,0]
- 인덱스 8 자식노드 와 인덱스 9 자식노드의 부모 노드 -> 인덱스 4에 인덱스 8 과인덱스 9의 값 더해주기
- 그다음 층 부모 노드도 마찬가지로 생성해 줍니다.
- 인덱스 4,5,6,7의 부모 노드 -> 2,3이 됩니다.
- [0,0,10,5 ,3,7,5,0,1,2,3,4,5,0,0,0]
- 인덱스 4,5,6,7의 부모 노드 -> 2,3이 됩니다.
- 최종 루트 노드 생성
- 인덱스 2,3의 부모 노드 -> 1
- [0,15,10,5 ,3,7,5,0,1,2,3,4,5,0,0,0]
- 인덱스 2,3의 부모 노드 -> 1
- 전체 세그먼트 트리 배열의 길이는 2*2^(math.ceil(math.log2(5)))가 됩니다.
- 코드
-
max_n = 2 ** (math.ceil(math.log2(n))) seg_tree = [0 for _ in range(max_n * 2)] # n은 인덱스 기준 def update(n, value): idx = max_n + n seg_tree[idx] = value while idx > 1: idx = idx // 2 seg_tree[idx] = seg_tree[idx * 2] + seg_tree[idx * 2 + 1] # 초기화 for i in range(n): update(i, int(input()))
-
- 부분합 구하기(query)
- 원하는 구간이 나올 때까지 전체 구간부터 탐색해 나아갑니다.
- 1,2,3,4,5의 세그먼트 트리에 대해서 2번째 수부터 5번째 수까지의 값을 구하는 예시
- 코드
-
# (n_1,n_2,0,max_n-1,1) 부터 시작 # n_1, n_2, s, e 은 인덱스 기준 def find(n_1, n_2, s, e, idx): if n_1 <= s and e <= n_2: return seg_tree[idx] elif n_1 > e or n_2 < s: return 0 else: left = find(n_1, n_2, s, (s + e) // 2, idx * 2) right = find(n_1, n_2, (s + e) // 2 + 1, e, idx * 2 + 1) return left + right
-
- 수의 변경(update)
- 수의 변경은 해당 인덱스의 값이 변경되면 그의 부모노드값들을 부모노드 인덱스가 1이 될 때까지 하나씩 업데이트해주면 됩니다.
- 코드
-
# n은 인덱스 기준 def update(n, value): idx = max_n + n seg_tree[idx] = value while idx > 1: idx = idx // 2 seg_tree[idx] = seg_tree[idx * 2] + seg_tree[idx * 2 + 1]
-
- 코드
-
import sys import math input = sys.stdin.readline n, m, k = map(int, input().split()) max_n = 2 ** (math.ceil(math.log2(n))) seg_tree = [0 for _ in range(max_n * 2)] # n은 인덱스 기준 def update(n, value): idx = max_n + n seg_tree[idx] = value while idx > 1: idx = idx // 2 seg_tree[idx] = seg_tree[idx * 2] + seg_tree[idx * 2 + 1] # 초기화 for i in range(n): update(i, int(input())) # (n_1,n_2,0,max_n-1,1) 부터 시작 # n_1, n_2, s, e 은 인덱스 기준 def find(n_1, n_2, s, e, idx): if n_1 <= s and e <= n_2: return seg_tree[idx] elif n_1 > e or n_2 < s: return 0 else: left = find(n_1, n_2, s, (s + e) // 2, idx * 2) right = find(n_1, n_2, (s + e) // 2 + 1, e, idx * 2 + 1) return left + right for _ in range(m + k): a, b, c = map(int, input().split()) # a가 1인 경우 -> b번째수 c로 변경 # a가 2인 경우 -> b번째부터 c번째 수까지의 합구하기 if a == 1: update(b - 1, c) elif a == 2: print(find(b - 1, c - 1, 0, max_n - 1, 1))
-
- N 은 최대 1000000 , M과 K는 최대 10000입니다. 만약 단순히 N개의 수에다가 수변경과 구간합을 구할 경우 (M+K)N 만큼의 시간이 소요될 것입니다.
- [백준 10868] 최솟값
- 구간 합 구하기와 동일하지만, 특정 구간에서 최솟값을 구하는 문제입니다.
- 코드
-
import sys import math input = sys.stdin.readline n, m = map(int, input().split()) max_n = 2 ** (math.ceil(math.log2(n))) seg_tree = [int(1e9) for _ in range(2 * max_n)] def update(n, value): idx = max_n + n seg_tree[idx] = value while idx > 1: idx = idx // 2 seg_tree[idx] = min(seg_tree[idx * 2], seg_tree[idx * 2 + 1]) for i in range(n): update(i, int(input())) def find(n_1, n_2, s, e, idx): if n_1 <= s and e <= n_2: return seg_tree[idx] elif n_1 > e or n_2 < s: return int(1e9) else: left = find(n_1, n_2, s, (s + e) // 2, idx * 2) right = find(n_1, n_2, (s + e) // 2 + 1, e, idx * 2 + 1) return min(left, right) for _ in range(m): a, b = map(int, input().split()) print(find(a - 1, b - 1, 0, max_n - 1, 1))
-
'Algorithm > Concepts' 카테고리의 다른 글
비트마스킹(Bit Masking) (0) | 2024.03.21 |
---|---|
느리게 갱신되는 세그먼트 트리(Lazy Segment Tree) (0) | 2024.03.20 |
최소 스패닝 트리(Minimum Spanning Tree, MST) (0) | 2024.03.20 |
Union-Find (0) | 2024.03.20 |
백트래킹 (Backtracking) (0) | 2024.03.19 |