[백준] (1197) 최소 스패닝 트리 [Python]
🔗 문제 링크
https://www.acmicpc.net/problem/1197
문제
그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.
최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.
입력
첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.
그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.
출력
첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.
테스트 케이스
| 예제 입력 | 예제 출력 |
| 3 3 1 2 1 2 3 2 1 3 3 |
3 |
💻 나의 코드
1st Code:
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**6)
v, e = map(int, input().split())
adj = [[float("inf") for _ in range(v+1)] for _ in range(v+1)]
for _ in range(e):
a, b, c = map(int, input().split())
adj[a][b] = c
adj[b][a] = c
visited = [False] * (v+1)
answer = 0
def mts(i):
global answer
if visited[i]:
return
visited[i] = True
min_val = float("inf")
idx = 0
for j in range(1, v+1):
if adj[i][j] < min_val and not visited[j]:
min_val = adj[i][j]
idx = j
if idx > 0:
answer += adj[i][idx]
mts(idx)
for i in range(1, v+1):
if not visited[i]:
mts(i)
print(answer)
🔎 코드의 문제점
- 정확한 MST 알고리즘이 아님
- 이 코드는 단순히 현재 노드에서 가장 가까운 방문 안 한 노드로 이동하는 그리디 방식이다.
- 그러나 Prim 알고리즘이나 Kruskal 알고리즘은 사이클을 방지하고, 전체 그래프의 간선 정보를 전역적으로 고려해야 한다.
- 지금 코드는 로컬 최솟값만 보고 이동해서 사이클을 만들거나, 최적해를 보장하지 못한다.
- MST의 핵심 조건
- 최소 간선들만으로 모든 노드를 연결해야 하고, 사이클이 없어야 한다.
- 이 코드는 visited를 통해 이미 방문한 노드는 다시 방문하지 않지만, 간선 선택 과정 자체에 사이클 검증이 없다.
- 입력 방식
- adj를 인접 행렬로 만들었는데, 이 과정에서 불필요하게 메모리를 많이 사용하고 있다.
- 정점 개수가 많으면 메모리 초과 가능성도 있다.
실제로 코드 제출 결과 메모리 초과가 발생했다.😵
2nd Code:
코드의 문제점을 참고하여, MST 탐색 방식을 Greedy에서 Prim 알고리즘으로 변경했다.
Prim 알고리즘의 핵심은 임의의 시작점에서 시작 → 가장 비용이 적은 간선으로 확장 → 모든 정점을 연결하는 것이다.
import sys
import heapq
input = sys.stdin.readline
v, e = map(int, input().split())
adj = [[] for _ in range(v+1)]
for _ in range(e):
a, b, c = map(int, input().split())
adj[a].append((c, b))
adj[b].append((c, a))
# Prim 알고리즘
visited = [False] * (v+1)
min_heap = [(0, 1)] # (가중치, 정점)
answer = 0
while min_heap:
c, i = heapq.heappop(min_heap)
if visited[i]:
continue
visited[i] = True
answer += c
for nc, j in adj[i]:
if not visited[j]:
heapq.heappush(min_heap, (nc, j))
print(answer)
MST를 찾기 위해 Prim 알고리즘은 최소 힙을 사용한다.
- min_heap에는 (가중치, 정점) 튜플이 들어간다.
- heappop을 했을 때 나온 정점 i의 visited가 False일 때, 그 가중치 c를 answer에 더해준다.
- i와 인접한 노드들 중 아직 방문하지 않은 노드들을 min_heap에 heappush한다.
Another Code:
Kruskal 알고리즘으로도 구현이 가능하다.
Kruskal 알고리즘의 핵심은 간선을 가중치 순으로 정렬 → 가장 짧은 간선부터 사이클 여부를 확인 → 사이클이 아니면 연결이다.
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**6)
# 유니온 파인드 (Disjoint Set) 함수
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(x, y):
x_root = find(x)
y_root = find(y)
if x_root != y_root:
parent[y_root] = x_root
return True
return False
v, e = map(int, input().split())
edges = []
for _ in range(e):
a, b, c = map(int, input().split())
edges.append((c, a, b))
edges.sort()
parent = [i for i in range(v+1)]
answer = 0
for c, a, b in edges:
if union(a, b):
answer += c
print(answer)
- edges 리스트를 가중치 기준으로 정렬한다.
- 작은 간선부터 하나씩 선택해서 사이클이 발생하지 않으면 연결한다.
- Union-Find로 사이클 여부를 판별한다.
- 최종적으로 MST의 가중치 합이 answer에 저장된다.
유니온 파인드 알고리즘이란?
유니온-파인드(Union-Find) 알고리즘은 상호 배타적 집합, Disjoin-set(서로소 집합) 이라고도 부른다. 여러 노드가 존재할 때 어떤 두 개의 노드를 같은 집합으로 묶어 주고, 어떤 두 노드가 같은 집합에 있는지 확인하는 알고리즘이다.
- 대표적인 그래프 알고리즘으로, 두 노드가 같은 집합에 속하는지 판별하는 알고리즘이다.
- 합집합 찾기 알고리즘이라고도 부르며, 반대로 서로 연결되어 있지 않은 노드를 판별할 수도 있기 때문에 서로소 집합 (Disjoint-set)이라고도 부른다.
- 노드를 합치는 Union 연산과 노드의 루트 노드를 찾는 Find 연산으로 이루어진다.
- Union : 서로 다른 두 개의 집합을 하나의 집합으로 병합하는 연산을 말한다. 일반적으로 합집한 연산과 같다.
- Find : 하나읜 원소가 어떤 집합에 속해있는지를 판단한다.
[출처]
- [알고리즘] Union-Find Algorithm (유니온 파인드 알고리즘)