"""Data structures for matching.""" from __future__ import annotations from typing import Generic, Optional, TypeVar _NameT = TypeVar("_NameT") _NameT2 = TypeVar("_NameT2") _ElemT = TypeVar("_ElemT") _ElemT2 = TypeVar("_ElemT2") class UnionFindQueue(Generic[_NameT, _ElemT]): """Combination of disjoint set and priority queue. A queue has a "name", which can be any Python object. Each element has associated "data", which can be any Python object. Each element has a priority. The following operations can be done efficiently: - Create a new queue containing one new element. - Find the name of the queue that contains a given element. - Change the priority of a given element. - Find the element with lowest priority in a given queue. - Merge two or more queues. - Undo a previous merge step. The implementation is essentially an AVL tree, with minimum-priority tracking added to it. """ __slots__ = ("name", "tree", "sub_queues", "split_nodes") class Node(Generic[_NameT2, _ElemT2]): """Node in a UnionFindQueue.""" __slots__ = ("owner", "data", "prio", "min_node", "height", "parent", "left", "right") def __init__(self, owner: UnionFindQueue[_NameT2, _ElemT2], data: _ElemT2, prio: float ) -> None: """Initialize a new element. This method should not be called directly. Instead, call UnionFindQueue.insert(). """ self.owner: Optional[UnionFindQueue[_NameT2, _ElemT2]] = owner self.data = data self.prio = prio self.min_node = self self.height = 1 self.parent: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]] self.left: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]] self.right: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]] self.parent = None self.left = None self.right = None def find(self) -> _NameT2: """Return the name of the queue that contains this element. This function takes time O(log(n)). """ node = self while node.parent is not None: node = node.parent assert node.owner is not None return node.owner.name def set_prio(self, prio: float) -> None: """Change the priority of this element.""" self.prio = prio node = self while True: min_node = node if node.left is not None: left_min_node = node.left.min_node if left_min_node.prio < min_node.prio: min_node = left_min_node if node.right is not None: right_min_node = node.right.min_node if right_min_node.prio < min_node.prio: min_node = right_min_node node.min_node = min_node if node.parent is None: break node = node.parent def __init__(self, name: _NameT) -> None: """Initialize an empty queue. This function takes time O(1). Parameters: name: Name to assign to the new queue. """ self.name = name self.tree: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None self.sub_queues: list[UnionFindQueue[_NameT, _ElemT]] = [] self.split_nodes: list[UnionFindQueue.Node[_NameT, _ElemT]] = [] def clear(self) -> None: """Remove all elements from the queue. This function takes time O(n). """ node = self.tree self.tree = None self.sub_queues = [] self.split_nodes.clear() # Wipe pointers to enable refcounted garbage collection. if node is not None: node.owner = None while node is not None: node.min_node = None # type: ignore prev_node = node if node.left is not None: node = node.left prev_node.left = None elif node.right is not None: node = node.right prev_node.right = None else: node = node.parent prev_node.parent = None def insert(self, elem: _ElemT, prio: float) -> Node[_NameT, _ElemT]: """Insert an element into the empty queue. This function can only be used if the queue is empty. Non-empty queues can grow only by merging. This function takes time O(1). Parameters: elem: Element to insert. prio: Initial priority of the new element. """ assert self.tree is None self.tree = UnionFindQueue.Node(self, elem, prio) return self.tree def min_prio(self) -> float: """Return the minimum priority of any element in the queue. The queue must be non-empty. This function takes time O(1). """ node = self.tree assert node is not None return node.min_node.prio def min_elem(self) -> _ElemT: """Return the element with minimum priority. The queue must be non-empty. This function takes time O(1). """ node = self.tree assert node is not None return node.min_node.data def merge(self, sub_queues: list[UnionFindQueue[_NameT, _ElemT]] ) -> None: """Merge the specified queues. This queue must inititially be empty. All specified sub-queues must initially be non-empty. This function removes all elements from the specified sub-queues and adds them to this queue. After merging, this queue retains a reference to the list of sub-queues. This function takes time O(len(sub_queues) * log(n)). """ assert self.tree is None assert not self.sub_queues assert not self.split_nodes assert sub_queues # Keep the list of sub-queues. self.sub_queues = sub_queues # Move the root node from the first sub-queue to this queue. # Clear its owner pointer. self.tree = sub_queues[0].tree assert self.tree is not None sub_queues[0].tree = None self.tree.owner = None # Merge remaining sub-queues. for sub in sub_queues[1:]: # Pull the root node from the sub-queue. # Clear its owner pointer. subtree = sub.tree assert subtree is not None assert subtree.owner is sub subtree.owner = None # Merge our current tree with the tree from the sub-queue. (self.tree, split_node) = self._merge_tree(self.tree, subtree) # Keep track of the left-most node from the sub-queue. self.split_nodes.append(split_node) # Put the owner pointer in the root node. self.tree.owner = self def split(self) -> None: """Undo the merge step that filled this queue. Remove all elements from this queue and put them back in the sub-queues from which they came. After splitting, this queue will be empty. This function takes time O(k * log(n)). """ assert self.tree is not None assert self.sub_queues # Clear the owner pointer from the root node. assert self.tree.owner is self self.tree.owner = None # Split the tree to reconstruct each sub-queue. for (sub, split_node) in zip(self.sub_queues[:0:-1], self.split_nodes[::-1]): (tree, rtree) = self._split_tree(split_node) # Assign the right tree to the sub-queue. sub.tree = rtree rtree.owner = sub # Put the remaining tree in the first sub-queue. self.sub_queues[0].tree = tree tree.owner = self.sub_queues[0] # Make this queue empty. self.tree = None self.sub_queues = [] self.split_nodes.clear() @staticmethod def _repair_node(node: Node[_NameT, _ElemT]) -> None: """Recalculate the height and min-priority information of the specified node. """ # Repair node height. lh = 0 if node.left is None else node.left.height rh = 0 if node.right is None else node.right.height node.height = 1 + max(lh, rh) # Repair min-priority. min_node = node if node.left is not None: left_min_node = node.left.min_node if left_min_node.prio < min_node.prio: min_node = left_min_node if node.right is not None: right_min_node = node.right.min_node if right_min_node.prio < min_node.prio: min_node = right_min_node node.min_node = min_node def _rotate_left(self, node: Node[_NameT, _ElemT] ) -> Node[_NameT, _ElemT]: """Rotate the specified subtree to the left. Return the new root node of the subtree. """ # # N C # / \ / \ # A C ---> N D # / \ / \ # B D A B # parent = node.parent new_top = node.right assert new_top is not None node.right = new_top.left if node.right is not None: node.right.parent = node new_top.left = node new_top.parent = parent node.parent = new_top if parent is not None: if parent.left is node: parent.left = new_top elif parent.right is node: parent.right = new_top self._repair_node(node) self._repair_node(new_top) return new_top def _rotate_right(self, node: Node[_NameT, _ElemT] ) -> Node[_NameT, _ElemT]: """Rotate the specified node to the right. Return the new root node of the subtree. """ # # N A # / \ / \ # A D ---> B N # / \ / \ # B C C D # parent = node.parent new_top = node.left assert new_top is not None node.left = new_top.right if node.left is not None: node.left.parent = node new_top.right = node new_top.parent = parent node.parent = new_top if parent is not None: if parent.left is node: parent.left = new_top elif parent.right is node: parent.right = new_top self._repair_node(node) self._repair_node(new_top) return new_top def _rebalance_up(self, node: Node[_NameT, _ElemT] ) -> Node[_NameT, _ElemT]: """Repair and rebalance the specified node and its ancestors. Return the root node of the rebalanced tree. """ # Walk up to the root of the tree. while True: lh = 0 if node.left is None else node.left.height rh = 0 if node.right is None else node.right.height if lh > rh + 1: # This node is left-heavy. Rotate right to rebalance. # # N L # / \ / \ # L \ / N # / \ \ ---> / / \ # A B \ A B \ # \ \ # R R # lchild = node.left assert lchild is not None if ((lchild.right is not None) and ((lchild.left is None) or (lchild.right.height > lchild.left.height))): # Double rotation. lchild = self._rotate_left(lchild) node = self._rotate_right(node) elif lh + 1 < rh: # This node is right-heavy. Rotate left to rebalance. # # N R # / \ / \ # / R N \ # / / \ ---> / \ \ # / A B / A B # / / # L L # rchild = node.right assert rchild is not None if ((rchild.left is not None) and ((rchild.right is None) or (rchild.left.height > rchild.right.height))): # Double rotation. rchild = self._rotate_right(rchild) node = self._rotate_left(node) else: # No rotation. Must still repair node though. self._repair_node(node) if node.parent is None: break # Continue rebalancing at the parent. node = node.parent # Return new root node. return node def _join_right(self, ltree: Node[_NameT, _ElemT], node: Node[_NameT, _ElemT], rtree: Optional[Node[_NameT, _ElemT]] ) -> Node[_NameT, _ElemT]: """Join a left subtree, middle node and right subtree together. The left subtree must be higher than the right subtree. Return the root node of the joined tree. """ lh = ltree.height rh = 0 if rtree is None else rtree.height assert lh > rh + 1 # Descend down the right spine of "ltree". # Stop at a node with compatible height, then insert "node" # and attach "rtree". # # ltree # / \ # X # / \ # X <-- cur # / \ # node # / \ # X rtree # # Descend to a point with compatible height. cur = ltree while (cur.right is not None) and (cur.right.height > rh + 1): cur = cur.right # Insert "node" and "rtree". node.left = cur.right node.right = rtree if node.left is not None: node.left.parent = node if rtree is not None: rtree.parent = node cur.right = node node.parent = cur # A double rotation may be necessary. if (cur.left is None) or (cur.left.height <= rh): node = self._rotate_right(node) cur = self._rotate_left(cur) else: self._repair_node(node) self._repair_node(cur) # Ascend from "cur" to the root of the tree. # Repair and/or rotate as needed. while cur.parent is not None: cur = cur.parent assert cur.left is not None assert cur.right is not None if cur.left.height + 1 < cur.right.height: cur = self._rotate_left(cur) else: self._repair_node(cur) return cur def _join_left(self, ltree: Optional[Node[_NameT, _ElemT]], node: Node[_NameT, _ElemT], rtree: Node[_NameT, _ElemT] ) -> Node[_NameT, _ElemT]: """Join a left subtree, middle node and right subtree together. The right subtree must be higher than the left subtree. Return the root node of the joined tree. """ lh = 0 if ltree is None else ltree.height rh = rtree.height assert lh + 1 < rh # Descend down the left spine of "rtree". # Stop at a node with compatible height, then insert "node" # and attach "ltree". # # rtree # / \ # X # / \ # cur --> X # / \ # node # / \ # ltree X # # Descend to a point with compatible height. cur = rtree while (cur.left is not None) and (cur.left.height > lh + 1): cur = cur.left # Insert "node" and "ltree". node.left = ltree node.right = cur.left if ltree is not None: ltree.parent = node if node.right is not None: node.right.parent = node cur.left = node node.parent = cur # A double rotation may be necessary. if (cur.right is None) or (cur.right.height <= lh): node = self._rotate_left(node) cur = self._rotate_right(cur) else: self._repair_node(node) self._repair_node(cur) # Ascend from "cur" to the root of the tree. # Repair and/or rotate as needed. while cur.parent is not None: cur = cur.parent assert cur.left is not None assert cur.right is not None if cur.left.height > cur.right.height + 1: cur = self._rotate_right(cur) else: self._repair_node(cur) return cur def _join(self, ltree: Optional[Node[_NameT, _ElemT]], node: Node[_NameT, _ElemT], rtree: Optional[Node[_NameT, _ElemT]] ) -> Node[_NameT, _ElemT]: """Join a left subtree, middle node and right subtree together. The left or right subtree may initially be a child of the middle node; such links will be broken as needed. The left and right subtrees must be consistent, AVL-balanced trees. Parent pointers of the subtrees are ignored. The middle node is considered as a single node. Its parent and child pointers are ignored. Return the root node of the joined tree. """ lh = 0 if ltree is None else ltree.height rh = 0 if rtree is None else rtree.height if lh > rh + 1: assert ltree is not None ltree.parent = None return self._join_right(ltree, node, rtree) elif lh + 1 < rh: assert rtree is not None rtree.parent = None return self._join_left(ltree, node, rtree) else: # Subtree heights are compatible. Just join them. # # node # / \ # ltree rtree # / \ / \ # node.parent = None node.left = ltree if ltree is not None: ltree.parent = node node.right = rtree if rtree is not None: rtree.parent = node self._repair_node(node) return node def _merge_tree(self, ltree: Node[_NameT, _ElemT], rtree: Node[_NameT, _ElemT] ) -> tuple[Node[_NameT, _ElemT], Node[_NameT, _ElemT]]: """Merge two trees. Return a tuple (split_node, merged_tree). """ # Find the left-most node of the right tree. split_node = rtree while split_node.left is not None: split_node = split_node.left # Delete the split_node from its tree. parent = split_node.parent if split_node.right is not None: split_node.right.parent = parent if parent is None: rtree_new = split_node.right else: # Repair and rebalance the ancestors of split_node. parent.left = split_node.right rtree_new = self._rebalance_up(parent) # Join the two trees via the split_node. merged_tree = self._join(ltree, split_node, rtree_new) return (merged_tree, split_node) def _split_tree(self, split_node: Node[_NameT, _ElemT] ) -> tuple[Node[_NameT, _ElemT], Node[_NameT, _ElemT]]: """Split a tree on a specified node. Two new trees will be constructed. All nodes to the left of "split_node" will go to the left tree. All nodes to the right of "split_node", and "split_node" itself, will go to the right tree. Return tuple (ltree, rtree), where ltree contains all nodes left of the split-node, rtree contains the split-nodes and all nodes to its right. """ # Assign the descendants of "split_node" to the appropriate trees # and detach them from "split_node". ltree = split_node.left rtree = split_node.right split_node.left = None split_node.right = None if ltree is not None: ltree.parent = None if rtree is not None: rtree.parent = None # Detach "split_node" from its parent (if any). parent = split_node.parent split_node.parent = None # Assign "split_node" to the right tree. rtree = self._join(None, split_node, rtree) # Walk up to the root of the tree. # On the way up, detach each node from its parent and join it, # and its descendants, to the appropriate tree. node = split_node while parent is not None: # Ascend to the parent node. child = node node = parent parent = node.parent # Detach "node" from its parent. node.parent = None if node.left is child: # "split_node" was located in the left subtree of "node". # This implies that "node" must be joined to the right tree. rtree = self._join(rtree, node, node.right) else: # "split_node" was located in the right subtree of "node". # This implies that "node" must be joined to the right tree. assert node.right is child ltree = self._join(node.left, node, ltree) assert ltree is not None return (ltree, rtree) class PriorityQueue(Generic[_ElemT]): """Priority queue based on a binary heap.""" __slots__ = ("heap", ) class Node(Generic[_ElemT2]): """Node in the priority queue.""" __slots__ = ("index", "prio", "data") def __init__( self, index: int, prio: float, data: _ElemT2 ) -> None: self.index = index self.prio = prio self.data = data def __init__(self) -> None: """Initialize an empty queue.""" self.heap: list[PriorityQueue.Node[_ElemT]] = [] def clear(self) -> None: """Remove all elements from the queue. This function takes time O(n). """ self.heap.clear() def empty(self) -> bool: """Return True if the queue is empty.""" return (not self.heap) def find_min(self) -> Node[_ElemT]: """Return the minimum-priority node. This function takes time O(1). """ if not self.heap: raise IndexError("Queue is empty") return self.heap[0] def _sift_up(self, index: int) -> None: """Repair the heap along an ascending path to the root.""" node = self.heap[index] prio = node.prio pos = index while pos > 0: tpos = (pos - 1) // 2 tnode = self.heap[tpos] if tnode.prio <= prio: break tnode.index = pos self.heap[pos] = tnode pos = tpos if pos != index: node.index = pos self.heap[pos] = node def _sift_down(self, index: int) -> None: """Repair the heap along a descending path.""" num_elem = len(self.heap) node = self.heap[index] prio = node.prio pos = index while True: tpos = 2 * pos + 1 if tpos >= num_elem: break tnode = self.heap[tpos] qpos = tpos + 1 if qpos < num_elem: qnode = self.heap[qpos] if qnode.prio <= tnode.prio: tpos = qpos tnode = qnode if tnode.prio >= prio: break tnode.index = pos self.heap[pos] = tnode pos = tpos if pos != index: node.index = pos self.heap[pos] = node def insert(self, prio: float, data: _ElemT) -> Node: """Insert a new element into the queue. This function takes time O(log(n)). Returns: Node that represents the new element. """ new_index = len(self.heap) node = self.Node(new_index, prio, data) self.heap.append(node) self._sift_up(new_index) return node def delete(self, elem: Node[_ElemT]) -> None: """Delete the specified element from the queue. This function takes time O(log(n)). """ index = elem.index assert self.heap[index] is elem node = self.heap.pop() if index < len(self.heap): node.index = index self.heap[index] = node if node.prio < elem.prio: self._sift_up(index) elif node.prio > elem.prio: self._sift_down(index) def decrease_prio(self, elem: Node[_ElemT], prio: float) -> None: """Decrease the priority of an existing element in the queue. This function takes time O(log(n)). """ assert self.heap[elem.index] is elem assert prio <= elem.prio elem.prio = prio self._sift_up(elem.index) def increase_prio(self, elem: Node[_ElemT], prio: float) -> None: """Increase the priority of an existing element in the queue. This function takes time O(log(n)). """ assert self.heap[elem.index] is elem assert prio >= elem.prio elem.prio = prio self._sift_down(elem.index)