1
0
Fork 0
maximum-weight-matching/cpp/datastruct.h

1120 lines
31 KiB
C++

/*
* Data structures for matching.
*/
#include <algorithm>
#include <cassert>
#include <limits>
/**
* Element in a disjoint set.
*
* Every instance of DisjointSetNode is a member of exactly one disjoint set.
* Some instances of DisjointSetNode also represent a set.
* Every set has a "label".
*
* The following operations can be done efficiently:
* - Find the label of the set that contains a given node.
* - Merge two sets.
* - Undo a previous merge step.
*
* Internally, a set is represented by a tree of DisjointSetNode instances
* that are linked through parent pointers. The shape of the tree is
* controlled such that its depth is at most O(log(n)), where "n" is
* the number of elements in the set.
*
* See also https://en.wikipedia.org/wiki/Disjoint-set_data_structure
*/
template <typename LabelType>
class DisjointSetNode
{
public:
/**
* Create a new element and a new set containing only the new element.
*
* The new instance represents the new element as well as the new set.
*
* This function takes time O(1).
*/
explicit DisjointSetNode(const LabelType& label)
: m_label(label),
m_parent(nullptr),
m_size(1)
{ }
// Prevent copying.
DisjointSetNode(const DisjointSetNode&) = delete;
DisjointSetNode& operator=(const DisjointSetNode&) = delete;
/**
* Return the label of the set that contains the element represented
* by this instance.
*
* This function takes time O(log(n)).
*/
LabelType find() const
{
const DisjointSetNode* p = this;
while (p->m_parent != nullptr) {
p = p->m_parent;
}
return p->m_label;
}
/**
* Change the label of the set represented by this instance.
*
* Only sets have labels. Elements do not have labels.
* Some instances of DisjointSetNode do not represent a set,
* and it is not allowed to assign names to such instances.
*
* This function takes time O(1).
*/
void set_label(const LabelType& label)
{
assert(m_parent == nullptr);
m_label = label;
}
/**
* Merge the sets represented by this instance and by "other".
*
* Both instances must represent a set (not merely an element).
* The merged set will initially have no label.
*
* This function takes time O(1).
*
* @return Instance representing the merged set.
*/
DisjointSetNode* merge(DisjointSetNode* other)
{
assert(m_parent == nullptr);
assert(other->m_parent == nullptr);
assert(other != this);
m_label = LabelType{};
other->m_label = LabelType{};
if (m_size < other->m_size) {
this->m_parent = other;
other->m_size += this->m_size;
return other;
} else {
other->m_parent = this;
this->m_size += other->m_size;
return this;
}
}
/**
* Undo a previous "merge" step.
*
* This function detaches a subset from a set into which it was
* previously merged, and re-establishes the subset as an independent
* set.
*
* The following conditions must hold:
* - At some point in the past, there was a group of K instances
* of DisjointSetNode (s_1, s_2, ..., s_k) that each represented
* a set.
* - These sets were later merged through (K-1) merge steps.
* - Since that point in time, no other merges into the same set
* occurred, disregarding any merges that have already been undone.
*
* The group of (K-1) merge steps can be undone by calling "detach"
* on each of the K instances. These K calls may occur in any order.
* Afterwards, each of these instances will again represent the same set
* that it represented originally. The original labels of these sets
* are lost. Labels must therefore be re-assigned through parameters
* to the "detach" calls.
*
* This function takes time O(log(n)).
*/
void detach(const LabelType& label)
{
DisjointSetNode* p = m_parent;
while (p != nullptr) {
p->m_size -= m_size;
p = p->m_parent;
}
m_parent = nullptr;
m_label = label;
}
private:
LabelType m_label;
DisjointSetNode* m_parent;
std::size_t m_size;
};
/**
* Priority queue that can be split into separate queues.
*
* Each element in the queue has an index, a priority and attached data.
* The following operations can be done efficiently:
* - Find the element with lowest priority.
* - Insert an element with given index, priority and data.
* - Update the priority of an existing element.
* - Split the queue such that all elements with index below a specified
* threshold end up in one queue and all remaining elements in
* a separate queue.
*
* The implementation is essentially an AVL tree, with minimum-priority
* tracking added to it. The implementation follows the outlines of
* https://en.wikipedia.org/wiki/Avl_tree
*/
template <typename PrioType, typename DataType>
class SplittablePriorityQueue
{
public:
typedef unsigned int IndexType;
/** Element in a queue. */
struct Node
{
public:
Node()
: parent(nullptr), left_child(nullptr), right_child(nullptr),
height(0)
{ }
~Node()
{
// Check that the node is not contained in any queue.
assert(height == 0);
}
// Prevent copying.
Node(const Node&) = delete;
Node& operator=(const Node&) = delete;
private:
IndexType index;
PrioType prio;
DataType data;
PrioType best_prio;
DataType best_data;
Node* parent;
Node* left_child;
Node* right_child;
unsigned int height;
friend class SplittablePriorityQueue;
};
/** Construct an empty queue. */
SplittablePriorityQueue()
: m_root(nullptr)
{ }
~SplittablePriorityQueue()
{
// Check that the queue is empty.
assert(! m_root);
}
// Prevent copying.
SplittablePriorityQueue(const SplittablePriorityQueue&) = delete;
SplittablePriorityQueue& operator=(const SplittablePriorityQueue&) = delete;
/** Construct a queue and move elements from "other" to the new queue. */
SplittablePriorityQueue(SplittablePriorityQueue&& other)
{
m_root = other.m_root;
other.m_root = nullptr;
}
/** Move all elements from "other" to this queue. */
SplittablePriorityQueue& operator=(SplittablePriorityQueue&& other)
{
assert(! m_root);
std::swap(m_root, other.m_root);
return *this;
}
/**
* Remove all elements from the queue.
*
* This function takes time O(n).
*/
void clear()
{
Node* node = m_root;
m_root = nullptr;
while (node) {
Node* prev = node;
if (node->left_child) {
node = node->left_child;
prev->left_child = nullptr;
} else if (node->right_child) {
node = node->right_child;
prev->right_child = nullptr;
} else {
node = node->parent;
prev->parent = nullptr;
prev->height = 0;
}
}
}
/** Return true if the queue is empty. */
bool empty() const
{
return (! m_root);
}
/**
* Return the minimum-priority element as tuple (prio, data).
*
* The queue must be non-empty.
*
* This function takes time O(1).
*/
std::pair<PrioType, DataType> find_min() const
{
assert(m_root);
return std::make_pair(m_root->best_prio, m_root->best_data);
}
/**
* Initialize the specified Node instance and insert it into the queue.
*
* The specified Node instance must currently be unused; not contained
* in any queue.
*
* The queue stores a pointer to the Node instance. The caller must
* ensure that the instance remains valid for as long as it is contained
* in any queue.
*
* This function takes time O(log(n)).
*/
void insert(Node* node,
IndexType index,
PrioType prio,
const DataType& data)
{
assert(node->height == 0);
node->index = index;
node->prio = prio;
node->data = data;
node->best_prio = prio;
node->best_data = data;
node->parent = nullptr;
node->left_child = nullptr;
node->right_child = nullptr;
node->height = 1;
if (! m_root) {
m_root = node;
return;
}
Node* p = m_root;
while (true) {
assert(index != p->index);
if (index < p->index) {
if (! p->left_child) {
p->left_child = node;
break;
}
p = p->left_child;
} else {
if (! p->right_child) {
p->right_child = node;
break;
}
p = p->right_child;
}
}
node->parent = p;
rebalance_up(node);
}
/**
* Decrease priority and update data of an existing element.
*
* Nothing is done if the specified priority is greater or equal
* to the current priority of the existing element.
*
* This function takes time O(log(n)).
*/
void update(Node* node, PrioType prio, const DataType& data)
{
assert(node->height != 0);
// Do nothing unless the new priority is strictly lower.
if (prio >= node->prio) {
return;
}
// Update the node itself.
node->prio = prio;
node->data = data;
// Update the best-prio information in the node and its ancestors.
while (true) {
repair_node_best_prio(node);
if (! node->parent) {
break;
}
node = node->parent;
}
// Check that the node belongs to this queue.
assert(node == m_root);
}
/**
* Split the queue by index.
*
* All elements with index less than "threshold" remain in the queue.
* All elements with index greater or equal to "threshold" are moved
* to a new queue. That new queue is returned.
*
* This function takes time O(log(n)).
*/
SplittablePriorityQueue split(IndexType threshold)
{
SplittablePriorityQueue new_queue;
// Special case for an empty queue.
if (! m_root) {
return new_queue;
}
// Descend down the tree along a path as close as possible to the
// threshold. Stop when we reach a subtree that is entirely on
// one side of the threshold.
Node* node = m_root;
while (true) {
if (node->index < threshold) {
/*
* This node is below the threshold. Descend right.
*
* (<T)
* / \
* / \
* (<T) (?) <--- go here
*/
if (! node->right_child) {
break;
}
node = node->right_child;
} else if (node->index > threshold) {
/*
* This node is above the threshold. Descend left.
*
* (>T)
* / \
* / \
* go here --> (?) (>T)
*/
if (! node->left_child) {
break;
}
node = node->left_child;
} else {
/*
* This node is exactly at the threshold.
* Descend left if possible, then stop.
*
* (=T)
* / \
* / \
* (<T) (>T)
*/
if (node->left_child) {
node = node->left_child;
}
break;
}
}
// Build "left_tree" with nodes < threshold,
// and "right_tree" with nodes >= threshold.
Node* left_tree = nullptr;
Node* right_tree = nullptr;
// Initially, either the left or right tree consists of only
// the subtree where the descending path ended, and the other
// tree is empty.
if (node->index < threshold) {
left_tree = node;
} else {
right_tree = node;
}
// Detach node from its parent.
Node* parent = node->parent;
detach_node(node);
// Retrace up the tree.
// On the way up, join each node to the left or right subtree.
while (parent) {
// Move up to parent.
node = parent;
parent = node->parent;
detach_node(node);
if (node->index < threshold) {
/*
* Join node (with its descendants) to the left tree.
*
* N <--- new left_tree
* / .
* X .
* / \ (left_tree)
*/
assert(! node->right_child);
left_tree = join(node->left_child, node, left_tree);
} else {
/*
* Join node (with its descendants) to the right tree.
*
* N <--- new right_tree
* . \
* . X
* (right_tree) / \
*/
assert(! node->left_child);
right_tree = join(right_tree, node, node->right_child);
}
}
// Keep the left tree in this instance.
m_root = left_tree;
// Return the right tree as a new queue.
new_queue.m_root = right_tree;
return new_queue;
}
private:
/** Return node height, or 0 if node == nullptr. */
static unsigned int get_node_height(const Node* node)
{
return node ? node->height : 0;
}
/** Detach a node from its parent. */
static void detach_node(Node* node)
{
Node* parent = node->parent;
if (parent) {
node->parent = nullptr;
if (parent->left_child == node) {
parent->left_child = nullptr;
} else if (parent->right_child == node) {
parent->right_child = nullptr;
}
}
}
/** Repair best-priority information in the specified node. */
static void repair_node_best_prio(Node* node)
{
PrioType best_prio = node->prio;
DataType best_data = node->data;
Node* lchild = node->left_child;
if (lchild && (lchild->best_prio < best_prio)) {
best_prio = lchild->best_prio;
best_data = lchild->best_data;
}
Node* rchild = node->right_child;
if (rchild && (rchild->best_prio < best_prio)) {
best_prio = rchild->best_prio;
best_data = rchild->best_data;
}
node->best_prio = best_prio;
node->best_data = best_data;
}
/**
* Repair the height and best-priority information of a node
* after modifying its children.
*
* After repairing a node, it is typically necessary to also repair
* its ancestors.
*/
static void repair_node(Node* node)
{
Node* lchild = node->left_child;
Node* rchild = node->right_child;
// Repair node height.
node->height = 1 + std::max(get_node_height(lchild),
get_node_height(rchild));
// Repair best-priority.
repair_node_best_prio(node);
}
/** Rotate the subtree to the left and return the new root of the subtree. */
static Node* rotate_left(Node* node)
{
/*
* N C
* / \ / \
* A C ---> N D
* / \ / \
* B D A B
*/
Node* parent = node->parent;
Node* new_top = node->right_child;
assert(new_top);
Node* nb = new_top->left_child;
node->right_child = nb;
if (nb) {
nb->parent = node;
}
new_top->left_child = node;
node->parent = new_top;
new_top->parent = parent;
if (parent) {
if (parent->left_child == node) {
parent->left_child = new_top;
} else if (parent->right_child == node) {
parent->right_child = new_top;
} else {
assert(false);
}
}
repair_node(node);
repair_node(new_top);
return new_top;
}
/** Rotate the subtree to the right and return the new root of the subtree. */
Node* rotate_right(Node* node)
{
/*
* N B
* / \ / \
* B D ---> A N
* / \ / \
* A C C D
*/
Node* parent = node->parent;
Node* new_top = node->left_child;
assert(new_top);
Node* nc = new_top->right_child;
node->left_child = nc;
if (nc) {
nc->parent = node;
}
new_top->right_child = node;
node->parent = new_top;
new_top->parent = parent;
if (parent) {
if (parent->left_child == node) {
parent->left_child = new_top;
} else if (parent->right_child == node) {
parent->right_child = new_top;
} else {
assert(false);
}
}
repair_node(node);
repair_node(new_top);
return new_top;
}
/** Repair and rebalance the specified node and its ancestors. */
void rebalance_up(Node* node)
{
while (true) {
unsigned int lh = get_node_height(node->left_child);
unsigned int rh = get_node_height(node->right_child);
if (lh > rh + 1) {
/*
* This node is left-heavy. Rotate right to rebalance.
*
* N L
* / \ / \
* L \ / N
* / \ \ ---> / / \
* A B \ A B \
* \ \
* R R
*/
Node *lchild = node->left_child;
unsigned int llh = get_node_height(lchild->left_child);
unsigned int lrh = get_node_height(lchild->right_child);
if (llh < lrh) {
// Double rotation.
lchild = rotate_left(lchild);
}
node = rotate_right(node);
} else if (lh + 1 < rh) {
/*
* This node is right-heavy. Rotate left to rebalance.
*
* N R
* / \ / \
* / R N \
* / / \ ---> / \ \
* / A B / A B
* / /
* L L
*/
Node *rchild = node->right_child;
unsigned int rlh = get_node_height(rchild->left_child);
unsigned int rrh = get_node_height(rchild->right_child);
if (rlh > rrh) {
// Double rotation.
rchild = rotate_right(rchild);
}
node = rotate_left(node);
} else {
// No rotation, but must still repair the node.
repair_node(node);
}
if (! node->parent) {
// Reached root.
m_root = node;
break;
}
// Continue rebalancing at the parent.
node = node->parent;
}
}
/**
* Join a left subtree, middle node and right subtree together.
*
* The left subtree is higher than the right subtree.
*/
Node* join_right(Node* ltree, Node* node, Node* rtree)
{
assert(ltree);
unsigned int lh = ltree->height;
unsigned int rh = get_node_height(rtree);
assert(lh > rh + 1);
/*
* Descend down the right spine of "ltree".
* Stop at a node with compatible height, then insert "node"
* and attach "rtree".
*
* ltree
* / \
* X
* / \
* X <-- cur
* / \
* node
* / \
* X rtree
*/
// Descend to a point with compatible height.
Node* cur = ltree;
while (cur->right_child && (cur->right_child->height > rh + 1)) {
cur = cur->right_child;
}
// Insert "node" and "rtree".
node->left_child = cur->right_child;
node->right_child = rtree;
if (node->left_child) {
node->left_child->parent = node;
}
if (rtree) {
rtree->parent = node;
}
cur->right_child = node;
node->parent = cur;
// A double rotation may be necessary.
if ((! cur->left_child) || (cur->left_child->height <= rh)) {
node = rotate_right(node);
cur = rotate_left(cur);
} else {
repair_node(node);
repair_node(cur);
}
// Ascend from "cur" to the root of the tree; repair and rebalance.
while (cur->parent) {
cur = cur->parent;
assert(cur->left_child);
assert(cur->right_child);
if (cur->left_child->height + 1 < cur->right_child->height) {
rotate_left(cur);
} else {
repair_node(cur);
}
}
return cur;
}
/**
* Join a left subtree, middle node and right subtree together.
*
* The right subtree is higher than the left subtree.
*/
Node* join_left(Node* ltree, Node* node, Node* rtree)
{
assert(rtree);
unsigned int lh = get_node_height(ltree);
unsigned int rh = rtree->height;
assert(lh + 1 < rh);
/*
* Descend down the left spine of "rtree".
* Stop at a node with compatible height, then insert "node"
* and attach "ltree".
*
* rtree
* / \
* X
* / \
* cur --> X
* / \
* node
* / \
* ltree X
*/
// Descend to a point with compatible height.
Node* cur = rtree;
while (cur->left_child && (cur->left_child->height > lh + 1)) {
cur = cur->left_child;
}
// Insert "node" and "ltree".
node->left_child = ltree;
node->right_child = cur->left_child;
if (ltree) {
ltree->parent = node;
}
if (node->right_child) {
node->right_child->parent = node;
}
cur->left_child = node;
node->parent = cur;
// A double rotation may be necessary.
if ((! cur->right_child) || (cur->right_child->height <= lh)) {
node = rotate_left(node);
cur = rotate_right(cur);
} else {
repair_node(node);
repair_node(cur);
}
// Ascend from "cur" to the root of the tree; repair and rebalance.
while (cur->parent) {
cur = cur->parent;
assert(cur->left_child);
assert(cur->right_child);
if (cur->left_child->height > cur->right_child->height + 1) {
cur = rotate_right(cur);
} else {
repair_node(cur);
}
}
return cur;
}
/**
* Join a left subtree, middle node and right subtree together.
*
* The index of all nodes in the left subtree must be less than
* the index of the middle node. The index of all nodes in
* the right subtree must be greater than the middle node.
*
* The left or right subtree may initially be a child of the middle
* node; such links will be broken as needed.
*
* The left and right subtrees must be consistent, semi-balanced trees.
* Height and priority of the middle node may initially be inconsistent;
* this function will repair it.
*
* @return Root node of the joined tree.
*/
Node* join(Node* ltree, Node* node, Node* rtree)
{
unsigned int lh = get_node_height(ltree);
unsigned int rh = get_node_height(rtree);
if (lh > rh + 1) {
assert(ltree);
ltree->parent = nullptr;
return join_right(ltree, node, rtree);
} else if (lh + 1 < rh) {
assert(rtree);
rtree->parent = nullptr;
return join_left(ltree, node, rtree);
} else {
/*
* Subtree heights are compatible. Just join them.
*
* node
* / \
* ltree rtree
* / \ / \
*/
node->left_child = ltree;
if (ltree) {
ltree->parent = node;
}
node->right_child = rtree;
if (rtree) {
rtree->parent = node;
}
repair_node(node);
return node;
}
}
/** Root node of the queue, or "nullptr" if the queue is empty. */
Node* m_root;
};
/**
* Normal min-priority queue based on a binary heap.
*/
template <typename PrioType, typename DataType>
class PriorityQueue
{
public:
typedef unsigned int IndexType;
static constexpr IndexType INVALID_INDEX = std::numeric_limits<IndexType>::max();
/** Element in a PriorityQueue. */
struct Node
{
public:
Node()
: index(INVALID_INDEX)
{ }
~Node()
{
// Check that the node is not contained in any queue.
assert(index == INVALID_INDEX);
}
// Prevent copying.
Node(const Node&) = delete;
Node& operator=(const Node&) = delete;
private:
IndexType index;
PrioType prio;
DataType data;
friend class PriorityQueue;
};
/** Construct an empty queue. */
PriorityQueue()
{ }
~PriorityQueue()
{
// Check that the queue is empty.
assert(m_heap.empty());
}
// Prevent copying.
PriorityQueue(const PriorityQueue&) = delete;
PriorityQueue& operator=(const PriorityQueue&) = delete;
/** Construct a queue and move elements from "other" to the new queue. */
PriorityQueue(PriorityQueue&& other)
: m_heap(std::move(other.m_heap))
{ }
/**
* Remove all elements from the queue.
*
* This function takes time O(n).
*/
void clear()
{
for (Node* node : m_heap) {
node->index = INVALID_INDEX;
}
m_heap.clear();
}
/** Return true if the queue is empty. */
bool empty() const
{
return m_heap.empty();
}
/**
* Return the minimum-priority element as tuple (prio, data).
*
* The queue must be non-empty.
*
* This function takes time O(1).
*/
std::pair<PrioType, DataType> find_min() const
{
assert(! m_heap.empty());
Node* top = m_heap.front();
return std::make_pair(top->prio, top->data);
}
/**
* Insert a new element into the queue.
*
* This function takes time O(log(n)).
*/
void insert(Node* node, PrioType prio, const DataType& data)
{
assert(node->index == INVALID_INDEX);
node->index = m_heap.size();
node->prio = prio;
node->data = data;
m_heap.push_back(node);
sift_up(node->index);
}
/**
* Update priority and/or data of an existing node.
*
* This function takes time O(log(n)).
*/
void update(Node* node, PrioType prio, const DataType& data)
{
IndexType index = node->index;
assert(index != INVALID_INDEX);
assert(m_heap[index] == node);
PrioType prev_prio = node->prio;
node->prio = prio;
node->data = data;
if (prio < prev_prio) {
sift_up(index);
} else if (prio > prev_prio) {
sift_down(index);
}
}
/**
* Remove the specified element from the queue.
*
* This function takes time O(log(n)).
*/
void remove(Node* node)
{
IndexType index = node->index;
assert(index != INVALID_INDEX);
assert(m_heap[index] == node);
node->index = INVALID_INDEX;
Node* move_node = m_heap.back();
m_heap.pop_back();
if (index < m_heap.size()) {
m_heap[index] = move_node;
move_node->index = index;
if (move_node->prio < node->prio) {
sift_up(index);
} else if (move_node->prio > node->prio) {
sift_down(index);
}
}
}
private:
/** Repair the heap along an ascending path to the root. */
void sift_up(IndexType index)
{
Node* node = m_heap[index];
PrioType prio = node->prio;
while (index > 0) {
IndexType next_index = (index - 1) / 2;
Node* next_node = m_heap[next_index];
if (next_node->prio <= prio) {
break;
}
m_heap[index] = next_node;
next_node->index = index;
index = next_index;
}
node->index = index;
m_heap[index] = node;
}
/** Repair the heap along a descending path. */
void sift_down(IndexType index)
{
Node* node = m_heap[index];
PrioType prio = node->prio;
IndexType num_elem = m_heap.size();
IndexType last_row = (m_heap.size() - 1) / 2;
while (index < last_row) {
IndexType next_index = 2 * index + 1;
Node* next_node = m_heap[next_index];
if (next_index + 1 < num_elem) {
Node* tmp_node = m_heap[next_index + 1];
if (tmp_node->prio <= next_node->prio) {
++next_index;
next_node = tmp_node;
}
}
if (next_node->prio >= prio) {
break;
}
m_heap[index] = next_node;
next_node->index = index;
index = next_index;
}
m_heap[index] = node;
node->index = index;
}
/** Array of nodes. */
std::vector<Node*> m_heap;
};