1
0
Fork 0

Restructure Python code as package

This commit is contained in:
Joris van Rantwijk 2024-07-07 10:30:21 +02:00
parent f2e8ca1357
commit 147640329f
9 changed files with 158 additions and 158 deletions

View File

@ -0,0 +1,11 @@
"""
Algorithm for finding a maximum weight matching in general graphs.
"""
from .algorithm import (maximum_weight_matching,
adjust_weights_for_maximum_cardinality_matching,
MatchingError)
__all__ = ["maximum_weight_matching",
"adjust_weights_for_maximum_cardinality_matching",
"MatchingError"]

View File

@ -10,7 +10,7 @@ import math
from collections.abc import Sequence
from typing import NamedTuple, Optional
from datastruct import UnionFindQueue, PriorityQueue
from .datastruct import UnionFindQueue, PriorityQueue
def maximum_weight_matching(
@ -66,10 +66,10 @@ def maximum_weight_matching(
return []
# Initialize graph representation.
graph = _GraphInfo(edges)
graph = GraphInfo(edges)
# Initialize the matching algorithm.
ctx = _MatchingContext(graph)
ctx = MatchingContext(graph)
ctx.start()
# Improve the solution until no further improvement is possible.
@ -92,7 +92,7 @@ def maximum_weight_matching(
# there is a bug in the matching algorithm.
# Verification only works reliably for integer weights.
if graph.integer_weights:
_verify_optimum(ctx)
verify_optimum(ctx)
return pairs
@ -277,7 +277,7 @@ def _remove_negative_weight_edges(
return edges
class _GraphInfo:
class GraphInfo:
"""Representation of the input graph.
These data remain unchanged while the algorithm runs.
@ -328,12 +328,12 @@ class _GraphInfo:
# Each vertex may be labeled "S" (outer) or "T" (inner) or be unlabeled.
_LABEL_NONE = 0
_LABEL_S = 1
_LABEL_T = 2
LABEL_NONE = 0
LABEL_S = 1
LABEL_T = 2
class _Blossom:
class Blossom:
"""Represents a blossom in a partially matched graph.
A blossom is an odd-length alternating cycle over sub-blossoms.
@ -361,7 +361,7 @@ class _Blossom:
#
# If this is a top-level blossom,
# "parent = None".
self.parent: Optional[_NonTrivialBlossom] = None
self.parent: Optional[NonTrivialBlossom] = None
# "base_vertex" is the vertex index of the base of the blossom.
# This is the unique vertex which is contained in the blossom
@ -374,7 +374,7 @@ class _Blossom:
# A top-level blossom that is part of an alternating tree,
# has label S or T. An unlabeled top-level blossom is not part
# of any alternating tree.
self.label: int = _LABEL_NONE
self.label: int = LABEL_NONE
# A labeled top-level blossoms keeps track of the edge through which
# it is attached to the alternating tree.
@ -389,11 +389,11 @@ class _Blossom:
# "tree_blossoms" is the set of all top-level blossoms that belong
# to the same alternating tree. The same set instance is shared by
# all top-level blossoms in the tree.
self.tree_blossoms: "Optional[set[_Blossom]]" = None
self.tree_blossoms: "Optional[set[Blossom]]" = None
# Each top-level blossom maintains a union-find datastructure
# containing all vertices in the blossom.
self.vertex_set: "UnionFindQueue[_Blossom, int]"
self.vertex_set: "UnionFindQueue[Blossom, int]"
self.vertex_set = UnionFindQueue(self)
# If this is a top-level unlabeled blossom with an edge to an
@ -415,7 +415,7 @@ class _Blossom:
return [self.base_vertex]
class _NonTrivialBlossom(_Blossom):
class NonTrivialBlossom(Blossom):
"""Represents a non-trivial blossom in a partially matched graph.
A non-trivial blossom is a blossom that contains multiple sub-blossoms
@ -436,7 +436,7 @@ class _NonTrivialBlossom(_Blossom):
def __init__(
self,
subblossoms: list[_Blossom],
subblossoms: list[Blossom],
edges: list[tuple[int, int]]
) -> None:
"""Initialize a new blossom."""
@ -454,7 +454,7 @@ class _NonTrivialBlossom(_Blossom):
#
# "subblossoms[0]" is the start and end of the alternating cycle.
# "subblossoms[0]" contains the base vertex of the blossom.
self.subblossoms: list[_Blossom] = subblossoms
self.subblossoms: list[Blossom] = subblossoms
# "edges" is a list of edges linking the sub-blossoms.
# Each edge is represented as an ordered pair "(x, y)" where "x"
@ -491,13 +491,13 @@ class _NonTrivialBlossom(_Blossom):
"""Return a list of vertex indices contained in the blossom."""
# Use an explicit stack to avoid deep recursion.
stack: list[_NonTrivialBlossom] = [self]
stack: list[NonTrivialBlossom] = [self]
nodes: list[int] = []
while stack:
b = stack.pop()
for sub in b.subblossoms:
if isinstance(sub, _NonTrivialBlossom):
if isinstance(sub, NonTrivialBlossom):
stack.append(sub)
else:
nodes.append(sub.base_vertex)
@ -505,21 +505,21 @@ class _NonTrivialBlossom(_Blossom):
return nodes
class _AlternatingPath(NamedTuple):
class AlternatingPath(NamedTuple):
"""Represents a list of edges forming an alternating path or an
alternating cycle."""
edges: list[tuple[int, int]]
is_cycle: bool
class _MatchingContext:
class MatchingContext:
"""Holds all data used by the matching algorithm.
It contains a partial solution of the matching problem and several
auxiliary data structures.
"""
def __init__(self, graph: _GraphInfo) -> None:
def __init__(self, graph: GraphInfo) -> None:
"""Set up the initial state of the matching algorithm."""
num_vertex = graph.num_vertex
@ -545,14 +545,14 @@ class _MatchingContext:
#
# "trivial_blossom[x]" is the trivial blossom that contains only
# vertex "x".
self.trivial_blossom: list[_Blossom] = [_Blossom(x)
for x in range(num_vertex)]
self.trivial_blossom: list[Blossom] = [Blossom(x)
for x in range(num_vertex)]
# Non-trivial blossoms may be created and destroyed during
# the course of the algorithm.
#
# Initially there are no non-trivial blossoms.
self.nontrivial_blossom: set[_NonTrivialBlossom] = set()
self.nontrivial_blossom: set[NonTrivialBlossom] = set()
# "vertex_set_node[x]" represents the vertex "x" inside the
# union-find datastructure of its top-level blossom.
@ -593,7 +593,7 @@ class _MatchingContext:
# Queue containing unlabeled top-level blossoms that have an edge to
# an S-blossom. The priority of a blossom is 2 times its least slack
# to an S blossom, plus 2 times the running sum of delta steps.
self.delta2_queue: PriorityQueue[_Blossom] = PriorityQueue()
self.delta2_queue: PriorityQueue[Blossom] = PriorityQueue()
# Queue containing edges between S-vertices in different top-level
# blossoms. The priority of an edge is its slack plus 2 times the
@ -605,7 +605,7 @@ class _MatchingContext:
# Queue containing top-level non-trivial T-blossoms.
# The priority of a blossom is its dual plus 2 times the running
# sum of delta steps.
self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue()
self.delta4_queue: PriorityQueue[NonTrivialBlossom] = PriorityQueue()
# For each T-vertex or unlabeled vertex "x",
# "vertex_sedge_queue[x]" is a queue of edges between "x" and any
@ -648,7 +648,7 @@ class _MatchingContext:
(x, y, w) = self.graph.edges[e]
return self.vertex_dual_2x[x] + self.vertex_dual_2x[y] - 2 * w
def delta2_add_edge(self, e: int, y: int, by: _Blossom) -> None:
def delta2_add_edge(self, e: int, y: int, by: Blossom) -> None:
"""Add edge "e" for delta2 tracking.
Edge "e" connects an S-vertex to a T-vertex or unlabeled vertex "y".
@ -674,14 +674,14 @@ class _MatchingContext:
# If the blossom is unlabeled and the new edge becomes its least-slack
# S-edge, insert or update the blossom in the global delta2 queue.
if by.label == _LABEL_NONE:
if by.label == LABEL_NONE:
prio += by.vertex_dual_offset
if by.delta2_node is None:
by.delta2_node = self.delta2_queue.insert(prio, by)
elif prio < by.delta2_node.prio:
self.delta2_queue.decrease_prio(by.delta2_node, prio)
def delta2_remove_edge(self, e: int, y: int, by: _Blossom) -> None:
def delta2_remove_edge(self, e: int, y: int, by: Blossom) -> None:
"""Remove edge "e" from delta2 tracking.
This function is called if an S-vertex becomes unlabeled,
@ -705,7 +705,7 @@ class _MatchingContext:
# If necessary, update the priority of "y" in its UnionFindQueue.
if prio > self.vertex_set_node[y].prio:
self.vertex_set_node[y].set_prio(prio)
if by.label == _LABEL_NONE:
if by.label == LABEL_NONE:
# Update or delete the blossom in the global delta2 queue.
assert by.delta2_node is not None
prio = by.vertex_set.min_prio()
@ -718,7 +718,7 @@ class _MatchingContext:
self.delta2_queue.delete(by.delta2_node)
by.delta2_node = None
def delta2_enable_blossom(self, blossom: _Blossom) -> None:
def delta2_enable_blossom(self, blossom: Blossom) -> None:
"""Enable delta2 tracking for "blossom".
This function is called when a blossom becomes an unlabeled top-level
@ -733,7 +733,7 @@ class _MatchingContext:
prio += blossom.vertex_dual_offset
blossom.delta2_node = self.delta2_queue.insert(prio, blossom)
def delta2_disable_blossom(self, blossom: _Blossom) -> None:
def delta2_disable_blossom(self, blossom: Blossom) -> None:
"""Disable delta2 tracking for "blossom".
The blossom will be removed from the global delta2 queue.
@ -780,7 +780,7 @@ class _MatchingContext:
prio = delta2_node.prio
slack_2x = prio - self.delta_sum_2x
assert blossom.parent is None
assert blossom.label == _LABEL_NONE
assert blossom.label == LABEL_NONE
x = blossom.vertex_set.min_elem()
e = self.vertex_sedge_queue[x].find_min().data
@ -840,7 +840,7 @@ class _MatchingContext:
(x, y, _w) = self.graph.edges[e]
bx = self.vertex_set_node[x].find()
by = self.vertex_set_node[y].find()
assert (bx.label == _LABEL_S) and (by.label == _LABEL_S)
assert (bx.label == LABEL_S) and (by.label == LABEL_S)
if bx is not by:
slack = delta3_node.prio - self.delta_sum_2x
return (e, slack)
@ -859,7 +859,7 @@ class _MatchingContext:
# Managing blossom labels:
#
def assign_blossom_label_s(self, blossom: _Blossom) -> None:
def assign_blossom_label_s(self, blossom: Blossom) -> None:
"""Change an unlabeled top-level blossom into an S-blossom.
For a blossom with "j" vertices and "k" incident edges,
@ -870,8 +870,8 @@ class _MatchingContext:
"""
assert blossom.parent is None
assert blossom.label == _LABEL_NONE
blossom.label = _LABEL_S
assert blossom.label == LABEL_NONE
blossom.label = LABEL_S
# Labeled blossoms must not be in the delta2 queue.
self.delta2_disable_blossom(blossom)
@ -887,7 +887,7 @@ class _MatchingContext:
# The value of blossom.dual_var must be adjusted accordingly
# when the blossom changes from unlabeled to S-blossom.
#
if isinstance(blossom, _NonTrivialBlossom):
if isinstance(blossom, NonTrivialBlossom):
blossom.dual_var -= self.delta_sum_2x
# Apply pending updates to vertex dual variables and prepare
@ -916,20 +916,20 @@ class _MatchingContext:
# Add the new S-vertices to the scan queue.
self.scan_queue.extend(vertices)
def assign_blossom_label_t(self, blossom: _Blossom) -> None:
def assign_blossom_label_t(self, blossom: Blossom) -> None:
"""Change an unlabeled top-level blossom into a T-blossom.
This function takes time O(log(n)).
"""
assert blossom.parent is None
assert blossom.label == _LABEL_NONE
blossom.label = _LABEL_T
assert blossom.label == LABEL_NONE
blossom.label = LABEL_T
# Labeled blossoms must not be in the delta2 queue.
self.delta2_disable_blossom(blossom)
if isinstance(blossom, _NonTrivialBlossom):
if isinstance(blossom, NonTrivialBlossom):
# Adjust for lazy updating of T-blossom dual variables.
#
@ -962,7 +962,7 @@ class _MatchingContext:
#
blossom.vertex_dual_offset -= self.delta_sum_2x
def remove_blossom_label_s(self, blossom: _Blossom) -> None:
def remove_blossom_label_s(self, blossom: Blossom) -> None:
"""Change a top-level S-blossom into an unlabeled blossom.
For a blossom with "j" vertices and "k" incident edges,
@ -973,11 +973,11 @@ class _MatchingContext:
"""
assert blossom.parent is None
assert blossom.label == _LABEL_S
blossom.label = _LABEL_NONE
assert blossom.label == LABEL_S
blossom.label = LABEL_NONE
# Unwind lazy delta updates to the S-blossom dual variable.
if isinstance(blossom, _NonTrivialBlossom):
if isinstance(blossom, NonTrivialBlossom):
blossom.dual_var += self.delta_sum_2x
assert blossom.vertex_dual_offset == 0
@ -1002,7 +1002,7 @@ class _MatchingContext:
self.delta3_remove_edge(e)
by = self.vertex_set_node[y].find()
if by.label == _LABEL_S:
if by.label == LABEL_S:
# Edge "e" connects unlabeled vertex "x" to S-vertex "y".
# It must be tracked for delta2 via vertex "x".
self.delta2_add_edge(e, x, blossom)
@ -1013,17 +1013,17 @@ class _MatchingContext:
# removed now.
self.delta2_remove_edge(e, y, by)
def remove_blossom_label_t(self, blossom: _Blossom) -> None:
def remove_blossom_label_t(self, blossom: Blossom) -> None:
"""Change a top-level T-blossom into an unlabeled blossom.
This function takes time O(log(n)).
"""
assert blossom.parent is None
assert blossom.label == _LABEL_T
blossom.label = _LABEL_NONE
assert blossom.label == LABEL_T
blossom.label = LABEL_NONE
if isinstance(blossom, _NonTrivialBlossom):
if isinstance(blossom, NonTrivialBlossom):
# Unlabeled blossoms are not tracked in the delta4 queue.
assert blossom.delta4_node is not None
@ -1039,31 +1039,31 @@ class _MatchingContext:
# Enable unlabeled top-level blossom for delta2 tracking.
self.delta2_enable_blossom(blossom)
def change_s_blossom_to_subblossom(self, blossom: _Blossom) -> None:
def change_s_blossom_to_subblossom(self, blossom: Blossom) -> None:
"""Change a top-level S-blossom into an S-subblossom.
This function takes time O(1).
"""
assert blossom.parent is None
assert blossom.label == _LABEL_S
blossom.label = _LABEL_NONE
assert blossom.label == LABEL_S
blossom.label = LABEL_NONE
# Unwind lazy delta updates to the S-blossom dual variable.
if isinstance(blossom, _NonTrivialBlossom):
if isinstance(blossom, NonTrivialBlossom):
blossom.dual_var += self.delta_sum_2x
#
# General support routines:
#
def reset_blossom_label(self, blossom: _Blossom) -> None:
def reset_blossom_label(self, blossom: Blossom) -> None:
"""Remove blossom label."""
assert blossom.parent is None
assert blossom.label != _LABEL_NONE
assert blossom.label != LABEL_NONE
if blossom.label == _LABEL_S:
if blossom.label == LABEL_S:
self.remove_blossom_label_s(blossom)
else:
self.remove_blossom_label_t(blossom)
@ -1072,7 +1072,7 @@ class _MatchingContext:
"""TODO -- remove this function, only for debugging"""
for blossom in itertools.chain(self.trivial_blossom,
self.nontrivial_blossom):
if (blossom.parent is None) and (blossom.label != _LABEL_NONE):
if (blossom.parent is None) and (blossom.label != LABEL_NONE):
assert blossom.tree_blossoms is not None
assert blossom in blossom.tree_blossoms
if blossom.tree_edge is not None:
@ -1084,7 +1084,7 @@ class _MatchingContext:
assert blossom.tree_edge is None
assert blossom.tree_blossoms is None
def remove_alternating_tree(self, tree_blossoms: set[_Blossom]) -> None:
def remove_alternating_tree(self, tree_blossoms: set[Blossom]) -> None:
"""Reset the alternating tree consisting of the specified blossoms.
Marks the blossoms as unlabeled.
@ -1093,13 +1093,13 @@ class _MatchingContext:
This function takes time O((n + m) * log(n)).
"""
for blossom in tree_blossoms:
assert blossom.label != _LABEL_NONE
assert blossom.label != LABEL_NONE
assert blossom.tree_blossoms is tree_blossoms
self.reset_blossom_label(blossom)
blossom.tree_edge = None
blossom.tree_blossoms = None
def trace_alternating_paths(self, x: int, y: int) -> _AlternatingPath:
def trace_alternating_paths(self, x: int, y: int) -> AlternatingPath:
"""Trace back through the alternating trees from vertices "x" and "y".
If both vertices are part of the same alternating tree, this function
@ -1119,7 +1119,7 @@ class _MatchingContext:
blossoms.
"""
marked_blossoms: list[_Blossom] = []
marked_blossoms: list[Blossom] = []
# "xedges" is a list of edges used while tracing from "x".
# "yedges" is a list of edges used while tracing from "y".
@ -1129,7 +1129,7 @@ class _MatchingContext:
# "first_common" is the first common ancestor of "x" and "y"
# in the alternating tree, or None if there is no common ancestor.
first_common: Optional[_Blossom] = None
first_common: Optional[Blossom] = None
# Alternate between tracing the path from "x" and the path from "y".
# This ensures that the search time is bounded by the size of the
@ -1179,14 +1179,14 @@ class _MatchingContext:
# Any S-to-S alternating path must have odd length.
assert len(path_edges) % 2 == 1
return _AlternatingPath(edges=path_edges,
is_cycle=(first_common is not None))
return AlternatingPath(edges=path_edges,
is_cycle=(first_common is not None))
#
# Merge and expand blossoms:
#
def make_blossom(self, path: _AlternatingPath) -> None:
def make_blossom(self, path: AlternatingPath) -> None:
"""Create a new blossom from an alternating cycle.
Assign label S to the new blossom.
@ -1214,21 +1214,21 @@ class _MatchingContext:
assert subblossoms[1:] == subblossoms_next[:-1]
# Blossom must start and end with an S-sub-blossom.
assert subblossoms[0].label == _LABEL_S
assert subblossoms[0].label == LABEL_S
# Remove blossom labels.
# Mark vertices inside former T-blossoms as S-vertices.
for sub in subblossoms:
if sub.label == _LABEL_T:
if sub.label == LABEL_T:
self.remove_blossom_label_t(sub)
self.assign_blossom_label_s(sub)
self.change_s_blossom_to_subblossom(sub)
# Create the new blossom object.
blossom = _NonTrivialBlossom(subblossoms, path.edges)
blossom = NonTrivialBlossom(subblossoms, path.edges)
# Assign label S to the new blossom.
blossom.label = _LABEL_S
blossom.label = LABEL_S
# Prepare for lazy updating of S-blossom dual variable.
blossom.dual_var = -self.delta_sum_2x
@ -1257,9 +1257,9 @@ class _MatchingContext:
@staticmethod
def find_path_through_blossom(
blossom: _NonTrivialBlossom,
sub: _Blossom
) -> tuple[list[_Blossom], list[tuple[int, int]]]:
blossom: NonTrivialBlossom,
sub: Blossom
) -> tuple[list[Blossom], list[tuple[int, int]]]:
"""Construct a path with an even number of edges through the
specified blossom, from sub-blossom "sub" to the base of "blossom".
@ -1285,14 +1285,14 @@ class _MatchingContext:
return (nodes, edges)
def expand_unlabeled_blossom(self, blossom: _NonTrivialBlossom) -> None:
def expand_unlabeled_blossom(self, blossom: NonTrivialBlossom) -> None:
"""Expand the specified unlabeled blossom.
This function takes total time O(n * log(n)) per stage.
"""
assert blossom.parent is None
assert blossom.label == _LABEL_NONE
assert blossom.label == LABEL_NONE
# Remove blossom from the delta2 queue.
self.delta2_disable_blossom(blossom)
@ -1306,7 +1306,7 @@ class _MatchingContext:
# Convert sub-blossoms into top-level blossoms.
for sub in blossom.subblossoms:
assert sub.label == _LABEL_NONE
assert sub.label == LABEL_NONE
sub.parent = None
assert sub.vertex_dual_offset == 0
@ -1320,14 +1320,14 @@ class _MatchingContext:
# Delete the expanded blossom.
self.nontrivial_blossom.remove(blossom)
def expand_t_blossom(self, blossom: _NonTrivialBlossom) -> None:
def expand_t_blossom(self, blossom: NonTrivialBlossom) -> None:
"""Expand the specified T-blossom.
This function takes total time O(n * log(n) + m) per stage.
"""
assert blossom.parent is None
assert blossom.label == _LABEL_T
assert blossom.label == LABEL_T
assert blossom.delta2_node is None
# Remove blossom from its alternating tree.
@ -1391,9 +1391,9 @@ class _MatchingContext:
def augment_blossom_rec(
self,
blossom: _NonTrivialBlossom,
sub: _Blossom,
stack: list[tuple[_NonTrivialBlossom, _Blossom]]
blossom: NonTrivialBlossom,
sub: Blossom,
stack: list[tuple[NonTrivialBlossom, Blossom]]
) -> None:
"""Augment along an alternating path through the specified blossom,
from sub-blossom "sub" to the base vertex of the blossom.
@ -1430,11 +1430,11 @@ class _MatchingContext:
# Augment through the subblossoms touching the edge (x, y).
# Nothing needs to be done for trivial subblossoms.
bx = path_nodes[p+1]
if isinstance(bx, _NonTrivialBlossom):
if isinstance(bx, NonTrivialBlossom):
stack.append((bx, self.trivial_blossom[x]))
by = path_nodes[p+2]
if isinstance(by, _NonTrivialBlossom):
if isinstance(by, NonTrivialBlossom):
stack.append((by, self.trivial_blossom[y]))
# Rotate the subblossom list so the new base ends up in position 0.
@ -1450,8 +1450,8 @@ class _MatchingContext:
def augment_blossom(
self,
blossom: _NonTrivialBlossom,
sub: _Blossom
blossom: NonTrivialBlossom,
sub: Blossom
) -> None:
"""Augment along an alternating path through the specified blossom,
from sub-blossom "sub" to the base vertex of the blossom.
@ -1486,7 +1486,7 @@ class _MatchingContext:
# Augment "blossom" from "sub" to the base vertex.
self.augment_blossom_rec(blossom, sub, stack)
def augment_matching(self, path: _AlternatingPath) -> None:
def augment_matching(self, path: AlternatingPath) -> None:
"""Augment the matching through the specified augmenting path.
This function takes time O(n).
@ -1515,11 +1515,11 @@ class _MatchingContext:
# Augment the non-trivial blossoms on either side of this edge.
# No action is necessary for trivial blossoms.
bx = self.vertex_set_node[x].find()
if isinstance(bx, _NonTrivialBlossom):
if isinstance(bx, NonTrivialBlossom):
self.augment_blossom(bx, self.trivial_blossom[x])
by = self.vertex_set_node[y].find()
if isinstance(by, _NonTrivialBlossom):
if isinstance(by, NonTrivialBlossom):
self.augment_blossom(by, self.trivial_blossom[y])
# Pull the edge into the matching.
@ -1551,7 +1551,7 @@ class _MatchingContext:
assert y != -1
by = self.vertex_set_node[y].find()
assert by.label == _LABEL_T
assert by.label == LABEL_T
assert by.tree_blossoms is not None
# Attach the blossom that contains "x" to the alternating tree.
@ -1575,10 +1575,10 @@ class _MatchingContext:
bx = self.vertex_set_node[x].find()
by = self.vertex_set_node[y].find()
assert bx.label == _LABEL_S
assert bx.label == LABEL_S
# Expand zero-dual blossoms before assigning label T.
while isinstance(by, _NonTrivialBlossom) and (by.dual_var == 0):
while isinstance(by, NonTrivialBlossom) and (by.dual_var == 0):
self.expand_unlabeled_blossom(by)
by = self.vertex_set_node[y].find()
@ -1617,8 +1617,8 @@ class _MatchingContext:
bx = self.vertex_set_node[x].find()
by = self.vertex_set_node[y].find()
assert bx.label == _LABEL_S
assert by.label == _LABEL_S
assert bx.label == LABEL_S
assert by.label == LABEL_S
assert bx is not by
# Trace back through the alternating trees from "x" and "y".
@ -1680,7 +1680,7 @@ class _MatchingContext:
# Double-check that "x" is an S-vertex.
bx = self.vertex_set_node[x].find()
assert bx.label == _LABEL_S
assert bx.label == LABEL_S
# Scan the edges that are incident on "x".
# This loop runs through O(m) iterations per stage.
@ -1703,7 +1703,7 @@ class _MatchingContext:
if bx is by:
continue
if by.label == _LABEL_S:
if by.label == LABEL_S:
self.delta3_add_edge(e)
else:
self.delta2_add_edge(e, y, by)
@ -1716,7 +1716,7 @@ class _MatchingContext:
def calc_dual_delta_step(
self
) -> tuple[int, float, int, Optional[_NonTrivialBlossom]]:
) -> tuple[int, float, int, Optional[NonTrivialBlossom]]:
"""Calculate a delta step in the dual LPP problem.
This function returns the minimum of the 4 types of delta values,
@ -1740,7 +1740,7 @@ class _MatchingContext:
Tuple (delta_type, delta_2x, delta_edge, delta_blossom).
"""
delta_edge = -1
delta_blossom: Optional[_NonTrivialBlossom] = None
delta_blossom: Optional[NonTrivialBlossom] = None
# Compute delta1: minimum dual variable of any S-vertex.
# All unmatched vertices have the same dual value, and this is
@ -1770,7 +1770,7 @@ class _MatchingContext:
# This takes time O(log(n)).
if not self.delta4_queue.empty():
blossom = self.delta4_queue.find_min().data
assert blossom.label == _LABEL_T
assert blossom.label == LABEL_T
assert blossom.parent is None
blossom_dual = blossom.dual_var - self.delta_sum_2x
if blossom_dual <= delta_2x:
@ -1847,7 +1847,7 @@ class _MatchingContext:
# Use the edge from S-vertex to unlabeled vertex that got
# unlocked through the delta update.
(x, y, _w) = self.graph.edges[delta_edge]
if self.vertex_set_node[x].find().label != _LABEL_S:
if self.vertex_set_node[x].find().label != LABEL_S:
(x, y) = (y, x)
self.extend_tree_s_to_t(x, y)
@ -1886,9 +1886,9 @@ class _MatchingContext:
self.nontrivial_blossom):
# Remove blossom label.
if (blossom.parent is None) and (blossom.label != _LABEL_NONE):
if (blossom.parent is None) and (blossom.label != LABEL_NONE):
self.reset_blossom_label(blossom)
assert blossom.label == _LABEL_NONE
assert blossom.label == LABEL_NONE
# Remove blossom from alternating tree.
blossom.tree_edge = None
@ -1906,8 +1906,8 @@ class _MatchingContext:
def _verify_blossom_edges(
ctx: _MatchingContext,
blossom: _NonTrivialBlossom,
ctx: MatchingContext,
blossom: NonTrivialBlossom,
edge_slack_2x: list[float]
) -> None:
"""Descend down the blossom tree to find edges that are contained
@ -1943,7 +1943,7 @@ def _verify_blossom_edges(
path_num_matched: list[int] = [0]
# Use an explicit stack to avoid deep recursion.
stack: list[tuple[_NonTrivialBlossom, int]] = [(blossom, -1)]
stack: list[tuple[NonTrivialBlossom, int]] = [(blossom, -1)]
while stack:
(blossom, p) = stack[-1]
@ -1969,7 +1969,7 @@ def _verify_blossom_edges(
# Examine the next sub-blossom at the current level.
sub = blossom.subblossoms[p]
if isinstance(sub, _NonTrivialBlossom):
if isinstance(sub, NonTrivialBlossom):
# Prepare to descent into the selected sub-blossom and
# scan it recursively.
stack.append((sub, -1))
@ -2031,7 +2031,7 @@ def _verify_blossom_edges(
stack.pop()
def _verify_optimum(ctx: _MatchingContext) -> None:
def verify_optimum(ctx: MatchingContext) -> None:
"""Verify that the optimum solution has been found.
This function takes time O(n**2).

View File

0
python/tests/__init__.py Normal file
View File

View File

@ -4,10 +4,12 @@ import math
import unittest
from unittest.mock import Mock
import mwmatching
from mwmatching import (
maximum_weight_matching as mwm,
adjust_weights_for_maximum_cardinality_matching as adj)
from mwmatching.algorithm import (
MatchingError, GraphInfo, Blossom, NonTrivialBlossom, MatchingContext,
verify_optimum)
class TestMaximumWeightMatching(unittest.TestCase):
@ -431,10 +433,10 @@ class TestMaximumCardinalityMatching(unittest.TestCase):
class TestGraphInfo(unittest.TestCase):
"""Test _GraphInfo helper class."""
"""Test GraphInfo helper class."""
def test_empty(self):
graph = mwmatching._GraphInfo([])
graph = GraphInfo([])
self.assertEqual(graph.num_vertex, 0)
self.assertEqual(graph.edges, [])
self.assertEqual(graph.adjacent_edges, [])
@ -449,8 +451,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate,
vertex_dual_2x,
nontrivial_blossom):
ctx = Mock(spec=mwmatching._MatchingContext)
ctx.graph = mwmatching._GraphInfo(edges)
ctx = Mock(spec=MatchingContext)
ctx.graph = GraphInfo(edges)
ctx.vertex_mate = vertex_mate
ctx.vertex_dual_2x = vertex_dual_2x
ctx.nontrivial_blossom = nontrivial_blossom
@ -463,7 +465,7 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 1],
vertex_dual_2x=[0, 20, 2],
nontrivial_blossom=[])
mwmatching._verify_optimum(ctx)
verify_optimum(ctx)
def test_asymmetric_matching(self):
edges = [(0,1,10), (1,2,11)]
@ -472,8 +474,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 0],
vertex_dual_2x=[0, 20, 2],
nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError):
mwmatching._verify_optimum(ctx)
with self.assertRaises(MatchingError):
verify_optimum(ctx)
def test_nonexistent_matched_edge(self):
edges = [(0,1,10), (1,2,11)]
@ -482,8 +484,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[2, -1, 0],
vertex_dual_2x=[11, 11, 11],
nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError):
mwmatching._verify_optimum(ctx)
with self.assertRaises(MatchingError):
verify_optimum(ctx)
def test_negative_vertex_dual(self):
edges = [(0,1,10), (1,2,11)]
@ -492,8 +494,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 1],
vertex_dual_2x=[-2, 22, 0],
nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError):
mwmatching._verify_optimum(ctx)
with self.assertRaises(MatchingError):
verify_optimum(ctx)
def test_unmatched_nonzero_dual(self):
edges = [(0,1,10), (1,2,11)]
@ -502,8 +504,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 1],
vertex_dual_2x=[9, 11, 11],
nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError):
mwmatching._verify_optimum(ctx)
with self.assertRaises(MatchingError):
verify_optimum(ctx)
def test_negative_edge_slack(self):
edges = [(0,1,10), (1,2,11)]
@ -512,8 +514,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 1],
vertex_dual_2x=[0, 11, 11],
nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError):
mwmatching._verify_optimum(ctx)
with self.assertRaises(MatchingError):
verify_optimum(ctx)
def test_matched_edge_slack(self):
edges = [(0,1,10), (1,2,11)]
@ -522,8 +524,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[-1, 2, 1],
vertex_dual_2x=[0, 20, 11],
nontrivial_blossom=[])
with self.assertRaises(mwmatching.MatchingError):
mwmatching._verify_optimum(ctx)
with self.assertRaises(MatchingError):
verify_optimum(ctx)
def test_negative_blossom_dual(self):
#
@ -532,11 +534,8 @@ class TestVerificationFail(unittest.TestCase):
# \----8-----/
#
edges = [(0,1,7), (0,2,8), (1,2,9), (2,3,6)]
blossom = mwmatching._NonTrivialBlossom(
subblossoms=[
mwmatching._Blossom(0),
mwmatching._Blossom(1),
mwmatching._Blossom(2)],
blossom = NonTrivialBlossom(
subblossoms=[Blossom(0), Blossom(1), Blossom(2)],
edges=[0,2,1])
for sub in blossom.subblossoms:
sub.parent = blossom
@ -546,8 +545,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[1, 0, 3, 2],
vertex_dual_2x=[4, 6, 8, 4],
nontrivial_blossom=[blossom])
with self.assertRaises(mwmatching.MatchingError):
mwmatching._verify_optimum(ctx)
with self.assertRaises(MatchingError):
verify_optimum(ctx)
def test_blossom_not_full(self):
#
@ -560,11 +559,8 @@ class TestVerificationFail(unittest.TestCase):
# \----2-----/
#
edges = [(0,1,7), (0,2,2), (1,2,5), (0,3,8), (1,4,8)]
blossom = mwmatching._NonTrivialBlossom(
subblossoms=[
mwmatching._Blossom(0),
mwmatching._Blossom(1),
mwmatching._Blossom(2)],
blossom = NonTrivialBlossom(
subblossoms=[Blossom(0), Blossom(1), Blossom(2)],
edges=[0,2,1])
for sub in blossom.subblossoms:
sub.parent = blossom
@ -574,8 +570,8 @@ class TestVerificationFail(unittest.TestCase):
vertex_mate=[3, 4, -1, 0, 1],
vertex_dual_2x=[4, 10, 0, 12, 6],
nontrivial_blossom=[blossom])
with self.assertRaises(mwmatching.MatchingError):
mwmatching._verify_optimum(ctx)
with self.assertRaises(MatchingError):
verify_optimum(ctx)
if __name__ == "__main__":

View File

@ -3,7 +3,7 @@
import random
import unittest
from datastruct import UnionFindQueue, PriorityQueue
from mwmatching.datastruct import UnionFindQueue, PriorityQueue
class TestUnionFindQueue(unittest.TestCase):

View File

@ -4,7 +4,7 @@ set -e
echo
echo "Running pycodestyle"
pycodestyle python/mwmatching.py python/datastruct.py tests
pycodestyle python/mwmatching python/run_matching.py tests
echo
echo "Running mypy"
@ -12,20 +12,15 @@ mypy --disallow-incomplete-defs python tests
echo
echo "Running pylint"
pylint --ignore=test_mwmatching.py python tests || [ $(($? & 3)) -eq 0 ]
pylint --ignore=test_algorithm.py python tests/*.py tests/generate/*.py || [ $(($? & 3)) -eq 0 ]
echo
echo "Running test_mwmatching.py"
python3 python/test_mwmatching.py
echo
echo "Running test_datastruct.py"
python3 python/test_datastruct.py
echo "Running unit tests"
python3 -m unittest discover -t python -s python/tests
echo
echo "Checking test coverage"
coverage erase
coverage run --branch python/test_datastruct.py
coverage run -a --branch python/test_mwmatching.py
coverage run --branch -m unittest discover -t python -s python/tests
coverage report -m

View File

@ -19,13 +19,12 @@ count_delta_step = [0]
def patch_matching_code() -> None:
"""Patch the matching code to count events."""
# pylint: disable=import-outside-toplevel,protected-access
# pylint: disable=import-outside-toplevel
import mwmatching
from mwmatching.algorithm import MatchingContext
orig_make_blossom = mwmatching._MatchingContext.make_blossom
orig_calc_dual_delta_step = (
mwmatching._MatchingContext.calc_dual_delta_step)
orig_make_blossom = MatchingContext.make_blossom
orig_calc_dual_delta_step = MatchingContext.calc_dual_delta_step
def stub_make_blossom(*args: Any, **kwargs: Any) -> Any:
count_make_blossom[0] += 1
@ -36,9 +35,8 @@ def patch_matching_code() -> None:
ret = orig_calc_dual_delta_step(*args, **kwargs)
return ret
mwmatching._MatchingContext.make_blossom = ( # type: ignore
stub_make_blossom)
mwmatching._MatchingContext.calc_dual_delta_step = ( # type: ignore
MatchingContext.make_blossom = stub_make_blossom # type: ignore
MatchingContext.calc_dual_delta_step = ( # type: ignore
stub_calc_dual_delta_step)