From 147640329f7f47d98dda950d344127b593086813 Mon Sep 17 00:00:00 2001 From: Joris van Rantwijk Date: Sun, 7 Jul 2024 10:30:21 +0200 Subject: [PATCH] Restructure Python code as package --- python/mwmatching/__init__.py | 11 + .../algorithm.py} | 214 +++++++++--------- python/{ => mwmatching}/datastruct.py | 0 python/mwmatching/py.typed | 0 python/tests/__init__.py | 0 .../test_algorithm.py} | 60 +++-- python/{ => tests}/test_datastruct.py | 2 +- run_checks.sh | 15 +- tests/generate/make_slow_graph.py | 14 +- 9 files changed, 158 insertions(+), 158 deletions(-) create mode 100644 python/mwmatching/__init__.py rename python/{mwmatching.py => mwmatching/algorithm.py} (93%) rename python/{ => mwmatching}/datastruct.py (100%) create mode 100644 python/mwmatching/py.typed create mode 100644 python/tests/__init__.py rename python/{test_mwmatching.py => tests/test_algorithm.py} (93%) rename python/{ => tests}/test_datastruct.py (99%) diff --git a/python/mwmatching/__init__.py b/python/mwmatching/__init__.py new file mode 100644 index 0000000..bce2581 --- /dev/null +++ b/python/mwmatching/__init__.py @@ -0,0 +1,11 @@ +""" +Algorithm for finding a maximum weight matching in general graphs. +""" + +from .algorithm import (maximum_weight_matching, + adjust_weights_for_maximum_cardinality_matching, + MatchingError) + +__all__ = ["maximum_weight_matching", + "adjust_weights_for_maximum_cardinality_matching", + "MatchingError"] diff --git a/python/mwmatching.py b/python/mwmatching/algorithm.py similarity index 93% rename from python/mwmatching.py rename to python/mwmatching/algorithm.py index 83ca277..ddfbe3c 100644 --- a/python/mwmatching.py +++ b/python/mwmatching/algorithm.py @@ -10,7 +10,7 @@ import math from collections.abc import Sequence from typing import NamedTuple, Optional -from datastruct import UnionFindQueue, PriorityQueue +from .datastruct import UnionFindQueue, PriorityQueue def maximum_weight_matching( @@ -66,10 +66,10 @@ def maximum_weight_matching( return [] # Initialize graph representation. - graph = _GraphInfo(edges) + graph = GraphInfo(edges) # Initialize the matching algorithm. - ctx = _MatchingContext(graph) + ctx = MatchingContext(graph) ctx.start() # Improve the solution until no further improvement is possible. @@ -92,7 +92,7 @@ def maximum_weight_matching( # there is a bug in the matching algorithm. # Verification only works reliably for integer weights. if graph.integer_weights: - _verify_optimum(ctx) + verify_optimum(ctx) return pairs @@ -277,7 +277,7 @@ def _remove_negative_weight_edges( return edges -class _GraphInfo: +class GraphInfo: """Representation of the input graph. These data remain unchanged while the algorithm runs. @@ -328,12 +328,12 @@ class _GraphInfo: # Each vertex may be labeled "S" (outer) or "T" (inner) or be unlabeled. -_LABEL_NONE = 0 -_LABEL_S = 1 -_LABEL_T = 2 +LABEL_NONE = 0 +LABEL_S = 1 +LABEL_T = 2 -class _Blossom: +class Blossom: """Represents a blossom in a partially matched graph. A blossom is an odd-length alternating cycle over sub-blossoms. @@ -361,7 +361,7 @@ class _Blossom: # # If this is a top-level blossom, # "parent = None". - self.parent: Optional[_NonTrivialBlossom] = None + self.parent: Optional[NonTrivialBlossom] = None # "base_vertex" is the vertex index of the base of the blossom. # This is the unique vertex which is contained in the blossom @@ -374,7 +374,7 @@ class _Blossom: # A top-level blossom that is part of an alternating tree, # has label S or T. An unlabeled top-level blossom is not part # of any alternating tree. - self.label: int = _LABEL_NONE + self.label: int = LABEL_NONE # A labeled top-level blossoms keeps track of the edge through which # it is attached to the alternating tree. @@ -389,11 +389,11 @@ class _Blossom: # "tree_blossoms" is the set of all top-level blossoms that belong # to the same alternating tree. The same set instance is shared by # all top-level blossoms in the tree. - self.tree_blossoms: "Optional[set[_Blossom]]" = None + self.tree_blossoms: "Optional[set[Blossom]]" = None # Each top-level blossom maintains a union-find datastructure # containing all vertices in the blossom. - self.vertex_set: "UnionFindQueue[_Blossom, int]" + self.vertex_set: "UnionFindQueue[Blossom, int]" self.vertex_set = UnionFindQueue(self) # If this is a top-level unlabeled blossom with an edge to an @@ -415,7 +415,7 @@ class _Blossom: return [self.base_vertex] -class _NonTrivialBlossom(_Blossom): +class NonTrivialBlossom(Blossom): """Represents a non-trivial blossom in a partially matched graph. A non-trivial blossom is a blossom that contains multiple sub-blossoms @@ -436,7 +436,7 @@ class _NonTrivialBlossom(_Blossom): def __init__( self, - subblossoms: list[_Blossom], + subblossoms: list[Blossom], edges: list[tuple[int, int]] ) -> None: """Initialize a new blossom.""" @@ -454,7 +454,7 @@ class _NonTrivialBlossom(_Blossom): # # "subblossoms[0]" is the start and end of the alternating cycle. # "subblossoms[0]" contains the base vertex of the blossom. - self.subblossoms: list[_Blossom] = subblossoms + self.subblossoms: list[Blossom] = subblossoms # "edges" is a list of edges linking the sub-blossoms. # Each edge is represented as an ordered pair "(x, y)" where "x" @@ -491,13 +491,13 @@ class _NonTrivialBlossom(_Blossom): """Return a list of vertex indices contained in the blossom.""" # Use an explicit stack to avoid deep recursion. - stack: list[_NonTrivialBlossom] = [self] + stack: list[NonTrivialBlossom] = [self] nodes: list[int] = [] while stack: b = stack.pop() for sub in b.subblossoms: - if isinstance(sub, _NonTrivialBlossom): + if isinstance(sub, NonTrivialBlossom): stack.append(sub) else: nodes.append(sub.base_vertex) @@ -505,21 +505,21 @@ class _NonTrivialBlossom(_Blossom): return nodes -class _AlternatingPath(NamedTuple): +class AlternatingPath(NamedTuple): """Represents a list of edges forming an alternating path or an alternating cycle.""" edges: list[tuple[int, int]] is_cycle: bool -class _MatchingContext: +class MatchingContext: """Holds all data used by the matching algorithm. It contains a partial solution of the matching problem and several auxiliary data structures. """ - def __init__(self, graph: _GraphInfo) -> None: + def __init__(self, graph: GraphInfo) -> None: """Set up the initial state of the matching algorithm.""" num_vertex = graph.num_vertex @@ -545,14 +545,14 @@ class _MatchingContext: # # "trivial_blossom[x]" is the trivial blossom that contains only # vertex "x". - self.trivial_blossom: list[_Blossom] = [_Blossom(x) - for x in range(num_vertex)] + self.trivial_blossom: list[Blossom] = [Blossom(x) + for x in range(num_vertex)] # Non-trivial blossoms may be created and destroyed during # the course of the algorithm. # # Initially there are no non-trivial blossoms. - self.nontrivial_blossom: set[_NonTrivialBlossom] = set() + self.nontrivial_blossom: set[NonTrivialBlossom] = set() # "vertex_set_node[x]" represents the vertex "x" inside the # union-find datastructure of its top-level blossom. @@ -593,7 +593,7 @@ class _MatchingContext: # Queue containing unlabeled top-level blossoms that have an edge to # an S-blossom. The priority of a blossom is 2 times its least slack # to an S blossom, plus 2 times the running sum of delta steps. - self.delta2_queue: PriorityQueue[_Blossom] = PriorityQueue() + self.delta2_queue: PriorityQueue[Blossom] = PriorityQueue() # Queue containing edges between S-vertices in different top-level # blossoms. The priority of an edge is its slack plus 2 times the @@ -605,7 +605,7 @@ class _MatchingContext: # Queue containing top-level non-trivial T-blossoms. # The priority of a blossom is its dual plus 2 times the running # sum of delta steps. - self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue() + self.delta4_queue: PriorityQueue[NonTrivialBlossom] = PriorityQueue() # For each T-vertex or unlabeled vertex "x", # "vertex_sedge_queue[x]" is a queue of edges between "x" and any @@ -648,7 +648,7 @@ class _MatchingContext: (x, y, w) = self.graph.edges[e] return self.vertex_dual_2x[x] + self.vertex_dual_2x[y] - 2 * w - def delta2_add_edge(self, e: int, y: int, by: _Blossom) -> None: + def delta2_add_edge(self, e: int, y: int, by: Blossom) -> None: """Add edge "e" for delta2 tracking. Edge "e" connects an S-vertex to a T-vertex or unlabeled vertex "y". @@ -674,14 +674,14 @@ class _MatchingContext: # If the blossom is unlabeled and the new edge becomes its least-slack # S-edge, insert or update the blossom in the global delta2 queue. - if by.label == _LABEL_NONE: + if by.label == LABEL_NONE: prio += by.vertex_dual_offset if by.delta2_node is None: by.delta2_node = self.delta2_queue.insert(prio, by) elif prio < by.delta2_node.prio: self.delta2_queue.decrease_prio(by.delta2_node, prio) - def delta2_remove_edge(self, e: int, y: int, by: _Blossom) -> None: + def delta2_remove_edge(self, e: int, y: int, by: Blossom) -> None: """Remove edge "e" from delta2 tracking. This function is called if an S-vertex becomes unlabeled, @@ -705,7 +705,7 @@ class _MatchingContext: # If necessary, update the priority of "y" in its UnionFindQueue. if prio > self.vertex_set_node[y].prio: self.vertex_set_node[y].set_prio(prio) - if by.label == _LABEL_NONE: + if by.label == LABEL_NONE: # Update or delete the blossom in the global delta2 queue. assert by.delta2_node is not None prio = by.vertex_set.min_prio() @@ -718,7 +718,7 @@ class _MatchingContext: self.delta2_queue.delete(by.delta2_node) by.delta2_node = None - def delta2_enable_blossom(self, blossom: _Blossom) -> None: + def delta2_enable_blossom(self, blossom: Blossom) -> None: """Enable delta2 tracking for "blossom". This function is called when a blossom becomes an unlabeled top-level @@ -733,7 +733,7 @@ class _MatchingContext: prio += blossom.vertex_dual_offset blossom.delta2_node = self.delta2_queue.insert(prio, blossom) - def delta2_disable_blossom(self, blossom: _Blossom) -> None: + def delta2_disable_blossom(self, blossom: Blossom) -> None: """Disable delta2 tracking for "blossom". The blossom will be removed from the global delta2 queue. @@ -780,7 +780,7 @@ class _MatchingContext: prio = delta2_node.prio slack_2x = prio - self.delta_sum_2x assert blossom.parent is None - assert blossom.label == _LABEL_NONE + assert blossom.label == LABEL_NONE x = blossom.vertex_set.min_elem() e = self.vertex_sedge_queue[x].find_min().data @@ -840,7 +840,7 @@ class _MatchingContext: (x, y, _w) = self.graph.edges[e] bx = self.vertex_set_node[x].find() by = self.vertex_set_node[y].find() - assert (bx.label == _LABEL_S) and (by.label == _LABEL_S) + assert (bx.label == LABEL_S) and (by.label == LABEL_S) if bx is not by: slack = delta3_node.prio - self.delta_sum_2x return (e, slack) @@ -859,7 +859,7 @@ class _MatchingContext: # Managing blossom labels: # - def assign_blossom_label_s(self, blossom: _Blossom) -> None: + def assign_blossom_label_s(self, blossom: Blossom) -> None: """Change an unlabeled top-level blossom into an S-blossom. For a blossom with "j" vertices and "k" incident edges, @@ -870,8 +870,8 @@ class _MatchingContext: """ assert blossom.parent is None - assert blossom.label == _LABEL_NONE - blossom.label = _LABEL_S + assert blossom.label == LABEL_NONE + blossom.label = LABEL_S # Labeled blossoms must not be in the delta2 queue. self.delta2_disable_blossom(blossom) @@ -887,7 +887,7 @@ class _MatchingContext: # The value of blossom.dual_var must be adjusted accordingly # when the blossom changes from unlabeled to S-blossom. # - if isinstance(blossom, _NonTrivialBlossom): + if isinstance(blossom, NonTrivialBlossom): blossom.dual_var -= self.delta_sum_2x # Apply pending updates to vertex dual variables and prepare @@ -916,20 +916,20 @@ class _MatchingContext: # Add the new S-vertices to the scan queue. self.scan_queue.extend(vertices) - def assign_blossom_label_t(self, blossom: _Blossom) -> None: + def assign_blossom_label_t(self, blossom: Blossom) -> None: """Change an unlabeled top-level blossom into a T-blossom. This function takes time O(log(n)). """ assert blossom.parent is None - assert blossom.label == _LABEL_NONE - blossom.label = _LABEL_T + assert blossom.label == LABEL_NONE + blossom.label = LABEL_T # Labeled blossoms must not be in the delta2 queue. self.delta2_disable_blossom(blossom) - if isinstance(blossom, _NonTrivialBlossom): + if isinstance(blossom, NonTrivialBlossom): # Adjust for lazy updating of T-blossom dual variables. # @@ -962,7 +962,7 @@ class _MatchingContext: # blossom.vertex_dual_offset -= self.delta_sum_2x - def remove_blossom_label_s(self, blossom: _Blossom) -> None: + def remove_blossom_label_s(self, blossom: Blossom) -> None: """Change a top-level S-blossom into an unlabeled blossom. For a blossom with "j" vertices and "k" incident edges, @@ -973,11 +973,11 @@ class _MatchingContext: """ assert blossom.parent is None - assert blossom.label == _LABEL_S - blossom.label = _LABEL_NONE + assert blossom.label == LABEL_S + blossom.label = LABEL_NONE # Unwind lazy delta updates to the S-blossom dual variable. - if isinstance(blossom, _NonTrivialBlossom): + if isinstance(blossom, NonTrivialBlossom): blossom.dual_var += self.delta_sum_2x assert blossom.vertex_dual_offset == 0 @@ -1002,7 +1002,7 @@ class _MatchingContext: self.delta3_remove_edge(e) by = self.vertex_set_node[y].find() - if by.label == _LABEL_S: + if by.label == LABEL_S: # Edge "e" connects unlabeled vertex "x" to S-vertex "y". # It must be tracked for delta2 via vertex "x". self.delta2_add_edge(e, x, blossom) @@ -1013,17 +1013,17 @@ class _MatchingContext: # removed now. self.delta2_remove_edge(e, y, by) - def remove_blossom_label_t(self, blossom: _Blossom) -> None: + def remove_blossom_label_t(self, blossom: Blossom) -> None: """Change a top-level T-blossom into an unlabeled blossom. This function takes time O(log(n)). """ assert blossom.parent is None - assert blossom.label == _LABEL_T - blossom.label = _LABEL_NONE + assert blossom.label == LABEL_T + blossom.label = LABEL_NONE - if isinstance(blossom, _NonTrivialBlossom): + if isinstance(blossom, NonTrivialBlossom): # Unlabeled blossoms are not tracked in the delta4 queue. assert blossom.delta4_node is not None @@ -1039,31 +1039,31 @@ class _MatchingContext: # Enable unlabeled top-level blossom for delta2 tracking. self.delta2_enable_blossom(blossom) - def change_s_blossom_to_subblossom(self, blossom: _Blossom) -> None: + def change_s_blossom_to_subblossom(self, blossom: Blossom) -> None: """Change a top-level S-blossom into an S-subblossom. This function takes time O(1). """ assert blossom.parent is None - assert blossom.label == _LABEL_S - blossom.label = _LABEL_NONE + assert blossom.label == LABEL_S + blossom.label = LABEL_NONE # Unwind lazy delta updates to the S-blossom dual variable. - if isinstance(blossom, _NonTrivialBlossom): + if isinstance(blossom, NonTrivialBlossom): blossom.dual_var += self.delta_sum_2x # # General support routines: # - def reset_blossom_label(self, blossom: _Blossom) -> None: + def reset_blossom_label(self, blossom: Blossom) -> None: """Remove blossom label.""" assert blossom.parent is None - assert blossom.label != _LABEL_NONE + assert blossom.label != LABEL_NONE - if blossom.label == _LABEL_S: + if blossom.label == LABEL_S: self.remove_blossom_label_s(blossom) else: self.remove_blossom_label_t(blossom) @@ -1072,7 +1072,7 @@ class _MatchingContext: """TODO -- remove this function, only for debugging""" for blossom in itertools.chain(self.trivial_blossom, self.nontrivial_blossom): - if (blossom.parent is None) and (blossom.label != _LABEL_NONE): + if (blossom.parent is None) and (blossom.label != LABEL_NONE): assert blossom.tree_blossoms is not None assert blossom in blossom.tree_blossoms if blossom.tree_edge is not None: @@ -1084,7 +1084,7 @@ class _MatchingContext: assert blossom.tree_edge is None assert blossom.tree_blossoms is None - def remove_alternating_tree(self, tree_blossoms: set[_Blossom]) -> None: + def remove_alternating_tree(self, tree_blossoms: set[Blossom]) -> None: """Reset the alternating tree consisting of the specified blossoms. Marks the blossoms as unlabeled. @@ -1093,13 +1093,13 @@ class _MatchingContext: This function takes time O((n + m) * log(n)). """ for blossom in tree_blossoms: - assert blossom.label != _LABEL_NONE + assert blossom.label != LABEL_NONE assert blossom.tree_blossoms is tree_blossoms self.reset_blossom_label(blossom) blossom.tree_edge = None blossom.tree_blossoms = None - def trace_alternating_paths(self, x: int, y: int) -> _AlternatingPath: + def trace_alternating_paths(self, x: int, y: int) -> AlternatingPath: """Trace back through the alternating trees from vertices "x" and "y". If both vertices are part of the same alternating tree, this function @@ -1119,7 +1119,7 @@ class _MatchingContext: blossoms. """ - marked_blossoms: list[_Blossom] = [] + marked_blossoms: list[Blossom] = [] # "xedges" is a list of edges used while tracing from "x". # "yedges" is a list of edges used while tracing from "y". @@ -1129,7 +1129,7 @@ class _MatchingContext: # "first_common" is the first common ancestor of "x" and "y" # in the alternating tree, or None if there is no common ancestor. - first_common: Optional[_Blossom] = None + first_common: Optional[Blossom] = None # Alternate between tracing the path from "x" and the path from "y". # This ensures that the search time is bounded by the size of the @@ -1179,14 +1179,14 @@ class _MatchingContext: # Any S-to-S alternating path must have odd length. assert len(path_edges) % 2 == 1 - return _AlternatingPath(edges=path_edges, - is_cycle=(first_common is not None)) + return AlternatingPath(edges=path_edges, + is_cycle=(first_common is not None)) # # Merge and expand blossoms: # - def make_blossom(self, path: _AlternatingPath) -> None: + def make_blossom(self, path: AlternatingPath) -> None: """Create a new blossom from an alternating cycle. Assign label S to the new blossom. @@ -1214,21 +1214,21 @@ class _MatchingContext: assert subblossoms[1:] == subblossoms_next[:-1] # Blossom must start and end with an S-sub-blossom. - assert subblossoms[0].label == _LABEL_S + assert subblossoms[0].label == LABEL_S # Remove blossom labels. # Mark vertices inside former T-blossoms as S-vertices. for sub in subblossoms: - if sub.label == _LABEL_T: + if sub.label == LABEL_T: self.remove_blossom_label_t(sub) self.assign_blossom_label_s(sub) self.change_s_blossom_to_subblossom(sub) # Create the new blossom object. - blossom = _NonTrivialBlossom(subblossoms, path.edges) + blossom = NonTrivialBlossom(subblossoms, path.edges) # Assign label S to the new blossom. - blossom.label = _LABEL_S + blossom.label = LABEL_S # Prepare for lazy updating of S-blossom dual variable. blossom.dual_var = -self.delta_sum_2x @@ -1257,9 +1257,9 @@ class _MatchingContext: @staticmethod def find_path_through_blossom( - blossom: _NonTrivialBlossom, - sub: _Blossom - ) -> tuple[list[_Blossom], list[tuple[int, int]]]: + blossom: NonTrivialBlossom, + sub: Blossom + ) -> tuple[list[Blossom], list[tuple[int, int]]]: """Construct a path with an even number of edges through the specified blossom, from sub-blossom "sub" to the base of "blossom". @@ -1285,14 +1285,14 @@ class _MatchingContext: return (nodes, edges) - def expand_unlabeled_blossom(self, blossom: _NonTrivialBlossom) -> None: + def expand_unlabeled_blossom(self, blossom: NonTrivialBlossom) -> None: """Expand the specified unlabeled blossom. This function takes total time O(n * log(n)) per stage. """ assert blossom.parent is None - assert blossom.label == _LABEL_NONE + assert blossom.label == LABEL_NONE # Remove blossom from the delta2 queue. self.delta2_disable_blossom(blossom) @@ -1306,7 +1306,7 @@ class _MatchingContext: # Convert sub-blossoms into top-level blossoms. for sub in blossom.subblossoms: - assert sub.label == _LABEL_NONE + assert sub.label == LABEL_NONE sub.parent = None assert sub.vertex_dual_offset == 0 @@ -1320,14 +1320,14 @@ class _MatchingContext: # Delete the expanded blossom. self.nontrivial_blossom.remove(blossom) - def expand_t_blossom(self, blossom: _NonTrivialBlossom) -> None: + def expand_t_blossom(self, blossom: NonTrivialBlossom) -> None: """Expand the specified T-blossom. This function takes total time O(n * log(n) + m) per stage. """ assert blossom.parent is None - assert blossom.label == _LABEL_T + assert blossom.label == LABEL_T assert blossom.delta2_node is None # Remove blossom from its alternating tree. @@ -1391,9 +1391,9 @@ class _MatchingContext: def augment_blossom_rec( self, - blossom: _NonTrivialBlossom, - sub: _Blossom, - stack: list[tuple[_NonTrivialBlossom, _Blossom]] + blossom: NonTrivialBlossom, + sub: Blossom, + stack: list[tuple[NonTrivialBlossom, Blossom]] ) -> None: """Augment along an alternating path through the specified blossom, from sub-blossom "sub" to the base vertex of the blossom. @@ -1430,11 +1430,11 @@ class _MatchingContext: # Augment through the subblossoms touching the edge (x, y). # Nothing needs to be done for trivial subblossoms. bx = path_nodes[p+1] - if isinstance(bx, _NonTrivialBlossom): + if isinstance(bx, NonTrivialBlossom): stack.append((bx, self.trivial_blossom[x])) by = path_nodes[p+2] - if isinstance(by, _NonTrivialBlossom): + if isinstance(by, NonTrivialBlossom): stack.append((by, self.trivial_blossom[y])) # Rotate the subblossom list so the new base ends up in position 0. @@ -1450,8 +1450,8 @@ class _MatchingContext: def augment_blossom( self, - blossom: _NonTrivialBlossom, - sub: _Blossom + blossom: NonTrivialBlossom, + sub: Blossom ) -> None: """Augment along an alternating path through the specified blossom, from sub-blossom "sub" to the base vertex of the blossom. @@ -1486,7 +1486,7 @@ class _MatchingContext: # Augment "blossom" from "sub" to the base vertex. self.augment_blossom_rec(blossom, sub, stack) - def augment_matching(self, path: _AlternatingPath) -> None: + def augment_matching(self, path: AlternatingPath) -> None: """Augment the matching through the specified augmenting path. This function takes time O(n). @@ -1515,11 +1515,11 @@ class _MatchingContext: # Augment the non-trivial blossoms on either side of this edge. # No action is necessary for trivial blossoms. bx = self.vertex_set_node[x].find() - if isinstance(bx, _NonTrivialBlossom): + if isinstance(bx, NonTrivialBlossom): self.augment_blossom(bx, self.trivial_blossom[x]) by = self.vertex_set_node[y].find() - if isinstance(by, _NonTrivialBlossom): + if isinstance(by, NonTrivialBlossom): self.augment_blossom(by, self.trivial_blossom[y]) # Pull the edge into the matching. @@ -1551,7 +1551,7 @@ class _MatchingContext: assert y != -1 by = self.vertex_set_node[y].find() - assert by.label == _LABEL_T + assert by.label == LABEL_T assert by.tree_blossoms is not None # Attach the blossom that contains "x" to the alternating tree. @@ -1575,10 +1575,10 @@ class _MatchingContext: bx = self.vertex_set_node[x].find() by = self.vertex_set_node[y].find() - assert bx.label == _LABEL_S + assert bx.label == LABEL_S # Expand zero-dual blossoms before assigning label T. - while isinstance(by, _NonTrivialBlossom) and (by.dual_var == 0): + while isinstance(by, NonTrivialBlossom) and (by.dual_var == 0): self.expand_unlabeled_blossom(by) by = self.vertex_set_node[y].find() @@ -1617,8 +1617,8 @@ class _MatchingContext: bx = self.vertex_set_node[x].find() by = self.vertex_set_node[y].find() - assert bx.label == _LABEL_S - assert by.label == _LABEL_S + assert bx.label == LABEL_S + assert by.label == LABEL_S assert bx is not by # Trace back through the alternating trees from "x" and "y". @@ -1680,7 +1680,7 @@ class _MatchingContext: # Double-check that "x" is an S-vertex. bx = self.vertex_set_node[x].find() - assert bx.label == _LABEL_S + assert bx.label == LABEL_S # Scan the edges that are incident on "x". # This loop runs through O(m) iterations per stage. @@ -1703,7 +1703,7 @@ class _MatchingContext: if bx is by: continue - if by.label == _LABEL_S: + if by.label == LABEL_S: self.delta3_add_edge(e) else: self.delta2_add_edge(e, y, by) @@ -1716,7 +1716,7 @@ class _MatchingContext: def calc_dual_delta_step( self - ) -> tuple[int, float, int, Optional[_NonTrivialBlossom]]: + ) -> tuple[int, float, int, Optional[NonTrivialBlossom]]: """Calculate a delta step in the dual LPP problem. This function returns the minimum of the 4 types of delta values, @@ -1740,7 +1740,7 @@ class _MatchingContext: Tuple (delta_type, delta_2x, delta_edge, delta_blossom). """ delta_edge = -1 - delta_blossom: Optional[_NonTrivialBlossom] = None + delta_blossom: Optional[NonTrivialBlossom] = None # Compute delta1: minimum dual variable of any S-vertex. # All unmatched vertices have the same dual value, and this is @@ -1770,7 +1770,7 @@ class _MatchingContext: # This takes time O(log(n)). if not self.delta4_queue.empty(): blossom = self.delta4_queue.find_min().data - assert blossom.label == _LABEL_T + assert blossom.label == LABEL_T assert blossom.parent is None blossom_dual = blossom.dual_var - self.delta_sum_2x if blossom_dual <= delta_2x: @@ -1847,7 +1847,7 @@ class _MatchingContext: # Use the edge from S-vertex to unlabeled vertex that got # unlocked through the delta update. (x, y, _w) = self.graph.edges[delta_edge] - if self.vertex_set_node[x].find().label != _LABEL_S: + if self.vertex_set_node[x].find().label != LABEL_S: (x, y) = (y, x) self.extend_tree_s_to_t(x, y) @@ -1886,9 +1886,9 @@ class _MatchingContext: self.nontrivial_blossom): # Remove blossom label. - if (blossom.parent is None) and (blossom.label != _LABEL_NONE): + if (blossom.parent is None) and (blossom.label != LABEL_NONE): self.reset_blossom_label(blossom) - assert blossom.label == _LABEL_NONE + assert blossom.label == LABEL_NONE # Remove blossom from alternating tree. blossom.tree_edge = None @@ -1906,8 +1906,8 @@ class _MatchingContext: def _verify_blossom_edges( - ctx: _MatchingContext, - blossom: _NonTrivialBlossom, + ctx: MatchingContext, + blossom: NonTrivialBlossom, edge_slack_2x: list[float] ) -> None: """Descend down the blossom tree to find edges that are contained @@ -1943,7 +1943,7 @@ def _verify_blossom_edges( path_num_matched: list[int] = [0] # Use an explicit stack to avoid deep recursion. - stack: list[tuple[_NonTrivialBlossom, int]] = [(blossom, -1)] + stack: list[tuple[NonTrivialBlossom, int]] = [(blossom, -1)] while stack: (blossom, p) = stack[-1] @@ -1969,7 +1969,7 @@ def _verify_blossom_edges( # Examine the next sub-blossom at the current level. sub = blossom.subblossoms[p] - if isinstance(sub, _NonTrivialBlossom): + if isinstance(sub, NonTrivialBlossom): # Prepare to descent into the selected sub-blossom and # scan it recursively. stack.append((sub, -1)) @@ -2031,7 +2031,7 @@ def _verify_blossom_edges( stack.pop() -def _verify_optimum(ctx: _MatchingContext) -> None: +def verify_optimum(ctx: MatchingContext) -> None: """Verify that the optimum solution has been found. This function takes time O(n**2). diff --git a/python/datastruct.py b/python/mwmatching/datastruct.py similarity index 100% rename from python/datastruct.py rename to python/mwmatching/datastruct.py diff --git a/python/mwmatching/py.typed b/python/mwmatching/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/test_mwmatching.py b/python/tests/test_algorithm.py similarity index 93% rename from python/test_mwmatching.py rename to python/tests/test_algorithm.py index e36efc1..1d0bd4a 100644 --- a/python/test_mwmatching.py +++ b/python/tests/test_algorithm.py @@ -4,10 +4,12 @@ import math import unittest from unittest.mock import Mock -import mwmatching from mwmatching import ( maximum_weight_matching as mwm, adjust_weights_for_maximum_cardinality_matching as adj) +from mwmatching.algorithm import ( + MatchingError, GraphInfo, Blossom, NonTrivialBlossom, MatchingContext, + verify_optimum) class TestMaximumWeightMatching(unittest.TestCase): @@ -431,10 +433,10 @@ class TestMaximumCardinalityMatching(unittest.TestCase): class TestGraphInfo(unittest.TestCase): - """Test _GraphInfo helper class.""" + """Test GraphInfo helper class.""" def test_empty(self): - graph = mwmatching._GraphInfo([]) + graph = GraphInfo([]) self.assertEqual(graph.num_vertex, 0) self.assertEqual(graph.edges, []) self.assertEqual(graph.adjacent_edges, []) @@ -449,8 +451,8 @@ class TestVerificationFail(unittest.TestCase): vertex_mate, vertex_dual_2x, nontrivial_blossom): - ctx = Mock(spec=mwmatching._MatchingContext) - ctx.graph = mwmatching._GraphInfo(edges) + ctx = Mock(spec=MatchingContext) + ctx.graph = GraphInfo(edges) ctx.vertex_mate = vertex_mate ctx.vertex_dual_2x = vertex_dual_2x ctx.nontrivial_blossom = nontrivial_blossom @@ -463,7 +465,7 @@ class TestVerificationFail(unittest.TestCase): vertex_mate=[-1, 2, 1], vertex_dual_2x=[0, 20, 2], nontrivial_blossom=[]) - mwmatching._verify_optimum(ctx) + verify_optimum(ctx) def test_asymmetric_matching(self): edges = [(0,1,10), (1,2,11)] @@ -472,8 +474,8 @@ class TestVerificationFail(unittest.TestCase): vertex_mate=[-1, 2, 0], vertex_dual_2x=[0, 20, 2], nontrivial_blossom=[]) - with self.assertRaises(mwmatching.MatchingError): - mwmatching._verify_optimum(ctx) + with self.assertRaises(MatchingError): + verify_optimum(ctx) def test_nonexistent_matched_edge(self): edges = [(0,1,10), (1,2,11)] @@ -482,8 +484,8 @@ class TestVerificationFail(unittest.TestCase): vertex_mate=[2, -1, 0], vertex_dual_2x=[11, 11, 11], nontrivial_blossom=[]) - with self.assertRaises(mwmatching.MatchingError): - mwmatching._verify_optimum(ctx) + with self.assertRaises(MatchingError): + verify_optimum(ctx) def test_negative_vertex_dual(self): edges = [(0,1,10), (1,2,11)] @@ -492,8 +494,8 @@ class TestVerificationFail(unittest.TestCase): vertex_mate=[-1, 2, 1], vertex_dual_2x=[-2, 22, 0], nontrivial_blossom=[]) - with self.assertRaises(mwmatching.MatchingError): - mwmatching._verify_optimum(ctx) + with self.assertRaises(MatchingError): + verify_optimum(ctx) def test_unmatched_nonzero_dual(self): edges = [(0,1,10), (1,2,11)] @@ -502,8 +504,8 @@ class TestVerificationFail(unittest.TestCase): vertex_mate=[-1, 2, 1], vertex_dual_2x=[9, 11, 11], nontrivial_blossom=[]) - with self.assertRaises(mwmatching.MatchingError): - mwmatching._verify_optimum(ctx) + with self.assertRaises(MatchingError): + verify_optimum(ctx) def test_negative_edge_slack(self): edges = [(0,1,10), (1,2,11)] @@ -512,8 +514,8 @@ class TestVerificationFail(unittest.TestCase): vertex_mate=[-1, 2, 1], vertex_dual_2x=[0, 11, 11], nontrivial_blossom=[]) - with self.assertRaises(mwmatching.MatchingError): - mwmatching._verify_optimum(ctx) + with self.assertRaises(MatchingError): + verify_optimum(ctx) def test_matched_edge_slack(self): edges = [(0,1,10), (1,2,11)] @@ -522,8 +524,8 @@ class TestVerificationFail(unittest.TestCase): vertex_mate=[-1, 2, 1], vertex_dual_2x=[0, 20, 11], nontrivial_blossom=[]) - with self.assertRaises(mwmatching.MatchingError): - mwmatching._verify_optimum(ctx) + with self.assertRaises(MatchingError): + verify_optimum(ctx) def test_negative_blossom_dual(self): # @@ -532,11 +534,8 @@ class TestVerificationFail(unittest.TestCase): # \----8-----/ # edges = [(0,1,7), (0,2,8), (1,2,9), (2,3,6)] - blossom = mwmatching._NonTrivialBlossom( - subblossoms=[ - mwmatching._Blossom(0), - mwmatching._Blossom(1), - mwmatching._Blossom(2)], + blossom = NonTrivialBlossom( + subblossoms=[Blossom(0), Blossom(1), Blossom(2)], edges=[0,2,1]) for sub in blossom.subblossoms: sub.parent = blossom @@ -546,8 +545,8 @@ class TestVerificationFail(unittest.TestCase): vertex_mate=[1, 0, 3, 2], vertex_dual_2x=[4, 6, 8, 4], nontrivial_blossom=[blossom]) - with self.assertRaises(mwmatching.MatchingError): - mwmatching._verify_optimum(ctx) + with self.assertRaises(MatchingError): + verify_optimum(ctx) def test_blossom_not_full(self): # @@ -560,11 +559,8 @@ class TestVerificationFail(unittest.TestCase): # \----2-----/ # edges = [(0,1,7), (0,2,2), (1,2,5), (0,3,8), (1,4,8)] - blossom = mwmatching._NonTrivialBlossom( - subblossoms=[ - mwmatching._Blossom(0), - mwmatching._Blossom(1), - mwmatching._Blossom(2)], + blossom = NonTrivialBlossom( + subblossoms=[Blossom(0), Blossom(1), Blossom(2)], edges=[0,2,1]) for sub in blossom.subblossoms: sub.parent = blossom @@ -574,8 +570,8 @@ class TestVerificationFail(unittest.TestCase): vertex_mate=[3, 4, -1, 0, 1], vertex_dual_2x=[4, 10, 0, 12, 6], nontrivial_blossom=[blossom]) - with self.assertRaises(mwmatching.MatchingError): - mwmatching._verify_optimum(ctx) + with self.assertRaises(MatchingError): + verify_optimum(ctx) if __name__ == "__main__": diff --git a/python/test_datastruct.py b/python/tests/test_datastruct.py similarity index 99% rename from python/test_datastruct.py rename to python/tests/test_datastruct.py index 60a09d5..60dd5ea 100644 --- a/python/test_datastruct.py +++ b/python/tests/test_datastruct.py @@ -3,7 +3,7 @@ import random import unittest -from datastruct import UnionFindQueue, PriorityQueue +from mwmatching.datastruct import UnionFindQueue, PriorityQueue class TestUnionFindQueue(unittest.TestCase): diff --git a/run_checks.sh b/run_checks.sh index 6dcf7f2..cf68c62 100755 --- a/run_checks.sh +++ b/run_checks.sh @@ -4,7 +4,7 @@ set -e echo echo "Running pycodestyle" -pycodestyle python/mwmatching.py python/datastruct.py tests +pycodestyle python/mwmatching python/run_matching.py tests echo echo "Running mypy" @@ -12,20 +12,15 @@ mypy --disallow-incomplete-defs python tests echo echo "Running pylint" -pylint --ignore=test_mwmatching.py python tests || [ $(($? & 3)) -eq 0 ] +pylint --ignore=test_algorithm.py python tests/*.py tests/generate/*.py || [ $(($? & 3)) -eq 0 ] echo -echo "Running test_mwmatching.py" -python3 python/test_mwmatching.py - -echo -echo "Running test_datastruct.py" -python3 python/test_datastruct.py +echo "Running unit tests" +python3 -m unittest discover -t python -s python/tests echo echo "Checking test coverage" coverage erase -coverage run --branch python/test_datastruct.py -coverage run -a --branch python/test_mwmatching.py +coverage run --branch -m unittest discover -t python -s python/tests coverage report -m diff --git a/tests/generate/make_slow_graph.py b/tests/generate/make_slow_graph.py index e507643..57c5571 100644 --- a/tests/generate/make_slow_graph.py +++ b/tests/generate/make_slow_graph.py @@ -19,13 +19,12 @@ count_delta_step = [0] def patch_matching_code() -> None: """Patch the matching code to count events.""" - # pylint: disable=import-outside-toplevel,protected-access + # pylint: disable=import-outside-toplevel - import mwmatching + from mwmatching.algorithm import MatchingContext - orig_make_blossom = mwmatching._MatchingContext.make_blossom - orig_calc_dual_delta_step = ( - mwmatching._MatchingContext.calc_dual_delta_step) + orig_make_blossom = MatchingContext.make_blossom + orig_calc_dual_delta_step = MatchingContext.calc_dual_delta_step def stub_make_blossom(*args: Any, **kwargs: Any) -> Any: count_make_blossom[0] += 1 @@ -36,9 +35,8 @@ def patch_matching_code() -> None: ret = orig_calc_dual_delta_step(*args, **kwargs) return ret - mwmatching._MatchingContext.make_blossom = ( # type: ignore - stub_make_blossom) - mwmatching._MatchingContext.calc_dual_delta_step = ( # type: ignore + MatchingContext.make_blossom = stub_make_blossom # type: ignore + MatchingContext.calc_dual_delta_step = ( # type: ignore stub_calc_dual_delta_step)