From dc8cdae225858ff288e3648e3ad554d6daf587fa Mon Sep 17 00:00:00 2001 From: Joris van Rantwijk Date: Sat, 20 Jul 2024 15:13:36 +0200 Subject: [PATCH] Minor simplification in UnionFindQueue --- python/mwmatching/datastruct.py | 35 +++++++++++++++------------------ 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/python/mwmatching/datastruct.py b/python/mwmatching/datastruct.py index c887669..00066f5 100644 --- a/python/mwmatching/datastruct.py +++ b/python/mwmatching/datastruct.py @@ -31,7 +31,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): tracking added to it. """ - __slots__ = ("name", "tree", "sub_queues", "split_nodes") + __slots__ = ("name", "tree", "first_node", "sub_queues") class Node(Generic[_NameT2, _ElemT2]): """Node in a UnionFindQueue.""" @@ -101,8 +101,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): """ self.name = name self.tree: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None + self.first_node: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None self.sub_queues: list[UnionFindQueue[_NameT, _ElemT]] = [] - self.split_nodes: list[UnionFindQueue.Node[_NameT, _ElemT]] = [] def clear(self) -> None: """Remove all elements from the queue. @@ -111,8 +111,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): """ node = self.tree self.tree = None + self.first_node = None self.sub_queues = [] - self.split_nodes.clear() # Wipe pointers to enable refcounted garbage collection. if node is not None: @@ -143,8 +143,10 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): prio: Initial priority of the new element. """ assert self.tree is None - self.tree = UnionFindQueue.Node(self, elem, prio) - return self.tree + node = UnionFindQueue.Node(self, elem, prio) + self.tree = node + self.first_node = node + return node def min_prio(self) -> float: """Return the minimum priority of any element in the queue. @@ -184,7 +186,6 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): """ assert self.tree is None assert not self.sub_queues - assert not self.split_nodes assert sub_queues # Keep the list of sub-queues. @@ -193,6 +194,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): # Move the root node from the first sub-queue to this queue. # Clear its owner pointer. self.tree = sub_queues[0].tree + self.first_node = sub_queues[0].first_node assert self.tree is not None sub_queues[0].tree = None self.tree.owner = None @@ -208,10 +210,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): subtree.owner = None # Merge our current tree with the tree from the sub-queue. - (self.tree, split_node) = self._merge_tree(self.tree, subtree) - - # Keep track of the left-most node from the sub-queue. - self.split_nodes.append(split_node) + self.tree = self._merge_tree(self.tree, subtree) # Put the owner pointer in the root node. self.tree.owner = self @@ -234,10 +233,10 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): self.tree.owner = None # Split the tree to reconstruct each sub-queue. - for (sub, split_node) in zip(self.sub_queues[:0:-1], - self.split_nodes[::-1]): + for sub in self.sub_queues[:0:-1]: - (tree, rtree) = self._split_tree(split_node) + assert sub.first_node is not None + (tree, rtree) = self._split_tree(sub.first_node) # Assign the right tree to the sub-queue. sub.tree = rtree @@ -249,8 +248,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): # Make this queue empty. self.tree = None + self.first_node = None self.sub_queues = [] - self.split_nodes.clear() @staticmethod def _repair_node(node: Node[_NameT, _ElemT]) -> None: @@ -602,10 +601,10 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): def _merge_tree(self, ltree: Node[_NameT, _ElemT], rtree: Node[_NameT, _ElemT] - ) -> tuple[Node[_NameT, _ElemT], Node[_NameT, _ElemT]]: + ) -> Node[_NameT, _ElemT]: """Merge two trees. - Return a tuple (split_node, merged_tree). + Return the root node of the merged tree. """ # Find the left-most node of the right tree. @@ -625,9 +624,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): rtree_new = self._rebalance_up(parent) # Join the two trees via the split_node. - merged_tree = self._join(ltree, split_node, rtree_new) - - return (merged_tree, split_node) + return self._join(ltree, split_node, rtree_new) def _split_tree(self, split_node: Node[_NameT, _ElemT]