1
0
Fork 0

Clean up verification code

This commit is contained in:
Joris van Rantwijk 2023-05-16 22:19:55 +02:00
parent 250fd4ea94
commit 731b202af3
1 changed files with 253 additions and 151 deletions

View File

@ -18,7 +18,6 @@
#include <stack> #include <stack>
#include <stdexcept> #include <stdexcept>
#include <tuple> #include <tuple>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -28,6 +27,9 @@
namespace mwmatching { namespace mwmatching {
/* **************************************************
* ** public definitions **
* ************************************************** */
/** Type representing the unique ID of a vertex. */ /** Type representing the unique ID of a vertex. */
using VertexId = unsigned int; using VertexId = unsigned int;
@ -64,6 +66,9 @@ struct Edge
namespace impl { namespace impl {
/* **************************************************
* ** private definitions **
* ************************************************** */
/** Value used to mark an invalid or undefined vertex. */ /** Value used to mark an invalid or undefined vertex. */
constexpr VertexId NO_VERTEX = std::numeric_limits<VertexId>::max(); constexpr VertexId NO_VERTEX = std::numeric_limits<VertexId>::max();
@ -73,6 +78,10 @@ constexpr VertexId NO_VERTEX = std::numeric_limits<VertexId>::max();
enum BlossomLabel { LABEL_NONE = 0, LABEL_S = 1, LABEL_T = 2 }; enum BlossomLabel { LABEL_NONE = 0, LABEL_S = 1, LABEL_T = 2 };
/* **************************************************
* ** private helper functions **
* ************************************************** */
/** Return a pair of vertices in flipped order. */ /** Return a pair of vertices in flipped order. */
inline VertexPair flip_vertex_pair(const VertexPair& vt) inline VertexPair flip_vertex_pair(const VertexPair& vt)
{ {
@ -166,6 +175,10 @@ std::vector<Edge<WeightType>> remove_negative_weight_edges(
} }
/* **************************************************
* ** struct Blossom **
* ************************************************** */
// Forward declaration. // Forward declaration.
template <typename WeightType> struct NonTrivialBlossom; template <typename WeightType> struct NonTrivialBlossom;
@ -212,7 +225,6 @@ struct Blossom
* that links to a different S-blossom, or "nullptr" if no such edge * that links to a different S-blossom, or "nullptr" if no such edge
* has been found. * has been found.
*/ */
// TODO : consider storing a copy of the edge instead of pointer
const Edge<WeightType>* best_edge; const Edge<WeightType>* best_edge;
protected: protected:
@ -241,9 +253,20 @@ public:
static_cast<NonTrivialBlossom<WeightType>*>(this) static_cast<NonTrivialBlossom<WeightType>*>(this)
: nullptr); : nullptr);
} }
const NonTrivialBlossom<WeightType>* nontrivial() const
{
return (is_nontrivial_blossom ?
static_cast<const NonTrivialBlossom<WeightType>*>(this)
: nullptr);
}
}; };
/* **************************************************
* ** struct NonTrivialBlossom **
* ************************************************** */
/** /**
* Represents a non-trivial blossom. * Represents a non-trivial blossom.
* *
@ -291,7 +314,6 @@ struct NonTrivialBlossom : public Blossom<WeightType>
* In case of a top-level S-blossom, "best_edge_set" is a list of * In case of a top-level S-blossom, "best_edge_set" is a list of
* least-slack edges between this blossom and other S-blossoms. * least-slack edges between this blossom and other S-blossoms.
*/ */
// TODO : consider storing a copy of the edge instead of pointer
std::list<const Edge<WeightType>*> best_edge_set; std::list<const Edge<WeightType>*> best_edge_set;
/** Initialize a non-trivial blossom. */ /** Initialize a non-trivial blossom. */
@ -333,6 +355,38 @@ struct NonTrivialBlossom : public Blossom<WeightType>
}; };
/** Call a function for every vertex inside the specified blossom. */
template <typename WeightType, typename Func>
inline void for_vertices_in_blossom(const Blossom<WeightType>* blossom, Func func)
{
const NonTrivialBlossom<WeightType>* ntb = blossom->nontrivial();
if (ntb) {
// Visit all vertices in the non-trivial blossom.
// Use an explicit stack to avoid deep call chains.
std::vector<const NonTrivialBlossom<WeightType>*> stack;
stack.push_back(ntb);
while (! stack.empty()) {
const NonTrivialBlossom<WeightType>* b = stack.back();
stack.pop_back();
for (const auto& sub : b->subblossoms) {
ntb = sub.blossom->nontrivial();
if (ntb) {
stack.push_back(ntb);
} else {
func(sub.blossom->base_vertex);
}
}
}
} else {
// A trivial blossom contains just one vertex.
func(blossom->base_vertex);
}
}
/** /**
* Represents a list of edges forming an alternating path or * Represents a list of edges forming an alternating path or
* an alternating cycle. * an alternating cycle.
@ -343,11 +397,14 @@ struct NonTrivialBlossom : public Blossom<WeightType>
*/ */
struct AlternatingPath struct AlternatingPath
{ {
// TODO : consider avoiding dynamic allocations for the alternating path
std::deque<VertexPair> edges; std::deque<VertexPair> edges;
}; };
/* **************************************************
* ** struct MatchingContext **
* ************************************************** */
/** /**
* This class holds all data used by the matching algorithm. * This class holds all data used by the matching algorithm.
* *
@ -388,7 +445,6 @@ struct MatchingContext
const VertexId num_vertex; const VertexId num_vertex;
/** For each vertex, a vector of pointers to its incident edges. */ /** For each vertex, a vector of pointers to its incident edges. */
// TODO : consider copying the Edge instead of pointers
const std::vector<std::vector<const EdgeT*>> adjacent_edges; const std::vector<std::vector<const EdgeT*>> adjacent_edges;
/** /**
@ -397,7 +453,6 @@ struct MatchingContext
* vertex_mate[x] == y if vertex "x" is matched to vertex "y". * vertex_mate[x] == y if vertex "x" is matched to vertex "y".
* vertex_mate[x] == NO_VERTEX if vertex "x" is unmatched. * vertex_mate[x] == NO_VERTEX if vertex "x" is unmatched.
*/ */
// TODO : consider merging all per-vertex info into a single "struct Vertex"
std::vector<VertexId> vertex_mate; std::vector<VertexId> vertex_mate;
/** /**
@ -436,7 +491,6 @@ struct MatchingContext
* "vertex_best_edge[x]" is the least-slack edge to any S-vertex, * "vertex_best_edge[x]" is the least-slack edge to any S-vertex,
* or "nullptr" if no such edge has been found. * or "nullptr" if no such edge has been found.
*/ */
// TODO - consider storing a copy of the edge instead of a pointer
std::vector<const EdgeT*> vertex_best_edge; std::vector<const EdgeT*> vertex_best_edge;
/** Queue of S-vertices to be scanned. */ /** Queue of S-vertices to be scanned. */
@ -558,37 +612,6 @@ struct MatchingContext
return vertex_dual[x] + vertex_dual[y] - weight_factor * edge.weight; return vertex_dual[x] + vertex_dual[y] - weight_factor * edge.weight;
} }
/** Call a function for every vertex inside the specified blossom. */
template <typename Func>
void for_vertices_in_blossom(BlossomT* blossom, Func func)
{
NonTrivialBlossomT* ntb = blossom->nontrivial();
if (ntb) {
// Visit all vertices in the non-trivial blossom.
// Use an explicit stack to avoid deep call chains.
std::vector<NonTrivialBlossomT*> stack;
stack.push_back(ntb);
while (! stack.empty()) {
NonTrivialBlossomT* b = stack.back();
stack.pop_back();
for (const auto& sub : b->subblossoms) {
ntb = sub.blossom->nontrivial();
if (ntb) {
stack.push_back(ntb);
} else {
func(sub.blossom->base_vertex);
}
}
}
} else {
// A trivial blossom contains just one vertex.
func(blossom->base_vertex);
}
}
/* /*
* Least-slack edge tracking: * Least-slack edge tracking:
* *
@ -820,7 +843,6 @@ struct MatchingContext
const EdgeT* best_edge = nullptr; const EdgeT* best_edge = nullptr;
WeightType best_slack = 0; WeightType best_slack = 0;
// TODO - clean this up when we have a linked list of top-level blossoms
auto consider_blossom = auto consider_blossom =
[this,&best_edge,&best_slack](BlossomT* blossom) [this,&best_edge,&best_slack](BlossomT* blossom)
{ {
@ -1317,7 +1339,7 @@ struct MatchingContext
{ {
// Assign label S to the blossom that contains vertex "x". // Assign label S to the blossom that contains vertex "x".
BlossomT* bx = vertex_top_blossom[x]; BlossomT* bx = vertex_top_blossom[x];
assert(bx->label == LABEL_NONE); // TODO - this assertion sometimes fails assert(bx->label == LABEL_NONE);
bx->label = LABEL_S; bx->label = LABEL_S;
VertexId y = vertex_mate[x]; VertexId y = vertex_mate[x];
@ -1718,8 +1740,103 @@ struct MatchingContext
// Each iteration takes time O(n**2). // Each iteration takes time O(n**2).
while (run_stage()) ; while (run_stage()) ;
} }
};
/* ********** Verify optimal solution: ********** */
/* **************************************************
* ** struct MatchingVerifier **
* ************************************************** */
/** Helper class to verify that an optimal solution has been found. */
template <typename WeightType>
struct MatchingVerifier
{
using EdgeT = Edge<WeightType>;
using BlossomT = Blossom<WeightType>;
using NonTrivialBlossomT = NonTrivialBlossom<WeightType>;
// Reference to the MatchingContext instance.
const MatchingContext<WeightType>& ctx;
// For each edge, the sum of duals of its incident vertices
// and duals of all blossoms that contain the edge.
std::vector<WeightType> edge_duals;
MatchingVerifier(const MatchingContext<WeightType>& ctx)
: ctx(ctx),
edge_duals(ctx.edges.size())
{ }
static bool checked_add(WeightType& result, WeightType a, WeightType b)
{
if (a > std::numeric_limits<WeightType>::max() - b) {
return true;
} else {
result = a + b;
return false;
}
}
/** Convert edge pointer to its index in the vector "edges". */
std::size_t edge_index(const EdgeT* edge)
{
return edge - ctx.edges.data();
}
/** Check that the array "vertex_mate" is consistent. */
bool verify_vertex_mate()
{
// Count matched vertices and check symmetry of "vertex_mate".
VertexId num_matched_vertex = 0;
for (VertexId x = 0; x < ctx.num_vertex; ++x) {
VertexId y = ctx.vertex_mate[x];
if (y != NO_VERTEX) {
++num_matched_vertex;
if (ctx.vertex_mate[y] != x) {
return false;
}
}
}
// Count matched edges.
VertexId num_matched_edge = 0;
for (const EdgeT& edge : ctx.edges) {
if (ctx.vertex_mate[edge.vt.first] == edge.vt.second) {
++num_matched_edge;
}
}
// Check that all matched vertices correspond to matched edges.
return (num_matched_vertex == 2 * num_matched_edge);
}
/**
* Check that vertex dual variables are non-negative,
* and all unmatched vertices have zero dual.
*/
bool verify_vertex_duals()
{
for (VertexId x = 0; x < ctx.num_vertex; ++x) {
if (ctx.vertex_dual[x] < 0) {
return false;
}
if ((ctx.vertex_mate[x] == NO_VERTEX) && (ctx.vertex_dual[x] != 0)) {
return false;
}
}
return true;
}
/** Check that blossom dual variables are non-negative. */
bool verify_blossom_duals()
{
for (const NonTrivialBlossomT& blossom : ctx.nontrivial_blossom) {
if (blossom.dual_var < 0) {
return false;
}
}
return true;
}
/** /**
* Helper function for verifying the solution. * Helper function for verifying the solution.
@ -1727,31 +1844,24 @@ struct MatchingContext
* Descend down the blossom tree to find edges that are contained * Descend down the blossom tree to find edges that are contained
* in blossoms. * in blossoms.
* *
* Adjust the slack of all contained edges to account for the dual
* variables of its containing blossoms.
*
* On the way down, keep track of the sum of dual variables of * On the way down, keep track of the sum of dual variables of
* the containing blossoms. * containing blossoms. Add blossom duals to edges that are contained
* inside blossoms.
* *
* On the way up, keep track of the total number of matched edges * On the way up, keep track of the total number of matched edges
* in the subblossoms. Then check that all blossoms with non-zero * in subblossoms. Check that all blossoms with non-zero dual variables
* dual variable are "full". * are "full".
* *
* @return True if successful; * @return True if successful;
* false if a blossom with non-zero dual is not full. * false if a blossom with non-zero dual is not full;
* false if blossom dual calculations cause numeric overflow.
*/ */
bool verify_blossom_edges(NonTrivialBlossomT* blossom, bool check_blossom(const NonTrivialBlossomT* blossom)
std::unordered_map<VertexPair, WeightType, boost::hash<VertexPair>>& edge_duals)
{ {
// TODO : fix line length
// TODO : simplify code
// TODO : optimize
// TODO : try to think of a simpler way
// For each vertex "x", // For each vertex "x",
// "vertex_depth[x]" is the depth of the smallest blossom on // "vertex_depth[x]" is the depth of the smallest blossom on
// the current descent path that contains "x". // the current descent path that contains "x".
std::vector<VertexId> vertex_depth(num_vertex); std::vector<VertexId> vertex_depth(ctx.num_vertex);
// At each depth, keep track of the sum of blossom duals // At each depth, keep track of the sum of blossom duals
// along the current descent path. // along the current descent path.
@ -1762,7 +1872,9 @@ struct MatchingContext
std::vector<VertexId> path_num_matched = {0}; std::vector<VertexId> path_num_matched = {0};
// Use an explicit stack to avoid deep recursion. // Use an explicit stack to avoid deep recursion.
std::stack<std::pair<NonTrivialBlossomT*, typename std::list<typename NonTrivialBlossomT::SubBlossom>::iterator>> stack; using SubBlossomList = std::list<typename NonTrivialBlossomT::SubBlossom>;
std::stack<std::pair<const NonTrivialBlossomT*,
typename SubBlossomList::const_iterator>> stack;
stack.emplace(blossom, blossom->subblossoms.begin()); stack.emplace(blossom, blossom->subblossoms.begin());
while (! stack.empty()) { while (! stack.empty()) {
@ -1781,8 +1893,12 @@ struct MatchingContext
}); });
// Calculate the sum of blossom duals at the new depth. // Calculate the sum of blossom duals at the new depth.
path_sum_dual.push_back(path_sum_dual.back() path_sum_dual.push_back(path_sum_dual.back());
+ blossom->dual_var); if (checked_add(path_sum_dual.back(),
path_sum_dual.back(),
blossom->dual_var)) {
return false;
}
// Initialize the number of matched edges at the new depth. // Initialize the number of matched edges at the new depth.
path_num_matched.push_back(0); path_num_matched.push_back(0);
@ -1808,18 +1924,22 @@ struct MatchingContext
// For each incident edge, find the smallest blossom // For each incident edge, find the smallest blossom
// that contains it. // that contains it.
VertexId x = sub->base_vertex; VertexId x = sub->base_vertex;
for (const EdgeT* edge : adjacent_edges[x]) { for (const EdgeT* edge : ctx.adjacent_edges[x]) {
// Only consider edges pointing out from "x". // Only consider edges pointing out from "x".
if (edge->vt.first == x) { if (edge->vt.first == x) {
VertexId y = edge->vt.second; VertexId y = edge->vt.second;
VertexId edge_depth = vertex_depth[y]; VertexId edge_depth = vertex_depth[y];
if (edge_depth > 0) { if (edge_depth > 0) {
// This edge is contained in a blossom. // Found the smallest blossom that contains this edge.
// Update its slack. // Add the duals of the containing blossoms.
edge_duals[edge->vt] += path_sum_dual[edge_depth]; if (checked_add(edge_duals[edge_index(edge)],
edge_duals[edge_index(edge)],
path_sum_dual[edge_depth])) {
return false;
}
// Update the number of matched edges in the blossom. // Update the number of matched edges in the blossom.
if (vertex_mate[x] == y) { if (ctx.vertex_mate[x] == y) {
path_num_matched[edge_depth] += 1; path_num_matched[edge_depth] += 1;
} }
} }
@ -1866,6 +1986,62 @@ struct MatchingContext
return true; return true;
} }
/**
* Check that all blossoms are full.
* Also calculate the sum of dual variables for every edge.
*/
bool verify_blossoms_and_calc_edge_duals()
{
// For each edge, calculate the sum of its vertex duals.
for (const EdgeT& edge : ctx.edges) {
if (checked_add(edge_duals[edge_index(&edge)],
ctx.vertex_dual[edge.vt.first],
ctx.vertex_dual[edge.vt.second])) {
return false;
}
}
// Descend down each top-level blossom.
// Check that blossoms are full.
// Add blossom duals to the edges contained inside the blossoms.
// This takes total time O(n**2).
for (const NonTrivialBlossomT& blossom : ctx.nontrivial_blossom) {
if (blossom.parent == nullptr) {
if (!check_blossom(&blossom)) {
return false;
}
}
}
return true;
}
/**
* Check that all edges have non-negative slack,
* and check that all matched edges have zero slack.
*
* @pre Edge duals must be calculated before calling this function.
*/
bool verify_edge_slack()
{
for (const EdgeT& edge : ctx.edges) {
WeightType duals = edge_duals[edge_index(&edge)];
WeightType weight = ctx.weight_factor * edge.weight;
if (weight > duals) {
return false;
}
WeightType slack = duals - weight;
if (ctx.vertex_mate[edge.vt.first] == edge.vt.second) {
if (slack != 0) {
return false;
}
}
}
return true;
}
/** /**
* Verify that the optimum solution has been found. * Verify that the optimum solution has been found.
* *
@ -1873,91 +2049,13 @@ struct MatchingContext
* *
* @return True if the solution is optimal; otherwise false. * @return True if the solution is optimal; otherwise false.
*/ */
bool verify_optimum() bool verify()
{ {
// Count matched vertices and check symmetry of "vertex_mate". return (verify_vertex_mate()
VertexId num_matched_vertex = 0; && verify_vertex_duals()
for (VertexId x = 0; x < num_vertex; ++x) { && verify_blossom_duals()
VertexId y = vertex_mate[x]; && verify_blossoms_and_calc_edge_duals()
if (y != NO_VERTEX) { && verify_edge_slack());
++num_matched_vertex;
if (vertex_mate[y] != x) {
return false;
}
}
}
// Check that all matched edges exist in the graph.
VertexId num_matched_edge = 0;
for (const EdgeT& edge : edges) {
if (vertex_mate[edge.vt.first] == edge.vt.second) {
++num_matched_edge;
}
}
if (num_matched_vertex != 2 * num_matched_edge) {
return false;
}
// Check that all dual variables are non-negative.
for (VertexId x = 0; x < num_vertex; ++x) {
if (vertex_dual[x] < 0) {
return false;
}
}
for (NonTrivialBlossomT& blossom : nontrivial_blossom) {
if (blossom.dual_var < 0) {
return false;
}
}
// Check that all unmatched vertices have zero dual.
for (VertexId x = 0; x < num_vertex; ++x) {
if ((vertex_mate[x] == NO_VERTEX) && (vertex_dual[x] != 0)) {
return false;
}
}
// Calculate the slack of each edge.
// A correction will be needed for edges inside blossoms.
std::unordered_map<VertexPair,
WeightType,
boost::hash<VertexPair>> edge_duals;
for (const EdgeT& edge : edges) {
// TODO : consider potential numeric overflow
edge_duals[edge.vt] = vertex_dual[edge.vt.first]
+ vertex_dual[edge.vt.second];
}
// Descend down each top-level blossom.
// Adjust edge slacks to account for blossom dual.
// Also check that all blossoms are full.
// This takes total time O(n**2).
for (NonTrivialBlossomT& blossom : nontrivial_blossom) {
if (blossom.parent == nullptr) {
verify_blossom_edges(&blossom, edge_duals);
}
}
// Check that all edges have non-negative slack.
// Check that all matched edges have zero slack.
// TODO : fix potential numeric overflow - WeightType may be unsigned
for (const EdgeT& edge : edges) {
WeightType slack = edge_duals[edge.vt]
- weight_factor * edge.weight;
if (slack < 0) {
return false;
}
if (vertex_mate[edge.vt.first] == edge.vt.second) {
if (slack != 0) {
return false;
}
}
}
// Optimum solution confirmed.
return true;
} }
}; };
@ -1965,6 +2063,10 @@ struct MatchingContext
} // namespace impl } // namespace impl
/* **************************************************
* ** public functions **
* ************************************************** */
/** /**
* Compute a maximum-weighted matching in the general undirected weighted * Compute a maximum-weighted matching in the general undirected weighted
* graph given by "edges". * graph given by "edges".
@ -1989,7 +2091,7 @@ struct MatchingContext
* *
* @tparam WeightType Type used to represent edge weights. * @tparam WeightType Type used to represent edge weights.
* For example "long" or "double". * For example "long" or "double".
* @param edges List of weighted edges defining the graph. * @param edges Graph defined as a vector of weighted edges.
* *
* @return Vector of pairs of matched vertex indices. * @return Vector of pairs of matched vertex indices.
* This is a subset of the edges in the graph. * This is a subset of the edges in the graph.
@ -2009,7 +2111,7 @@ std::vector<VertexPair> maximum_weight_matching(
// Verify that the solution is optimal (works only for integer weights). // Verify that the solution is optimal (works only for integer weights).
if (std::numeric_limits<WeightType>::is_integer) { if (std::numeric_limits<WeightType>::is_integer) {
assert(matching.verify_optimum()); assert(impl::MatchingVerifier<WeightType>(matching).verify());
} }
// Extract the matched edges. // Extract the matched edges.
@ -2050,7 +2152,7 @@ std::vector<VertexPair> maximum_weight_matching(
* This function takes time O(m), where "m" is the number of edges. * This function takes time O(m), where "m" is the number of edges.
* *
* @tparam WeightType Type used to represent edge weights. * @tparam WeightType Type used to represent edge weights.
* @param edges Vector of weighted edges defining the graph. * @param edges Graph defined as a vector of weighted edges.
* *
* @return Vector of edges with adjusted weights. * @return Vector of edges with adjusted weights.
* If no adjustments are necessary, this will be a copy of the * If no adjustments are necessary, this will be a copy of the