본문 바로가기
Algorithm/Concepts

구간 트리(Segment Tree)

by 컴돈AI 2024. 3. 20.

목차

    구간 트리(Segment Tree)

    • 세그먼트 트리는 구간 쿼리와 갱신(Update) 작업을 효율적으로 처리할 수 있는 자료 구조입니다.
      • 배열의 특정 구간에 대한 정보(예를 들면 합, 최댓값, 최솟값)를 빠르게 계산하고, 배열의 원소가 변경될 때 이 정보를 쉽게 업데이트할 수 있도록 설계되었습니다.
      • 리스트를 슬라이싱 해서 해당 구간의 연산을 자주 하는 문제가 있을 경우 세그먼트 트리를 이용하여 시간 단축이 가능합니다.
      • 전체 배열의 개수가 N이고 (특정 리스트 슬라이싱 구간 연산) 쿼리의 개수가 Q라고 할 경우, 시간복잡도는 리스트 슬라이싱으로 O(N), 쿼리연산으로 O(Q)이므로 결국 O(QN)의 시간복잡도가 걸릴 것입니다. 
      • 반면, 세그먼트 트리를 사용하면, 초기 배열에 대한 트리를 구성하는 데 O(NlogN) 시간이 소요되고(실제로는 더 최적화된 O 시간으로 구성할 수 있습니다), 각 쿼리 처리 시 의 시간이 소요됩니다. 따라서 쿼리 개를 처리하는 데는 총 의 시간이 걸립니다.
        • 배열의 원소가 변경될 때, 즉 업데이트 작업도 의 시간으로 효율적으로 처리할 수 있습니다.
    • 세그먼트 트리는 주로 트리 형태로 구현되며, 리프 노드에는 배열의 원소, 내부 노드에는 자식 노드의 정보를 합한 값이 저장됩니다.
    • 세그먼트 트리의 연산
      • build : 주어진 배열을 기반으로 세그먼트 트리를 구축하는 과정입니다. 일반적으로 O(Nlog(N) 시간복잡도를 가집니다.
      • query : 특정 구간에 대한 정보( 예 : 합, 최댓값, 최솟값 등)를 조회하는 과정입니다. O(logN) 시간복잡도를 가집니다.
      • update : 배열의 특정 원소가 변경될 때 세그먼트 트리를 업데이트하는 과정입니다. O(logN)의 시간복잡도를 가집니다.
    • 예시의 [백준 2042] 구간합 구하기를 통해 개념 정리가 가능합니다.

    예시

    • [백준 2042] 구간 합 구하기
      • N 은 최대 1000000 , M과 K는 최대 10000입니다. 만약 단순히 N개의 수에다가 수변경과 구간합을 구할 경우 (M+K)N 만큼의 시간이 소요될 것입니다. 
        • 즉 20000* 1000000 은 시간 초과가 발생하게 될 것 입니다. 
      • 이를 이제 세그먼트 트리로 구성해 놓을 경우 생성 시에는 Nlog(N), 수의 변경 시에는 Mlog(N) , 부분합 구할 시에는 Klog(N) 만큼의 시간이 소요될 것입니다.
        • log1000000는 약 20으로 훨씬 적은 시간으로 해결이 가능합니다.
      • 세그먼트 트리 생성하기(build)
        • 세그먼트 트리는 N개의 리스트가 있다면, 2 * 2^(math.ceil(logN)) 크기의 배열이 됩니다. (아래 예시를 보면 왜 이러한 수식이 나오는지 이해할 수 있습니다.)
        • 세그먼트 트리는 완전 이진 트리 형태로 구성됩니다. 리프 노드에 기존 배열의 값들이 들어가게 됩니다. 각 리프노드의 부모들은 그 두 구간의 연산(예를 들면 합, 최댓값, 최솟값)을 의미합니다. 이런 식으로, 최상단까지 올라가게 된다면, 부모 노드는 결국 모든 구간에 대한 연산(예를 들면 합, 최댓값, 최솟값)이 됩니다.
        • 예를 들어 1, 2, 3 ,4 ,5의 배열에 대해 구간합을 구하는 세그먼트 트리를 생성해 보겠습니다.
          • 전체 세그먼트 트리 배열의 길이는 2*2^(math.ceil(math.log2(5)))가 됩니다.
            • math.ceil(math.log2(5))  = 3
            • 2^3 = 8
            • 2*2^3 = 16
          • 리프노드는 2^(math.ceil( math.log2(N) )) 인덱스부터 시작됩니다.
          • 부모노드로 이동하는 것은 index//2 , 자식노드로 이동하는 것은 index*2, index*2+1로 표현가능합니다.
          • 부모노드는 자식노드 구간에 대한 연산을 의미하므로 1,2,3,4,5의 배열에 대한 구간합을 구하는 세그먼트 트리를 생성해 보면 다음과 같습니다.
            • (인덱스 0 은 무시합니다. 인덱스 1이 루트 노드가 됩니다.)
            • [0,15,10,5,3,7,5,0,1,2,3,4,5,0,0,0]
            • 가장 먼저 2^(math.ceil( math.log2(N)) ) =2^3 = 8부터 리스트 배열 넣어주기
              • [0,0,0,0,0,0,0,0,1,2,3,4,5,0,0,0]
            • 그 후에 부모 노드로 하나씩 올라가기
              • 인덱스 8 자식노드 와 인덱스 9 자식노드의 부모 노드 -> 인덱스 4에 인덱스 8 과인덱스 9의 값 더해주기
                • [0,0,0,0,3,0,0,0,1,2,3,4,5,0,0,0]
              • 마찬가지로 인덱스 5, 6 ,7 채워주기
                • [0,0,0,0,3,7,5,0,1,2,3,4,5,0,0,0]
            • 그다음 층 부모 노드도 마찬가지로 생성해 줍니다. 
              • 인덱스 4,5,6,7의 부모 노드 -> 2,3이 됩니다. 
                • [0,0,10,5 ,3,7,5,0,1,2,3,4,5,0,0,0]
            • 최종 루트 노드 생성
              • 인덱스 2,3의 부모 노드 -> 1
                • [0,15,10,5 ,3,7,5,0,1,2,3,4,5,0,0,0]
        • 코드
          • max_n = 2 ** (math.ceil(math.log2(n)))
            seg_tree = [0 for _ in range(max_n * 2)]
            
            
            # n은 인덱스 기준
            def update(n, value):
                idx = max_n + n
                seg_tree[idx] = value
            
                while idx > 1:
                    idx = idx // 2
                    seg_tree[idx] = seg_tree[idx * 2] + seg_tree[idx * 2 + 1]
            
            
            # 초기화
            for i in range(n):
                update(i, int(input()))
      • 부분합 구하기(query)
        • 원하는 구간이 나올 때까지 전체 구간부터 탐색해 나아갑니다.
        • 1,2,3,4,5의 세그먼트 트리에 대해서 2번째 수부터 5번째 수까지의 값을 구하는 예시
        • 코드
          • # (n_1,n_2,0,max_n-1,1) 부터 시작
            # n_1, n_2, s, e 은 인덱스 기준
            def find(n_1, n_2, s, e, idx):
                if n_1 <= s and e <= n_2:
                    return seg_tree[idx]
                elif n_1 > e or n_2 < s:
                    return 0
                else:
                    left = find(n_1, n_2, s, (s + e) // 2, idx * 2)
                    right = find(n_1, n_2, (s + e) // 2 + 1, e, idx * 2 + 1)
                    return left + right
      • 수의 변경(update)
        • 수의 변경은 해당 인덱스의 값이 변경되면 그의 부모노드값들을 부모노드 인덱스가 1이 될 때까지 하나씩 업데이트해주면 됩니다.
        • 코드
          • # n은 인덱스 기준
            def update(n, value):
                idx = max_n + n
                seg_tree[idx] = value
            
                while idx > 1:
                    idx = idx // 2
                    seg_tree[idx] = seg_tree[idx * 2] + seg_tree[idx * 2 + 1]
      • 코드
        • import sys
          import math
          
          input = sys.stdin.readline
          
          n, m, k = map(int, input().split())
          
          max_n = 2 ** (math.ceil(math.log2(n)))
          seg_tree = [0 for _ in range(max_n * 2)]
          
          
          # n은 인덱스 기준
          def update(n, value):
              idx = max_n + n
              seg_tree[idx] = value
          
              while idx > 1:
                  idx = idx // 2
                  seg_tree[idx] = seg_tree[idx * 2] + seg_tree[idx * 2 + 1]
          
          
          # 초기화
          for i in range(n):
              update(i, int(input()))
          
          
          # (n_1,n_2,0,max_n-1,1) 부터 시작
          # n_1, n_2, s, e 은 인덱스 기준
          def find(n_1, n_2, s, e, idx):
              if n_1 <= s and e <= n_2:
                  return seg_tree[idx]
              elif n_1 > e or n_2 < s:
                  return 0
              else:
                  left = find(n_1, n_2, s, (s + e) // 2, idx * 2)
                  right = find(n_1, n_2, (s + e) // 2 + 1, e, idx * 2 + 1)
                  return left + right
          
          
          for _ in range(m + k):
              a, b, c = map(int, input().split())
              # a가 1인 경우 -> b번째수 c로 변경
              # a가 2인 경우 -> b번째부터 c번째 수까지의 합구하기
              if a == 1:
                  update(b - 1, c)
              elif a == 2:
                  print(find(b - 1, c - 1, 0, max_n - 1, 1))
    • [백준 10868] 최솟값
      • 구간 합 구하기와 동일하지만, 특정 구간에서 최솟값을 구하는 문제입니다.
      • 코드
        • import sys
          import math
          
          input = sys.stdin.readline
          
          n, m = map(int, input().split())
          
          max_n = 2 ** (math.ceil(math.log2(n)))
          
          seg_tree = [int(1e9) for _ in range(2 * max_n)]
          
          
          def update(n, value):
              idx = max_n + n
              seg_tree[idx] = value
          
              while idx > 1:
                  idx = idx // 2
                  seg_tree[idx] = min(seg_tree[idx * 2], seg_tree[idx * 2 + 1])
          
          
          for i in range(n):
              update(i, int(input()))
          
          
          def find(n_1, n_2, s, e, idx):
              if n_1 <= s and e <= n_2:
                  return seg_tree[idx]
              elif n_1 > e or n_2 < s:
                  return int(1e9)
              else:
                  left = find(n_1, n_2, s, (s + e) // 2, idx * 2)
                  right = find(n_1, n_2, (s + e) // 2 + 1, e, idx * 2 + 1)
                  return min(left, right)
          
          
          for _ in range(m):
              a, b = map(int, input().split())
              print(find(a - 1, b - 1, 0, max_n - 1, 1))