1
0
Fork 0

Implement ConcatenableQueue as 2-3 tree

This commit is contained in:
Joris van Rantwijk 2024-07-20 21:47:13 +02:00
parent dc8cdae225
commit e8490010d6
3 changed files with 267 additions and 474 deletions

View File

@ -10,7 +10,7 @@ import math
from collections.abc import Sequence from collections.abc import Sequence
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
from .datastruct import UnionFindQueue, PriorityQueue from .datastruct import ConcatenableQueue, PriorityQueue
def maximum_weight_matching( def maximum_weight_matching(
@ -391,9 +391,10 @@ class Blossom:
# all top-level blossoms in the tree. # all top-level blossoms in the tree.
self.tree_blossoms: Optional[set[Blossom]] = None self.tree_blossoms: Optional[set[Blossom]] = None
# Each top-level blossom maintains a union-find datastructure # Each top-level blossom maintains a concatenable queue containing
# containing all vertices in the blossom. # all vertices in the blossom.
self.vertex_set: UnionFindQueue[Blossom, int] = UnionFindQueue(self) self.vertex_set: ConcatenableQueue[Blossom, int]
self.vertex_set = ConcatenableQueue(self)
# If this is a top-level unlabeled blossom with an edge to an # If this is a top-level unlabeled blossom with an edge to an
# S-blossom, "delta2_node" is the corresponding node in the delta2 # S-blossom, "delta2_node" is the corresponding node in the delta2
@ -554,7 +555,7 @@ class MatchingContext:
self.nontrivial_blossom: set[NonTrivialBlossom] = set() self.nontrivial_blossom: set[NonTrivialBlossom] = set()
# "vertex_set_node[x]" represents the vertex "x" inside the # "vertex_set_node[x]" represents the vertex "x" inside the
# union-find datastructure of its top-level blossom. # concatenable queue of its top-level blossom.
# #
# Initially, each vertex belongs to its own trivial top-level blossom. # Initially, each vertex belongs to its own trivial top-level blossom.
self.vertex_set_node = [b.vertex_set.insert(i, math.inf) self.vertex_set_node = [b.vertex_set.insert(i, math.inf)
@ -668,7 +669,7 @@ class MatchingContext:
if not improved: if not improved:
return return
# Update the priority of "y" in its UnionFindQueue. # Update the priority of "y" in its ConcatenableQueue.
self.vertex_set_node[y].set_prio(prio) self.vertex_set_node[y].set_prio(prio)
# If the blossom is unlabeled and the new edge becomes its least-slack # If the blossom is unlabeled and the new edge becomes its least-slack
@ -701,7 +702,7 @@ class MatchingContext:
else: else:
prio = vertex_sedge_queue.find_min().prio prio = vertex_sedge_queue.find_min().prio
# If necessary, update the priority of "y" in its UnionFindQueue. # If necessary, update priority of "y" in its ConcatenableQueue.
if prio > self.vertex_set_node[y].prio: if prio > self.vertex_set_node[y].prio:
self.vertex_set_node[y].set_prio(prio) self.vertex_set_node[y].set_prio(prio)
if by.label == LABEL_NONE: if by.label == LABEL_NONE:
@ -1235,7 +1236,7 @@ class MatchingContext:
sub.tree_blossoms = None sub.tree_blossoms = None
tree_blossoms.remove(sub) tree_blossoms.remove(sub)
# Merge union-find structures. # Merge concatenable queues.
blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms]) blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms])
@staticmethod @staticmethod
@ -1280,7 +1281,7 @@ class MatchingContext:
# Remove blossom from the delta2 queue. # Remove blossom from the delta2 queue.
self.delta2_disable_blossom(blossom) self.delta2_disable_blossom(blossom)
# Split union-find structure. # Split concatenable queue.
blossom.vertex_set.split() blossom.vertex_set.split()
# Prepare to push lazy delta updates down to the sub-blossoms. # Prepare to push lazy delta updates down to the sub-blossoms.

View File

@ -11,11 +11,11 @@ _ElemT = TypeVar("_ElemT")
_ElemT2 = TypeVar("_ElemT2") _ElemT2 = TypeVar("_ElemT2")
class UnionFindQueue(Generic[_NameT, _ElemT]): class ConcatenableQueue(Generic[_NameT, _ElemT]):
"""Combination of disjoint set and priority queue. """Priority queue supporting efficient merge and split operations.
This is a combination of a disjoint set and a priority queue.
A queue has a "name", which can be any Python object. A queue has a "name", which can be any Python object.
Each element has associated "data", which can be any Python object. Each element has associated "data", which can be any Python object.
Each element has a priority. Each element has a priority.
@ -27,68 +27,70 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
- Merge two or more queues. - Merge two or more queues.
- Undo a previous merge step. - Undo a previous merge step.
The implementation is essentially an AVL tree, with minimum-priority This data structure is implemented as a 2-3 tree with minimum-priority
tracking added to it. tracking added to it.
""" """
__slots__ = ("name", "tree", "first_node", "sub_queues") __slots__ = ("name", "tree", "first_node", "sub_queues")
class Node(Generic[_NameT2, _ElemT2]): class BaseNode(Generic[_NameT2, _ElemT2]):
"""Node in a UnionFindQueue.""" """Node in the 2-3 tree."""
__slots__ = ("owner", "data", "prio", "min_node", "height", __slots__ = ("owner", "min_node", "height", "parent", "childs")
"parent", "left", "right")
def __init__(self, def __init__(self,
owner: UnionFindQueue[_NameT2, _ElemT2], min_node: ConcatenableQueue.Node[_NameT2, _ElemT2],
data: _ElemT2, height: int
prio: float
) -> None: ) -> None:
"""Initialize a new element. """Initialize a new node."""
self.owner: Optional[ConcatenableQueue[_NameT2, _ElemT2]] = None
self.min_node = min_node
self.height = height
self.parent: Optional[ConcatenableQueue.BaseNode[_NameT2,
_ElemT2]]
self.parent = None
self.childs: list[ConcatenableQueue.BaseNode[_NameT2, _ElemT2]]
self.childs = []
class Node(BaseNode[_NameT2, _ElemT2]):
"""Leaf node in the 2-3 tree, representing an element in the queue."""
__slots__ = ("data", "prio")
def __init__(self, data: _ElemT2, prio: float) -> None:
"""Initialize a new leaf node.
This method should not be called directly. This method should not be called directly.
Instead, call UnionFindQueue.insert(). Instead, call ConcatenableQueue.insert().
""" """
self.owner: Optional[UnionFindQueue[_NameT2, _ElemT2]] = owner super().__init__(min_node=self, height=0)
self.data = data self.data = data
self.prio = prio self.prio = prio
self.min_node = self
self.height = 1
self.parent: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]]
self.left: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]]
self.right: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]]
self.parent = None
self.left = None
self.right = None
def find(self) -> _NameT2: def find(self) -> _NameT2:
"""Return the name of the queue that contains this element. """Return the name of the queue that contains this element.
This function takes time O(log(n)). This function takes time O(log(n)).
""" """
node = self node: ConcatenableQueue.BaseNode[_NameT2, _ElemT2] = self
while node.parent is not None: while node.parent is not None:
node = node.parent node = node.parent
assert node.owner is not None assert node.owner is not None
return node.owner.name return node.owner.name
def set_prio(self, prio: float) -> None: def set_prio(self, prio: float) -> None:
"""Change the priority of this element.""" """Change the priority of this element.
This function takes time O(log(n)).
"""
self.prio = prio self.prio = prio
node = self node = self.parent
while True: while node is not None:
min_node = node min_node = node.childs[0].min_node
if node.left is not None: for child in node.childs[1:]:
left_min_node = node.left.min_node if child.min_node.prio < min_node.prio:
if left_min_node.prio < min_node.prio: min_node = child.min_node
min_node = left_min_node
if node.right is not None:
right_min_node = node.right.min_node
if right_min_node.prio < min_node.prio:
min_node = right_min_node
node.min_node = min_node node.min_node = min_node
if node.parent is None:
break
node = node.parent node = node.parent
def __init__(self, name: _NameT) -> None: def __init__(self, name: _NameT) -> None:
@ -100,9 +102,10 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
name: Name to assign to the new queue. name: Name to assign to the new queue.
""" """
self.name = name self.name = name
self.tree: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None self.tree: Optional[ConcatenableQueue.BaseNode[_NameT, _ElemT]] = None
self.first_node: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None self.first_node: Optional[ConcatenableQueue.Node[_NameT, _ElemT]]
self.sub_queues: list[UnionFindQueue[_NameT, _ElemT]] = [] self.first_node = None
self.sub_queues: list[ConcatenableQueue[_NameT, _ElemT]] = []
def clear(self) -> None: def clear(self) -> None:
"""Remove all elements from the queue. """Remove all elements from the queue.
@ -120,12 +123,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
while node is not None: while node is not None:
node.min_node = None # type: ignore node.min_node = None # type: ignore
prev_node = node prev_node = node
if node.left is not None: if node.childs:
node = node.left node = node.childs.pop()
prev_node.left = None
elif node.right is not None:
node = node.right
prev_node.right = None
else: else:
node = node.parent node = node.parent
prev_node.parent = None prev_node.parent = None
@ -143,10 +142,10 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
prio: Initial priority of the new element. prio: Initial priority of the new element.
""" """
assert self.tree is None assert self.tree is None
node = UnionFindQueue.Node(self, elem, prio) self.tree = ConcatenableQueue.Node(elem, prio)
self.tree = node self.tree.owner = self
self.first_node = node self.first_node = self.tree
return node return self.tree
def min_prio(self) -> float: def min_prio(self) -> float:
"""Return the minimum priority of any element in the queue. """Return the minimum priority of any element in the queue.
@ -169,7 +168,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
return node.min_node.data return node.min_node.data
def merge(self, def merge(self,
sub_queues: list[UnionFindQueue[_NameT, _ElemT]] sub_queues: list[ConcatenableQueue[_NameT, _ElemT]]
) -> None: ) -> None:
"""Merge the specified queues. """Merge the specified queues.
@ -210,7 +209,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
subtree.owner = None subtree.owner = None
# Merge our current tree with the tree from the sub-queue. # Merge our current tree with the tree from the sub-queue.
self.tree = self._merge_tree(self.tree, subtree) self.tree = self._join(self.tree, subtree)
# Put the owner pointer in the root node. # Put the owner pointer in the root node.
self.tree.owner = self self.tree.owner = self
@ -252,417 +251,180 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
self.sub_queues = [] self.sub_queues = []
@staticmethod @staticmethod
def _repair_node(node: Node[_NameT, _ElemT]) -> None: def _repair_node(node: BaseNode[_NameT, _ElemT]) -> None:
"""Recalculate the height and min-priority information of the """Repair min_prio attribute of an internal node."""
specified node. min_node = node.childs[0].min_node
""" for child in node.childs[1:]:
if child.min_node.prio < min_node.prio:
# Repair node height. min_node = child.min_node
lh = 0 if node.left is None else node.left.height
rh = 0 if node.right is None else node.right.height
node.height = 1 + max(lh, rh)
# Repair min-priority.
min_node = node
if node.left is not None:
left_min_node = node.left.min_node
if left_min_node.prio < min_node.prio:
min_node = left_min_node
if node.right is not None:
right_min_node = node.right.min_node
if right_min_node.prio < min_node.prio:
min_node = right_min_node
node.min_node = min_node node.min_node = min_node
def _rotate_left(self, @staticmethod
node: Node[_NameT, _ElemT] def _new_internal_node(ltree: BaseNode[_NameT, _ElemT],
) -> Node[_NameT, _ElemT]: rtree: BaseNode[_NameT, _ElemT]
"""Rotate the specified subtree to the left. ) -> BaseNode[_NameT, _ElemT]:
"""Create a new internal node with 2 child nodes."""
Return the new root node of the subtree. assert ltree.height == rtree.height
""" height = ltree.height + 1
# if ltree.min_node.prio <= rtree.min_node.prio:
# N C min_node = ltree.min_node
# / \ / \ else:
# A C ---> N D min_node = rtree.min_node
# / \ / \ node = ConcatenableQueue.BaseNode(min_node, height)
# B D A B node.childs = [ltree, rtree]
# ltree.parent = node
parent = node.parent rtree.parent = node
new_top = node.right
assert new_top is not None
node.right = new_top.left
if node.right is not None:
node.right.parent = node
new_top.left = node
new_top.parent = parent
node.parent = new_top
if parent is not None:
if parent.left is node:
parent.left = new_top
elif parent.right is node:
parent.right = new_top
self._repair_node(node)
self._repair_node(new_top)
return new_top
def _rotate_right(self,
node: Node[_NameT, _ElemT]
) -> Node[_NameT, _ElemT]:
"""Rotate the specified node to the right.
Return the new root node of the subtree.
"""
#
# N A
# / \ / \
# A D ---> B N
# / \ / \
# B C C D
#
parent = node.parent
new_top = node.left
assert new_top is not None
node.left = new_top.right
if node.left is not None:
node.left.parent = node
new_top.right = node
new_top.parent = parent
node.parent = new_top
if parent is not None:
if parent.left is node:
parent.left = new_top
elif parent.right is node:
parent.right = new_top
self._repair_node(node)
self._repair_node(new_top)
return new_top
def _rebalance_up(self,
node: Node[_NameT, _ElemT]
) -> Node[_NameT, _ElemT]:
"""Repair and rebalance the specified node and its ancestors.
Return the root node of the rebalanced tree.
"""
# Walk up to the root of the tree.
while True:
lh = 0 if node.left is None else node.left.height
rh = 0 if node.right is None else node.right.height
if lh > rh + 1:
# This node is left-heavy. Rotate right to rebalance.
#
# N L
# / \ / \
# L \ / N
# / \ \ ---> / / \
# A B \ A B \
# \ \
# R R
#
lchild = node.left
assert lchild is not None
if ((lchild.right is not None)
and ((lchild.left is None)
or (lchild.right.height > lchild.left.height))):
# Double rotation.
lchild = self._rotate_left(lchild)
node = self._rotate_right(node)
elif lh + 1 < rh:
# This node is right-heavy. Rotate left to rebalance.
#
# N R
# / \ / \
# / R N \
# / / \ ---> / \ \
# / A B / A B
# / /
# L L
#
rchild = node.right
assert rchild is not None
if ((rchild.left is not None)
and ((rchild.right is None)
or (rchild.left.height > rchild.right.height))):
# Double rotation.
rchild = self._rotate_right(rchild)
node = self._rotate_left(node)
else:
# No rotation. Must still repair node though.
self._repair_node(node)
if node.parent is None:
break
# Continue rebalancing at the parent.
node = node.parent
# Return new root node.
return node return node
def _join_right(self, def _join_right(self,
ltree: Node[_NameT, _ElemT], ltree: BaseNode[_NameT, _ElemT],
node: Node[_NameT, _ElemT], rtree: BaseNode[_NameT, _ElemT]
rtree: Optional[Node[_NameT, _ElemT]] ) -> BaseNode[_NameT, _ElemT]:
) -> Node[_NameT, _ElemT]: """Join two trees together.
"""Join a left subtree, middle node and right subtree together.
The left subtree must be higher than the right subtree. The initial left subtree must be higher than the right subtree.
Return the root node of the joined tree. Return the root node of the joined tree.
""" """
lh = ltree.height
rh = 0 if rtree is None else rtree.height
assert lh > rh + 1
# Descend down the right spine of "ltree". # Descend down the right spine of the left tree until we
# Stop at a node with compatible height, then insert "node" # reach a node just above the right tree.
# and attach "rtree". node = ltree
# while node.height > rtree.height + 1:
# ltree node = node.childs[-1]
# / \
# X
# / \
# X <-- cur
# / \
# node
# / \
# X rtree
#
# Descend to a point with compatible height. assert node.height == rtree.height + 1
cur = ltree
while (cur.right is not None) and (cur.right.height > rh + 1):
cur = cur.right
# Insert "node" and "rtree". # Find a node in the left tree to insert the right tree as child.
node.left = cur.right while len(node.childs) == 3:
node.right = rtree # This node already has 3 childs so we can not add the right tree.
if node.left is not None: # Rearrange into 2 nodes with 2 childs each, then solve it
node.left.parent = node # at the parent level.
if rtree is not None: #
rtree.parent = node # N N R'
cur.right = node # / | \ / \ / \
node.parent = cur # / | \ ---> / \ / \
# A B C R A B C R
# A double rotation may be necessary. #
if (cur.left is None) or (cur.left.height <= rh): child = node.childs.pop()
node = self._rotate_right(node)
cur = self._rotate_left(cur)
else:
self._repair_node(node) self._repair_node(node)
self._repair_node(cur) rtree = self._new_internal_node(child, rtree)
if node.parent is None:
# Create a new root node.
return self._new_internal_node(node, rtree)
node = node.parent
# Ascend from "cur" to the root of the tree. # Insert the right tree as child of this node.
# Repair and/or rotate as needed. assert len(node.childs) < 3
while cur.parent is not None: node.childs.append(rtree)
cur = cur.parent rtree.parent = node
assert cur.left is not None
assert cur.right is not None
if cur.left.height + 1 < cur.right.height: # Repair min-prio pointers of ancestors.
cur = self._rotate_left(cur) while True:
else: self._repair_node(node)
self._repair_node(cur) if node.parent is None:
break
node = node.parent
return cur return node
def _join_left(self, def _join_left(self,
ltree: Optional[Node[_NameT, _ElemT]], ltree: BaseNode[_NameT, _ElemT],
node: Node[_NameT, _ElemT], rtree: BaseNode[_NameT, _ElemT]
rtree: Node[_NameT, _ElemT] ) -> BaseNode[_NameT, _ElemT]:
) -> Node[_NameT, _ElemT]: """Join two trees together.
"""Join a left subtree, middle node and right subtree together.
The right subtree must be higher than the left subtree. The initial left subtree must be lower than the right subtree.
Return the root node of the joined tree. Return the root node of the joined tree.
""" """
lh = 0 if ltree is None else ltree.height
rh = rtree.height
assert lh + 1 < rh
# Descend down the left spine of "rtree". # Descend down the left spine of the right tree until we
# Stop at a node with compatible height, then insert "node" # reach a node just above the left tree.
# and attach "ltree". node = rtree
# while node.height > ltree.height + 1:
# rtree node = node.childs[0]
# / \
# X
# / \
# cur --> X
# / \
# node
# / \
# ltree X
#
# Descend to a point with compatible height. assert node.height == ltree.height + 1
cur = rtree
while (cur.left is not None) and (cur.left.height > lh + 1):
cur = cur.left
# Insert "node" and "ltree". # Find a node in the right tree to insert the left tree as child.
node.left = ltree while len(node.childs) == 3:
node.right = cur.left # This node already has 3 childs so we can not add the left tree.
if ltree is not None: # Rearrange into 2 nodes with 2 childs each, then solve it
ltree.parent = node # at the parent level.
if node.right is not None: #
node.right.parent = node # N L' N
cur.left = node # / | \ / \ / \
node.parent = cur # / | \ ---> / \ / \
# L A B C L A B C
# A double rotation may be necessary. #
if (cur.right is None) or (cur.right.height <= lh): child = node.childs.pop(0)
node = self._rotate_left(node)
cur = self._rotate_right(cur)
else:
self._repair_node(node) self._repair_node(node)
self._repair_node(cur) ltree = self._new_internal_node(ltree, child)
if node.parent is None:
# Create a new root node.
return self._new_internal_node(ltree, node)
node = node.parent
# Ascend from "cur" to the root of the tree. # Insert the left tree as child of this node.
# Repair and/or rotate as needed. assert len(node.childs) < 3
while cur.parent is not None: node.childs.insert(0, ltree)
cur = cur.parent ltree.parent = node
assert cur.left is not None
assert cur.right is not None
if cur.left.height > cur.right.height + 1: # Repair min-prio pointers of ancestors.
cur = self._rotate_right(cur) while True:
else: self._repair_node(node)
self._repair_node(cur) if node.parent is None:
break
node = node.parent
return cur return node
def _join(self, def _join(self,
ltree: Optional[Node[_NameT, _ElemT]], ltree: BaseNode[_NameT, _ElemT],
node: Node[_NameT, _ElemT], rtree: BaseNode[_NameT, _ElemT]
rtree: Optional[Node[_NameT, _ElemT]] ) -> BaseNode[_NameT, _ElemT]:
) -> Node[_NameT, _ElemT]: """Join two trees together.
"""Join a left subtree, middle node and right subtree together.
The left or right subtree may initially be a child of the middle The left and right subtree must be consistent 2-3 trees.
node; such links will be broken as needed. Initial parent pointers of these subtrees are ignored.
The left and right subtrees must be consistent, AVL-balanced trees.
Parent pointers of the subtrees are ignored.
The middle node is considered as a single node.
Its parent and child pointers are ignored.
Return the root node of the joined tree. Return the root node of the joined tree.
""" """
lh = 0 if ltree is None else ltree.height if ltree.height > rtree.height:
rh = 0 if rtree is None else rtree.height return self._join_right(ltree, rtree)
elif ltree.height < rtree.height:
if lh > rh + 1: return self._join_left(ltree, rtree)
assert ltree is not None
ltree.parent = None
return self._join_right(ltree, node, rtree)
elif lh + 1 < rh:
assert rtree is not None
rtree.parent = None
return self._join_left(ltree, node, rtree)
else: else:
# Subtree heights are compatible. Just join them. return self._new_internal_node(ltree, rtree)
#
# node
# / \
# ltree rtree
# / \ / \
#
node.parent = None
node.left = ltree
if ltree is not None:
ltree.parent = node
node.right = rtree
if rtree is not None:
rtree.parent = node
self._repair_node(node)
return node
def _merge_tree(self,
ltree: Node[_NameT, _ElemT],
rtree: Node[_NameT, _ElemT]
) -> Node[_NameT, _ElemT]:
"""Merge two trees.
Return the root node of the merged tree.
"""
# Find the left-most node of the right tree.
split_node = rtree
while split_node.left is not None:
split_node = split_node.left
# Delete the split_node from its tree.
parent = split_node.parent
if split_node.right is not None:
split_node.right.parent = parent
if parent is None:
rtree_new = split_node.right
else:
# Repair and rebalance the ancestors of split_node.
parent.left = split_node.right
rtree_new = self._rebalance_up(parent)
# Join the two trees via the split_node.
return self._join(ltree, split_node, rtree_new)
def _split_tree(self, def _split_tree(self,
split_node: Node[_NameT, _ElemT] split_node: BaseNode[_NameT, _ElemT]
) -> tuple[Node[_NameT, _ElemT], Node[_NameT, _ElemT]]: ) -> tuple[BaseNode[_NameT, _ElemT],
BaseNode[_NameT, _ElemT]]:
"""Split a tree on a specified node. """Split a tree on a specified node.
Two new trees will be constructed. Two new trees will be constructed.
All nodes to the left of "split_node" will go to the left tree. Leaf nodes to the left of "split_node" will go to the left tree.
All nodes to the right of "split_node", and "split_node" itself, Leaf nodes to the right of "split_node", and "split_node" itself,
will go to the right tree. will go to the right tree.
Return tuple (ltree, rtree), Return tuple (ltree, rtree).
where ltree contains all nodes left of the split-node,
rtree contains the split-nodes and all nodes to its right.
""" """
# Assign the descendants of "split_node" to the appropriate trees # Detach "split_node" from its parent.
# and detach them from "split_node". # Assign it to the right tree.
ltree = split_node.left
rtree = split_node.right
split_node.left = None
split_node.right = None
if ltree is not None:
ltree.parent = None
if rtree is not None:
rtree.parent = None
# Detach "split_node" from its parent (if any).
parent = split_node.parent parent = split_node.parent
split_node.parent = None split_node.parent = None
# Assign "split_node" to the right tree. # The left tree is initially empty.
rtree = self._join(None, split_node, rtree) # The right tree initially contains only "split_node".
ltree: Optional[ConcatenableQueue.BaseNode[_NameT, _ElemT]] = None
rtree = split_node
# Walk up to the root of the tree. # Walk up to the root of the tree.
# On the way up, detach each node from its parent and join it, # On the way up, detach each node from its parent and join its
# and its descendants, to the appropriate tree. # child nodes to the appropriate tree.
node = split_node node = split_node
while parent is not None: while parent is not None:
@ -674,16 +436,53 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
# Detach "node" from its parent. # Detach "node" from its parent.
node.parent = None node.parent = None
if node.left is child: if len(node.childs) == 3:
# "split_node" was located in the left subtree of "node". if node.childs[0] is child:
# This implies that "node" must be joined to the right tree. # "node" has 3 child nodes.
rtree = self._join(rtree, node, node.right) # Its left subtree has already been split.
# Turn it into a 2-node and join it to the right tree.
node.childs.pop(0)
self._repair_node(node)
rtree = self._join(rtree, node)
elif node.childs[2] is child:
# "node" has 3 child nodes.
# Its right subtree has already been split.
# Turn it into a 2-node and join it to the left tree.
node.childs.pop()
self._repair_node(node)
if ltree is None:
ltree = node
else:
ltree = self._join(node, ltree)
else:
# "node has 3 child nodes.
# Its middle subtree has already been split.
# Join its left child to the left tree, and its right
# child to the right tree, then delete "node".
node.childs[0].parent = None
node.childs[2].parent = None
if ltree is None:
ltree = node.childs[0]
else:
ltree = self._join(node.childs[0], ltree)
rtree = self._join(rtree, node.childs[2])
elif node.childs[0] is child:
# "node" has 2 child nodes.
# Its left subtree has already been split.
# Join its right child to the right tree, then delete "node".
node.childs[1].parent = None
rtree = self._join(rtree, node.childs[1])
else: else:
# "split_node" was located in the right subtree of "node". # "node" has 2 child nodes.
# This implies that "node" must be joined to the right tree. # Its right subtree has already been split.
assert node.right is child # Join its left child to the left tree, then delete "node".
ltree = self._join(node.left, node, ltree) node.childs[0].parent = None
if ltree is None:
ltree = node.childs[0]
else:
ltree = self._join(node.childs[0], ltree)
assert ltree is not None assert ltree is not None
return (ltree, rtree) return (ltree, rtree)

View File

@ -3,11 +3,11 @@
import random import random
import unittest import unittest
from mwmatching.datastruct import UnionFindQueue, PriorityQueue from mwmatching.datastruct import ConcatenableQueue, PriorityQueue
class TestUnionFindQueue(unittest.TestCase): class TestConcatenableQueue(unittest.TestCase):
"""Test UnionFindQueue.""" """Test ConcatenableQueue."""
def _check_tree(self, queue): def _check_tree(self, queue):
"""Check tree balancing rules and priority info.""" """Check tree balancing rules and priority info."""
@ -20,40 +20,33 @@ class TestUnionFindQueue(unittest.TestCase):
node = nodes.pop() node = nodes.pop()
if node.left is not None:
self.assertIs(node.left.parent, node)
nodes.append(node.left)
if node.right is not None:
self.assertIs(node.right.parent, node)
nodes.append(node.right)
if node is not queue.tree: if node is not queue.tree:
self.assertIsNone(node.owner) self.assertIsNone(node.owner)
lh = 0 if node.left is None else node.left.height if node.height == 0:
rh = 0 if node.right is None else node.right.height self.assertEqual(len(node.childs), 0)
self.assertEqual(node.height, 1 + max(lh, rh)) self.assertIs(node.min_node, node)
else:
self.assertLessEqual(lh, rh + 1) self.assertIn(len(node.childs), (2, 3))
self.assertLessEqual(rh, lh + 1) best_node = set()
best_prio = None
best_node = {node} for child in node.childs:
best_prio = node.prio self.assertIs(child.parent, node)
for child in (node.left, node.right): self.assertEqual(child.height, node.height - 1)
if child is not None: nodes.append(child)
if child.min_node.prio < best_prio: if ((best_prio is None)
best_prio = child.min_node.prio or (child.min_node.prio < best_prio)):
best_node = {child.min_node} best_node = {child.min_node}
best_prio = child.min_node.prio
elif child.min_node.prio == best_prio: elif child.min_node.prio == best_prio:
best_node.add(child.min_node) best_node.add(child.min_node)
self.assertEqual(node.min_node.prio, best_prio) self.assertEqual(node.min_node.prio, best_prio)
self.assertIn(node.min_node, best_node) self.assertIn(node.min_node, best_node)
def test_single(self): def test_single(self):
"""Single element.""" """Single element."""
q = UnionFindQueue("Q") q = ConcatenableQueue("Q")
with self.assertRaises(Exception): with self.assertRaises(Exception):
q.min_prio() q.min_prio()
@ -62,7 +55,7 @@ class TestUnionFindQueue(unittest.TestCase):
q.min_elem() q.min_elem()
n = q.insert("a", 4) n = q.insert("a", 4)
self.assertIsInstance(n, UnionFindQueue.Node) self.assertIsInstance(n, ConcatenableQueue.Node)
self._check_tree(q) self._check_tree(q)
@ -84,22 +77,22 @@ class TestUnionFindQueue(unittest.TestCase):
def test_simple(self): def test_simple(self):
"""Simple test, 5 elements.""" """Simple test, 5 elements."""
q1 = UnionFindQueue("A") q1 = ConcatenableQueue("A")
n1 = q1.insert("a", 5) n1 = q1.insert("a", 5)
q2 = UnionFindQueue("B") q2 = ConcatenableQueue("B")
n2 = q2.insert("b", 6) n2 = q2.insert("b", 6)
q3 = UnionFindQueue("C") q3 = ConcatenableQueue("C")
n3 = q3.insert("c", 7) n3 = q3.insert("c", 7)
q4 = UnionFindQueue("D") q4 = ConcatenableQueue("D")
n4 = q4.insert("d", 4) n4 = q4.insert("d", 4)
q5 = UnionFindQueue("E") q5 = ConcatenableQueue("E")
n5 = q5.insert("e", 3) n5 = q5.insert("e", 3)
q345 = UnionFindQueue("P") q345 = ConcatenableQueue("P")
q345.merge([q3, q4, q5]) q345.merge([q3, q4, q5])
self._check_tree(q345) self._check_tree(q345)
@ -120,7 +113,7 @@ class TestUnionFindQueue(unittest.TestCase):
self.assertEqual(q345.min_prio(), 4) self.assertEqual(q345.min_prio(), 4)
self.assertEqual(q345.min_elem(), "d") self.assertEqual(q345.min_elem(), "d")
q12 = UnionFindQueue("Q") q12 = ConcatenableQueue("Q")
q12.merge([q1, q2]) q12.merge([q1, q2])
self._check_tree(q12) self._check_tree(q12)
@ -129,7 +122,7 @@ class TestUnionFindQueue(unittest.TestCase):
self.assertEqual(q12.min_prio(), 5) self.assertEqual(q12.min_prio(), 5)
self.assertEqual(q12.min_elem(), "a") self.assertEqual(q12.min_elem(), "a")
q12345 = UnionFindQueue("R") q12345 = ConcatenableQueue("R")
q12345.merge([q12, q345]) q12345.merge([q12, q345])
self._check_tree(q12345) self._check_tree(q12345)
@ -201,30 +194,30 @@ class TestUnionFindQueue(unittest.TestCase):
queues = [] queues = []
nodes = [] nodes = []
for i in range(14): for i in range(14):
q = UnionFindQueue(chr(ord("A") + i)) q = ConcatenableQueue(chr(ord("A") + i))
n = q.insert(chr(ord("a") + i), prios[i]) n = q.insert(chr(ord("a") + i), prios[i])
queues.append(q) queues.append(q)
nodes.append(n) nodes.append(n)
q = UnionFindQueue("AB") q = ConcatenableQueue("AB")
q.merge(queues[0:2]) q.merge(queues[0:2])
queues.append(q) queues.append(q)
self._check_tree(q) self._check_tree(q)
self.assertEqual(q.min_prio(), min(prios[0:2])) self.assertEqual(q.min_prio(), min(prios[0:2]))
q = UnionFindQueue("CDE") q = ConcatenableQueue("CDE")
q.merge(queues[2:5]) q.merge(queues[2:5])
queues.append(q) queues.append(q)
self._check_tree(q) self._check_tree(q)
self.assertEqual(q.min_prio(), min(prios[2:5])) self.assertEqual(q.min_prio(), min(prios[2:5]))
q = UnionFindQueue("FGHI") q = ConcatenableQueue("FGHI")
q.merge(queues[5:9]) q.merge(queues[5:9])
queues.append(q) queues.append(q)
self._check_tree(q) self._check_tree(q)
self.assertEqual(q.min_prio(), min(prios[5:9])) self.assertEqual(q.min_prio(), min(prios[5:9]))
q = UnionFindQueue("JKLMN") q = ConcatenableQueue("JKLMN")
q.merge(queues[9:14]) q.merge(queues[9:14])
queues.append(q) queues.append(q)
self._check_tree(q) self._check_tree(q)
@ -239,7 +232,7 @@ class TestUnionFindQueue(unittest.TestCase):
for i in range(9, 14): for i in range(9, 14):
self.assertEqual(nodes[i].find(), "JKLMN") self.assertEqual(nodes[i].find(), "JKLMN")
q = UnionFindQueue("ALL") q = ConcatenableQueue("ALL")
q.merge(queues[14:18]) q.merge(queues[14:18])
queues.append(q) queues.append(q)
self._check_tree(q) self._check_tree(q)
@ -298,7 +291,7 @@ class TestUnionFindQueue(unittest.TestCase):
for i in range(4000): for i in range(4000):
name = f"q{i}" name = f"q{i}"
q = UnionFindQueue(name) q = ConcatenableQueue(name)
p = rng.random() p = rng.random()
n = q.insert(f"n{i}", p) n = q.insert(f"n{i}", p)
nodes.append(n) nodes.append(n)
@ -326,7 +319,7 @@ class TestUnionFindQueue(unittest.TestCase):
subs = rng.sample(sorted(live_queues), k) subs = rng.sample(sorted(live_queues), k)
name = f"Q{i}" name = f"Q{i}"
q = UnionFindQueue(name) q = ConcatenableQueue(name)
q.merge([queues[nn] for nn in subs]) q.merge([queues[nn] for nn in subs])
self._check_tree(q) self._check_tree(q)
queues[name] = q queues[name] = q