CS/Python

파이썬으로 이진 탐색 트리 구현하기

munsik22 2025. 3. 27. 16:03

개요

이진 탐색 트리(Binary Search Tree, BST)는 모든 노드가 다음 조건을 만족해야 한다.

  • 왼쪽 서브트리 노드의 key값은 자신의 노드 key값보다 작아야 한다.
  • 오른쪽 서브트리 노드의 key값은 자신의 노드 key값보다 커야 한다.

 

위의 BST를 중위 순회로 스캔하면 다음과 같이 노드의 key값을 오름차순으로 얻을 수 있다.

1 → 4 → 5 → 6 → 7 → 9 → 11 → 12 → 13 → 14 → 15 → 18

BST의 특징

  • 구조가 단순하다.
  • 중위 순회의 DFS를 통해 노드값을 오름차순으로 얻을 수 있다.
  • 이진 탐색과 비슷한 방식으로 매우 빠른 검색이 가능하다.
  • 노드를 삽입하기가 쉽다.

코드 구현

코드 전문은 Do it! 자료구조와 함께 배우는 알고리즘 입문 파이썬 편에서 확인할 수 있다.

class Node:
    def __init__(self, key, value, left, right):
        self.key = key
        self.value = value
        self.left = left
        self.right = right

class BinarySearchTree:
    def __init__(self):
        self.root = None

    def search(self, key):
        p = self.root
        while True:
            if p is None:
                return None
            if key == p.key:
                return p.value
            elif key < p.key:
                p = p.left
            else:
                p = p.right

    def add(self, key, value):
        def add_node(node, key, value):
            if key == node.key:
                return False
            elif key < node.key:
                if node.left is None:
                    node.left = Node(key, value, None, None)
                else:
                    add_node(node.left, key, value)
            else:
                if node.right is None:
                    node.right = Node(key, value, None, None)
                else:
                    add_node(node.right, key, value)
            return True

        if self.root is None:
            self.root = Node(key, value, None, None)
            return True
        else:
            return add_node(self.root, key, value)

    def remove(self, key):
        p = self.root
        parent = None
        is_left_child = True

        while True:
            if p is None:
                return False

            if key == p.key:
                break
            else:
                parent = p
                if key < p.key:
                    is_left_child = True
                    p = p.left
                else:
                    is_left_child = False
                    p = p.right

        if p.left is None:
            if p is self.root:
                self.root = p.right
            elif is_left_child:
                parent.left = p.right
            else:
                parent.right = p.right
        elif p.right is None:
            if p is self.root:
                self.root = p.left
            elif is_left_child:
                parent.left = p.left
            else:
                parent.right = p.left
        else:
            parent = p
            left = p.left
            is_left_child = True
            while left.right is not None:
                parent = left
                left = left.right
                is_left_child = False

            p.key = left.key
            p.value = left.value
            if is_left_child:
                parent.left = left.left
            else:
                parent.right = left.left

        return True

    def dump(self, reverse = False):
        def print_subtree(node):
            if node is not None:
                print_subtree(node.left)
                print(f"{node.key} {node.value}")
                print_subtree(node.right)

        def print_subtree_rev(node):
            if node is not None:
                print_subtree(node.right)
                print(f"{node.key} {node.value}")
                print_subtree(node.left)

        if reverse:
            print_subtree_rev(self.root)
        else:
            print_subtree(self.root)

    def min_key(self):
        if self.root is None:
            return None
        p = self.root
        while p.left is not None:
            p = p.left
        return p.key

    def max_key(self):
        if self.root is None:
            return None
        p = self.root
        while p.right is not None:
            p = p.right
        return p.key