PS/백준

[백준] (1197) 최소 스패닝 트리 [Python]

munsik22 2025. 3. 28. 16:21

🔗 문제 링크

https://www.acmicpc.net/problem/1197

문제

그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.

최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.

입력

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.

그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.

출력

첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.

테스트 케이스

예제 입력 예제 출력
3 3
1 2 1
2 3 2
1 3 3
3




💻 나의 코드

1st Code:

import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**6)

v, e = map(int, input().split())
adj = [[float("inf") for _ in range(v+1)] for _ in range(v+1)]
for _ in range(e):
    a, b, c = map(int, input().split())
    adj[a][b] = c
    adj[b][a] = c

visited = [False] * (v+1)

answer = 0
def mts(i):
    global answer
    
    if visited[i]:
        return

    visited[i] = True
    
    min_val = float("inf")
    idx = 0
    for j in range(1, v+1):
        if adj[i][j] < min_val and not visited[j]:
            min_val = adj[i][j]
            idx = j
    
    if idx > 0:
        answer += adj[i][idx]
        mts(idx)

for i in range(1, v+1):
    if not visited[i]:
        mts(i)
        
print(answer)

🔎 코드의 문제점

  1. 정확한 MST 알고리즘이 아님
    • 이 코드는 단순히 현재 노드에서 가장 가까운 방문 안 한 노드로 이동하는 그리디 방식이다.
    • 그러나 Prim 알고리즘이나 Kruskal 알고리즘사이클을 방지하고, 전체 그래프의 간선 정보를 전역적으로 고려해야 한다.
    • 지금 코드는 로컬 최솟값만 보고 이동해서 사이클을 만들거나, 최적해를 보장하지 못한다.
  2. MST의 핵심 조건
    • 최소 간선들만으로 모든 노드를 연결해야 하고, 사이클이 없어야 한다.
    • 이 코드는 visited를 통해 이미 방문한 노드는 다시 방문하지 않지만, 간선 선택 과정 자체에 사이클 검증이 없다.
  3. 입력 방식
    • adj를 인접 행렬로 만들었는데, 이 과정에서 불필요하게 메모리를 많이 사용하고 있다.
    • 정점 개수가 많으면 메모리 초과 가능성도 있다.

실제로 코드 제출 결과 메모리 초과가 발생했다.😵

2nd Code:

코드의 문제점을 참고하여, MST 탐색 방식을 Greedy에서 Prim 알고리즘으로 변경했다.

Prim 알고리즘의 핵심은 임의의 시작점에서 시작 → 가장 비용이 적은 간선으로 확장 → 모든 정점을 연결하는 것이다.

import sys
import heapq
input = sys.stdin.readline

v, e = map(int, input().split())
adj = [[] for _ in range(v+1)]

for _ in range(e):
    a, b, c = map(int, input().split())
    adj[a].append((c, b))
    adj[b].append((c, a))

# Prim 알고리즘
visited = [False] * (v+1)
min_heap = [(0, 1)]  # (가중치, 정점)
answer = 0

while min_heap:
    c, i = heapq.heappop(min_heap)
    
    if visited[i]:
        continue
    
    visited[i] = True
    answer += c
    
    for nc, j in adj[i]:
        if not visited[j]:
            heapq.heappush(min_heap, (nc, j))
            
print(answer)

 

MST를 찾기 위해 Prim 알고리즘은 최소 힙을 사용한다.

  • min_heap에는 (가중치, 정점) 튜플이 들어간다.
  • heappop을 했을 때 나온 정점 i의 visited가 False일 때, 그 가중치 c를 answer에 더해준다.
  • i와 인접한 노드들 중 아직 방문하지 않은 노드들을 min_heap에 heappush한다.

Another Code:

Kruskal 알고리즘으로도 구현이 가능하다.

Kruskal 알고리즘의 핵심은 간선을 가중치 순으로 정렬가장 짧은 간선부터 사이클 여부를 확인사이클이 아니면 연결이다.

import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**6)

# 유니온 파인드 (Disjoint Set) 함수
def find(x):
    if parent[x] != x:
        parent[x] = find(parent[x])
    return parent[x]

def union(x, y):
    x_root = find(x)
    y_root = find(y)
    if x_root != y_root:
        parent[y_root] = x_root
        return True
    return False

v, e = map(int, input().split())
edges = []

for _ in range(e):
    a, b, c = map(int, input().split())
    edges.append((c, a, b))

edges.sort()

parent = [i for i in range(v+1)]

answer = 0
for c, a, b in edges:
    if union(a, b):
        answer += c

print(answer)

 

  • edges 리스트를 가중치 기준으로 정렬한다.
  • 작은 간선부터 하나씩 선택해서 사이클이 발생하지 않으면 연결한다.
  • Union-Find로 사이클 여부를 판별한다.
  • 최종적으로 MST의 가중치 합이 answer에 저장된다.

유니온 파인드 알고리즘이란?

유니온-파인드(Union-Find) 알고리즘은 상호 배타적 집합, Disjoin-set(서로소 집합) 이라고도 부른다. 여러 노드가 존재할 때 어떤 두 개의 노드를 같은 집합으로 묶어 주고, 어떤 두 노드가 같은 집합에 있는지 확인하는 알고리즘이다.
  • 대표적인 그래프 알고리즘으로, 두 노드가 같은 집합에 속하는지 판별하는 알고리즘이다.
  • 합집합 찾기 알고리즘이라고도 부르며, 반대로 서로 연결되어 있지 않은 노드를 판별할 수도 있기 때문에 서로소 집합 (Disjoint-set)이라고도 부른다.
  • 노드를 합치는 Union 연산과 노드의 루트 노드를 찾는 Find 연산으로 이루어진다.
    • Union : 서로 다른 두 개의 집합을 하나의 집합으로 병합하는 연산을 말한다. 일반적으로 합집한 연산과 같다.
    • Find : 하나읜 원소가 어떤 집합에 속해있는지를 판단한다.

[출처]

 - [알고리즘] Union-Find Algorithm (유니온 파인드 알고리즘)

 - [알고리즘] 유니온 파인드 (Union-Find)