1
0
Fork 0

Clean up verification code

This commit is contained in:
Joris van Rantwijk 2023-05-16 22:19:55 +02:00
parent 250fd4ea94
commit 731b202af3
1 changed files with 253 additions and 151 deletions

View File

@ -18,7 +18,6 @@
#include <stack>
#include <stdexcept>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@ -28,6 +27,9 @@
namespace mwmatching {
/* **************************************************
* ** public definitions **
* ************************************************** */
/** Type representing the unique ID of a vertex. */
using VertexId = unsigned int;
@ -64,6 +66,9 @@ struct Edge
namespace impl {
/* **************************************************
* ** private definitions **
* ************************************************** */
/** Value used to mark an invalid or undefined vertex. */
constexpr VertexId NO_VERTEX = std::numeric_limits<VertexId>::max();
@ -73,6 +78,10 @@ constexpr VertexId NO_VERTEX = std::numeric_limits<VertexId>::max();
enum BlossomLabel { LABEL_NONE = 0, LABEL_S = 1, LABEL_T = 2 };
/* **************************************************
* ** private helper functions **
* ************************************************** */
/** Return a pair of vertices in flipped order. */
inline VertexPair flip_vertex_pair(const VertexPair& vt)
{
@ -166,6 +175,10 @@ std::vector<Edge<WeightType>> remove_negative_weight_edges(
}
/* **************************************************
* ** struct Blossom **
* ************************************************** */
// Forward declaration.
template <typename WeightType> struct NonTrivialBlossom;
@ -212,7 +225,6 @@ struct Blossom
* that links to a different S-blossom, or "nullptr" if no such edge
* has been found.
*/
// TODO : consider storing a copy of the edge instead of pointer
const Edge<WeightType>* best_edge;
protected:
@ -241,9 +253,20 @@ public:
static_cast<NonTrivialBlossom<WeightType>*>(this)
: nullptr);
}
const NonTrivialBlossom<WeightType>* nontrivial() const
{
return (is_nontrivial_blossom ?
static_cast<const NonTrivialBlossom<WeightType>*>(this)
: nullptr);
}
};
/* **************************************************
* ** struct NonTrivialBlossom **
* ************************************************** */
/**
* Represents a non-trivial blossom.
*
@ -291,7 +314,6 @@ struct NonTrivialBlossom : public Blossom<WeightType>
* In case of a top-level S-blossom, "best_edge_set" is a list of
* least-slack edges between this blossom and other S-blossoms.
*/
// TODO : consider storing a copy of the edge instead of pointer
std::list<const Edge<WeightType>*> best_edge_set;
/** Initialize a non-trivial blossom. */
@ -333,6 +355,38 @@ struct NonTrivialBlossom : public Blossom<WeightType>
};
/** Call a function for every vertex inside the specified blossom. */
template <typename WeightType, typename Func>
inline void for_vertices_in_blossom(const Blossom<WeightType>* blossom, Func func)
{
const NonTrivialBlossom<WeightType>* ntb = blossom->nontrivial();
if (ntb) {
// Visit all vertices in the non-trivial blossom.
// Use an explicit stack to avoid deep call chains.
std::vector<const NonTrivialBlossom<WeightType>*> stack;
stack.push_back(ntb);
while (! stack.empty()) {
const NonTrivialBlossom<WeightType>* b = stack.back();
stack.pop_back();
for (const auto& sub : b->subblossoms) {
ntb = sub.blossom->nontrivial();
if (ntb) {
stack.push_back(ntb);
} else {
func(sub.blossom->base_vertex);
}
}
}
} else {
// A trivial blossom contains just one vertex.
func(blossom->base_vertex);
}
}
/**
* Represents a list of edges forming an alternating path or
* an alternating cycle.
@ -343,11 +397,14 @@ struct NonTrivialBlossom : public Blossom<WeightType>
*/
struct AlternatingPath
{
// TODO : consider avoiding dynamic allocations for the alternating path
std::deque<VertexPair> edges;
};
/* **************************************************
* ** struct MatchingContext **
* ************************************************** */
/**
* This class holds all data used by the matching algorithm.
*
@ -388,7 +445,6 @@ struct MatchingContext
const VertexId num_vertex;
/** For each vertex, a vector of pointers to its incident edges. */
// TODO : consider copying the Edge instead of pointers
const std::vector<std::vector<const EdgeT*>> adjacent_edges;
/**
@ -397,7 +453,6 @@ struct MatchingContext
* vertex_mate[x] == y if vertex "x" is matched to vertex "y".
* vertex_mate[x] == NO_VERTEX if vertex "x" is unmatched.
*/
// TODO : consider merging all per-vertex info into a single "struct Vertex"
std::vector<VertexId> vertex_mate;
/**
@ -436,7 +491,6 @@ struct MatchingContext
* "vertex_best_edge[x]" is the least-slack edge to any S-vertex,
* or "nullptr" if no such edge has been found.
*/
// TODO - consider storing a copy of the edge instead of a pointer
std::vector<const EdgeT*> vertex_best_edge;
/** Queue of S-vertices to be scanned. */
@ -558,37 +612,6 @@ struct MatchingContext
return vertex_dual[x] + vertex_dual[y] - weight_factor * edge.weight;
}
/** Call a function for every vertex inside the specified blossom. */
template <typename Func>
void for_vertices_in_blossom(BlossomT* blossom, Func func)
{
NonTrivialBlossomT* ntb = blossom->nontrivial();
if (ntb) {
// Visit all vertices in the non-trivial blossom.
// Use an explicit stack to avoid deep call chains.
std::vector<NonTrivialBlossomT*> stack;
stack.push_back(ntb);
while (! stack.empty()) {
NonTrivialBlossomT* b = stack.back();
stack.pop_back();
for (const auto& sub : b->subblossoms) {
ntb = sub.blossom->nontrivial();
if (ntb) {
stack.push_back(ntb);
} else {
func(sub.blossom->base_vertex);
}
}
}
} else {
// A trivial blossom contains just one vertex.
func(blossom->base_vertex);
}
}
/*
* Least-slack edge tracking:
*
@ -820,7 +843,6 @@ struct MatchingContext
const EdgeT* best_edge = nullptr;
WeightType best_slack = 0;
// TODO - clean this up when we have a linked list of top-level blossoms
auto consider_blossom =
[this,&best_edge,&best_slack](BlossomT* blossom)
{
@ -1317,7 +1339,7 @@ struct MatchingContext
{
// Assign label S to the blossom that contains vertex "x".
BlossomT* bx = vertex_top_blossom[x];
assert(bx->label == LABEL_NONE); // TODO - this assertion sometimes fails
assert(bx->label == LABEL_NONE);
bx->label = LABEL_S;
VertexId y = vertex_mate[x];
@ -1718,8 +1740,103 @@ struct MatchingContext
// Each iteration takes time O(n**2).
while (run_stage()) ;
}
};
/* ********** Verify optimal solution: ********** */
/* **************************************************
* ** struct MatchingVerifier **
* ************************************************** */
/** Helper class to verify that an optimal solution has been found. */
template <typename WeightType>
struct MatchingVerifier
{
using EdgeT = Edge<WeightType>;
using BlossomT = Blossom<WeightType>;
using NonTrivialBlossomT = NonTrivialBlossom<WeightType>;
// Reference to the MatchingContext instance.
const MatchingContext<WeightType>& ctx;
// For each edge, the sum of duals of its incident vertices
// and duals of all blossoms that contain the edge.
std::vector<WeightType> edge_duals;
MatchingVerifier(const MatchingContext<WeightType>& ctx)
: ctx(ctx),
edge_duals(ctx.edges.size())
{ }
static bool checked_add(WeightType& result, WeightType a, WeightType b)
{
if (a > std::numeric_limits<WeightType>::max() - b) {
return true;
} else {
result = a + b;
return false;
}
}
/** Convert edge pointer to its index in the vector "edges". */
std::size_t edge_index(const EdgeT* edge)
{
return edge - ctx.edges.data();
}
/** Check that the array "vertex_mate" is consistent. */
bool verify_vertex_mate()
{
// Count matched vertices and check symmetry of "vertex_mate".
VertexId num_matched_vertex = 0;
for (VertexId x = 0; x < ctx.num_vertex; ++x) {
VertexId y = ctx.vertex_mate[x];
if (y != NO_VERTEX) {
++num_matched_vertex;
if (ctx.vertex_mate[y] != x) {
return false;
}
}
}
// Count matched edges.
VertexId num_matched_edge = 0;
for (const EdgeT& edge : ctx.edges) {
if (ctx.vertex_mate[edge.vt.first] == edge.vt.second) {
++num_matched_edge;
}
}
// Check that all matched vertices correspond to matched edges.
return (num_matched_vertex == 2 * num_matched_edge);
}
/**
* Check that vertex dual variables are non-negative,
* and all unmatched vertices have zero dual.
*/
bool verify_vertex_duals()
{
for (VertexId x = 0; x < ctx.num_vertex; ++x) {
if (ctx.vertex_dual[x] < 0) {
return false;
}
if ((ctx.vertex_mate[x] == NO_VERTEX) && (ctx.vertex_dual[x] != 0)) {
return false;
}
}
return true;
}
/** Check that blossom dual variables are non-negative. */
bool verify_blossom_duals()
{
for (const NonTrivialBlossomT& blossom : ctx.nontrivial_blossom) {
if (blossom.dual_var < 0) {
return false;
}
}
return true;
}
/**
* Helper function for verifying the solution.
@ -1727,31 +1844,24 @@ struct MatchingContext
* Descend down the blossom tree to find edges that are contained
* in blossoms.
*
* Adjust the slack of all contained edges to account for the dual
* variables of its containing blossoms.
*
* On the way down, keep track of the sum of dual variables of
* the containing blossoms.
* containing blossoms. Add blossom duals to edges that are contained
* inside blossoms.
*
* On the way up, keep track of the total number of matched edges
* in the subblossoms. Then check that all blossoms with non-zero
* dual variable are "full".
* in subblossoms. Check that all blossoms with non-zero dual variables
* are "full".
*
* @return True if successful;
* false if a blossom with non-zero dual is not full.
* false if a blossom with non-zero dual is not full;
* false if blossom dual calculations cause numeric overflow.
*/
bool verify_blossom_edges(NonTrivialBlossomT* blossom,
std::unordered_map<VertexPair, WeightType, boost::hash<VertexPair>>& edge_duals)
bool check_blossom(const NonTrivialBlossomT* blossom)
{
// TODO : fix line length
// TODO : simplify code
// TODO : optimize
// TODO : try to think of a simpler way
// For each vertex "x",
// "vertex_depth[x]" is the depth of the smallest blossom on
// the current descent path that contains "x".
std::vector<VertexId> vertex_depth(num_vertex);
std::vector<VertexId> vertex_depth(ctx.num_vertex);
// At each depth, keep track of the sum of blossom duals
// along the current descent path.
@ -1762,7 +1872,9 @@ struct MatchingContext
std::vector<VertexId> path_num_matched = {0};
// Use an explicit stack to avoid deep recursion.
std::stack<std::pair<NonTrivialBlossomT*, typename std::list<typename NonTrivialBlossomT::SubBlossom>::iterator>> stack;
using SubBlossomList = std::list<typename NonTrivialBlossomT::SubBlossom>;
std::stack<std::pair<const NonTrivialBlossomT*,
typename SubBlossomList::const_iterator>> stack;
stack.emplace(blossom, blossom->subblossoms.begin());
while (! stack.empty()) {
@ -1781,8 +1893,12 @@ struct MatchingContext
});
// Calculate the sum of blossom duals at the new depth.
path_sum_dual.push_back(path_sum_dual.back()
+ blossom->dual_var);
path_sum_dual.push_back(path_sum_dual.back());
if (checked_add(path_sum_dual.back(),
path_sum_dual.back(),
blossom->dual_var)) {
return false;
}
// Initialize the number of matched edges at the new depth.
path_num_matched.push_back(0);
@ -1808,18 +1924,22 @@ struct MatchingContext
// For each incident edge, find the smallest blossom
// that contains it.
VertexId x = sub->base_vertex;
for (const EdgeT* edge : adjacent_edges[x]) {
for (const EdgeT* edge : ctx.adjacent_edges[x]) {
// Only consider edges pointing out from "x".
if (edge->vt.first == x) {
VertexId y = edge->vt.second;
VertexId edge_depth = vertex_depth[y];
if (edge_depth > 0) {
// This edge is contained in a blossom.
// Update its slack.
edge_duals[edge->vt] += path_sum_dual[edge_depth];
// Found the smallest blossom that contains this edge.
// Add the duals of the containing blossoms.
if (checked_add(edge_duals[edge_index(edge)],
edge_duals[edge_index(edge)],
path_sum_dual[edge_depth])) {
return false;
}
// Update the number of matched edges in the blossom.
if (vertex_mate[x] == y) {
if (ctx.vertex_mate[x] == y) {
path_num_matched[edge_depth] += 1;
}
}
@ -1866,6 +1986,62 @@ struct MatchingContext
return true;
}
/**
* Check that all blossoms are full.
* Also calculate the sum of dual variables for every edge.
*/
bool verify_blossoms_and_calc_edge_duals()
{
// For each edge, calculate the sum of its vertex duals.
for (const EdgeT& edge : ctx.edges) {
if (checked_add(edge_duals[edge_index(&edge)],
ctx.vertex_dual[edge.vt.first],
ctx.vertex_dual[edge.vt.second])) {
return false;
}
}
// Descend down each top-level blossom.
// Check that blossoms are full.
// Add blossom duals to the edges contained inside the blossoms.
// This takes total time O(n**2).
for (const NonTrivialBlossomT& blossom : ctx.nontrivial_blossom) {
if (blossom.parent == nullptr) {
if (!check_blossom(&blossom)) {
return false;
}
}
}
return true;
}
/**
* Check that all edges have non-negative slack,
* and check that all matched edges have zero slack.
*
* @pre Edge duals must be calculated before calling this function.
*/
bool verify_edge_slack()
{
for (const EdgeT& edge : ctx.edges) {
WeightType duals = edge_duals[edge_index(&edge)];
WeightType weight = ctx.weight_factor * edge.weight;
if (weight > duals) {
return false;
}
WeightType slack = duals - weight;
if (ctx.vertex_mate[edge.vt.first] == edge.vt.second) {
if (slack != 0) {
return false;
}
}
}
return true;
}
/**
* Verify that the optimum solution has been found.
*
@ -1873,91 +2049,13 @@ struct MatchingContext
*
* @return True if the solution is optimal; otherwise false.
*/
bool verify_optimum()
bool verify()
{
// Count matched vertices and check symmetry of "vertex_mate".
VertexId num_matched_vertex = 0;
for (VertexId x = 0; x < num_vertex; ++x) {
VertexId y = vertex_mate[x];
if (y != NO_VERTEX) {
++num_matched_vertex;
if (vertex_mate[y] != x) {
return false;
}
}
}
// Check that all matched edges exist in the graph.
VertexId num_matched_edge = 0;
for (const EdgeT& edge : edges) {
if (vertex_mate[edge.vt.first] == edge.vt.second) {
++num_matched_edge;
}
}
if (num_matched_vertex != 2 * num_matched_edge) {
return false;
}
// Check that all dual variables are non-negative.
for (VertexId x = 0; x < num_vertex; ++x) {
if (vertex_dual[x] < 0) {
return false;
}
}
for (NonTrivialBlossomT& blossom : nontrivial_blossom) {
if (blossom.dual_var < 0) {
return false;
}
}
// Check that all unmatched vertices have zero dual.
for (VertexId x = 0; x < num_vertex; ++x) {
if ((vertex_mate[x] == NO_VERTEX) && (vertex_dual[x] != 0)) {
return false;
}
}
// Calculate the slack of each edge.
// A correction will be needed for edges inside blossoms.
std::unordered_map<VertexPair,
WeightType,
boost::hash<VertexPair>> edge_duals;
for (const EdgeT& edge : edges) {
// TODO : consider potential numeric overflow
edge_duals[edge.vt] = vertex_dual[edge.vt.first]
+ vertex_dual[edge.vt.second];
}
// Descend down each top-level blossom.
// Adjust edge slacks to account for blossom dual.
// Also check that all blossoms are full.
// This takes total time O(n**2).
for (NonTrivialBlossomT& blossom : nontrivial_blossom) {
if (blossom.parent == nullptr) {
verify_blossom_edges(&blossom, edge_duals);
}
}
// Check that all edges have non-negative slack.
// Check that all matched edges have zero slack.
// TODO : fix potential numeric overflow - WeightType may be unsigned
for (const EdgeT& edge : edges) {
WeightType slack = edge_duals[edge.vt]
- weight_factor * edge.weight;
if (slack < 0) {
return false;
}
if (vertex_mate[edge.vt.first] == edge.vt.second) {
if (slack != 0) {
return false;
}
}
}
// Optimum solution confirmed.
return true;
return (verify_vertex_mate()
&& verify_vertex_duals()
&& verify_blossom_duals()
&& verify_blossoms_and_calc_edge_duals()
&& verify_edge_slack());
}
};
@ -1965,6 +2063,10 @@ struct MatchingContext
} // namespace impl
/* **************************************************
* ** public functions **
* ************************************************** */
/**
* Compute a maximum-weighted matching in the general undirected weighted
* graph given by "edges".
@ -1989,7 +2091,7 @@ struct MatchingContext
*
* @tparam WeightType Type used to represent edge weights.
* For example "long" or "double".
* @param edges List of weighted edges defining the graph.
* @param edges Graph defined as a vector of weighted edges.
*
* @return Vector of pairs of matched vertex indices.
* This is a subset of the edges in the graph.
@ -2009,7 +2111,7 @@ std::vector<VertexPair> maximum_weight_matching(
// Verify that the solution is optimal (works only for integer weights).
if (std::numeric_limits<WeightType>::is_integer) {
assert(matching.verify_optimum());
assert(impl::MatchingVerifier<WeightType>(matching).verify());
}
// Extract the matched edges.
@ -2050,7 +2152,7 @@ std::vector<VertexPair> maximum_weight_matching(
* This function takes time O(m), where "m" is the number of edges.
*
* @tparam WeightType Type used to represent edge weights.
* @param edges Vector of weighted edges defining the graph.
* @param edges Graph defined as a vector of weighted edges.
*
* @return Vector of edges with adjusted weights.
* If no adjustments are necessary, this will be a copy of the