1
0
Fork 0

Implement heap-based tracking for delta2

This commit is contained in:
Joris van Rantwijk 2024-05-28 21:29:57 +02:00
parent 7cc1666cf2
commit 225311dae0
2 changed files with 166 additions and 85 deletions

View File

@ -112,6 +112,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
self.split_nodes.clear() self.split_nodes.clear()
# Wipe pointers to enable refcounted garbage collection. # Wipe pointers to enable refcounted garbage collection.
if node is not None:
node.owner = None
while node is not None: while node is not None:
prev_node = node prev_node = node
if node.left is not None: if node.left is not None:

View File

@ -383,6 +383,19 @@ class _Blossom:
# "tree_edge = None" if the blossom is the root of an alternating tree. # "tree_edge = None" if the blossom is the root of an alternating tree.
self.tree_edge: Optional[tuple[int, int]] = None self.tree_edge: Optional[tuple[int, int]] = None
# Each top-level blossom maintains a union-find datastructure
# containing all vertices in the blossom.
self.vertex_set: "UnionFindQueue[_Blossom, int]"
self.vertex_set = UnionFindQueue(self)
# If this is a top-level unlabeled blossom with an edge to an
# S-blossom, "delta2_node" is the corresponding node in the delta2
# queue.
self.delta2_node: Optional[PriorityQueue.Node] = None
# Support variable for lazy updating of vertex dual variables.
self.vertex_dual_offset: float = 0
# "marker" is a temporary variable used to discover common # "marker" is a temporary variable used to discover common
# ancestors in the blossom tree. It is normally False, except # ancestors in the blossom tree. It is normally False, except
# when used by "trace_alternating_paths()". # when used by "trace_alternating_paths()".
@ -460,9 +473,6 @@ class _NonTrivialBlossom(_Blossom):
# Note that "dual_var" is invariant under delta steps. # Note that "dual_var" is invariant under delta steps.
self.dual_var: float = 0 self.dual_var: float = 0
# Support variable for lazy updating of vertex dual variables.
self.vertex_dual_offset: float = 0
# If this is a top-level T-blossom, # If this is a top-level T-blossom,
# "delta4_node" is the corresponding node in the delta4 queue. # "delta4_node" is the corresponding node in the delta4 queue.
# Otherwise "delta4_node" is None. # Otherwise "delta4_node" is None.
@ -543,6 +553,11 @@ class _MatchingContext:
# Initially all vertices are trivial top-level blossoms. # Initially all vertices are trivial top-level blossoms.
self.vertex_top_blossom: list[_Blossom] = self.trivial_blossom.copy() self.vertex_top_blossom: list[_Blossom] = self.trivial_blossom.copy()
# "vertex_set_node[x]" represents the vertex "x" inside the
# union-find datastructure of its top-level blossom.
self.vertex_set_node = [b.vertex_set.insert(i, math.inf)
for (i, b) in enumerate(self.trivial_blossom)]
# All vertex duals are initialized to half the maximum edge weight. # All vertex duals are initialized to half the maximum edge weight.
# #
# "start_vertex_dual_2x" is 2 times the initial vertex dual value. # "start_vertex_dual_2x" is 2 times the initial vertex dual value.
@ -572,14 +587,19 @@ class _MatchingContext:
# Running sum of applied delta steps times 2. # Running sum of applied delta steps times 2.
self.delta_sum_2x: float = 0 self.delta_sum_2x: float = 0
# Queue containing unlabeled top-level blossoms that have an edge to
# an S-blossom. The priority of a blossom is 2 times the least slack
# to an S blossom, plus 2 times the running sum of delta steps.
self.delta2_queue: PriorityQueue[_Blossom] = PriorityQueue()
# Queue containing edges between S-vertices in different top-level # Queue containing edges between S-vertices in different top-level
# blossoms. The priority of an edge is its slack plus two times the # blossoms. The priority of an edge is its slack plus 2 times the
# running sum of delta steps. # running sum of delta steps.
self.delta3_queue: PriorityQueue[int] = PriorityQueue() self.delta3_queue: PriorityQueue[int] = PriorityQueue()
self.delta3_set: set[int] = set() self.delta3_set: set[int] = set()
# Queue containing top-level non-trivial T-blossoms. # Queue containing top-level non-trivial T-blossoms.
# The priority of a blossom is its dual plus two times the running # The priority of a blossom is its dual plus 2 times the running
# sum of delta steps. # sum of delta steps.
self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue() self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue()
@ -591,7 +611,18 @@ class _MatchingContext:
# Queue of S-vertices to be scanned. # Queue of S-vertices to be scanned.
self.scan_queue: collections.deque[int] = collections.deque() self.scan_queue: collections.deque[int] = collections.deque()
def edge_slack_2x( def __del__(self) -> None:
"""Delete reference cycles during cleanup of the matching context."""
for blossom in self.trivial_blossom:
blossom.vertex_set.clear()
del blossom.vertex_set
for blossom in self.nontrivial_blossom:
blossom.vertex_set.clear()
del blossom.vertex_set
def edge_slack_2x_info(
self, self,
x: int, x: int,
y: int, y: int,
@ -619,12 +650,17 @@ class _MatchingContext:
dual_2x += self.delta_sum_2x dual_2x += self.delta_sum_2x
if by.label == _LABEL_T: if by.label == _LABEL_T:
dual_2x += self.delta_sum_2x dual_2x += self.delta_sum_2x
if isinstance(bx, _NonTrivialBlossom):
dual_2x += bx.vertex_dual_offset dual_2x += bx.vertex_dual_offset
if isinstance(by, _NonTrivialBlossom):
dual_2x += by.vertex_dual_offset dual_2x += by.vertex_dual_offset
return dual_2x - 2 * w return dual_2x - 2 * w
def edge_slack_2x(self, e: int) -> float:
"""Return 2 times the slack of the edge with index "e"."""
(x, y, w) = self.graph.edges[e]
bx = self.vertex_top_blossom[x]
by = self.vertex_top_blossom[y]
return self.edge_slack_2x_info(x, y, bx, by, w)
# #
# Least-slack edge tracking: # Least-slack edge tracking:
# #
@ -649,57 +685,80 @@ class _MatchingContext:
def lset_reset(self) -> None: def lset_reset(self) -> None:
"""Reset least-slack edge tracking. """Reset least-slack edge tracking.
This function takes time O(n). This function takes time O(n * log(n)).
""" """
num_vertex = self.graph.num_vertex num_vertex = self.graph.num_vertex
for x in range(num_vertex): for x in range(num_vertex):
self.vertex_best_edge[x] = -1 self.vertex_best_edge[x] = -1
self.vertex_set_node[x].set_prio(math.inf)
def lset_add_vertex_edge(self, y: int, e: int, slack: float) -> None: self.delta2_queue.clear()
for blossom in self.trivial_blossom + self.nontrivial_blossom:
blossom.delta2_node = None
def lset_add_vertex_edge(
self,
y: int,
by: _Blossom,
e: int,
slack: float
) -> None:
"""Add edge "e" from an S-vertex to unlabeled vertex or T-vertex "y". """Add edge "e" from an S-vertex to unlabeled vertex or T-vertex "y".
This function takes time O(1) per call. This function takes time O(log(n)).
This function is called O(m) times per stage.
""" """
# TODO -- Simplify: We don't need to know the true slack of the edge,
# only the pseudo-slack based on raw vertex duals and weight.
best_edge = self.vertex_best_edge[y] best_edge = self.vertex_best_edge[y]
if best_edge == -1: if best_edge != -1:
self.vertex_best_edge[y] = e best_slack = self.edge_slack_2x(best_edge)
else: if slack >= best_slack:
(xx, yy, w) = self.graph.edges[best_edge] return
bx = self.vertex_top_blossom[xx]
by = self.vertex_top_blossom[yy]
best_slack = self.edge_slack_2x(xx, yy, bx, by, w)
if slack < best_slack:
self.vertex_best_edge[y] = e self.vertex_best_edge[y] = e
(p, q, w) = self.graph.edges[e]
prio = self.vertex_dual_2x[p] + self.vertex_dual_2x[q] - 2 * w
prev_min = by.vertex_set.min_prio()
self.vertex_set_node[y].set_prio(prio)
if (by.label == _LABEL_NONE) and (prio < prev_min):
prio = slack + self.delta_sum_2x
if by.delta2_node is None:
by.delta2_node = self.delta2_queue.insert(prio, by)
elif prio < by.delta2_node.prio:
self.delta2_queue.decrease_prio(by.delta2_node, prio)
def lset_get_best_vertex_edge(self) -> tuple[int, float]: def lset_get_best_vertex_edge(self) -> tuple[int, float]:
"""Return the index and slack of the least-slack edge between """Return the index and slack of the least-slack edge between
any S-vertex and unlabeled vertex. any S-vertex and unlabeled vertex.
This function takes time O(n) per call. This function takes time O(log(n)).
This function takes total time O(n**2) per stage.
Returns: Returns:
Tuple (edge_index, slack_2x) if there is a least-slack edge, Tuple (edge_index, slack_2x) if there is a least-slack edge,
or (-1, 0) if there is no suitable edge. or (-1, 0) if there is no suitable edge.
""" """
best_index = -1
best_slack: float = 0
for x in range(self.graph.num_vertex): if self.delta2_queue.empty():
if self.vertex_top_blossom[x].label == _LABEL_NONE: return (-1, 0)
delta2_node = self.delta2_queue.find_min()
blossom = delta2_node.data
prio = delta2_node.prio
slack_2x = prio - self.delta_sum_2x
assert blossom.parent is None
assert blossom.label == _LABEL_NONE
x = blossom.vertex_set.min_elem()
e = self.vertex_best_edge[x] e = self.vertex_best_edge[x]
if e != -1: assert e >= 0
(x, y, w) = self.graph.edges[e]
bx = self.vertex_top_blossom[x]
by = self.vertex_top_blossom[y]
slack = self.edge_slack_2x(x, y, bx, by, w)
if (best_index == -1) or (slack < best_slack):
best_index = e
best_slack = slack
return (best_index, best_slack) return (e, slack_2x)
# #
# General support routines: # General support routines:
@ -711,6 +770,11 @@ class _MatchingContext:
assert blossom.label == _LABEL_NONE assert blossom.label == _LABEL_NONE
blossom.label = _LABEL_S blossom.label = _LABEL_S
# Delete blossom from delta2 queue.
if blossom.delta2_node is not None:
self.delta2_queue.delete(blossom.delta2_node)
blossom.delta2_node = None
# Prepare for lazy updating of S-blossom dual variable. # Prepare for lazy updating of S-blossom dual variable.
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, _NonTrivialBlossom):
blossom.dual_var -= self.delta_sum_2x blossom.dual_var -= self.delta_sum_2x
@ -733,11 +797,8 @@ class _MatchingContext:
self.scan_queue.extend(vertices) self.scan_queue.extend(vertices)
# Prepare for lazy updating of S-vertex dual variables. # Prepare for lazy updating of S-vertex dual variables.
vertex_dual_fixup = self.delta_sum_2x vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset
if isinstance(blossom, _NonTrivialBlossom):
vertex_dual_fixup += blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0 blossom.vertex_dual_offset = 0
for x in vertices: for x in vertices:
self.vertex_dual_2x[x] += vertex_dual_fixup self.vertex_dual_2x[x] += vertex_dual_fixup
@ -748,23 +809,23 @@ class _MatchingContext:
assert blossom.label == _LABEL_NONE assert blossom.label == _LABEL_NONE
blossom.label = _LABEL_T blossom.label = _LABEL_T
# Delete blossom from delta2 queue.
if blossom.delta2_node is not None:
self.delta2_queue.delete(blossom.delta2_node)
blossom.delta2_node = None
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, _NonTrivialBlossom):
# Prepare for lazy updating of T-blossom dual variables. # Prepare for lazy updating of T-blossom dual variables.
blossom.dual_var += self.delta_sum_2x blossom.dual_var += self.delta_sum_2x
# Prepare for lazy updating of T-vertex dual variables.
blossom.vertex_dual_offset -= self.delta_sum_2x
# Insert blossom into the delta4 queue. # Insert blossom into the delta4 queue.
assert blossom.delta4_node is None assert blossom.delta4_node is None
blossom.delta4_node = self.delta4_queue.insert(blossom.dual_var, blossom.delta4_node = self.delta4_queue.insert(blossom.dual_var,
blossom) blossom)
else:
# Prepare for lazy updating of T-vertex dual variables. # Prepare for lazy updating of T-vertex dual variables.
self.vertex_dual_2x[blossom.base_vertex] -= self.delta_sum_2x blossom.vertex_dual_offset -= self.delta_sum_2x
def remove_blossom_label_t(self, blossom: _Blossom) -> None: def remove_blossom_label_t(self, blossom: _Blossom) -> None:
"""Remove label T from a top-level T-blossom.""" """Remove label T from a top-level T-blossom."""
@ -786,10 +847,6 @@ class _MatchingContext:
# Unwind lazy updates of T-vertex dual variables. # Unwind lazy updates of T-vertex dual variables.
blossom.vertex_dual_offset += self.delta_sum_2x blossom.vertex_dual_offset += self.delta_sum_2x
else:
# Unwind lazy updates of T-vertex dual variables.
self.vertex_dual_2x[blossom.base_vertex] += self.delta_sum_2x
def reset_blossom_label(self, blossom: _Blossom) -> None: def reset_blossom_label(self, blossom: _Blossom) -> None:
"""Remove blossom label and calculate true dual variables.""" """Remove blossom label and calculate true dual variables."""
@ -803,9 +860,9 @@ class _MatchingContext:
# Unwind lazy delta updates to S-blossom dual variable. # Unwind lazy delta updates to S-blossom dual variable.
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, _NonTrivialBlossom):
blossom.dual_var += self.delta_sum_2x blossom.dual_var += self.delta_sum_2x
assert blossom.vertex_dual_offset == 0
# Unwind lazy delta updates to S-vertex dual variables. # Unwind lazy delta updates to S-vertex dual variables.
assert blossom.vertex_dual_offset == 0
vertex_dual_fixup = -self.delta_sum_2x vertex_dual_fixup = -self.delta_sum_2x
for x in blossom.vertices(): for x in blossom.vertices():
self.vertex_dual_2x[x] += vertex_dual_fixup self.vertex_dual_2x[x] += vertex_dual_fixup
@ -815,28 +872,21 @@ class _MatchingContext:
# Remove label. # Remove label.
blossom.label = _LABEL_NONE blossom.label = _LABEL_NONE
# Prepare to unwind lazy delta updates to vertices.
vertex_dual_fixup = self.delta_sum_2x
if isinstance(blossom, _NonTrivialBlossom):
# Unwind lazy delta updates to T-blossom dual variable. # Unwind lazy delta updates to T-blossom dual variable.
if isinstance(blossom, _NonTrivialBlossom):
blossom.dual_var -= self.delta_sum_2x blossom.dual_var -= self.delta_sum_2x
# Prepare to unwind lazy delta updates to vertices.
vertex_dual_fixup += blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0
# Unwind lazy delta updates to T-vertex dual variables. # Unwind lazy delta updates to T-vertex dual variables.
vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0
for x in blossom.vertices(): for x in blossom.vertices():
self.vertex_dual_2x[x] += vertex_dual_fixup self.vertex_dual_2x[x] += vertex_dual_fixup
else: else:
# Unwind lazy delta updates to vertex dual variables. # Unwind lazy delta updates to vertex dual variables.
if isinstance(blossom, _NonTrivialBlossom):
vertex_dual_fixup = blossom.vertex_dual_offset vertex_dual_fixup = blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0 blossom.vertex_dual_offset = 0
for x in blossom.vertices(): for x in blossom.vertices():
self.vertex_dual_2x[x] += vertex_dual_fixup self.vertex_dual_2x[x] += vertex_dual_fixup
@ -846,7 +896,7 @@ class _MatchingContext:
Marks all blossoms as unlabeled, clears the queue, Marks all blossoms as unlabeled, clears the queue,
and resets tracking of least-slack edges. and resets tracking of least-slack edges.
This function takes time O(n). This function takes time O(n * log(n)).
""" """
# Remove blossom labels and unwind lazy dual updates. # Remove blossom labels and unwind lazy dual updates.
@ -1007,6 +1057,9 @@ class _MatchingContext:
for x in blossom.vertices(): for x in blossom.vertices():
self.vertex_top_blossom[x] = blossom self.vertex_top_blossom[x] = blossom
# Merge union-find structures.
blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms])
@staticmethod @staticmethod
def find_path_through_blossom( def find_path_through_blossom(
blossom: _NonTrivialBlossom, blossom: _NonTrivialBlossom,
@ -1048,6 +1101,9 @@ class _MatchingContext:
self.delta4_queue.delete(blossom.delta4_node) self.delta4_queue.delete(blossom.delta4_node)
blossom.delta4_node = None blossom.delta4_node = None
# Split union-find structure.
blossom.vertex_set.split()
# Prepare to push lazy delta updates down to the sub-blossoms. # Prepare to push lazy delta updates down to the sub-blossoms.
vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0 blossom.vertex_dual_offset = 0
@ -1056,14 +1112,22 @@ class _MatchingContext:
for sub in blossom.subblossoms: for sub in blossom.subblossoms:
assert sub.label == _LABEL_NONE assert sub.label == _LABEL_NONE
sub.parent = None sub.parent = None
if isinstance(sub, _NonTrivialBlossom):
assert sub.vertex_dual_offset == 0
sub.vertex_dual_offset = vertex_dual_fixup sub.vertex_dual_offset = vertex_dual_fixup
if isinstance(sub, _NonTrivialBlossom):
for x in sub.vertices(): for x in sub.vertices():
self.vertex_top_blossom[x] = sub self.vertex_top_blossom[x] = sub
else: else:
x = sub.base_vertex self.vertex_top_blossom[sub.base_vertex] = sub
self.vertex_dual_2x[x] += vertex_dual_fixup
self.vertex_top_blossom[x] = sub # Insert blossom in delta2_queue if necessary.
prio = sub.vertex_set.min_prio()
if prio < math.inf:
assert sub.delta2_node is None
prio += sub.vertex_dual_offset
sub.delta2_node = self.delta2_queue.insert(prio, sub)
# The expanding blossom was part of an alternating tree, linked to # The expanding blossom was part of an alternating tree, linked to
# a parent node in the tree via one of its subblossoms, and linked to # a parent node in the tree via one of its subblossoms, and linked to
@ -1117,6 +1181,9 @@ class _MatchingContext:
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_NONE assert blossom.label == _LABEL_NONE
# Split union-find structure.
blossom.vertex_set.split()
# Prepare to push lazy delta updates down to the sub-blossoms. # Prepare to push lazy delta updates down to the sub-blossoms.
vertex_dual_offset = blossom.vertex_dual_offset vertex_dual_offset = blossom.vertex_dual_offset
blossom.vertex_dual_offset = 0 blossom.vertex_dual_offset = 0
@ -1126,15 +1193,21 @@ class _MatchingContext:
assert sub.label == _LABEL_NONE assert sub.label == _LABEL_NONE
sub.parent = None sub.parent = None
if isinstance(sub, _NonTrivialBlossom):
assert sub.vertex_dual_offset == 0 assert sub.vertex_dual_offset == 0
sub.vertex_dual_offset = vertex_dual_offset sub.vertex_dual_offset = vertex_dual_offset
if isinstance(sub, _NonTrivialBlossom):
for x in sub.vertices(): for x in sub.vertices():
self.vertex_top_blossom[x] = sub self.vertex_top_blossom[x] = sub
else: else:
x = sub.base_vertex self.vertex_top_blossom[sub.base_vertex] = sub
self.vertex_dual_2x[x] += vertex_dual_offset
self.vertex_top_blossom[x] = sub # Insert blossom in delta2_queue if necessary.
prio = sub.vertex_set.min_prio()
if prio < math.inf:
assert sub.delta2_node is None
prio += sub.vertex_dual_offset
sub.delta2_node = self.delta2_queue.insert(prio, sub)
# Delete the expanded blossom. # Delete the expanded blossom.
self.nontrivial_blossom.remove(blossom) self.nontrivial_blossom.remove(blossom)
@ -1427,7 +1500,7 @@ class _MatchingContext:
# Check whether this edge is tight (has zero slack). # Check whether this edge is tight (has zero slack).
# Only tight edges may be part of an alternating tree. # Only tight edges may be part of an alternating tree.
slack = self.edge_slack_2x(x, y, bx, by, w) slack = self.edge_slack_2x_info(x, y, bx, by, w)
if slack <= 0: if slack <= 0:
if ylabel == _LABEL_NONE: if ylabel == _LABEL_NONE:
# Assign label T to the blossom that contains "y". # Assign label T to the blossom that contains "y".
@ -1460,7 +1533,7 @@ class _MatchingContext:
# any S-vertex. We do this for T-vertices and unlabeled # any S-vertex. We do this for T-vertices and unlabeled
# vertices. Edges which already have zero slack are still # vertices. Edges which already have zero slack are still
# tracked. # tracked.
self.lset_add_vertex_edge(y, e, slack) self.lset_add_vertex_edge(y, by, e, slack)
# No further S vertices to scan, and no augmenting path found. # No further S vertices to scan, and no augmenting path found.
return None return None
@ -1483,7 +1556,8 @@ class _MatchingContext:
weights are integers. weights are integers.
This function assumes that there is at least one S-vertex. This function assumes that there is at least one S-vertex.
This function takes time O(n). This function takes total time O(m * log(n)) for all calls
within a stage.
Returns: Returns:
Tuple (delta_type, delta_2x, delta_edge, delta_blossom). Tuple (delta_type, delta_2x, delta_edge, delta_blossom).
@ -1499,6 +1573,7 @@ class _MatchingContext:
# Compute delta2: minimum slack of any edge between an S-vertex and # Compute delta2: minimum slack of any edge between an S-vertex and
# an unlabeled vertex. # an unlabeled vertex.
# This takes time O(log(n)).
(e, slack) = self.lset_get_best_vertex_edge() (e, slack) = self.lset_get_best_vertex_edge()
if (e != -1) and (slack <= delta_2x): if (e != -1) and (slack <= delta_2x):
delta_type = 2 delta_type = 2
@ -1507,6 +1582,9 @@ class _MatchingContext:
# Compute delta3: half minimum slack of any edge between two top-level # Compute delta3: half minimum slack of any edge between two top-level
# S-blossoms. # S-blossoms.
#
# This loop iterates O(m) times per stage.
# Each iteration takes time O(log(n)).
while not self.delta3_queue.empty(): while not self.delta3_queue.empty():
delta3_node = self.delta3_queue.find_min() delta3_node = self.delta3_queue.find_min()
e = delta3_node.data e = delta3_node.data
@ -1531,6 +1609,7 @@ class _MatchingContext:
self.delta3_set.remove(e) self.delta3_set.remove(e)
# Compute delta4: half minimum dual variable of a top-level T-blossom. # Compute delta4: half minimum dual variable of a top-level T-blossom.
# This takes time O(log(n)).
if not self.delta4_queue.empty(): if not self.delta4_queue.empty():
blossom = self.delta4_queue.find_min().data blossom = self.delta4_queue.find_min().data
assert blossom.label == _LABEL_T assert blossom.label == _LABEL_T