diff --git a/python/mwmatching.py b/python/mwmatching.py index e16773b..73d5077 100644 --- a/python/mwmatching.py +++ b/python/mwmatching.py @@ -443,23 +443,26 @@ class _NonTrivialBlossom(_Blossom): self.edges: list[tuple[int, int]] = edges # Every non-trivial blossom has a variable in the dual LPP. + # New blossoms start with dual variable 0. # - # If this is a top-level S-blossom, - # "dual_var" is the value of the dual variable minus 2 times - # the running sum of delta steps. + # The value of the dual variable changes through delta steps, + # but these changes are implemented as lazy updates. # - # If this is a top-level T-blossom, - # "dual_var" is the value of the dual variable plus 2 times - # the running sum of delta steps. + # The true dual value of a top-level S-blossom is + # blossom.dual_var + ctx.delta_sum_2x # - # In all other cases, - # "dual_var" is the current value of the dual variable. + # The true dual value of a top-level T-blossom is + # blossom.dual_var - ctx.delta_sum_2x + # + # The true dual value of any other type of blossom is simply + # blossom.dual_var # # Note that "dual_var" is invariant under delta steps. - # - # New blossoms start with dual variable 0. self.dual_var: float = 0 + # Support variable for lazy updating of vertex dual variables. + self.vertex_dual_offset: float = 0 + # If this is a top-level T-blossom, # "delta4_node" is the corresponding node in the delta4 queue. # Otherwise "delta4_node" is None. @@ -540,23 +543,29 @@ class _MatchingContext: # Initially all vertices are trivial top-level blossoms. self.vertex_top_blossom: list[_Blossom] = self.trivial_blossom.copy() - # Initial dual value of all vertices times 2. - # Multiplication by 2 ensures that the values are integers - # if all edge weights are integers. + # All vertex duals are initialized to half the maximum edge weight. # - # Vertex duals are initialized to half the maximum edge weight. + # "start_vertex_dual_2x" is 2 times the initial vertex dual value. + # + # Pre-multiplication by 2 ensures that the values are integers + # if all edge weights are integers. self.start_vertex_dual_2x = max(w for (_x, _y, w) in graph.edges) # Every vertex has a variable in the dual LPP. # - # For an unlabeled vertex "x", - # "vertex_dual_2x[x]" is 2 times the dual variable of vertex "x". + # The value of the dual variable changes through delta steps, + # but these changes are implemented as lazy updates. # - # For an S-vertex "x", - # "vertex_dual_2x[x]" is 2 times the dual variable of vertex "x" - # plus two times the running sum of delta steps. + # The true dual value of an S-vertex is + # (vertex_dual_2x[x] - delta_sum_2x) / 2 # - # For a T-vertex "x", ... TODO + # The true dual value of a T-vertex is + # (vertex_dual_2x[x] + delta_sum_2x + B(x).vertex_dual_offset) / 2 + # + # The true dual value of an unlabeled vertex is + # (vertex_dual_2x[x] + B(x).vertex_dual_offset) / 2 + # + # Note that "vertex_dual_2x" is invariant under delta steps. self.vertex_dual_2x: list[float] self.vertex_dual_2x = num_vertex * [self.start_vertex_dual_2x] @@ -606,6 +615,14 @@ class _MatchingContext: dual_2x -= self.delta_sum_2x if by.label == _LABEL_S: dual_2x -= self.delta_sum_2x + if bx.label == _LABEL_T: + dual_2x += self.delta_sum_2x + if by.label == _LABEL_T: + dual_2x += self.delta_sum_2x + if isinstance(bx, _NonTrivialBlossom): + dual_2x += bx.vertex_dual_offset + if isinstance(by, _NonTrivialBlossom): + dual_2x += by.vertex_dual_offset return dual_2x - 2 * w # @@ -711,20 +728,18 @@ class _MatchingContext: def assign_vertex_label_s(self, blossom: _Blossom) -> None: """Adjust after assigning label S to previously unlabeled vertices.""" - # Prepare for lazy updating of S-vertex dual variables. - vertices = blossom.vertices() - for x in vertices: - self.vertex_dual_2x[x] += self.delta_sum_2x - # Add the new S-vertices to the scan queue. + vertices = blossom.vertices() self.scan_queue.extend(vertices) - def remove_vertex_label_s(self, blossom: _Blossom) -> None: - """Adjust after removing labels from S-vertices.""" + # Prepare for lazy updating of S-vertex dual variables. + vertex_dual_fixup = self.delta_sum_2x + if isinstance(blossom, _NonTrivialBlossom): + vertex_dual_fixup += blossom.vertex_dual_offset + blossom.vertex_dual_offset = 0 - # Catch up with lazy updates to S-vertex dual variables. - for x in blossom.vertices(): - self.vertex_dual_2x[x] -= self.delta_sum_2x + for x in vertices: + self.vertex_dual_2x[x] += vertex_dual_fixup def assign_blossom_label_t(self, blossom: _Blossom) -> None: """Assign label T to an unlabeled top-level blossom.""" @@ -735,14 +750,22 @@ class _MatchingContext: if isinstance(blossom, _NonTrivialBlossom): - # Prepare for lazy updates to T-blossom dual variables. + # Prepare for lazy updating of T-blossom dual variables. blossom.dual_var += self.delta_sum_2x + # Prepare for lazy updating of T-vertex dual variables. + blossom.vertex_dual_offset -= self.delta_sum_2x + # Insert blossom into the delta4 queue. assert blossom.delta4_node is None blossom.delta4_node = self.delta4_queue.insert(blossom.dual_var, blossom) + else: + # Prepare for lazy updating of T-vertex dual variables. + self.vertex_dual_2x[blossom.base_vertex] -= self.delta_sum_2x + + def remove_blossom_label_t(self, blossom: _Blossom) -> None: """Remove label T from a top-level T-blossom.""" @@ -757,9 +780,66 @@ class _MatchingContext: self.delta4_queue.delete(blossom.delta4_node) blossom.delta4_node = None - # Catch up with lazy updates to T-blossom dual variable. + # Unwind lazy updates to T-blossom dual variable. blossom.dual_var -= self.delta_sum_2x + # Unwind lazy updates of T-vertex dual variables. + blossom.vertex_dual_offset += self.delta_sum_2x + + else: + # Unwind lazy updates of T-vertex dual variables. + self.vertex_dual_2x[blossom.base_vertex] += self.delta_sum_2x + + def reset_blossom_label(self, blossom: _Blossom) -> None: + """Remove blossom label and calculate true dual variables.""" + + assert blossom.parent is None + + if blossom.label == _LABEL_S: + + # Remove label. + blossom.label = _LABEL_NONE + + # Unwind lazy delta updates to S-blossom dual variable. + if isinstance(blossom, _NonTrivialBlossom): + blossom.dual_var += self.delta_sum_2x + assert blossom.vertex_dual_offset == 0 + + # Unwind lazy delta updates to S-vertex dual variables. + vertex_dual_fixup = -self.delta_sum_2x + for x in blossom.vertices(): + self.vertex_dual_2x[x] += vertex_dual_fixup + + elif blossom.label == _LABEL_T: + + # Remove label. + blossom.label = _LABEL_NONE + + # Prepare to unwind lazy delta updates to vertices. + vertex_dual_fixup = self.delta_sum_2x + + if isinstance(blossom, _NonTrivialBlossom): + # Unwind lazy delta updates to T-blossom dual variable. + blossom.dual_var -= self.delta_sum_2x + + # Prepare to unwind lazy delta updates to vertices. + vertex_dual_fixup += blossom.vertex_dual_offset + blossom.vertex_dual_offset = 0 + + # Unwind lazy delta updates to T-vertex dual variables. + for x in blossom.vertices(): + self.vertex_dual_2x[x] += vertex_dual_fixup + + else: + + # Unwind lazy delta updates to vertex dual variables. + if isinstance(blossom, _NonTrivialBlossom): + vertex_dual_fixup = blossom.vertex_dual_offset + blossom.vertex_dual_offset = 0 + + for x in blossom.vertices(): + self.vertex_dual_2x[x] += vertex_dual_fixup + def reset_stage(self) -> None: """Reset data which are only valid during a stage. @@ -769,13 +849,13 @@ class _MatchingContext: This function takes time O(n). """ - # Remove blossom labels. + # Remove blossom labels and unwind lazy dual updates. for blossom in self.trivial_blossom + self.nontrivial_blossom: - if blossom.label == _LABEL_S: - self.remove_blossom_label_s(blossom) - self.remove_vertex_label_s(blossom) - elif blossom.label == _LABEL_T: - self.remove_blossom_label_t(blossom) + if blossom.parent is None: + self.reset_blossom_label(blossom) + if isinstance(blossom, _NonTrivialBlossom): + blossom.delta4_node = None + assert blossom.label == _LABEL_NONE blossom.tree_edge = None # Clear the scan queue. @@ -787,7 +867,7 @@ class _MatchingContext: # Reset delta queues. self.delta3_queue.clear() self.delta3_set.clear() - assert self.delta4_queue.empty() + self.delta4_queue.clear() def trace_alternating_paths(self, x: int, y: int) -> _AlternatingPath: """Trace back through the alternating trees from vertices "x" and "y". @@ -968,15 +1048,22 @@ class _MatchingContext: self.delta4_queue.delete(blossom.delta4_node) blossom.delta4_node = None + # Prepare to push lazy delta updates down to the sub-blossoms. + vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset + blossom.vertex_dual_offset = 0 + # Convert sub-blossoms into top-level blossoms. for sub in blossom.subblossoms: assert sub.label == _LABEL_NONE sub.parent = None if isinstance(sub, _NonTrivialBlossom): + sub.vertex_dual_offset = vertex_dual_fixup for x in sub.vertices(): self.vertex_top_blossom[x] = sub else: - self.vertex_top_blossom[sub.base_vertex] = sub + x = sub.base_vertex + self.vertex_dual_2x[x] += vertex_dual_fixup + self.vertex_top_blossom[x] = sub # The expanding blossom was part of an alternating tree, linked to # a parent node in the tree via one of its subblossoms, and linked to @@ -1030,16 +1117,24 @@ class _MatchingContext: assert blossom.parent is None assert blossom.label == _LABEL_NONE + # Prepare to push lazy delta updates down to the sub-blossoms. + vertex_dual_offset = blossom.vertex_dual_offset + blossom.vertex_dual_offset = 0 + # Convert sub-blossoms into top-level blossoms. for sub in blossom.subblossoms: assert sub.label == _LABEL_NONE sub.parent = None if isinstance(sub, _NonTrivialBlossom): + assert sub.vertex_dual_offset == 0 + sub.vertex_dual_offset = vertex_dual_offset for x in sub.vertices(): self.vertex_top_blossom[x] = sub else: - self.vertex_top_blossom[sub.base_vertex] = sub + x = sub.base_vertex + self.vertex_dual_2x[x] += vertex_dual_offset + self.vertex_top_blossom[x] = sub # Delete the expanded blossom. self.nontrivial_blossom.remove(blossom) @@ -1448,20 +1543,6 @@ class _MatchingContext: return (delta_type, delta_2x, delta_edge, delta_blossom) - def substage_apply_delta_step(self, delta_2x: float) -> None: - """Apply a delta step to the dual LPP variables.""" - - num_vertex = self.graph.num_vertex - - self.delta_sum_2x += delta_2x - - # Apply delta to dual variables of all vertices. - for x in range(num_vertex): - xlabel = self.vertex_top_blossom[x].label - if xlabel == _LABEL_T: - # T-vertex: add delta to dual variable. - self.vertex_dual_2x[x] += delta_2x - # # Main stage function: # @@ -1514,8 +1595,11 @@ class _MatchingContext: (delta_type, delta_2x, delta_edge, delta_blossom ) = self.substage_calc_dual_delta() - # Apply the delta step to the dual variables. - self.substage_apply_delta_step(delta_2x) + # Update the running sum of delta steps. + # This implicitly updates the dual variables as needed, because + # the running delta sum is taken into account when calculating + # dual values. + self.delta_sum_2x += delta_2x if delta_type == 2: # Use the edge from S-vertex to unlabeled vertex that got