본문 바로가기
Algorithm/Concepts

Union-Find

by 컴돈AI 2024. 3. 20.

목차

    Union-Find

    • Union-Find 알고리즘은 두 요소가 같은 그룹에 속하는지를 판단하거나, 두 요소를 같은 그룹으로 합치는 연산을 효율적으로 수행하는 데 사용됩니다.
    • 이 알고리즘은 그래프의 연결성을 확인하거나, 최소 스패닝 트리를 찾는 크루스칼 알고리즘과 같은 다양한 알고리즘에서 중요한 역할을 합니다.
    • Union-Find 알고리즘은 다음 두 가지 주요 연산으로 구성됩니다.
      • Find
        • 어떤 요소가 속한 그룹(또는 집합)의 대표를 찾는 연산.
        • 이 연산을 통해 두 요소가 같은 그룹에 속하는지 확인 가능
        • 두 요소의 Find 연산 결과가 같다면, 두 요소는 같은 그룹에 속한다고 판단 가능
        • parent[x]를 찾는 과정
      • Union
        • 두 그룹을 하나의 그룹으로 합치는 연산.
        • 두 요소가 속한 그룹을 Union 연산을 통해 합칠 경우, 이후 이 두 요소는 Find 연산을 했을 때 같은 결과를 반환하게 됩니다.
    • 코드
      • def find(x):
            if parent[x]==x:
                return x
            parent[x]=find(parent[x])
            return parent[x]
        
        def union(a,b):
            a = find(a)
            b = find(b)
        
            if a==b:
                return False
            
            elif a>b: # 작은값이 부모가 되도록 설정
                parent[a]=b 
                return True
            
            elif a<b:
                parent[b]=a
                return True
                
         parent = [i for i in range(n)] # 처음은 각자 자신이 최상단 노드

    예시

    • [백준 16562] 친구비
      • 모든 친구관계 중 최소비용을 기록하면서, 친구관계를 집합으로 묶어줍니다. 친구 집합들에 대해 최소비용의 합을 구해주면 됩니다.
      • 중요한 것은 마지막에 전체 노드에 대해서 find() 연산을 한번 진행해줘야 합니다. 나중에 연결된 노드에 대해서는 이전에 연결된 집합 원소들에 대해서는 부모노드가 업데이트되지 않았을 수도 있기 때문입니다.
      • 코드
        • import sys
          
          input = sys.stdin.readline
          
          n, m, k = map(int, input().split())
          
          pay_arr = list(map(int, input().split()))
          
          
          def find(x):
              if parent[x] == x:
                  return x
              parent[x] = find(parent[x])
              return parent[x]
          
          
          def union(a, b):
              a = find(a)
              b = find(b)
              if a > b:
                  parent[a] = b
          
              else:
                  parent[b] = a
          
              pay_arr[a], pay_arr[b] = min(pay_arr[a], pay_arr[b]), min(pay_arr[a], pay_arr[b])
          
          
          parent = [i for i in range(n)]
          
          for _ in range(m):
              v, w = map(int, input().split())
              union(v - 1, w - 1)
          
          for i in range(n):
              find(i)
          
          parent_set = set(parent)
          
          sum_value = sum([pay_arr[idx] for idx in parent_set])
          if sum_value <= k:
              print(sum_value)
          else:
              print("Oh no")