diff --git a/python/mwmatching.py b/python/mwmatching.py index af986ac..e16773b 100644 --- a/python/mwmatching.py +++ b/python/mwmatching.py @@ -448,9 +448,15 @@ class _NonTrivialBlossom(_Blossom): # "dual_var" is the value of the dual variable minus 2 times # the running sum of delta steps. # + # If this is a top-level T-blossom, + # "dual_var" is the value of the dual variable plus 2 times + # the running sum of delta steps. + # # In all other cases, # "dual_var" is the current value of the dual variable. # + # Note that "dual_var" is invariant under delta steps. + # # New blossoms start with dual variable 0. self.dual_var: float = 0 @@ -688,6 +694,7 @@ class _MatchingContext: assert blossom.label == _LABEL_NONE blossom.label = _LABEL_S + # Prepare for lazy updating of S-blossom dual variable. if isinstance(blossom, _NonTrivialBlossom): blossom.dual_var -= self.delta_sum_2x @@ -697,13 +704,14 @@ class _MatchingContext: assert blossom.label == _LABEL_S blossom.label = _LABEL_NONE + # Catch up with lazy updates to S-blossom dual variable. if isinstance(blossom, _NonTrivialBlossom): blossom.dual_var += self.delta_sum_2x def assign_vertex_label_s(self, blossom: _Blossom) -> None: """Adjust after assigning label S to previously unlabeled vertices.""" - # Adjust the vertex dual variables of the new S-vertices. + # Prepare for lazy updating of S-vertex dual variables. vertices = blossom.vertices() for x in vertices: self.vertex_dual_2x[x] += self.delta_sum_2x @@ -714,7 +722,7 @@ class _MatchingContext: def remove_vertex_label_s(self, blossom: _Blossom) -> None: """Adjust after removing labels from S-vertices.""" - # Adjust the vertex dual variables of the former S-vertices. + # Catch up with lazy updates to S-vertex dual variables. for x in blossom.vertices(): self.vertex_dual_2x[x] -= self.delta_sum_2x @@ -723,31 +731,35 @@ class _MatchingContext: assert blossom.parent is None assert blossom.label == _LABEL_NONE - - # Assign label T. blossom.label = _LABEL_T - # Insert blossom into the delta4 queue. if isinstance(blossom, _NonTrivialBlossom): + + # Prepare for lazy updates to T-blossom dual variables. + blossom.dual_var += self.delta_sum_2x + + # Insert blossom into the delta4 queue. assert blossom.delta4_node is None - prio = blossom.dual_var + self.delta_sum_2x - blossom.delta4_node = self.delta4_queue.insert(prio, blossom) + blossom.delta4_node = self.delta4_queue.insert(blossom.dual_var, + blossom) def remove_blossom_label_t(self, blossom: _Blossom) -> None: """Remove label T from a top-level T-blossom.""" assert blossom.parent is None assert blossom.label == _LABEL_T - - # Remove label. blossom.label = _LABEL_NONE - # Remove blossom from delta4 queue. if isinstance(blossom, _NonTrivialBlossom): + + # Remove blossom from delta4 queue. assert blossom.delta4_node is not None self.delta4_queue.delete(blossom.delta4_node) blossom.delta4_node = None + # Catch up with lazy updates to T-blossom dual variable. + blossom.dual_var -= self.delta_sum_2x + def reset_stage(self) -> None: """Reset data which are only valid during a stage. @@ -1428,9 +1440,10 @@ class _MatchingContext: blossom = self.delta4_queue.find_min().data assert blossom.label == _LABEL_T assert blossom.parent is None - if blossom.dual_var <= delta_2x: + blossom_dual = blossom.dual_var - self.delta_sum_2x + if blossom_dual <= delta_2x: delta_type = 4 - delta_2x = blossom.dual_var + delta_2x = blossom_dual delta_blossom = blossom return (delta_type, delta_2x, delta_edge, delta_blossom) @@ -1449,14 +1462,6 @@ class _MatchingContext: # T-vertex: add delta to dual variable. self.vertex_dual_2x[x] += delta_2x - # Apply delta to dual variables of top-level non-trivial blossoms. - for blossom in self.nontrivial_blossom: - if blossom.parent is None: - blabel = blossom.label - if blabel == _LABEL_T: - # T-blossom: subtract 2*delta from dual variable. - blossom.dual_var -= delta_2x - # # Main stage function: #