533 lines
15 KiB
Python
533 lines
15 KiB
Python
"""Unit tests for data structures."""
|
|
|
|
import random
|
|
import unittest
|
|
|
|
from mwmatching.datastruct import UnionFindQueue, PriorityQueue
|
|
|
|
|
|
class TestUnionFindQueue(unittest.TestCase):
|
|
"""Test UnionFindQueue."""
|
|
|
|
def _check_tree(self, queue):
|
|
"""Check tree balancing rules and priority info."""
|
|
|
|
self.assertIsNone(queue.tree.parent)
|
|
self.assertIs(queue.tree.owner, queue)
|
|
|
|
nodes = [queue.tree]
|
|
while nodes:
|
|
|
|
node = nodes.pop()
|
|
|
|
if node.left is not None:
|
|
self.assertIs(node.left.parent, node)
|
|
nodes.append(node.left)
|
|
|
|
if node.right is not None:
|
|
self.assertIs(node.right.parent, node)
|
|
nodes.append(node.right)
|
|
|
|
if node is not queue.tree:
|
|
self.assertIsNone(node.owner)
|
|
|
|
lh = 0 if node.left is None else node.left.height
|
|
rh = 0 if node.right is None else node.right.height
|
|
self.assertEqual(node.height, 1 + max(lh, rh))
|
|
|
|
self.assertLessEqual(lh, rh + 1)
|
|
self.assertLessEqual(rh, lh + 1)
|
|
|
|
best_node = {node}
|
|
best_prio = node.prio
|
|
for child in (node.left, node.right):
|
|
if child is not None:
|
|
if child.min_node.prio < best_prio:
|
|
best_prio = child.min_node.prio
|
|
best_node = {child.min_node}
|
|
elif child.min_node.prio == best_prio:
|
|
best_node.add(child.min_node)
|
|
|
|
self.assertEqual(node.min_node.prio, best_prio)
|
|
self.assertIn(node.min_node, best_node)
|
|
|
|
def test_single(self):
|
|
"""Single element."""
|
|
q = UnionFindQueue("Q")
|
|
|
|
with self.assertRaises(Exception):
|
|
q.min_prio()
|
|
|
|
with self.assertRaises(Exception):
|
|
q.min_elem()
|
|
|
|
n = q.insert("a", 4)
|
|
self.assertIsInstance(n, UnionFindQueue.Node)
|
|
|
|
self._check_tree(q)
|
|
|
|
self.assertEqual(n.find(), "Q")
|
|
self.assertEqual(q.min_prio(), 4)
|
|
self.assertEqual(q.min_elem(), "a")
|
|
|
|
with self.assertRaises(Exception):
|
|
q.insert("x", 1)
|
|
|
|
n.set_prio(8)
|
|
self._check_tree(q)
|
|
|
|
self.assertEqual(n.find(), "Q")
|
|
self.assertEqual(q.min_prio(), 8)
|
|
self.assertEqual(q.min_elem(), "a")
|
|
|
|
q.clear()
|
|
|
|
def test_simple(self):
|
|
"""Simple test, 5 elements."""
|
|
q1 = UnionFindQueue("A")
|
|
n1 = q1.insert("a", 5)
|
|
|
|
q2 = UnionFindQueue("B")
|
|
n2 = q2.insert("b", 6)
|
|
|
|
q3 = UnionFindQueue("C")
|
|
n3 = q3.insert("c", 7)
|
|
|
|
q4 = UnionFindQueue("D")
|
|
n4 = q4.insert("d", 4)
|
|
|
|
q5 = UnionFindQueue("E")
|
|
n5 = q5.insert("e", 3)
|
|
|
|
q345 = UnionFindQueue("P")
|
|
q345.merge([q3, q4, q5])
|
|
self._check_tree(q345)
|
|
|
|
self.assertEqual(n1.find(), "A")
|
|
self.assertEqual(n3.find(), "P")
|
|
self.assertEqual(n4.find(), "P")
|
|
self.assertEqual(n5.find(), "P")
|
|
self.assertEqual(q345.min_prio(), 3)
|
|
self.assertEqual(q345.min_elem(), "e")
|
|
|
|
with self.assertRaises(Exception):
|
|
q3.min_prio()
|
|
|
|
self._check_tree(q345)
|
|
n5.set_prio(6)
|
|
self._check_tree(q345)
|
|
|
|
self.assertEqual(q345.min_prio(), 4)
|
|
self.assertEqual(q345.min_elem(), "d")
|
|
|
|
q12 = UnionFindQueue("Q")
|
|
q12.merge([q1, q2])
|
|
self._check_tree(q12)
|
|
|
|
self.assertEqual(n1.find(), "Q")
|
|
self.assertEqual(n2.find(), "Q")
|
|
self.assertEqual(q12.min_prio(), 5)
|
|
self.assertEqual(q12.min_elem(), "a")
|
|
|
|
q12345 = UnionFindQueue("R")
|
|
q12345.merge([q12, q345])
|
|
self._check_tree(q12345)
|
|
|
|
self.assertEqual(n1.find(), "R")
|
|
self.assertEqual(n2.find(), "R")
|
|
self.assertEqual(n3.find(), "R")
|
|
self.assertEqual(n4.find(), "R")
|
|
self.assertEqual(n5.find(), "R")
|
|
self.assertEqual(q12345.min_prio(), 4)
|
|
self.assertEqual(q12345.min_elem(), "d")
|
|
|
|
n4.set_prio(8)
|
|
self._check_tree(q12345)
|
|
|
|
self.assertEqual(q12345.min_prio(), 5)
|
|
self.assertEqual(q12345.min_elem(), "a")
|
|
|
|
n3.set_prio(2)
|
|
self._check_tree(q12345)
|
|
|
|
self.assertEqual(q12345.min_prio(), 2)
|
|
self.assertEqual(q12345.min_elem(), "c")
|
|
|
|
q12345.split()
|
|
self._check_tree(q12)
|
|
self._check_tree(q345)
|
|
|
|
self.assertEqual(n1.find(), "Q")
|
|
self.assertEqual(n2.find(), "Q")
|
|
self.assertEqual(n3.find(), "P")
|
|
self.assertEqual(n4.find(), "P")
|
|
self.assertEqual(n5.find(), "P")
|
|
self.assertEqual(q12.min_prio(), 5)
|
|
self.assertEqual(q12.min_elem(), "a")
|
|
self.assertEqual(q345.min_prio(), 2)
|
|
self.assertEqual(q345.min_elem(), "c")
|
|
|
|
q12.split()
|
|
self._check_tree(q1)
|
|
self._check_tree(q2)
|
|
|
|
q345.split()
|
|
self._check_tree(q3)
|
|
self._check_tree(q4)
|
|
self._check_tree(q5)
|
|
|
|
self.assertEqual(n1.find(), "A")
|
|
self.assertEqual(n2.find(), "B")
|
|
self.assertEqual(n3.find(), "C")
|
|
self.assertEqual(n4.find(), "D")
|
|
self.assertEqual(n5.find(), "E")
|
|
self.assertEqual(q3.min_prio(), 2)
|
|
self.assertEqual(q3.min_elem(), "c")
|
|
|
|
q1.clear()
|
|
q2.clear()
|
|
q3.clear()
|
|
q4.clear()
|
|
q5.clear()
|
|
q12.clear()
|
|
q345.clear()
|
|
q12345.clear()
|
|
|
|
def test_medium(self):
|
|
"""Medium test, 14 elements."""
|
|
|
|
prios = [3, 8, 6, 2, 9, 4, 6, 8, 1, 5, 9, 4, 7, 8]
|
|
|
|
queues = []
|
|
nodes = []
|
|
for i in range(14):
|
|
q = UnionFindQueue(chr(ord("A") + i))
|
|
n = q.insert(chr(ord("a") + i), prios[i])
|
|
queues.append(q)
|
|
nodes.append(n)
|
|
|
|
q = UnionFindQueue("AB")
|
|
q.merge(queues[0:2])
|
|
queues.append(q)
|
|
self._check_tree(q)
|
|
self.assertEqual(q.min_prio(), min(prios[0:2]))
|
|
|
|
q = UnionFindQueue("CDE")
|
|
q.merge(queues[2:5])
|
|
queues.append(q)
|
|
self._check_tree(q)
|
|
self.assertEqual(q.min_prio(), min(prios[2:5]))
|
|
|
|
q = UnionFindQueue("FGHI")
|
|
q.merge(queues[5:9])
|
|
queues.append(q)
|
|
self._check_tree(q)
|
|
self.assertEqual(q.min_prio(), min(prios[5:9]))
|
|
|
|
q = UnionFindQueue("JKLMN")
|
|
q.merge(queues[9:14])
|
|
queues.append(q)
|
|
self._check_tree(q)
|
|
self.assertEqual(q.min_prio(), min(prios[9:14]))
|
|
|
|
for i in range(0, 2):
|
|
self.assertEqual(nodes[i].find(), "AB")
|
|
for i in range(2, 5):
|
|
self.assertEqual(nodes[i].find(), "CDE")
|
|
for i in range(5, 9):
|
|
self.assertEqual(nodes[i].find(), "FGHI")
|
|
for i in range(9, 14):
|
|
self.assertEqual(nodes[i].find(), "JKLMN")
|
|
|
|
q = UnionFindQueue("ALL")
|
|
q.merge(queues[14:18])
|
|
queues.append(q)
|
|
self._check_tree(q)
|
|
self.assertEqual(q.min_prio(), 1)
|
|
self.assertEqual(q.min_elem(), "i")
|
|
|
|
for i in range(14):
|
|
self.assertEqual(nodes[i].find(), "ALL")
|
|
|
|
prios[8] = 5
|
|
nodes[8].set_prio(prios[8])
|
|
self.assertEqual(q.min_prio(), 2)
|
|
self.assertEqual(q.min_elem(), "d")
|
|
|
|
q.split()
|
|
|
|
for i in range(0, 2):
|
|
self.assertEqual(nodes[i].find(), "AB")
|
|
for i in range(2, 5):
|
|
self.assertEqual(nodes[i].find(), "CDE")
|
|
for i in range(5, 9):
|
|
self.assertEqual(nodes[i].find(), "FGHI")
|
|
for i in range(9, 14):
|
|
self.assertEqual(nodes[i].find(), "JKLMN")
|
|
|
|
self.assertEqual(queues[14].min_prio(), min(prios[0:2]))
|
|
self.assertEqual(queues[15].min_prio(), min(prios[2:5]))
|
|
self.assertEqual(queues[16].min_prio(), min(prios[5:9]))
|
|
self.assertEqual(queues[17].min_prio(), min(prios[9:14]))
|
|
|
|
for q in queues[14:18]:
|
|
self._check_tree(q)
|
|
q.split()
|
|
|
|
for i in range(14):
|
|
self._check_tree(queues[i])
|
|
self.assertEqual(nodes[i].find(), chr(ord("A") + i))
|
|
self.assertEqual(queues[i].min_prio(), prios[i])
|
|
self.assertEqual(queues[i].min_elem(), chr(ord("a") + i))
|
|
|
|
for q in queues:
|
|
q.clear()
|
|
|
|
def test_random(self):
|
|
"""Pseudo-random test."""
|
|
|
|
rng = random.Random(23456)
|
|
|
|
nodes = []
|
|
prios = []
|
|
queues = {}
|
|
queue_nodes = {}
|
|
queue_subs = {}
|
|
live_queues = set()
|
|
live_merged_queues = set()
|
|
|
|
for i in range(4000):
|
|
name = f"q{i}"
|
|
q = UnionFindQueue(name)
|
|
p = rng.random()
|
|
n = q.insert(f"n{i}", p)
|
|
nodes.append(n)
|
|
prios.append(p)
|
|
queues[name] = q
|
|
queue_nodes[name] = {i}
|
|
live_queues.add(name)
|
|
|
|
for i in range(2000):
|
|
|
|
for k in range(10):
|
|
t = rng.randint(0, len(nodes) - 1)
|
|
name = nodes[t].find()
|
|
self.assertIn(name, live_queues)
|
|
self.assertIn(t, queue_nodes[name])
|
|
p = rng.random()
|
|
prios[t] = p
|
|
nodes[t].set_prio(p)
|
|
pp = min(prios[tt] for tt in queue_nodes[name])
|
|
tt = prios.index(pp)
|
|
self.assertEqual(queues[name].min_prio(), pp)
|
|
self.assertEqual(queues[name].min_elem(), f"n{tt}")
|
|
|
|
k = rng.randint(2, max(2, len(live_queues) // 2 - 400))
|
|
subs = rng.sample(sorted(live_queues), k)
|
|
|
|
name = f"Q{i}"
|
|
q = UnionFindQueue(name)
|
|
q.merge([queues[nn] for nn in subs])
|
|
self._check_tree(q)
|
|
queues[name] = q
|
|
queue_nodes[name] = set().union(*(queue_nodes[nn] for nn in subs))
|
|
queue_subs[name] = set(subs)
|
|
live_queues.difference_update(subs)
|
|
live_merged_queues.difference_update(subs)
|
|
live_queues.add(name)
|
|
live_merged_queues.add(name)
|
|
|
|
pp = min(prios[tt] for tt in queue_nodes[name])
|
|
tt = prios.index(pp)
|
|
self.assertEqual(q.min_prio(), pp)
|
|
self.assertEqual(q.min_elem(), f"n{tt}")
|
|
|
|
if len(live_merged_queues) >= 100:
|
|
name = rng.choice(sorted(live_merged_queues))
|
|
queues[name].split()
|
|
|
|
for nn in queue_subs[name]:
|
|
self._check_tree(queues[nn])
|
|
pp = min(prios[tt] for tt in queue_nodes[nn])
|
|
tt = prios.index(pp)
|
|
self.assertEqual(queues[nn].min_prio(), pp)
|
|
self.assertEqual(queues[nn].min_elem(), f"n{tt}")
|
|
live_queues.add(nn)
|
|
if nn in queue_subs:
|
|
live_merged_queues.add(nn)
|
|
|
|
live_merged_queues.remove(name)
|
|
live_queues.remove(name)
|
|
|
|
del queues[name]
|
|
del queue_nodes[name]
|
|
del queue_subs[name]
|
|
|
|
for q in queues.values():
|
|
q.clear()
|
|
|
|
|
|
class TestPriorityQueue(unittest.TestCase):
|
|
"""Test PriorityQueue."""
|
|
|
|
def test_empty(self):
|
|
"""Empty queue."""
|
|
q = PriorityQueue()
|
|
self.assertTrue(q.empty())
|
|
with self.assertRaises(IndexError):
|
|
q.find_min()
|
|
|
|
def test_single(self):
|
|
"""Single element."""
|
|
q = PriorityQueue()
|
|
|
|
n1 = q.insert(5, "a")
|
|
self.assertEqual(n1.prio, 5)
|
|
self.assertEqual(n1.data, "a")
|
|
self.assertFalse(q.empty())
|
|
self.assertIs(q.find_min(), n1)
|
|
|
|
q.decrease_prio(n1, 3)
|
|
self.assertEqual(n1.prio, 3)
|
|
self.assertIs(q.find_min(), n1)
|
|
|
|
q.delete(n1)
|
|
self.assertTrue(q.empty())
|
|
|
|
def test_simple(self):
|
|
"""A few elements."""
|
|
prios = [9, 4, 7, 5, 8, 6, 4, 5, 2, 6]
|
|
labels = "abcdefghij"
|
|
|
|
q = PriorityQueue()
|
|
|
|
elems = [q.insert(prio, data) for (prio, data) in zip(prios, labels)]
|
|
for (n, prio, data) in zip(elems, prios, labels):
|
|
self.assertEqual(n.prio, prio)
|
|
self.assertEqual(n.data, data)
|
|
|
|
self.assertIs(q.find_min(), elems[8])
|
|
|
|
q.decrease_prio(elems[2], 1)
|
|
self.assertIs(q.find_min(), elems[2])
|
|
|
|
q.decrease_prio(elems[4], 3)
|
|
self.assertIs(q.find_min(), elems[2])
|
|
|
|
q.delete(elems[2])
|
|
self.assertIs(q.find_min(), elems[8])
|
|
|
|
q.delete(elems[8])
|
|
self.assertIs(q.find_min(), elems[4])
|
|
|
|
q.delete(elems[4])
|
|
q.delete(elems[1])
|
|
self.assertIs(q.find_min(), elems[6])
|
|
|
|
q.delete(elems[3])
|
|
q.delete(elems[9])
|
|
self.assertIs(q.find_min(), elems[6])
|
|
|
|
q.delete(elems[6])
|
|
self.assertIs(q.find_min(), elems[7])
|
|
|
|
q.delete(elems[7])
|
|
self.assertIs(q.find_min(), elems[5])
|
|
|
|
self.assertFalse(q.empty())
|
|
q.clear()
|
|
self.assertTrue(q.empty())
|
|
|
|
def test_increase_prio(self):
|
|
"""Increase priority of existing element."""
|
|
|
|
q = PriorityQueue()
|
|
|
|
n1 = q.insert(5, "a")
|
|
q.increase_prio(n1, 8)
|
|
self.assertEqual(n1.prio, 8)
|
|
self.assertIs(q.find_min(), n1)
|
|
|
|
q = PriorityQueue()
|
|
n1 = q.insert(9, "a")
|
|
n2 = q.insert(4, "b")
|
|
n3 = q.insert(7, "c")
|
|
n4 = q.insert(5, "d")
|
|
self.assertIs(q.find_min(), n2)
|
|
|
|
q.increase_prio(n2, 8)
|
|
self.assertEqual(n2.prio, 8)
|
|
self.assertIs(q.find_min(), n4)
|
|
|
|
q.increase_prio(n3, 10)
|
|
self.assertEqual(n3.prio, 10)
|
|
self.assertIs(q.find_min(), n4)
|
|
|
|
q.delete(n4)
|
|
self.assertIs(q.find_min(), n2)
|
|
|
|
q.delete(n2)
|
|
self.assertIs(q.find_min(), n1)
|
|
|
|
q.delete(n1)
|
|
self.assertIs(q.find_min(), n3)
|
|
self.assertEqual(n3.prio, 10)
|
|
|
|
q.delete(n3)
|
|
self.assertTrue(q.empty())
|
|
|
|
def test_random(self):
|
|
"""Pseudo-random test."""
|
|
rng = random.Random(34567)
|
|
|
|
num_elem = 1000
|
|
|
|
seq = 0
|
|
elems = []
|
|
q = PriorityQueue()
|
|
|
|
def check():
|
|
min_prio = min(prio for (n, prio, data) in elems)
|
|
m = q.find_min()
|
|
self.assertIn((m, m.prio, m.data), elems)
|
|
self.assertEqual(m.prio, min_prio)
|
|
|
|
for i in range(num_elem):
|
|
seq += 1
|
|
prio = rng.randint(0, 1000000)
|
|
elems.append((q.insert(prio, seq), prio, seq))
|
|
check()
|
|
|
|
for i in range(10000):
|
|
p = rng.randint(0, num_elem - 1)
|
|
prio = rng.randint(0, 1000000)
|
|
if prio <= elems[p][1]:
|
|
q.decrease_prio(elems[p][0], prio)
|
|
else:
|
|
q.increase_prio(elems[p][0], prio)
|
|
elems[p] = (elems[p][0], prio, elems[p][2])
|
|
check()
|
|
|
|
p = rng.randint(0, num_elem - 1)
|
|
q.delete(elems[p][0])
|
|
elems.pop(p)
|
|
check()
|
|
|
|
seq += 1
|
|
prio = rng.randint(0, 1000000)
|
|
elems.append((q.insert(prio, seq), prio, seq))
|
|
check()
|
|
|
|
for i in range(num_elem):
|
|
p = rng.randint(0, num_elem - 1 - i)
|
|
q.delete(elems[p][0])
|
|
elems.pop(p)
|
|
if elems:
|
|
check()
|
|
|
|
self.assertTrue(q.empty())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|