1
0
Fork 0

Minor simplification in UnionFindQueue

This commit is contained in:
Joris van Rantwijk 2024-07-20 15:13:36 +02:00
parent e2f5b63a01
commit dc8cdae225
1 changed files with 16 additions and 19 deletions

View File

@ -31,7 +31,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
tracking added to it. tracking added to it.
""" """
__slots__ = ("name", "tree", "sub_queues", "split_nodes") __slots__ = ("name", "tree", "first_node", "sub_queues")
class Node(Generic[_NameT2, _ElemT2]): class Node(Generic[_NameT2, _ElemT2]):
"""Node in a UnionFindQueue.""" """Node in a UnionFindQueue."""
@ -101,8 +101,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
""" """
self.name = name self.name = name
self.tree: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None self.tree: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None
self.first_node: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None
self.sub_queues: list[UnionFindQueue[_NameT, _ElemT]] = [] self.sub_queues: list[UnionFindQueue[_NameT, _ElemT]] = []
self.split_nodes: list[UnionFindQueue.Node[_NameT, _ElemT]] = []
def clear(self) -> None: def clear(self) -> None:
"""Remove all elements from the queue. """Remove all elements from the queue.
@ -111,8 +111,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
""" """
node = self.tree node = self.tree
self.tree = None self.tree = None
self.first_node = None
self.sub_queues = [] self.sub_queues = []
self.split_nodes.clear()
# Wipe pointers to enable refcounted garbage collection. # Wipe pointers to enable refcounted garbage collection.
if node is not None: if node is not None:
@ -143,8 +143,10 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
prio: Initial priority of the new element. prio: Initial priority of the new element.
""" """
assert self.tree is None assert self.tree is None
self.tree = UnionFindQueue.Node(self, elem, prio) node = UnionFindQueue.Node(self, elem, prio)
return self.tree self.tree = node
self.first_node = node
return node
def min_prio(self) -> float: def min_prio(self) -> float:
"""Return the minimum priority of any element in the queue. """Return the minimum priority of any element in the queue.
@ -184,7 +186,6 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
""" """
assert self.tree is None assert self.tree is None
assert not self.sub_queues assert not self.sub_queues
assert not self.split_nodes
assert sub_queues assert sub_queues
# Keep the list of sub-queues. # Keep the list of sub-queues.
@ -193,6 +194,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
# Move the root node from the first sub-queue to this queue. # Move the root node from the first sub-queue to this queue.
# Clear its owner pointer. # Clear its owner pointer.
self.tree = sub_queues[0].tree self.tree = sub_queues[0].tree
self.first_node = sub_queues[0].first_node
assert self.tree is not None assert self.tree is not None
sub_queues[0].tree = None sub_queues[0].tree = None
self.tree.owner = None self.tree.owner = None
@ -208,10 +210,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
subtree.owner = None subtree.owner = None
# Merge our current tree with the tree from the sub-queue. # Merge our current tree with the tree from the sub-queue.
(self.tree, split_node) = self._merge_tree(self.tree, subtree) self.tree = self._merge_tree(self.tree, subtree)
# Keep track of the left-most node from the sub-queue.
self.split_nodes.append(split_node)
# Put the owner pointer in the root node. # Put the owner pointer in the root node.
self.tree.owner = self self.tree.owner = self
@ -234,10 +233,10 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
self.tree.owner = None self.tree.owner = None
# Split the tree to reconstruct each sub-queue. # Split the tree to reconstruct each sub-queue.
for (sub, split_node) in zip(self.sub_queues[:0:-1], for sub in self.sub_queues[:0:-1]:
self.split_nodes[::-1]):
(tree, rtree) = self._split_tree(split_node) assert sub.first_node is not None
(tree, rtree) = self._split_tree(sub.first_node)
# Assign the right tree to the sub-queue. # Assign the right tree to the sub-queue.
sub.tree = rtree sub.tree = rtree
@ -249,8 +248,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
# Make this queue empty. # Make this queue empty.
self.tree = None self.tree = None
self.first_node = None
self.sub_queues = [] self.sub_queues = []
self.split_nodes.clear()
@staticmethod @staticmethod
def _repair_node(node: Node[_NameT, _ElemT]) -> None: def _repair_node(node: Node[_NameT, _ElemT]) -> None:
@ -602,10 +601,10 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
def _merge_tree(self, def _merge_tree(self,
ltree: Node[_NameT, _ElemT], ltree: Node[_NameT, _ElemT],
rtree: Node[_NameT, _ElemT] rtree: Node[_NameT, _ElemT]
) -> tuple[Node[_NameT, _ElemT], Node[_NameT, _ElemT]]: ) -> Node[_NameT, _ElemT]:
"""Merge two trees. """Merge two trees.
Return a tuple (split_node, merged_tree). Return the root node of the merged tree.
""" """
# Find the left-most node of the right tree. # Find the left-most node of the right tree.
@ -625,9 +624,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
rtree_new = self._rebalance_up(parent) rtree_new = self._rebalance_up(parent)
# Join the two trees via the split_node. # Join the two trees via the split_node.
merged_tree = self._join(ltree, split_node, rtree_new) return self._join(ltree, split_node, rtree_new)
return (merged_tree, split_node)
def _split_tree(self, def _split_tree(self,
split_node: Node[_NameT, _ElemT] split_node: Node[_NameT, _ElemT]