1
0
Fork 0
maximum-weight-matching/tests/run_test.py

867 lines
26 KiB
Python
Raw Normal View History

2023-02-19 21:59:03 +01:00
#!/usr/bin/env python3
"""
Run matching solvers on test graphs.
"""
from __future__ import annotations
import sys
import argparse
import enum
import io
import math
import os.path
import random
import shlex
import subprocess
import time
from typing import NamedTuple, Optional, Sequence, TextIO
class InvalidMatchingError(Exception):
"""Raised when the solver output is not a valid matching in the graph."""
class SolverError(Exception):
2023-03-13 21:56:46 +01:00
"""Raised if the solver returns an error or outputs an invalid format."""
2023-02-19 21:59:03 +01:00
class Graph(NamedTuple):
"""Represents a graph. Vertex indices start from 0."""
2024-06-22 20:04:49 +02:00
edges: list[tuple[int, int, float]]
2023-02-19 21:59:03 +01:00
def num_vertex(self) -> int:
"""Count number of vertices."""
if self.edges:
return 1 + max(max(x, y) for (x, y, _w) in self.edges)
else:
return 0
class Matching(NamedTuple):
"""Represents a matching."""
pairs: list[tuple[int, int]]
class RunStatus(enum.IntEnum):
2023-03-11 23:21:10 +01:00
"""Result categories for running a solver."""
2023-02-19 21:59:03 +01:00
OK = 0
FAILED = 1
WRONG_ANSWER = 2
TIMEOUT = 3
class RunResult(NamedTuple):
"""Represent the result of running a solver on a graph."""
status: RunStatus = RunStatus.OK
2024-06-22 20:04:49 +02:00
weight: float = 0
2023-02-19 21:59:03 +01:00
run_time: Sequence[float] = ()
2024-06-22 20:04:49 +02:00
def parse_int_or_float(s: str) -> float:
2023-02-19 21:59:03 +01:00
"""Convert a string to integer or float value."""
try:
return int(s)
except ValueError:
pass
return float(s)
def read_dimacs_graph(f: TextIO) -> Graph:
"""Read a graph in DIMACS edge list format."""
edges: list[tuple[int, int, float]] = []
for line in f:
s = line.strip()
words = s.split()
if not words[0]:
# Skip empty line.
pass
elif words[0].startswith("c"):
# Skip comment line.
pass
elif words[0] == "p":
# Handle "problem" line.
if len(words) != 4:
raise ValueError(
f"Expecting DIMACS edge format but got {s!r}")
if words[1] != "edge":
raise ValueError(
f"Expecting DIMACS edge format but got {words[1]!r}")
elif words[0] == "e":
# Handle "edge" line.
if len(words) != 4:
raise ValueError(f"Expecting edge but got {s!r}")
x = int(words[1])
y = int(words[2])
w = parse_int_or_float(words[3])
if x < 1:
raise ValueError("Invalid vertex index {x}")
if y < 1:
raise ValueError("Invalid vertex index {y}")
edges.append((x - 1, y - 1, w))
else:
raise ValueError(f"Unknown line type {words[0]!r}")
return Graph(edges)
def write_dimacs_graph(f: TextIO, graph: Graph) -> None:
"""Write a graph in DIMACS edge list format."""
num_vertex = graph.num_vertex()
num_edge = len(graph.edges)
print(f"p edge {num_vertex} {num_edge}", file=f)
integer_weights = all(isinstance(w, int) for (_x, _y, w) in graph.edges)
for (x, y, w) in graph.edges:
if integer_weights:
print(f"e {x+1} {y+1} {w}", file=f)
else:
print(f"e {x+1} {y+1} {w:.12g}", file=f)
2024-06-22 20:04:49 +02:00
def read_dimacs_matching(f: TextIO) -> tuple[float, Matching]:
2023-02-19 21:59:03 +01:00
"""Read a matching solution in DIMACS format."""
have_weight = False
2024-06-22 20:04:49 +02:00
weight: float = 0
2023-02-19 21:59:03 +01:00
pairs: list[tuple[int, int]] = []
for line in f:
s = line.strip()
words = s.split()
if not words[0]:
# Skip empty line.
continue
if words[0].startswith("c"):
# Skip comment line.
pass
elif words[0] == "s":
# Handle "solution" line.
if len(words) != 2:
raise ValueError(
f"Expecting solution line but got {s!r}")
if have_weight:
raise ValueError("Duplicate solution line")
have_weight = True
weight = parse_int_or_float(words[1])
elif words[0] == "m":
# Handle "matching" line.
if len(words) != 3:
raise ValueError(
f"Expecting matched edge but got {s!r}")
x = int(words[1])
y = int(words[2])
if x < 1:
raise ValueError("Invalid vertex index {x}")
if y < 1:
raise ValueError("Invalid vertex index {y}")
pairs.append((x - 1, y - 1))
else:
raise ValueError(f"Unknown line type {words[0]!r}")
if not have_weight:
raise ValueError("Missing solution line")
return (weight, Matching(pairs))
def make_random_graph(
n: int,
m: int,
max_weight: float,
float_weights: bool,
rng: random.Random
) -> Graph:
"""Generate a random graph with random edge weights."""
edge_set: set[tuple[int, int]] = set()
if 3 * m < n * (n - 2) // 2:
# Simply add random edges until we have enough.
while len(edge_set) < m:
x = rng.randint(0, n - 2)
y = rng.randint(x + 1, n - 1)
edge_set.add((x, y))
else:
# We need a very dense graph.
# Generate all edge candidates and choose a random subset.
edge_candidates = [(x, y)
for x in range(n - 1)
for y in range(x + 1, n)]
rng.shuffle(edge_candidates)
edge_set.update(edge_candidates[:m])
2024-06-22 20:04:49 +02:00
edges: list[tuple[int, int, float]] = []
2023-02-19 21:59:03 +01:00
for (x, y) in sorted(edge_set):
2024-06-22 20:04:49 +02:00
w: float
2023-02-19 21:59:03 +01:00
if float_weights:
w = rng.uniform(1.0e-8, max_weight)
else:
w = rng.randint(1, int(max_weight))
edges.append((x, y, w))
return Graph(edges)
2024-06-22 20:04:49 +02:00
def check_matching(graph: Graph, matching: Matching) -> float:
2023-02-19 21:59:03 +01:00
"""Verify that the matching is valid and calculate its weight."""
2024-06-22 20:04:49 +02:00
edge_map: dict[tuple[int, int], float] = {}
2023-02-19 21:59:03 +01:00
for (x, y, w) in graph.edges:
edge_map[(min(x, y), max(x, y))] = w
2024-06-22 20:04:49 +02:00
weight: float = 0
2023-02-19 21:59:03 +01:00
nodes_used: set[int] = set()
for pair in matching.pairs:
x = min(pair)
y = max(pair)
if x in nodes_used:
raise InvalidMatchingError(f"Matching uses vertex {x} twice")
if y in nodes_used:
raise InvalidMatchingError(f"Matching uses vertex {y} twice")
if (x, y) not in edge_map:
raise InvalidMatchingError(
f"Matching contains non-existing edge ({x}, {y})")
weight += edge_map[(x, y)]
nodes_used.add(x)
nodes_used.add(y)
return weight
2024-06-22 20:04:49 +02:00
def compare_weight(weight1: float, weight2: float) -> int:
2023-02-19 21:59:03 +01:00
"""Compare weights of matchings.
Returns:
0 if weights are equal,
-1 if weight1 < weight2,
+1 if weight1 > weight2.
"""
if isinstance(weight1, int) and isinstance(weight2, int):
if weight1 == weight2:
return 0
else:
if math.isclose(weight1, weight2, rel_tol=1e-9, abs_tol=1e-9):
return 0
if weight1 < weight2:
return -1
if weight1 > weight2:
return 1
return 0
class Solver:
"""Abstract base class for interaction with solvers."""
def __init__(self, name: str, timeout: Optional[float]) -> None:
self.name = name
self.timeout = timeout
def __str__(self) -> str:
return self.name
def prepare_input(self, graph: Graph) -> str:
"""Serialize a graph instance to input bytes for the solver."""
raise NotImplementedError
def parse_output(self, output_data: str) -> Matching:
"""Parse output bytes from the solver to a set of matched edges."""
raise NotImplementedError
def solver_command(self) -> list[str]:
"""Return the command and arguments to invoke the solver."""
raise NotImplementedError
def run(self, graph: Graph) -> tuple[Matching, float]:
"""Run the solver on a graph instance.
Returns:
Tuple (matching, elapsed_time).
Raises:
SolverError: If the solver returns an error status,
or produces incorrectly formatted output.
TimeoutEpxired: If the solver exceeds the timeout.
"""
input_str = self.prepare_input(graph)
input_bytes = input_str.encode("ascii")
command = self.solver_command()
t0 = time.monotonic()
proc = subprocess.run(
command,
stdout=subprocess.PIPE,
input=input_bytes,
timeout=self.timeout,
check=False)
t1 = time.monotonic()
elapsed = t1 - t0
if proc.returncode < 0:
raise SolverError(f"Solver aborted on signal {-proc.returncode}")
if proc.returncode > 0:
raise SolverError(f"Solver exit status {proc.returncode}")
try:
output_str = proc.stdout.decode("ascii")
except ValueError:
raise SolverError("Non-ASCII data in solver output") from None
pairs = self.parse_output(output_str)
return (pairs, elapsed)
class DimacsSolver(Solver):
"""Interaction with a solver that uses DIMACS input/output."""
def __init__(
self,
name: str,
solver_command: str,
timeout: Optional[float]
) -> None:
"""Initialize solver wrapper.
The solver must read a graph in DIMACS format from stdin.
The solver must write a matching in DIMACS format to stdout.
Parameters:
name: Short label to refer to the solver.
solver_command: Command to invoke the solver.
Arguments may be provided, separated by spaces.
Shell-style quoting may be used.
timeout: Optional run time limit.
"""
super().__init__(name, timeout)
self._solver_command = shlex.split(solver_command)
def solver_command(self) -> list[str]:
return self._solver_command
def prepare_input(self, graph: Graph) -> str:
buf = io.StringIO()
write_dimacs_graph(buf, graph)
return buf.getvalue()
def parse_output(self, output_data: str) -> Matching:
if not output_data:
raise SolverError("Empty output from solver")
buf = io.StringIO(output_data)
try:
(_weight, matching) = read_dimacs_matching(buf)
except ValueError as exc:
raise SolverError(exc) from None
return matching
class WmatchSolver(Solver):
"""Interaction with the Wmatch solver."""
def __init__(
self,
name: str,
solver_command: str,
timeout: Optional[float]
) -> None:
"""Initialize solver wrapper.
Parameters:
name: Short label to refer to the solver.
solver_command: Command to invoke the solver.
It should be a single binary without arguments.
timeout: Optional run time limit.
"""
super().__init__(name, timeout)
self._solver_command = [solver_command, "/dev/stdin"]
def solver_command(self) -> list[str]:
return self._solver_command
def prepare_input(self, graph: Graph) -> str:
num_vertex = graph.num_vertex()
num_edge = len(graph.edges)
all_integer = True
2024-06-22 20:04:49 +02:00
adjacent: list[list[tuple[int, float]]] = [
2023-02-19 21:59:03 +01:00
[] for i in range(num_vertex)]
for (x, y, w) in graph.edges:
adjacent[x].append((y, w))
adjacent[y].append((x, w))
all_integer = all_integer and isinstance(w, int)
lines: list[str] = []
lines.append(f"{num_vertex} {num_edge} U")
for x in range(num_vertex):
degree = len(adjacent[x])
lines.append(f"{degree} {x+1} 0 0")
for (y, w) in adjacent[x]:
if all_integer:
lines.append(f"{y+1} {w}")
else:
lines.append(f"{y+1} {w:.12g}")
return "\n".join(lines) + "\n"
def parse_output(self, output_data: str) -> Matching:
pairs: list[tuple[int, int]] = []
for line in output_data.splitlines():
words = line.split()
if len(words) != 2:
raise SolverError("Invalid format in solver output")
try:
x = int(words[0])
y = int(words[1])
except ValueError:
raise SolverError("Invalid format in solver output") from None
2023-03-11 23:21:10 +01:00
if 0 < x < y:
2023-02-19 21:59:03 +01:00
pairs.append((x - 1, y - 1))
return Matching(pairs)
def test_solver_on_graph(
solver: Solver,
graph: Graph,
graph_desc: str,
2024-06-22 20:04:49 +02:00
gold_weight: Optional[float],
2023-02-19 21:59:03 +01:00
num_run: int
) -> RunResult:
"""Test the specified solver with the specified graph."""
solver_run_time: list[float] = []
2024-06-22 20:04:49 +02:00
solver_weight: Optional[float] = None
2023-02-19 21:59:03 +01:00
for i in range(num_run):
print(f"Running {solver} on {graph_desc}, run {i+1}/{num_run} ...")
sys.stdout.flush()
try:
(matching, elapsed) = solver.run(graph)
weight = check_matching(graph, matching)
except SolverError as exc:
print("FAILED:", exc)
return RunResult(status=RunStatus.FAILED)
except InvalidMatchingError as exc:
print("WRONG:", exc)
return RunResult(status=RunStatus.WRONG_ANSWER)
except subprocess.TimeoutExpired as exc:
print("TIMEOUT:", exc)
return RunResult(status=RunStatus.TIMEOUT)
if gold_weight is not None:
weight_diff = compare_weight(weight, gold_weight)
if weight_diff > 0:
raise ValueError(f"IMPOSSIBLE: Solver found weight {weight}"
f" but reference answer is {gold_weight}")
if weight_diff < 0:
print(f"WRONG: matching has weight {weight}"
f" but expected {gold_weight}")
return RunResult(status=RunStatus.WRONG_ANSWER)
elif solver_weight is not None:
weight_diff = compare_weight(weight, solver_weight)
if weight_diff != 0:
print(f"WRONG: matching has weight {weight}"
f" but previous run had weight {solver_weight}")
return RunResult(status=RunStatus.WRONG_ANSWER)
solver_weight = weight
solver_run_time.append(elapsed)
sys.stdout.flush()
assert solver_weight is not None
return RunResult(weight=solver_weight, run_time=solver_run_time)
def test_graph(
solvers: list[Solver],
graph: Graph,
graph_desc: str,
2024-06-22 20:04:49 +02:00
gold_weight: Optional[float],
2023-02-19 21:59:03 +01:00
num_run: int
) -> list[RunResult]:
"""Test all solvers with the specified graph."""
results: list[RunResult] = []
for solver in solvers:
result = test_solver_on_graph(
solver,
graph,
graph_desc,
gold_weight,
num_run)
results.append(result)
if gold_weight is None:
2024-06-22 20:04:49 +02:00
best_weight: Optional[float] = None
2023-02-19 21:59:03 +01:00
for result in results:
if result.status == RunStatus.OK:
if (best_weight is None) or (result.weight > best_weight):
best_weight = result.weight
if best_weight is not None:
for (i, solver) in enumerate(solvers):
result = results[i]
if result.status == RunStatus.OK:
if compare_weight(result.weight, best_weight) != 0:
print(f"WRONG: Solver {solver} found weight"
f" {result.weight} but other solver found"
f" {best_weight}")
results[i] = RunResult(status=RunStatus.WRONG_ANSWER)
return results
def test_random(
solvers: list[Solver],
num_vertex: int,
num_edge: int,
max_weight: int,
float_weights: bool,
num_run: int,
seed: Optional[int]
) -> int:
"""Run a test with randomly generated graphs."""
# Choose a random seed if the user does not provide a seed.
if seed is None:
seed = random.getrandbits(48)
solver_num_fail: list[int] = [0 for solver in solvers]
solver_num_wrong: list[int] = [0 for solver in solvers]
solver_num_timeout: list[int] = [0 for solver in solvers]
solver_run_time: list[list[float]] = [[] for solver in solvers]
solver_errors: list[list[tuple[str, RunStatus]]] = [
[] for solver in solvers]
num_errors = 0
for i in range(num_run):
rng = random.Random(seed + i)
graph_desc = f"graph {i+1}/{num_run} seed={seed+i}"
graph = make_random_graph(
num_vertex,
num_edge,
max_weight,
float_weights,
rng)
results = test_graph(
solvers=solvers,
graph=graph,
graph_desc=graph_desc,
gold_weight=None,
num_run=1)
assert len(results) == len(solvers)
for (j, result) in enumerate(results):
if result.status == RunStatus.FAILED:
solver_num_fail[j] += 1
elif result.status == RunStatus.WRONG_ANSWER:
solver_num_wrong[j] += 1
elif result.status == RunStatus.TIMEOUT:
solver_num_timeout[j] += 1
else:
solver_run_time[j] += result.run_time
if result.status != RunStatus.OK:
num_errors += 1
solver_errors[j].append((graph_desc, result.status))
print("Test finished.")
print()
if num_errors > 0:
print("Failed tests")
print("------------")
for (j, solver) in enumerate(solvers):
for (graph_desc, status) in solver_errors[j]:
print(f"solver: {solver}, {graph_desc},"
f" status: {status.name}")
print()
print("Results")
print("-------")
print()
print("N =", num_vertex)
print("M =", num_edge)
print("float" if float_weights else "integer", "weights;",
"max_weight =", max_weight)
print("runs =", num_run)
print()
for (i, solver) in enumerate(solvers):
print(f" {str(solver):32s} ", end="")
num_fail = solver_num_fail[i]
num_wrong = solver_num_wrong[i]
num_timeout = solver_num_timeout[i]
if (num_fail > 0) or (num_wrong > 0) or (num_timeout > 0):
print(f"nFAIL={num_fail} nWRONG={num_wrong}"
f" nTIMEOUT={num_timeout}")
else:
run_time = solver_run_time[i]
tmin = min(run_time)
tmax = max(run_time)
tavg = sum(run_time) / len(run_time)
print(f"min={tmin:8.2f} max={tmax:8.2f} avg={tavg:8.2f}")
print()
if num_errors == 0:
print("All tests passed")
else:
print(num_errors, "tests failed")
return 0 if (num_errors == 0) else 2
def test_input(
solvers: list[Solver],
files: list[str],
verify: bool,
num_run: int
) -> int:
"""Run a test with graphs from input files."""
result_table: list[list[RunResult]] = []
for filename in files:
try:
2023-03-11 23:21:10 +01:00
with open(filename, "r", encoding="ascii") as f:
2023-02-19 21:59:03 +01:00
graph = read_dimacs_graph(f)
except (OSError, ValueError) as exc:
print(f"ERROR: Can not read graph {filename!r} ({exc})",
file=sys.stderr)
return 1
2024-06-22 20:04:49 +02:00
gold_weight: Optional[float] = None
2023-02-19 21:59:03 +01:00
if verify:
reffile = os.path.splitext(filename)[0] + ".out"
try:
2023-03-11 23:21:10 +01:00
with open(reffile, "r", encoding="ascii") as f:
2023-02-19 21:59:03 +01:00
(gold_weight, _matching) = read_dimacs_matching(f)
except (OSError, ValueError) as exc:
print(f"ERROR: Can not read matching {reffile!r} ({exc})",
file=sys.stderr)
return 1
results = test_graph(solvers, graph, filename, gold_weight, num_run)
result_table.append(results)
print("Test finished.")
print()
print("Results")
print("-------")
print()
errors = 0
for (i, filename) in enumerate(files):
print("Graph:", filename)
for (j, solver) in enumerate(solvers):
result = result_table[i][j]
print(f" {str(solver):32s} ", end="")
if result.status == RunStatus.FAILED:
print("FAILED")
errors += 1
elif result.status == RunStatus.WRONG_ANSWER:
print("WRONG ANSWER")
errors += 1
elif result.status == RunStatus.TIMEOUT:
print("TIMEOUT")
errors += 1
else:
tmin = min(result.run_time)
tmax = max(result.run_time)
tavg = sum(result.run_time) / len(result.run_time)
print(f"min={tmin:8.2f} max={tmax:8.2f} avg={tavg:8.2f}")
print()
if errors == 0:
print("All tests passed")
else:
print(errors, "tests failed")
return 0 if (errors == 0) else 2
def main() -> int:
"""Main program."""
parser = argparse.ArgumentParser()
parser.description = "Run matching solvers on test graphs."
parser.add_argument("--solver",
action="append",
type=str,
metavar="CMD",
help="add a solver with DIMACS input/output"
" (CMD is the command to run the solver)")
parser.add_argument("--solver-wmatch",
action="append",
type=str,
metavar="CMD",
help="add a WMATCH solver"
" (CMD is the solver executable)")
2023-02-19 21:59:03 +01:00
parser.add_argument("--timeout",
action="store",
type=float,
2024-06-22 20:04:49 +02:00
help="abort when solver runs longer than TIMEOUT"
" seconds")
2023-02-19 21:59:03 +01:00
parser.add_argument("--runs",
action="store",
type=int,
metavar="R",
default=1,
help="run R random graphs or run input files R times")
parser.add_argument("--random",
action="store_true",
help="generate random test graphs")
parser.add_argument("--seed",
action="store",
type=int,
help="starting random seed")
parser.add_argument("-n", "--n",
action="store",
type=int,
help="number of vertices for random graphs")
parser.add_argument("-m", "--m",
action="store",
type=int,
help="number of edges for random graphs")
parser.add_argument("--maxweight",
action="store",
type=int,
default=1000000,
metavar="W",
help="maximum edge weight for random graphs")
parser.add_argument("--float",
action="store_true",
help="use floating point weights in random graphs")
parser.add_argument("--verify",
action="store_true",
help="compare answers to existing output files")
parser.add_argument("input",
nargs="*",
help="read test graphs from input files")
args = parser.parse_args()
if (not args.input) and (not args.random):
print("ERROR: Specify at least one input file or --random",
file=sys.stderr)
print(file=sys.stderr)
parser.print_help(sys.stderr)
return 1
if (not args.solver) and (not args.solver_wmatch):
2023-02-19 21:59:03 +01:00
print("ERROR: Specify at least one solver", file=sys.stderr)
return 1
solvers: list[Solver] = []
if args.solver:
for cmd in args.solver:
name = shlex.split(cmd)[0]
if name.endswith("python") or name.endswith("python3"):
name = shlex.split(cmd)[1]
2023-02-19 21:59:03 +01:00
solvers.append(DimacsSolver(name, cmd, args.timeout))
if args.solver_wmatch:
for cmd in args.solver_wmatch:
name = cmd
2023-02-19 21:59:03 +01:00
solvers.append(WmatchSolver(name, cmd, args.timeout))
if args.runs < 1:
print("ERROR: Number of runs must be >= 1", file=sys.stderr)
return 1
if args.random:
if args.input:
print("ERROR: Specify either --random or input files, not both",
file=sys.stderr)
return 1
if args.verify:
print("ERROR: --verify not supported with --random",
file=sys.stderr)
return 1
if not args.n:
print("ERROR: Missing required option --n", file=sys.stderr)
return 1
if not args.m:
print("ERROR: Missing required option --m", file=sys.stderr)
return 1
if args.n < 2:
print("ERROR: Number of vertices must be >= 2", file=sys.stderr)
return 1
if args.m < 1:
print("ERROR: Number of edges must be >= 1", file=sys.stderr)
return 1
if args.m > args.n * (args.n - 1) // 2:
print("ERROR: Too many edges", file=sys.stderr)
return 1
if args.maxweight < 1:
print("ERROR: Maximum weight must be >= 1", file=sys.stderr)
return 1
return test_random(
solvers=solvers,
num_vertex=args.n,
num_edge=args.m,
max_weight=args.maxweight,
float_weights=args.float,
num_run=args.runs,
seed=args.seed)
else:
return test_input(
solvers=solvers,
files=args.input,
verify=args.verify,
num_run=args.runs)
if __name__ == "__main__":
sys.exit(main())