diff --git a/python/mwmatching.py b/python/mwmatching.py index f51b3e1..5ec7880 100644 --- a/python/mwmatching.py +++ b/python/mwmatching.py @@ -1071,83 +1071,28 @@ class _MatchingContext: # Delete the expanded blossom. self.nontrivial_blossom.remove(blossom) - def expand_blossom_rec( - self, - blossom: _NonTrivialBlossom, - stack: list[_NonTrivialBlossom] - ) -> None: - """Expand the specified blossom and recursively expand any - sub-blossoms that have dual variable zero. - - Use the stack object instead of making direct recursive calls. - """ - - assert blossom.parent is None - - # Examine sub-blossoms. - for sub in blossom.subblossoms: - - # Mark the sub-blossom as a top-level blossom. - sub.parent = None - - if isinstance(sub, _NonTrivialBlossom): - # Non-trivial sub-blossom. - # If its dual variable is zero, expand it recursively. - if sub.dual_var == 0: - stack.append(sub) - else: - # This sub-blossom will not be expanded; - # it now becomes top-level. Update its vertices - # to point to this sub-blossom. - for x in sub.vertices(): - self.vertex_top_blossom[x] = sub - else: - # Trivial sub-blossom. Mark it as top-level vertex. - self.vertex_top_blossom[sub.base_vertex] = sub - - # Deletion of the expanded blossom will be handled in - # the function "expand_zero_dual_blossoms()". - - def expand_zero_dual_blossoms(self) -> None: - """Expand all blossoms with zero dual variable (recursively). - - Note that this function runs at the end of a stage. - Blossoms are not labeled. Least-slack edges are not tracked. + def expand_unlabeled_blossom(self, blossom: _NonTrivialBlossom) -> None: + """Expand the specified unlabeled blossom. This function takes time O(n). """ - # Use an explicit stack to avoid deep recursion. - # The stack contains a list of blossoms to be expanded. - stack: list[_NonTrivialBlossom] = [] + assert blossom.parent is None + assert blossom.label == _LABEL_NONE - # Find top-level blossoms with zero slack. - for blossom in self.nontrivial_blossom: - if blossom.parent is None: - # We typically expand only S-blossoms that were created after - # the most recent delta step. Those blossoms have _exactly_ - # zero dual. So this comparison is reliable, even in case - # of floating point edge weights. - if blossom.dual_var == 0: - stack.append(blossom) + # Convert sub-blossoms into top-level blossoms. + for sub in blossom.subblossoms: + assert sub.label == _LABEL_NONE + sub.parent = None - # Skip the rest of this function if there are no blossoms to delete. - if not stack: - return + if isinstance(sub, _NonTrivialBlossom): + for x in sub.vertices(): + self.vertex_top_blossom[x] = sub + else: + self.vertex_top_blossom[sub.base_vertex] = sub - # Expand blossoms. - while stack: - blossom = stack.pop() - self.expand_blossom_rec(blossom, stack) - - # Mark the blossom for deletion. - blossom.marker = True - - # Delete the expanded blossoms. - # We do this in one pass over the array to ensure O(n) time. - self.nontrivial_blossom = [blossom - for blossom in self.nontrivial_blossom - if not blossom.marker] + # Delete the expanded blossom. + self.nontrivial_blossom.remove(blossom) # # Augmenting: @@ -1350,8 +1295,14 @@ class _MatchingContext: """ assert self.vertex_top_blossom[x].label == _LABEL_S - # Assign label T to the unlabeled blossom. by = self.vertex_top_blossom[y] + + # Expand zero-dual blossoms before assigning label T. + while isinstance(by, _NonTrivialBlossom) and (by.dual_var == 0): + self.expand_unlabeled_blossom(by) + by = self.vertex_top_blossom[y] + + # Assign label T to the unlabeled blossom. assert by.label == _LABEL_NONE by.label = _LABEL_T by.tree_edge = (x, y) @@ -1646,11 +1597,6 @@ class _MatchingContext: if augmenting_path is not None: self.augment_matching(augmenting_path) - # Expand all blossoms with dual variable zero. - # These are typically S-blossoms, since T-blossoms normally - # get expanded as soon as their dual variable hits zero. - self.expand_zero_dual_blossoms() - # Remove all labels, clear queue. self.reset_stage() diff --git a/python/test_mwmatching.py b/python/test_mwmatching.py index 06e7504..07b50d7 100644 --- a/python/test_mwmatching.py +++ b/python/test_mwmatching.py @@ -191,6 +191,32 @@ class TestMaximumWeightMatching(unittest.TestCase): edges = [(0,2,4), (0,3,4), (0,4,1), (1,2,8), (1,5,3), (2,3,9), (3,4,7), (4,5,2)] self.assertEqual(mwm(edges), [(1,2), (3,4)]) + def test46_expand_unlabeled_blossom(self): + """expand blossom before assigning label T""" + # + # 5--[2]--3--[4] + # / | + # [0]--5--[1] 5 + # \ | + # 5--[3]--3--[5] + # + self.assertEqual( + mwm([(0,1,5), (1,2,5), (1,3,5), (2,3,5), (2,4,3), (3,5,3)]), + [(0,1), (2,4), (3,5)]) + + def test47_expand_unlabeled_outer(self): + """expand outer blossom before assigning label T""" + # + # [3]--10--[1]--15--[2]--12--[5] + # _/ \_ | | + # 11 16_ 8 15 + # / \ | | + # [4] [6]---7--[7] + # + self.assertEqual( + mwm([(1,2,15), (1,3,10), (1,4,11), (1,6,17), (2,5,12), (2,6,8), (5,7,15), (6,7,7)]), + [(1,4), (2,6), (5,7)]) + def test_fail_bad_input(self): """bad input values""" with self.assertRaises(TypeError):