From efb238ff8e72f8a6413a3d26801c933c35dc3d9d Mon Sep 17 00:00:00 2001 From: Joris van Rantwijk Date: Fri, 16 Jun 2023 19:55:43 +0200 Subject: [PATCH] C++ datastructures for O(n*m*log(n)) --- cpp/Makefile | 5 + cpp/datastruct.h | 1119 +++++++++++++++++++++++++++++++++++++++ cpp/test_datastruct.cpp | 660 +++++++++++++++++++++++ 3 files changed, 1784 insertions(+) create mode 100644 cpp/datastruct.h create mode 100644 cpp/test_datastruct.cpp diff --git a/cpp/Makefile b/cpp/Makefile index 0075e80..18cfcfd 100644 --- a/cpp/Makefile +++ b/cpp/Makefile @@ -21,6 +21,11 @@ test_mwmatching: DBGFLAGS = -fsanitize=address -fsanitize=undefined test_mwmatching: test_mwmatching.cpp mwmatching.h $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(LDFLAGS) -o $@ $< $(LIB_BOOST_TEST) +test_datastruct: OPTFLAGS = -O1 +test_datastruct: DBGFLAGS = -fsanitize=address -fsanitize=undefined +test_datastruct: test_datastruct.cpp datastruct.h + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(LDFLAGS) -o $@ $< $(LIB_BOOST_TEST) + .PHONY: clean clean: $(RM) run_matching run_matching_dbg test_mwmatching diff --git a/cpp/datastruct.h b/cpp/datastruct.h new file mode 100644 index 0000000..a19729b --- /dev/null +++ b/cpp/datastruct.h @@ -0,0 +1,1119 @@ +/* + * Data structures for matching. + */ + +#include +#include +#include + + +/** + * Element in a disjoint set. + * + * Every instance of DisjointSetNode is a member of exactly one disjoint set. + * Some instances of DisjointSetNode also represent a set. + * Every set has a "label". + * + * The following operations can be done efficiently: + * - Find the label of the set that contains a given node. + * - Merge two sets. + * - Undo a previous merge step. + * + * Internally, a set is represented by a tree of DisjointSetNode instances + * that are linked through parent pointers. The shape of the tree is + * controlled such that its depth is at most O(log(n)), where "n" is + * the number of elements in the set. + * + * See also https://en.wikipedia.org/wiki/Disjoint-set_data_structure + */ +template +class DisjointSetNode +{ +public: + /** + * Create a new element and a new set containing only the new element. + * + * The new instance represents the new element as well as the new set. + * + * This function takes time O(1). + */ + explicit DisjointSetNode(const LabelType& label) + : m_label(label), + m_parent(nullptr), + m_size(1) + { } + + // Prevent copying. + DisjointSetNode(const DisjointSetNode&) = delete; + DisjointSetNode& operator=(const DisjointSetNode&) = delete; + + /** + * Return the label of the set that contains the element represented + * by this instance. + * + * This function takes time O(log(n)). + */ + LabelType find() const + { + const DisjointSetNode* p = this; + while (p->m_parent != nullptr) { + p = p->m_parent; + } + return p->m_label; + } + + /** + * Change the label of the set represented by this instance. + * + * Only sets have labels. Elements do not have labels. + * Some instances of DisjointSetNode do not represent a set, + * and it is not allowed to assign names to such instances. + * + * This function takes time O(1). + */ + void set_label(const LabelType& label) + { + assert(m_parent == nullptr); + m_label = label; + } + + /** + * Merge the sets represented by this instance and by "other". + * + * Both instances must represent a set (not merely an element). + * The merged set will initially have no label. + * + * This function takes time O(1). + * + * @return Instance representing the merged set. + */ + DisjointSetNode* merge(DisjointSetNode* other) + { + assert(m_parent == nullptr); + assert(other->m_parent == nullptr); + assert(other != this); + + m_label = LabelType{}; + other->m_label = LabelType{}; + + if (m_size < other->m_size) { + this->m_parent = other; + other->m_size += this->m_size; + return other; + } else { + other->m_parent = this; + this->m_size += other->m_size; + return this; + } + } + + /** + * Undo a previous "merge" step. + * + * This function detaches a subset from a set into which it was + * previously merged, and re-establishes the subset as an independent + * set. + * + * The following conditions must hold: + * - At some point in the past, there was a group of K instances + * of DisjointSetNode (s_1, s_2, ..., s_k) that each represented + * a set. + * - These sets were later merged through (K-1) merge steps. + * - Since that point in time, no other merges into the same set + * occurred, disregarding any merges that have already been undone. + * + * The group of (K-1) merge steps can be undone by calling "detach" + * on each of the K instances. These K calls may occur in any order. + * Afterwards, each of these instances will again represent the same set + * that it represented originally. The original labels of these sets + * are lost. Labels must therefore be re-assigned through parameters + * to the "detach" calls. + * + * This function takes time O(log(n)). + */ + void detach(const LabelType& label) + { + DisjointSetNode* p = m_parent; + while (p != nullptr) { + p->m_size -= m_size; + p = p->m_parent; + } + m_parent = nullptr; + m_label = label; + } + +private: + LabelType m_label; + DisjointSetNode* m_parent; + std::size_t m_size; +}; + + +/** + * Priority queue that can be split into separate queues. + * + * Each element in the queue has an index, a priority and attached data. + * The following operations can be done efficiently: + * - Find the element with lowest priority. + * - Insert an element with given index, priority and data. + * - Update the priority of an existing element. + * - Split the queue such that all elements with index below a specified + * threshold end up in one queue and all remaining elements in + * a separate queue. + * + * The implementation is essentially an AVL tree, with minimum-priority + * tracking added to it. The implementation follows the outlines of + * https://en.wikipedia.org/wiki/Avl_tree + */ +template +class SplittablePriorityQueue +{ +public: + typedef unsigned int IndexType; + + /** Element in a queue. */ + struct Node + { + public: + Node() + : parent(nullptr), left_child(nullptr), right_child(nullptr), + height(0) + { } + + ~Node() + { + // Check that the node is not contained in any queue. + assert(height == 0); + } + + // Prevent copying. + Node(const Node&) = delete; + Node& operator=(const Node&) = delete; + + private: + IndexType index; + PrioType prio; + DataType data; + PrioType best_prio; + DataType best_data; + Node* parent; + Node* left_child; + Node* right_child; + unsigned int height; + + friend class SplittablePriorityQueue; + }; + + /** Construct an empty queue. */ + SplittablePriorityQueue() + : m_root(nullptr) + { } + + ~SplittablePriorityQueue() + { + // Check that the queue is empty. + assert(! m_root); + } + + // Prevent copying. + SplittablePriorityQueue(const SplittablePriorityQueue&) = delete; + SplittablePriorityQueue& operator=(const SplittablePriorityQueue&) = delete; + + /** Construct a queue and move elements from "other" to the new queue. */ + SplittablePriorityQueue(SplittablePriorityQueue&& other) + { + m_root = other.m_root; + other.m_root = nullptr; + } + + /** Move all elements from "other" to this queue. */ + SplittablePriorityQueue& operator=(SplittablePriorityQueue&& other) + { + assert(! m_root); + std::swap(m_root, other.m_root); + return *this; + } + + /** + * Remove all elements from the queue. + * + * This function takes time O(n). + */ + void clear() + { + Node* node = m_root; + m_root = nullptr; + + while (node) { + Node* prev = node; + if (node->left_child) { + node = node->left_child; + prev->left_child = nullptr; + } else if (node->right_child) { + node = node->right_child; + prev->right_child = nullptr; + } else { + node = node->parent; + prev->parent = nullptr; + prev->height = 0; + } + } + } + + /** Return true if the queue is empty. */ + bool empty() const + { + return (! m_root); + } + + /** + * Return the minimum-priority element as tuple (prio, data). + * + * The queue must be non-empty. + * + * This function takes time O(1). + */ + std::pair find_min() const + { + assert(m_root); + return std::make_pair(m_root->best_prio, m_root->best_data); + + } + + /** + * Initialize the specified Node instance and insert it into the queue. + * + * The specified Node instance must currently be unused; not contained + * in any queue. + * + * The queue stores a pointer to the Node instance. The caller must + * ensure that the instance remains valid for as long as it is contained + * in any queue. + * + * This function takes time O(log(n)). + */ + void insert(Node* node, + IndexType index, + PrioType prio, + const DataType& data) + { + assert(node->height == 0); + + node->index = index; + node->prio = prio; + node->data = data; + node->best_prio = prio; + node->best_data = data; + node->parent = nullptr; + node->left_child = nullptr; + node->right_child = nullptr; + node->height = 1; + + if (! m_root) { + m_root = node; + return; + } + + Node* p = m_root; + while (true) { + assert(index != p->index); + if (index < p->index) { + if (! p->left_child) { + p->left_child = node; + break; + } + p = p->left_child; + } else { + if (! p->right_child) { + p->right_child = node; + break; + } + p = p->right_child; + } + } + + node->parent = p; + rebalance_up(node); + } + + /** + * Decrease priority and update data of an existing element. + * + * Nothing is done if the specified priority is greater or equal + * to the current priority of the existing element. + * + * This function takes time O(log(n)). + */ + void update(Node* node, PrioType prio, const DataType& data) + { + assert(node->height != 0); + + // Do nothing unless the new priority is strictly lower. + if (prio >= node->prio) { + return; + } + + // Update the node itself. + node->prio = prio; + node->data = data; + + // Update the best-prio information in the node and its ancestors. + while (true) { + repair_node_best_prio(node); + if (! node->parent) { + break; + } + node = node->parent; + } + + // Check that the node belongs to this queue. + assert(node == m_root); + } + + /** + * Split the queue by index. + * + * All elements with index less than "threshold" remain in the queue. + * All elements with index greater or equal to "threshold" are moved + * to a new queue. That new queue is returned. + * + * This function takes time O(log(n)). + */ + SplittablePriorityQueue split(IndexType threshold) + { + SplittablePriorityQueue new_queue; + + // Special case for an empty queue. + if (! m_root) { + return new_queue; + } + + // Descend down the tree along a path as close as possible to the + // threshold. Stop when we reach a subtree that is entirely on + // one side of the threshold. + Node* node = m_root; + while (true) { + if (node->index < threshold) { + /* + * This node is below the threshold. Descend right. + * + * (right_child) { + break; + } + node = node->right_child; + } else if (node->index > threshold) { + /* + * This node is above the threshold. Descend left. + * + * (>T) + * / \ + * / \ + * go here --> (?) (>T) + */ + if (! node->left_child) { + break; + } + node = node->left_child; + } else { + /* + * This node is exactly at the threshold. + * Descend left if possible, then stop. + * + * (=T) + * / \ + * / \ + * (T) + */ + if (node->left_child) { + node = node->left_child; + } + break; + } + } + + // Build "left_tree" with nodes < threshold, + // and "right_tree" with nodes >= threshold. + Node* left_tree = nullptr; + Node* right_tree = nullptr; + + // Initially, either the left or right tree consists of only + // the subtree where the descending path ended, and the other + // tree is empty. + if (node->index < threshold) { + left_tree = node; + } else { + right_tree = node; + } + + // Detach node from its parent. + Node* parent = node->parent; + detach_node(node); + + // Retrace up the tree. + // On the way up, join each node to the left or right subtree. + while (parent) { + + // Move up to parent. + node = parent; + parent = node->parent; + + detach_node(node); + + if (node->index < threshold) { + /* + * Join node (with its descendants) to the left tree. + * + * N <--- new left_tree + * / . + * X . + * / \ (left_tree) + */ + assert(! node->right_child); + left_tree = join(node->left_child, node, left_tree); + } else { + /* + * Join node (with its descendants) to the right tree. + * + * N <--- new right_tree + * . \ + * . X + * (right_tree) / \ + */ + assert(! node->left_child); + right_tree = join(right_tree, node, node->right_child); + } + } + + // Keep the left tree in this instance. + m_root = left_tree; + + // Return the right tree as a new queue. + new_queue.m_root = right_tree; + return new_queue; + } + +private: + /** Return node height, or 0 if node == nullptr. */ + static unsigned int get_node_height(const Node* node) + { + return node ? node->height : 0; + } + + /** Detach a node from its parent. */ + static void detach_node(Node* node) + { + Node* parent = node->parent; + if (parent) { + node->parent = nullptr; + if (parent->left_child == node) { + parent->left_child = nullptr; + } else if (parent->right_child == node) { + parent->right_child = nullptr; + } + } + } + + /** Repair best-priority information in the specified node. */ + static void repair_node_best_prio(Node* node) + { + PrioType best_prio = node->prio; + DataType best_data = node->data; + + Node* lchild = node->left_child; + if (lchild && (lchild->best_prio < best_prio)) { + best_prio = lchild->best_prio; + best_data = lchild->best_data; + } + + Node* rchild = node->right_child; + if (rchild && (rchild->best_prio < best_prio)) { + best_prio = rchild->best_prio; + best_data = rchild->best_data; + } + + node->best_prio = best_prio; + node->best_data = best_data; + } + + /** + * Repair the height and best-priority information of a node + * after modifying its children. + * + * After repairing a node, it is typically necessary to also repair + * its ancestors. + */ + static void repair_node(Node* node) + { + Node* lchild = node->left_child; + Node* rchild = node->right_child; + + // Repair node height. + node->height = 1 + std::max(get_node_height(lchild), + get_node_height(rchild)); + + // Repair best-priority. + repair_node_best_prio(node); + } + + /** Rotate the subtree to the left and return the new root of the subtree. */ + static Node* rotate_left(Node* node) + { + /* + * N C + * / \ / \ + * A C ---> N D + * / \ / \ + * B D A B + */ + Node* parent = node->parent; + Node* new_top = node->right_child; + assert(new_top); + + Node* nb = new_top->left_child; + node->right_child = nb; + if (nb) { + nb->parent = node; + } + + new_top->left_child = node; + node->parent = new_top; + + new_top->parent = parent; + + if (parent) { + if (parent->left_child == node) { + parent->left_child = new_top; + } else if (parent->right_child == node) { + parent->right_child = new_top; + } else { + assert(false); + } + } + + repair_node(node); + repair_node(new_top); + + return new_top; + } + + /** Rotate the subtree to the right and return the new root of the subtree. */ + Node* rotate_right(Node* node) + { + /* + * N B + * / \ / \ + * B D ---> A N + * / \ / \ + * A C C D + */ + Node* parent = node->parent; + Node* new_top = node->left_child; + assert(new_top); + + Node* nc = new_top->right_child; + node->left_child = nc; + if (nc) { + nc->parent = node; + } + + new_top->right_child = node; + node->parent = new_top; + + new_top->parent = parent; + + if (parent) { + if (parent->left_child == node) { + parent->left_child = new_top; + } else if (parent->right_child == node) { + parent->right_child = new_top; + } else { + assert(false); + } + } + + repair_node(node); + repair_node(new_top); + + return new_top; + } + + /** Repair and rebalance the specified node and its ancestors. */ + void rebalance_up(Node* node) + { + while (true) { + unsigned int lh = get_node_height(node->left_child); + unsigned int rh = get_node_height(node->right_child); + + if (lh > rh + 1) { + /* + * This node is left-heavy. Rotate right to rebalance. + * + * N L + * / \ / \ + * L \ / N + * / \ \ ---> / / \ + * A B \ A B \ + * \ \ + * R R + */ + Node *lchild = node->left_child; + unsigned int llh = get_node_height(lchild->left_child); + unsigned int lrh = get_node_height(lchild->right_child); + if (llh < lrh) { + // Double rotation. + lchild = rotate_left(lchild); + } + node = rotate_right(node); + } else if (lh + 1 < rh) { + /* + * This node is right-heavy. Rotate left to rebalance. + * + * N R + * / \ / \ + * / R N \ + * / / \ ---> / \ \ + * / A B / A B + * / / + * L L + */ + Node *rchild = node->right_child; + unsigned int rlh = get_node_height(rchild->left_child); + unsigned int rrh = get_node_height(rchild->right_child); + if (rlh > rrh) { + // Double rotation. + rchild = rotate_right(rchild); + } + node = rotate_left(node); + } else { + // No rotation, but must still repair the node. + repair_node(node); + } + + if (! node->parent) { + // Reached root. + m_root = node; + break; + } + + // Continue rebalancing at the parent. + node = node->parent; + } + } + + /** + * Join a left subtree, middle node and right subtree together. + * + * The left subtree is higher than the right subtree. + */ + Node* join_right(Node* ltree, Node* node, Node* rtree) + { + assert(ltree); + unsigned int lh = ltree->height; + unsigned int rh = get_node_height(rtree); + assert(lh > rh + 1); + + /* + * Descend down the right spine of "ltree". + * Stop at a node with compatible height, then insert "node" + * and attach "rtree". + * + * ltree + * / \ + * X + * / \ + * X <-- cur + * / \ + * node + * / \ + * X rtree + */ + + // Descend to a point with compatible height. + Node* cur = ltree; + while (cur->right_child && (cur->right_child->height > rh + 1)) { + cur = cur->right_child; + } + + // Insert "node" and "rtree". + node->left_child = cur->right_child; + node->right_child = rtree; + if (node->left_child) { + node->left_child->parent = node; + } + if (rtree) { + rtree->parent = node; + } + cur->right_child = node; + node->parent = cur; + + // A double rotation may be necessary. + if ((! cur->left_child) || (cur->left_child->height <= rh)) { + node = rotate_right(node); + cur = rotate_left(cur); + } else { + repair_node(node); + repair_node(cur); + } + + // Ascend from "cur" to the root of the tree; repair and rebalance. + while (cur->parent) { + cur = cur->parent; + assert(cur->left_child); + assert(cur->right_child); + + if (cur->left_child->height + 1 < cur->right_child->height) { + rotate_left(cur); + } else { + repair_node(cur); + } + } + + return cur; + } + + /** + * Join a left subtree, middle node and right subtree together. + * + * The right subtree is higher than the left subtree. + */ + Node* join_left(Node* ltree, Node* node, Node* rtree) + { + assert(rtree); + unsigned int lh = get_node_height(ltree); + unsigned int rh = rtree->height; + assert(lh + 1 < rh); + + /* + * Descend down the left spine of "rtree". + * Stop at a node with compatible height, then insert "node" + * and attach "ltree". + * + * rtree + * / \ + * X + * / \ + * cur --> X + * / \ + * node + * / \ + * ltree X + */ + + // Descend to a point with compatible height. + Node* cur = rtree; + while (cur->left_child && (cur->left_child->height > lh + 1)) { + cur = cur->left_child; + } + + // Insert "node" and "ltree". + node->left_child = ltree; + node->right_child = cur->left_child; + if (ltree) { + ltree->parent = node; + } + if (node->right_child) { + node->right_child->parent = node; + } + cur->left_child = node; + node->parent = cur; + + // A double rotation may be necessary. + if ((! cur->right_child) || (cur->right_child->height <= lh)) { + node = rotate_left(node); + cur = rotate_right(cur); + } else { + repair_node(node); + repair_node(cur); + } + + // Ascend from "cur" to the root of the tree; repair and rebalance. + while (cur->parent) { + cur = cur->parent; + assert(cur->left_child); + assert(cur->right_child); + + if (cur->left_child->height > cur->right_child->height + 1) { + cur = rotate_right(cur); + } else { + repair_node(cur); + } + } + + return cur; + } + + /** + * Join a left subtree, middle node and right subtree together. + * + * The index of all nodes in the left subtree must be less than + * the index of the middle node. The index of all nodes in + * the right subtree must be greater than the middle node. + * + * The left or right subtree may initially be a child of the middle + * node; such links will be broken as needed. + * + * The left and right subtrees must be consistent, semi-balanced trees. + * Height and priority of the middle node may initially be inconsistent; + * this function will repair it. + * + * @return Root node of the joined tree. + */ + Node* join(Node* ltree, Node* node, Node* rtree) + { + unsigned int lh = get_node_height(ltree); + unsigned int rh = get_node_height(rtree); + + if (lh > rh + 1) { + assert(ltree); + ltree->parent = nullptr; + return join_right(ltree, node, rtree); + } else if (lh + 1 < rh) { + assert(rtree); + rtree->parent = nullptr; + return join_left(ltree, node, rtree); + } else { + /* + * Subtree heights are compatible. Just join them. + * + * node + * / \ + * ltree rtree + * / \ / \ + */ + node->left_child = ltree; + if (ltree) { + ltree->parent = node; + } + node->right_child = rtree; + if (rtree) { + rtree->parent = node; + } + repair_node(node); + return node; + } + } + + /** Root node of the queue, or "nullptr" if the queue is empty. */ + Node* m_root; +}; + + +/** + * Normal min-priority queue based on a binary heap. + */ +template +class PriorityQueue +{ +public: + typedef unsigned int IndexType; + static constexpr IndexType INVALID_INDEX = std::numeric_limits::max(); + + /** Element in a PriorityQueue. */ + struct Node + { + public: + Node() + : index(INVALID_INDEX) + { } + + ~Node() + { + // Check that the node is not contained in any queue. + assert(index == INVALID_INDEX); + } + + // Prevent copying. + Node(const Node&) = delete; + Node& operator=(const Node&) = delete; + + private: + IndexType index; + PrioType prio; + DataType data; + + friend class PriorityQueue; + }; + + /** Construct an empty queue. */ + PriorityQueue() + { } + + ~PriorityQueue() + { + // Check that the queue is empty. + assert(m_heap.empty()); + } + + // Prevent copying. + PriorityQueue(const PriorityQueue&) = delete; + PriorityQueue& operator=(const PriorityQueue&) = delete; + + /** Construct a queue and move elements from "other" to the new queue. */ + PriorityQueue(PriorityQueue&& other) + : m_heap(std::move(other.m_heap)) + { } + + /** + * Remove all elements from the queue. + * + * This function takes time O(n). + */ + void clear() + { + for (Node* node : m_heap) { + node->index = INVALID_INDEX; + } + m_heap.clear(); + } + + /** Return true if the queue is empty. */ + bool empty() const + { + return m_heap.empty(); + } + + /** + * Return the minimum-priority element as tuple (prio, data). + * + * The queue must be non-empty. + * + * This function takes time O(1). + */ + std::pair find_min() const + { + assert(! m_heap.empty()); + Node* top = m_heap.front(); + return std::make_pair(top->prio, top->data); + } + + /** + * Insert a new element into the queue. + * + * This function takes time O(log(n)). + */ + void insert(Node* node, PrioType prio, const DataType& data) + { + assert(node->index == INVALID_INDEX); + + node->index = m_heap.size(); + node->prio = prio; + node->data = data; + + m_heap.push_back(node); + sift_up(node->index); + } + + /** + * Update priority and/or data of an existing node. + * + * This function takes time O(log(n)). + */ + void update(Node* node, PrioType prio, const DataType& data) + { + IndexType index = node->index; + assert(index != INVALID_INDEX); + assert(m_heap[index] == node); + + PrioType prev_prio = node->prio; + node->prio = prio; + node->data = data; + + if (prio < prev_prio) { + sift_up(index); + } else if (prio > prev_prio) { + sift_down(index); + } + } + + /** + * Remove the specified element from the queue. + * + * This function takes time O(log(n)). + */ + void remove(Node* node) + { + IndexType index = node->index; + assert(index != INVALID_INDEX); + assert(m_heap[index] == node); + + node->index = INVALID_INDEX; + + Node* move_node = m_heap.back(); + m_heap.pop_back(); + + if (index < m_heap.size()) { + m_heap[index] = move_node; + move_node->index = index; + if (move_node->prio < node->prio) { + sift_up(index); + } else if (move_node->prio > node->prio) { + sift_down(index); + } + } + } + +private: + /** Repair the heap along an ascending path to the root. */ + void sift_up(IndexType index) + { + Node* node = m_heap[index]; + PrioType prio = node->prio; + + while (index > 0) { + IndexType next_index = (index - 1) / 2; + Node* next_node = m_heap[next_index]; + if (next_node->prio <= prio) { + break; + } + m_heap[index] = next_node; + next_node->index = index; + index = next_index; + } + + node->index = index; + m_heap[index] = node; + } + + /** Repair the heap along a descending path. */ + void sift_down(IndexType index) + { + Node* node = m_heap[index]; + PrioType prio = node->prio; + + IndexType num_elem = m_heap.size(); + IndexType last_row = (m_heap.size() - 1) / 2; + + while (index < last_row) { + IndexType next_index = 2 * index + 1; + Node* next_node = m_heap[next_index]; + + if (next_index + 1 < num_elem) { + Node* tmp_node = m_heap[next_index + 1]; + if (tmp_node->prio <= next_node->prio) { + ++next_index; + next_node = tmp_node; + } + } + + if (next_node->prio >= prio) { + break; + } + + m_heap[index] = next_node; + next_node->index = index; + + index = next_index; + } + + m_heap[index] = node; + node->index = index; + } + + /** Array of nodes. */ + std::vector m_heap; +}; diff --git a/cpp/test_datastruct.cpp b/cpp/test_datastruct.cpp new file mode 100644 index 0000000..dbf58d8 --- /dev/null +++ b/cpp/test_datastruct.cpp @@ -0,0 +1,660 @@ +/* + * Unit tests for data structures. + * + * Depends on the Boost.Test unit test framework. + * Tested with Boost v1.74, available from https://www.boost.org/ + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define BOOST_TEST_MODULE datastruct +#include + +#include "datastruct.h" + + +/* ********** Test DisjointSetNode ********** */ + +BOOST_AUTO_TEST_SUITE(test_disjoint_set) + +BOOST_AUTO_TEST_CASE(test_single) +{ + using Node = DisjointSetNode; + Node a(1); + BOOST_TEST(a.find() == 1); + a.set_label(2); + BOOST_TEST(a.find() == 2); +} + +BOOST_AUTO_TEST_CASE(test_simple) +{ + using Node = DisjointSetNode; + Node a(1); + Node b(2); + Node c(3); + Node* m = a.merge(&b); + m->set_label(10); + BOOST_TEST(a.find() == 10); + BOOST_TEST(b.find() == 10); + BOOST_TEST(c.find() == 3); + m = m->merge(&c); + m->set_label(11); + BOOST_TEST(a.find() == 11); + BOOST_TEST(c.find() == 11); + a.detach(1); + b.detach(2); + c.detach(3); + BOOST_TEST(a.find() == 1); + BOOST_TEST(b.find() == 2); + BOOST_TEST(c.find() == 3); +} + +BOOST_AUTO_TEST_CASE(test_multilevel) +{ + using Node = DisjointSetNode; + + std::unique_ptr nodes[27]; + for (int i = 0; i < 27; ++i) { + nodes[i].reset(new Node(i)); + } + + std::vector level1; + for (int i = 0; i < 9; ++i) { + Node* m = nodes[3*i].get(); + for (int k = 1; k < 3; ++k) { + m = m->merge(nodes[3*i+k].get()); + } + m->set_label(100 + i); + level1.push_back(m); + } + + std::vector level2; + for (int i = 0; i < 3; ++i) { + Node* m = level1[3*i]; + for (int k = 1; k < 3; ++k) { + m = m->merge(level1[3*i+k]); + } + m->set_label(200 + i); + level2.push_back(m); + } + + Node* m = level2[0]; + for (int k = 1; k < 3; ++k) { + m = m->merge(level2[k]); + } + m->set_label(300); + + for (int i = 0; i < 27; ++i) { + BOOST_TEST(nodes[i]->find() == 300); + } + + for (int i = 0; i < 3; ++i) { + level2[i]->detach(200 + i); + } + + for (int i = 0; i < 27; ++i) { + BOOST_TEST(nodes[i]->find() == 200 + i / 9); + } + + for (int i = 0; i < 6; ++i) { + level1[i]->detach(100 + i); + } + + for (int i = 0; i < 18; ++i) { + BOOST_TEST(nodes[i]->find() == 100 + i / 3); + } + for (int i = 18; i < 27; ++i) { + BOOST_TEST(nodes[i]->find() == 202); + } + + for (int i = 6; i < 9; ++i) { + level1[i]->detach(100 + i); + } + + for (int i = 0; i < 27; ++i) { + BOOST_TEST(nodes[i]->find() == 100 + i / 3); + } + + for (int i = 6; i < 27; ++i) { + nodes[i]->detach(i); + } + + for (int i = 0; i < 6; ++i) { + BOOST_TEST(nodes[i]->find() == 100 + i / 3); + } + for (int i = 6; i < 27; ++i) { + BOOST_TEST(nodes[i]->find() == i); + } + + for (int i = 0; i < 6; i++) { + nodes[i]->detach(i); + } + + for (int i = 0; i < 27; ++i) { + BOOST_TEST(nodes[i]->find() == i); + } +} + +BOOST_AUTO_TEST_CASE(test_random) +{ + using Node = DisjointSetNode; + + std::mt19937 rng(12345); + + const int num_nodes = 1000; + std::unique_ptr nodes[num_nodes]; + for (int i = 0; i < num_nodes; ++i) { + nodes[i].reset(new Node(i)); + } + + std::unordered_map blossoms; + std::unordered_map> sub_blossoms; + std::vector top_blossoms; + + for (int i = 0; i < num_nodes; ++i) { + blossoms[i] = nodes[i].get(); + top_blossoms.push_back(i); + } + + int next_blossom = num_nodes; + + auto make_blossom = [&]() { + int b = next_blossom; + ++next_blossom; + + int nsub = 1 + 2 * std::uniform_int_distribution<>(1, 4)(rng); + std::vector subs(nsub); + + for (int i = 0; i < nsub; ++i) { + int p = std::uniform_int_distribution<>(0, top_blossoms.size() - 1)(rng); + subs[i] = top_blossoms[p]; + top_blossoms.erase(top_blossoms.begin() + p); + } + + Node *m = blossoms[subs[0]]; + for (int i = 1; i < nsub; ++i) { + m = m->merge(blossoms[subs[i]]); + } + m->set_label(b); + + blossoms[b] = m; + sub_blossoms[b] = std::move(subs); + top_blossoms.push_back(b); + }; + + auto expand_blossom = [&](int b) { + top_blossoms.erase( + std::find(top_blossoms.begin(), top_blossoms.end(), b)); + for (int t : sub_blossoms[b]) { + blossoms[t]->detach(t); + top_blossoms.push_back(t); + } + blossoms.erase(b); + sub_blossoms.erase(b); + }; + + auto check_membership = [&](int b, int label) { + std::vector q; + q.push_back(b); + while (! q.empty()) { + b = q.back(); + q.pop_back(); + if (b < num_nodes) { + BOOST_TEST(nodes[b]->find() == label); + } else { + for (int t : sub_blossoms[b]) { + q.push_back(t); + } + } + } + }; + + for (int k = 0; k < 100; ++k) { + make_blossom(); + } + + for (int b : top_blossoms) { + check_membership(b, b); + } + + std::vector top_groups; + for (int b : top_blossoms) { + if (b >= num_nodes) { + top_groups.push_back(b); + } + } + std::shuffle(top_groups.begin(), top_groups.end(), rng); + for (int k = 0; k < 50; ++k) { + expand_blossom(top_groups[k]); + } + top_groups.clear(); + + for (int b : top_blossoms) { + check_membership(b, b); + } + + for (int k = 0; k < 50; ++k) { + make_blossom(); + } + + for (int b : top_blossoms) { + check_membership(b, b); + } + + for (int b : top_blossoms) { + if (b >= num_nodes) { + top_groups.push_back(b); + } + } + std::shuffle(top_groups.begin(), top_groups.end(), rng); + for (int b : top_groups) { + expand_blossom(b); + } + + for (int b : top_blossoms) { + check_membership(b, b); + } +} + +BOOST_AUTO_TEST_SUITE_END() + + +/* ********** Test SplittablePriorityQueue ********** */ + +template +static void check_min_elem(const std::pair& pair, + PrioType prio, + const DataType& data) +{ + BOOST_TEST(pair.first == prio); + BOOST_TEST(pair.second == data); +} + + +BOOST_AUTO_TEST_SUITE(test_splittable_queue) + +BOOST_AUTO_TEST_CASE(test_single) +{ + using Queue = SplittablePriorityQueue; + Queue q; + + BOOST_TEST(q.empty() == true); + + Queue::Node n; + q.insert(&n, 3, 4, 101); + + BOOST_TEST(q.empty() == false); + check_min_elem(q.find_min(), 4, 101); + + q.update(&n, 5, 102); + check_min_elem(q.find_min(), 4, 101); + + q.update(&n, 3, 103); + check_min_elem(q.find_min(), 3, 103); + + q.clear(); + BOOST_TEST(q.empty() == true); +} + +BOOST_AUTO_TEST_CASE(test_simple) +{ + using Queue = SplittablePriorityQueue; + Queue q; + Queue::Node n1, n2, n3, n4, n5, nx; + + q.insert(&n1, 1, 5, "a"); + q.insert(&n2, 2, 6, "b"); + q.insert(&n3, 3, 7, "c"); + q.insert(&n4, 4, 4, "d"); + q.insert(&n5, 5, 3, "e"); + check_min_elem(q.find_min(), 3, std::string("e")); + + q.update(&n1, 4, "f"); + check_min_elem(q.find_min(), 3, std::string("e")); + + q.update(&n3, 2, "h"); + check_min_elem(q.find_min(), 2, std::string("h")); + + Queue q2 = q.split(3); + check_min_elem(q.find_min(), 4, std::string("f")); + check_min_elem(q2.find_min(), 2, std::string("h")); + + q.insert(&nx, 3, 1, "x"); + check_min_elem(q.find_min(), 1, std::string("x")); + check_min_elem(q2.find_min(), 2, std::string("h")); + + q.clear(); + q2.clear(); +} + +BOOST_AUTO_TEST_CASE(test_split_empty) +{ + using Queue = SplittablePriorityQueue; + Queue q; + Queue q2 = q.split(10); + BOOST_TEST(q.empty() == true); + BOOST_TEST(q2.empty() == true); +} + +BOOST_AUTO_TEST_CASE(test_split_oneway) +{ + using Queue = SplittablePriorityQueue; + Queue q; + Queue::Node n4, n5, n6; + q.insert(&n4, 4, 3, "a"); + q.insert(&n5, 5, 4, "b"); + q.insert(&n6, 6, 2, "c"); + Queue q2 = q.split(7); + BOOST_TEST(q.empty() == false); + BOOST_TEST(q2.empty() == true); + check_min_elem(q.find_min(), 2, std::string("c")); + q.clear(); + + q.insert(&n4, 4, 3, "a"); + q.insert(&n5, 5, 4, "b"); + q.insert(&n6, 6, 2, "c"); + q2 = q.split(4); + BOOST_TEST(q.empty() == true); + BOOST_TEST(q2.empty() == false); + check_min_elem(q2.find_min(), 2, std::string("c")); + q2.clear(); +} + +BOOST_AUTO_TEST_CASE(test_larger) +{ + using Queue = SplittablePriorityQueue; + Queue q; + Queue::Node nodes[15]; + + q.insert(&nodes[7], 7, 5, "h"); + q.insert(&nodes[6], 6, 4, "g"); + q.insert(&nodes[8], 8, 2, "i"); + q.insert(&nodes[5], 5, 4, "f"); + q.insert(&nodes[9], 9, 6, "j"); + q.insert(&nodes[4], 4, 8, "e"); + q.insert(&nodes[10], 10, 4, "k"); + q.insert(&nodes[3], 3, 5, "d"); + q.insert(&nodes[11], 11, 6, "l"); + q.insert(&nodes[2], 2, 7, "c"); + q.insert(&nodes[12], 12, 8, "m"); + q.insert(&nodes[1], 1, 3, "b"); + q.insert(&nodes[13], 13, 1, "n"); + q.insert(&nodes[0], 0, 9, "a"); + q.insert(&nodes[14], 14, 7, "o"); + + check_min_elem(q.find_min(), 1, std::string("n")); + + Queue q2 = q.split(10); + check_min_elem(q.find_min(), 2, std::string("i")); + check_min_elem(q2.find_min(), 1, std::string("n")); + + Queue q1 = q.split(5); + check_min_elem(q.find_min(), 3, std::string("b")); + check_min_elem(q1.find_min(), 2, std::string("i")); + check_min_elem(q2.find_min(), 1, std::string("n")); + + q.clear(); + q1.clear(); + q2.clear(); +} + +BOOST_AUTO_TEST_CASE(test_random) +{ + using Queue = SplittablePriorityQueue; + + std::mt19937 rng(23456); + std::uniform_int_distribution<> index_distribution(0, 1000000); + std::uniform_int_distribution<> prio_distribution(0, 1000000); + + Queue q; + std::map, int, int>> elems; + int next_data = 1; + + for (int i = 0; i < 200; ++i) { + + // Insert stuff into the queue. + for (int k = 0; k < 1000; ++k) { + int idx = index_distribution(rng); + int prio = prio_distribution(rng); + int data = next_data; + ++next_data; + auto it = elems.find(idx); + if (it != elems.end()) { + auto& v = it->second; + Queue::Node* nptr = std::get<0>(v).get(); + int pprio = std::get<1>(v); + q.update(nptr, prio, data); + if (prio < pprio) { + std::get<1>(v) = prio; + std::get<2>(v) = data; + } + } else { + std::unique_ptr nptr(new Queue::Node); + q.insert(nptr.get(), idx, prio, data); + elems[idx] = std::make_tuple(std::move(nptr), prio, data); + } + } + + // Check min element. + int min_prio = INT_MAX; + int min_data = 0; + for (const auto& v : elems) { + int prio = std::get<1>(v.second); + int data = std::get<2>(v.second); + if (prio < min_prio) { + min_prio = prio; + min_data = data; + } + } + + check_min_elem(q.find_min(), min_prio, min_data); + + // Split the queue. + int threshold = index_distribution(rng); + Queue q2 = q.split(threshold); + + // Keep one queue and discard the other. + if (rng() % 2 == 0) { + q.clear(); + q = std::move(q2); + elems.erase(elems.begin(), elems.lower_bound(threshold)); + } else { + q2.clear(); + elems.erase(elems.lower_bound(threshold), elems.end()); + } + + // Check min element. + min_prio = INT_MAX; + min_data = 0; + for (const auto& v : elems) { + int prio = std::get<1>(v.second); + int data = std::get<2>(v.second); + if (prio < min_prio) { + min_prio = prio; + min_data = data; + } + } + + if (min_prio < INT_MAX) { + check_min_elem(q.find_min(), min_prio, min_data); + } + } + + q.clear(); +} + +BOOST_AUTO_TEST_SUITE_END() + + +/* ********** Test PriorityQueue ********** */ + +BOOST_AUTO_TEST_SUITE(test_priority_queue) + +BOOST_AUTO_TEST_CASE(test_empty) +{ + using Queue = PriorityQueue; + Queue q; + BOOST_TEST(q.empty() == true); +} + +BOOST_AUTO_TEST_CASE(test_single) +{ + using Queue = PriorityQueue; + Queue q; + + Queue::Node n1; + q.insert(&n1, 5, "a"); + + BOOST_TEST(q.empty() == false); + check_min_elem(q.find_min(), 5, std::string("a")); + + q.update(&n1, 3, "a"); + check_min_elem(q.find_min(), 3, std::string("a")); + + q.remove(&n1); + BOOST_TEST(q.empty() == true); +} + +BOOST_AUTO_TEST_CASE(test_simple) +{ + using Queue = PriorityQueue; + Queue q; + Queue::Node nodes[10]; + + q.insert(&nodes[0], 9, 'a'); + check_min_elem(q.find_min(), 9, 'a'); + + q.insert(&nodes[1], 4, 'b'); + check_min_elem(q.find_min(), 4, 'b'); + + q.insert(&nodes[2], 7, 'c'); + check_min_elem(q.find_min(), 4, 'b'); + + q.insert(&nodes[3], 5, 'd'); + check_min_elem(q.find_min(), 4, 'b'); + + q.insert(&nodes[4], 8, 'e'); + check_min_elem(q.find_min(), 4, 'b'); + + q.insert(&nodes[5], 6, 'f'); + check_min_elem(q.find_min(), 4, 'b'); + + q.insert(&nodes[6], 4, 'g'); + q.insert(&nodes[7], 5, 'h'); + q.insert(&nodes[8], 2, 'i'); + check_min_elem(q.find_min(), 2, 'i'); + + q.insert(&nodes[9], 6, 'j'); + check_min_elem(q.find_min(), 2, 'i'); + + q.update(&nodes[2], 1, 'c'); + check_min_elem(q.find_min(), 1, 'c'); + + q.update(&nodes[4], 3, 'e'); + check_min_elem(q.find_min(), 1, 'c'); + + q.remove(&nodes[2]); + check_min_elem(q.find_min(), 2, 'i'); + + q.remove(&nodes[8]); + check_min_elem(q.find_min(), 3, 'e'); + + q.remove(&nodes[4]); + q.remove(&nodes[1]); + check_min_elem(q.find_min(), 4, 'g'); + + q.remove(&nodes[3]); + q.remove(&nodes[9]); + check_min_elem(q.find_min(), 4, 'g'); + + q.remove(&nodes[6]); + check_min_elem(q.find_min(), 5, 'h'); + + BOOST_TEST(q.empty() == false); + q.clear(); + BOOST_TEST(q.empty() == true); +} + +BOOST_AUTO_TEST_CASE(test_random) +{ + using Queue = PriorityQueue; + Queue q; + + const int num_elem = 1000; + std::vector, int, int>> elems; + int next_data = 0; + + std::mt19937 rng(34567); + + auto check = [&q,&elems]() { + int min_prio, min_data; + std::tie(min_prio, min_data) = q.find_min(); + int best_prio = INT_MAX; + bool found = false; + for (const auto& v : elems) { + int this_prio = std::get<1>(v); + int this_data = std::get<2>(v); + best_prio = std::min(best_prio, this_prio); + if ((this_prio == min_prio) && (this_data == min_data)) { + found = true; + } + } + BOOST_TEST(found == true); + BOOST_TEST(min_prio == best_prio); + }; + + for (int i = 0; i < num_elem; ++i) { + ++next_data; + int prio = std::uniform_int_distribution<>(0, 1000000)(rng); + std::unique_ptr nptr(new Queue::Node); + q.insert(nptr.get(), prio, next_data); + elems.push_back(std::make_tuple(std::move(nptr), prio, next_data)); + check(); + } + + for (int i = 0; i < 10000; ++i) { + int p = std::uniform_int_distribution<>(0, num_elem - 1)(rng); + Queue::Node* node = std::get<0>(elems[p]).get(); + int prio = std::get<1>(elems[p]); + int data = std::get<2>(elems[p]); + prio = std::uniform_int_distribution<>(0, prio)(rng); + q.update(node, prio, data); + std::get<1>(elems[p]) = prio; + check(); + + p = std::uniform_int_distribution<>(0, num_elem - 1)(rng); + node = std::get<0>(elems[p]).get(); + q.remove(node); + elems.erase(elems.begin() + p); + check(); + + ++next_data; + prio = std::uniform_int_distribution<>(0, 1000000)(rng); + std::unique_ptr nptr(new Queue::Node); + q.insert(nptr.get(), prio, next_data); + elems.push_back(std::make_tuple(std::move(nptr), prio, next_data)); + check(); + } + + for (int i = 0; i < num_elem; ++i) { + int p = std::uniform_int_distribution<>(0, num_elem - 1 - i)(rng); + Queue::Node* node = std::get<0>(elems[p]).get(); + q.remove(node); + elems.erase(elems.begin() + p); + if (! elems.empty()) { + check(); + } + } + + BOOST_TEST(q.empty() == true); +} + +BOOST_AUTO_TEST_SUITE_END()