""" Union-find data structure. """ from networkx.utils import groups class UnionFind: """Union-find data structure. Each unionFind instance X maintains a family of disjoint sets of hashable objects, supporting the following two methods: - X[item] returns a name for the set containing the given item. Each set is named by an arbitrarily-chosen one of its members; as long as the set remains unchanged it will keep the same name. If the item is not yet part of a set in X, a new singleton set is created for it. - X.union(item1, item2, ...) merges the sets containing each item into a single larger set. If any item is not yet part of a set in X, it is added to X as one of the members of the merged set. Union-find data structure. Based on Josiah Carlson's code, http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/215912 with significant additional changes by D. Eppstein. http://www.ics.uci.edu/~eppstein/PADS/UnionFind.py """ def __init__(self, elements=None): """Create a new empty union-find structure. If *elements* is an iterable, this structure will be initialized with the discrete partition on the given set of elements. """ if elements is None: elements = () self.parents = {} self.weights = {} for x in elements: self.weights[x] = 1 self.parents[x] = x def __getitem__(self, object): """Find and return the name of the set containing the object.""" # check for previously unknown object if object not in self.parents: self.parents[object] = object self.weights[object] = 1 return object # find path of objects leading to the root path = [object] root = self.parents[object] while root != path[-1]: path.append(root) root = self.parents[root] # compress the path and return for ancestor in path: self.parents[ancestor] = root return root def __iter__(self): """Iterate through all items ever found or unioned by this structure. """ return iter(self.parents) def to_sets(self): """Iterates over the sets stored in this structure. For example:: >>> partition = UnionFind("xyz") >>> sorted(map(sorted, partition.to_sets())) [['x'], ['y'], ['z']] >>> partition.union("x", "y") >>> sorted(map(sorted, partition.to_sets())) [['x', 'y'], ['z']] """ # Ensure fully pruned paths for x in self.parents.keys(): _ = self[x] # Evaluated for side-effect only yield from groups(self.parents).values() def union(self, *objects): """Find the sets containing the objects and merge them all.""" # Find the heaviest root according to its weight. roots = iter(sorted({self[x] for x in objects}, key=lambda r: self.weights[r])) try: root = next(roots) except StopIteration: return for r in roots: self.weights[root] += self.weights[r] self.parents[r] = root