1
0
Fork 0
maximum-weight-matching/python/mwmatching/datastruct.py

839 lines
26 KiB
Python
Raw Normal View History

2024-07-06 22:32:19 +02:00
"""Data structures for matching."""
from __future__ import annotations
from typing import Generic, Optional, TypeVar
_NameT = TypeVar("_NameT")
_NameT2 = TypeVar("_NameT2")
_ElemT = TypeVar("_ElemT")
_ElemT2 = TypeVar("_ElemT2")
class UnionFindQueue(Generic[_NameT, _ElemT]):
"""Combination of disjoint set and priority queue.
A queue has a "name", which can be any Python object.
Each element has associated "data", which can be any Python object.
Each element has a priority.
The following operations can be done efficiently:
- Create a new queue containing one new element.
- Find the name of the queue that contains a given element.
- Change the priority of a given element.
- Find the element with lowest priority in a given queue.
- Merge two or more queues.
- Undo a previous merge step.
The implementation is essentially an AVL tree, with minimum-priority
tracking added to it.
"""
__slots__ = ("name", "tree", "sub_queues", "split_nodes")
class Node(Generic[_NameT2, _ElemT2]):
"""Node in a UnionFindQueue."""
__slots__ = ("owner", "data", "prio", "min_node", "height",
"parent", "left", "right")
def __init__(self,
owner: UnionFindQueue[_NameT2, _ElemT2],
data: _ElemT2,
prio: float
) -> None:
"""Initialize a new element.
This method should not be called directly.
Instead, call UnionFindQueue.insert().
"""
self.owner: Optional[UnionFindQueue[_NameT2, _ElemT2]] = owner
self.data = data
self.prio = prio
self.min_node = self
self.height = 1
self.parent: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]]
self.left: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]]
self.right: Optional[UnionFindQueue.Node[_NameT2, _ElemT2]]
self.parent = None
self.left = None
self.right = None
def find(self) -> _NameT2:
"""Return the name of the queue that contains this element.
This function takes time O(log(n)).
"""
node = self
while node.parent is not None:
node = node.parent
assert node.owner is not None
return node.owner.name
def set_prio(self, prio: float) -> None:
"""Change the priority of this element."""
self.prio = prio
node = self
while True:
min_node = node
if node.left is not None:
left_min_node = node.left.min_node
if left_min_node.prio < min_node.prio:
min_node = left_min_node
if node.right is not None:
right_min_node = node.right.min_node
if right_min_node.prio < min_node.prio:
min_node = right_min_node
node.min_node = min_node
if node.parent is None:
break
node = node.parent
def __init__(self, name: _NameT) -> None:
"""Initialize an empty queue.
This function takes time O(1).
Parameters:
name: Name to assign to the new queue.
"""
self.name = name
self.tree: Optional[UnionFindQueue.Node[_NameT, _ElemT]] = None
self.sub_queues: list[UnionFindQueue[_NameT, _ElemT]] = []
self.split_nodes: list[UnionFindQueue.Node[_NameT, _ElemT]] = []
def clear(self) -> None:
"""Remove all elements from the queue.
This function takes time O(n).
"""
node = self.tree
self.tree = None
self.sub_queues = []
self.split_nodes.clear()
# Wipe pointers to enable refcounted garbage collection.
if node is not None:
node.owner = None
while node is not None:
2024-07-06 16:13:34 +02:00
node.min_node = None # type: ignore
prev_node = node
if node.left is not None:
node = node.left
prev_node.left = None
elif node.right is not None:
node = node.right
prev_node.right = None
else:
node = node.parent
prev_node.parent = None
def insert(self, elem: _ElemT, prio: float) -> Node[_NameT, _ElemT]:
"""Insert an element into the empty queue.
This function can only be used if the queue is empty.
Non-empty queues can grow only by merging.
This function takes time O(1).
Parameters:
elem: Element to insert.
prio: Initial priority of the new element.
"""
assert self.tree is None
self.tree = UnionFindQueue.Node(self, elem, prio)
return self.tree
def min_prio(self) -> float:
"""Return the minimum priority of any element in the queue.
The queue must be non-empty.
This function takes time O(1).
"""
node = self.tree
assert node is not None
return node.min_node.prio
def min_elem(self) -> _ElemT:
"""Return the element with minimum priority.
The queue must be non-empty.
This function takes time O(1).
"""
node = self.tree
assert node is not None
return node.min_node.data
2024-06-22 20:04:49 +02:00
def merge(self,
sub_queues: list[UnionFindQueue[_NameT, _ElemT]]
2024-06-22 20:04:49 +02:00
) -> None:
"""Merge the specified queues.
This queue must inititially be empty.
All specified sub-queues must initially be non-empty.
This function removes all elements from the specified sub-queues
and adds them to this queue.
2024-06-22 20:04:49 +02:00
After merging, this queue retains a reference to the list of
sub-queues.
This function takes time O(len(sub_queues) * log(n)).
"""
assert self.tree is None
assert not self.sub_queues
assert not self.split_nodes
assert sub_queues
# Keep the list of sub-queues.
self.sub_queues = sub_queues
# Move the root node from the first sub-queue to this queue.
# Clear its owner pointer.
self.tree = sub_queues[0].tree
assert self.tree is not None
sub_queues[0].tree = None
self.tree.owner = None
# Merge remaining sub-queues.
for sub in sub_queues[1:]:
# Pull the root node from the sub-queue.
# Clear its owner pointer.
subtree = sub.tree
assert subtree is not None
assert subtree.owner is sub
subtree.owner = None
# Merge our current tree with the tree from the sub-queue.
(self.tree, split_node) = self._merge_tree(self.tree, subtree)
# Keep track of the left-most node from the sub-queue.
self.split_nodes.append(split_node)
# Put the owner pointer in the root node.
self.tree.owner = self
def split(self) -> None:
"""Undo the merge step that filled this queue.
Remove all elements from this queue and put them back in
the sub-queues from which they came.
After splitting, this queue will be empty.
This function takes time O(k * log(n)).
"""
assert self.tree is not None
assert self.sub_queues
# Clear the owner pointer from the root node.
assert self.tree.owner is self
self.tree.owner = None
# Split the tree to reconstruct each sub-queue.
for (sub, split_node) in zip(self.sub_queues[:0:-1],
self.split_nodes[::-1]):
(tree, rtree) = self._split_tree(split_node)
# Assign the right tree to the sub-queue.
sub.tree = rtree
rtree.owner = sub
# Put the remaining tree in the first sub-queue.
self.sub_queues[0].tree = tree
tree.owner = self.sub_queues[0]
# Make this queue empty.
self.tree = None
self.sub_queues = []
self.split_nodes.clear()
@staticmethod
def _repair_node(node: Node[_NameT, _ElemT]) -> None:
"""Recalculate the height and min-priority information of the
specified node.
"""
# Repair node height.
lh = 0 if node.left is None else node.left.height
rh = 0 if node.right is None else node.right.height
node.height = 1 + max(lh, rh)
# Repair min-priority.
min_node = node
if node.left is not None:
left_min_node = node.left.min_node
if left_min_node.prio < min_node.prio:
min_node = left_min_node
if node.right is not None:
right_min_node = node.right.min_node
if right_min_node.prio < min_node.prio:
min_node = right_min_node
node.min_node = min_node
2024-06-22 20:04:49 +02:00
def _rotate_left(self,
node: Node[_NameT, _ElemT]
) -> Node[_NameT, _ElemT]:
"""Rotate the specified subtree to the left.
Return the new root node of the subtree.
"""
#
# N C
# / \ / \
# A C ---> N D
# / \ / \
# B D A B
#
parent = node.parent
new_top = node.right
assert new_top is not None
node.right = new_top.left
if node.right is not None:
node.right.parent = node
new_top.left = node
new_top.parent = parent
node.parent = new_top
if parent is not None:
if parent.left is node:
parent.left = new_top
elif parent.right is node:
parent.right = new_top
self._repair_node(node)
self._repair_node(new_top)
return new_top
2024-06-22 20:04:49 +02:00
def _rotate_right(self,
node: Node[_NameT, _ElemT]
) -> Node[_NameT, _ElemT]:
"""Rotate the specified node to the right.
Return the new root node of the subtree.
"""
#
# N A
# / \ / \
# A D ---> B N
# / \ / \
# B C C D
#
parent = node.parent
new_top = node.left
assert new_top is not None
node.left = new_top.right
if node.left is not None:
node.left.parent = node
new_top.right = node
new_top.parent = parent
node.parent = new_top
if parent is not None:
if parent.left is node:
parent.left = new_top
elif parent.right is node:
parent.right = new_top
self._repair_node(node)
self._repair_node(new_top)
return new_top
2024-06-22 20:04:49 +02:00
def _rebalance_up(self,
node: Node[_NameT, _ElemT]
) -> Node[_NameT, _ElemT]:
"""Repair and rebalance the specified node and its ancestors.
Return the root node of the rebalanced tree.
"""
# Walk up to the root of the tree.
while True:
lh = 0 if node.left is None else node.left.height
rh = 0 if node.right is None else node.right.height
if lh > rh + 1:
# This node is left-heavy. Rotate right to rebalance.
#
# N L
# / \ / \
# L \ / N
# / \ \ ---> / / \
# A B \ A B \
# \ \
# R R
#
lchild = node.left
assert lchild is not None
if ((lchild.right is not None)
and ((lchild.left is None)
or (lchild.right.height > lchild.left.height))):
# Double rotation.
lchild = self._rotate_left(lchild)
node = self._rotate_right(node)
elif lh + 1 < rh:
# This node is right-heavy. Rotate left to rebalance.
#
# N R
# / \ / \
# / R N \
# / / \ ---> / \ \
# / A B / A B
# / /
# L L
#
rchild = node.right
assert rchild is not None
if ((rchild.left is not None)
and ((rchild.right is None)
or (rchild.left.height > rchild.right.height))):
# Double rotation.
rchild = self._rotate_right(rchild)
node = self._rotate_left(node)
else:
# No rotation. Must still repair node though.
self._repair_node(node)
if node.parent is None:
break
# Continue rebalancing at the parent.
node = node.parent
# Return new root node.
return node
def _join_right(self,
ltree: Node[_NameT, _ElemT],
node: Node[_NameT, _ElemT],
rtree: Optional[Node[_NameT, _ElemT]]
) -> Node[_NameT, _ElemT]:
"""Join a left subtree, middle node and right subtree together.
The left subtree must be higher than the right subtree.
Return the root node of the joined tree.
"""
lh = ltree.height
rh = 0 if rtree is None else rtree.height
assert lh > rh + 1
# Descend down the right spine of "ltree".
# Stop at a node with compatible height, then insert "node"
# and attach "rtree".
#
# ltree
# / \
# X
# / \
# X <-- cur
# / \
# node
# / \
# X rtree
#
# Descend to a point with compatible height.
cur = ltree
while (cur.right is not None) and (cur.right.height > rh + 1):
cur = cur.right
# Insert "node" and "rtree".
node.left = cur.right
node.right = rtree
if node.left is not None:
node.left.parent = node
if rtree is not None:
rtree.parent = node
cur.right = node
node.parent = cur
# A double rotation may be necessary.
if (cur.left is None) or (cur.left.height <= rh):
node = self._rotate_right(node)
cur = self._rotate_left(cur)
else:
self._repair_node(node)
self._repair_node(cur)
# Ascend from "cur" to the root of the tree.
# Repair and/or rotate as needed.
while cur.parent is not None:
cur = cur.parent
assert cur.left is not None
assert cur.right is not None
if cur.left.height + 1 < cur.right.height:
cur = self._rotate_left(cur)
else:
self._repair_node(cur)
return cur
def _join_left(self,
ltree: Optional[Node[_NameT, _ElemT]],
node: Node[_NameT, _ElemT],
rtree: Node[_NameT, _ElemT]
) -> Node[_NameT, _ElemT]:
"""Join a left subtree, middle node and right subtree together.
The right subtree must be higher than the left subtree.
Return the root node of the joined tree.
"""
lh = 0 if ltree is None else ltree.height
rh = rtree.height
assert lh + 1 < rh
# Descend down the left spine of "rtree".
# Stop at a node with compatible height, then insert "node"
# and attach "ltree".
#
# rtree
# / \
# X
# / \
# cur --> X
# / \
# node
# / \
# ltree X
#
# Descend to a point with compatible height.
cur = rtree
while (cur.left is not None) and (cur.left.height > lh + 1):
cur = cur.left
# Insert "node" and "ltree".
node.left = ltree
node.right = cur.left
if ltree is not None:
ltree.parent = node
if node.right is not None:
node.right.parent = node
cur.left = node
node.parent = cur
# A double rotation may be necessary.
if (cur.right is None) or (cur.right.height <= lh):
node = self._rotate_left(node)
cur = self._rotate_right(cur)
else:
self._repair_node(node)
self._repair_node(cur)
# Ascend from "cur" to the root of the tree.
# Repair and/or rotate as needed.
while cur.parent is not None:
cur = cur.parent
assert cur.left is not None
assert cur.right is not None
if cur.left.height > cur.right.height + 1:
cur = self._rotate_right(cur)
else:
self._repair_node(cur)
return cur
def _join(self,
ltree: Optional[Node[_NameT, _ElemT]],
node: Node[_NameT, _ElemT],
rtree: Optional[Node[_NameT, _ElemT]]
) -> Node[_NameT, _ElemT]:
"""Join a left subtree, middle node and right subtree together.
The left or right subtree may initially be a child of the middle
node; such links will be broken as needed.
The left and right subtrees must be consistent, AVL-balanced trees.
Parent pointers of the subtrees are ignored.
The middle node is considered as a single node.
Its parent and child pointers are ignored.
Return the root node of the joined tree.
"""
lh = 0 if ltree is None else ltree.height
rh = 0 if rtree is None else rtree.height
if lh > rh + 1:
assert ltree is not None
ltree.parent = None
return self._join_right(ltree, node, rtree)
elif lh + 1 < rh:
assert rtree is not None
rtree.parent = None
return self._join_left(ltree, node, rtree)
else:
# Subtree heights are compatible. Just join them.
#
# node
# / \
# ltree rtree
# / \ / \
#
node.parent = None
node.left = ltree
if ltree is not None:
ltree.parent = node
node.right = rtree
if rtree is not None:
rtree.parent = node
self._repair_node(node)
return node
def _merge_tree(self,
ltree: Node[_NameT, _ElemT],
rtree: Node[_NameT, _ElemT]
) -> tuple[Node[_NameT, _ElemT], Node[_NameT, _ElemT]]:
"""Merge two trees.
Return a tuple (split_node, merged_tree).
"""
# Find the left-most node of the right tree.
split_node = rtree
while split_node.left is not None:
split_node = split_node.left
# Delete the split_node from its tree.
parent = split_node.parent
if split_node.right is not None:
split_node.right.parent = parent
if parent is None:
rtree_new = split_node.right
else:
# Repair and rebalance the ancestors of split_node.
parent.left = split_node.right
rtree_new = self._rebalance_up(parent)
# Join the two trees via the split_node.
merged_tree = self._join(ltree, split_node, rtree_new)
return (merged_tree, split_node)
def _split_tree(self,
split_node: Node[_NameT, _ElemT]
) -> tuple[Node[_NameT, _ElemT], Node[_NameT, _ElemT]]:
"""Split a tree on a specified node.
Two new trees will be constructed.
All nodes to the left of "split_node" will go to the left tree.
All nodes to the right of "split_node", and "split_node" itself,
will go to the right tree.
Return tuple (ltree, rtree),
where ltree contains all nodes left of the split-node,
rtree contains the split-nodes and all nodes to its right.
"""
# Assign the descendants of "split_node" to the appropriate trees
# and detach them from "split_node".
ltree = split_node.left
rtree = split_node.right
split_node.left = None
split_node.right = None
if ltree is not None:
ltree.parent = None
if rtree is not None:
rtree.parent = None
# Detach "split_node" from its parent (if any).
parent = split_node.parent
split_node.parent = None
# Assign "split_node" to the right tree.
rtree = self._join(None, split_node, rtree)
# Walk up to the root of the tree.
# On the way up, detach each node from its parent and join it,
# and its descendants, to the appropriate tree.
node = split_node
while parent is not None:
# Ascend to the parent node.
child = node
node = parent
parent = node.parent
# Detach "node" from its parent.
node.parent = None
if node.left is child:
# "split_node" was located in the left subtree of "node".
# This implies that "node" must be joined to the right tree.
rtree = self._join(rtree, node, node.right)
else:
# "split_node" was located in the right subtree of "node".
# This implies that "node" must be joined to the right tree.
assert node.right is child
ltree = self._join(node.left, node, ltree)
assert ltree is not None
return (ltree, rtree)
class PriorityQueue(Generic[_ElemT]):
"""Priority queue based on a binary heap."""
__slots__ = ("heap", )
class Node(Generic[_ElemT2]):
"""Node in the priority queue."""
__slots__ = ("index", "prio", "data")
def __init__(
self,
index: int,
prio: float,
data: _ElemT2
) -> None:
self.index = index
self.prio = prio
self.data = data
def __init__(self) -> None:
"""Initialize an empty queue."""
self.heap: list[PriorityQueue.Node[_ElemT]] = []
def clear(self) -> None:
"""Remove all elements from the queue.
This function takes time O(n).
"""
self.heap.clear()
def empty(self) -> bool:
"""Return True if the queue is empty."""
return (not self.heap)
def find_min(self) -> Node[_ElemT]:
"""Return the minimum-priority node.
This function takes time O(1).
"""
if not self.heap:
raise IndexError("Queue is empty")
return self.heap[0]
def _sift_up(self, index: int) -> None:
"""Repair the heap along an ascending path to the root."""
node = self.heap[index]
prio = node.prio
pos = index
while pos > 0:
tpos = (pos - 1) // 2
tnode = self.heap[tpos]
if tnode.prio <= prio:
break
tnode.index = pos
self.heap[pos] = tnode
pos = tpos
if pos != index:
node.index = pos
self.heap[pos] = node
def _sift_down(self, index: int) -> None:
"""Repair the heap along a descending path."""
num_elem = len(self.heap)
node = self.heap[index]
prio = node.prio
pos = index
while True:
tpos = 2 * pos + 1
if tpos >= num_elem:
break
tnode = self.heap[tpos]
qpos = tpos + 1
if qpos < num_elem:
qnode = self.heap[qpos]
if qnode.prio <= tnode.prio:
tpos = qpos
tnode = qnode
if tnode.prio >= prio:
break
tnode.index = pos
self.heap[pos] = tnode
pos = tpos
if pos != index:
node.index = pos
self.heap[pos] = node
def insert(self, prio: float, data: _ElemT) -> Node:
"""Insert a new element into the queue.
This function takes time O(log(n)).
Returns:
Node that represents the new element.
"""
new_index = len(self.heap)
node = self.Node(new_index, prio, data)
self.heap.append(node)
self._sift_up(new_index)
return node
def delete(self, elem: Node[_ElemT]) -> None:
"""Delete the specified element from the queue.
This function takes time O(log(n)).
"""
index = elem.index
assert self.heap[index] is elem
node = self.heap.pop()
if index < len(self.heap):
node.index = index
self.heap[index] = node
if node.prio < elem.prio:
self._sift_up(index)
elif node.prio > elem.prio:
self._sift_down(index)
def decrease_prio(self, elem: Node[_ElemT], prio: float) -> None:
"""Decrease the priority of an existing element in the queue.
This function takes time O(log(n)).
"""
assert self.heap[elem.index] is elem
assert prio <= elem.prio
elem.prio = prio
self._sift_up(elem.index)
2024-06-23 19:46:50 +02:00
def increase_prio(self, elem: Node[_ElemT], prio: float) -> None:
"""Increase the priority of an existing element in the queue.
This function takes time O(log(n)).
"""
assert self.heap[elem.index] is elem
assert prio >= elem.prio
elem.prio = prio
self._sift_down(elem.index)