1
0
Fork 0

Add a few tests for maximum cardinality

This commit is contained in:
Joris van Rantwijk 2023-02-06 13:58:38 +01:00
parent 7617e68d59
commit d4bfb712d2
1 changed files with 44 additions and 1 deletions

View File

@ -2,10 +2,13 @@
import unittest import unittest
from max_weight_matching import maximum_weight_matching as mwm from max_weight_matching import (
maximum_weight_matching as mwm,
adjust_weights_for_maximum_cardinality_matching as adj)
class TestMaximumWeightMatching(unittest.TestCase): class TestMaximumWeightMatching(unittest.TestCase):
"""Test maximum_weight_matching() function."""
def test10_empty(self): def test10_empty(self):
"""empty input graph""" """empty input graph"""
@ -126,6 +129,46 @@ class TestMaximumWeightMatching(unittest.TestCase):
[(1,2), (3,5), (7,6), (8,10), (4,9)]) [(1,2), (3,5), (7,6), (8,10), (4,9)])
class TestAdjustWeightForMaxCardinality(unittest.TestCase):
"""Test adjust_weights_for_maximum_cardinality_matching() function."""
def test_chain(self):
self.assertEqual(
adj([(0,1,2), (1,2,8), (2,3,3), (3,4,9), (4,5,1), (5,6,7), (6,7,4)]),
[(0,1,65), (1,2,71), (2,3,66), (3,4,72), (4,5,64), (5,6,70), (6,7,67)])
def test_chain_preadjusted(self):
self.assertEqual(
adj([(0,1,65), (1,2,71), (2,3,66), (3,4,72), (4,5,64), (5,6,70), (6,7,67)]),
[(0,1,65), (1,2,71), (2,3,66), (3,4,72), (4,5,64), (5,6,70), (6,7,67)])
def test14_maxcard(self):
self.assertEqual(
adj([(1,2,5), (2,3,11), (3,4,5)]),
[(1,2,30), (2,3,36), (3,4,30)])
def test16_negative(self):
self.assertEqual(
adj([(1,2,2), (1,3,-2), (2,3,1), (2,4,-1), (3,4,-6)]),
[(1,2,48), (1,3,44), (2,3,47), (2,4,45), (3,4,40)])
class TestMaximumCardinalityMatching(unittest.TestCase):
"""Test maximum cardinality matching."""
def test14_maxcard(self):
"""maximum cardinality"""
self.assertEqual(
mwm(adj([(1,2,5), (2,3,11), (3,4,5)])),
[(1,2), (3,4)])
def test16_negative(self):
"""negative weights"""
self.assertEqual(
mwm(adj([(1,2,2), (1,3,-2), (2,3,1), (2,4,-1), (3,4,-6)])),
[(1,3), (2,4)])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()