635 lines
21 KiB
Python
635 lines
21 KiB
Python
"""Data structures for matching."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Generic, Optional, TypeVar
|
|
|
|
|
|
_NameT = TypeVar("_NameT")
|
|
_NameT2 = TypeVar("_NameT2")
|
|
_ElemT = TypeVar("_ElemT")
|
|
_ElemT2 = TypeVar("_ElemT2")
|
|
|
|
|
|
class ConcatenableQueue(Generic[_NameT, _ElemT]):
|
|
"""Priority queue supporting efficient merge and split operations.
|
|
|
|
This is a combination of a disjoint set and a priority queue.
|
|
A queue has a "name", which can be any Python object.
|
|
Each element has associated "data", which can be any Python object.
|
|
Each element has a priority.
|
|
|
|
The following operations can be done efficiently:
|
|
- Create a new queue containing one new element.
|
|
- Find the name of the queue that contains a given element.
|
|
- Change the priority of a given element.
|
|
- Find the element with lowest priority in a given queue.
|
|
- Merge two or more queues.
|
|
- Undo a previous merge step.
|
|
|
|
This data structure is implemented as a 2-3 tree with minimum-priority
|
|
tracking added to it.
|
|
"""
|
|
|
|
__slots__ = ("name", "tree", "first_node", "sub_queues")
|
|
|
|
class BaseNode(Generic[_NameT2, _ElemT2]):
|
|
"""Node in the 2-3 tree."""
|
|
|
|
__slots__ = ("owner", "min_node", "height", "parent", "childs")
|
|
|
|
def __init__(self,
|
|
min_node: ConcatenableQueue.Node[_NameT2, _ElemT2],
|
|
height: int
|
|
) -> None:
|
|
"""Initialize a new node."""
|
|
self.owner: Optional[ConcatenableQueue[_NameT2, _ElemT2]] = None
|
|
self.min_node = min_node
|
|
self.height = height
|
|
self.parent: Optional[ConcatenableQueue.BaseNode[_NameT2,
|
|
_ElemT2]]
|
|
self.parent = None
|
|
self.childs: list[ConcatenableQueue.BaseNode[_NameT2, _ElemT2]]
|
|
self.childs = []
|
|
|
|
class Node(BaseNode[_NameT2, _ElemT2]):
|
|
"""Leaf node in the 2-3 tree, representing an element in the queue."""
|
|
|
|
__slots__ = ("data", "prio")
|
|
|
|
def __init__(self, data: _ElemT2, prio: float) -> None:
|
|
"""Initialize a new leaf node.
|
|
|
|
This method should not be called directly.
|
|
Instead, call ConcatenableQueue.insert().
|
|
"""
|
|
super().__init__(min_node=self, height=0)
|
|
self.data = data
|
|
self.prio = prio
|
|
|
|
def find(self) -> _NameT2:
|
|
"""Return the name of the queue that contains this element.
|
|
|
|
This function takes time O(log(n)).
|
|
"""
|
|
node: ConcatenableQueue.BaseNode[_NameT2, _ElemT2] = self
|
|
while node.parent is not None:
|
|
node = node.parent
|
|
assert node.owner is not None
|
|
return node.owner.name
|
|
|
|
def set_prio(self, prio: float) -> None:
|
|
"""Change the priority of this element.
|
|
|
|
This function takes time O(log(n)).
|
|
"""
|
|
self.prio = prio
|
|
node = self.parent
|
|
while node is not None:
|
|
min_node = node.childs[0].min_node
|
|
for child in node.childs[1:]:
|
|
if child.min_node.prio < min_node.prio:
|
|
min_node = child.min_node
|
|
node.min_node = min_node
|
|
node = node.parent
|
|
|
|
def __init__(self, name: _NameT) -> None:
|
|
"""Initialize an empty queue.
|
|
|
|
This function takes time O(1).
|
|
|
|
Parameters:
|
|
name: Name to assign to the new queue.
|
|
"""
|
|
self.name = name
|
|
self.tree: Optional[ConcatenableQueue.BaseNode[_NameT, _ElemT]] = None
|
|
self.first_node: Optional[ConcatenableQueue.Node[_NameT, _ElemT]]
|
|
self.first_node = None
|
|
self.sub_queues: list[ConcatenableQueue[_NameT, _ElemT]] = []
|
|
|
|
def clear(self) -> None:
|
|
"""Remove all elements from the queue.
|
|
|
|
This function takes time O(n).
|
|
"""
|
|
node = self.tree
|
|
self.tree = None
|
|
self.first_node = None
|
|
self.sub_queues = []
|
|
|
|
# Wipe pointers to enable refcounted garbage collection.
|
|
if node is not None:
|
|
node.owner = None
|
|
while node is not None:
|
|
node.min_node = None # type: ignore
|
|
prev_node = node
|
|
if node.childs:
|
|
node = node.childs.pop()
|
|
else:
|
|
node = node.parent
|
|
prev_node.parent = None
|
|
|
|
def insert(self, elem: _ElemT, prio: float) -> Node[_NameT, _ElemT]:
|
|
"""Insert an element into the empty queue.
|
|
|
|
This function can only be used if the queue is empty.
|
|
Non-empty queues can grow only by merging.
|
|
|
|
This function takes time O(1).
|
|
|
|
Parameters:
|
|
elem: Element to insert.
|
|
prio: Initial priority of the new element.
|
|
"""
|
|
assert self.tree is None
|
|
self.tree = ConcatenableQueue.Node(elem, prio)
|
|
self.tree.owner = self
|
|
self.first_node = self.tree
|
|
return self.tree
|
|
|
|
def min_prio(self) -> float:
|
|
"""Return the minimum priority of any element in the queue.
|
|
|
|
The queue must be non-empty.
|
|
This function takes time O(1).
|
|
"""
|
|
node = self.tree
|
|
assert node is not None
|
|
return node.min_node.prio
|
|
|
|
def min_elem(self) -> _ElemT:
|
|
"""Return the element with minimum priority.
|
|
|
|
The queue must be non-empty.
|
|
This function takes time O(1).
|
|
"""
|
|
node = self.tree
|
|
assert node is not None
|
|
return node.min_node.data
|
|
|
|
def merge(self,
|
|
sub_queues: list[ConcatenableQueue[_NameT, _ElemT]]
|
|
) -> None:
|
|
"""Merge the specified queues.
|
|
|
|
This queue must inititially be empty.
|
|
All specified sub-queues must initially be non-empty.
|
|
|
|
This function removes all elements from the specified sub-queues
|
|
and adds them to this queue.
|
|
|
|
After merging, this queue retains a reference to the list of
|
|
sub-queues.
|
|
|
|
This function takes time O(len(sub_queues) * log(n)).
|
|
"""
|
|
assert self.tree is None
|
|
assert not self.sub_queues
|
|
assert sub_queues
|
|
|
|
# Keep the list of sub-queues.
|
|
self.sub_queues = sub_queues
|
|
|
|
# Move the root node from the first sub-queue to this queue.
|
|
# Clear its owner pointer.
|
|
self.tree = sub_queues[0].tree
|
|
self.first_node = sub_queues[0].first_node
|
|
assert self.tree is not None
|
|
sub_queues[0].tree = None
|
|
self.tree.owner = None
|
|
|
|
# Merge remaining sub-queues.
|
|
for sub in sub_queues[1:]:
|
|
|
|
# Pull the root node from the sub-queue.
|
|
# Clear its owner pointer.
|
|
subtree = sub.tree
|
|
assert subtree is not None
|
|
assert subtree.owner is sub
|
|
subtree.owner = None
|
|
|
|
# Merge our current tree with the tree from the sub-queue.
|
|
self.tree = self._join(self.tree, subtree)
|
|
|
|
# Put the owner pointer in the root node.
|
|
self.tree.owner = self
|
|
|
|
def split(self) -> None:
|
|
"""Undo the merge step that filled this queue.
|
|
|
|
Remove all elements from this queue and put them back in
|
|
the sub-queues from which they came.
|
|
|
|
After splitting, this queue will be empty.
|
|
|
|
This function takes time O(k * log(n)).
|
|
"""
|
|
assert self.tree is not None
|
|
assert self.sub_queues
|
|
|
|
# Clear the owner pointer from the root node.
|
|
assert self.tree.owner is self
|
|
self.tree.owner = None
|
|
|
|
# Split the tree to reconstruct each sub-queue.
|
|
for sub in self.sub_queues[:0:-1]:
|
|
|
|
assert sub.first_node is not None
|
|
(tree, rtree) = self._split_tree(sub.first_node)
|
|
|
|
# Assign the right tree to the sub-queue.
|
|
sub.tree = rtree
|
|
rtree.owner = sub
|
|
|
|
# Put the remaining tree in the first sub-queue.
|
|
self.sub_queues[0].tree = tree
|
|
tree.owner = self.sub_queues[0]
|
|
|
|
# Make this queue empty.
|
|
self.tree = None
|
|
self.first_node = None
|
|
self.sub_queues = []
|
|
|
|
@staticmethod
|
|
def _repair_node(node: BaseNode[_NameT, _ElemT]) -> None:
|
|
"""Repair min_prio attribute of an internal node."""
|
|
min_node = node.childs[0].min_node
|
|
for child in node.childs[1:]:
|
|
if child.min_node.prio < min_node.prio:
|
|
min_node = child.min_node
|
|
node.min_node = min_node
|
|
|
|
@staticmethod
|
|
def _new_internal_node(ltree: BaseNode[_NameT, _ElemT],
|
|
rtree: BaseNode[_NameT, _ElemT]
|
|
) -> BaseNode[_NameT, _ElemT]:
|
|
"""Create a new internal node with 2 child nodes."""
|
|
assert ltree.height == rtree.height
|
|
height = ltree.height + 1
|
|
if ltree.min_node.prio <= rtree.min_node.prio:
|
|
min_node = ltree.min_node
|
|
else:
|
|
min_node = rtree.min_node
|
|
node = ConcatenableQueue.BaseNode(min_node, height)
|
|
node.childs = [ltree, rtree]
|
|
ltree.parent = node
|
|
rtree.parent = node
|
|
return node
|
|
|
|
def _join_right(self,
|
|
ltree: BaseNode[_NameT, _ElemT],
|
|
rtree: BaseNode[_NameT, _ElemT]
|
|
) -> BaseNode[_NameT, _ElemT]:
|
|
"""Join two trees together.
|
|
|
|
The initial left subtree must be higher than the right subtree.
|
|
|
|
Return the root node of the joined tree.
|
|
"""
|
|
|
|
# Descend down the right spine of the left tree until we
|
|
# reach a node just above the right tree.
|
|
node = ltree
|
|
while node.height > rtree.height + 1:
|
|
node = node.childs[-1]
|
|
|
|
assert node.height == rtree.height + 1
|
|
|
|
# Find a node in the left tree to insert the right tree as child.
|
|
while len(node.childs) == 3:
|
|
# This node already has 3 childs so we can not add the right tree.
|
|
# Rearrange into 2 nodes with 2 childs each, then solve it
|
|
# at the parent level.
|
|
#
|
|
# N N R'
|
|
# / | \ / \ / \
|
|
# / | \ ---> / \ / \
|
|
# A B C R A B C R
|
|
#
|
|
child = node.childs.pop()
|
|
self._repair_node(node)
|
|
rtree = self._new_internal_node(child, rtree)
|
|
if node.parent is None:
|
|
# Create a new root node.
|
|
return self._new_internal_node(node, rtree)
|
|
node = node.parent
|
|
|
|
# Insert the right tree as child of this node.
|
|
assert len(node.childs) < 3
|
|
node.childs.append(rtree)
|
|
rtree.parent = node
|
|
|
|
# Repair min-prio pointers of ancestors.
|
|
while True:
|
|
self._repair_node(node)
|
|
if node.parent is None:
|
|
break
|
|
node = node.parent
|
|
|
|
return node
|
|
|
|
def _join_left(self,
|
|
ltree: BaseNode[_NameT, _ElemT],
|
|
rtree: BaseNode[_NameT, _ElemT]
|
|
) -> BaseNode[_NameT, _ElemT]:
|
|
"""Join two trees together.
|
|
|
|
The initial left subtree must be lower than the right subtree.
|
|
|
|
Return the root node of the joined tree.
|
|
"""
|
|
|
|
# Descend down the left spine of the right tree until we
|
|
# reach a node just above the left tree.
|
|
node = rtree
|
|
while node.height > ltree.height + 1:
|
|
node = node.childs[0]
|
|
|
|
assert node.height == ltree.height + 1
|
|
|
|
# Find a node in the right tree to insert the left tree as child.
|
|
while len(node.childs) == 3:
|
|
# This node already has 3 childs so we can not add the left tree.
|
|
# Rearrange into 2 nodes with 2 childs each, then solve it
|
|
# at the parent level.
|
|
#
|
|
# N L' N
|
|
# / | \ / \ / \
|
|
# / | \ ---> / \ / \
|
|
# L A B C L A B C
|
|
#
|
|
child = node.childs.pop(0)
|
|
self._repair_node(node)
|
|
ltree = self._new_internal_node(ltree, child)
|
|
if node.parent is None:
|
|
# Create a new root node.
|
|
return self._new_internal_node(ltree, node)
|
|
node = node.parent
|
|
|
|
# Insert the left tree as child of this node.
|
|
assert len(node.childs) < 3
|
|
node.childs.insert(0, ltree)
|
|
ltree.parent = node
|
|
|
|
# Repair min-prio pointers of ancestors.
|
|
while True:
|
|
self._repair_node(node)
|
|
if node.parent is None:
|
|
break
|
|
node = node.parent
|
|
|
|
return node
|
|
|
|
def _join(self,
|
|
ltree: BaseNode[_NameT, _ElemT],
|
|
rtree: BaseNode[_NameT, _ElemT]
|
|
) -> BaseNode[_NameT, _ElemT]:
|
|
"""Join two trees together.
|
|
|
|
The left and right subtree must be consistent 2-3 trees.
|
|
Initial parent pointers of these subtrees are ignored.
|
|
|
|
Return the root node of the joined tree.
|
|
"""
|
|
if ltree.height > rtree.height:
|
|
return self._join_right(ltree, rtree)
|
|
elif ltree.height < rtree.height:
|
|
return self._join_left(ltree, rtree)
|
|
else:
|
|
return self._new_internal_node(ltree, rtree)
|
|
|
|
def _split_tree(self,
|
|
split_node: BaseNode[_NameT, _ElemT]
|
|
) -> tuple[BaseNode[_NameT, _ElemT],
|
|
BaseNode[_NameT, _ElemT]]:
|
|
"""Split a tree on a specified node.
|
|
|
|
Two new trees will be constructed.
|
|
Leaf nodes to the left of "split_node" will go to the left tree.
|
|
Leaf nodes to the right of "split_node", and "split_node" itself,
|
|
will go to the right tree.
|
|
|
|
Return tuple (ltree, rtree).
|
|
"""
|
|
|
|
# Detach "split_node" from its parent.
|
|
# Assign it to the right tree.
|
|
parent = split_node.parent
|
|
split_node.parent = None
|
|
|
|
# The left tree is initially empty.
|
|
# The right tree initially contains only "split_node".
|
|
ltree: Optional[ConcatenableQueue.BaseNode[_NameT, _ElemT]] = None
|
|
rtree = split_node
|
|
|
|
# Walk up to the root of the tree.
|
|
# On the way up, detach each node from its parent and join its
|
|
# child nodes to the appropriate tree.
|
|
node = split_node
|
|
while parent is not None:
|
|
|
|
# Ascend to the parent node.
|
|
child = node
|
|
node = parent
|
|
parent = node.parent
|
|
|
|
# Detach "node" from its parent.
|
|
node.parent = None
|
|
|
|
if len(node.childs) == 3:
|
|
if node.childs[0] is child:
|
|
# "node" has 3 child nodes.
|
|
# Its left subtree has already been split.
|
|
# Turn it into a 2-node and join it to the right tree.
|
|
node.childs.pop(0)
|
|
self._repair_node(node)
|
|
rtree = self._join(rtree, node)
|
|
elif node.childs[2] is child:
|
|
# "node" has 3 child nodes.
|
|
# Its right subtree has already been split.
|
|
# Turn it into a 2-node and join it to the left tree.
|
|
node.childs.pop()
|
|
self._repair_node(node)
|
|
if ltree is None:
|
|
ltree = node
|
|
else:
|
|
ltree = self._join(node, ltree)
|
|
else:
|
|
# "node has 3 child nodes.
|
|
# Its middle subtree has already been split.
|
|
# Join its left child to the left tree, and its right
|
|
# child to the right tree, then delete "node".
|
|
node.childs[0].parent = None
|
|
node.childs[2].parent = None
|
|
if ltree is None:
|
|
ltree = node.childs[0]
|
|
else:
|
|
ltree = self._join(node.childs[0], ltree)
|
|
rtree = self._join(rtree, node.childs[2])
|
|
|
|
elif node.childs[0] is child:
|
|
# "node" has 2 child nodes.
|
|
# Its left subtree has already been split.
|
|
# Join its right child to the right tree, then delete "node".
|
|
node.childs[1].parent = None
|
|
rtree = self._join(rtree, node.childs[1])
|
|
|
|
else:
|
|
# "node" has 2 child nodes.
|
|
# Its right subtree has already been split.
|
|
# Join its left child to the left tree, then delete "node".
|
|
node.childs[0].parent = None
|
|
if ltree is None:
|
|
ltree = node.childs[0]
|
|
else:
|
|
ltree = self._join(node.childs[0], ltree)
|
|
|
|
assert ltree is not None
|
|
return (ltree, rtree)
|
|
|
|
|
|
class PriorityQueue(Generic[_ElemT]):
|
|
"""Priority queue based on a binary heap."""
|
|
|
|
__slots__ = ("heap", )
|
|
|
|
class Node(Generic[_ElemT2]):
|
|
"""Node in the priority queue."""
|
|
|
|
__slots__ = ("index", "prio", "data")
|
|
|
|
def __init__(
|
|
self,
|
|
index: int,
|
|
prio: float,
|
|
data: _ElemT2
|
|
) -> None:
|
|
self.index = index
|
|
self.prio = prio
|
|
self.data = data
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize an empty queue."""
|
|
self.heap: list[PriorityQueue.Node[_ElemT]] = []
|
|
|
|
def clear(self) -> None:
|
|
"""Remove all elements from the queue.
|
|
|
|
This function takes time O(n).
|
|
"""
|
|
self.heap.clear()
|
|
|
|
def empty(self) -> bool:
|
|
"""Return True if the queue is empty."""
|
|
return (not self.heap)
|
|
|
|
def find_min(self) -> Node[_ElemT]:
|
|
"""Return the minimum-priority node.
|
|
|
|
This function takes time O(1).
|
|
"""
|
|
if not self.heap:
|
|
raise IndexError("Queue is empty")
|
|
return self.heap[0]
|
|
|
|
def _sift_up(self, index: int) -> None:
|
|
"""Repair the heap along an ascending path to the root."""
|
|
node = self.heap[index]
|
|
prio = node.prio
|
|
|
|
pos = index
|
|
while pos > 0:
|
|
tpos = (pos - 1) // 2
|
|
tnode = self.heap[tpos]
|
|
if tnode.prio <= prio:
|
|
break
|
|
tnode.index = pos
|
|
self.heap[pos] = tnode
|
|
pos = tpos
|
|
|
|
if pos != index:
|
|
node.index = pos
|
|
self.heap[pos] = node
|
|
|
|
def _sift_down(self, index: int) -> None:
|
|
"""Repair the heap along a descending path."""
|
|
num_elem = len(self.heap)
|
|
node = self.heap[index]
|
|
prio = node.prio
|
|
|
|
pos = index
|
|
while True:
|
|
tpos = 2 * pos + 1
|
|
if tpos >= num_elem:
|
|
break
|
|
tnode = self.heap[tpos]
|
|
|
|
qpos = tpos + 1
|
|
if qpos < num_elem:
|
|
qnode = self.heap[qpos]
|
|
if qnode.prio <= tnode.prio:
|
|
tpos = qpos
|
|
tnode = qnode
|
|
|
|
if tnode.prio >= prio:
|
|
break
|
|
|
|
tnode.index = pos
|
|
self.heap[pos] = tnode
|
|
pos = tpos
|
|
|
|
if pos != index:
|
|
node.index = pos
|
|
self.heap[pos] = node
|
|
|
|
def insert(self, prio: float, data: _ElemT) -> Node:
|
|
"""Insert a new element into the queue.
|
|
|
|
This function takes time O(log(n)).
|
|
|
|
Returns:
|
|
Node that represents the new element.
|
|
"""
|
|
new_index = len(self.heap)
|
|
node = self.Node(new_index, prio, data)
|
|
self.heap.append(node)
|
|
self._sift_up(new_index)
|
|
return node
|
|
|
|
def delete(self, elem: Node[_ElemT]) -> None:
|
|
"""Delete the specified element from the queue.
|
|
|
|
This function takes time O(log(n)).
|
|
"""
|
|
index = elem.index
|
|
assert self.heap[index] is elem
|
|
|
|
node = self.heap.pop()
|
|
if index < len(self.heap):
|
|
node.index = index
|
|
self.heap[index] = node
|
|
if node.prio < elem.prio:
|
|
self._sift_up(index)
|
|
elif node.prio > elem.prio:
|
|
self._sift_down(index)
|
|
|
|
def decrease_prio(self, elem: Node[_ElemT], prio: float) -> None:
|
|
"""Decrease the priority of an existing element in the queue.
|
|
|
|
This function takes time O(log(n)).
|
|
"""
|
|
assert self.heap[elem.index] is elem
|
|
assert prio <= elem.prio
|
|
elem.prio = prio
|
|
self._sift_up(elem.index)
|
|
|
|
def increase_prio(self, elem: Node[_ElemT], prio: float) -> None:
|
|
"""Increase the priority of an existing element in the queue.
|
|
|
|
This function takes time O(log(n)).
|
|
"""
|
|
assert self.heap[elem.index] is elem
|
|
assert prio >= elem.prio
|
|
elem.prio = prio
|
|
self._sift_down(elem.index)
|