1
0
Fork 0
maximum-weight-matching/python/mwmatching/datastruct.py

635 lines
21 KiB
Python
Raw Normal View History

2024-07-06 22:32:19 +02:00
"""Data structures for matching."""
from __future__ import annotations
from typing import Generic, Optional, TypeVar
_NameT = TypeVar("_NameT")
_NameT2 = TypeVar("_NameT2")
_ElemT = TypeVar("_ElemT")
_ElemT2 = TypeVar("_ElemT2")
class ConcatenableQueue(Generic[_NameT, _ElemT]):
"""Priority queue supporting efficient merge and split operations.
This is a combination of a disjoint set and a priority queue.
A queue has a "name", which can be any Python object.
Each element has associated "data", which can be any Python object.
Each element has a priority.
The following operations can be done efficiently:
- Create a new queue containing one new element.
- Find the name of the queue that contains a given element.
- Change the priority of a given element.
- Find the element with lowest priority in a given queue.
- Merge two or more queues.
- Undo a previous merge step.
This data structure is implemented as a 2-3 tree with minimum-priority
tracking added to it.
"""
2024-07-20 15:13:36 +02:00
__slots__ = ("name", "tree", "first_node", "sub_queues")
class BaseNode(Generic[_NameT2, _ElemT2]):
"""Node in the 2-3 tree."""
__slots__ = ("owner", "min_node", "height", "parent", "childs")
def __init__(self,
min_node: ConcatenableQueue.Node[_NameT2, _ElemT2],
height: int
) -> None:
"""Initialize a new node."""
self.owner: Optional[ConcatenableQueue[_NameT2, _ElemT2]] = None
self.min_node = min_node
self.height = height
self.parent: Optional[ConcatenableQueue.BaseNode[_NameT2,
_ElemT2]]
self.parent = None
self.childs: list[ConcatenableQueue.BaseNode[_NameT2, _ElemT2]]
self.childs = []
class Node(BaseNode[_NameT2, _ElemT2]):
"""Leaf node in the 2-3 tree, representing an element in the queue."""
__slots__ = ("data", "prio")
def __init__(self, data: _ElemT2, prio: float) -> None:
"""Initialize a new leaf node.
This method should not be called directly.
Instead, call ConcatenableQueue.insert().
"""
super().__init__(min_node=self, height=0)
self.data = data
self.prio = prio
def find(self) -> _NameT2:
"""Return the name of the queue that contains this element.
This function takes time O(log(n)).
"""
node: ConcatenableQueue.BaseNode[_NameT2, _ElemT2] = self
while node.parent is not None:
node = node.parent
assert node.owner is not None
return node.owner.name
def set_prio(self, prio: float) -> None:
"""Change the priority of this element.
This function takes time O(log(n)).
"""
self.prio = prio
node = self.parent
while node is not None:
min_node = node.childs[0].min_node
for child in node.childs[1:]:
if child.min_node.prio < min_node.prio:
min_node = child.min_node
node.min_node = min_node
node = node.parent
def __init__(self, name: _NameT) -> None:
"""Initialize an empty queue.
This function takes time O(1).
Parameters:
name: Name to assign to the new queue.
"""
self.name = name
self.tree: Optional[ConcatenableQueue.BaseNode[_NameT, _ElemT]] = None
self.first_node: Optional[ConcatenableQueue.Node[_NameT, _ElemT]]
self.first_node = None
self.sub_queues: list[ConcatenableQueue[_NameT, _ElemT]] = []
def clear(self) -> None:
"""Remove all elements from the queue.
This function takes time O(n).
"""
node = self.tree
self.tree = None
2024-07-20 15:13:36 +02:00
self.first_node = None
self.sub_queues = []
# Wipe pointers to enable refcounted garbage collection.
if node is not None:
node.owner = None
while node is not None:
2024-07-06 16:13:34 +02:00
node.min_node = None # type: ignore
prev_node = node
if node.childs:
node = node.childs.pop()
else:
node = node.parent
prev_node.parent = None
def insert(self, elem: _ElemT, prio: float) -> Node[_NameT, _ElemT]:
"""Insert an element into the empty queue.
This function can only be used if the queue is empty.
Non-empty queues can grow only by merging.
This function takes time O(1).
Parameters:
elem: Element to insert.
prio: Initial priority of the new element.
"""
assert self.tree is None
self.tree = ConcatenableQueue.Node(elem, prio)
self.tree.owner = self
self.first_node = self.tree
return self.tree
def min_prio(self) -> float:
"""Return the minimum priority of any element in the queue.
The queue must be non-empty.
This function takes time O(1).
"""
node = self.tree
assert node is not None
return node.min_node.prio
def min_elem(self) -> _ElemT:
"""Return the element with minimum priority.
The queue must be non-empty.
This function takes time O(1).
"""
node = self.tree
assert node is not None
return node.min_node.data
2024-06-22 20:04:49 +02:00
def merge(self,
sub_queues: list[ConcatenableQueue[_NameT, _ElemT]]
2024-06-22 20:04:49 +02:00
) -> None:
"""Merge the specified queues.
This queue must inititially be empty.
All specified sub-queues must initially be non-empty.
This function removes all elements from the specified sub-queues
and adds them to this queue.
2024-06-22 20:04:49 +02:00
After merging, this queue retains a reference to the list of
sub-queues.
This function takes time O(len(sub_queues) * log(n)).
"""
assert self.tree is None
assert not self.sub_queues
assert sub_queues
# Keep the list of sub-queues.
self.sub_queues = sub_queues
# Move the root node from the first sub-queue to this queue.
# Clear its owner pointer.
self.tree = sub_queues[0].tree
2024-07-20 15:13:36 +02:00
self.first_node = sub_queues[0].first_node
assert self.tree is not None
sub_queues[0].tree = None
self.tree.owner = None
# Merge remaining sub-queues.
for sub in sub_queues[1:]:
# Pull the root node from the sub-queue.
# Clear its owner pointer.
subtree = sub.tree
assert subtree is not None
assert subtree.owner is sub
subtree.owner = None
# Merge our current tree with the tree from the sub-queue.
self.tree = self._join(self.tree, subtree)
# Put the owner pointer in the root node.
self.tree.owner = self
def split(self) -> None:
"""Undo the merge step that filled this queue.
Remove all elements from this queue and put them back in
the sub-queues from which they came.
After splitting, this queue will be empty.
This function takes time O(k * log(n)).
"""
assert self.tree is not None
assert self.sub_queues
# Clear the owner pointer from the root node.
assert self.tree.owner is self
self.tree.owner = None
# Split the tree to reconstruct each sub-queue.
2024-07-20 15:13:36 +02:00
for sub in self.sub_queues[:0:-1]:
2024-07-20 15:13:36 +02:00
assert sub.first_node is not None
(tree, rtree) = self._split_tree(sub.first_node)
# Assign the right tree to the sub-queue.
sub.tree = rtree
rtree.owner = sub
# Put the remaining tree in the first sub-queue.
self.sub_queues[0].tree = tree
tree.owner = self.sub_queues[0]
# Make this queue empty.
self.tree = None
2024-07-20 15:13:36 +02:00
self.first_node = None
self.sub_queues = []
@staticmethod
def _repair_node(node: BaseNode[_NameT, _ElemT]) -> None:
"""Repair min_prio attribute of an internal node."""
min_node = node.childs[0].min_node
for child in node.childs[1:]:
if child.min_node.prio < min_node.prio:
min_node = child.min_node
node.min_node = min_node
@staticmethod
def _new_internal_node(ltree: BaseNode[_NameT, _ElemT],
rtree: BaseNode[_NameT, _ElemT]
) -> BaseNode[_NameT, _ElemT]:
"""Create a new internal node with 2 child nodes."""
assert ltree.height == rtree.height
height = ltree.height + 1
if ltree.min_node.prio <= rtree.min_node.prio:
min_node = ltree.min_node
else:
min_node = rtree.min_node
node = ConcatenableQueue.BaseNode(min_node, height)
node.childs = [ltree, rtree]
ltree.parent = node
rtree.parent = node
return node
def _join_right(self,
ltree: BaseNode[_NameT, _ElemT],
rtree: BaseNode[_NameT, _ElemT]
) -> BaseNode[_NameT, _ElemT]:
"""Join two trees together.
The initial left subtree must be higher than the right subtree.
Return the root node of the joined tree.
"""
# Descend down the right spine of the left tree until we
# reach a node just above the right tree.
node = ltree
while node.height > rtree.height + 1:
node = node.childs[-1]
assert node.height == rtree.height + 1
# Find a node in the left tree to insert the right tree as child.
while len(node.childs) == 3:
# This node already has 3 childs so we can not add the right tree.
# Rearrange into 2 nodes with 2 childs each, then solve it
# at the parent level.
#
# N N R'
# / | \ / \ / \
# / | \ ---> / \ / \
# A B C R A B C R
#
child = node.childs.pop()
self._repair_node(node)
rtree = self._new_internal_node(child, rtree)
if node.parent is None:
# Create a new root node.
return self._new_internal_node(node, rtree)
node = node.parent
# Insert the right tree as child of this node.
assert len(node.childs) < 3
node.childs.append(rtree)
rtree.parent = node
# Repair min-prio pointers of ancestors.
while True:
self._repair_node(node)
if node.parent is None:
break
node = node.parent
return node
def _join_left(self,
ltree: BaseNode[_NameT, _ElemT],
rtree: BaseNode[_NameT, _ElemT]
) -> BaseNode[_NameT, _ElemT]:
"""Join two trees together.
The initial left subtree must be lower than the right subtree.
Return the root node of the joined tree.
"""
# Descend down the left spine of the right tree until we
# reach a node just above the left tree.
node = rtree
while node.height > ltree.height + 1:
node = node.childs[0]
assert node.height == ltree.height + 1
# Find a node in the right tree to insert the left tree as child.
while len(node.childs) == 3:
# This node already has 3 childs so we can not add the left tree.
# Rearrange into 2 nodes with 2 childs each, then solve it
# at the parent level.
#
# N L' N
# / | \ / \ / \
# / | \ ---> / \ / \
# L A B C L A B C
#
child = node.childs.pop(0)
self._repair_node(node)
ltree = self._new_internal_node(ltree, child)
if node.parent is None:
# Create a new root node.
return self._new_internal_node(ltree, node)
node = node.parent
# Insert the left tree as child of this node.
assert len(node.childs) < 3
node.childs.insert(0, ltree)
ltree.parent = node
# Repair min-prio pointers of ancestors.
while True:
self._repair_node(node)
if node.parent is None:
break
node = node.parent
return node
def _join(self,
ltree: BaseNode[_NameT, _ElemT],
rtree: BaseNode[_NameT, _ElemT]
) -> BaseNode[_NameT, _ElemT]:
"""Join two trees together.
The left and right subtree must be consistent 2-3 trees.
Initial parent pointers of these subtrees are ignored.
Return the root node of the joined tree.
"""
if ltree.height > rtree.height:
return self._join_right(ltree, rtree)
elif ltree.height < rtree.height:
return self._join_left(ltree, rtree)
else:
return self._new_internal_node(ltree, rtree)
def _split_tree(self,
split_node: BaseNode[_NameT, _ElemT]
) -> tuple[BaseNode[_NameT, _ElemT],
BaseNode[_NameT, _ElemT]]:
"""Split a tree on a specified node.
Two new trees will be constructed.
Leaf nodes to the left of "split_node" will go to the left tree.
Leaf nodes to the right of "split_node", and "split_node" itself,
will go to the right tree.
Return tuple (ltree, rtree).
"""
# Detach "split_node" from its parent.
# Assign it to the right tree.
parent = split_node.parent
split_node.parent = None
# The left tree is initially empty.
# The right tree initially contains only "split_node".
ltree: Optional[ConcatenableQueue.BaseNode[_NameT, _ElemT]] = None
rtree = split_node
# Walk up to the root of the tree.
# On the way up, detach each node from its parent and join its
# child nodes to the appropriate tree.
node = split_node
while parent is not None:
# Ascend to the parent node.
child = node
node = parent
parent = node.parent
# Detach "node" from its parent.
node.parent = None
if len(node.childs) == 3:
if node.childs[0] is child:
# "node" has 3 child nodes.
# Its left subtree has already been split.
# Turn it into a 2-node and join it to the right tree.
node.childs.pop(0)
self._repair_node(node)
rtree = self._join(rtree, node)
elif node.childs[2] is child:
# "node" has 3 child nodes.
# Its right subtree has already been split.
# Turn it into a 2-node and join it to the left tree.
node.childs.pop()
self._repair_node(node)
if ltree is None:
ltree = node
else:
ltree = self._join(node, ltree)
else:
# "node has 3 child nodes.
# Its middle subtree has already been split.
# Join its left child to the left tree, and its right
# child to the right tree, then delete "node".
node.childs[0].parent = None
node.childs[2].parent = None
if ltree is None:
ltree = node.childs[0]
else:
ltree = self._join(node.childs[0], ltree)
rtree = self._join(rtree, node.childs[2])
elif node.childs[0] is child:
# "node" has 2 child nodes.
# Its left subtree has already been split.
# Join its right child to the right tree, then delete "node".
node.childs[1].parent = None
rtree = self._join(rtree, node.childs[1])
else:
# "node" has 2 child nodes.
# Its right subtree has already been split.
# Join its left child to the left tree, then delete "node".
node.childs[0].parent = None
if ltree is None:
ltree = node.childs[0]
else:
ltree = self._join(node.childs[0], ltree)
assert ltree is not None
return (ltree, rtree)
class PriorityQueue(Generic[_ElemT]):
"""Priority queue based on a binary heap."""
__slots__ = ("heap", )
class Node(Generic[_ElemT2]):
"""Node in the priority queue."""
__slots__ = ("index", "prio", "data")
def __init__(
self,
index: int,
prio: float,
data: _ElemT2
) -> None:
self.index = index
self.prio = prio
self.data = data
def __init__(self) -> None:
"""Initialize an empty queue."""
self.heap: list[PriorityQueue.Node[_ElemT]] = []
def clear(self) -> None:
"""Remove all elements from the queue.
This function takes time O(n).
"""
self.heap.clear()
def empty(self) -> bool:
"""Return True if the queue is empty."""
return (not self.heap)
def find_min(self) -> Node[_ElemT]:
"""Return the minimum-priority node.
This function takes time O(1).
"""
if not self.heap:
raise IndexError("Queue is empty")
return self.heap[0]
def _sift_up(self, index: int) -> None:
"""Repair the heap along an ascending path to the root."""
node = self.heap[index]
prio = node.prio
pos = index
while pos > 0:
tpos = (pos - 1) // 2
tnode = self.heap[tpos]
if tnode.prio <= prio:
break
tnode.index = pos
self.heap[pos] = tnode
pos = tpos
if pos != index:
node.index = pos
self.heap[pos] = node
def _sift_down(self, index: int) -> None:
"""Repair the heap along a descending path."""
num_elem = len(self.heap)
node = self.heap[index]
prio = node.prio
pos = index
while True:
tpos = 2 * pos + 1
if tpos >= num_elem:
break
tnode = self.heap[tpos]
qpos = tpos + 1
if qpos < num_elem:
qnode = self.heap[qpos]
if qnode.prio <= tnode.prio:
tpos = qpos
tnode = qnode
if tnode.prio >= prio:
break
tnode.index = pos
self.heap[pos] = tnode
pos = tpos
if pos != index:
node.index = pos
self.heap[pos] = node
def insert(self, prio: float, data: _ElemT) -> Node:
"""Insert a new element into the queue.
This function takes time O(log(n)).
Returns:
Node that represents the new element.
"""
new_index = len(self.heap)
node = self.Node(new_index, prio, data)
self.heap.append(node)
self._sift_up(new_index)
return node
def delete(self, elem: Node[_ElemT]) -> None:
"""Delete the specified element from the queue.
This function takes time O(log(n)).
"""
index = elem.index
assert self.heap[index] is elem
node = self.heap.pop()
if index < len(self.heap):
node.index = index
self.heap[index] = node
if node.prio < elem.prio:
self._sift_up(index)
elif node.prio > elem.prio:
self._sift_down(index)
def decrease_prio(self, elem: Node[_ElemT], prio: float) -> None:
"""Decrease the priority of an existing element in the queue.
This function takes time O(log(n)).
"""
assert self.heap[elem.index] is elem
assert prio <= elem.prio
elem.prio = prio
self._sift_up(elem.index)
2024-06-23 19:46:50 +02:00
def increase_prio(self, elem: Node[_ElemT], prio: float) -> None:
"""Increase the priority of an existing element in the queue.
This function takes time O(log(n)).
"""
assert self.heap[elem.index] is elem
assert prio >= elem.prio
elem.prio = prio
self._sift_down(elem.index)