1
0
Fork 0

Keep alternating trees between stages

Delete only the trees that are involved in an augmenting path.
Keep the other trees and reuse them in the next stage.

This gives a big speedup on many cases such as random graphs.
The code is a mess, needs to be cleaned up.
This commit is contained in:
Joris van Rantwijk 2024-06-23 19:50:27 +02:00
parent 73641d7b70
commit 61524990d7
1 changed files with 219 additions and 99 deletions

View File

@ -69,6 +69,7 @@ def maximum_weight_matching(
# Initialize the matching algorithm.
ctx = _MatchingContext(graph)
ctx.start()
# Improve the solution until no further improvement is possible.
#
@ -81,6 +82,7 @@ def maximum_weight_matching(
pass
# Extract the final solution.
ctx.cleanup()
pairs: list[tuple[int, int]] = [
(x, y) for (x, y, _w) in edges if ctx.vertex_mate[x] == y]
@ -592,7 +594,8 @@ class _MatchingContext:
# blossoms. The priority of an edge is its slack plus 2 times the
# running sum of delta steps.
self.delta3_queue: PriorityQueue[int] = PriorityQueue()
self.delta3_set: set[int] = set()
self.delta3_node: list[Optional[PriorityQueue.Node]]
self.delta3_node = [None for _e in graph.edges]
# Queue containing top-level non-trivial T-blossoms.
# The priority of a blossom is its dual plus 2 times the running
@ -600,9 +603,12 @@ class _MatchingContext:
self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue()
# For each T-vertex or unlabeled vertex "x",
# "vertex_best_edge[x]" is the edge index of the least-slack edge
# between "x" and any S-vertex, or -1 if no such edge has been found.
self.vertex_best_edge: list[int] = num_vertex * [-1]
# "vertex_sedge_queue[x]" is a queue of edges between "x" and any
# S-vertex. The priority of an edge is 2 times its pseudo-slack.
self.vertex_sedge_queue: list[PriorityQueue]
self.vertex_sedge_queue = [PriorityQueue() for _x in range(num_vertex)]
self.vertex_sedge_node: list[Optional[PriorityQueue.Node]]
self.vertex_sedge_node = [None for _e in graph.edges]
# Queue of S-vertices to be scanned.
self.scan_queue: list[int] = []
@ -659,22 +665,7 @@ class _MatchingContext:
# after the delta step.
#
def lset_reset(self) -> None:
"""Reset least-slack edge tracking.
This function takes time O(n * log(n)).
"""
num_vertex = self.graph.num_vertex
for x in range(num_vertex):
self.vertex_best_edge[x] = -1
self.vertex_set_node[x].set_prio(math.inf)
self.delta2_queue.clear()
for blossom in self.trivial_blossom + self.nontrivial_blossom:
blossom.delta2_node = None
# TODO -- rename function, maybe refactor
def lset_add_vertex_edge(self, y: int, by: _Blossom, e: int) -> None:
"""Add edge "e" from an S-vertex to unlabeled vertex or T-vertex "y".
@ -682,13 +673,14 @@ class _MatchingContext:
"""
prio = self.edge_pseudo_slack_2x(e)
best_edge = self.vertex_best_edge[y]
if best_edge != -1:
best_prio = self.edge_pseudo_slack_2x(best_edge)
if prio >= best_prio:
return
improved = (self.vertex_sedge_queue[y].empty()
or (self.vertex_sedge_queue[y].find_min().prio > prio))
self.vertex_best_edge[y] = e
assert self.vertex_sedge_node[e] is None
self.vertex_sedge_node[e] = self.vertex_sedge_queue[y].insert(prio, e)
if not improved:
return
prev_min = by.vertex_set.min_prio()
self.vertex_set_node[y].set_prio(prio)
@ -700,6 +692,7 @@ class _MatchingContext:
elif prio < by.delta2_node.prio:
self.delta2_queue.decrease_prio(by.delta2_node, prio)
# TODO -- rename function, maybe refactor
def lset_get_best_vertex_edge(self) -> tuple[int, float]:
"""Return the index and slack of the least-slack edge between
any S-vertex and unlabeled vertex.
@ -722,8 +715,7 @@ class _MatchingContext:
assert blossom.label == _LABEL_NONE
x = blossom.vertex_set.min_elem()
e = self.vertex_best_edge[x]
assert e >= 0
e = self.vertex_sedge_queue[x].find_min().data
return (e, slack_2x)
@ -731,6 +723,14 @@ class _MatchingContext:
# General support routines:
#
# TODO -- Although this code is correct, there is a conceptual problem
# that the division of responsibilities is asymmetric.
# For example, assign_blossom_label_s is not the exact
# opposite of remove_blossom_label_s, and remove_blossom_label_s
# has different responsibilities from remove_blossom_label_t.
# This must be fixed by shifting responsibilities or renaming
# functions, otherwise this stuff is impossible to understand.
def assign_blossom_label_s(self, blossom: _Blossom) -> None:
"""Assign label S to an unlabeled top-level blossom."""
assert blossom.parent is None
@ -769,6 +769,18 @@ class _MatchingContext:
for x in vertices:
self.vertex_dual_2x[x] += vertex_dual_fixup
# Clean up tracking of edges from vertex "x" to S-vertices.
# We maintain that tracking only for unlabeled vertices and
# T-vertices.
self.vertex_sedge_queue[x].clear()
for e in self.graph.adjacent_edges[x]:
# TODO -- Postpone this cleanup step to the scanning of
# S-vertex incident edges.
# It will be more efficient, and also simpler to
# reason about final time complexity.
self.vertex_sedge_node[e] = None
self.vertex_set_node[x].set_prio(math.inf)
def assign_blossom_label_t(self, blossom: _Blossom) -> None:
"""Assign label T to an unlabeled top-level blossom."""
@ -814,10 +826,72 @@ class _MatchingContext:
# Unwind lazy updates of T-vertex dual variables.
blossom.vertex_dual_offset += self.delta_sum_2x
def remove_vertex_label_s(self, x: int, bx: _Blossom) -> None:
"""Adjust delta tracking for S-vertex losings its label.
This function is called when vertex "x" was an S-vertex but
has just lost its label. This requires adjustments in the tracking
of delta2 and delta3.
This function is takes time O(q * log(n)),
where q is the number of incident edges of vertex "x".
This function is called at most once per vertex per stage.
This function therefore takes ammortized time O(m * log(n)) per stage.
"""
# Scan the edges that are incident on "x".
edges = self.graph.edges
for e in self.graph.adjacent_edges[x]:
(p, q, _w) = edges[e]
y = p if p != x else q
# If this edge was in the delta3_queue, remove it since
# this is no longer an edge between S-vertices.
delta3_node = self.delta3_node[e]
if delta3_node is not None:
self.delta3_queue.delete(delta3_node)
self.delta3_node[e] = None
by = self.vertex_set_node[y].find()
if by.label == _LABEL_S:
# This is an edge between "x" and an S-vertex.
# Add this edge to "vertex_sedge_queue[x]".
# Update delta2 tracking accordingly.
self.lset_add_vertex_edge(x, bx, e)
else:
# This is no longer an edge between "y" and an S-vertex.
# Remove this edge from "vertex_sedge_queue[y]".
# Update delta2 tracking accordingly.
# TODO -- untangle this mess
vertex_sedge_node = self.vertex_sedge_node[e]
if vertex_sedge_node is not None:
vertex_sedge_queue = self.vertex_sedge_queue[y]
vertex_sedge_queue.delete(vertex_sedge_node)
self.vertex_sedge_node[e] = None
if vertex_sedge_queue.empty():
prio = math.inf
else:
prio = vertex_sedge_queue.find_min().prio
if prio > self.vertex_set_node[y].prio:
self.vertex_set_node[y].set_prio(prio)
if by.label == _LABEL_NONE:
assert by.delta2_node is not None
prio = by.vertex_set.min_prio()
if prio < math.inf:
prio += by.vertex_dual_offset
if prio > by.delta2_node.prio:
self.delta2_queue.increase_prio(
by.delta2_node, prio)
else:
self.delta2_queue.delete(by.delta2_node)
by.delta2_node = None
def reset_blossom_label(self, blossom: _Blossom) -> None:
"""Remove blossom label and calculate true dual variables."""
"""Remove blossom label."""
assert blossom.parent is None
assert blossom.label != _LABEL_NONE
if blossom.label == _LABEL_S:
@ -834,41 +908,23 @@ class _MatchingContext:
for x in blossom.vertices():
self.vertex_dual_2x[x] += vertex_dual_fixup
# Adjust delta tracking for S-vertex losing its label.
self.remove_vertex_label_s(x, blossom)
elif blossom.label == _LABEL_T:
# Remove label.
blossom.label = _LABEL_NONE
self.remove_blossom_label_t(blossom)
# Unwind lazy delta updates to T-blossom dual variable.
if isinstance(blossom, _NonTrivialBlossom):
blossom.dual_var -= self.delta_sum_2x
# Since the blossom is now unlabeled, insert it in delta2_queue
# if it has at least one edge to an S-vertex.
assert blossom.delta2_node is None
prio = blossom.vertex_set.min_prio()
if prio < math.inf:
prio += blossom.vertex_dual_offset
blossom.delta2_node = self.delta2_queue.insert(prio, blossom)
# Unwind lazy delta updates to T-vertex dual variables.
vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0
for x in blossom.vertices():
self.vertex_dual_2x[x] += vertex_dual_fixup
else:
# Unwind lazy delta updates to vertex dual variables.
vertex_dual_fixup = blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0
for x in blossom.vertices():
self.vertex_dual_2x[x] += vertex_dual_fixup
def reset_stage(self) -> None:
"""Reset data which are only valid during a stage.
Marks all blossoms as unlabeled, clears the queue,
and resets tracking of least-slack edges.
This function takes time O(n * log(n)).
"""
assert not self.scan_queue
# Check consistency of alternating tree.
def _check_alternating_tree_consistency(self) -> None:
"""TODO -- remove this function, only for debugging"""
for blossom in self.trivial_blossom + self.nontrivial_blossom:
if (blossom.parent is None) and (blossom.label != _LABEL_NONE):
assert blossom.tree_blossoms is not None
@ -879,25 +935,58 @@ class _MatchingContext:
assert bx.tree_blossoms is blossom.tree_blossoms
assert by.tree_blossoms is blossom.tree_blossoms
else:
assert blossom.tree_edge is None
assert blossom.tree_blossoms is None
# Remove blossom labels and unwind lazy dual updates.
def reset_stage(self) -> None:
"""Reset data which are only valid during a stage.
Marks all blossoms as unlabeled, applies delayed delta updates,
and resets tracking of least-slack edges.
This function takes time O((n + m) * log(n)).
"""
assert not self.scan_queue
for blossom in self.trivial_blossom + self.nontrivial_blossom:
if blossom.parent is None:
# Remove blossom label.
if (blossom.parent is None) and (blossom.label != _LABEL_NONE):
self.reset_blossom_label(blossom)
if isinstance(blossom, _NonTrivialBlossom):
blossom.delta4_node = None
assert blossom.label == _LABEL_NONE
# Remove blossom from alternating tree.
blossom.tree_edge = None
blossom.tree_blossoms = None
# Reset least-slack edge tracking.
self.lset_reset()
# Unwind lazy delta updates to vertex dual variables.
if blossom.vertex_dual_offset != 0:
for x in blossom.vertices():
self.vertex_dual_2x[x] += blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0
# Reset delta queues.
self.delta3_queue.clear()
self.delta3_set.clear()
self.delta4_queue.clear()
assert self.delta2_queue.empty()
assert self.delta3_queue.empty()
assert self.delta4_queue.empty()
def remove_alternating_tree(
self,
tree_blossoms: set[_Blossom]
) -> None:
"""Reset the alternating tree consisting of the specified blossoms.
Marks the blossoms as unlabeled.
Updates delta tracking accordingly.
This function takes time O((n+m) * log(n)).
"""
for blossom in tree_blossoms:
assert blossom.label != _LABEL_NONE
assert blossom.tree_blossoms is tree_blossoms
self.reset_blossom_label(blossom)
blossom.tree_edge = None
blossom.tree_blossoms = None
def trace_alternating_paths(self, x: int, y: int) -> _AlternatingPath:
"""Trace back through the alternating trees from vertices "x" and "y".
@ -1014,22 +1103,21 @@ class _MatchingContext:
# Remove blossom labels.
# Mark vertices inside former T-blossoms as S-vertices.
tree_blossoms = subblossoms[0].tree_blossoms
assert tree_blossoms is not None
for sub in subblossoms:
if sub.label == _LABEL_S:
self.remove_blossom_label_s(sub)
elif sub.label == _LABEL_T:
self.remove_blossom_label_t(sub)
self.assign_vertex_label_s(sub)
sub.tree_blossoms = None
tree_blossoms.remove(sub)
# Create the new blossom object.
blossom = _NonTrivialBlossom(subblossoms, path.edges)
# Assign label S to the new blossom and link it to the tree.
self.assign_blossom_label_s(blossom)
tree_blossoms = subblossoms[0].tree_blossoms
assert tree_blossoms is not None
blossom.tree_edge = subblossoms[0].tree_edge
blossom.tree_blossoms = tree_blossoms
tree_blossoms.add(blossom)
@ -1041,6 +1129,11 @@ class _MatchingContext:
for sub in subblossoms:
sub.parent = blossom
# Remove subblossom from the alternating tree.
sub.tree_edge = None
sub.tree_blossoms = None
tree_blossoms.remove(sub)
# Merge union-find structures.
blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms])
@ -1496,7 +1589,7 @@ class _MatchingContext:
# Update tracking of least-slack edges between S-blossoms.
# Priority is edge slack plus 2 times the running sum of
# delta steps.
if e not in self.delta3_set:
if self.delta3_node[e] is None:
prio_2x = self.edge_pseudo_slack_2x(e)
if self.graph.integer_weights:
# If all edge weights are integers, the slack of
@ -1505,8 +1598,7 @@ class _MatchingContext:
prio = prio_2x // 2
else:
prio = prio_2x / 2
self.delta3_set.add(e)
self.delta3_queue.insert(prio, e)
self.delta3_node[e] = self.delta3_queue.insert(prio, e)
else:
# Update tracking of least-slack edges from vertex "y" to
# any S-vertex. We do this for T-vertices and unlabeled
@ -1584,7 +1676,7 @@ class _MatchingContext:
# existing edges in the queue may become intra-blossom when
# a new blossom is formed.
self.delta3_queue.delete(delta3_node)
self.delta3_set.remove(e)
self.delta3_node[e] = None
# Compute delta4: half minimum dual variable of a top-level T-blossom.
# This takes time O(log(n)).
@ -1601,9 +1693,30 @@ class _MatchingContext:
return (delta_type, delta_2x, delta_edge, delta_blossom)
#
# Main stage function:
# Main algorithm:
#
def start(self) -> None:
"""Mark each vertex as the node of an alternating tree.
Assign label S to all vertices and add them to the scan queue.
This function takes time O(n).
It is called once, at the beginning of the algorithm.
"""
for x in range(self.graph.num_vertex):
assert self.vertex_mate[x] == -1
bx = self.vertex_set_node[x].find()
assert bx.base_vertex == x
# Assign label S.
self.assign_blossom_label_s(bx)
self.assign_vertex_label_s(bx)
# Mark blossom as the root of an alternating tree.
bx.tree_edge = None
bx.tree_blossoms = {bx}
def run_stage(self) -> bool:
"""Run one stage of the matching algorithm.
@ -1619,17 +1732,10 @@ class _MatchingContext:
False if no further improvement is possible.
"""
num_vertex = self.graph.num_vertex
# Assign label S to all unmatched vertices and put them in the queue.
for x in range(num_vertex):
if self.vertex_mate[x] == -1:
self.extend_tree_s(x)
# Stop if all vertices are matched.
# No further improvement is possible in that case.
# This avoids messy calculations of delta steps without any S-vertex.
if not self.scan_queue:
if all(y >= 0 for y in self.vertex_mate):
return False
# Each pass through the following loop is a "substage".
@ -1639,9 +1745,10 @@ class _MatchingContext:
# next substage, or stop if no further improvement is possible.
#
# This loop runs through at most O(n) iterations per stage.
augmenting_path = None
while True:
# self._check_alternating_tree_consistency() # TODO -- remove this
# Consider the incident edges of newly labeled S-vertices.
self.substage_scan()
@ -1665,11 +1772,22 @@ class _MatchingContext:
elif delta_type == 3:
# Use the S-to-S edge that got unlocked by the delta update.
# This may reveal an augmenting path.
# This reveals either a new blossom or an augmenting path.
(x, y, _w) = self.graph.edges[delta_edge]
augmenting_path = self.add_s_to_s_edge(x, y)
if augmenting_path is not None:
break
# Found augmenting path.
# Delete the two alternating trees on the augmenting path.
bx = self.vertex_set_node[x].find()
by = self.vertex_set_node[y].find()
assert bx.tree_blossoms is not None
assert by.tree_blossoms is not None
self.remove_alternating_tree(bx.tree_blossoms)
self.remove_alternating_tree(by.tree_blossoms)
# Augment the matching.
self.augment_matching(augmenting_path)
# End the stage.
return True
elif delta_type == 4:
# Expand the T-blossom that reached dual value 0 through
@ -1678,20 +1796,22 @@ class _MatchingContext:
self.expand_t_blossom(delta_blossom)
else:
# No further improvement possible. End the stage.
# No further improvement possible. End the algorithm.
assert delta_type == 1
break
return False
# Augment the matching if an augmenting path was found.
if augmenting_path is not None:
self.augment_matching(augmenting_path)
def cleanup(self) -> None:
"""Remove all alternating trees and mark all blossoms as unlabeled.
# Remove all labels, clear queue.
Also applies delayed updates to dual variables.
Also resets tracking of least-slack edges.
This function takes time O(n * log(n)).
It is called only once, at the end of the algorithm.
"""
# TODO -- move that function in here
self.reset_stage()
# Return True if the matching was augmented.
return (augmenting_path is not None)
def _verify_blossom_edges(
ctx: _MatchingContext,