diff --git a/python/datastruct.py b/python/datastruct.py index 8530c04..330d2fe 100644 --- a/python/datastruct.py +++ b/python/datastruct.py @@ -820,3 +820,13 @@ class PriorityQueue(Generic[_ElemT]): assert prio <= elem.prio elem.prio = prio self._sift_up(elem.index) + + def increase_prio(self, elem: Node[_ElemT], prio: float) -> None: + """Increase the priority of an existing element in the queue. + + This function takes time O(log(n)). + """ + assert self.heap[elem.index] is elem + assert prio >= elem.prio + elem.prio = prio + self._sift_down(elem.index) diff --git a/python/test_datastruct.py b/python/test_datastruct.py index 8b1f00e..60a09d5 100644 --- a/python/test_datastruct.py +++ b/python/test_datastruct.py @@ -438,6 +438,44 @@ class TestPriorityQueue(unittest.TestCase): q.clear() self.assertTrue(q.empty()) + def test_increase_prio(self): + """Increase priority of existing element.""" + + q = PriorityQueue() + + n1 = q.insert(5, "a") + q.increase_prio(n1, 8) + self.assertEqual(n1.prio, 8) + self.assertIs(q.find_min(), n1) + + q = PriorityQueue() + n1 = q.insert(9, "a") + n2 = q.insert(4, "b") + n3 = q.insert(7, "c") + n4 = q.insert(5, "d") + self.assertIs(q.find_min(), n2) + + q.increase_prio(n2, 8) + self.assertEqual(n2.prio, 8) + self.assertIs(q.find_min(), n4) + + q.increase_prio(n3, 10) + self.assertEqual(n3.prio, 10) + self.assertIs(q.find_min(), n4) + + q.delete(n4) + self.assertIs(q.find_min(), n2) + + q.delete(n2) + self.assertIs(q.find_min(), n1) + + q.delete(n1) + self.assertIs(q.find_min(), n3) + self.assertEqual(n3.prio, 10) + + q.delete(n3) + self.assertTrue(q.empty()) + def test_random(self): """Pseudo-random test.""" rng = random.Random(34567) @@ -462,8 +500,11 @@ class TestPriorityQueue(unittest.TestCase): for i in range(10000): p = rng.randint(0, num_elem - 1) - prio = rng.randint(0, elems[p][1]) - q.decrease_prio(elems[p][0], prio) + prio = rng.randint(0, 1000000) + if prio <= elems[p][1]: + q.decrease_prio(elems[p][0], prio) + else: + q.increase_prio(elems[p][0], prio) elems[p] = (elems[p][0], prio, elems[p][2]) check()