1
0
Fork 0

Code style cleanups

This commit is contained in:
Joris van Rantwijk 2024-06-22 20:04:49 +02:00
parent de30ac3c5e
commit 0675230692
6 changed files with 46 additions and 31 deletions

View File

@ -162,7 +162,9 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
assert node is not None assert node is not None
return node.min_node.data return node.min_node.data
def merge(self, sub_queues: "list[UnionFindQueue[_NameT, _ElemT]]") -> None: def merge(self,
sub_queues: "list[UnionFindQueue[_NameT, _ElemT]]"
) -> None:
"""Merge the specified queues. """Merge the specified queues.
This queue must inititially be empty. This queue must inititially be empty.
@ -171,7 +173,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
This function removes all elements from the specified sub-queues This function removes all elements from the specified sub-queues
and adds them to this queue. and adds them to this queue.
After merging, this queue retains a reference to the list of sub-queues. After merging, this queue retains a reference to the list of
sub-queues.
This function takes time O(len(sub_queues) * log(n)). This function takes time O(len(sub_queues) * log(n)).
""" """
@ -268,7 +271,9 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
min_node = right_min_node min_node = right_min_node
node.min_node = min_node node.min_node = min_node
def _rotate_left(self, node: Node[_NameT, _ElemT]) -> Node[_NameT, _ElemT]: def _rotate_left(self,
node: Node[_NameT, _ElemT]
) -> Node[_NameT, _ElemT]:
"""Rotate the specified subtree to the left. """Rotate the specified subtree to the left.
Return the new root node of the subtree. Return the new root node of the subtree.
@ -303,7 +308,9 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
return new_top return new_top
def _rotate_right(self, node: Node[_NameT, _ElemT]) -> Node[_NameT, _ElemT]: def _rotate_right(self,
node: Node[_NameT, _ElemT]
) -> Node[_NameT, _ElemT]:
"""Rotate the specified node to the right. """Rotate the specified node to the right.
Return the new root node of the subtree. Return the new root node of the subtree.
@ -338,7 +345,9 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
return new_top return new_top
def _rebalance_up(self, node: Node[_NameT, _ElemT]) -> Node[_NameT, _ElemT]: def _rebalance_up(self,
node: Node[_NameT, _ElemT]
) -> Node[_NameT, _ElemT]:
"""Repair and rebalance the specified node and its ancestors. """Repair and rebalance the specified node and its ancestors.
Return the root node of the rebalanced tree. Return the root node of the rebalanced tree.

View File

@ -1,14 +1,18 @@
#!/bin/sh #!/bin/bash
set -e set -e
echo
echo "Running pycodestyle"
pycodestyle python/mwmatching.py python/datastruct.py tests
echo echo
echo "Running mypy" echo "Running mypy"
mypy --disallow-incomplete-defs python tests mypy --disallow-incomplete-defs python tests
echo echo
echo "Running pylint" echo "Running pylint"
pylint --ignore=test_mwmatching.py python tests pylint --ignore=test_mwmatching.py python tests || [ $(($? & 3)) -eq 0 ]
echo echo
echo "Running test_mwmatching.py" echo "Running test_mwmatching.py"

View File

@ -14,7 +14,7 @@ from typing import TextIO
def write_dimacs_graph( def write_dimacs_graph(
f: TextIO, f: TextIO,
edges: list[tuple[int, int, int|float]] edges: list[tuple[int, int, float]]
) -> None: ) -> None:
"""Write a graph in DIMACS edge list format.""" """Write a graph in DIMACS edge list format."""
@ -38,7 +38,7 @@ def make_random_graph(
max_weight: float, max_weight: float,
float_weights: bool, float_weights: bool,
rng: random.Random rng: random.Random
) -> list[tuple[int, int, int|float]]: ) -> list[tuple[int, int, float]]:
"""Generate a random graph with random edge weights.""" """Generate a random graph with random edge weights."""
edge_set: set[tuple[int, int]] = set() edge_set: set[tuple[int, int]] = set()
@ -59,9 +59,9 @@ def make_random_graph(
rng.shuffle(edge_candidates) rng.shuffle(edge_candidates)
edge_set.update(edge_candidates[:m]) edge_set.update(edge_candidates[:m])
edges: list[tuple[int, int, int|float]] = [] edges: list[tuple[int, int, float]] = []
for (x, y) in sorted(edge_set): for (x, y) in sorted(edge_set):
w: int|float w: float
if float_weights: if float_weights:
w = rng.uniform(1.0e-8, max_weight) w = rng.uniform(1.0e-8, max_weight)
else: else:

View File

@ -36,7 +36,8 @@ def patch_matching_code() -> None:
ret = orig_substage_calc_dual_delta(*args, **kwargs) ret = orig_substage_calc_dual_delta(*args, **kwargs)
return ret return ret
mwmatching._MatchingContext.make_blossom = stub_make_blossom # type: ignore mwmatching._MatchingContext.make_blossom = ( # type: ignore
stub_make_blossom)
mwmatching._MatchingContext.substage_calc_dual_delta = ( # type: ignore mwmatching._MatchingContext.substage_calc_dual_delta = ( # type: ignore
stub_substage_calc_dual_delta) stub_substage_calc_dual_delta)

View File

@ -29,7 +29,7 @@ class SolverError(Exception):
class Graph(NamedTuple): class Graph(NamedTuple):
"""Represents a graph. Vertex indices start from 0.""" """Represents a graph. Vertex indices start from 0."""
edges: list[tuple[int, int, int|float]] edges: list[tuple[int, int, float]]
def num_vertex(self) -> int: def num_vertex(self) -> int:
"""Count number of vertices.""" """Count number of vertices."""
@ -55,11 +55,11 @@ class RunStatus(enum.IntEnum):
class RunResult(NamedTuple): class RunResult(NamedTuple):
"""Represent the result of running a solver on a graph.""" """Represent the result of running a solver on a graph."""
status: RunStatus = RunStatus.OK status: RunStatus = RunStatus.OK
weight: int|float = 0 weight: float = 0
run_time: Sequence[float] = () run_time: Sequence[float] = ()
def parse_int_or_float(s: str) -> int|float: def parse_int_or_float(s: str) -> float:
"""Convert a string to integer or float value.""" """Convert a string to integer or float value."""
try: try:
return int(s) return int(s)
@ -130,11 +130,11 @@ def write_dimacs_graph(f: TextIO, graph: Graph) -> None:
print(f"e {x+1} {y+1} {w:.12g}", file=f) print(f"e {x+1} {y+1} {w:.12g}", file=f)
def read_dimacs_matching(f: TextIO) -> tuple[int|float, Matching]: def read_dimacs_matching(f: TextIO) -> tuple[float, Matching]:
"""Read a matching solution in DIMACS format.""" """Read a matching solution in DIMACS format."""
have_weight = False have_weight = False
weight: int|float = 0 weight: float = 0
pairs: list[tuple[int, int]] = [] pairs: list[tuple[int, int]] = []
for line in f: for line in f:
@ -208,9 +208,9 @@ def make_random_graph(
rng.shuffle(edge_candidates) rng.shuffle(edge_candidates)
edge_set.update(edge_candidates[:m]) edge_set.update(edge_candidates[:m])
edges: list[tuple[int, int, int|float]] = [] edges: list[tuple[int, int, float]] = []
for (x, y) in sorted(edge_set): for (x, y) in sorted(edge_set):
w: int|float w: float
if float_weights: if float_weights:
w = rng.uniform(1.0e-8, max_weight) w = rng.uniform(1.0e-8, max_weight)
else: else:
@ -220,14 +220,14 @@ def make_random_graph(
return Graph(edges) return Graph(edges)
def check_matching(graph: Graph, matching: Matching) -> int|float: def check_matching(graph: Graph, matching: Matching) -> float:
"""Verify that the matching is valid and calculate its weight.""" """Verify that the matching is valid and calculate its weight."""
edge_map: dict[tuple[int, int], int|float] = {} edge_map: dict[tuple[int, int], float] = {}
for (x, y, w) in graph.edges: for (x, y, w) in graph.edges:
edge_map[(min(x, y), max(x, y))] = w edge_map[(min(x, y), max(x, y))] = w
weight: int|float = 0 weight: float = 0
nodes_used: set[int] = set() nodes_used: set[int] = set()
for pair in matching.pairs: for pair in matching.pairs:
@ -247,7 +247,7 @@ def check_matching(graph: Graph, matching: Matching) -> int|float:
return weight return weight
def compare_weight(weight1: int|float, weight2: int|float) -> int: def compare_weight(weight1: float, weight2: float) -> int:
"""Compare weights of matchings. """Compare weights of matchings.
Returns: Returns:
@ -406,7 +406,7 @@ class WmatchSolver(Solver):
num_edge = len(graph.edges) num_edge = len(graph.edges)
all_integer = True all_integer = True
adjacent: list[list[tuple[int, int|float]]] = [ adjacent: list[list[tuple[int, float]]] = [
[] for i in range(num_vertex)] [] for i in range(num_vertex)]
for (x, y, w) in graph.edges: for (x, y, w) in graph.edges:
@ -451,13 +451,13 @@ def test_solver_on_graph(
solver: Solver, solver: Solver,
graph: Graph, graph: Graph,
graph_desc: str, graph_desc: str,
gold_weight: Optional[int|float], gold_weight: Optional[float],
num_run: int num_run: int
) -> RunResult: ) -> RunResult:
"""Test the specified solver with the specified graph.""" """Test the specified solver with the specified graph."""
solver_run_time: list[float] = [] solver_run_time: list[float] = []
solver_weight: Optional[int|float] = None solver_weight: Optional[float] = None
for i in range(num_run): for i in range(num_run):
@ -507,7 +507,7 @@ def test_graph(
solvers: list[Solver], solvers: list[Solver],
graph: Graph, graph: Graph,
graph_desc: str, graph_desc: str,
gold_weight: Optional[int|float], gold_weight: Optional[float],
num_run: int num_run: int
) -> list[RunResult]: ) -> list[RunResult]:
"""Test all solvers with the specified graph.""" """Test all solvers with the specified graph."""
@ -525,7 +525,7 @@ def test_graph(
if gold_weight is None: if gold_weight is None:
best_weight: Optional[int|float] = None best_weight: Optional[float] = None
for result in results: for result in results:
if result.status == RunStatus.OK: if result.status == RunStatus.OK:
if (best_weight is None) or (result.weight > best_weight): if (best_weight is None) or (result.weight > best_weight):
@ -668,7 +668,7 @@ def test_input(
file=sys.stderr) file=sys.stderr)
return 1 return 1
gold_weight: Optional[int|float] = None gold_weight: Optional[float] = None
if verify: if verify:
reffile = os.path.splitext(filename)[0] + ".out" reffile = os.path.splitext(filename)[0] + ".out"
try: try:
@ -740,7 +740,8 @@ def main() -> int:
parser.add_argument("--timeout", parser.add_argument("--timeout",
action="store", action="store",
type=float, type=float,
help="abort when solver runs longer than TIMEOUT seconds") help="abort when solver runs longer than TIMEOUT"
" seconds")
parser.add_argument("--runs", parser.add_argument("--runs",
action="store", action="store",
type=int, type=int,