1
0
Fork 0

Restructure Python code as package

This commit is contained in:
Joris van Rantwijk 2024-07-07 10:30:21 +02:00
parent f2e8ca1357
commit 147640329f
9 changed files with 158 additions and 158 deletions

View File

@ -0,0 +1,11 @@
"""
Algorithm for finding a maximum weight matching in general graphs.
"""
from .algorithm import (maximum_weight_matching,
adjust_weights_for_maximum_cardinality_matching,
MatchingError)
__all__ = ["maximum_weight_matching",
"adjust_weights_for_maximum_cardinality_matching",
"MatchingError"]

View File

@ -10,7 +10,7 @@ import math
from collections.abc import Sequence from collections.abc import Sequence
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
from datastruct import UnionFindQueue, PriorityQueue from .datastruct import UnionFindQueue, PriorityQueue
def maximum_weight_matching( def maximum_weight_matching(
@ -66,10 +66,10 @@ def maximum_weight_matching(
return [] return []
# Initialize graph representation. # Initialize graph representation.
graph = _GraphInfo(edges) graph = GraphInfo(edges)
# Initialize the matching algorithm. # Initialize the matching algorithm.
ctx = _MatchingContext(graph) ctx = MatchingContext(graph)
ctx.start() ctx.start()
# Improve the solution until no further improvement is possible. # Improve the solution until no further improvement is possible.
@ -92,7 +92,7 @@ def maximum_weight_matching(
# there is a bug in the matching algorithm. # there is a bug in the matching algorithm.
# Verification only works reliably for integer weights. # Verification only works reliably for integer weights.
if graph.integer_weights: if graph.integer_weights:
_verify_optimum(ctx) verify_optimum(ctx)
return pairs return pairs
@ -277,7 +277,7 @@ def _remove_negative_weight_edges(
return edges return edges
class _GraphInfo: class GraphInfo:
"""Representation of the input graph. """Representation of the input graph.
These data remain unchanged while the algorithm runs. These data remain unchanged while the algorithm runs.
@ -328,12 +328,12 @@ class _GraphInfo:
# Each vertex may be labeled "S" (outer) or "T" (inner) or be unlabeled. # Each vertex may be labeled "S" (outer) or "T" (inner) or be unlabeled.
_LABEL_NONE = 0 LABEL_NONE = 0
_LABEL_S = 1 LABEL_S = 1
_LABEL_T = 2 LABEL_T = 2
class _Blossom: class Blossom:
"""Represents a blossom in a partially matched graph. """Represents a blossom in a partially matched graph.
A blossom is an odd-length alternating cycle over sub-blossoms. A blossom is an odd-length alternating cycle over sub-blossoms.
@ -361,7 +361,7 @@ class _Blossom:
# #
# If this is a top-level blossom, # If this is a top-level blossom,
# "parent = None". # "parent = None".
self.parent: Optional[_NonTrivialBlossom] = None self.parent: Optional[NonTrivialBlossom] = None
# "base_vertex" is the vertex index of the base of the blossom. # "base_vertex" is the vertex index of the base of the blossom.
# This is the unique vertex which is contained in the blossom # This is the unique vertex which is contained in the blossom
@ -374,7 +374,7 @@ class _Blossom:
# A top-level blossom that is part of an alternating tree, # A top-level blossom that is part of an alternating tree,
# has label S or T. An unlabeled top-level blossom is not part # has label S or T. An unlabeled top-level blossom is not part
# of any alternating tree. # of any alternating tree.
self.label: int = _LABEL_NONE self.label: int = LABEL_NONE
# A labeled top-level blossoms keeps track of the edge through which # A labeled top-level blossoms keeps track of the edge through which
# it is attached to the alternating tree. # it is attached to the alternating tree.
@ -389,11 +389,11 @@ class _Blossom:
# "tree_blossoms" is the set of all top-level blossoms that belong # "tree_blossoms" is the set of all top-level blossoms that belong
# to the same alternating tree. The same set instance is shared by # to the same alternating tree. The same set instance is shared by
# all top-level blossoms in the tree. # all top-level blossoms in the tree.
self.tree_blossoms: "Optional[set[_Blossom]]" = None self.tree_blossoms: "Optional[set[Blossom]]" = None
# Each top-level blossom maintains a union-find datastructure # Each top-level blossom maintains a union-find datastructure
# containing all vertices in the blossom. # containing all vertices in the blossom.
self.vertex_set: "UnionFindQueue[_Blossom, int]" self.vertex_set: "UnionFindQueue[Blossom, int]"
self.vertex_set = UnionFindQueue(self) self.vertex_set = UnionFindQueue(self)
# If this is a top-level unlabeled blossom with an edge to an # If this is a top-level unlabeled blossom with an edge to an
@ -415,7 +415,7 @@ class _Blossom:
return [self.base_vertex] return [self.base_vertex]
class _NonTrivialBlossom(_Blossom): class NonTrivialBlossom(Blossom):
"""Represents a non-trivial blossom in a partially matched graph. """Represents a non-trivial blossom in a partially matched graph.
A non-trivial blossom is a blossom that contains multiple sub-blossoms A non-trivial blossom is a blossom that contains multiple sub-blossoms
@ -436,7 +436,7 @@ class _NonTrivialBlossom(_Blossom):
def __init__( def __init__(
self, self,
subblossoms: list[_Blossom], subblossoms: list[Blossom],
edges: list[tuple[int, int]] edges: list[tuple[int, int]]
) -> None: ) -> None:
"""Initialize a new blossom.""" """Initialize a new blossom."""
@ -454,7 +454,7 @@ class _NonTrivialBlossom(_Blossom):
# #
# "subblossoms[0]" is the start and end of the alternating cycle. # "subblossoms[0]" is the start and end of the alternating cycle.
# "subblossoms[0]" contains the base vertex of the blossom. # "subblossoms[0]" contains the base vertex of the blossom.
self.subblossoms: list[_Blossom] = subblossoms self.subblossoms: list[Blossom] = subblossoms
# "edges" is a list of edges linking the sub-blossoms. # "edges" is a list of edges linking the sub-blossoms.
# Each edge is represented as an ordered pair "(x, y)" where "x" # Each edge is represented as an ordered pair "(x, y)" where "x"
@ -491,13 +491,13 @@ class _NonTrivialBlossom(_Blossom):
"""Return a list of vertex indices contained in the blossom.""" """Return a list of vertex indices contained in the blossom."""
# Use an explicit stack to avoid deep recursion. # Use an explicit stack to avoid deep recursion.
stack: list[_NonTrivialBlossom] = [self] stack: list[NonTrivialBlossom] = [self]
nodes: list[int] = [] nodes: list[int] = []
while stack: while stack:
b = stack.pop() b = stack.pop()
for sub in b.subblossoms: for sub in b.subblossoms:
if isinstance(sub, _NonTrivialBlossom): if isinstance(sub, NonTrivialBlossom):
stack.append(sub) stack.append(sub)
else: else:
nodes.append(sub.base_vertex) nodes.append(sub.base_vertex)
@ -505,21 +505,21 @@ class _NonTrivialBlossom(_Blossom):
return nodes return nodes
class _AlternatingPath(NamedTuple): class AlternatingPath(NamedTuple):
"""Represents a list of edges forming an alternating path or an """Represents a list of edges forming an alternating path or an
alternating cycle.""" alternating cycle."""
edges: list[tuple[int, int]] edges: list[tuple[int, int]]
is_cycle: bool is_cycle: bool
class _MatchingContext: class MatchingContext:
"""Holds all data used by the matching algorithm. """Holds all data used by the matching algorithm.
It contains a partial solution of the matching problem and several It contains a partial solution of the matching problem and several
auxiliary data structures. auxiliary data structures.
""" """
def __init__(self, graph: _GraphInfo) -> None: def __init__(self, graph: GraphInfo) -> None:
"""Set up the initial state of the matching algorithm.""" """Set up the initial state of the matching algorithm."""
num_vertex = graph.num_vertex num_vertex = graph.num_vertex
@ -545,14 +545,14 @@ class _MatchingContext:
# #
# "trivial_blossom[x]" is the trivial blossom that contains only # "trivial_blossom[x]" is the trivial blossom that contains only
# vertex "x". # vertex "x".
self.trivial_blossom: list[_Blossom] = [_Blossom(x) self.trivial_blossom: list[Blossom] = [Blossom(x)
for x in range(num_vertex)] for x in range(num_vertex)]
# Non-trivial blossoms may be created and destroyed during # Non-trivial blossoms may be created and destroyed during
# the course of the algorithm. # the course of the algorithm.
# #
# Initially there are no non-trivial blossoms. # Initially there are no non-trivial blossoms.
self.nontrivial_blossom: set[_NonTrivialBlossom] = set() self.nontrivial_blossom: set[NonTrivialBlossom] = set()
# "vertex_set_node[x]" represents the vertex "x" inside the # "vertex_set_node[x]" represents the vertex "x" inside the
# union-find datastructure of its top-level blossom. # union-find datastructure of its top-level blossom.
@ -593,7 +593,7 @@ class _MatchingContext:
# Queue containing unlabeled top-level blossoms that have an edge to # Queue containing unlabeled top-level blossoms that have an edge to
# an S-blossom. The priority of a blossom is 2 times its least slack # an S-blossom. The priority of a blossom is 2 times its least slack
# to an S blossom, plus 2 times the running sum of delta steps. # to an S blossom, plus 2 times the running sum of delta steps.
self.delta2_queue: PriorityQueue[_Blossom] = PriorityQueue() self.delta2_queue: PriorityQueue[Blossom] = PriorityQueue()
# Queue containing edges between S-vertices in different top-level # Queue containing edges between S-vertices in different top-level
# blossoms. The priority of an edge is its slack plus 2 times the # blossoms. The priority of an edge is its slack plus 2 times the
@ -605,7 +605,7 @@ class _MatchingContext:
# Queue containing top-level non-trivial T-blossoms. # Queue containing top-level non-trivial T-blossoms.
# The priority of a blossom is its dual plus 2 times the running # The priority of a blossom is its dual plus 2 times the running
# sum of delta steps. # sum of delta steps.
self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue() self.delta4_queue: PriorityQueue[NonTrivialBlossom] = PriorityQueue()
# For each T-vertex or unlabeled vertex "x", # For each T-vertex or unlabeled vertex "x",
# "vertex_sedge_queue[x]" is a queue of edges between "x" and any # "vertex_sedge_queue[x]" is a queue of edges between "x" and any
@ -648,7 +648,7 @@ class _MatchingContext:
(x, y, w) = self.graph.edges[e] (x, y, w) = self.graph.edges[e]
return self.vertex_dual_2x[x] + self.vertex_dual_2x[y] - 2 * w return self.vertex_dual_2x[x] + self.vertex_dual_2x[y] - 2 * w
def delta2_add_edge(self, e: int, y: int, by: _Blossom) -> None: def delta2_add_edge(self, e: int, y: int, by: Blossom) -> None:
"""Add edge "e" for delta2 tracking. """Add edge "e" for delta2 tracking.
Edge "e" connects an S-vertex to a T-vertex or unlabeled vertex "y". Edge "e" connects an S-vertex to a T-vertex or unlabeled vertex "y".
@ -674,14 +674,14 @@ class _MatchingContext:
# If the blossom is unlabeled and the new edge becomes its least-slack # If the blossom is unlabeled and the new edge becomes its least-slack
# S-edge, insert or update the blossom in the global delta2 queue. # S-edge, insert or update the blossom in the global delta2 queue.
if by.label == _LABEL_NONE: if by.label == LABEL_NONE:
prio += by.vertex_dual_offset prio += by.vertex_dual_offset
if by.delta2_node is None: if by.delta2_node is None:
by.delta2_node = self.delta2_queue.insert(prio, by) by.delta2_node = self.delta2_queue.insert(prio, by)
elif prio < by.delta2_node.prio: elif prio < by.delta2_node.prio:
self.delta2_queue.decrease_prio(by.delta2_node, prio) self.delta2_queue.decrease_prio(by.delta2_node, prio)
def delta2_remove_edge(self, e: int, y: int, by: _Blossom) -> None: def delta2_remove_edge(self, e: int, y: int, by: Blossom) -> None:
"""Remove edge "e" from delta2 tracking. """Remove edge "e" from delta2 tracking.
This function is called if an S-vertex becomes unlabeled, This function is called if an S-vertex becomes unlabeled,
@ -705,7 +705,7 @@ class _MatchingContext:
# If necessary, update the priority of "y" in its UnionFindQueue. # If necessary, update the priority of "y" in its UnionFindQueue.
if prio > self.vertex_set_node[y].prio: if prio > self.vertex_set_node[y].prio:
self.vertex_set_node[y].set_prio(prio) self.vertex_set_node[y].set_prio(prio)
if by.label == _LABEL_NONE: if by.label == LABEL_NONE:
# Update or delete the blossom in the global delta2 queue. # Update or delete the blossom in the global delta2 queue.
assert by.delta2_node is not None assert by.delta2_node is not None
prio = by.vertex_set.min_prio() prio = by.vertex_set.min_prio()
@ -718,7 +718,7 @@ class _MatchingContext:
self.delta2_queue.delete(by.delta2_node) self.delta2_queue.delete(by.delta2_node)
by.delta2_node = None by.delta2_node = None
def delta2_enable_blossom(self, blossom: _Blossom) -> None: def delta2_enable_blossom(self, blossom: Blossom) -> None:
"""Enable delta2 tracking for "blossom". """Enable delta2 tracking for "blossom".
This function is called when a blossom becomes an unlabeled top-level This function is called when a blossom becomes an unlabeled top-level
@ -733,7 +733,7 @@ class _MatchingContext:
prio += blossom.vertex_dual_offset prio += blossom.vertex_dual_offset
blossom.delta2_node = self.delta2_queue.insert(prio, blossom) blossom.delta2_node = self.delta2_queue.insert(prio, blossom)
def delta2_disable_blossom(self, blossom: _Blossom) -> None: def delta2_disable_blossom(self, blossom: Blossom) -> None:
"""Disable delta2 tracking for "blossom". """Disable delta2 tracking for "blossom".
The blossom will be removed from the global delta2 queue. The blossom will be removed from the global delta2 queue.
@ -780,7 +780,7 @@ class _MatchingContext:
prio = delta2_node.prio prio = delta2_node.prio
slack_2x = prio - self.delta_sum_2x slack_2x = prio - self.delta_sum_2x
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_NONE assert blossom.label == LABEL_NONE
x = blossom.vertex_set.min_elem() x = blossom.vertex_set.min_elem()
e = self.vertex_sedge_queue[x].find_min().data e = self.vertex_sedge_queue[x].find_min().data
@ -840,7 +840,7 @@ class _MatchingContext:
(x, y, _w) = self.graph.edges[e] (x, y, _w) = self.graph.edges[e]
bx = self.vertex_set_node[x].find() bx = self.vertex_set_node[x].find()
by = self.vertex_set_node[y].find() by = self.vertex_set_node[y].find()
assert (bx.label == _LABEL_S) and (by.label == _LABEL_S) assert (bx.label == LABEL_S) and (by.label == LABEL_S)
if bx is not by: if bx is not by:
slack = delta3_node.prio - self.delta_sum_2x slack = delta3_node.prio - self.delta_sum_2x
return (e, slack) return (e, slack)
@ -859,7 +859,7 @@ class _MatchingContext:
# Managing blossom labels: # Managing blossom labels:
# #
def assign_blossom_label_s(self, blossom: _Blossom) -> None: def assign_blossom_label_s(self, blossom: Blossom) -> None:
"""Change an unlabeled top-level blossom into an S-blossom. """Change an unlabeled top-level blossom into an S-blossom.
For a blossom with "j" vertices and "k" incident edges, For a blossom with "j" vertices and "k" incident edges,
@ -870,8 +870,8 @@ class _MatchingContext:
""" """
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_NONE assert blossom.label == LABEL_NONE
blossom.label = _LABEL_S blossom.label = LABEL_S
# Labeled blossoms must not be in the delta2 queue. # Labeled blossoms must not be in the delta2 queue.
self.delta2_disable_blossom(blossom) self.delta2_disable_blossom(blossom)
@ -887,7 +887,7 @@ class _MatchingContext:
# The value of blossom.dual_var must be adjusted accordingly # The value of blossom.dual_var must be adjusted accordingly
# when the blossom changes from unlabeled to S-blossom. # when the blossom changes from unlabeled to S-blossom.
# #
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, NonTrivialBlossom):
blossom.dual_var -= self.delta_sum_2x blossom.dual_var -= self.delta_sum_2x
# Apply pending updates to vertex dual variables and prepare # Apply pending updates to vertex dual variables and prepare
@ -916,20 +916,20 @@ class _MatchingContext:
# Add the new S-vertices to the scan queue. # Add the new S-vertices to the scan queue.
self.scan_queue.extend(vertices) self.scan_queue.extend(vertices)
def assign_blossom_label_t(self, blossom: _Blossom) -> None: def assign_blossom_label_t(self, blossom: Blossom) -> None:
"""Change an unlabeled top-level blossom into a T-blossom. """Change an unlabeled top-level blossom into a T-blossom.
This function takes time O(log(n)). This function takes time O(log(n)).
""" """
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_NONE assert blossom.label == LABEL_NONE
blossom.label = _LABEL_T blossom.label = LABEL_T
# Labeled blossoms must not be in the delta2 queue. # Labeled blossoms must not be in the delta2 queue.
self.delta2_disable_blossom(blossom) self.delta2_disable_blossom(blossom)
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, NonTrivialBlossom):
# Adjust for lazy updating of T-blossom dual variables. # Adjust for lazy updating of T-blossom dual variables.
# #
@ -962,7 +962,7 @@ class _MatchingContext:
# #
blossom.vertex_dual_offset -= self.delta_sum_2x blossom.vertex_dual_offset -= self.delta_sum_2x
def remove_blossom_label_s(self, blossom: _Blossom) -> None: def remove_blossom_label_s(self, blossom: Blossom) -> None:
"""Change a top-level S-blossom into an unlabeled blossom. """Change a top-level S-blossom into an unlabeled blossom.
For a blossom with "j" vertices and "k" incident edges, For a blossom with "j" vertices and "k" incident edges,
@ -973,11 +973,11 @@ class _MatchingContext:
""" """
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_S assert blossom.label == LABEL_S
blossom.label = _LABEL_NONE blossom.label = LABEL_NONE
# Unwind lazy delta updates to the S-blossom dual variable. # Unwind lazy delta updates to the S-blossom dual variable.
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, NonTrivialBlossom):
blossom.dual_var += self.delta_sum_2x blossom.dual_var += self.delta_sum_2x
assert blossom.vertex_dual_offset == 0 assert blossom.vertex_dual_offset == 0
@ -1002,7 +1002,7 @@ class _MatchingContext:
self.delta3_remove_edge(e) self.delta3_remove_edge(e)
by = self.vertex_set_node[y].find() by = self.vertex_set_node[y].find()
if by.label == _LABEL_S: if by.label == LABEL_S:
# Edge "e" connects unlabeled vertex "x" to S-vertex "y". # Edge "e" connects unlabeled vertex "x" to S-vertex "y".
# It must be tracked for delta2 via vertex "x". # It must be tracked for delta2 via vertex "x".
self.delta2_add_edge(e, x, blossom) self.delta2_add_edge(e, x, blossom)
@ -1013,17 +1013,17 @@ class _MatchingContext:
# removed now. # removed now.
self.delta2_remove_edge(e, y, by) self.delta2_remove_edge(e, y, by)
def remove_blossom_label_t(self, blossom: _Blossom) -> None: def remove_blossom_label_t(self, blossom: Blossom) -> None:
"""Change a top-level T-blossom into an unlabeled blossom. """Change a top-level T-blossom into an unlabeled blossom.
This function takes time O(log(n)). This function takes time O(log(n)).
""" """
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_T assert blossom.label == LABEL_T
blossom.label = _LABEL_NONE blossom.label = LABEL_NONE
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, NonTrivialBlossom):
# Unlabeled blossoms are not tracked in the delta4 queue. # Unlabeled blossoms are not tracked in the delta4 queue.
assert blossom.delta4_node is not None assert blossom.delta4_node is not None
@ -1039,31 +1039,31 @@ class _MatchingContext:
# Enable unlabeled top-level blossom for delta2 tracking. # Enable unlabeled top-level blossom for delta2 tracking.
self.delta2_enable_blossom(blossom) self.delta2_enable_blossom(blossom)
def change_s_blossom_to_subblossom(self, blossom: _Blossom) -> None: def change_s_blossom_to_subblossom(self, blossom: Blossom) -> None:
"""Change a top-level S-blossom into an S-subblossom. """Change a top-level S-blossom into an S-subblossom.
This function takes time O(1). This function takes time O(1).
""" """
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_S assert blossom.label == LABEL_S
blossom.label = _LABEL_NONE blossom.label = LABEL_NONE
# Unwind lazy delta updates to the S-blossom dual variable. # Unwind lazy delta updates to the S-blossom dual variable.
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, NonTrivialBlossom):
blossom.dual_var += self.delta_sum_2x blossom.dual_var += self.delta_sum_2x
# #
# General support routines: # General support routines:
# #
def reset_blossom_label(self, blossom: _Blossom) -> None: def reset_blossom_label(self, blossom: Blossom) -> None:
"""Remove blossom label.""" """Remove blossom label."""
assert blossom.parent is None assert blossom.parent is None
assert blossom.label != _LABEL_NONE assert blossom.label != LABEL_NONE
if blossom.label == _LABEL_S: if blossom.label == LABEL_S:
self.remove_blossom_label_s(blossom) self.remove_blossom_label_s(blossom)
else: else:
self.remove_blossom_label_t(blossom) self.remove_blossom_label_t(blossom)
@ -1072,7 +1072,7 @@ class _MatchingContext:
"""TODO -- remove this function, only for debugging""" """TODO -- remove this function, only for debugging"""
for blossom in itertools.chain(self.trivial_blossom, for blossom in itertools.chain(self.trivial_blossom,
self.nontrivial_blossom): self.nontrivial_blossom):
if (blossom.parent is None) and (blossom.label != _LABEL_NONE): if (blossom.parent is None) and (blossom.label != LABEL_NONE):
assert blossom.tree_blossoms is not None assert blossom.tree_blossoms is not None
assert blossom in blossom.tree_blossoms assert blossom in blossom.tree_blossoms
if blossom.tree_edge is not None: if blossom.tree_edge is not None:
@ -1084,7 +1084,7 @@ class _MatchingContext:
assert blossom.tree_edge is None assert blossom.tree_edge is None
assert blossom.tree_blossoms is None assert blossom.tree_blossoms is None
def remove_alternating_tree(self, tree_blossoms: set[_Blossom]) -> None: def remove_alternating_tree(self, tree_blossoms: set[Blossom]) -> None:
"""Reset the alternating tree consisting of the specified blossoms. """Reset the alternating tree consisting of the specified blossoms.
Marks the blossoms as unlabeled. Marks the blossoms as unlabeled.
@ -1093,13 +1093,13 @@ class _MatchingContext:
This function takes time O((n + m) * log(n)). This function takes time O((n + m) * log(n)).
""" """
for blossom in tree_blossoms: for blossom in tree_blossoms:
assert blossom.label != _LABEL_NONE assert blossom.label != LABEL_NONE
assert blossom.tree_blossoms is tree_blossoms assert blossom.tree_blossoms is tree_blossoms
self.reset_blossom_label(blossom) self.reset_blossom_label(blossom)
blossom.tree_edge = None blossom.tree_edge = None
blossom.tree_blossoms = None blossom.tree_blossoms = None
def trace_alternating_paths(self, x: int, y: int) -> _AlternatingPath: def trace_alternating_paths(self, x: int, y: int) -> AlternatingPath:
"""Trace back through the alternating trees from vertices "x" and "y". """Trace back through the alternating trees from vertices "x" and "y".
If both vertices are part of the same alternating tree, this function If both vertices are part of the same alternating tree, this function
@ -1119,7 +1119,7 @@ class _MatchingContext:
blossoms. blossoms.
""" """
marked_blossoms: list[_Blossom] = [] marked_blossoms: list[Blossom] = []
# "xedges" is a list of edges used while tracing from "x". # "xedges" is a list of edges used while tracing from "x".
# "yedges" is a list of edges used while tracing from "y". # "yedges" is a list of edges used while tracing from "y".
@ -1129,7 +1129,7 @@ class _MatchingContext:
# "first_common" is the first common ancestor of "x" and "y" # "first_common" is the first common ancestor of "x" and "y"
# in the alternating tree, or None if there is no common ancestor. # in the alternating tree, or None if there is no common ancestor.
first_common: Optional[_Blossom] = None first_common: Optional[Blossom] = None
# Alternate between tracing the path from "x" and the path from "y". # Alternate between tracing the path from "x" and the path from "y".
# This ensures that the search time is bounded by the size of the # This ensures that the search time is bounded by the size of the
@ -1179,14 +1179,14 @@ class _MatchingContext:
# Any S-to-S alternating path must have odd length. # Any S-to-S alternating path must have odd length.
assert len(path_edges) % 2 == 1 assert len(path_edges) % 2 == 1
return _AlternatingPath(edges=path_edges, return AlternatingPath(edges=path_edges,
is_cycle=(first_common is not None)) is_cycle=(first_common is not None))
# #
# Merge and expand blossoms: # Merge and expand blossoms:
# #
def make_blossom(self, path: _AlternatingPath) -> None: def make_blossom(self, path: AlternatingPath) -> None:
"""Create a new blossom from an alternating cycle. """Create a new blossom from an alternating cycle.
Assign label S to the new blossom. Assign label S to the new blossom.
@ -1214,21 +1214,21 @@ class _MatchingContext:
assert subblossoms[1:] == subblossoms_next[:-1] assert subblossoms[1:] == subblossoms_next[:-1]
# Blossom must start and end with an S-sub-blossom. # Blossom must start and end with an S-sub-blossom.
assert subblossoms[0].label == _LABEL_S assert subblossoms[0].label == LABEL_S
# Remove blossom labels. # Remove blossom labels.
# Mark vertices inside former T-blossoms as S-vertices. # Mark vertices inside former T-blossoms as S-vertices.
for sub in subblossoms: for sub in subblossoms:
if sub.label == _LABEL_T: if sub.label == LABEL_T:
self.remove_blossom_label_t(sub) self.remove_blossom_label_t(sub)
self.assign_blossom_label_s(sub) self.assign_blossom_label_s(sub)
self.change_s_blossom_to_subblossom(sub) self.change_s_blossom_to_subblossom(sub)
# Create the new blossom object. # Create the new blossom object.
blossom = _NonTrivialBlossom(subblossoms, path.edges) blossom = NonTrivialBlossom(subblossoms, path.edges)
# Assign label S to the new blossom. # Assign label S to the new blossom.
blossom.label = _LABEL_S blossom.label = LABEL_S
# Prepare for lazy updating of S-blossom dual variable. # Prepare for lazy updating of S-blossom dual variable.
blossom.dual_var = -self.delta_sum_2x blossom.dual_var = -self.delta_sum_2x
@ -1257,9 +1257,9 @@ class _MatchingContext:
@staticmethod @staticmethod
def find_path_through_blossom( def find_path_through_blossom(
blossom: _NonTrivialBlossom, blossom: NonTrivialBlossom,
sub: _Blossom sub: Blossom
) -> tuple[list[_Blossom], list[tuple[int, int]]]: ) -> tuple[list[Blossom], list[tuple[int, int]]]:
"""Construct a path with an even number of edges through the """Construct a path with an even number of edges through the
specified blossom, from sub-blossom "sub" to the base of "blossom". specified blossom, from sub-blossom "sub" to the base of "blossom".
@ -1285,14 +1285,14 @@ class _MatchingContext:
return (nodes, edges) return (nodes, edges)
def expand_unlabeled_blossom(self, blossom: _NonTrivialBlossom) -> None: def expand_unlabeled_blossom(self, blossom: NonTrivialBlossom) -> None:
"""Expand the specified unlabeled blossom. """Expand the specified unlabeled blossom.
This function takes total time O(n * log(n)) per stage. This function takes total time O(n * log(n)) per stage.
""" """
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_NONE assert blossom.label == LABEL_NONE
# Remove blossom from the delta2 queue. # Remove blossom from the delta2 queue.
self.delta2_disable_blossom(blossom) self.delta2_disable_blossom(blossom)
@ -1306,7 +1306,7 @@ class _MatchingContext:
# Convert sub-blossoms into top-level blossoms. # Convert sub-blossoms into top-level blossoms.
for sub in blossom.subblossoms: for sub in blossom.subblossoms:
assert sub.label == _LABEL_NONE assert sub.label == LABEL_NONE
sub.parent = None sub.parent = None
assert sub.vertex_dual_offset == 0 assert sub.vertex_dual_offset == 0
@ -1320,14 +1320,14 @@ class _MatchingContext:
# Delete the expanded blossom. # Delete the expanded blossom.
self.nontrivial_blossom.remove(blossom) self.nontrivial_blossom.remove(blossom)
def expand_t_blossom(self, blossom: _NonTrivialBlossom) -> None: def expand_t_blossom(self, blossom: NonTrivialBlossom) -> None:
"""Expand the specified T-blossom. """Expand the specified T-blossom.
This function takes total time O(n * log(n) + m) per stage. This function takes total time O(n * log(n) + m) per stage.
""" """
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_T assert blossom.label == LABEL_T
assert blossom.delta2_node is None assert blossom.delta2_node is None
# Remove blossom from its alternating tree. # Remove blossom from its alternating tree.
@ -1391,9 +1391,9 @@ class _MatchingContext:
def augment_blossom_rec( def augment_blossom_rec(
self, self,
blossom: _NonTrivialBlossom, blossom: NonTrivialBlossom,
sub: _Blossom, sub: Blossom,
stack: list[tuple[_NonTrivialBlossom, _Blossom]] stack: list[tuple[NonTrivialBlossom, Blossom]]
) -> None: ) -> None:
"""Augment along an alternating path through the specified blossom, """Augment along an alternating path through the specified blossom,
from sub-blossom "sub" to the base vertex of the blossom. from sub-blossom "sub" to the base vertex of the blossom.
@ -1430,11 +1430,11 @@ class _MatchingContext:
# Augment through the subblossoms touching the edge (x, y). # Augment through the subblossoms touching the edge (x, y).
# Nothing needs to be done for trivial subblossoms. # Nothing needs to be done for trivial subblossoms.
bx = path_nodes[p+1] bx = path_nodes[p+1]
if isinstance(bx, _NonTrivialBlossom): if isinstance(bx, NonTrivialBlossom):
stack.append((bx, self.trivial_blossom[x])) stack.append((bx, self.trivial_blossom[x]))
by = path_nodes[p+2] by = path_nodes[p+2]
if isinstance(by, _NonTrivialBlossom): if isinstance(by, NonTrivialBlossom):
stack.append((by, self.trivial_blossom[y])) stack.append((by, self.trivial_blossom[y]))
# Rotate the subblossom list so the new base ends up in position 0. # Rotate the subblossom list so the new base ends up in position 0.
@ -1450,8 +1450,8 @@ class _MatchingContext:
def augment_blossom( def augment_blossom(
self, self,
blossom: _NonTrivialBlossom, blossom: NonTrivialBlossom,
sub: _Blossom sub: Blossom
) -> None: ) -> None:
"""Augment along an alternating path through the specified blossom, """Augment along an alternating path through the specified blossom,
from sub-blossom "sub" to the base vertex of the blossom. from sub-blossom "sub" to the base vertex of the blossom.
@ -1486,7 +1486,7 @@ class _MatchingContext:
# Augment "blossom" from "sub" to the base vertex. # Augment "blossom" from "sub" to the base vertex.
self.augment_blossom_rec(blossom, sub, stack) self.augment_blossom_rec(blossom, sub, stack)
def augment_matching(self, path: _AlternatingPath) -> None: def augment_matching(self, path: AlternatingPath) -> None:
"""Augment the matching through the specified augmenting path. """Augment the matching through the specified augmenting path.
This function takes time O(n). This function takes time O(n).
@ -1515,11 +1515,11 @@ class _MatchingContext:
# Augment the non-trivial blossoms on either side of this edge. # Augment the non-trivial blossoms on either side of this edge.
# No action is necessary for trivial blossoms. # No action is necessary for trivial blossoms.
bx = self.vertex_set_node[x].find() bx = self.vertex_set_node[x].find()
if isinstance(bx, _NonTrivialBlossom): if isinstance(bx, NonTrivialBlossom):
self.augment_blossom(bx, self.trivial_blossom[x]) self.augment_blossom(bx, self.trivial_blossom[x])
by = self.vertex_set_node[y].find() by = self.vertex_set_node[y].find()
if isinstance(by, _NonTrivialBlossom): if isinstance(by, NonTrivialBlossom):
self.augment_blossom(by, self.trivial_blossom[y]) self.augment_blossom(by, self.trivial_blossom[y])
# Pull the edge into the matching. # Pull the edge into the matching.
@ -1551,7 +1551,7 @@ class _MatchingContext:
assert y != -1 assert y != -1
by = self.vertex_set_node[y].find() by = self.vertex_set_node[y].find()
assert by.label == _LABEL_T assert by.label == LABEL_T
assert by.tree_blossoms is not None assert by.tree_blossoms is not None
# Attach the blossom that contains "x" to the alternating tree. # Attach the blossom that contains "x" to the alternating tree.
@ -1575,10 +1575,10 @@ class _MatchingContext:
bx = self.vertex_set_node[x].find() bx = self.vertex_set_node[x].find()
by = self.vertex_set_node[y].find() by = self.vertex_set_node[y].find()
assert bx.label == _LABEL_S assert bx.label == LABEL_S
# Expand zero-dual blossoms before assigning label T. # Expand zero-dual blossoms before assigning label T.
while isinstance(by, _NonTrivialBlossom) and (by.dual_var == 0): while isinstance(by, NonTrivialBlossom) and (by.dual_var == 0):
self.expand_unlabeled_blossom(by) self.expand_unlabeled_blossom(by)
by = self.vertex_set_node[y].find() by = self.vertex_set_node[y].find()
@ -1617,8 +1617,8 @@ class _MatchingContext:
bx = self.vertex_set_node[x].find() bx = self.vertex_set_node[x].find()
by = self.vertex_set_node[y].find() by = self.vertex_set_node[y].find()
assert bx.label == _LABEL_S assert bx.label == LABEL_S
assert by.label == _LABEL_S assert by.label == LABEL_S
assert bx is not by assert bx is not by
# Trace back through the alternating trees from "x" and "y". # Trace back through the alternating trees from "x" and "y".
@ -1680,7 +1680,7 @@ class _MatchingContext:
# Double-check that "x" is an S-vertex. # Double-check that "x" is an S-vertex.
bx = self.vertex_set_node[x].find() bx = self.vertex_set_node[x].find()
assert bx.label == _LABEL_S assert bx.label == LABEL_S
# Scan the edges that are incident on "x". # Scan the edges that are incident on "x".
# This loop runs through O(m) iterations per stage. # This loop runs through O(m) iterations per stage.
@ -1703,7 +1703,7 @@ class _MatchingContext:
if bx is by: if bx is by:
continue continue
if by.label == _LABEL_S: if by.label == LABEL_S:
self.delta3_add_edge(e) self.delta3_add_edge(e)
else: else:
self.delta2_add_edge(e, y, by) self.delta2_add_edge(e, y, by)
@ -1716,7 +1716,7 @@ class _MatchingContext:
def calc_dual_delta_step( def calc_dual_delta_step(
self self
) -> tuple[int, float, int, Optional[_NonTrivialBlossom]]: ) -> tuple[int, float, int, Optional[NonTrivialBlossom]]:
"""Calculate a delta step in the dual LPP problem. """Calculate a delta step in the dual LPP problem.
This function returns the minimum of the 4 types of delta values, This function returns the minimum of the 4 types of delta values,
@ -1740,7 +1740,7 @@ class _MatchingContext:
Tuple (delta_type, delta_2x, delta_edge, delta_blossom). Tuple (delta_type, delta_2x, delta_edge, delta_blossom).
""" """
delta_edge = -1 delta_edge = -1
delta_blossom: Optional[_NonTrivialBlossom] = None delta_blossom: Optional[NonTrivialBlossom] = None
# Compute delta1: minimum dual variable of any S-vertex. # Compute delta1: minimum dual variable of any S-vertex.
# All unmatched vertices have the same dual value, and this is # All unmatched vertices have the same dual value, and this is
@ -1770,7 +1770,7 @@ class _MatchingContext:
# This takes time O(log(n)). # This takes time O(log(n)).
if not self.delta4_queue.empty(): if not self.delta4_queue.empty():
blossom = self.delta4_queue.find_min().data blossom = self.delta4_queue.find_min().data
assert blossom.label == _LABEL_T assert blossom.label == LABEL_T
assert blossom.parent is None assert blossom.parent is None
blossom_dual = blossom.dual_var - self.delta_sum_2x blossom_dual = blossom.dual_var - self.delta_sum_2x
if blossom_dual <= delta_2x: if blossom_dual <= delta_2x:
@ -1847,7 +1847,7 @@ class _MatchingContext:
# Use the edge from S-vertex to unlabeled vertex that got # Use the edge from S-vertex to unlabeled vertex that got
# unlocked through the delta update. # unlocked through the delta update.
(x, y, _w) = self.graph.edges[delta_edge] (x, y, _w) = self.graph.edges[delta_edge]
if self.vertex_set_node[x].find().label != _LABEL_S: if self.vertex_set_node[x].find().label != LABEL_S:
(x, y) = (y, x) (x, y) = (y, x)
self.extend_tree_s_to_t(x, y) self.extend_tree_s_to_t(x, y)
@ -1886,9 +1886,9 @@ class _MatchingContext:
self.nontrivial_blossom): self.nontrivial_blossom):
# Remove blossom label. # Remove blossom label.
if (blossom.parent is None) and (blossom.label != _LABEL_NONE): if (blossom.parent is None) and (blossom.label != LABEL_NONE):
self.reset_blossom_label(blossom) self.reset_blossom_label(blossom)
assert blossom.label == _LABEL_NONE assert blossom.label == LABEL_NONE
# Remove blossom from alternating tree. # Remove blossom from alternating tree.
blossom.tree_edge = None blossom.tree_edge = None
@ -1906,8 +1906,8 @@ class _MatchingContext:
def _verify_blossom_edges( def _verify_blossom_edges(
ctx: _MatchingContext, ctx: MatchingContext,
blossom: _NonTrivialBlossom, blossom: NonTrivialBlossom,
edge_slack_2x: list[float] edge_slack_2x: list[float]
) -> None: ) -> None:
"""Descend down the blossom tree to find edges that are contained """Descend down the blossom tree to find edges that are contained
@ -1943,7 +1943,7 @@ def _verify_blossom_edges(
path_num_matched: list[int] = [0] path_num_matched: list[int] = [0]
# Use an explicit stack to avoid deep recursion. # Use an explicit stack to avoid deep recursion.
stack: list[tuple[_NonTrivialBlossom, int]] = [(blossom, -1)] stack: list[tuple[NonTrivialBlossom, int]] = [(blossom, -1)]
while stack: while stack:
(blossom, p) = stack[-1] (blossom, p) = stack[-1]
@ -1969,7 +1969,7 @@ def _verify_blossom_edges(
# Examine the next sub-blossom at the current level. # Examine the next sub-blossom at the current level.
sub = blossom.subblossoms[p] sub = blossom.subblossoms[p]
if isinstance(sub, _NonTrivialBlossom): if isinstance(sub, NonTrivialBlossom):
# Prepare to descent into the selected sub-blossom and # Prepare to descent into the selected sub-blossom and
# scan it recursively. # scan it recursively.
stack.append((sub, -1)) stack.append((sub, -1))
@ -2031,7 +2031,7 @@ def _verify_blossom_edges(
stack.pop() stack.pop()
def _verify_optimum(ctx: _MatchingContext) -> None: def verify_optimum(ctx: MatchingContext) -> None:
"""Verify that the optimum solution has been found. """Verify that the optimum solution has been found.
This function takes time O(n**2). This function takes time O(n**2).

View File

0
python/tests/__init__.py Normal file
View File

View File

@ -4,10 +4,12 @@ import math
import unittest import unittest
from unittest.mock import Mock from unittest.mock import Mock
import mwmatching
from mwmatching import ( from mwmatching import (
maximum_weight_matching as mwm, maximum_weight_matching as mwm,
adjust_weights_for_maximum_cardinality_matching as adj) adjust_weights_for_maximum_cardinality_matching as adj)
from mwmatching.algorithm import (
MatchingError, GraphInfo, Blossom, NonTrivialBlossom, MatchingContext,
verify_optimum)
class TestMaximumWeightMatching(unittest.TestCase): class TestMaximumWeightMatching(unittest.TestCase):
@ -431,10 +433,10 @@ class TestMaximumCardinalityMatching(unittest.TestCase):
class TestGraphInfo(unittest.TestCase): class TestGraphInfo(unittest.TestCase):
"""Test _GraphInfo helper class.""" """Test GraphInfo helper class."""
def test_empty(self): def test_empty(self):
graph = mwmatching._GraphInfo([]) graph = GraphInfo([])
self.assertEqual(graph.num_vertex, 0) self.assertEqual(graph.num_vertex, 0)
self.assertEqual(graph.edges, []) self.assertEqual(graph.edges, [])
self.assertEqual(graph.adjacent_edges, []) self.assertEqual(graph.adjacent_edges, [])
@ -449,8 +451,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate, vertex_mate,
vertex_dual_2x, vertex_dual_2x,
nontrivial_blossom): nontrivial_blossom):
ctx = Mock(spec=mwmatching._MatchingContext) ctx = Mock(spec=MatchingContext)
ctx.graph = mwmatching._GraphInfo(edges) ctx.graph = GraphInfo(edges)
ctx.vertex_mate = vertex_mate ctx.vertex_mate = vertex_mate
ctx.vertex_dual_2x = vertex_dual_2x ctx.vertex_dual_2x = vertex_dual_2x
ctx.nontrivial_blossom = nontrivial_blossom ctx.nontrivial_blossom = nontrivial_blossom
@ -463,7 +465,7 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 1], vertex_mate=[-1, 2, 1],
vertex_dual_2x=[0, 20, 2], vertex_dual_2x=[0, 20, 2],
nontrivial_blossom=[]) nontrivial_blossom=[])
mwmatching._verify_optimum(ctx) verify_optimum(ctx)
def test_asymmetric_matching(self): def test_asymmetric_matching(self):
edges = [(0,1,10), (1,2,11)] edges = [(0,1,10), (1,2,11)]
@ -472,8 +474,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 0], vertex_mate=[-1, 2, 0],
vertex_dual_2x=[0, 20, 2], vertex_dual_2x=[0, 20, 2],
nontrivial_blossom=[]) nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError): with self.assertRaises(MatchingError):
mwmatching._verify_optimum(ctx) verify_optimum(ctx)
def test_nonexistent_matched_edge(self): def test_nonexistent_matched_edge(self):
edges = [(0,1,10), (1,2,11)] edges = [(0,1,10), (1,2,11)]
@ -482,8 +484,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[2, -1, 0], vertex_mate=[2, -1, 0],
vertex_dual_2x=[11, 11, 11], vertex_dual_2x=[11, 11, 11],
nontrivial_blossom=[]) nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError): with self.assertRaises(MatchingError):
mwmatching._verify_optimum(ctx) verify_optimum(ctx)
def test_negative_vertex_dual(self): def test_negative_vertex_dual(self):
edges = [(0,1,10), (1,2,11)] edges = [(0,1,10), (1,2,11)]
@ -492,8 +494,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 1], vertex_mate=[-1, 2, 1],
vertex_dual_2x=[-2, 22, 0], vertex_dual_2x=[-2, 22, 0],
nontrivial_blossom=[]) nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError): with self.assertRaises(MatchingError):
mwmatching._verify_optimum(ctx) verify_optimum(ctx)
def test_unmatched_nonzero_dual(self): def test_unmatched_nonzero_dual(self):
edges = [(0,1,10), (1,2,11)] edges = [(0,1,10), (1,2,11)]
@ -502,8 +504,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 1], vertex_mate=[-1, 2, 1],
vertex_dual_2x=[9, 11, 11], vertex_dual_2x=[9, 11, 11],
nontrivial_blossom=[]) nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError): with self.assertRaises(MatchingError):
mwmatching._verify_optimum(ctx) verify_optimum(ctx)
def test_negative_edge_slack(self): def test_negative_edge_slack(self):
edges = [(0,1,10), (1,2,11)] edges = [(0,1,10), (1,2,11)]
@ -512,8 +514,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 1], vertex_mate=[-1, 2, 1],
vertex_dual_2x=[0, 11, 11], vertex_dual_2x=[0, 11, 11],
nontrivial_blossom=[]) nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError): with self.assertRaises(MatchingError):
mwmatching._verify_optimum(ctx) verify_optimum(ctx)
def test_matched_edge_slack(self): def test_matched_edge_slack(self):
edges = [(0,1,10), (1,2,11)] edges = [(0,1,10), (1,2,11)]
@ -522,8 +524,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 1], vertex_mate=[-1, 2, 1],
vertex_dual_2x=[0, 20, 11], vertex_dual_2x=[0, 20, 11],
nontrivial_blossom=[]) nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError): with self.assertRaises(MatchingError):
mwmatching._verify_optimum(ctx) verify_optimum(ctx)
def test_negative_blossom_dual(self): def test_negative_blossom_dual(self):
# #
@ -532,11 +534,8 @@ class TestVerificationFail(unittest.TestCase):
# \----8-----/ # \----8-----/
# #
edges = [(0,1,7), (0,2,8), (1,2,9), (2,3,6)] edges = [(0,1,7), (0,2,8), (1,2,9), (2,3,6)]
blossom = mwmatching._NonTrivialBlossom( blossom = NonTrivialBlossom(
subblossoms=[ subblossoms=[Blossom(0), Blossom(1), Blossom(2)],
mwmatching._Blossom(0),
mwmatching._Blossom(1),
mwmatching._Blossom(2)],
edges=[0,2,1]) edges=[0,2,1])
for sub in blossom.subblossoms: for sub in blossom.subblossoms:
sub.parent = blossom sub.parent = blossom
@ -546,8 +545,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[1, 0, 3, 2], vertex_mate=[1, 0, 3, 2],
vertex_dual_2x=[4, 6, 8, 4], vertex_dual_2x=[4, 6, 8, 4],
nontrivial_blossom=[blossom]) nontrivial_blossom=[blossom])
with self.assertRaises(mwmatching.MatchingError): with self.assertRaises(MatchingError):
mwmatching._verify_optimum(ctx) verify_optimum(ctx)
def test_blossom_not_full(self): def test_blossom_not_full(self):
# #
@ -560,11 +559,8 @@ class TestVerificationFail(unittest.TestCase):
# \----2-----/ # \----2-----/
# #
edges = [(0,1,7), (0,2,2), (1,2,5), (0,3,8), (1,4,8)] edges = [(0,1,7), (0,2,2), (1,2,5), (0,3,8), (1,4,8)]
blossom = mwmatching._NonTrivialBlossom( blossom = NonTrivialBlossom(
subblossoms=[ subblossoms=[Blossom(0), Blossom(1), Blossom(2)],
mwmatching._Blossom(0),
mwmatching._Blossom(1),
mwmatching._Blossom(2)],
edges=[0,2,1]) edges=[0,2,1])
for sub in blossom.subblossoms: for sub in blossom.subblossoms:
sub.parent = blossom sub.parent = blossom
@ -574,8 +570,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[3, 4, -1, 0, 1], vertex_mate=[3, 4, -1, 0, 1],
vertex_dual_2x=[4, 10, 0, 12, 6], vertex_dual_2x=[4, 10, 0, 12, 6],
nontrivial_blossom=[blossom]) nontrivial_blossom=[blossom])
with self.assertRaises(mwmatching.MatchingError): with self.assertRaises(MatchingError):
mwmatching._verify_optimum(ctx) verify_optimum(ctx)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,7 +3,7 @@
import random import random
import unittest import unittest
from datastruct import UnionFindQueue, PriorityQueue from mwmatching.datastruct import UnionFindQueue, PriorityQueue
class TestUnionFindQueue(unittest.TestCase): class TestUnionFindQueue(unittest.TestCase):

View File

@ -4,7 +4,7 @@ set -e
echo echo
echo "Running pycodestyle" echo "Running pycodestyle"
pycodestyle python/mwmatching.py python/datastruct.py tests pycodestyle python/mwmatching python/run_matching.py tests
echo echo
echo "Running mypy" echo "Running mypy"
@ -12,20 +12,15 @@ mypy --disallow-incomplete-defs python tests
echo echo
echo "Running pylint" echo "Running pylint"
pylint --ignore=test_mwmatching.py python tests || [ $(($? & 3)) -eq 0 ] pylint --ignore=test_algorithm.py python tests/*.py tests/generate/*.py || [ $(($? & 3)) -eq 0 ]
echo echo
echo "Running test_mwmatching.py" echo "Running unit tests"
python3 python/test_mwmatching.py python3 -m unittest discover -t python -s python/tests
echo
echo "Running test_datastruct.py"
python3 python/test_datastruct.py
echo echo
echo "Checking test coverage" echo "Checking test coverage"
coverage erase coverage erase
coverage run --branch python/test_datastruct.py coverage run --branch -m unittest discover -t python -s python/tests
coverage run -a --branch python/test_mwmatching.py
coverage report -m coverage report -m

View File

@ -19,13 +19,12 @@ count_delta_step = [0]
def patch_matching_code() -> None: def patch_matching_code() -> None:
"""Patch the matching code to count events.""" """Patch the matching code to count events."""
# pylint: disable=import-outside-toplevel,protected-access # pylint: disable=import-outside-toplevel
import mwmatching from mwmatching.algorithm import MatchingContext
orig_make_blossom = mwmatching._MatchingContext.make_blossom orig_make_blossom = MatchingContext.make_blossom
orig_calc_dual_delta_step = ( orig_calc_dual_delta_step = MatchingContext.calc_dual_delta_step
mwmatching._MatchingContext.calc_dual_delta_step)
def stub_make_blossom(*args: Any, **kwargs: Any) -> Any: def stub_make_blossom(*args: Any, **kwargs: Any) -> Any:
count_make_blossom[0] += 1 count_make_blossom[0] += 1
@ -36,9 +35,8 @@ def patch_matching_code() -> None:
ret = orig_calc_dual_delta_step(*args, **kwargs) ret = orig_calc_dual_delta_step(*args, **kwargs)
return ret return ret
mwmatching._MatchingContext.make_blossom = ( # type: ignore MatchingContext.make_blossom = stub_make_blossom # type: ignore
stub_make_blossom) MatchingContext.calc_dual_delta_step = ( # type: ignore
mwmatching._MatchingContext.calc_dual_delta_step = ( # type: ignore
stub_calc_dual_delta_step) stub_calc_dual_delta_step)