diff --git a/python/datastruct.py b/python/datastruct.py new file mode 100644 index 0000000..9c71d72 --- /dev/null +++ b/python/datastruct.py @@ -0,0 +1,811 @@ +""" +Data structures for matching. + +Experimental. +""" + +from typing import Generic, Optional, TypeVar + + +_NameT = TypeVar("_NameT") +_NameT2 = TypeVar("_NameT2") +_ElemT = TypeVar("_ElemT") +_ElemT2 = TypeVar("_ElemT2") + + +class UnionFindQueue(Generic[_NameT, _ElemT]): + """Combination of disjoint set and priority queue. + + A queue has a "name", which can be any Python object. + + Each element has associated "data", which can be any Python object. + Each element has a priority. + + The following operations can be done efficiently: + - Create a new queue containing one new element. + - Find the name of the queue that contains a given element. + - Change the priority of a given element. + - Find the element with lowest priority in a given queue. + - Merge two or more queues. + - Undo a previous merge step. + + The implementation is essentially an AVL tree, with minimum-priority + tracking added to it. + """ + + class Node(Generic[_NameT2, _ElemT2]): + """Node in a UnionFindQueue.""" + + def __init__(self, + owner: "UnionFindQueue[_NameT2, _ElemT2]", + data: _ElemT2, + prio: float + ) -> None: + """Initialize a new element. + + This method should not be called directly. + Instead, call UnionFindQueue.insert(). + """ + self.owner: "Optional[UnionFindQueue[_NameT2, _ElemT2]]" = owner + self.data = data + self.prio = prio + self.min_node = self + self.height = 1 + self.parent: "Optional[UnionFindQueue.Node[_NameT2, _ElemT2]]" + self.left: "Optional[UnionFindQueue.Node[_NameT2, _ElemT2]]" + self.right: "Optional[UnionFindQueue.Node[_NameT2, _ElemT2]]" + self.parent = None + self.left = None + self.right = None + + def find(self) -> _NameT2: + """Return the name of the queue that contains this element. + + This function takes time O(log(n)). + """ + node = self + while node.parent is not None: + node = node.parent + assert node.owner is not None + return node.owner.name + + def set_prio(self, prio: float) -> None: + """Change the priority of this element.""" + self.prio = prio + node = self + while True: + min_node = node + if node.left is not None: + left_min_node = node.left.min_node + if left_min_node.prio < min_node.prio: + min_node = left_min_node + if node.right is not None: + right_min_node = node.right.min_node + if right_min_node.prio < min_node.prio: + min_node = right_min_node + node.min_node = min_node + if node.parent is None: + break + node = node.parent + + def __init__(self, name: _NameT) -> None: + """Initialize an empty queue. + + This function takes time O(1). + + Parameters: + name: Name to assign to the new queue. + """ + self.name = name + self.tree: "Optional[UnionFindQueue.Node[_NameT, _ElemT]]" = None + self.sub_queues: "list[UnionFindQueue[_NameT, _ElemT]]" = [] + self.split_nodes: "list[UnionFindQueue.Node[_NameT, _ElemT]]" = [] + + def clear(self) -> None: + """Remove all elements from the queue. + + This function takes time O(n). + """ + node = self.tree + self.tree = None + self.sub_queues = [] + self.split_nodes.clear() + + # Wipe pointers to enable refcounted garbage collection. + while node is not None: + prev_node = node + if node.left is not None: + node = node.left + prev_node.left = None + elif node.right is not None: + node = node.right + prev_node.right = None + else: + node = node.parent + prev_node.parent = None + + def insert(self, elem: _ElemT, prio: float) -> Node[_NameT, _ElemT]: + """Insert an element into the empty queue. + + This function can only be used if the queue is empty. + Non-empty queues can grow only by merging. + + This function takes time O(1). + + Parameters: + elem: Element to insert. + prio: Initial priority of the new element. + """ + assert self.tree is None + self.tree = UnionFindQueue.Node(self, elem, prio) + return self.tree + + def min_prio(self) -> float: + """Return the minimum priority of any element in the queue. + + The queue must be non-empty. + This function takes time O(1). + """ + node = self.tree + assert node is not None + return node.min_node.prio + + def min_elem(self) -> _ElemT: + """Return the element with minimum priority. + + The queue must be non-empty. + This function takes time O(1). + """ + node = self.tree + assert node is not None + return node.min_node.data + + def merge(self, sub_queues: "list[UnionFindQueue[_NameT, _ElemT]]") -> None: + """Merge the specified queues. + + This queue must inititially be empty. + All specified sub-queues must initially be non-empty. + + This function removes all elements from the specified sub-queues + and adds them to this queue. + + After merging, this queue retains a reference to the list of sub-queues. + + This function takes time O(len(sub_queues) * log(n)). + """ + assert self.tree is None + assert not self.sub_queues + assert not self.split_nodes + assert sub_queues + + # Keep the list of sub-queues. + self.sub_queues = sub_queues + + # Move the root node from the first sub-queue to this queue. + # Clear its owner pointer. + self.tree = sub_queues[0].tree + assert self.tree is not None + sub_queues[0].tree = None + self.tree.owner = None + + # Merge remaining sub-queues. + for sub in sub_queues[1:]: + + # Pull the root node from the sub-queue. + # Clear its owner pointer. + subtree = sub.tree + assert subtree is not None + assert subtree.owner is sub + subtree.owner = None + + # Merge our current tree with the tree from the sub-queue. + (self.tree, split_node) = self._merge_tree(self.tree, subtree) + + # Keep track of the left-most node from the sub-queue. + self.split_nodes.append(split_node) + + # Put the owner pointer in the root node. + self.tree.owner = self + + def split(self) -> None: + """Undo the merge step that filled this queue. + + Remove all elements from this queue and put them back in + the sub-queues from which they came. + + After splitting, this queue will be empty. + + This function takes time O(k * log(n)). + """ + assert self.tree is not None + assert self.sub_queues + + # Clear the owner pointer from the root node. + assert self.tree.owner is self + self.tree.owner = None + + # Split the tree to reconstruct each sub-queue. + for (sub, split_node) in zip(self.sub_queues[:0:-1], + self.split_nodes[::-1]): + + (tree, rtree) = self._split_tree(split_node) + + # Assign the right tree to the sub-queue. + sub.tree = rtree + rtree.owner = sub + + # Put the remaining tree in the first sub-queue. + self.sub_queues[0].tree = tree + tree.owner = self.sub_queues[0] + + # Make this queue empty. + self.tree = None + self.sub_queues = [] + self.split_nodes.clear() + + @staticmethod + def _repair_node(node: Node[_NameT, _ElemT]) -> None: + """Recalculate the height and min-priority information of the + specified node. + """ + + # Repair node height. + lh = 0 if node.left is None else node.left.height + rh = 0 if node.right is None else node.right.height + node.height = 1 + max(lh, rh) + + # Repair min-priority. + min_node = node + if node.left is not None: + left_min_node = node.left.min_node + if left_min_node.prio < min_node.prio: + min_node = left_min_node + if node.right is not None: + right_min_node = node.right.min_node + if right_min_node.prio < min_node.prio: + min_node = right_min_node + node.min_node = min_node + + def _rotate_left(self, node: Node[_NameT, _ElemT]) -> Node[_NameT, _ElemT]: + """Rotate the specified subtree to the left. + + Return the new root node of the subtree. + """ + # + # N C + # / \ / \ + # A C ---> N D + # / \ / \ + # B D A B + # + parent = node.parent + new_top = node.right + assert new_top is not None + + node.right = new_top.left + if node.right is not None: + node.right.parent = node + + new_top.left = node + new_top.parent = parent + node.parent = new_top + + if parent is not None: + if parent.left is node: + parent.left = new_top + elif parent.right is node: + parent.right = new_top + + self._repair_node(node) + self._repair_node(new_top) + + return new_top + + def _rotate_right(self, node: Node[_NameT, _ElemT]) -> Node[_NameT, _ElemT]: + """Rotate the specified node to the right. + + Return the new root node of the subtree. + """ + # + # N A + # / \ / \ + # A D ---> B N + # / \ / \ + # B C C D + # + parent = node.parent + new_top = node.left + assert new_top is not None + + node.left = new_top.right + if node.left is not None: + node.left.parent = node + + new_top.right = node + new_top.parent = parent + node.parent = new_top + + if parent is not None: + if parent.left is node: + parent.left = new_top + elif parent.right is node: + parent.right = new_top + + self._repair_node(node) + self._repair_node(new_top) + + return new_top + + def _rebalance_up(self, node: Node[_NameT, _ElemT]) -> Node[_NameT, _ElemT]: + """Repair and rebalance the specified node and its ancestors. + + Return the root node of the rebalanced tree. + """ + + # Walk up to the root of the tree. + while True: + + lh = 0 if node.left is None else node.left.height + rh = 0 if node.right is None else node.right.height + + if lh > rh + 1: + # This node is left-heavy. Rotate right to rebalance. + # + # N L + # / \ / \ + # L \ / N + # / \ \ ---> / / \ + # A B \ A B \ + # \ \ + # R R + # + lchild = node.left + assert lchild is not None + if ((lchild.right is not None) + and ((lchild.left is None) + or (lchild.right.height > lchild.left.height))): + # Double rotation. + lchild = self._rotate_left(lchild) + + node = self._rotate_right(node) + + elif lh + 1 < rh: + # This node is right-heavy. Rotate left to rebalance. + # + # N R + # / \ / \ + # / R N \ + # / / \ ---> / \ \ + # / A B / A B + # / / + # L L + # + rchild = node.right + assert rchild is not None + if ((rchild.left is not None) + and ((rchild.right is None) + or (rchild.left.height > rchild.right.height))): + # Double rotation. + rchild = self._rotate_right(rchild) + + node = self._rotate_left(node) + + else: + # No rotation. Must still repair node though. + self._repair_node(node) + + if node.parent is None: + break + + # Continue rebalancing at the parent. + node = node.parent + + # Return new root node. + return node + + def _join_right(self, + ltree: Node[_NameT, _ElemT], + node: Node[_NameT, _ElemT], + rtree: Optional[Node[_NameT, _ElemT]] + ) -> Node[_NameT, _ElemT]: + """Join a left subtree, middle node and right subtree together. + + The left subtree must be higher than the right subtree. + + Return the root node of the joined tree. + """ + lh = ltree.height + rh = 0 if rtree is None else rtree.height + assert lh > rh + 1 + + # Descend down the right spine of "ltree". + # Stop at a node with compatible height, then insert "node" + # and attach "rtree". + # + # ltree + # / \ + # X + # / \ + # X <-- cur + # / \ + # node + # / \ + # X rtree + # + + # Descend to a point with compatible height. + cur = ltree + while (cur.right is not None) and (cur.right.height > rh + 1): + cur = cur.right + + # Insert "node" and "rtree". + node.left = cur.right + node.right = rtree + if node.left is not None: + node.left.parent = node + if rtree is not None: + rtree.parent = node + cur.right = node + node.parent = cur + + # A double rotation may be necessary. + if (cur.left is None) or (cur.left.height <= rh): + node = self._rotate_right(node) + cur = self._rotate_left(cur) + else: + self._repair_node(node) + self._repair_node(cur) + + # Ascend from "cur" to the root of the tree. + # Repair and/or rotate as needed. + while cur.parent is not None: + cur = cur.parent + assert cur.left is not None + assert cur.right is not None + + if cur.left.height + 1 < cur.right.height: + cur = self._rotate_left(cur) + else: + self._repair_node(cur) + + return cur + + def _join_left(self, + ltree: Optional[Node[_NameT, _ElemT]], + node: Node[_NameT, _ElemT], + rtree: Node[_NameT, _ElemT] + ) -> Node[_NameT, _ElemT]: + """Join a left subtree, middle node and right subtree together. + + The right subtree must be higher than the left subtree. + + Return the root node of the joined tree. + """ + lh = 0 if ltree is None else ltree.height + rh = rtree.height + assert lh + 1 < rh + + # Descend down the left spine of "rtree". + # Stop at a node with compatible height, then insert "node" + # and attach "ltree". + # + # rtree + # / \ + # X + # / \ + # cur --> X + # / \ + # node + # / \ + # ltree X + # + + # Descend to a point with compatible height. + cur = rtree + while (cur.left is not None) and (cur.left.height > lh + 1): + cur = cur.left + + # Insert "node" and "ltree". + node.left = ltree + node.right = cur.left + if ltree is not None: + ltree.parent = node + if node.right is not None: + node.right.parent = node + cur.left = node + node.parent = cur + + # A double rotation may be necessary. + if (cur.right is None) or (cur.right.height <= lh): + node = self._rotate_left(node) + cur = self._rotate_right(cur) + else: + self._repair_node(node) + self._repair_node(cur) + + # Ascend from "cur" to the root of the tree. + # Repair and/or rotate as needed. + while cur.parent is not None: + cur = cur.parent + assert cur.left is not None + assert cur.right is not None + + if cur.left.height > cur.right.height + 1: + cur = self._rotate_right(cur) + else: + self._repair_node(cur) + + return cur + + def _join(self, + ltree: Optional[Node[_NameT, _ElemT]], + node: Node[_NameT, _ElemT], + rtree: Optional[Node[_NameT, _ElemT]] + ) -> Node[_NameT, _ElemT]: + """Join a left subtree, middle node and right subtree together. + + The left or right subtree may initially be a child of the middle + node; such links will be broken as needed. + + The left and right subtrees must be consistent, AVL-balanced trees. + Parent pointers of the subtrees are ignored. + + The middle node is considered as a single node. + Its parent and child pointers are ignored. + + Return the root node of the joined tree. + """ + lh = 0 if ltree is None else ltree.height + rh = 0 if rtree is None else rtree.height + + if lh > rh + 1: + assert ltree is not None + ltree.parent = None + return self._join_right(ltree, node, rtree) + elif lh + 1 < rh: + assert rtree is not None + rtree.parent = None + return self._join_left(ltree, node, rtree) + else: + # Subtree heights are compatible. Just join them. + # + # node + # / \ + # ltree rtree + # / \ / \ + # + node.parent = None + node.left = ltree + if ltree is not None: + ltree.parent = node + node.right = rtree + if rtree is not None: + rtree.parent = node + self._repair_node(node) + return node + + def _merge_tree(self, + ltree: Node[_NameT, _ElemT], + rtree: Node[_NameT, _ElemT] + ) -> tuple[Node[_NameT, _ElemT], Node[_NameT, _ElemT]]: + """Merge two trees. + + Return a tuple (split_node, merged_tree). + """ + + # Find the left-most node of the right tree. + split_node = rtree + while split_node.left is not None: + split_node = split_node.left + + # Delete the split_node from its tree. + parent = split_node.parent + if split_node.right is not None: + split_node.right.parent = parent + if parent is None: + rtree_new = split_node.right + else: + # Repair and rebalance the ancestors of split_node. + parent.left = split_node.right + rtree_new = self._rebalance_up(parent) + + # Join the two trees via the split_node. + merged_tree = self._join(ltree, split_node, rtree_new) + + return (merged_tree, split_node) + + def _split_tree(self, + split_node: Node[_NameT, _ElemT] + ) -> tuple[Node[_NameT, _ElemT], Node[_NameT, _ElemT]]: + """Split a tree on a specified node. + + Two new trees will be constructed. + All nodes to the left of "split_node" will go to the left tree. + All nodes to the right of "split_node", and "split_node" itself, + will go to the right tree. + + Return tuple (ltree, rtree), + where ltree contains all nodes left of the split-node, + rtree contains the split-nodes and all nodes to its right. + """ + + # Assign the descendants of "split_node" to the appropriate trees + # and detach them from "split_node". + ltree = split_node.left + rtree = split_node.right + + split_node.left = None + split_node.right = None + if ltree is not None: + ltree.parent = None + if rtree is not None: + rtree.parent = None + + # Detach "split_node" from its parent (if any). + parent = split_node.parent + split_node.parent = None + + # Assign "split_node" to the right tree. + rtree = self._join(None, split_node, rtree) + + # Walk up to the root of the tree. + # On the way up, detach each node from its parent and join it, + # and its descendants, to the appropriate tree. + node = split_node + while parent is not None: + + # Ascend to the parent node. + child = node + node = parent + parent = node.parent + + # Detach "node" from its parent. + node.parent = None + + if node.left is child: + # "split_node" was located in the left subtree of "node". + # This implies that "node" must be joined to the right tree. + rtree = self._join(rtree, node, node.right) + + else: + # "split_node" was located in the right subtree of "node". + # This implies that "node" must be joined to the right tree. + assert node.right is child + ltree = self._join(node.left, node, ltree) + + assert ltree is not None + return (ltree, rtree) + + +class PriorityQueue(Generic[_ElemT]): + """Priority queue based on a binary heap.""" + + class Node(Generic[_ElemT2]): + """Node in the priority queue.""" + + __slots__ = ("index", "prio", "data") + + def __init__( + self, + index: int, + prio: float, + data: _ElemT2 + ) -> None: + self.index = index + self.prio = prio + self.data = data + + def __init__(self) -> None: + """Initialize an empty queue.""" + self.heap: "list[PriorityQueue.Node[_ElemT]]" = [] + + def clear(self) -> None: + """Remove all elements from the queue. + + This function takes time O(n). + """ + self.heap.clear() + + def empty(self) -> bool: + """Return True if the queue is empty.""" + return (not self.heap) + + def find_min(self) -> Node[_ElemT]: + """Return the minimum-priority node. + + This function takes time O(1). + """ + if not self.heap: + raise IndexError("Queue is empty") + return self.heap[0] + + def _sift_up(self, index: int) -> None: + """Repair the heap along an ascending path to the root.""" + node = self.heap[index] + prio = node.prio + + pos = index + while pos > 0: + tpos = (pos - 1) // 2 + tnode = self.heap[tpos] + if tnode.prio <= prio: + break + tnode.index = pos + self.heap[pos] = tnode + pos = tpos + + if pos != index: + node.index = pos + self.heap[pos] = node + + def _sift_down(self, index: int) -> None: + """Repair the heap along a descending path.""" + num_elem = len(self.heap) + node = self.heap[index] + prio = node.prio + + pos = index + while True: + tpos = 2 * pos + 1 + if tpos >= num_elem: + break + tnode = self.heap[tpos] + + qpos = tpos + 1 + if qpos < num_elem: + qnode = self.heap[qpos] + if qnode.prio <= tnode.prio: + tpos = qpos + tnode = qnode + + if tnode.prio >= prio: + break + + tnode.index = pos + self.heap[pos] = tnode + pos = tpos + + if pos != index: + node.index = pos + self.heap[pos] = node + + def insert(self, prio: float, data: _ElemT) -> Node: + """Insert a new element into the queue. + + This function takes time O(log(n)). + + Returns: + Node that represents the new element. + """ + new_index = len(self.heap) + node = self.Node(new_index, prio, data) + self.heap.append(node) + self._sift_up(new_index) + return node + + def delete(self, elem: Node[_ElemT]) -> None: + """Delete the specified element from the queue. + + This function takes time O(log(n)). + """ + index = elem.index + assert self.heap[index] is elem + + node = self.heap.pop() + if index < len(self.heap): + node.index = index + self.heap[index] = node + if node.prio < elem.prio: + self._sift_up(index) + elif node.prio > elem.prio: + self._sift_down(index) + + def decrease_prio(self, elem: Node[_ElemT], prio: float) -> None: + """Decrease the priority of an existing element in the queue. + + This function takes time O(log(n)). + """ + assert self.heap[elem.index] is elem + assert prio <= elem.prio + elem.prio = prio + self._sift_up(elem.index) diff --git a/python/test_datastruct.py b/python/test_datastruct.py new file mode 100644 index 0000000..8b1f00e --- /dev/null +++ b/python/test_datastruct.py @@ -0,0 +1,491 @@ +"""Unit tests for data structures.""" + +import random +import unittest + +from datastruct import UnionFindQueue, PriorityQueue + + +class TestUnionFindQueue(unittest.TestCase): + """Test UnionFindQueue.""" + + def _check_tree(self, queue): + """Check tree balancing rules and priority info.""" + + self.assertIsNone(queue.tree.parent) + self.assertIs(queue.tree.owner, queue) + + nodes = [queue.tree] + while nodes: + + node = nodes.pop() + + if node.left is not None: + self.assertIs(node.left.parent, node) + nodes.append(node.left) + + if node.right is not None: + self.assertIs(node.right.parent, node) + nodes.append(node.right) + + if node is not queue.tree: + self.assertIsNone(node.owner) + + lh = 0 if node.left is None else node.left.height + rh = 0 if node.right is None else node.right.height + self.assertEqual(node.height, 1 + max(lh, rh)) + + self.assertLessEqual(lh, rh + 1) + self.assertLessEqual(rh, lh + 1) + + best_node = {node} + best_prio = node.prio + for child in (node.left, node.right): + if child is not None: + if child.min_node.prio < best_prio: + best_prio = child.min_node.prio + best_node = {child.min_node} + elif child.min_node.prio == best_prio: + best_node.add(child.min_node) + + self.assertEqual(node.min_node.prio, best_prio) + self.assertIn(node.min_node, best_node) + + def test_single(self): + """Single element.""" + q = UnionFindQueue("Q") + + with self.assertRaises(Exception): + q.min_prio() + + with self.assertRaises(Exception): + q.min_elem() + + n = q.insert("a", 4) + self.assertIsInstance(n, UnionFindQueue.Node) + + self._check_tree(q) + + self.assertEqual(n.find(), "Q") + self.assertEqual(q.min_prio(), 4) + self.assertEqual(q.min_elem(), "a") + + with self.assertRaises(Exception): + q.insert("x", 1) + + n.set_prio(8) + self._check_tree(q) + + self.assertEqual(n.find(), "Q") + self.assertEqual(q.min_prio(), 8) + self.assertEqual(q.min_elem(), "a") + + q.clear() + + def test_simple(self): + """Simple test, 5 elements.""" + q1 = UnionFindQueue("A") + n1 = q1.insert("a", 5) + + q2 = UnionFindQueue("B") + n2 = q2.insert("b", 6) + + q3 = UnionFindQueue("C") + n3 = q3.insert("c", 7) + + q4 = UnionFindQueue("D") + n4 = q4.insert("d", 4) + + q5 = UnionFindQueue("E") + n5 = q5.insert("e", 3) + + q345 = UnionFindQueue("P") + q345.merge([q3, q4, q5]) + self._check_tree(q345) + + self.assertEqual(n1.find(), "A") + self.assertEqual(n3.find(), "P") + self.assertEqual(n4.find(), "P") + self.assertEqual(n5.find(), "P") + self.assertEqual(q345.min_prio(), 3) + self.assertEqual(q345.min_elem(), "e") + + with self.assertRaises(Exception): + q3.min_prio() + + self._check_tree(q345) + n5.set_prio(6) + self._check_tree(q345) + + self.assertEqual(q345.min_prio(), 4) + self.assertEqual(q345.min_elem(), "d") + + q12 = UnionFindQueue("Q") + q12.merge([q1, q2]) + self._check_tree(q12) + + self.assertEqual(n1.find(), "Q") + self.assertEqual(n2.find(), "Q") + self.assertEqual(q12.min_prio(), 5) + self.assertEqual(q12.min_elem(), "a") + + q12345 = UnionFindQueue("R") + q12345.merge([q12, q345]) + self._check_tree(q12345) + + self.assertEqual(n1.find(), "R") + self.assertEqual(n2.find(), "R") + self.assertEqual(n3.find(), "R") + self.assertEqual(n4.find(), "R") + self.assertEqual(n5.find(), "R") + self.assertEqual(q12345.min_prio(), 4) + self.assertEqual(q12345.min_elem(), "d") + + n4.set_prio(8) + self._check_tree(q12345) + + self.assertEqual(q12345.min_prio(), 5) + self.assertEqual(q12345.min_elem(), "a") + + n3.set_prio(2) + self._check_tree(q12345) + + self.assertEqual(q12345.min_prio(), 2) + self.assertEqual(q12345.min_elem(), "c") + + q12345.split() + self._check_tree(q12) + self._check_tree(q345) + + self.assertEqual(n1.find(), "Q") + self.assertEqual(n2.find(), "Q") + self.assertEqual(n3.find(), "P") + self.assertEqual(n4.find(), "P") + self.assertEqual(n5.find(), "P") + self.assertEqual(q12.min_prio(), 5) + self.assertEqual(q12.min_elem(), "a") + self.assertEqual(q345.min_prio(), 2) + self.assertEqual(q345.min_elem(), "c") + + q12.split() + self._check_tree(q1) + self._check_tree(q2) + + q345.split() + self._check_tree(q3) + self._check_tree(q4) + self._check_tree(q5) + + self.assertEqual(n1.find(), "A") + self.assertEqual(n2.find(), "B") + self.assertEqual(n3.find(), "C") + self.assertEqual(n4.find(), "D") + self.assertEqual(n5.find(), "E") + self.assertEqual(q3.min_prio(), 2) + self.assertEqual(q3.min_elem(), "c") + + q1.clear() + q2.clear() + q3.clear() + q4.clear() + q5.clear() + q12.clear() + q345.clear() + q12345.clear() + + def test_medium(self): + """Medium test, 14 elements.""" + + prios = [3, 8, 6, 2, 9, 4, 6, 8, 1, 5, 9, 4, 7, 8] + + queues = [] + nodes = [] + for i in range(14): + q = UnionFindQueue(chr(ord("A") + i)) + n = q.insert(chr(ord("a") + i), prios[i]) + queues.append(q) + nodes.append(n) + + q = UnionFindQueue("AB") + q.merge(queues[0:2]) + queues.append(q) + self._check_tree(q) + self.assertEqual(q.min_prio(), min(prios[0:2])) + + q = UnionFindQueue("CDE") + q.merge(queues[2:5]) + queues.append(q) + self._check_tree(q) + self.assertEqual(q.min_prio(), min(prios[2:5])) + + q = UnionFindQueue("FGHI") + q.merge(queues[5:9]) + queues.append(q) + self._check_tree(q) + self.assertEqual(q.min_prio(), min(prios[5:9])) + + q = UnionFindQueue("JKLMN") + q.merge(queues[9:14]) + queues.append(q) + self._check_tree(q) + self.assertEqual(q.min_prio(), min(prios[9:14])) + + for i in range(0, 2): + self.assertEqual(nodes[i].find(), "AB") + for i in range(2, 5): + self.assertEqual(nodes[i].find(), "CDE") + for i in range(5, 9): + self.assertEqual(nodes[i].find(), "FGHI") + for i in range(9, 14): + self.assertEqual(nodes[i].find(), "JKLMN") + + q = UnionFindQueue("ALL") + q.merge(queues[14:18]) + queues.append(q) + self._check_tree(q) + self.assertEqual(q.min_prio(), 1) + self.assertEqual(q.min_elem(), "i") + + for i in range(14): + self.assertEqual(nodes[i].find(), "ALL") + + prios[8] = 5 + nodes[8].set_prio(prios[8]) + self.assertEqual(q.min_prio(), 2) + self.assertEqual(q.min_elem(), "d") + + q.split() + + for i in range(0, 2): + self.assertEqual(nodes[i].find(), "AB") + for i in range(2, 5): + self.assertEqual(nodes[i].find(), "CDE") + for i in range(5, 9): + self.assertEqual(nodes[i].find(), "FGHI") + for i in range(9, 14): + self.assertEqual(nodes[i].find(), "JKLMN") + + self.assertEqual(queues[14].min_prio(), min(prios[0:2])) + self.assertEqual(queues[15].min_prio(), min(prios[2:5])) + self.assertEqual(queues[16].min_prio(), min(prios[5:9])) + self.assertEqual(queues[17].min_prio(), min(prios[9:14])) + + for q in queues[14:18]: + self._check_tree(q) + q.split() + + for i in range(14): + self._check_tree(queues[i]) + self.assertEqual(nodes[i].find(), chr(ord("A") + i)) + self.assertEqual(queues[i].min_prio(), prios[i]) + self.assertEqual(queues[i].min_elem(), chr(ord("a") + i)) + + for q in queues: + q.clear() + + def test_random(self): + """Pseudo-random test.""" + + rng = random.Random(23456) + + nodes = [] + prios = [] + queues = {} + queue_nodes = {} + queue_subs = {} + live_queues = set() + live_merged_queues = set() + + for i in range(4000): + name = f"q{i}" + q = UnionFindQueue(name) + p = rng.random() + n = q.insert(f"n{i}", p) + nodes.append(n) + prios.append(p) + queues[name] = q + queue_nodes[name] = {i} + live_queues.add(name) + + for i in range(2000): + + for k in range(10): + t = rng.randint(0, len(nodes) - 1) + name = nodes[t].find() + self.assertIn(name, live_queues) + self.assertIn(t, queue_nodes[name]) + p = rng.random() + prios[t] = p + nodes[t].set_prio(p) + pp = min(prios[tt] for tt in queue_nodes[name]) + tt = prios.index(pp) + self.assertEqual(queues[name].min_prio(), pp) + self.assertEqual(queues[name].min_elem(), f"n{tt}") + + k = rng.randint(2, max(2, len(live_queues) // 2 - 400)) + subs = rng.sample(sorted(live_queues), k) + + name = f"Q{i}" + q = UnionFindQueue(name) + q.merge([queues[nn] for nn in subs]) + self._check_tree(q) + queues[name] = q + queue_nodes[name] = set().union(*(queue_nodes[nn] for nn in subs)) + queue_subs[name] = set(subs) + live_queues.difference_update(subs) + live_merged_queues.difference_update(subs) + live_queues.add(name) + live_merged_queues.add(name) + + pp = min(prios[tt] for tt in queue_nodes[name]) + tt = prios.index(pp) + self.assertEqual(q.min_prio(), pp) + self.assertEqual(q.min_elem(), f"n{tt}") + + if len(live_merged_queues) >= 100: + name = rng.choice(sorted(live_merged_queues)) + queues[name].split() + + for nn in queue_subs[name]: + self._check_tree(queues[nn]) + pp = min(prios[tt] for tt in queue_nodes[nn]) + tt = prios.index(pp) + self.assertEqual(queues[nn].min_prio(), pp) + self.assertEqual(queues[nn].min_elem(), f"n{tt}") + live_queues.add(nn) + if nn in queue_subs: + live_merged_queues.add(nn) + + live_merged_queues.remove(name) + live_queues.remove(name) + + del queues[name] + del queue_nodes[name] + del queue_subs[name] + + for q in queues.values(): + q.clear() + + +class TestPriorityQueue(unittest.TestCase): + """Test PriorityQueue.""" + + def test_empty(self): + """Empty queue.""" + q = PriorityQueue() + self.assertTrue(q.empty()) + with self.assertRaises(IndexError): + q.find_min() + + def test_single(self): + """Single element.""" + q = PriorityQueue() + + n1 = q.insert(5, "a") + self.assertEqual(n1.prio, 5) + self.assertEqual(n1.data, "a") + self.assertFalse(q.empty()) + self.assertIs(q.find_min(), n1) + + q.decrease_prio(n1, 3) + self.assertEqual(n1.prio, 3) + self.assertIs(q.find_min(), n1) + + q.delete(n1) + self.assertTrue(q.empty()) + + def test_simple(self): + """A few elements.""" + prios = [9, 4, 7, 5, 8, 6, 4, 5, 2, 6] + labels = "abcdefghij" + + q = PriorityQueue() + + elems = [q.insert(prio, data) for (prio, data) in zip(prios, labels)] + for (n, prio, data) in zip(elems, prios, labels): + self.assertEqual(n.prio, prio) + self.assertEqual(n.data, data) + + self.assertIs(q.find_min(), elems[8]) + + q.decrease_prio(elems[2], 1) + self.assertIs(q.find_min(), elems[2]) + + q.decrease_prio(elems[4], 3) + self.assertIs(q.find_min(), elems[2]) + + q.delete(elems[2]) + self.assertIs(q.find_min(), elems[8]) + + q.delete(elems[8]) + self.assertIs(q.find_min(), elems[4]) + + q.delete(elems[4]) + q.delete(elems[1]) + self.assertIs(q.find_min(), elems[6]) + + q.delete(elems[3]) + q.delete(elems[9]) + self.assertIs(q.find_min(), elems[6]) + + q.delete(elems[6]) + self.assertIs(q.find_min(), elems[7]) + + q.delete(elems[7]) + self.assertIs(q.find_min(), elems[5]) + + self.assertFalse(q.empty()) + q.clear() + self.assertTrue(q.empty()) + + def test_random(self): + """Pseudo-random test.""" + rng = random.Random(34567) + + num_elem = 1000 + + seq = 0 + elems = [] + q = PriorityQueue() + + def check(): + min_prio = min(prio for (n, prio, data) in elems) + m = q.find_min() + self.assertIn((m, m.prio, m.data), elems) + self.assertEqual(m.prio, min_prio) + + for i in range(num_elem): + seq += 1 + prio = rng.randint(0, 1000000) + elems.append((q.insert(prio, seq), prio, seq)) + check() + + for i in range(10000): + p = rng.randint(0, num_elem - 1) + prio = rng.randint(0, elems[p][1]) + q.decrease_prio(elems[p][0], prio) + elems[p] = (elems[p][0], prio, elems[p][2]) + check() + + p = rng.randint(0, num_elem - 1) + q.delete(elems[p][0]) + elems.pop(p) + check() + + seq += 1 + prio = rng.randint(0, 1000000) + elems.append((q.insert(prio, seq), prio, seq)) + check() + + for i in range(num_elem): + p = rng.randint(0, num_elem - 1 - i) + q.delete(elems[p][0]) + elems.pop(p) + if elems: + check() + + self.assertTrue(q.empty()) + + +if __name__ == "__main__": + unittest.main()