Pages

Union Find

I have a (possibly huge) bunch of edges describing a forest of graphs, and I'd like to know how many components it actually has. This problem has a simple solution if we shovel the edges in a union-find data structure, and then just ask it for that piece of information.

Besides the number of components, our structure keeps track of the id associated to each node, and the size of each component. Here is how the constructor for my python implementation looks:
class UnionFind:
    def __init__(self, n):  # 1
        self.count = n
        self.id = [i for i in range(n)]  # 2
        self.sz = [1 for i in range(n)]  # 3
1. If the union-find is created for n nodes, initially the number of components, named count, is n itself.
2. All the nodes in a component have the same id, initially the id is simply the index of each node.
3. In the beginning, each node is a component on its own, so the size is initialized to one for each of them.

This simple ADT has two operations, union and find, hence its name. The first gets in input an edge and, if the two nodes are in different components, joins them. The latter returns the id of the passed node.

Besides, the client code would check the count data member to see how many components are in. Pythonically, this is exposure of internal status is not perceived as horrifying. A more conservative implementation would mediate this access with a getter.

Moreover, a utility method is provided to check if two node are connected. This is not a strict necessity, still makes the user code more readable:
def connected(self, p, q):
    return self.find(p) == self.find(q)
The meaning of this method is crystal clear. Two nodes are connected only if they have the same id.

In this implementation, we connect two nodes making them share the same id. So, if we call union() on p and q, we'll change the id of one of them to assume the other one. Given this approach, we implement find() in this way:
def find(self, p):
    while p != self.id[p]:
        p = self.id[p]
    return p
If the passed p has id different from its default value, we check the other node until we find one that has its original value, that is the id of the component.

We could implement union() picking up randomly which id keep among the two passed, but we want keep low the operational costs, so we work it out so to keep low the height of the tree representing nodes in a component, leading to O(log n) find() complexity.
def union(self, p, q):
    i = self.find(p)
    j = self.find(q)
    if i != j:  # 1
        self.count -= 1  # 2
        if self.sz[i] < self.sz[j]:  # 3
            self.id[i] = j
            self.sz[j] += self.sz[i]
        else:
            self.id[j] = i
            self.sz[i] += self.sz[j]
1. If the two nodes are already in the same component, there is no need of doing anything more.
2. We are joining two components, their total number in the union-find decrease.
3. This is the smart trick to keep low the cost of find(). We decide which id to keep as representative for the component accordingly with the sizes of the two merging ones.

As example, consider this:
uf = UnionFind(10)
uf.union(4, 3)
uf.union(3, 8)
uf.union(6, 5)
uf.union(9, 4)
uf.union(2, 1)
uf.union(5, 0)
uf.union(7, 2)
uf.union(6, 1)
I created a union-find for nodes in [0..9], specifying eight edges among them, from (4, 3) to (6, 1).
As a result, I expect two components and, for instance, to see that node 2 and node 6 are connected, whilst 4 and 5 not.

I based my python code on the Java implementation provided by Robert Sedgewick and Kevin Wayne in their excellent Algorithms, 4th Edition, following the weighted quick-union variant. Check it out also for a better algorithm description.

I pushed to GitHub full code for the python class, and a test case for the example described above.

No comments:

Post a Comment