diff --git a/python/mwmatching.py b/python/mwmatching.py index 0116dc8..1e7d9c0 100644 --- a/python/mwmatching.py +++ b/python/mwmatching.py @@ -69,6 +69,7 @@ def maximum_weight_matching( # Initialize the matching algorithm. ctx = _MatchingContext(graph) + ctx.start() # Improve the solution until no further improvement is possible. # @@ -81,6 +82,7 @@ def maximum_weight_matching( pass # Extract the final solution. + ctx.cleanup() pairs: list[tuple[int, int]] = [ (x, y) for (x, y, _w) in edges if ctx.vertex_mate[x] == y] @@ -592,7 +594,8 @@ class _MatchingContext: # blossoms. The priority of an edge is its slack plus 2 times the # running sum of delta steps. self.delta3_queue: PriorityQueue[int] = PriorityQueue() - self.delta3_set: set[int] = set() + self.delta3_node: list[Optional[PriorityQueue.Node]] + self.delta3_node = [None for _e in graph.edges] # Queue containing top-level non-trivial T-blossoms. # The priority of a blossom is its dual plus 2 times the running @@ -600,9 +603,12 @@ class _MatchingContext: self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue() # For each T-vertex or unlabeled vertex "x", - # "vertex_best_edge[x]" is the edge index of the least-slack edge - # between "x" and any S-vertex, or -1 if no such edge has been found. - self.vertex_best_edge: list[int] = num_vertex * [-1] + # "vertex_sedge_queue[x]" is a queue of edges between "x" and any + # S-vertex. The priority of an edge is 2 times its pseudo-slack. + self.vertex_sedge_queue: list[PriorityQueue] + self.vertex_sedge_queue = [PriorityQueue() for _x in range(num_vertex)] + self.vertex_sedge_node: list[Optional[PriorityQueue.Node]] + self.vertex_sedge_node = [None for _e in graph.edges] # Queue of S-vertices to be scanned. self.scan_queue: list[int] = [] @@ -659,22 +665,7 @@ class _MatchingContext: # after the delta step. # - def lset_reset(self) -> None: - """Reset least-slack edge tracking. - - This function takes time O(n * log(n)). - """ - num_vertex = self.graph.num_vertex - - for x in range(num_vertex): - self.vertex_best_edge[x] = -1 - self.vertex_set_node[x].set_prio(math.inf) - - self.delta2_queue.clear() - - for blossom in self.trivial_blossom + self.nontrivial_blossom: - blossom.delta2_node = None - + # TODO -- rename function, maybe refactor def lset_add_vertex_edge(self, y: int, by: _Blossom, e: int) -> None: """Add edge "e" from an S-vertex to unlabeled vertex or T-vertex "y". @@ -682,13 +673,14 @@ class _MatchingContext: """ prio = self.edge_pseudo_slack_2x(e) - best_edge = self.vertex_best_edge[y] - if best_edge != -1: - best_prio = self.edge_pseudo_slack_2x(best_edge) - if prio >= best_prio: - return + improved = (self.vertex_sedge_queue[y].empty() + or (self.vertex_sedge_queue[y].find_min().prio > prio)) - self.vertex_best_edge[y] = e + assert self.vertex_sedge_node[e] is None + self.vertex_sedge_node[e] = self.vertex_sedge_queue[y].insert(prio, e) + + if not improved: + return prev_min = by.vertex_set.min_prio() self.vertex_set_node[y].set_prio(prio) @@ -700,6 +692,7 @@ class _MatchingContext: elif prio < by.delta2_node.prio: self.delta2_queue.decrease_prio(by.delta2_node, prio) + # TODO -- rename function, maybe refactor def lset_get_best_vertex_edge(self) -> tuple[int, float]: """Return the index and slack of the least-slack edge between any S-vertex and unlabeled vertex. @@ -722,8 +715,7 @@ class _MatchingContext: assert blossom.label == _LABEL_NONE x = blossom.vertex_set.min_elem() - e = self.vertex_best_edge[x] - assert e >= 0 + e = self.vertex_sedge_queue[x].find_min().data return (e, slack_2x) @@ -731,6 +723,14 @@ class _MatchingContext: # General support routines: # + # TODO -- Although this code is correct, there is a conceptual problem + # that the division of responsibilities is asymmetric. + # For example, assign_blossom_label_s is not the exact + # opposite of remove_blossom_label_s, and remove_blossom_label_s + # has different responsibilities from remove_blossom_label_t. + # This must be fixed by shifting responsibilities or renaming + # functions, otherwise this stuff is impossible to understand. + def assign_blossom_label_s(self, blossom: _Blossom) -> None: """Assign label S to an unlabeled top-level blossom.""" assert blossom.parent is None @@ -769,6 +769,18 @@ class _MatchingContext: for x in vertices: self.vertex_dual_2x[x] += vertex_dual_fixup + # Clean up tracking of edges from vertex "x" to S-vertices. + # We maintain that tracking only for unlabeled vertices and + # T-vertices. + self.vertex_sedge_queue[x].clear() + for e in self.graph.adjacent_edges[x]: + # TODO -- Postpone this cleanup step to the scanning of + # S-vertex incident edges. + # It will be more efficient, and also simpler to + # reason about final time complexity. + self.vertex_sedge_node[e] = None + self.vertex_set_node[x].set_prio(math.inf) + def assign_blossom_label_t(self, blossom: _Blossom) -> None: """Assign label T to an unlabeled top-level blossom.""" @@ -814,10 +826,72 @@ class _MatchingContext: # Unwind lazy updates of T-vertex dual variables. blossom.vertex_dual_offset += self.delta_sum_2x + def remove_vertex_label_s(self, x: int, bx: _Blossom) -> None: + """Adjust delta tracking for S-vertex losings its label. + + This function is called when vertex "x" was an S-vertex but + has just lost its label. This requires adjustments in the tracking + of delta2 and delta3. + + This function is takes time O(q * log(n)), + where q is the number of incident edges of vertex "x". + This function is called at most once per vertex per stage. + This function therefore takes ammortized time O(m * log(n)) per stage. + """ + + # Scan the edges that are incident on "x". + edges = self.graph.edges + for e in self.graph.adjacent_edges[x]: + (p, q, _w) = edges[e] + y = p if p != x else q + + # If this edge was in the delta3_queue, remove it since + # this is no longer an edge between S-vertices. + delta3_node = self.delta3_node[e] + if delta3_node is not None: + self.delta3_queue.delete(delta3_node) + self.delta3_node[e] = None + + by = self.vertex_set_node[y].find() + if by.label == _LABEL_S: + # This is an edge between "x" and an S-vertex. + # Add this edge to "vertex_sedge_queue[x]". + # Update delta2 tracking accordingly. + self.lset_add_vertex_edge(x, bx, e) + + else: + # This is no longer an edge between "y" and an S-vertex. + # Remove this edge from "vertex_sedge_queue[y]". + # Update delta2 tracking accordingly. + # TODO -- untangle this mess + vertex_sedge_node = self.vertex_sedge_node[e] + if vertex_sedge_node is not None: + vertex_sedge_queue = self.vertex_sedge_queue[y] + vertex_sedge_queue.delete(vertex_sedge_node) + self.vertex_sedge_node[e] = None + if vertex_sedge_queue.empty(): + prio = math.inf + else: + prio = vertex_sedge_queue.find_min().prio + if prio > self.vertex_set_node[y].prio: + self.vertex_set_node[y].set_prio(prio) + if by.label == _LABEL_NONE: + assert by.delta2_node is not None + prio = by.vertex_set.min_prio() + if prio < math.inf: + prio += by.vertex_dual_offset + if prio > by.delta2_node.prio: + self.delta2_queue.increase_prio( + by.delta2_node, prio) + else: + self.delta2_queue.delete(by.delta2_node) + by.delta2_node = None + def reset_blossom_label(self, blossom: _Blossom) -> None: - """Remove blossom label and calculate true dual variables.""" + """Remove blossom label.""" assert blossom.parent is None + assert blossom.label != _LABEL_NONE if blossom.label == _LABEL_S: @@ -834,41 +908,23 @@ class _MatchingContext: for x in blossom.vertices(): self.vertex_dual_2x[x] += vertex_dual_fixup + # Adjust delta tracking for S-vertex losing its label. + self.remove_vertex_label_s(x, blossom) + elif blossom.label == _LABEL_T: - # Remove label. - blossom.label = _LABEL_NONE + self.remove_blossom_label_t(blossom) - # Unwind lazy delta updates to T-blossom dual variable. - if isinstance(blossom, _NonTrivialBlossom): - blossom.dual_var -= self.delta_sum_2x + # Since the blossom is now unlabeled, insert it in delta2_queue + # if it has at least one edge to an S-vertex. + assert blossom.delta2_node is None + prio = blossom.vertex_set.min_prio() + if prio < math.inf: + prio += blossom.vertex_dual_offset + blossom.delta2_node = self.delta2_queue.insert(prio, blossom) - # Unwind lazy delta updates to T-vertex dual variables. - vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset - blossom.vertex_dual_offset = 0 - for x in blossom.vertices(): - self.vertex_dual_2x[x] += vertex_dual_fixup - - else: - - # Unwind lazy delta updates to vertex dual variables. - vertex_dual_fixup = blossom.vertex_dual_offset - blossom.vertex_dual_offset = 0 - for x in blossom.vertices(): - self.vertex_dual_2x[x] += vertex_dual_fixup - - def reset_stage(self) -> None: - """Reset data which are only valid during a stage. - - Marks all blossoms as unlabeled, clears the queue, - and resets tracking of least-slack edges. - - This function takes time O(n * log(n)). - """ - - assert not self.scan_queue - - # Check consistency of alternating tree. + def _check_alternating_tree_consistency(self) -> None: + """TODO -- remove this function, only for debugging""" for blossom in self.trivial_blossom + self.nontrivial_blossom: if (blossom.parent is None) and (blossom.label != _LABEL_NONE): assert blossom.tree_blossoms is not None @@ -879,25 +935,58 @@ class _MatchingContext: assert bx.tree_blossoms is blossom.tree_blossoms assert by.tree_blossoms is blossom.tree_blossoms else: + assert blossom.tree_edge is None assert blossom.tree_blossoms is None - # Remove blossom labels and unwind lazy dual updates. + def reset_stage(self) -> None: + """Reset data which are only valid during a stage. + + Marks all blossoms as unlabeled, applies delayed delta updates, + and resets tracking of least-slack edges. + + This function takes time O((n + m) * log(n)). + """ + + assert not self.scan_queue + for blossom in self.trivial_blossom + self.nontrivial_blossom: - if blossom.parent is None: + + # Remove blossom label. + if (blossom.parent is None) and (blossom.label != _LABEL_NONE): self.reset_blossom_label(blossom) - if isinstance(blossom, _NonTrivialBlossom): - blossom.delta4_node = None assert blossom.label == _LABEL_NONE + + # Remove blossom from alternating tree. blossom.tree_edge = None blossom.tree_blossoms = None - # Reset least-slack edge tracking. - self.lset_reset() + # Unwind lazy delta updates to vertex dual variables. + if blossom.vertex_dual_offset != 0: + for x in blossom.vertices(): + self.vertex_dual_2x[x] += blossom.vertex_dual_offset + blossom.vertex_dual_offset = 0 - # Reset delta queues. - self.delta3_queue.clear() - self.delta3_set.clear() - self.delta4_queue.clear() + assert self.delta2_queue.empty() + assert self.delta3_queue.empty() + assert self.delta4_queue.empty() + + def remove_alternating_tree( + self, + tree_blossoms: set[_Blossom] + ) -> None: + """Reset the alternating tree consisting of the specified blossoms. + + Marks the blossoms as unlabeled. + Updates delta tracking accordingly. + + This function takes time O((n+m) * log(n)). + """ + for blossom in tree_blossoms: + assert blossom.label != _LABEL_NONE + assert blossom.tree_blossoms is tree_blossoms + self.reset_blossom_label(blossom) + blossom.tree_edge = None + blossom.tree_blossoms = None def trace_alternating_paths(self, x: int, y: int) -> _AlternatingPath: """Trace back through the alternating trees from vertices "x" and "y". @@ -1014,22 +1103,21 @@ class _MatchingContext: # Remove blossom labels. # Mark vertices inside former T-blossoms as S-vertices. - tree_blossoms = subblossoms[0].tree_blossoms - assert tree_blossoms is not None for sub in subblossoms: if sub.label == _LABEL_S: self.remove_blossom_label_s(sub) elif sub.label == _LABEL_T: self.remove_blossom_label_t(sub) self.assign_vertex_label_s(sub) - sub.tree_blossoms = None - tree_blossoms.remove(sub) # Create the new blossom object. blossom = _NonTrivialBlossom(subblossoms, path.edges) # Assign label S to the new blossom and link it to the tree. self.assign_blossom_label_s(blossom) + + tree_blossoms = subblossoms[0].tree_blossoms + assert tree_blossoms is not None blossom.tree_edge = subblossoms[0].tree_edge blossom.tree_blossoms = tree_blossoms tree_blossoms.add(blossom) @@ -1041,6 +1129,11 @@ class _MatchingContext: for sub in subblossoms: sub.parent = blossom + # Remove subblossom from the alternating tree. + sub.tree_edge = None + sub.tree_blossoms = None + tree_blossoms.remove(sub) + # Merge union-find structures. blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms]) @@ -1496,7 +1589,7 @@ class _MatchingContext: # Update tracking of least-slack edges between S-blossoms. # Priority is edge slack plus 2 times the running sum of # delta steps. - if e not in self.delta3_set: + if self.delta3_node[e] is None: prio_2x = self.edge_pseudo_slack_2x(e) if self.graph.integer_weights: # If all edge weights are integers, the slack of @@ -1505,8 +1598,7 @@ class _MatchingContext: prio = prio_2x // 2 else: prio = prio_2x / 2 - self.delta3_set.add(e) - self.delta3_queue.insert(prio, e) + self.delta3_node[e] = self.delta3_queue.insert(prio, e) else: # Update tracking of least-slack edges from vertex "y" to # any S-vertex. We do this for T-vertices and unlabeled @@ -1584,7 +1676,7 @@ class _MatchingContext: # existing edges in the queue may become intra-blossom when # a new blossom is formed. self.delta3_queue.delete(delta3_node) - self.delta3_set.remove(e) + self.delta3_node[e] = None # Compute delta4: half minimum dual variable of a top-level T-blossom. # This takes time O(log(n)). @@ -1601,9 +1693,30 @@ class _MatchingContext: return (delta_type, delta_2x, delta_edge, delta_blossom) # - # Main stage function: + # Main algorithm: # + def start(self) -> None: + """Mark each vertex as the node of an alternating tree. + + Assign label S to all vertices and add them to the scan queue. + + This function takes time O(n). + It is called once, at the beginning of the algorithm. + """ + for x in range(self.graph.num_vertex): + assert self.vertex_mate[x] == -1 + bx = self.vertex_set_node[x].find() + assert bx.base_vertex == x + + # Assign label S. + self.assign_blossom_label_s(bx) + self.assign_vertex_label_s(bx) + + # Mark blossom as the root of an alternating tree. + bx.tree_edge = None + bx.tree_blossoms = {bx} + def run_stage(self) -> bool: """Run one stage of the matching algorithm. @@ -1619,17 +1732,10 @@ class _MatchingContext: False if no further improvement is possible. """ - num_vertex = self.graph.num_vertex - - # Assign label S to all unmatched vertices and put them in the queue. - for x in range(num_vertex): - if self.vertex_mate[x] == -1: - self.extend_tree_s(x) - # Stop if all vertices are matched. # No further improvement is possible in that case. # This avoids messy calculations of delta steps without any S-vertex. - if not self.scan_queue: + if all(y >= 0 for y in self.vertex_mate): return False # Each pass through the following loop is a "substage". @@ -1639,9 +1745,10 @@ class _MatchingContext: # next substage, or stop if no further improvement is possible. # # This loop runs through at most O(n) iterations per stage. - augmenting_path = None while True: + # self._check_alternating_tree_consistency() # TODO -- remove this + # Consider the incident edges of newly labeled S-vertices. self.substage_scan() @@ -1665,11 +1772,22 @@ class _MatchingContext: elif delta_type == 3: # Use the S-to-S edge that got unlocked by the delta update. - # This may reveal an augmenting path. + # This reveals either a new blossom or an augmenting path. (x, y, _w) = self.graph.edges[delta_edge] augmenting_path = self.add_s_to_s_edge(x, y) if augmenting_path is not None: - break + # Found augmenting path. + # Delete the two alternating trees on the augmenting path. + bx = self.vertex_set_node[x].find() + by = self.vertex_set_node[y].find() + assert bx.tree_blossoms is not None + assert by.tree_blossoms is not None + self.remove_alternating_tree(bx.tree_blossoms) + self.remove_alternating_tree(by.tree_blossoms) + # Augment the matching. + self.augment_matching(augmenting_path) + # End the stage. + return True elif delta_type == 4: # Expand the T-blossom that reached dual value 0 through @@ -1678,20 +1796,22 @@ class _MatchingContext: self.expand_t_blossom(delta_blossom) else: - # No further improvement possible. End the stage. + # No further improvement possible. End the algorithm. assert delta_type == 1 - break + return False - # Augment the matching if an augmenting path was found. - if augmenting_path is not None: - self.augment_matching(augmenting_path) + def cleanup(self) -> None: + """Remove all alternating trees and mark all blossoms as unlabeled. - # Remove all labels, clear queue. + Also applies delayed updates to dual variables. + Also resets tracking of least-slack edges. + + This function takes time O(n * log(n)). + It is called only once, at the end of the algorithm. + """ + # TODO -- move that function in here self.reset_stage() - # Return True if the matching was augmented. - return (augmenting_path is not None) - def _verify_blossom_edges( ctx: _MatchingContext,