diff --git a/python/mwmatching/algorithm.py b/python/mwmatching/algorithm.py index f908d30..3b9007e 100644 --- a/python/mwmatching/algorithm.py +++ b/python/mwmatching/algorithm.py @@ -10,7 +10,7 @@ import math from collections.abc import Sequence from typing import NamedTuple, Optional -from .datastruct import UnionFindQueue, PriorityQueue +from .datastruct import ConcatenableQueue, PriorityQueue def maximum_weight_matching( @@ -391,9 +391,10 @@ class Blossom: # all top-level blossoms in the tree. self.tree_blossoms: Optional[set[Blossom]] = None - # Each top-level blossom maintains a union-find datastructure - # containing all vertices in the blossom. - self.vertex_set: UnionFindQueue[Blossom, int] = UnionFindQueue(self) + # Each top-level blossom maintains a concatenable queue containing + # all vertices in the blossom. + self.vertex_set: ConcatenableQueue[Blossom, int] + self.vertex_set = ConcatenableQueue(self) # If this is a top-level unlabeled blossom with an edge to an # S-blossom, "delta2_node" is the corresponding node in the delta2 @@ -554,7 +555,7 @@ class MatchingContext: self.nontrivial_blossom: set[NonTrivialBlossom] = set() # "vertex_set_node[x]" represents the vertex "x" inside the - # union-find datastructure of its top-level blossom. + # concatenable queue of its top-level blossom. # # Initially, each vertex belongs to its own trivial top-level blossom. self.vertex_set_node = [b.vertex_set.insert(i, math.inf) @@ -668,7 +669,7 @@ class MatchingContext: if not improved: return - # Update the priority of "y" in its UnionFindQueue. + # Update the priority of "y" in its ConcatenableQueue. self.vertex_set_node[y].set_prio(prio) # If the blossom is unlabeled and the new edge becomes its least-slack @@ -701,7 +702,7 @@ class MatchingContext: else: prio = vertex_sedge_queue.find_min().prio - # If necessary, update the priority of "y" in its UnionFindQueue. + # If necessary, update priority of "y" in its ConcatenableQueue. if prio > self.vertex_set_node[y].prio: self.vertex_set_node[y].set_prio(prio) if by.label == LABEL_NONE: @@ -1235,7 +1236,7 @@ class MatchingContext: sub.tree_blossoms = None tree_blossoms.remove(sub) - # Merge union-find structures. + # Merge concatenable queues. blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms]) @staticmethod @@ -1280,7 +1281,7 @@ class MatchingContext: # Remove blossom from the delta2 queue. self.delta2_disable_blossom(blossom) - # Split union-find structure. + # Split concatenable queue. blossom.vertex_set.split() # Prepare to push lazy delta updates down to the sub-blossoms. diff --git a/python/mwmatching/datastruct.py b/python/mwmatching/datastruct.py index 00066f5..5996769 100644 --- a/python/mwmatching/datastruct.py +++ b/python/mwmatching/datastruct.py @@ -11,11 +11,11 @@ _ElemT = TypeVar("_ElemT") _ElemT2 = TypeVar("_ElemT2") -class UnionFindQueue(Generic[_NameT, _ElemT]): - """Combination of disjoint set and priority queue. +class ConcatenableQueue(Generic[_NameT, _ElemT]): + """Priority queue supporting efficient merge and split operations. + This is a combination of a disjoint set and a priority queue. A queue has a "name", which can be any Python object. - Each element has associated "data", which can be any Python object. Each element has a priority. @@ -27,68 +27,70 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): - Merge two or more queues. - Undo a previous merge step. - The implementation is essentially an AVL tree, with minimum-priority + This data structure is implemented as a 2-3 tree with minimum-priority tracking added to it. """ __slots__ = ("name", "tree", "first_node", "sub_queues") - class Node(Generic[_NameT2, _ElemT2]): - """Node in a UnionFindQueue.""" + class BaseNode(Generic[_NameT2, _ElemT2]): + """Node in the 2-3 tree.""" - __slots__ = ("owner", "data", "prio", "min_node", "height", - "parent", "left", "right") + __slots__ = ("owner", "min_node", "height", "parent", "childs") def __init__(self, - owner: UnionFindQueue[_NameT2, _ElemT2], - data: _ElemT2, - prio: float + min_node: ConcatenableQueue.Node[_NameT2, _ElemT2], + height: int ) -> None: - """Initialize a new element. + """Initialize a new node.""" + self.owner: Optional[ConcatenableQueue[_NameT2, _ElemT2]] = None + self.min_node = min_node + self.height = height + self.parent: Optional[ConcatenableQueue.BaseNode[_NameT2, + _ElemT2]] + self.parent = None + self.childs: list[ConcatenableQueue.BaseNode[_NameT2, _ElemT2]] + self.childs = [] + + class Node(BaseNode[_NameT2, _ElemT2]): + """Leaf node in the 2-3 tree, representing an element in the queue.""" + + __slots__ = ("data", "prio") + + def __init__(self, data: _ElemT2, prio: float) -> None: + """Initialize a new leaf node. This method should not be called directly. - Instead, call UnionFindQueue.insert(). + Instead, call ConcatenableQueue.insert(). """ - self.owner: Optional[UnionFindQueue[_NameT2, _ElemT2]] = owner + super().__init__(min_node=self, height=0) self.data = data self.prio = prio - self.min_node = self - self.height = 1 - self.parent: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]] - self.left: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]] - self.right: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]] - self.parent = None - self.left = None - self.right = None def find(self) -> _NameT2: """Return the name of the queue that contains this element. This function takes time O(log(n)). """ - node = self + node: ConcatenableQueue.BaseNode[_NameT2, _ElemT2] = self while node.parent is not None: node = node.parent assert node.owner is not None return node.owner.name def set_prio(self, prio: float) -> None: - """Change the priority of this element.""" + """Change the priority of this element. + + This function takes time O(log(n)). + """ self.prio = prio - node = self - while True: - min_node = node - if node.left is not None: - left_min_node = node.left.min_node - if left_min_node.prio < min_node.prio: - min_node = left_min_node - if node.right is not None: - right_min_node = node.right.min_node - if right_min_node.prio < min_node.prio: - min_node = right_min_node + node = self.parent + while node is not None: + min_node = node.childs[0].min_node + for child in node.childs[1:]: + if child.min_node.prio < min_node.prio: + min_node = child.min_node node.min_node = min_node - if node.parent is None: - break node = node.parent def __init__(self, name: _NameT) -> None: @@ -100,9 +102,10 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): name: Name to assign to the new queue. """ self.name = name - self.tree: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None - self.first_node: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None - self.sub_queues: list[UnionFindQueue[_NameT, _ElemT]] = [] + self.tree: Optional[ConcatenableQueue.BaseNode[_NameT, _ElemT]] = None + self.first_node: Optional[ConcatenableQueue.Node[_NameT, _ElemT]] + self.first_node = None + self.sub_queues: list[ConcatenableQueue[_NameT, _ElemT]] = [] def clear(self) -> None: """Remove all elements from the queue. @@ -120,12 +123,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): while node is not None: node.min_node = None # type: ignore prev_node = node - if node.left is not None: - node = node.left - prev_node.left = None - elif node.right is not None: - node = node.right - prev_node.right = None + if node.childs: + node = node.childs.pop() else: node = node.parent prev_node.parent = None @@ -143,10 +142,10 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): prio: Initial priority of the new element. """ assert self.tree is None - node = UnionFindQueue.Node(self, elem, prio) - self.tree = node - self.first_node = node - return node + self.tree = ConcatenableQueue.Node(elem, prio) + self.tree.owner = self + self.first_node = self.tree + return self.tree def min_prio(self) -> float: """Return the minimum priority of any element in the queue. @@ -169,7 +168,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): return node.min_node.data def merge(self, - sub_queues: list[UnionFindQueue[_NameT, _ElemT]] + sub_queues: list[ConcatenableQueue[_NameT, _ElemT]] ) -> None: """Merge the specified queues. @@ -210,7 +209,7 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): subtree.owner = None # Merge our current tree with the tree from the sub-queue. - self.tree = self._merge_tree(self.tree, subtree) + self.tree = self._join(self.tree, subtree) # Put the owner pointer in the root node. self.tree.owner = self @@ -252,417 +251,180 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): self.sub_queues = [] @staticmethod - def _repair_node(node: Node[_NameT, _ElemT]) -> None: - """Recalculate the height and min-priority information of the - specified node. - """ - - # Repair node height. - lh = 0 if node.left is None else node.left.height - rh = 0 if node.right is None else node.right.height - node.height = 1 + max(lh, rh) - - # Repair min-priority. - min_node = node - if node.left is not None: - left_min_node = node.left.min_node - if left_min_node.prio < min_node.prio: - min_node = left_min_node - if node.right is not None: - right_min_node = node.right.min_node - if right_min_node.prio < min_node.prio: - min_node = right_min_node + def _repair_node(node: BaseNode[_NameT, _ElemT]) -> None: + """Repair min_prio attribute of an internal node.""" + min_node = node.childs[0].min_node + for child in node.childs[1:]: + if child.min_node.prio < min_node.prio: + min_node = child.min_node node.min_node = min_node - def _rotate_left(self, - node: Node[_NameT, _ElemT] - ) -> Node[_NameT, _ElemT]: - """Rotate the specified subtree to the left. - - Return the new root node of the subtree. - """ - # - # N C - # / \ / \ - # A C ---> N D - # / \ / \ - # B D A B - # - parent = node.parent - new_top = node.right - assert new_top is not None - - node.right = new_top.left - if node.right is not None: - node.right.parent = node - - new_top.left = node - new_top.parent = parent - node.parent = new_top - - if parent is not None: - if parent.left is node: - parent.left = new_top - elif parent.right is node: - parent.right = new_top - - self._repair_node(node) - self._repair_node(new_top) - - return new_top - - def _rotate_right(self, - node: Node[_NameT, _ElemT] - ) -> Node[_NameT, _ElemT]: - """Rotate the specified node to the right. - - Return the new root node of the subtree. - """ - # - # N A - # / \ / \ - # A D ---> B N - # / \ / \ - # B C C D - # - parent = node.parent - new_top = node.left - assert new_top is not None - - node.left = new_top.right - if node.left is not None: - node.left.parent = node - - new_top.right = node - new_top.parent = parent - node.parent = new_top - - if parent is not None: - if parent.left is node: - parent.left = new_top - elif parent.right is node: - parent.right = new_top - - self._repair_node(node) - self._repair_node(new_top) - - return new_top - - def _rebalance_up(self, - node: Node[_NameT, _ElemT] - ) -> Node[_NameT, _ElemT]: - """Repair and rebalance the specified node and its ancestors. - - Return the root node of the rebalanced tree. - """ - - # Walk up to the root of the tree. - while True: - - lh = 0 if node.left is None else node.left.height - rh = 0 if node.right is None else node.right.height - - if lh > rh + 1: - # This node is left-heavy. Rotate right to rebalance. - # - # N L - # / \ / \ - # L \ / N - # / \ \ ---> / / \ - # A B \ A B \ - # \ \ - # R R - # - lchild = node.left - assert lchild is not None - if ((lchild.right is not None) - and ((lchild.left is None) - or (lchild.right.height > lchild.left.height))): - # Double rotation. - lchild = self._rotate_left(lchild) - - node = self._rotate_right(node) - - elif lh + 1 < rh: - # This node is right-heavy. Rotate left to rebalance. - # - # N R - # / \ / \ - # / R N \ - # / / \ ---> / \ \ - # / A B / A B - # / / - # L L - # - rchild = node.right - assert rchild is not None - if ((rchild.left is not None) - and ((rchild.right is None) - or (rchild.left.height > rchild.right.height))): - # Double rotation. - rchild = self._rotate_right(rchild) - - node = self._rotate_left(node) - - else: - # No rotation. Must still repair node though. - self._repair_node(node) - - if node.parent is None: - break - - # Continue rebalancing at the parent. - node = node.parent - - # Return new root node. + @staticmethod + def _new_internal_node(ltree: BaseNode[_NameT, _ElemT], + rtree: BaseNode[_NameT, _ElemT] + ) -> BaseNode[_NameT, _ElemT]: + """Create a new internal node with 2 child nodes.""" + assert ltree.height == rtree.height + height = ltree.height + 1 + if ltree.min_node.prio <= rtree.min_node.prio: + min_node = ltree.min_node + else: + min_node = rtree.min_node + node = ConcatenableQueue.BaseNode(min_node, height) + node.childs = [ltree, rtree] + ltree.parent = node + rtree.parent = node return node def _join_right(self, - ltree: Node[_NameT, _ElemT], - node: Node[_NameT, _ElemT], - rtree: Optional[Node[_NameT, _ElemT]] - ) -> Node[_NameT, _ElemT]: - """Join a left subtree, middle node and right subtree together. + ltree: BaseNode[_NameT, _ElemT], + rtree: BaseNode[_NameT, _ElemT] + ) -> BaseNode[_NameT, _ElemT]: + """Join two trees together. - The left subtree must be higher than the right subtree. + The initial left subtree must be higher than the right subtree. Return the root node of the joined tree. """ - lh = ltree.height - rh = 0 if rtree is None else rtree.height - assert lh > rh + 1 - # Descend down the right spine of "ltree". - # Stop at a node with compatible height, then insert "node" - # and attach "rtree". - # - # ltree - # / \ - # X - # / \ - # X <-- cur - # / \ - # node - # / \ - # X rtree - # + # Descend down the right spine of the left tree until we + # reach a node just above the right tree. + node = ltree + while node.height > rtree.height + 1: + node = node.childs[-1] - # Descend to a point with compatible height. - cur = ltree - while (cur.right is not None) and (cur.right.height > rh + 1): - cur = cur.right + assert node.height == rtree.height + 1 - # Insert "node" and "rtree". - node.left = cur.right - node.right = rtree - if node.left is not None: - node.left.parent = node - if rtree is not None: - rtree.parent = node - cur.right = node - node.parent = cur - - # A double rotation may be necessary. - if (cur.left is None) or (cur.left.height <= rh): - node = self._rotate_right(node) - cur = self._rotate_left(cur) - else: + # Find a node in the left tree to insert the right tree as child. + while len(node.childs) == 3: + # This node already has 3 childs so we can not add the right tree. + # Rearrange into 2 nodes with 2 childs each, then solve it + # at the parent level. + # + # N N R' + # / | \ / \ / \ + # / | \ ---> / \ / \ + # A B C R A B C R + # + child = node.childs.pop() self._repair_node(node) - self._repair_node(cur) + rtree = self._new_internal_node(child, rtree) + if node.parent is None: + # Create a new root node. + return self._new_internal_node(node, rtree) + node = node.parent - # Ascend from "cur" to the root of the tree. - # Repair and/or rotate as needed. - while cur.parent is not None: - cur = cur.parent - assert cur.left is not None - assert cur.right is not None + # Insert the right tree as child of this node. + assert len(node.childs) < 3 + node.childs.append(rtree) + rtree.parent = node - if cur.left.height + 1 < cur.right.height: - cur = self._rotate_left(cur) - else: - self._repair_node(cur) + # Repair min-prio pointers of ancestors. + while True: + self._repair_node(node) + if node.parent is None: + break + node = node.parent - return cur + return node def _join_left(self, - ltree: Optional[Node[_NameT, _ElemT]], - node: Node[_NameT, _ElemT], - rtree: Node[_NameT, _ElemT] - ) -> Node[_NameT, _ElemT]: - """Join a left subtree, middle node and right subtree together. + ltree: BaseNode[_NameT, _ElemT], + rtree: BaseNode[_NameT, _ElemT] + ) -> BaseNode[_NameT, _ElemT]: + """Join two trees together. - The right subtree must be higher than the left subtree. + The initial left subtree must be lower than the right subtree. Return the root node of the joined tree. """ - lh = 0 if ltree is None else ltree.height - rh = rtree.height - assert lh + 1 < rh - # Descend down the left spine of "rtree". - # Stop at a node with compatible height, then insert "node" - # and attach "ltree". - # - # rtree - # / \ - # X - # / \ - # cur --> X - # / \ - # node - # / \ - # ltree X - # + # Descend down the left spine of the right tree until we + # reach a node just above the left tree. + node = rtree + while node.height > ltree.height + 1: + node = node.childs[0] - # Descend to a point with compatible height. - cur = rtree - while (cur.left is not None) and (cur.left.height > lh + 1): - cur = cur.left + assert node.height == ltree.height + 1 - # Insert "node" and "ltree". - node.left = ltree - node.right = cur.left - if ltree is not None: - ltree.parent = node - if node.right is not None: - node.right.parent = node - cur.left = node - node.parent = cur - - # A double rotation may be necessary. - if (cur.right is None) or (cur.right.height <= lh): - node = self._rotate_left(node) - cur = self._rotate_right(cur) - else: + # Find a node in the right tree to insert the left tree as child. + while len(node.childs) == 3: + # This node already has 3 childs so we can not add the left tree. + # Rearrange into 2 nodes with 2 childs each, then solve it + # at the parent level. + # + # N L' N + # / | \ / \ / \ + # / | \ ---> / \ / \ + # L A B C L A B C + # + child = node.childs.pop(0) self._repair_node(node) - self._repair_node(cur) + ltree = self._new_internal_node(ltree, child) + if node.parent is None: + # Create a new root node. + return self._new_internal_node(ltree, node) + node = node.parent - # Ascend from "cur" to the root of the tree. - # Repair and/or rotate as needed. - while cur.parent is not None: - cur = cur.parent - assert cur.left is not None - assert cur.right is not None + # Insert the left tree as child of this node. + assert len(node.childs) < 3 + node.childs.insert(0, ltree) + ltree.parent = node - if cur.left.height > cur.right.height + 1: - cur = self._rotate_right(cur) - else: - self._repair_node(cur) + # Repair min-prio pointers of ancestors. + while True: + self._repair_node(node) + if node.parent is None: + break + node = node.parent - return cur + return node def _join(self, - ltree: Optional[Node[_NameT, _ElemT]], - node: Node[_NameT, _ElemT], - rtree: Optional[Node[_NameT, _ElemT]] - ) -> Node[_NameT, _ElemT]: - """Join a left subtree, middle node and right subtree together. + ltree: BaseNode[_NameT, _ElemT], + rtree: BaseNode[_NameT, _ElemT] + ) -> BaseNode[_NameT, _ElemT]: + """Join two trees together. - The left or right subtree may initially be a child of the middle - node; such links will be broken as needed. - - The left and right subtrees must be consistent, AVL-balanced trees. - Parent pointers of the subtrees are ignored. - - The middle node is considered as a single node. - Its parent and child pointers are ignored. + The left and right subtree must be consistent 2-3 trees. + Initial parent pointers of these subtrees are ignored. Return the root node of the joined tree. """ - lh = 0 if ltree is None else ltree.height - rh = 0 if rtree is None else rtree.height - - if lh > rh + 1: - assert ltree is not None - ltree.parent = None - return self._join_right(ltree, node, rtree) - elif lh + 1 < rh: - assert rtree is not None - rtree.parent = None - return self._join_left(ltree, node, rtree) + if ltree.height > rtree.height: + return self._join_right(ltree, rtree) + elif ltree.height < rtree.height: + return self._join_left(ltree, rtree) else: - # Subtree heights are compatible. Just join them. - # - # node - # / \ - # ltree rtree - # / \ / \ - # - node.parent = None - node.left = ltree - if ltree is not None: - ltree.parent = node - node.right = rtree - if rtree is not None: - rtree.parent = node - self._repair_node(node) - return node - - def _merge_tree(self, - ltree: Node[_NameT, _ElemT], - rtree: Node[_NameT, _ElemT] - ) -> Node[_NameT, _ElemT]: - """Merge two trees. - - Return the root node of the merged tree. - """ - - # Find the left-most node of the right tree. - split_node = rtree - while split_node.left is not None: - split_node = split_node.left - - # Delete the split_node from its tree. - parent = split_node.parent - if split_node.right is not None: - split_node.right.parent = parent - if parent is None: - rtree_new = split_node.right - else: - # Repair and rebalance the ancestors of split_node. - parent.left = split_node.right - rtree_new = self._rebalance_up(parent) - - # Join the two trees via the split_node. - return self._join(ltree, split_node, rtree_new) + return self._new_internal_node(ltree, rtree) def _split_tree(self, - split_node: Node[_NameT, _ElemT] - ) -> tuple[Node[_NameT, _ElemT], Node[_NameT, _ElemT]]: + split_node: BaseNode[_NameT, _ElemT] + ) -> tuple[BaseNode[_NameT, _ElemT], + BaseNode[_NameT, _ElemT]]: """Split a tree on a specified node. Two new trees will be constructed. - All nodes to the left of "split_node" will go to the left tree. - All nodes to the right of "split_node", and "split_node" itself, + Leaf nodes to the left of "split_node" will go to the left tree. + Leaf nodes to the right of "split_node", and "split_node" itself, will go to the right tree. - Return tuple (ltree, rtree), - where ltree contains all nodes left of the split-node, - rtree contains the split-nodes and all nodes to its right. + Return tuple (ltree, rtree). """ - # Assign the descendants of "split_node" to the appropriate trees - # and detach them from "split_node". - ltree = split_node.left - rtree = split_node.right - - split_node.left = None - split_node.right = None - if ltree is not None: - ltree.parent = None - if rtree is not None: - rtree.parent = None - - # Detach "split_node" from its parent (if any). + # Detach "split_node" from its parent. + # Assign it to the right tree. parent = split_node.parent split_node.parent = None - # Assign "split_node" to the right tree. - rtree = self._join(None, split_node, rtree) + # The left tree is initially empty. + # The right tree initially contains only "split_node". + ltree: Optional[ConcatenableQueue.BaseNode[_NameT, _ElemT]] = None + rtree = split_node # Walk up to the root of the tree. - # On the way up, detach each node from its parent and join it, - # and its descendants, to the appropriate tree. + # On the way up, detach each node from its parent and join its + # child nodes to the appropriate tree. node = split_node while parent is not None: @@ -674,16 +436,53 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): # Detach "node" from its parent. node.parent = None - if node.left is child: - # "split_node" was located in the left subtree of "node". - # This implies that "node" must be joined to the right tree. - rtree = self._join(rtree, node, node.right) + if len(node.childs) == 3: + if node.childs[0] is child: + # "node" has 3 child nodes. + # Its left subtree has already been split. + # Turn it into a 2-node and join it to the right tree. + node.childs.pop(0) + self._repair_node(node) + rtree = self._join(rtree, node) + elif node.childs[2] is child: + # "node" has 3 child nodes. + # Its right subtree has already been split. + # Turn it into a 2-node and join it to the left tree. + node.childs.pop() + self._repair_node(node) + if ltree is None: + ltree = node + else: + ltree = self._join(node, ltree) + else: + # "node has 3 child nodes. + # Its middle subtree has already been split. + # Join its left child to the left tree, and its right + # child to the right tree, then delete "node". + node.childs[0].parent = None + node.childs[2].parent = None + if ltree is None: + ltree = node.childs[0] + else: + ltree = self._join(node.childs[0], ltree) + rtree = self._join(rtree, node.childs[2]) + + elif node.childs[0] is child: + # "node" has 2 child nodes. + # Its left subtree has already been split. + # Join its right child to the right tree, then delete "node". + node.childs[1].parent = None + rtree = self._join(rtree, node.childs[1]) else: - # "split_node" was located in the right subtree of "node". - # This implies that "node" must be joined to the right tree. - assert node.right is child - ltree = self._join(node.left, node, ltree) + # "node" has 2 child nodes. + # Its right subtree has already been split. + # Join its left child to the left tree, then delete "node". + node.childs[0].parent = None + if ltree is None: + ltree = node.childs[0] + else: + ltree = self._join(node.childs[0], ltree) assert ltree is not None return (ltree, rtree) diff --git a/python/tests/test_datastruct.py b/python/tests/test_datastruct.py index 60dd5ea..0dc84aa 100644 --- a/python/tests/test_datastruct.py +++ b/python/tests/test_datastruct.py @@ -3,11 +3,11 @@ import random import unittest -from mwmatching.datastruct import UnionFindQueue, PriorityQueue +from mwmatching.datastruct import ConcatenableQueue, PriorityQueue -class TestUnionFindQueue(unittest.TestCase): - """Test UnionFindQueue.""" +class TestConcatenableQueue(unittest.TestCase): + """Test ConcatenableQueue.""" def _check_tree(self, queue): """Check tree balancing rules and priority info.""" @@ -20,40 +20,33 @@ class TestUnionFindQueue(unittest.TestCase): node = nodes.pop() - if node.left is not None: - self.assertIs(node.left.parent, node) - nodes.append(node.left) - - if node.right is not None: - self.assertIs(node.right.parent, node) - nodes.append(node.right) - if node is not queue.tree: self.assertIsNone(node.owner) - lh = 0 if node.left is None else node.left.height - rh = 0 if node.right is None else node.right.height - self.assertEqual(node.height, 1 + max(lh, rh)) - - self.assertLessEqual(lh, rh + 1) - self.assertLessEqual(rh, lh + 1) - - best_node = {node} - best_prio = node.prio - for child in (node.left, node.right): - if child is not None: - if child.min_node.prio < best_prio: - best_prio = child.min_node.prio + if node.height == 0: + self.assertEqual(len(node.childs), 0) + self.assertIs(node.min_node, node) + else: + self.assertIn(len(node.childs), (2, 3)) + best_node = set() + best_prio = None + for child in node.childs: + self.assertIs(child.parent, node) + self.assertEqual(child.height, node.height - 1) + nodes.append(child) + if ((best_prio is None) + or (child.min_node.prio < best_prio)): best_node = {child.min_node} + best_prio = child.min_node.prio elif child.min_node.prio == best_prio: best_node.add(child.min_node) - self.assertEqual(node.min_node.prio, best_prio) - self.assertIn(node.min_node, best_node) + self.assertEqual(node.min_node.prio, best_prio) + self.assertIn(node.min_node, best_node) def test_single(self): """Single element.""" - q = UnionFindQueue("Q") + q = ConcatenableQueue("Q") with self.assertRaises(Exception): q.min_prio() @@ -62,7 +55,7 @@ class TestUnionFindQueue(unittest.TestCase): q.min_elem() n = q.insert("a", 4) - self.assertIsInstance(n, UnionFindQueue.Node) + self.assertIsInstance(n, ConcatenableQueue.Node) self._check_tree(q) @@ -84,22 +77,22 @@ class TestUnionFindQueue(unittest.TestCase): def test_simple(self): """Simple test, 5 elements.""" - q1 = UnionFindQueue("A") + q1 = ConcatenableQueue("A") n1 = q1.insert("a", 5) - q2 = UnionFindQueue("B") + q2 = ConcatenableQueue("B") n2 = q2.insert("b", 6) - q3 = UnionFindQueue("C") + q3 = ConcatenableQueue("C") n3 = q3.insert("c", 7) - q4 = UnionFindQueue("D") + q4 = ConcatenableQueue("D") n4 = q4.insert("d", 4) - q5 = UnionFindQueue("E") + q5 = ConcatenableQueue("E") n5 = q5.insert("e", 3) - q345 = UnionFindQueue("P") + q345 = ConcatenableQueue("P") q345.merge([q3, q4, q5]) self._check_tree(q345) @@ -120,7 +113,7 @@ class TestUnionFindQueue(unittest.TestCase): self.assertEqual(q345.min_prio(), 4) self.assertEqual(q345.min_elem(), "d") - q12 = UnionFindQueue("Q") + q12 = ConcatenableQueue("Q") q12.merge([q1, q2]) self._check_tree(q12) @@ -129,7 +122,7 @@ class TestUnionFindQueue(unittest.TestCase): self.assertEqual(q12.min_prio(), 5) self.assertEqual(q12.min_elem(), "a") - q12345 = UnionFindQueue("R") + q12345 = ConcatenableQueue("R") q12345.merge([q12, q345]) self._check_tree(q12345) @@ -201,30 +194,30 @@ class TestUnionFindQueue(unittest.TestCase): queues = [] nodes = [] for i in range(14): - q = UnionFindQueue(chr(ord("A") + i)) + q = ConcatenableQueue(chr(ord("A") + i)) n = q.insert(chr(ord("a") + i), prios[i]) queues.append(q) nodes.append(n) - q = UnionFindQueue("AB") + q = ConcatenableQueue("AB") q.merge(queues[0:2]) queues.append(q) self._check_tree(q) self.assertEqual(q.min_prio(), min(prios[0:2])) - q = UnionFindQueue("CDE") + q = ConcatenableQueue("CDE") q.merge(queues[2:5]) queues.append(q) self._check_tree(q) self.assertEqual(q.min_prio(), min(prios[2:5])) - q = UnionFindQueue("FGHI") + q = ConcatenableQueue("FGHI") q.merge(queues[5:9]) queues.append(q) self._check_tree(q) self.assertEqual(q.min_prio(), min(prios[5:9])) - q = UnionFindQueue("JKLMN") + q = ConcatenableQueue("JKLMN") q.merge(queues[9:14]) queues.append(q) self._check_tree(q) @@ -239,7 +232,7 @@ class TestUnionFindQueue(unittest.TestCase): for i in range(9, 14): self.assertEqual(nodes[i].find(), "JKLMN") - q = UnionFindQueue("ALL") + q = ConcatenableQueue("ALL") q.merge(queues[14:18]) queues.append(q) self._check_tree(q) @@ -298,7 +291,7 @@ class TestUnionFindQueue(unittest.TestCase): for i in range(4000): name = f"q{i}" - q = UnionFindQueue(name) + q = ConcatenableQueue(name) p = rng.random() n = q.insert(f"n{i}", p) nodes.append(n) @@ -326,7 +319,7 @@ class TestUnionFindQueue(unittest.TestCase): subs = rng.sample(sorted(live_queues), k) name = f"Q{i}" - q = UnionFindQueue(name) + q = ConcatenableQueue(name) q.merge([queues[nn] for nn in subs]) self._check_tree(q) queues[name] = q