diff --git a/cpp/mwmatching.h b/cpp/mwmatching.h index dcfd45b..8d05e18 100644 --- a/cpp/mwmatching.h +++ b/cpp/mwmatching.h @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -28,6 +27,9 @@ namespace mwmatching { +/* ************************************************** + * ** public definitions ** + * ************************************************** */ /** Type representing the unique ID of a vertex. */ using VertexId = unsigned int; @@ -64,6 +66,9 @@ struct Edge namespace impl { +/* ************************************************** + * ** private definitions ** + * ************************************************** */ /** Value used to mark an invalid or undefined vertex. */ constexpr VertexId NO_VERTEX = std::numeric_limits::max(); @@ -73,6 +78,10 @@ constexpr VertexId NO_VERTEX = std::numeric_limits::max(); enum BlossomLabel { LABEL_NONE = 0, LABEL_S = 1, LABEL_T = 2 }; +/* ************************************************** + * ** private helper functions ** + * ************************************************** */ + /** Return a pair of vertices in flipped order. */ inline VertexPair flip_vertex_pair(const VertexPair& vt) { @@ -166,6 +175,10 @@ std::vector> remove_negative_weight_edges( } +/* ************************************************** + * ** struct Blossom ** + * ************************************************** */ + // Forward declaration. template struct NonTrivialBlossom; @@ -212,7 +225,6 @@ struct Blossom * that links to a different S-blossom, or "nullptr" if no such edge * has been found. */ -// TODO : consider storing a copy of the edge instead of pointer const Edge* best_edge; protected: @@ -241,9 +253,20 @@ public: static_cast*>(this) : nullptr); } + + const NonTrivialBlossom* nontrivial() const + { + return (is_nontrivial_blossom ? + static_cast*>(this) + : nullptr); + } }; +/* ************************************************** + * ** struct NonTrivialBlossom ** + * ************************************************** */ + /** * Represents a non-trivial blossom. * @@ -291,7 +314,6 @@ struct NonTrivialBlossom : public Blossom * In case of a top-level S-blossom, "best_edge_set" is a list of * least-slack edges between this blossom and other S-blossoms. */ -// TODO : consider storing a copy of the edge instead of pointer std::list*> best_edge_set; /** Initialize a non-trivial blossom. */ @@ -333,6 +355,38 @@ struct NonTrivialBlossom : public Blossom }; +/** Call a function for every vertex inside the specified blossom. */ +template +inline void for_vertices_in_blossom(const Blossom* blossom, Func func) +{ + const NonTrivialBlossom* ntb = blossom->nontrivial(); + if (ntb) { + // Visit all vertices in the non-trivial blossom. + // Use an explicit stack to avoid deep call chains. + std::vector*> stack; + stack.push_back(ntb); + + while (! stack.empty()) { + const NonTrivialBlossom* b = stack.back(); + stack.pop_back(); + + for (const auto& sub : b->subblossoms) { + ntb = sub.blossom->nontrivial(); + if (ntb) { + stack.push_back(ntb); + } else { + func(sub.blossom->base_vertex); + } + } + } + + } else { + // A trivial blossom contains just one vertex. + func(blossom->base_vertex); + } +} + + /** * Represents a list of edges forming an alternating path or * an alternating cycle. @@ -343,11 +397,14 @@ struct NonTrivialBlossom : public Blossom */ struct AlternatingPath { -// TODO : consider avoiding dynamic allocations for the alternating path std::deque edges; }; +/* ************************************************** + * ** struct MatchingContext ** + * ************************************************** */ + /** * This class holds all data used by the matching algorithm. * @@ -388,7 +445,6 @@ struct MatchingContext const VertexId num_vertex; /** For each vertex, a vector of pointers to its incident edges. */ -// TODO : consider copying the Edge instead of pointers const std::vector> adjacent_edges; /** @@ -397,7 +453,6 @@ struct MatchingContext * vertex_mate[x] == y if vertex "x" is matched to vertex "y". * vertex_mate[x] == NO_VERTEX if vertex "x" is unmatched. */ -// TODO : consider merging all per-vertex info into a single "struct Vertex" std::vector vertex_mate; /** @@ -436,7 +491,6 @@ struct MatchingContext * "vertex_best_edge[x]" is the least-slack edge to any S-vertex, * or "nullptr" if no such edge has been found. */ -// TODO - consider storing a copy of the edge instead of a pointer std::vector vertex_best_edge; /** Queue of S-vertices to be scanned. */ @@ -558,37 +612,6 @@ struct MatchingContext return vertex_dual[x] + vertex_dual[y] - weight_factor * edge.weight; } - /** Call a function for every vertex inside the specified blossom. */ - template - void for_vertices_in_blossom(BlossomT* blossom, Func func) - { - NonTrivialBlossomT* ntb = blossom->nontrivial(); - if (ntb) { - // Visit all vertices in the non-trivial blossom. - // Use an explicit stack to avoid deep call chains. - std::vector stack; - stack.push_back(ntb); - - while (! stack.empty()) { - NonTrivialBlossomT* b = stack.back(); - stack.pop_back(); - - for (const auto& sub : b->subblossoms) { - ntb = sub.blossom->nontrivial(); - if (ntb) { - stack.push_back(ntb); - } else { - func(sub.blossom->base_vertex); - } - } - } - - } else { - // A trivial blossom contains just one vertex. - func(blossom->base_vertex); - } - } - /* * Least-slack edge tracking: * @@ -820,7 +843,6 @@ struct MatchingContext const EdgeT* best_edge = nullptr; WeightType best_slack = 0; -// TODO - clean this up when we have a linked list of top-level blossoms auto consider_blossom = [this,&best_edge,&best_slack](BlossomT* blossom) { @@ -1317,7 +1339,7 @@ struct MatchingContext { // Assign label S to the blossom that contains vertex "x". BlossomT* bx = vertex_top_blossom[x]; - assert(bx->label == LABEL_NONE); // TODO - this assertion sometimes fails + assert(bx->label == LABEL_NONE); bx->label = LABEL_S; VertexId y = vertex_mate[x]; @@ -1718,8 +1740,103 @@ struct MatchingContext // Each iteration takes time O(n**2). while (run_stage()) ; } +}; - /* ********** Verify optimal solution: ********** */ + +/* ************************************************** + * ** struct MatchingVerifier ** + * ************************************************** */ + +/** Helper class to verify that an optimal solution has been found. */ +template +struct MatchingVerifier +{ + using EdgeT = Edge; + using BlossomT = Blossom; + using NonTrivialBlossomT = NonTrivialBlossom; + + // Reference to the MatchingContext instance. + const MatchingContext& ctx; + + // For each edge, the sum of duals of its incident vertices + // and duals of all blossoms that contain the edge. + std::vector edge_duals; + + MatchingVerifier(const MatchingContext& ctx) + : ctx(ctx), + edge_duals(ctx.edges.size()) + { } + + static bool checked_add(WeightType& result, WeightType a, WeightType b) + { + if (a > std::numeric_limits::max() - b) { + return true; + } else { + result = a + b; + return false; + } + } + + /** Convert edge pointer to its index in the vector "edges". */ + std::size_t edge_index(const EdgeT* edge) + { + return edge - ctx.edges.data(); + } + + /** Check that the array "vertex_mate" is consistent. */ + bool verify_vertex_mate() + { + // Count matched vertices and check symmetry of "vertex_mate". + VertexId num_matched_vertex = 0; + for (VertexId x = 0; x < ctx.num_vertex; ++x) { + VertexId y = ctx.vertex_mate[x]; + if (y != NO_VERTEX) { + ++num_matched_vertex; + if (ctx.vertex_mate[y] != x) { + return false; + } + } + } + + // Count matched edges. + VertexId num_matched_edge = 0; + for (const EdgeT& edge : ctx.edges) { + if (ctx.vertex_mate[edge.vt.first] == edge.vt.second) { + ++num_matched_edge; + } + } + + // Check that all matched vertices correspond to matched edges. + return (num_matched_vertex == 2 * num_matched_edge); + } + + /** + * Check that vertex dual variables are non-negative, + * and all unmatched vertices have zero dual. + */ + bool verify_vertex_duals() + { + for (VertexId x = 0; x < ctx.num_vertex; ++x) { + if (ctx.vertex_dual[x] < 0) { + return false; + } + if ((ctx.vertex_mate[x] == NO_VERTEX) && (ctx.vertex_dual[x] != 0)) { + return false; + } + } + return true; + } + + /** Check that blossom dual variables are non-negative. */ + bool verify_blossom_duals() + { + for (const NonTrivialBlossomT& blossom : ctx.nontrivial_blossom) { + if (blossom.dual_var < 0) { + return false; + } + } + return true; + } /** * Helper function for verifying the solution. @@ -1727,31 +1844,24 @@ struct MatchingContext * Descend down the blossom tree to find edges that are contained * in blossoms. * - * Adjust the slack of all contained edges to account for the dual - * variables of its containing blossoms. - * * On the way down, keep track of the sum of dual variables of - * the containing blossoms. + * containing blossoms. Add blossom duals to edges that are contained + * inside blossoms. * * On the way up, keep track of the total number of matched edges - * in the subblossoms. Then check that all blossoms with non-zero - * dual variable are "full". + * in subblossoms. Check that all blossoms with non-zero dual variables + * are "full". * * @return True if successful; - * false if a blossom with non-zero dual is not full. + * false if a blossom with non-zero dual is not full; + * false if blossom dual calculations cause numeric overflow. */ - bool verify_blossom_edges(NonTrivialBlossomT* blossom, - std::unordered_map>& edge_duals) + bool check_blossom(const NonTrivialBlossomT* blossom) { -// TODO : fix line length -// TODO : simplify code -// TODO : optimize -// TODO : try to think of a simpler way - // For each vertex "x", // "vertex_depth[x]" is the depth of the smallest blossom on // the current descent path that contains "x". - std::vector vertex_depth(num_vertex); + std::vector vertex_depth(ctx.num_vertex); // At each depth, keep track of the sum of blossom duals // along the current descent path. @@ -1762,7 +1872,9 @@ struct MatchingContext std::vector path_num_matched = {0}; // Use an explicit stack to avoid deep recursion. - std::stack::iterator>> stack; + using SubBlossomList = std::list; + std::stack> stack; stack.emplace(blossom, blossom->subblossoms.begin()); while (! stack.empty()) { @@ -1781,8 +1893,12 @@ struct MatchingContext }); // Calculate the sum of blossom duals at the new depth. - path_sum_dual.push_back(path_sum_dual.back() - + blossom->dual_var); + path_sum_dual.push_back(path_sum_dual.back()); + if (checked_add(path_sum_dual.back(), + path_sum_dual.back(), + blossom->dual_var)) { + return false; + } // Initialize the number of matched edges at the new depth. path_num_matched.push_back(0); @@ -1808,18 +1924,22 @@ struct MatchingContext // For each incident edge, find the smallest blossom // that contains it. VertexId x = sub->base_vertex; - for (const EdgeT* edge : adjacent_edges[x]) { + for (const EdgeT* edge : ctx.adjacent_edges[x]) { // Only consider edges pointing out from "x". if (edge->vt.first == x) { VertexId y = edge->vt.second; VertexId edge_depth = vertex_depth[y]; if (edge_depth > 0) { - // This edge is contained in a blossom. - // Update its slack. - edge_duals[edge->vt] += path_sum_dual[edge_depth]; + // Found the smallest blossom that contains this edge. + // Add the duals of the containing blossoms. + if (checked_add(edge_duals[edge_index(edge)], + edge_duals[edge_index(edge)], + path_sum_dual[edge_depth])) { + return false; + } // Update the number of matched edges in the blossom. - if (vertex_mate[x] == y) { + if (ctx.vertex_mate[x] == y) { path_num_matched[edge_depth] += 1; } } @@ -1866,6 +1986,62 @@ struct MatchingContext return true; } + /** + * Check that all blossoms are full. + * Also calculate the sum of dual variables for every edge. + */ + bool verify_blossoms_and_calc_edge_duals() + { + // For each edge, calculate the sum of its vertex duals. + for (const EdgeT& edge : ctx.edges) { + if (checked_add(edge_duals[edge_index(&edge)], + ctx.vertex_dual[edge.vt.first], + ctx.vertex_dual[edge.vt.second])) { + return false; + } + } + + // Descend down each top-level blossom. + // Check that blossoms are full. + // Add blossom duals to the edges contained inside the blossoms. + // This takes total time O(n**2). + for (const NonTrivialBlossomT& blossom : ctx.nontrivial_blossom) { + if (blossom.parent == nullptr) { + if (!check_blossom(&blossom)) { + return false; + } + } + } + + return true; + } + + /** + * Check that all edges have non-negative slack, + * and check that all matched edges have zero slack. + * + * @pre Edge duals must be calculated before calling this function. + */ + bool verify_edge_slack() + { + for (const EdgeT& edge : ctx.edges) { + WeightType duals = edge_duals[edge_index(&edge)]; + WeightType weight = ctx.weight_factor * edge.weight; + + if (weight > duals) { + return false; + } + WeightType slack = duals - weight; + + if (ctx.vertex_mate[edge.vt.first] == edge.vt.second) { + if (slack != 0) { + return false; + } + } + } + return true; + } + /** * Verify that the optimum solution has been found. * @@ -1873,91 +2049,13 @@ struct MatchingContext * * @return True if the solution is optimal; otherwise false. */ - bool verify_optimum() + bool verify() { - // Count matched vertices and check symmetry of "vertex_mate". - VertexId num_matched_vertex = 0; - for (VertexId x = 0; x < num_vertex; ++x) { - VertexId y = vertex_mate[x]; - if (y != NO_VERTEX) { - ++num_matched_vertex; - if (vertex_mate[y] != x) { - return false; - } - } - } - - // Check that all matched edges exist in the graph. - VertexId num_matched_edge = 0; - for (const EdgeT& edge : edges) { - if (vertex_mate[edge.vt.first] == edge.vt.second) { - ++num_matched_edge; - } - } - - if (num_matched_vertex != 2 * num_matched_edge) { - return false; - } - - // Check that all dual variables are non-negative. - for (VertexId x = 0; x < num_vertex; ++x) { - if (vertex_dual[x] < 0) { - return false; - } - } - - for (NonTrivialBlossomT& blossom : nontrivial_blossom) { - if (blossom.dual_var < 0) { - return false; - } - } - - // Check that all unmatched vertices have zero dual. - for (VertexId x = 0; x < num_vertex; ++x) { - if ((vertex_mate[x] == NO_VERTEX) && (vertex_dual[x] != 0)) { - return false; - } - } - - // Calculate the slack of each edge. - // A correction will be needed for edges inside blossoms. - std::unordered_map> edge_duals; - for (const EdgeT& edge : edges) { -// TODO : consider potential numeric overflow - edge_duals[edge.vt] = vertex_dual[edge.vt.first] - + vertex_dual[edge.vt.second]; - } - - // Descend down each top-level blossom. - // Adjust edge slacks to account for blossom dual. - // Also check that all blossoms are full. - // This takes total time O(n**2). - for (NonTrivialBlossomT& blossom : nontrivial_blossom) { - if (blossom.parent == nullptr) { - verify_blossom_edges(&blossom, edge_duals); - } - } - - // Check that all edges have non-negative slack. - // Check that all matched edges have zero slack. -// TODO : fix potential numeric overflow - WeightType may be unsigned - for (const EdgeT& edge : edges) { - WeightType slack = edge_duals[edge.vt] - - weight_factor * edge.weight; - if (slack < 0) { - return false; - } - if (vertex_mate[edge.vt.first] == edge.vt.second) { - if (slack != 0) { - return false; - } - } - } - - // Optimum solution confirmed. - return true; + return (verify_vertex_mate() + && verify_vertex_duals() + && verify_blossom_duals() + && verify_blossoms_and_calc_edge_duals() + && verify_edge_slack()); } }; @@ -1965,6 +2063,10 @@ struct MatchingContext } // namespace impl +/* ************************************************** + * ** public functions ** + * ************************************************** */ + /** * Compute a maximum-weighted matching in the general undirected weighted * graph given by "edges". @@ -1989,7 +2091,7 @@ struct MatchingContext * * @tparam WeightType Type used to represent edge weights. * For example "long" or "double". - * @param edges List of weighted edges defining the graph. + * @param edges Graph defined as a vector of weighted edges. * * @return Vector of pairs of matched vertex indices. * This is a subset of the edges in the graph. @@ -2009,7 +2111,7 @@ std::vector maximum_weight_matching( // Verify that the solution is optimal (works only for integer weights). if (std::numeric_limits::is_integer) { - assert(matching.verify_optimum()); + assert(impl::MatchingVerifier(matching).verify()); } // Extract the matched edges. @@ -2050,7 +2152,7 @@ std::vector maximum_weight_matching( * This function takes time O(m), where "m" is the number of edges. * * @tparam WeightType Type used to represent edge weights. - * @param edges Vector of weighted edges defining the graph. + * @param edges Graph defined as a vector of weighted edges. * * @return Vector of edges with adjusted weights. * If no adjustments are necessary, this will be a copy of the