Union-Find (Disjoint Set Union)

Khái niệm

Union-Find (hay DSU — Disjoint Set Union) là cấu trúc dữ liệu quản lý tập hợp các phần tử được phân vào các nhóm rời nhau (không giao nhau). Hỗ trợ hai thao tác chính:

  • find(x): Tìm đại diện (root) của nhóm chứa phần tử x
  • union(x, y): Hợp nhất nhóm chứa x với nhóm chứa y

Ứng dụng phổ biến:

  • Kiểm tra hai đỉnh có cùng thành phần liên thông không
  • Thuật toán Kruskal’s MST (thêm cạnh mà không tạo chu trình)
  • Đếm số thành phần liên thông
  • Bài toán offline dynamic connectivity
graph TD; A0[Bước đầu:\n5 đỉnh riêng lẻ\n0 1 2 3 4]; A1[union 0,1:\n0←1 2 3 4]; A2[union 2,3:\n0←1 2←3 4]; A3[union 1,3:\n0←1←3←2 4]; A0 --> A1 --> A2 --> A3;

Cài đặt cơ bản với tối ưu hoá

Hai tối ưu quan trọng giúp đạt độ phức tạp gần như O(1) mỗi thao tác:

  1. Nén đường (Path Compression): Khi find(x), trỏ thẳng x về root
  2. Union theo rank: Gắn cây thấp hơn vào cây cao hơn
class DSU:
    def __init__(self, n):
        """Khởi tạo n phần tử, mỗi phần tử là một nhóm riêng."""
        self.parent = list(range(n))   # parent[i] = i ban đầu
        self.rank   = [0] * n          # rank ~ chiều cao cây

    def find(self, x):
        """Tìm root của nhóm chứa x, với nén đường."""
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])   # Nén đường
        return self.parent[x]

    def union(self, x, y):
        """Hợp nhất nhóm x và nhóm y. Trả về False nếu đã cùng nhóm."""
        rx, ry = self.find(x), self.find(y)
        if rx == ry:
            return False          # Đã cùng nhóm → sẽ tạo chu trình
        # Gắn cây thấp hơn vào cây cao hơn
        if self.rank[rx] < self.rank[ry]:
            rx, ry = ry, rx
        self.parent[ry] = rx
        if self.rank[rx] == self.rank[ry]:
            self.rank[rx] += 1
        return True

    def connected(self, x, y):
        """Kiểm tra x và y có cùng nhóm không."""
        return self.find(x) == self.find(y)

    def count_components(self):
        """Đếm số nhóm (thành phần liên thông)."""
        return sum(1 for i in range(len(self.parent)) if self.find(i) == i)


# Demo với 5 đỉnh
dsu = DSU(5)
print("Ban đầu:", dsu.count_components(), "nhóm")  # 5

dsu.union(0, 1)
dsu.union(2, 3)
print("Sau 2 union:", dsu.count_components(), "nhóm")  # 3

dsu.union(1, 3)
print("Sau 3 union:", dsu.count_components(), "nhóm")  # 2

print("0 và 2 cùng nhóm?", dsu.connected(0, 2))  # True
print("0 và 4 cùng nhóm?", dsu.connected(0, 4))  # False

Ví dụ: Đếm thành phần liên thông

def dem_thanh_phan_lien_thong(n, canh):
    """
    Đếm số thành phần liên thông trong đồ thị vô hướng.
    n: số đỉnh, canh: [(u, v)]
    """
    dsu = DSU(n)
    for u, v in canh:
        dsu.union(u, v)
    return dsu.count_components()

# Đồ thị với 7 đỉnh, 3 thành phần liên thông
canh = [(0,1),(1,2),(3,4),(5,6)]
print(dem_thanh_phan_lien_thong(7, canh))   # 3
# {0,1,2}, {3,4}, {5,6} — nhóm 1, và đỉnh 7 chưa có cạnh... ở đây n=7 → 3 thành phần

Ví dụ: Phát hiện chu trình trong đồ thị vô hướng

def co_chu_trinh(n, canh):
    """
    Kiểm tra đồ thị vô hướng có chu trình không.
    Nếu thêm cạnh (u,v) mà u và v đã cùng nhóm → chu trình.
    """
    dsu = DSU(n)
    for u, v in canh:
        if not dsu.union(u, v):   # union trả False → đã cùng nhóm
            return True
    return False

print(co_chu_trinh(4, [(0,1),(1,2),(2,3)]))          # False (cây)
print(co_chu_trinh(4, [(0,1),(1,2),(2,3),(3,0)]))    # True (chu trình 0-1-2-3-0)
print(co_chu_trinh(4, [(0,1),(1,2),(0,2)]))          # True (chu trình 0-1-2-0)

Bài tập luyện tập

Dễ — Số đảo trong lưới

def dem_dao(luoi):
    """
    Đếm số đảo (thành phần liên thông) trong lưới 0/1.
    '1' là đất, '0' là nước. Kết nối 4 hướng (trên/dưới/trái/phải).
    """
    if not luoi:
        return 0
    rows, cols = len(luoi), len(luoi[0])
    dsu = DSU(rows * cols)

    def idx(r, c):
        return r * cols + c

    for r in range(rows):
        for c in range(cols):
            if luoi[r][c] == '1':
                for dr, dc in [(0,1),(1,0)]:   # Chỉ cần kiểm tra phải và dưới
                    nr, nc = r + dr, c + dc
                    if 0 <= nr < rows and 0 <= nc < cols and luoi[nr][nc] == '1':
                        dsu.union(idx(r,c), idx(nr,nc))

    # Đếm root của các ô đất '1'
    return sum(
        1 for r in range(rows) for c in range(cols)
        if luoi[r][c] == '1' and dsu.find(idx(r,c)) == idx(r,c)
    )

luoi = [
    ["1","1","0","0","0"],
    ["1","1","0","0","0"],
    ["0","0","1","0","0"],
    ["0","0","0","1","1"]
]
print(dem_dao(luoi))   # 3

Trung bình — Thêm cạnh và trả lời truy vấn liên thông

def xu_ly_tran_van(n, lenh):
    """
    Xử lý hai loại lệnh:
    - ("union", u, v): thêm cạnh u-v
    - ("query", u, v): in "YES" nếu u và v liên thông, "NO" nếu không
    """
    dsu = DSU(n)
    ket_qua = []
    for loai, u, v in lenh:
        if loai == "union":
            dsu.union(u, v)
        else:
            ket_qua.append("YES" if dsu.connected(u, v) else "NO")
    return ket_qua

lenh = [
    ("union", 0, 1),
    ("union", 2, 3),
    ("query", 0, 1),   # YES
    ("query", 0, 2),   # NO
    ("union", 1, 2),
    ("query", 0, 3),   # YES
]
print(xu_ly_tran_van(5, lenh))   # ['YES', 'NO', 'YES']

Khó — Tìm cạnh dư thừa (Redundant Connection)

def tim_canh_du_thua(n, canh):
    """
    Cho đồ thị vô hướng n đỉnh được xây dựng bằng cách thêm từng cạnh.
    Khi nào thêm cạnh mà tạo ra chu trình → đó là cạnh dư thừa.
    Trả về cạnh dư thừa cuối cùng.
    """
    dsu = DSU(n + 1)   # Đỉnh 1-indexed
    for u, v in canh:
        if not dsu.union(u, v):
            return (u, v)
    return None

# Ví dụ: thêm 4 cạnh vào đồ thị 4 đỉnh
canh = [(1,2),(1,3),(2,3),(3,4)]
print(tim_canh_du_thua(4, canh))   # (2, 3) — cạnh này tạo chu trình 1-2-3-1

Kết luận

Union-Find với path compression và union theo rank đạt độ phức tạp O(α(n)) mỗi thao tác — gần như O(1) trong thực tế. Đây là cấu trúc dữ liệu không thể thiếu khi cài đặt Kruskal’s MST. Tiếp theo hãy học Dijkstra — thuật toán tìm đường ngắn nhất phổ biến nhất.

Bình luận