코딩 테스트 (Python)/문법

[문법] 최소 신장 트리 개념 및 구현(Python)

hihyuk 2024. 1. 13. 10:32

용어 정리

정점:  node

간선 : edge

  • 두 node 를 잇는 요소
  • 방향과 가중치를 가질 수 있음

그래프:  정점과 간선의 집합

서브 그래프: 기존 그래프의 정점 일부와, 간선 일부로 구성된 그래프

 

최소 신장 트리(Minimum Spanning Tree, MST)란?

  • 최소 스패닝 트리 라고도 많이 불린다.
  • 주어진 그래프의 모든 정점을 지나면서, 간선의 가중치 합이 최소가 되는 트리
    • 그래프의 모든 정점을 포함하고
    • 그래프의 일부 간선을 포함하며
    • 사이클이 없는 (트리의 특징)
    • 위 특징을 가진 서브 그래프 중 간선의 가중치 합이 최소인 트리
  • 원본 그래프 간선에 가중치가 부여되어 있음
  • 간선의 방향 여부에 따라 구하는 방법이 달라짐
    • 코딩테스트에서는 간선의 방향이 없는, 무방향 그래프에 대해서만 나옴
  • 정점이 n 개인 그래프의 최소 신장 트리는, n - 1 개의 간선을 가져야 함(트리의 특징)

 

프림 (Prim’s Algorithm)

작동 순서

  1. 임의의 정점을 시작으로 선택한다.
  2. 해당 정점과 연결된 간선을 탐색 대상에 추가한다.
  3. 탐색 대상에서 가장 가중치가 낮은 간선을 뽑는다. 간선의 끝 점이 아직 방문하지 않았다면, 그 간선을 추가한다. 방문했다면, 방문하지 않은 간선이 나올 때까지 간선을 뽑는다.
  4. 새로 추가된 정점과 연결된 간선을 탐색 대상에 추가한다.
  5. 3 ~ 4번 과정을 (전체 정점 - 1) 개의 간선이 트리에 추가될 때까지 반복한다.
def prim(n, edges) -> int:
    """
    n: 정점의 개수
    edges: (정점1, 정점2, 가중치)의 리스트

    최소 스패닝 트리의 가중치를 반환
    """
    import heapq

    graph = [[] for _ in range(n + 1)]
    for idx, adj, cost in edges:
        graph[idx].append((cost, adj))
        graph[adj].append((cost, idx))

    # 임의의 정점을 시작으로 선택한다.
    visited = [False] * (n + 1)
    visited[1] = True
    heap = []
    for cost, adj in graph[1]:
        heapq.heappush(heap, (cost, adj))

    result = 0
    used_edges = 0
    while used_edges < n - 1:
        cost, idx = heapq.heappop(heap) # 가중치 낮은 간선을 선택한다.
        if visited[idx]: # 이미 방문한 정점이라면 패스
            continue
        visited[idx] = True
        result += cost
        used_edges += 1
        for adj_cost, adj in graph[idx]: # 선택한 정점과 연결된 간선들을 우선순위 큐에 넣는다.
            if not visited[adj]:
                heapq.heappush(heap, (adj_cost, adj))

    return result

 

크루스칼 (Kruskal’s Algorithm)

작동 순서

  1. 모든 간선들을 가중치에 대해 오름차순으로 정렬한다.
  2. 앞에서부터, 간선을 트리에 추가했을 때, 사이클이 생기지 않는다면 그 간선을 트리에 추가한다.
    • 사이클이 생긴다면 그대로 넘어간다 (간선을 트리에 추가하지 않는다)
  3. 2번 과정을 (전체 정점 - 1) 개의 간선이 트리에 추가될 때까지 반복한다.

매 간선을 추가할 때마다 그래프 탐색을 하는 건 비효율적이라, 새로운 알고리즘을 사용해서 사이클을 찾는다.

 

분리 집합 (Union-Find / Disjoint-Set)

그래프가 유동적으로 변할 때(간선이 추가/삭제 될 때), 점들 간의 연결 여부를 빠르게 확인할 수 있어, 크루스칼 알고리즘 구현 외에도 많이 사용된다.

부모 정점을 찾는 find 함수와,

서로 다른 부모 정점을 가지는 집합을 합치는 Union 함수로 이루어져 있습니다.

class DisjointSet:

    def __init__(self, n):
        self.parent = list(range(n + 1))

    def find(self, x):
        if self.parent[x] == x:
            return x
        self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x == root_y:
            return
        self.parent[root_x] = root_y


disjoint_set = DisjointSet(10)
edges = [(1, 3), (2, 5), (3, 5), (4, 6), (7, 10)]
for idx, adj in edges:
    disjoint_set.union(idx, adj)
for i in range(1, 11):
    print(f"{i}의 부모: {disjoint_set.find(i)}")

 

구현

class DisjointSet:

    def __init__(self, n):
        self.parent = list(range(n + 1))

    def find(self, x):
        if self.parent[x] == x:
            return x
        self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x == root_y:
            return
        self.parent[root_x] = root_y


def kruskal(n, edges) -> int:
    """
    n: 정점의 개수
    edges: (정점1, 정점2, 가중치)의 리스트
    """

    # 간선을 가중치 순으로 정렬
    edges.sort(key=lambda x: x[2])
    disjoint_set = DisjointSet(n)
    result = 0
    used_edges = 0

    # 가중치가 낮은 간선부터 선택
    for idx, adj, cost in edges:
        # 각 노드의 부모 노드 탐색
        # 사이클이 생기지 않는다면 간선을 선택
        # 부모 노드가 같다 = 이 간선을 선택하면 사이클이 생긴다!
        if disjoint_set.find(idx) != disjoint_set.find(adj):
            disjoint_set.union(idx, adj)
            result += cost
            used_edges += 1
            # 간선의 개수가 n - 1개가 되면 탐색 종료!
            if used_edges == n - 1:
                break

    return result