Keep alternating trees between stages
Delete only the trees that are involved in an augmenting path. Keep the other trees and reuse them in the next stage. This gives a big speedup on many cases such as random graphs. The code is a mess, needs to be cleaned up.
This commit is contained in:
parent
73641d7b70
commit
61524990d7
|
@ -69,6 +69,7 @@ def maximum_weight_matching(
|
|||
|
||||
# Initialize the matching algorithm.
|
||||
ctx = _MatchingContext(graph)
|
||||
ctx.start()
|
||||
|
||||
# Improve the solution until no further improvement is possible.
|
||||
#
|
||||
|
@ -81,6 +82,7 @@ def maximum_weight_matching(
|
|||
pass
|
||||
|
||||
# Extract the final solution.
|
||||
ctx.cleanup()
|
||||
pairs: list[tuple[int, int]] = [
|
||||
(x, y) for (x, y, _w) in edges if ctx.vertex_mate[x] == y]
|
||||
|
||||
|
@ -592,7 +594,8 @@ class _MatchingContext:
|
|||
# blossoms. The priority of an edge is its slack plus 2 times the
|
||||
# running sum of delta steps.
|
||||
self.delta3_queue: PriorityQueue[int] = PriorityQueue()
|
||||
self.delta3_set: set[int] = set()
|
||||
self.delta3_node: list[Optional[PriorityQueue.Node]]
|
||||
self.delta3_node = [None for _e in graph.edges]
|
||||
|
||||
# Queue containing top-level non-trivial T-blossoms.
|
||||
# The priority of a blossom is its dual plus 2 times the running
|
||||
|
@ -600,9 +603,12 @@ class _MatchingContext:
|
|||
self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue()
|
||||
|
||||
# For each T-vertex or unlabeled vertex "x",
|
||||
# "vertex_best_edge[x]" is the edge index of the least-slack edge
|
||||
# between "x" and any S-vertex, or -1 if no such edge has been found.
|
||||
self.vertex_best_edge: list[int] = num_vertex * [-1]
|
||||
# "vertex_sedge_queue[x]" is a queue of edges between "x" and any
|
||||
# S-vertex. The priority of an edge is 2 times its pseudo-slack.
|
||||
self.vertex_sedge_queue: list[PriorityQueue]
|
||||
self.vertex_sedge_queue = [PriorityQueue() for _x in range(num_vertex)]
|
||||
self.vertex_sedge_node: list[Optional[PriorityQueue.Node]]
|
||||
self.vertex_sedge_node = [None for _e in graph.edges]
|
||||
|
||||
# Queue of S-vertices to be scanned.
|
||||
self.scan_queue: list[int] = []
|
||||
|
@ -659,22 +665,7 @@ class _MatchingContext:
|
|||
# after the delta step.
|
||||
#
|
||||
|
||||
def lset_reset(self) -> None:
|
||||
"""Reset least-slack edge tracking.
|
||||
|
||||
This function takes time O(n * log(n)).
|
||||
"""
|
||||
num_vertex = self.graph.num_vertex
|
||||
|
||||
for x in range(num_vertex):
|
||||
self.vertex_best_edge[x] = -1
|
||||
self.vertex_set_node[x].set_prio(math.inf)
|
||||
|
||||
self.delta2_queue.clear()
|
||||
|
||||
for blossom in self.trivial_blossom + self.nontrivial_blossom:
|
||||
blossom.delta2_node = None
|
||||
|
||||
# TODO -- rename function, maybe refactor
|
||||
def lset_add_vertex_edge(self, y: int, by: _Blossom, e: int) -> None:
|
||||
"""Add edge "e" from an S-vertex to unlabeled vertex or T-vertex "y".
|
||||
|
||||
|
@ -682,13 +673,14 @@ class _MatchingContext:
|
|||
"""
|
||||
prio = self.edge_pseudo_slack_2x(e)
|
||||
|
||||
best_edge = self.vertex_best_edge[y]
|
||||
if best_edge != -1:
|
||||
best_prio = self.edge_pseudo_slack_2x(best_edge)
|
||||
if prio >= best_prio:
|
||||
return
|
||||
improved = (self.vertex_sedge_queue[y].empty()
|
||||
or (self.vertex_sedge_queue[y].find_min().prio > prio))
|
||||
|
||||
self.vertex_best_edge[y] = e
|
||||
assert self.vertex_sedge_node[e] is None
|
||||
self.vertex_sedge_node[e] = self.vertex_sedge_queue[y].insert(prio, e)
|
||||
|
||||
if not improved:
|
||||
return
|
||||
|
||||
prev_min = by.vertex_set.min_prio()
|
||||
self.vertex_set_node[y].set_prio(prio)
|
||||
|
@ -700,6 +692,7 @@ class _MatchingContext:
|
|||
elif prio < by.delta2_node.prio:
|
||||
self.delta2_queue.decrease_prio(by.delta2_node, prio)
|
||||
|
||||
# TODO -- rename function, maybe refactor
|
||||
def lset_get_best_vertex_edge(self) -> tuple[int, float]:
|
||||
"""Return the index and slack of the least-slack edge between
|
||||
any S-vertex and unlabeled vertex.
|
||||
|
@ -722,8 +715,7 @@ class _MatchingContext:
|
|||
assert blossom.label == _LABEL_NONE
|
||||
|
||||
x = blossom.vertex_set.min_elem()
|
||||
e = self.vertex_best_edge[x]
|
||||
assert e >= 0
|
||||
e = self.vertex_sedge_queue[x].find_min().data
|
||||
|
||||
return (e, slack_2x)
|
||||
|
||||
|
@ -731,6 +723,14 @@ class _MatchingContext:
|
|||
# General support routines:
|
||||
#
|
||||
|
||||
# TODO -- Although this code is correct, there is a conceptual problem
|
||||
# that the division of responsibilities is asymmetric.
|
||||
# For example, assign_blossom_label_s is not the exact
|
||||
# opposite of remove_blossom_label_s, and remove_blossom_label_s
|
||||
# has different responsibilities from remove_blossom_label_t.
|
||||
# This must be fixed by shifting responsibilities or renaming
|
||||
# functions, otherwise this stuff is impossible to understand.
|
||||
|
||||
def assign_blossom_label_s(self, blossom: _Blossom) -> None:
|
||||
"""Assign label S to an unlabeled top-level blossom."""
|
||||
assert blossom.parent is None
|
||||
|
@ -769,6 +769,18 @@ class _MatchingContext:
|
|||
for x in vertices:
|
||||
self.vertex_dual_2x[x] += vertex_dual_fixup
|
||||
|
||||
# Clean up tracking of edges from vertex "x" to S-vertices.
|
||||
# We maintain that tracking only for unlabeled vertices and
|
||||
# T-vertices.
|
||||
self.vertex_sedge_queue[x].clear()
|
||||
for e in self.graph.adjacent_edges[x]:
|
||||
# TODO -- Postpone this cleanup step to the scanning of
|
||||
# S-vertex incident edges.
|
||||
# It will be more efficient, and also simpler to
|
||||
# reason about final time complexity.
|
||||
self.vertex_sedge_node[e] = None
|
||||
self.vertex_set_node[x].set_prio(math.inf)
|
||||
|
||||
def assign_blossom_label_t(self, blossom: _Blossom) -> None:
|
||||
"""Assign label T to an unlabeled top-level blossom."""
|
||||
|
||||
|
@ -814,10 +826,72 @@ class _MatchingContext:
|
|||
# Unwind lazy updates of T-vertex dual variables.
|
||||
blossom.vertex_dual_offset += self.delta_sum_2x
|
||||
|
||||
def remove_vertex_label_s(self, x: int, bx: _Blossom) -> None:
|
||||
"""Adjust delta tracking for S-vertex losings its label.
|
||||
|
||||
This function is called when vertex "x" was an S-vertex but
|
||||
has just lost its label. This requires adjustments in the tracking
|
||||
of delta2 and delta3.
|
||||
|
||||
This function is takes time O(q * log(n)),
|
||||
where q is the number of incident edges of vertex "x".
|
||||
This function is called at most once per vertex per stage.
|
||||
This function therefore takes ammortized time O(m * log(n)) per stage.
|
||||
"""
|
||||
|
||||
# Scan the edges that are incident on "x".
|
||||
edges = self.graph.edges
|
||||
for e in self.graph.adjacent_edges[x]:
|
||||
(p, q, _w) = edges[e]
|
||||
y = p if p != x else q
|
||||
|
||||
# If this edge was in the delta3_queue, remove it since
|
||||
# this is no longer an edge between S-vertices.
|
||||
delta3_node = self.delta3_node[e]
|
||||
if delta3_node is not None:
|
||||
self.delta3_queue.delete(delta3_node)
|
||||
self.delta3_node[e] = None
|
||||
|
||||
by = self.vertex_set_node[y].find()
|
||||
if by.label == _LABEL_S:
|
||||
# This is an edge between "x" and an S-vertex.
|
||||
# Add this edge to "vertex_sedge_queue[x]".
|
||||
# Update delta2 tracking accordingly.
|
||||
self.lset_add_vertex_edge(x, bx, e)
|
||||
|
||||
else:
|
||||
# This is no longer an edge between "y" and an S-vertex.
|
||||
# Remove this edge from "vertex_sedge_queue[y]".
|
||||
# Update delta2 tracking accordingly.
|
||||
# TODO -- untangle this mess
|
||||
vertex_sedge_node = self.vertex_sedge_node[e]
|
||||
if vertex_sedge_node is not None:
|
||||
vertex_sedge_queue = self.vertex_sedge_queue[y]
|
||||
vertex_sedge_queue.delete(vertex_sedge_node)
|
||||
self.vertex_sedge_node[e] = None
|
||||
if vertex_sedge_queue.empty():
|
||||
prio = math.inf
|
||||
else:
|
||||
prio = vertex_sedge_queue.find_min().prio
|
||||
if prio > self.vertex_set_node[y].prio:
|
||||
self.vertex_set_node[y].set_prio(prio)
|
||||
if by.label == _LABEL_NONE:
|
||||
assert by.delta2_node is not None
|
||||
prio = by.vertex_set.min_prio()
|
||||
if prio < math.inf:
|
||||
prio += by.vertex_dual_offset
|
||||
if prio > by.delta2_node.prio:
|
||||
self.delta2_queue.increase_prio(
|
||||
by.delta2_node, prio)
|
||||
else:
|
||||
self.delta2_queue.delete(by.delta2_node)
|
||||
by.delta2_node = None
|
||||
|
||||
def reset_blossom_label(self, blossom: _Blossom) -> None:
|
||||
"""Remove blossom label and calculate true dual variables."""
|
||||
"""Remove blossom label."""
|
||||
|
||||
assert blossom.parent is None
|
||||
assert blossom.label != _LABEL_NONE
|
||||
|
||||
if blossom.label == _LABEL_S:
|
||||
|
||||
|
@ -834,41 +908,23 @@ class _MatchingContext:
|
|||
for x in blossom.vertices():
|
||||
self.vertex_dual_2x[x] += vertex_dual_fixup
|
||||
|
||||
# Adjust delta tracking for S-vertex losing its label.
|
||||
self.remove_vertex_label_s(x, blossom)
|
||||
|
||||
elif blossom.label == _LABEL_T:
|
||||
|
||||
# Remove label.
|
||||
blossom.label = _LABEL_NONE
|
||||
self.remove_blossom_label_t(blossom)
|
||||
|
||||
# Unwind lazy delta updates to T-blossom dual variable.
|
||||
if isinstance(blossom, _NonTrivialBlossom):
|
||||
blossom.dual_var -= self.delta_sum_2x
|
||||
# Since the blossom is now unlabeled, insert it in delta2_queue
|
||||
# if it has at least one edge to an S-vertex.
|
||||
assert blossom.delta2_node is None
|
||||
prio = blossom.vertex_set.min_prio()
|
||||
if prio < math.inf:
|
||||
prio += blossom.vertex_dual_offset
|
||||
blossom.delta2_node = self.delta2_queue.insert(prio, blossom)
|
||||
|
||||
# Unwind lazy delta updates to T-vertex dual variables.
|
||||
vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
for x in blossom.vertices():
|
||||
self.vertex_dual_2x[x] += vertex_dual_fixup
|
||||
|
||||
else:
|
||||
|
||||
# Unwind lazy delta updates to vertex dual variables.
|
||||
vertex_dual_fixup = blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
for x in blossom.vertices():
|
||||
self.vertex_dual_2x[x] += vertex_dual_fixup
|
||||
|
||||
def reset_stage(self) -> None:
|
||||
"""Reset data which are only valid during a stage.
|
||||
|
||||
Marks all blossoms as unlabeled, clears the queue,
|
||||
and resets tracking of least-slack edges.
|
||||
|
||||
This function takes time O(n * log(n)).
|
||||
"""
|
||||
|
||||
assert not self.scan_queue
|
||||
|
||||
# Check consistency of alternating tree.
|
||||
def _check_alternating_tree_consistency(self) -> None:
|
||||
"""TODO -- remove this function, only for debugging"""
|
||||
for blossom in self.trivial_blossom + self.nontrivial_blossom:
|
||||
if (blossom.parent is None) and (blossom.label != _LABEL_NONE):
|
||||
assert blossom.tree_blossoms is not None
|
||||
|
@ -879,25 +935,58 @@ class _MatchingContext:
|
|||
assert bx.tree_blossoms is blossom.tree_blossoms
|
||||
assert by.tree_blossoms is blossom.tree_blossoms
|
||||
else:
|
||||
assert blossom.tree_edge is None
|
||||
assert blossom.tree_blossoms is None
|
||||
|
||||
# Remove blossom labels and unwind lazy dual updates.
|
||||
def reset_stage(self) -> None:
|
||||
"""Reset data which are only valid during a stage.
|
||||
|
||||
Marks all blossoms as unlabeled, applies delayed delta updates,
|
||||
and resets tracking of least-slack edges.
|
||||
|
||||
This function takes time O((n + m) * log(n)).
|
||||
"""
|
||||
|
||||
assert not self.scan_queue
|
||||
|
||||
for blossom in self.trivial_blossom + self.nontrivial_blossom:
|
||||
if blossom.parent is None:
|
||||
|
||||
# Remove blossom label.
|
||||
if (blossom.parent is None) and (blossom.label != _LABEL_NONE):
|
||||
self.reset_blossom_label(blossom)
|
||||
if isinstance(blossom, _NonTrivialBlossom):
|
||||
blossom.delta4_node = None
|
||||
assert blossom.label == _LABEL_NONE
|
||||
|
||||
# Remove blossom from alternating tree.
|
||||
blossom.tree_edge = None
|
||||
blossom.tree_blossoms = None
|
||||
|
||||
# Reset least-slack edge tracking.
|
||||
self.lset_reset()
|
||||
# Unwind lazy delta updates to vertex dual variables.
|
||||
if blossom.vertex_dual_offset != 0:
|
||||
for x in blossom.vertices():
|
||||
self.vertex_dual_2x[x] += blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
|
||||
# Reset delta queues.
|
||||
self.delta3_queue.clear()
|
||||
self.delta3_set.clear()
|
||||
self.delta4_queue.clear()
|
||||
assert self.delta2_queue.empty()
|
||||
assert self.delta3_queue.empty()
|
||||
assert self.delta4_queue.empty()
|
||||
|
||||
def remove_alternating_tree(
|
||||
self,
|
||||
tree_blossoms: set[_Blossom]
|
||||
) -> None:
|
||||
"""Reset the alternating tree consisting of the specified blossoms.
|
||||
|
||||
Marks the blossoms as unlabeled.
|
||||
Updates delta tracking accordingly.
|
||||
|
||||
This function takes time O((n+m) * log(n)).
|
||||
"""
|
||||
for blossom in tree_blossoms:
|
||||
assert blossom.label != _LABEL_NONE
|
||||
assert blossom.tree_blossoms is tree_blossoms
|
||||
self.reset_blossom_label(blossom)
|
||||
blossom.tree_edge = None
|
||||
blossom.tree_blossoms = None
|
||||
|
||||
def trace_alternating_paths(self, x: int, y: int) -> _AlternatingPath:
|
||||
"""Trace back through the alternating trees from vertices "x" and "y".
|
||||
|
@ -1014,22 +1103,21 @@ class _MatchingContext:
|
|||
|
||||
# Remove blossom labels.
|
||||
# Mark vertices inside former T-blossoms as S-vertices.
|
||||
tree_blossoms = subblossoms[0].tree_blossoms
|
||||
assert tree_blossoms is not None
|
||||
for sub in subblossoms:
|
||||
if sub.label == _LABEL_S:
|
||||
self.remove_blossom_label_s(sub)
|
||||
elif sub.label == _LABEL_T:
|
||||
self.remove_blossom_label_t(sub)
|
||||
self.assign_vertex_label_s(sub)
|
||||
sub.tree_blossoms = None
|
||||
tree_blossoms.remove(sub)
|
||||
|
||||
# Create the new blossom object.
|
||||
blossom = _NonTrivialBlossom(subblossoms, path.edges)
|
||||
|
||||
# Assign label S to the new blossom and link it to the tree.
|
||||
self.assign_blossom_label_s(blossom)
|
||||
|
||||
tree_blossoms = subblossoms[0].tree_blossoms
|
||||
assert tree_blossoms is not None
|
||||
blossom.tree_edge = subblossoms[0].tree_edge
|
||||
blossom.tree_blossoms = tree_blossoms
|
||||
tree_blossoms.add(blossom)
|
||||
|
@ -1041,6 +1129,11 @@ class _MatchingContext:
|
|||
for sub in subblossoms:
|
||||
sub.parent = blossom
|
||||
|
||||
# Remove subblossom from the alternating tree.
|
||||
sub.tree_edge = None
|
||||
sub.tree_blossoms = None
|
||||
tree_blossoms.remove(sub)
|
||||
|
||||
# Merge union-find structures.
|
||||
blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms])
|
||||
|
||||
|
@ -1496,7 +1589,7 @@ class _MatchingContext:
|
|||
# Update tracking of least-slack edges between S-blossoms.
|
||||
# Priority is edge slack plus 2 times the running sum of
|
||||
# delta steps.
|
||||
if e not in self.delta3_set:
|
||||
if self.delta3_node[e] is None:
|
||||
prio_2x = self.edge_pseudo_slack_2x(e)
|
||||
if self.graph.integer_weights:
|
||||
# If all edge weights are integers, the slack of
|
||||
|
@ -1505,8 +1598,7 @@ class _MatchingContext:
|
|||
prio = prio_2x // 2
|
||||
else:
|
||||
prio = prio_2x / 2
|
||||
self.delta3_set.add(e)
|
||||
self.delta3_queue.insert(prio, e)
|
||||
self.delta3_node[e] = self.delta3_queue.insert(prio, e)
|
||||
else:
|
||||
# Update tracking of least-slack edges from vertex "y" to
|
||||
# any S-vertex. We do this for T-vertices and unlabeled
|
||||
|
@ -1584,7 +1676,7 @@ class _MatchingContext:
|
|||
# existing edges in the queue may become intra-blossom when
|
||||
# a new blossom is formed.
|
||||
self.delta3_queue.delete(delta3_node)
|
||||
self.delta3_set.remove(e)
|
||||
self.delta3_node[e] = None
|
||||
|
||||
# Compute delta4: half minimum dual variable of a top-level T-blossom.
|
||||
# This takes time O(log(n)).
|
||||
|
@ -1601,9 +1693,30 @@ class _MatchingContext:
|
|||
return (delta_type, delta_2x, delta_edge, delta_blossom)
|
||||
|
||||
#
|
||||
# Main stage function:
|
||||
# Main algorithm:
|
||||
#
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark each vertex as the node of an alternating tree.
|
||||
|
||||
Assign label S to all vertices and add them to the scan queue.
|
||||
|
||||
This function takes time O(n).
|
||||
It is called once, at the beginning of the algorithm.
|
||||
"""
|
||||
for x in range(self.graph.num_vertex):
|
||||
assert self.vertex_mate[x] == -1
|
||||
bx = self.vertex_set_node[x].find()
|
||||
assert bx.base_vertex == x
|
||||
|
||||
# Assign label S.
|
||||
self.assign_blossom_label_s(bx)
|
||||
self.assign_vertex_label_s(bx)
|
||||
|
||||
# Mark blossom as the root of an alternating tree.
|
||||
bx.tree_edge = None
|
||||
bx.tree_blossoms = {bx}
|
||||
|
||||
def run_stage(self) -> bool:
|
||||
"""Run one stage of the matching algorithm.
|
||||
|
||||
|
@ -1619,17 +1732,10 @@ class _MatchingContext:
|
|||
False if no further improvement is possible.
|
||||
"""
|
||||
|
||||
num_vertex = self.graph.num_vertex
|
||||
|
||||
# Assign label S to all unmatched vertices and put them in the queue.
|
||||
for x in range(num_vertex):
|
||||
if self.vertex_mate[x] == -1:
|
||||
self.extend_tree_s(x)
|
||||
|
||||
# Stop if all vertices are matched.
|
||||
# No further improvement is possible in that case.
|
||||
# This avoids messy calculations of delta steps without any S-vertex.
|
||||
if not self.scan_queue:
|
||||
if all(y >= 0 for y in self.vertex_mate):
|
||||
return False
|
||||
|
||||
# Each pass through the following loop is a "substage".
|
||||
|
@ -1639,9 +1745,10 @@ class _MatchingContext:
|
|||
# next substage, or stop if no further improvement is possible.
|
||||
#
|
||||
# This loop runs through at most O(n) iterations per stage.
|
||||
augmenting_path = None
|
||||
while True:
|
||||
|
||||
# self._check_alternating_tree_consistency() # TODO -- remove this
|
||||
|
||||
# Consider the incident edges of newly labeled S-vertices.
|
||||
self.substage_scan()
|
||||
|
||||
|
@ -1665,11 +1772,22 @@ class _MatchingContext:
|
|||
|
||||
elif delta_type == 3:
|
||||
# Use the S-to-S edge that got unlocked by the delta update.
|
||||
# This may reveal an augmenting path.
|
||||
# This reveals either a new blossom or an augmenting path.
|
||||
(x, y, _w) = self.graph.edges[delta_edge]
|
||||
augmenting_path = self.add_s_to_s_edge(x, y)
|
||||
if augmenting_path is not None:
|
||||
break
|
||||
# Found augmenting path.
|
||||
# Delete the two alternating trees on the augmenting path.
|
||||
bx = self.vertex_set_node[x].find()
|
||||
by = self.vertex_set_node[y].find()
|
||||
assert bx.tree_blossoms is not None
|
||||
assert by.tree_blossoms is not None
|
||||
self.remove_alternating_tree(bx.tree_blossoms)
|
||||
self.remove_alternating_tree(by.tree_blossoms)
|
||||
# Augment the matching.
|
||||
self.augment_matching(augmenting_path)
|
||||
# End the stage.
|
||||
return True
|
||||
|
||||
elif delta_type == 4:
|
||||
# Expand the T-blossom that reached dual value 0 through
|
||||
|
@ -1678,20 +1796,22 @@ class _MatchingContext:
|
|||
self.expand_t_blossom(delta_blossom)
|
||||
|
||||
else:
|
||||
# No further improvement possible. End the stage.
|
||||
# No further improvement possible. End the algorithm.
|
||||
assert delta_type == 1
|
||||
break
|
||||
return False
|
||||
|
||||
# Augment the matching if an augmenting path was found.
|
||||
if augmenting_path is not None:
|
||||
self.augment_matching(augmenting_path)
|
||||
def cleanup(self) -> None:
|
||||
"""Remove all alternating trees and mark all blossoms as unlabeled.
|
||||
|
||||
# Remove all labels, clear queue.
|
||||
Also applies delayed updates to dual variables.
|
||||
Also resets tracking of least-slack edges.
|
||||
|
||||
This function takes time O(n * log(n)).
|
||||
It is called only once, at the end of the algorithm.
|
||||
"""
|
||||
# TODO -- move that function in here
|
||||
self.reset_stage()
|
||||
|
||||
# Return True if the matching was augmented.
|
||||
return (augmenting_path is not None)
|
||||
|
||||
|
||||
def _verify_blossom_edges(
|
||||
ctx: _MatchingContext,
|
||||
|
|
Loading…
Reference in New Issue