/* * Data structures for matching. */ #include #include #include /** * Element in a disjoint set. * * Every instance of DisjointSetNode is a member of exactly one disjoint set. * Some instances of DisjointSetNode also represent a set. * Every set has a "label". * * The following operations can be done efficiently: * - Find the label of the set that contains a given node. * - Merge two sets. * - Undo a previous merge step. * * Internally, a set is represented by a tree of DisjointSetNode instances * that are linked through parent pointers. The shape of the tree is * controlled such that its depth is at most O(log(n)), where "n" is * the number of elements in the set. * * See also https://en.wikipedia.org/wiki/Disjoint-set_data_structure */ template class DisjointSetNode { public: /** * Create a new element and a new set containing only the new element. * * The new instance represents the new element as well as the new set. * * This function takes time O(1). */ explicit DisjointSetNode(const LabelType& label) : m_label(label), m_parent(nullptr), m_size(1) { } // Prevent copying. DisjointSetNode(const DisjointSetNode&) = delete; DisjointSetNode& operator=(const DisjointSetNode&) = delete; /** * Return the label of the set that contains the element represented * by this instance. * * This function takes time O(log(n)). */ LabelType find() const { const DisjointSetNode* p = this; while (p->m_parent != nullptr) { p = p->m_parent; } return p->m_label; } /** * Change the label of the set represented by this instance. * * Only sets have labels. Elements do not have labels. * Some instances of DisjointSetNode do not represent a set, * and it is not allowed to assign names to such instances. * * This function takes time O(1). */ void set_label(const LabelType& label) { assert(m_parent == nullptr); m_label = label; } /** * Merge the sets represented by this instance and by "other". * * Both instances must represent a set (not merely an element). * The merged set will initially have no label. * * This function takes time O(1). * * @return Instance representing the merged set. */ DisjointSetNode* merge(DisjointSetNode* other) { assert(m_parent == nullptr); assert(other->m_parent == nullptr); assert(other != this); m_label = LabelType{}; other->m_label = LabelType{}; if (m_size < other->m_size) { this->m_parent = other; other->m_size += this->m_size; return other; } else { other->m_parent = this; this->m_size += other->m_size; return this; } } /** * Undo a previous "merge" step. * * This function detaches a subset from a set into which it was * previously merged, and re-establishes the subset as an independent * set. * * The following conditions must hold: * - At some point in the past, there was a group of K instances * of DisjointSetNode (s_1, s_2, ..., s_k) that each represented * a set. * - These sets were later merged through (K-1) merge steps. * - Since that point in time, no other merges into the same set * occurred, disregarding any merges that have already been undone. * * The group of (K-1) merge steps can be undone by calling "detach" * on each of the K instances. These K calls may occur in any order. * Afterwards, each of these instances will again represent the same set * that it represented originally. The original labels of these sets * are lost. Labels must therefore be re-assigned through parameters * to the "detach" calls. * * This function takes time O(log(n)). */ void detach(const LabelType& label) { DisjointSetNode* p = m_parent; while (p != nullptr) { p->m_size -= m_size; p = p->m_parent; } m_parent = nullptr; m_label = label; } private: LabelType m_label; DisjointSetNode* m_parent; std::size_t m_size; }; /** * Priority queue that can be split into separate queues. * * Each element in the queue has an index, a priority and attached data. * The following operations can be done efficiently: * - Find the element with lowest priority. * - Insert an element with given index, priority and data. * - Update the priority of an existing element. * - Split the queue such that all elements with index below a specified * threshold end up in one queue and all remaining elements in * a separate queue. * * The implementation is essentially an AVL tree, with minimum-priority * tracking added to it. The implementation follows the outlines of * https://en.wikipedia.org/wiki/Avl_tree */ template class SplittablePriorityQueue { public: typedef unsigned int IndexType; /** Element in a queue. */ struct Node { public: Node() : parent(nullptr), left_child(nullptr), right_child(nullptr), height(0) { } ~Node() { // Check that the node is not contained in any queue. assert(height == 0); } // Prevent copying. Node(const Node&) = delete; Node& operator=(const Node&) = delete; private: IndexType index; PrioType prio; DataType data; PrioType best_prio; DataType best_data; Node* parent; Node* left_child; Node* right_child; unsigned int height; friend class SplittablePriorityQueue; }; /** Construct an empty queue. */ SplittablePriorityQueue() : m_root(nullptr) { } ~SplittablePriorityQueue() { // Check that the queue is empty. assert(! m_root); } // Prevent copying. SplittablePriorityQueue(const SplittablePriorityQueue&) = delete; SplittablePriorityQueue& operator=(const SplittablePriorityQueue&) = delete; /** Construct a queue and move elements from "other" to the new queue. */ SplittablePriorityQueue(SplittablePriorityQueue&& other) { m_root = other.m_root; other.m_root = nullptr; } /** Move all elements from "other" to this queue. */ SplittablePriorityQueue& operator=(SplittablePriorityQueue&& other) { assert(! m_root); std::swap(m_root, other.m_root); return *this; } /** * Remove all elements from the queue. * * This function takes time O(n). */ void clear() { Node* node = m_root; m_root = nullptr; while (node) { Node* prev = node; if (node->left_child) { node = node->left_child; prev->left_child = nullptr; } else if (node->right_child) { node = node->right_child; prev->right_child = nullptr; } else { node = node->parent; prev->parent = nullptr; prev->height = 0; } } } /** Return true if the queue is empty. */ bool empty() const { return (! m_root); } /** * Return the minimum-priority element as tuple (prio, data). * * The queue must be non-empty. * * This function takes time O(1). */ std::pair find_min() const { assert(m_root); return std::make_pair(m_root->best_prio, m_root->best_data); } /** * Initialize the specified Node instance and insert it into the queue. * * The specified Node instance must currently be unused; not contained * in any queue. * * The queue stores a pointer to the Node instance. The caller must * ensure that the instance remains valid for as long as it is contained * in any queue. * * This function takes time O(log(n)). */ void insert(Node* node, IndexType index, PrioType prio, const DataType& data) { assert(node->height == 0); node->index = index; node->prio = prio; node->data = data; node->best_prio = prio; node->best_data = data; node->parent = nullptr; node->left_child = nullptr; node->right_child = nullptr; node->height = 1; if (! m_root) { m_root = node; return; } Node* p = m_root; while (true) { assert(index != p->index); if (index < p->index) { if (! p->left_child) { p->left_child = node; break; } p = p->left_child; } else { if (! p->right_child) { p->right_child = node; break; } p = p->right_child; } } node->parent = p; rebalance_up(node); } /** * Decrease priority and update data of an existing element. * * Nothing is done if the specified priority is greater or equal * to the current priority of the existing element. * * This function takes time O(log(n)). */ void update(Node* node, PrioType prio, const DataType& data) { assert(node->height != 0); // Do nothing unless the new priority is strictly lower. if (prio >= node->prio) { return; } // Update the node itself. node->prio = prio; node->data = data; // Update the best-prio information in the node and its ancestors. while (true) { repair_node_best_prio(node); if (! node->parent) { break; } node = node->parent; } // Check that the node belongs to this queue. assert(node == m_root); } /** * Split the queue by index. * * All elements with index less than "threshold" remain in the queue. * All elements with index greater or equal to "threshold" are moved * to a new queue. That new queue is returned. * * This function takes time O(log(n)). */ SplittablePriorityQueue split(IndexType threshold) { SplittablePriorityQueue new_queue; // Special case for an empty queue. if (! m_root) { return new_queue; } // Descend down the tree along a path as close as possible to the // threshold. Stop when we reach a subtree that is entirely on // one side of the threshold. Node* node = m_root; while (true) { if (node->index < threshold) { /* * This node is below the threshold. Descend right. * * (right_child) { break; } node = node->right_child; } else if (node->index > threshold) { /* * This node is above the threshold. Descend left. * * (>T) * / \ * / \ * go here --> (?) (>T) */ if (! node->left_child) { break; } node = node->left_child; } else { /* * This node is exactly at the threshold. * Descend left if possible, then stop. * * (=T) * / \ * / \ * (T) */ if (node->left_child) { node = node->left_child; } break; } } // Build "left_tree" with nodes < threshold, // and "right_tree" with nodes >= threshold. Node* left_tree = nullptr; Node* right_tree = nullptr; // Initially, either the left or right tree consists of only // the subtree where the descending path ended, and the other // tree is empty. if (node->index < threshold) { left_tree = node; } else { right_tree = node; } // Detach node from its parent. Node* parent = node->parent; detach_node(node); // Retrace up the tree. // On the way up, join each node to the left or right subtree. while (parent) { // Move up to parent. node = parent; parent = node->parent; detach_node(node); if (node->index < threshold) { /* * Join node (with its descendants) to the left tree. * * N <--- new left_tree * / . * X . * / \ (left_tree) */ assert(! node->right_child); left_tree = join(node->left_child, node, left_tree); } else { /* * Join node (with its descendants) to the right tree. * * N <--- new right_tree * . \ * . X * (right_tree) / \ */ assert(! node->left_child); right_tree = join(right_tree, node, node->right_child); } } // Keep the left tree in this instance. m_root = left_tree; // Return the right tree as a new queue. new_queue.m_root = right_tree; return new_queue; } private: /** Return node height, or 0 if node == nullptr. */ static unsigned int get_node_height(const Node* node) { return node ? node->height : 0; } /** Detach a node from its parent. */ static void detach_node(Node* node) { Node* parent = node->parent; if (parent) { node->parent = nullptr; if (parent->left_child == node) { parent->left_child = nullptr; } else if (parent->right_child == node) { parent->right_child = nullptr; } } } /** Repair best-priority information in the specified node. */ static void repair_node_best_prio(Node* node) { PrioType best_prio = node->prio; DataType best_data = node->data; Node* lchild = node->left_child; if (lchild && (lchild->best_prio < best_prio)) { best_prio = lchild->best_prio; best_data = lchild->best_data; } Node* rchild = node->right_child; if (rchild && (rchild->best_prio < best_prio)) { best_prio = rchild->best_prio; best_data = rchild->best_data; } node->best_prio = best_prio; node->best_data = best_data; } /** * Repair the height and best-priority information of a node * after modifying its children. * * After repairing a node, it is typically necessary to also repair * its ancestors. */ static void repair_node(Node* node) { Node* lchild = node->left_child; Node* rchild = node->right_child; // Repair node height. node->height = 1 + std::max(get_node_height(lchild), get_node_height(rchild)); // Repair best-priority. repair_node_best_prio(node); } /** Rotate the subtree to the left and return the new root of the subtree. */ static Node* rotate_left(Node* node) { /* * N C * / \ / \ * A C ---> N D * / \ / \ * B D A B */ Node* parent = node->parent; Node* new_top = node->right_child; assert(new_top); Node* nb = new_top->left_child; node->right_child = nb; if (nb) { nb->parent = node; } new_top->left_child = node; node->parent = new_top; new_top->parent = parent; if (parent) { if (parent->left_child == node) { parent->left_child = new_top; } else if (parent->right_child == node) { parent->right_child = new_top; } else { assert(false); } } repair_node(node); repair_node(new_top); return new_top; } /** Rotate the subtree to the right and return the new root of the subtree. */ Node* rotate_right(Node* node) { /* * N B * / \ / \ * B D ---> A N * / \ / \ * A C C D */ Node* parent = node->parent; Node* new_top = node->left_child; assert(new_top); Node* nc = new_top->right_child; node->left_child = nc; if (nc) { nc->parent = node; } new_top->right_child = node; node->parent = new_top; new_top->parent = parent; if (parent) { if (parent->left_child == node) { parent->left_child = new_top; } else if (parent->right_child == node) { parent->right_child = new_top; } else { assert(false); } } repair_node(node); repair_node(new_top); return new_top; } /** Repair and rebalance the specified node and its ancestors. */ void rebalance_up(Node* node) { while (true) { unsigned int lh = get_node_height(node->left_child); unsigned int rh = get_node_height(node->right_child); if (lh > rh + 1) { /* * This node is left-heavy. Rotate right to rebalance. * * N L * / \ / \ * L \ / N * / \ \ ---> / / \ * A B \ A B \ * \ \ * R R */ Node *lchild = node->left_child; unsigned int llh = get_node_height(lchild->left_child); unsigned int lrh = get_node_height(lchild->right_child); if (llh < lrh) { // Double rotation. lchild = rotate_left(lchild); } node = rotate_right(node); } else if (lh + 1 < rh) { /* * This node is right-heavy. Rotate left to rebalance. * * N R * / \ / \ * / R N \ * / / \ ---> / \ \ * / A B / A B * / / * L L */ Node *rchild = node->right_child; unsigned int rlh = get_node_height(rchild->left_child); unsigned int rrh = get_node_height(rchild->right_child); if (rlh > rrh) { // Double rotation. rchild = rotate_right(rchild); } node = rotate_left(node); } else { // No rotation, but must still repair the node. repair_node(node); } if (! node->parent) { // Reached root. m_root = node; break; } // Continue rebalancing at the parent. node = node->parent; } } /** * Join a left subtree, middle node and right subtree together. * * The left subtree is higher than the right subtree. */ Node* join_right(Node* ltree, Node* node, Node* rtree) { assert(ltree); unsigned int lh = ltree->height; unsigned int rh = get_node_height(rtree); assert(lh > rh + 1); /* * Descend down the right spine of "ltree". * Stop at a node with compatible height, then insert "node" * and attach "rtree". * * ltree * / \ * X * / \ * X <-- cur * / \ * node * / \ * X rtree */ // Descend to a point with compatible height. Node* cur = ltree; while (cur->right_child && (cur->right_child->height > rh + 1)) { cur = cur->right_child; } // Insert "node" and "rtree". node->left_child = cur->right_child; node->right_child = rtree; if (node->left_child) { node->left_child->parent = node; } if (rtree) { rtree->parent = node; } cur->right_child = node; node->parent = cur; // A double rotation may be necessary. if ((! cur->left_child) || (cur->left_child->height <= rh)) { node = rotate_right(node); cur = rotate_left(cur); } else { repair_node(node); repair_node(cur); } // Ascend from "cur" to the root of the tree; repair and rebalance. while (cur->parent) { cur = cur->parent; assert(cur->left_child); assert(cur->right_child); if (cur->left_child->height + 1 < cur->right_child->height) { rotate_left(cur); } else { repair_node(cur); } } return cur; } /** * Join a left subtree, middle node and right subtree together. * * The right subtree is higher than the left subtree. */ Node* join_left(Node* ltree, Node* node, Node* rtree) { assert(rtree); unsigned int lh = get_node_height(ltree); unsigned int rh = rtree->height; assert(lh + 1 < rh); /* * Descend down the left spine of "rtree". * Stop at a node with compatible height, then insert "node" * and attach "ltree". * * rtree * / \ * X * / \ * cur --> X * / \ * node * / \ * ltree X */ // Descend to a point with compatible height. Node* cur = rtree; while (cur->left_child && (cur->left_child->height > lh + 1)) { cur = cur->left_child; } // Insert "node" and "ltree". node->left_child = ltree; node->right_child = cur->left_child; if (ltree) { ltree->parent = node; } if (node->right_child) { node->right_child->parent = node; } cur->left_child = node; node->parent = cur; // A double rotation may be necessary. if ((! cur->right_child) || (cur->right_child->height <= lh)) { node = rotate_left(node); cur = rotate_right(cur); } else { repair_node(node); repair_node(cur); } // Ascend from "cur" to the root of the tree; repair and rebalance. while (cur->parent) { cur = cur->parent; assert(cur->left_child); assert(cur->right_child); if (cur->left_child->height > cur->right_child->height + 1) { cur = rotate_right(cur); } else { repair_node(cur); } } return cur; } /** * Join a left subtree, middle node and right subtree together. * * The index of all nodes in the left subtree must be less than * the index of the middle node. The index of all nodes in * the right subtree must be greater than the middle node. * * The left or right subtree may initially be a child of the middle * node; such links will be broken as needed. * * The left and right subtrees must be consistent, semi-balanced trees. * Height and priority of the middle node may initially be inconsistent; * this function will repair it. * * @return Root node of the joined tree. */ Node* join(Node* ltree, Node* node, Node* rtree) { unsigned int lh = get_node_height(ltree); unsigned int rh = get_node_height(rtree); if (lh > rh + 1) { assert(ltree); ltree->parent = nullptr; return join_right(ltree, node, rtree); } else if (lh + 1 < rh) { assert(rtree); rtree->parent = nullptr; return join_left(ltree, node, rtree); } else { /* * Subtree heights are compatible. Just join them. * * node * / \ * ltree rtree * / \ / \ */ node->left_child = ltree; if (ltree) { ltree->parent = node; } node->right_child = rtree; if (rtree) { rtree->parent = node; } repair_node(node); return node; } } /** Root node of the queue, or "nullptr" if the queue is empty. */ Node* m_root; }; /** * Normal min-priority queue based on a binary heap. */ template class PriorityQueue { public: typedef unsigned int IndexType; static constexpr IndexType INVALID_INDEX = std::numeric_limits::max(); /** Element in a PriorityQueue. */ struct Node { public: Node() : index(INVALID_INDEX) { } ~Node() { // Check that the node is not contained in any queue. assert(index == INVALID_INDEX); } // Prevent copying. Node(const Node&) = delete; Node& operator=(const Node&) = delete; private: IndexType index; PrioType prio; DataType data; friend class PriorityQueue; }; /** Construct an empty queue. */ PriorityQueue() { } ~PriorityQueue() { // Check that the queue is empty. assert(m_heap.empty()); } // Prevent copying. PriorityQueue(const PriorityQueue&) = delete; PriorityQueue& operator=(const PriorityQueue&) = delete; /** Construct a queue and move elements from "other" to the new queue. */ PriorityQueue(PriorityQueue&& other) : m_heap(std::move(other.m_heap)) { } /** * Remove all elements from the queue. * * This function takes time O(n). */ void clear() { for (Node* node : m_heap) { node->index = INVALID_INDEX; } m_heap.clear(); } /** Return true if the queue is empty. */ bool empty() const { return m_heap.empty(); } /** * Return the minimum-priority element as tuple (prio, data). * * The queue must be non-empty. * * This function takes time O(1). */ std::pair find_min() const { assert(! m_heap.empty()); Node* top = m_heap.front(); return std::make_pair(top->prio, top->data); } /** * Insert a new element into the queue. * * This function takes time O(log(n)). */ void insert(Node* node, PrioType prio, const DataType& data) { assert(node->index == INVALID_INDEX); node->index = m_heap.size(); node->prio = prio; node->data = data; m_heap.push_back(node); sift_up(node->index); } /** * Update priority and/or data of an existing node. * * This function takes time O(log(n)). */ void update(Node* node, PrioType prio, const DataType& data) { IndexType index = node->index; assert(index != INVALID_INDEX); assert(m_heap[index] == node); PrioType prev_prio = node->prio; node->prio = prio; node->data = data; if (prio < prev_prio) { sift_up(index); } else if (prio > prev_prio) { sift_down(index); } } /** * Remove the specified element from the queue. * * This function takes time O(log(n)). */ void remove(Node* node) { IndexType index = node->index; assert(index != INVALID_INDEX); assert(m_heap[index] == node); node->index = INVALID_INDEX; Node* move_node = m_heap.back(); m_heap.pop_back(); if (index < m_heap.size()) { m_heap[index] = move_node; move_node->index = index; if (move_node->prio < node->prio) { sift_up(index); } else if (move_node->prio > node->prio) { sift_down(index); } } } private: /** Repair the heap along an ascending path to the root. */ void sift_up(IndexType index) { Node* node = m_heap[index]; PrioType prio = node->prio; while (index > 0) { IndexType next_index = (index - 1) / 2; Node* next_node = m_heap[next_index]; if (next_node->prio <= prio) { break; } m_heap[index] = next_node; next_node->index = index; index = next_index; } node->index = index; m_heap[index] = node; } /** Repair the heap along a descending path. */ void sift_down(IndexType index) { Node* node = m_heap[index]; PrioType prio = node->prio; IndexType num_elem = m_heap.size(); IndexType last_row = (m_heap.size() - 1) / 2; while (index < last_row) { IndexType next_index = 2 * index + 1; Node* next_node = m_heap[next_index]; if (next_index + 1 < num_elem) { Node* tmp_node = m_heap[next_index + 1]; if (tmp_node->prio <= next_node->prio) { ++next_index; next_node = tmp_node; } } if (next_node->prio >= prio) { break; } m_heap[index] = next_node; next_node->index = index; index = next_index; } m_heap[index] = node; node->index = index; } /** Array of nodes. */ std::vector m_heap; };