1
0
Fork 0

Keep alternating trees between stages

Delete only the trees that are involved in an augmenting path.
Keep the other trees and reuse them in the next stage.

This gives a big speedup on many cases such as random graphs.
The code is a mess, needs to be cleaned up.
This commit is contained in:
Joris van Rantwijk 2024-06-23 19:50:27 +02:00
parent 73641d7b70
commit 61524990d7
1 changed files with 219 additions and 99 deletions

View File

@ -69,6 +69,7 @@ def maximum_weight_matching(
# Initialize the matching algorithm. # Initialize the matching algorithm.
ctx = _MatchingContext(graph) ctx = _MatchingContext(graph)
ctx.start()
# Improve the solution until no further improvement is possible. # Improve the solution until no further improvement is possible.
# #
@ -81,6 +82,7 @@ def maximum_weight_matching(
pass pass
# Extract the final solution. # Extract the final solution.
ctx.cleanup()
pairs: list[tuple[int, int]] = [ pairs: list[tuple[int, int]] = [
(x, y) for (x, y, _w) in edges if ctx.vertex_mate[x] == y] (x, y) for (x, y, _w) in edges if ctx.vertex_mate[x] == y]
@ -592,7 +594,8 @@ class _MatchingContext:
# blossoms. The priority of an edge is its slack plus 2 times the # blossoms. The priority of an edge is its slack plus 2 times the
# running sum of delta steps. # running sum of delta steps.
self.delta3_queue: PriorityQueue[int] = PriorityQueue() self.delta3_queue: PriorityQueue[int] = PriorityQueue()
self.delta3_set: set[int] = set() self.delta3_node: list[Optional[PriorityQueue.Node]]
self.delta3_node = [None for _e in graph.edges]
# Queue containing top-level non-trivial T-blossoms. # Queue containing top-level non-trivial T-blossoms.
# The priority of a blossom is its dual plus 2 times the running # The priority of a blossom is its dual plus 2 times the running
@ -600,9 +603,12 @@ class _MatchingContext:
self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue() self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue()
# For each T-vertex or unlabeled vertex "x", # For each T-vertex or unlabeled vertex "x",
# "vertex_best_edge[x]" is the edge index of the least-slack edge # "vertex_sedge_queue[x]" is a queue of edges between "x" and any
# between "x" and any S-vertex, or -1 if no such edge has been found. # S-vertex. The priority of an edge is 2 times its pseudo-slack.
self.vertex_best_edge: list[int] = num_vertex * [-1] self.vertex_sedge_queue: list[PriorityQueue]
self.vertex_sedge_queue = [PriorityQueue() for _x in range(num_vertex)]
self.vertex_sedge_node: list[Optional[PriorityQueue.Node]]
self.vertex_sedge_node = [None for _e in graph.edges]
# Queue of S-vertices to be scanned. # Queue of S-vertices to be scanned.
self.scan_queue: list[int] = [] self.scan_queue: list[int] = []
@ -659,22 +665,7 @@ class _MatchingContext:
# after the delta step. # after the delta step.
# #
def lset_reset(self) -> None: # TODO -- rename function, maybe refactor
"""Reset least-slack edge tracking.
This function takes time O(n * log(n)).
"""
num_vertex = self.graph.num_vertex
for x in range(num_vertex):
self.vertex_best_edge[x] = -1
self.vertex_set_node[x].set_prio(math.inf)
self.delta2_queue.clear()
for blossom in self.trivial_blossom + self.nontrivial_blossom:
blossom.delta2_node = None
def lset_add_vertex_edge(self, y: int, by: _Blossom, e: int) -> None: def lset_add_vertex_edge(self, y: int, by: _Blossom, e: int) -> None:
"""Add edge "e" from an S-vertex to unlabeled vertex or T-vertex "y". """Add edge "e" from an S-vertex to unlabeled vertex or T-vertex "y".
@ -682,13 +673,14 @@ class _MatchingContext:
""" """
prio = self.edge_pseudo_slack_2x(e) prio = self.edge_pseudo_slack_2x(e)
best_edge = self.vertex_best_edge[y] improved = (self.vertex_sedge_queue[y].empty()
if best_edge != -1: or (self.vertex_sedge_queue[y].find_min().prio > prio))
best_prio = self.edge_pseudo_slack_2x(best_edge)
if prio >= best_prio:
return
self.vertex_best_edge[y] = e assert self.vertex_sedge_node[e] is None
self.vertex_sedge_node[e] = self.vertex_sedge_queue[y].insert(prio, e)
if not improved:
return
prev_min = by.vertex_set.min_prio() prev_min = by.vertex_set.min_prio()
self.vertex_set_node[y].set_prio(prio) self.vertex_set_node[y].set_prio(prio)
@ -700,6 +692,7 @@ class _MatchingContext:
elif prio < by.delta2_node.prio: elif prio < by.delta2_node.prio:
self.delta2_queue.decrease_prio(by.delta2_node, prio) self.delta2_queue.decrease_prio(by.delta2_node, prio)
# TODO -- rename function, maybe refactor
def lset_get_best_vertex_edge(self) -> tuple[int, float]: def lset_get_best_vertex_edge(self) -> tuple[int, float]:
"""Return the index and slack of the least-slack edge between """Return the index and slack of the least-slack edge between
any S-vertex and unlabeled vertex. any S-vertex and unlabeled vertex.
@ -722,8 +715,7 @@ class _MatchingContext:
assert blossom.label == _LABEL_NONE assert blossom.label == _LABEL_NONE
x = blossom.vertex_set.min_elem() x = blossom.vertex_set.min_elem()
e = self.vertex_best_edge[x] e = self.vertex_sedge_queue[x].find_min().data
assert e >= 0
return (e, slack_2x) return (e, slack_2x)
@ -731,6 +723,14 @@ class _MatchingContext:
# General support routines: # General support routines:
# #
# TODO -- Although this code is correct, there is a conceptual problem
# that the division of responsibilities is asymmetric.
# For example, assign_blossom_label_s is not the exact
# opposite of remove_blossom_label_s, and remove_blossom_label_s
# has different responsibilities from remove_blossom_label_t.
# This must be fixed by shifting responsibilities or renaming
# functions, otherwise this stuff is impossible to understand.
def assign_blossom_label_s(self, blossom: _Blossom) -> None: def assign_blossom_label_s(self, blossom: _Blossom) -> None:
"""Assign label S to an unlabeled top-level blossom.""" """Assign label S to an unlabeled top-level blossom."""
assert blossom.parent is None assert blossom.parent is None
@ -769,6 +769,18 @@ class _MatchingContext:
for x in vertices: for x in vertices:
self.vertex_dual_2x[x] += vertex_dual_fixup self.vertex_dual_2x[x] += vertex_dual_fixup
# Clean up tracking of edges from vertex "x" to S-vertices.
# We maintain that tracking only for unlabeled vertices and
# T-vertices.
self.vertex_sedge_queue[x].clear()
for e in self.graph.adjacent_edges[x]:
# TODO -- Postpone this cleanup step to the scanning of
# S-vertex incident edges.
# It will be more efficient, and also simpler to
# reason about final time complexity.
self.vertex_sedge_node[e] = None
self.vertex_set_node[x].set_prio(math.inf)
def assign_blossom_label_t(self, blossom: _Blossom) -> None: def assign_blossom_label_t(self, blossom: _Blossom) -> None:
"""Assign label T to an unlabeled top-level blossom.""" """Assign label T to an unlabeled top-level blossom."""
@ -814,10 +826,72 @@ class _MatchingContext:
# Unwind lazy updates of T-vertex dual variables. # Unwind lazy updates of T-vertex dual variables.
blossom.vertex_dual_offset += self.delta_sum_2x blossom.vertex_dual_offset += self.delta_sum_2x
def remove_vertex_label_s(self, x: int, bx: _Blossom) -> None:
"""Adjust delta tracking for S-vertex losings its label.
This function is called when vertex "x" was an S-vertex but
has just lost its label. This requires adjustments in the tracking
of delta2 and delta3.
This function is takes time O(q * log(n)),
where q is the number of incident edges of vertex "x".
This function is called at most once per vertex per stage.
This function therefore takes ammortized time O(m * log(n)) per stage.
"""
# Scan the edges that are incident on "x".
edges = self.graph.edges
for e in self.graph.adjacent_edges[x]:
(p, q, _w) = edges[e]
y = p if p != x else q
# If this edge was in the delta3_queue, remove it since
# this is no longer an edge between S-vertices.
delta3_node = self.delta3_node[e]
if delta3_node is not None:
self.delta3_queue.delete(delta3_node)
self.delta3_node[e] = None
by = self.vertex_set_node[y].find()
if by.label == _LABEL_S:
# This is an edge between "x" and an S-vertex.
# Add this edge to "vertex_sedge_queue[x]".
# Update delta2 tracking accordingly.
self.lset_add_vertex_edge(x, bx, e)
else:
# This is no longer an edge between "y" and an S-vertex.
# Remove this edge from "vertex_sedge_queue[y]".
# Update delta2 tracking accordingly.
# TODO -- untangle this mess
vertex_sedge_node = self.vertex_sedge_node[e]
if vertex_sedge_node is not None:
vertex_sedge_queue = self.vertex_sedge_queue[y]
vertex_sedge_queue.delete(vertex_sedge_node)
self.vertex_sedge_node[e] = None
if vertex_sedge_queue.empty():
prio = math.inf
else:
prio = vertex_sedge_queue.find_min().prio
if prio > self.vertex_set_node[y].prio:
self.vertex_set_node[y].set_prio(prio)
if by.label == _LABEL_NONE:
assert by.delta2_node is not None
prio = by.vertex_set.min_prio()
if prio < math.inf:
prio += by.vertex_dual_offset
if prio > by.delta2_node.prio:
self.delta2_queue.increase_prio(
by.delta2_node, prio)
else:
self.delta2_queue.delete(by.delta2_node)
by.delta2_node = None
def reset_blossom_label(self, blossom: _Blossom) -> None: def reset_blossom_label(self, blossom: _Blossom) -> None:
"""Remove blossom label and calculate true dual variables.""" """Remove blossom label."""
assert blossom.parent is None assert blossom.parent is None
assert blossom.label != _LABEL_NONE
if blossom.label == _LABEL_S: if blossom.label == _LABEL_S:
@ -834,41 +908,23 @@ class _MatchingContext:
for x in blossom.vertices(): for x in blossom.vertices():
self.vertex_dual_2x[x] += vertex_dual_fixup self.vertex_dual_2x[x] += vertex_dual_fixup
# Adjust delta tracking for S-vertex losing its label.
self.remove_vertex_label_s(x, blossom)
elif blossom.label == _LABEL_T: elif blossom.label == _LABEL_T:
# Remove label. self.remove_blossom_label_t(blossom)
blossom.label = _LABEL_NONE
# Unwind lazy delta updates to T-blossom dual variable. # Since the blossom is now unlabeled, insert it in delta2_queue
if isinstance(blossom, _NonTrivialBlossom): # if it has at least one edge to an S-vertex.
blossom.dual_var -= self.delta_sum_2x assert blossom.delta2_node is None
prio = blossom.vertex_set.min_prio()
if prio < math.inf:
prio += blossom.vertex_dual_offset
blossom.delta2_node = self.delta2_queue.insert(prio, blossom)
# Unwind lazy delta updates to T-vertex dual variables. def _check_alternating_tree_consistency(self) -> None:
vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset """TODO -- remove this function, only for debugging"""
blossom.vertex_dual_offset = 0
for x in blossom.vertices():
self.vertex_dual_2x[x] += vertex_dual_fixup
else:
# Unwind lazy delta updates to vertex dual variables.
vertex_dual_fixup = blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0
for x in blossom.vertices():
self.vertex_dual_2x[x] += vertex_dual_fixup
def reset_stage(self) -> None:
"""Reset data which are only valid during a stage.
Marks all blossoms as unlabeled, clears the queue,
and resets tracking of least-slack edges.
This function takes time O(n * log(n)).
"""
assert not self.scan_queue
# Check consistency of alternating tree.
for blossom in self.trivial_blossom + self.nontrivial_blossom: for blossom in self.trivial_blossom + self.nontrivial_blossom:
if (blossom.parent is None) and (blossom.label != _LABEL_NONE): if (blossom.parent is None) and (blossom.label != _LABEL_NONE):
assert blossom.tree_blossoms is not None assert blossom.tree_blossoms is not None
@ -879,25 +935,58 @@ class _MatchingContext:
assert bx.tree_blossoms is blossom.tree_blossoms assert bx.tree_blossoms is blossom.tree_blossoms
assert by.tree_blossoms is blossom.tree_blossoms assert by.tree_blossoms is blossom.tree_blossoms
else: else:
assert blossom.tree_edge is None
assert blossom.tree_blossoms is None assert blossom.tree_blossoms is None
# Remove blossom labels and unwind lazy dual updates. def reset_stage(self) -> None:
"""Reset data which are only valid during a stage.
Marks all blossoms as unlabeled, applies delayed delta updates,
and resets tracking of least-slack edges.
This function takes time O((n + m) * log(n)).
"""
assert not self.scan_queue
for blossom in self.trivial_blossom + self.nontrivial_blossom: for blossom in self.trivial_blossom + self.nontrivial_blossom:
if blossom.parent is None:
# Remove blossom label.
if (blossom.parent is None) and (blossom.label != _LABEL_NONE):
self.reset_blossom_label(blossom) self.reset_blossom_label(blossom)
if isinstance(blossom, _NonTrivialBlossom):
blossom.delta4_node = None
assert blossom.label == _LABEL_NONE assert blossom.label == _LABEL_NONE
# Remove blossom from alternating tree.
blossom.tree_edge = None blossom.tree_edge = None
blossom.tree_blossoms = None blossom.tree_blossoms = None
# Reset least-slack edge tracking. # Unwind lazy delta updates to vertex dual variables.
self.lset_reset() if blossom.vertex_dual_offset != 0:
for x in blossom.vertices():
self.vertex_dual_2x[x] += blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0
# Reset delta queues. assert self.delta2_queue.empty()
self.delta3_queue.clear() assert self.delta3_queue.empty()
self.delta3_set.clear() assert self.delta4_queue.empty()
self.delta4_queue.clear()
def remove_alternating_tree(
self,
tree_blossoms: set[_Blossom]
) -> None:
"""Reset the alternating tree consisting of the specified blossoms.
Marks the blossoms as unlabeled.
Updates delta tracking accordingly.
This function takes time O((n+m) * log(n)).
"""
for blossom in tree_blossoms:
assert blossom.label != _LABEL_NONE
assert blossom.tree_blossoms is tree_blossoms
self.reset_blossom_label(blossom)
blossom.tree_edge = None
blossom.tree_blossoms = None
def trace_alternating_paths(self, x: int, y: int) -> _AlternatingPath: def trace_alternating_paths(self, x: int, y: int) -> _AlternatingPath:
"""Trace back through the alternating trees from vertices "x" and "y". """Trace back through the alternating trees from vertices "x" and "y".
@ -1014,22 +1103,21 @@ class _MatchingContext:
# Remove blossom labels. # Remove blossom labels.
# Mark vertices inside former T-blossoms as S-vertices. # Mark vertices inside former T-blossoms as S-vertices.
tree_blossoms = subblossoms[0].tree_blossoms
assert tree_blossoms is not None
for sub in subblossoms: for sub in subblossoms:
if sub.label == _LABEL_S: if sub.label == _LABEL_S:
self.remove_blossom_label_s(sub) self.remove_blossom_label_s(sub)
elif sub.label == _LABEL_T: elif sub.label == _LABEL_T:
self.remove_blossom_label_t(sub) self.remove_blossom_label_t(sub)
self.assign_vertex_label_s(sub) self.assign_vertex_label_s(sub)
sub.tree_blossoms = None
tree_blossoms.remove(sub)
# Create the new blossom object. # Create the new blossom object.
blossom = _NonTrivialBlossom(subblossoms, path.edges) blossom = _NonTrivialBlossom(subblossoms, path.edges)
# Assign label S to the new blossom and link it to the tree. # Assign label S to the new blossom and link it to the tree.
self.assign_blossom_label_s(blossom) self.assign_blossom_label_s(blossom)
tree_blossoms = subblossoms[0].tree_blossoms
assert tree_blossoms is not None
blossom.tree_edge = subblossoms[0].tree_edge blossom.tree_edge = subblossoms[0].tree_edge
blossom.tree_blossoms = tree_blossoms blossom.tree_blossoms = tree_blossoms
tree_blossoms.add(blossom) tree_blossoms.add(blossom)
@ -1041,6 +1129,11 @@ class _MatchingContext:
for sub in subblossoms: for sub in subblossoms:
sub.parent = blossom sub.parent = blossom
# Remove subblossom from the alternating tree.
sub.tree_edge = None
sub.tree_blossoms = None
tree_blossoms.remove(sub)
# Merge union-find structures. # Merge union-find structures.
blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms]) blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms])
@ -1496,7 +1589,7 @@ class _MatchingContext:
# Update tracking of least-slack edges between S-blossoms. # Update tracking of least-slack edges between S-blossoms.
# Priority is edge slack plus 2 times the running sum of # Priority is edge slack plus 2 times the running sum of
# delta steps. # delta steps.
if e not in self.delta3_set: if self.delta3_node[e] is None:
prio_2x = self.edge_pseudo_slack_2x(e) prio_2x = self.edge_pseudo_slack_2x(e)
if self.graph.integer_weights: if self.graph.integer_weights:
# If all edge weights are integers, the slack of # If all edge weights are integers, the slack of
@ -1505,8 +1598,7 @@ class _MatchingContext:
prio = prio_2x // 2 prio = prio_2x // 2
else: else:
prio = prio_2x / 2 prio = prio_2x / 2
self.delta3_set.add(e) self.delta3_node[e] = self.delta3_queue.insert(prio, e)
self.delta3_queue.insert(prio, e)
else: else:
# Update tracking of least-slack edges from vertex "y" to # Update tracking of least-slack edges from vertex "y" to
# any S-vertex. We do this for T-vertices and unlabeled # any S-vertex. We do this for T-vertices and unlabeled
@ -1584,7 +1676,7 @@ class _MatchingContext:
# existing edges in the queue may become intra-blossom when # existing edges in the queue may become intra-blossom when
# a new blossom is formed. # a new blossom is formed.
self.delta3_queue.delete(delta3_node) self.delta3_queue.delete(delta3_node)
self.delta3_set.remove(e) self.delta3_node[e] = None
# Compute delta4: half minimum dual variable of a top-level T-blossom. # Compute delta4: half minimum dual variable of a top-level T-blossom.
# This takes time O(log(n)). # This takes time O(log(n)).
@ -1601,9 +1693,30 @@ class _MatchingContext:
return (delta_type, delta_2x, delta_edge, delta_blossom) return (delta_type, delta_2x, delta_edge, delta_blossom)
# #
# Main stage function: # Main algorithm:
# #
def start(self) -> None:
"""Mark each vertex as the node of an alternating tree.
Assign label S to all vertices and add them to the scan queue.
This function takes time O(n).
It is called once, at the beginning of the algorithm.
"""
for x in range(self.graph.num_vertex):
assert self.vertex_mate[x] == -1
bx = self.vertex_set_node[x].find()
assert bx.base_vertex == x
# Assign label S.
self.assign_blossom_label_s(bx)
self.assign_vertex_label_s(bx)
# Mark blossom as the root of an alternating tree.
bx.tree_edge = None
bx.tree_blossoms = {bx}
def run_stage(self) -> bool: def run_stage(self) -> bool:
"""Run one stage of the matching algorithm. """Run one stage of the matching algorithm.
@ -1619,17 +1732,10 @@ class _MatchingContext:
False if no further improvement is possible. False if no further improvement is possible.
""" """
num_vertex = self.graph.num_vertex
# Assign label S to all unmatched vertices and put them in the queue.
for x in range(num_vertex):
if self.vertex_mate[x] == -1:
self.extend_tree_s(x)
# Stop if all vertices are matched. # Stop if all vertices are matched.
# No further improvement is possible in that case. # No further improvement is possible in that case.
# This avoids messy calculations of delta steps without any S-vertex. # This avoids messy calculations of delta steps without any S-vertex.
if not self.scan_queue: if all(y >= 0 for y in self.vertex_mate):
return False return False
# Each pass through the following loop is a "substage". # Each pass through the following loop is a "substage".
@ -1639,9 +1745,10 @@ class _MatchingContext:
# next substage, or stop if no further improvement is possible. # next substage, or stop if no further improvement is possible.
# #
# This loop runs through at most O(n) iterations per stage. # This loop runs through at most O(n) iterations per stage.
augmenting_path = None
while True: while True:
# self._check_alternating_tree_consistency() # TODO -- remove this
# Consider the incident edges of newly labeled S-vertices. # Consider the incident edges of newly labeled S-vertices.
self.substage_scan() self.substage_scan()
@ -1665,11 +1772,22 @@ class _MatchingContext:
elif delta_type == 3: elif delta_type == 3:
# Use the S-to-S edge that got unlocked by the delta update. # Use the S-to-S edge that got unlocked by the delta update.
# This may reveal an augmenting path. # This reveals either a new blossom or an augmenting path.
(x, y, _w) = self.graph.edges[delta_edge] (x, y, _w) = self.graph.edges[delta_edge]
augmenting_path = self.add_s_to_s_edge(x, y) augmenting_path = self.add_s_to_s_edge(x, y)
if augmenting_path is not None: if augmenting_path is not None:
break # Found augmenting path.
# Delete the two alternating trees on the augmenting path.
bx = self.vertex_set_node[x].find()
by = self.vertex_set_node[y].find()
assert bx.tree_blossoms is not None
assert by.tree_blossoms is not None
self.remove_alternating_tree(bx.tree_blossoms)
self.remove_alternating_tree(by.tree_blossoms)
# Augment the matching.
self.augment_matching(augmenting_path)
# End the stage.
return True
elif delta_type == 4: elif delta_type == 4:
# Expand the T-blossom that reached dual value 0 through # Expand the T-blossom that reached dual value 0 through
@ -1678,20 +1796,22 @@ class _MatchingContext:
self.expand_t_blossom(delta_blossom) self.expand_t_blossom(delta_blossom)
else: else:
# No further improvement possible. End the stage. # No further improvement possible. End the algorithm.
assert delta_type == 1 assert delta_type == 1
break return False
# Augment the matching if an augmenting path was found. def cleanup(self) -> None:
if augmenting_path is not None: """Remove all alternating trees and mark all blossoms as unlabeled.
self.augment_matching(augmenting_path)
# Remove all labels, clear queue. Also applies delayed updates to dual variables.
Also resets tracking of least-slack edges.
This function takes time O(n * log(n)).
It is called only once, at the end of the algorithm.
"""
# TODO -- move that function in here
self.reset_stage() self.reset_stage()
# Return True if the matching was augmented.
return (augmenting_path is not None)
def _verify_blossom_edges( def _verify_blossom_edges(
ctx: _MatchingContext, ctx: _MatchingContext,