Implement heap-based tracking for delta2
This commit is contained in:
parent
7cc1666cf2
commit
225311dae0
|
@ -112,6 +112,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]):
|
|||
self.split_nodes.clear()
|
||||
|
||||
# Wipe pointers to enable refcounted garbage collection.
|
||||
if node is not None:
|
||||
node.owner = None
|
||||
while node is not None:
|
||||
prev_node = node
|
||||
if node.left is not None:
|
||||
|
|
|
@ -383,6 +383,19 @@ class _Blossom:
|
|||
# "tree_edge = None" if the blossom is the root of an alternating tree.
|
||||
self.tree_edge: Optional[tuple[int, int]] = None
|
||||
|
||||
# Each top-level blossom maintains a union-find datastructure
|
||||
# containing all vertices in the blossom.
|
||||
self.vertex_set: "UnionFindQueue[_Blossom, int]"
|
||||
self.vertex_set = UnionFindQueue(self)
|
||||
|
||||
# If this is a top-level unlabeled blossom with an edge to an
|
||||
# S-blossom, "delta2_node" is the corresponding node in the delta2
|
||||
# queue.
|
||||
self.delta2_node: Optional[PriorityQueue.Node] = None
|
||||
|
||||
# Support variable for lazy updating of vertex dual variables.
|
||||
self.vertex_dual_offset: float = 0
|
||||
|
||||
# "marker" is a temporary variable used to discover common
|
||||
# ancestors in the blossom tree. It is normally False, except
|
||||
# when used by "trace_alternating_paths()".
|
||||
|
@ -460,9 +473,6 @@ class _NonTrivialBlossom(_Blossom):
|
|||
# Note that "dual_var" is invariant under delta steps.
|
||||
self.dual_var: float = 0
|
||||
|
||||
# Support variable for lazy updating of vertex dual variables.
|
||||
self.vertex_dual_offset: float = 0
|
||||
|
||||
# If this is a top-level T-blossom,
|
||||
# "delta4_node" is the corresponding node in the delta4 queue.
|
||||
# Otherwise "delta4_node" is None.
|
||||
|
@ -543,6 +553,11 @@ class _MatchingContext:
|
|||
# Initially all vertices are trivial top-level blossoms.
|
||||
self.vertex_top_blossom: list[_Blossom] = self.trivial_blossom.copy()
|
||||
|
||||
# "vertex_set_node[x]" represents the vertex "x" inside the
|
||||
# union-find datastructure of its top-level blossom.
|
||||
self.vertex_set_node = [b.vertex_set.insert(i, math.inf)
|
||||
for (i, b) in enumerate(self.trivial_blossom)]
|
||||
|
||||
# All vertex duals are initialized to half the maximum edge weight.
|
||||
#
|
||||
# "start_vertex_dual_2x" is 2 times the initial vertex dual value.
|
||||
|
@ -572,14 +587,19 @@ class _MatchingContext:
|
|||
# Running sum of applied delta steps times 2.
|
||||
self.delta_sum_2x: float = 0
|
||||
|
||||
# Queue containing unlabeled top-level blossoms that have an edge to
|
||||
# an S-blossom. The priority of a blossom is 2 times the least slack
|
||||
# to an S blossom, plus 2 times the running sum of delta steps.
|
||||
self.delta2_queue: PriorityQueue[_Blossom] = PriorityQueue()
|
||||
|
||||
# Queue containing edges between S-vertices in different top-level
|
||||
# blossoms. The priority of an edge is its slack plus two times the
|
||||
# blossoms. The priority of an edge is its slack plus 2 times the
|
||||
# running sum of delta steps.
|
||||
self.delta3_queue: PriorityQueue[int] = PriorityQueue()
|
||||
self.delta3_set: set[int] = set()
|
||||
|
||||
# Queue containing top-level non-trivial T-blossoms.
|
||||
# The priority of a blossom is its dual plus two times the running
|
||||
# The priority of a blossom is its dual plus 2 times the running
|
||||
# sum of delta steps.
|
||||
self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue()
|
||||
|
||||
|
@ -591,7 +611,18 @@ class _MatchingContext:
|
|||
# Queue of S-vertices to be scanned.
|
||||
self.scan_queue: collections.deque[int] = collections.deque()
|
||||
|
||||
def edge_slack_2x(
|
||||
def __del__(self) -> None:
|
||||
"""Delete reference cycles during cleanup of the matching context."""
|
||||
|
||||
for blossom in self.trivial_blossom:
|
||||
blossom.vertex_set.clear()
|
||||
del blossom.vertex_set
|
||||
|
||||
for blossom in self.nontrivial_blossom:
|
||||
blossom.vertex_set.clear()
|
||||
del blossom.vertex_set
|
||||
|
||||
def edge_slack_2x_info(
|
||||
self,
|
||||
x: int,
|
||||
y: int,
|
||||
|
@ -619,12 +650,17 @@ class _MatchingContext:
|
|||
dual_2x += self.delta_sum_2x
|
||||
if by.label == _LABEL_T:
|
||||
dual_2x += self.delta_sum_2x
|
||||
if isinstance(bx, _NonTrivialBlossom):
|
||||
dual_2x += bx.vertex_dual_offset
|
||||
if isinstance(by, _NonTrivialBlossom):
|
||||
dual_2x += by.vertex_dual_offset
|
||||
dual_2x += bx.vertex_dual_offset
|
||||
dual_2x += by.vertex_dual_offset
|
||||
return dual_2x - 2 * w
|
||||
|
||||
def edge_slack_2x(self, e: int) -> float:
|
||||
"""Return 2 times the slack of the edge with index "e"."""
|
||||
(x, y, w) = self.graph.edges[e]
|
||||
bx = self.vertex_top_blossom[x]
|
||||
by = self.vertex_top_blossom[y]
|
||||
return self.edge_slack_2x_info(x, y, bx, by, w)
|
||||
|
||||
#
|
||||
# Least-slack edge tracking:
|
||||
#
|
||||
|
@ -649,57 +685,80 @@ class _MatchingContext:
|
|||
def lset_reset(self) -> None:
|
||||
"""Reset least-slack edge tracking.
|
||||
|
||||
This function takes time O(n).
|
||||
This function takes time O(n * log(n)).
|
||||
"""
|
||||
num_vertex = self.graph.num_vertex
|
||||
|
||||
for x in range(num_vertex):
|
||||
self.vertex_best_edge[x] = -1
|
||||
self.vertex_set_node[x].set_prio(math.inf)
|
||||
|
||||
def lset_add_vertex_edge(self, y: int, e: int, slack: float) -> None:
|
||||
self.delta2_queue.clear()
|
||||
|
||||
for blossom in self.trivial_blossom + self.nontrivial_blossom:
|
||||
blossom.delta2_node = None
|
||||
|
||||
def lset_add_vertex_edge(
|
||||
self,
|
||||
y: int,
|
||||
by: _Blossom,
|
||||
e: int,
|
||||
slack: float
|
||||
) -> None:
|
||||
"""Add edge "e" from an S-vertex to unlabeled vertex or T-vertex "y".
|
||||
|
||||
This function takes time O(1) per call.
|
||||
This function is called O(m) times per stage.
|
||||
This function takes time O(log(n)).
|
||||
"""
|
||||
|
||||
# TODO -- Simplify: We don't need to know the true slack of the edge,
|
||||
# only the pseudo-slack based on raw vertex duals and weight.
|
||||
|
||||
best_edge = self.vertex_best_edge[y]
|
||||
if best_edge == -1:
|
||||
self.vertex_best_edge[y] = e
|
||||
else:
|
||||
(xx, yy, w) = self.graph.edges[best_edge]
|
||||
bx = self.vertex_top_blossom[xx]
|
||||
by = self.vertex_top_blossom[yy]
|
||||
best_slack = self.edge_slack_2x(xx, yy, bx, by, w)
|
||||
if slack < best_slack:
|
||||
self.vertex_best_edge[y] = e
|
||||
if best_edge != -1:
|
||||
best_slack = self.edge_slack_2x(best_edge)
|
||||
if slack >= best_slack:
|
||||
return
|
||||
|
||||
self.vertex_best_edge[y] = e
|
||||
|
||||
(p, q, w) = self.graph.edges[e]
|
||||
prio = self.vertex_dual_2x[p] + self.vertex_dual_2x[q] - 2 * w
|
||||
prev_min = by.vertex_set.min_prio()
|
||||
self.vertex_set_node[y].set_prio(prio)
|
||||
|
||||
if (by.label == _LABEL_NONE) and (prio < prev_min):
|
||||
prio = slack + self.delta_sum_2x
|
||||
if by.delta2_node is None:
|
||||
by.delta2_node = self.delta2_queue.insert(prio, by)
|
||||
elif prio < by.delta2_node.prio:
|
||||
self.delta2_queue.decrease_prio(by.delta2_node, prio)
|
||||
|
||||
def lset_get_best_vertex_edge(self) -> tuple[int, float]:
|
||||
"""Return the index and slack of the least-slack edge between
|
||||
any S-vertex and unlabeled vertex.
|
||||
|
||||
This function takes time O(n) per call.
|
||||
This function takes total time O(n**2) per stage.
|
||||
This function takes time O(log(n)).
|
||||
|
||||
Returns:
|
||||
Tuple (edge_index, slack_2x) if there is a least-slack edge,
|
||||
or (-1, 0) if there is no suitable edge.
|
||||
"""
|
||||
best_index = -1
|
||||
best_slack: float = 0
|
||||
|
||||
for x in range(self.graph.num_vertex):
|
||||
if self.vertex_top_blossom[x].label == _LABEL_NONE:
|
||||
e = self.vertex_best_edge[x]
|
||||
if e != -1:
|
||||
(x, y, w) = self.graph.edges[e]
|
||||
bx = self.vertex_top_blossom[x]
|
||||
by = self.vertex_top_blossom[y]
|
||||
slack = self.edge_slack_2x(x, y, bx, by, w)
|
||||
if (best_index == -1) or (slack < best_slack):
|
||||
best_index = e
|
||||
best_slack = slack
|
||||
if self.delta2_queue.empty():
|
||||
return (-1, 0)
|
||||
|
||||
return (best_index, best_slack)
|
||||
delta2_node = self.delta2_queue.find_min()
|
||||
blossom = delta2_node.data
|
||||
prio = delta2_node.prio
|
||||
slack_2x = prio - self.delta_sum_2x
|
||||
assert blossom.parent is None
|
||||
assert blossom.label == _LABEL_NONE
|
||||
|
||||
x = blossom.vertex_set.min_elem()
|
||||
e = self.vertex_best_edge[x]
|
||||
assert e >= 0
|
||||
|
||||
return (e, slack_2x)
|
||||
|
||||
#
|
||||
# General support routines:
|
||||
|
@ -711,6 +770,11 @@ class _MatchingContext:
|
|||
assert blossom.label == _LABEL_NONE
|
||||
blossom.label = _LABEL_S
|
||||
|
||||
# Delete blossom from delta2 queue.
|
||||
if blossom.delta2_node is not None:
|
||||
self.delta2_queue.delete(blossom.delta2_node)
|
||||
blossom.delta2_node = None
|
||||
|
||||
# Prepare for lazy updating of S-blossom dual variable.
|
||||
if isinstance(blossom, _NonTrivialBlossom):
|
||||
blossom.dual_var -= self.delta_sum_2x
|
||||
|
@ -733,11 +797,8 @@ class _MatchingContext:
|
|||
self.scan_queue.extend(vertices)
|
||||
|
||||
# Prepare for lazy updating of S-vertex dual variables.
|
||||
vertex_dual_fixup = self.delta_sum_2x
|
||||
if isinstance(blossom, _NonTrivialBlossom):
|
||||
vertex_dual_fixup += blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
|
||||
vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
for x in vertices:
|
||||
self.vertex_dual_2x[x] += vertex_dual_fixup
|
||||
|
||||
|
@ -748,23 +809,23 @@ class _MatchingContext:
|
|||
assert blossom.label == _LABEL_NONE
|
||||
blossom.label = _LABEL_T
|
||||
|
||||
# Delete blossom from delta2 queue.
|
||||
if blossom.delta2_node is not None:
|
||||
self.delta2_queue.delete(blossom.delta2_node)
|
||||
blossom.delta2_node = None
|
||||
|
||||
if isinstance(blossom, _NonTrivialBlossom):
|
||||
|
||||
# Prepare for lazy updating of T-blossom dual variables.
|
||||
blossom.dual_var += self.delta_sum_2x
|
||||
|
||||
# Prepare for lazy updating of T-vertex dual variables.
|
||||
blossom.vertex_dual_offset -= self.delta_sum_2x
|
||||
|
||||
# Insert blossom into the delta4 queue.
|
||||
assert blossom.delta4_node is None
|
||||
blossom.delta4_node = self.delta4_queue.insert(blossom.dual_var,
|
||||
blossom)
|
||||
|
||||
else:
|
||||
# Prepare for lazy updating of T-vertex dual variables.
|
||||
self.vertex_dual_2x[blossom.base_vertex] -= self.delta_sum_2x
|
||||
|
||||
# Prepare for lazy updating of T-vertex dual variables.
|
||||
blossom.vertex_dual_offset -= self.delta_sum_2x
|
||||
|
||||
def remove_blossom_label_t(self, blossom: _Blossom) -> None:
|
||||
"""Remove label T from a top-level T-blossom."""
|
||||
|
@ -783,12 +844,8 @@ class _MatchingContext:
|
|||
# Unwind lazy updates to T-blossom dual variable.
|
||||
blossom.dual_var -= self.delta_sum_2x
|
||||
|
||||
# Unwind lazy updates of T-vertex dual variables.
|
||||
blossom.vertex_dual_offset += self.delta_sum_2x
|
||||
|
||||
else:
|
||||
# Unwind lazy updates of T-vertex dual variables.
|
||||
self.vertex_dual_2x[blossom.base_vertex] += self.delta_sum_2x
|
||||
# Unwind lazy updates of T-vertex dual variables.
|
||||
blossom.vertex_dual_offset += self.delta_sum_2x
|
||||
|
||||
def reset_blossom_label(self, blossom: _Blossom) -> None:
|
||||
"""Remove blossom label and calculate true dual variables."""
|
||||
|
@ -803,9 +860,9 @@ class _MatchingContext:
|
|||
# Unwind lazy delta updates to S-blossom dual variable.
|
||||
if isinstance(blossom, _NonTrivialBlossom):
|
||||
blossom.dual_var += self.delta_sum_2x
|
||||
assert blossom.vertex_dual_offset == 0
|
||||
|
||||
# Unwind lazy delta updates to S-vertex dual variables.
|
||||
assert blossom.vertex_dual_offset == 0
|
||||
vertex_dual_fixup = -self.delta_sum_2x
|
||||
for x in blossom.vertices():
|
||||
self.vertex_dual_2x[x] += vertex_dual_fixup
|
||||
|
@ -815,30 +872,23 @@ class _MatchingContext:
|
|||
# Remove label.
|
||||
blossom.label = _LABEL_NONE
|
||||
|
||||
# Prepare to unwind lazy delta updates to vertices.
|
||||
vertex_dual_fixup = self.delta_sum_2x
|
||||
|
||||
# Unwind lazy delta updates to T-blossom dual variable.
|
||||
if isinstance(blossom, _NonTrivialBlossom):
|
||||
# Unwind lazy delta updates to T-blossom dual variable.
|
||||
blossom.dual_var -= self.delta_sum_2x
|
||||
|
||||
# Prepare to unwind lazy delta updates to vertices.
|
||||
vertex_dual_fixup += blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
|
||||
# Unwind lazy delta updates to T-vertex dual variables.
|
||||
vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
for x in blossom.vertices():
|
||||
self.vertex_dual_2x[x] += vertex_dual_fixup
|
||||
|
||||
else:
|
||||
|
||||
# Unwind lazy delta updates to vertex dual variables.
|
||||
if isinstance(blossom, _NonTrivialBlossom):
|
||||
vertex_dual_fixup = blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
|
||||
for x in blossom.vertices():
|
||||
self.vertex_dual_2x[x] += vertex_dual_fixup
|
||||
vertex_dual_fixup = blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
for x in blossom.vertices():
|
||||
self.vertex_dual_2x[x] += vertex_dual_fixup
|
||||
|
||||
def reset_stage(self) -> None:
|
||||
"""Reset data which are only valid during a stage.
|
||||
|
@ -846,7 +896,7 @@ class _MatchingContext:
|
|||
Marks all blossoms as unlabeled, clears the queue,
|
||||
and resets tracking of least-slack edges.
|
||||
|
||||
This function takes time O(n).
|
||||
This function takes time O(n * log(n)).
|
||||
"""
|
||||
|
||||
# Remove blossom labels and unwind lazy dual updates.
|
||||
|
@ -1007,6 +1057,9 @@ class _MatchingContext:
|
|||
for x in blossom.vertices():
|
||||
self.vertex_top_blossom[x] = blossom
|
||||
|
||||
# Merge union-find structures.
|
||||
blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms])
|
||||
|
||||
@staticmethod
|
||||
def find_path_through_blossom(
|
||||
blossom: _NonTrivialBlossom,
|
||||
|
@ -1048,6 +1101,9 @@ class _MatchingContext:
|
|||
self.delta4_queue.delete(blossom.delta4_node)
|
||||
blossom.delta4_node = None
|
||||
|
||||
# Split union-find structure.
|
||||
blossom.vertex_set.split()
|
||||
|
||||
# Prepare to push lazy delta updates down to the sub-blossoms.
|
||||
vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
|
@ -1056,14 +1112,22 @@ class _MatchingContext:
|
|||
for sub in blossom.subblossoms:
|
||||
assert sub.label == _LABEL_NONE
|
||||
sub.parent = None
|
||||
|
||||
assert sub.vertex_dual_offset == 0
|
||||
sub.vertex_dual_offset = vertex_dual_fixup
|
||||
|
||||
if isinstance(sub, _NonTrivialBlossom):
|
||||
sub.vertex_dual_offset = vertex_dual_fixup
|
||||
for x in sub.vertices():
|
||||
self.vertex_top_blossom[x] = sub
|
||||
else:
|
||||
x = sub.base_vertex
|
||||
self.vertex_dual_2x[x] += vertex_dual_fixup
|
||||
self.vertex_top_blossom[x] = sub
|
||||
self.vertex_top_blossom[sub.base_vertex] = sub
|
||||
|
||||
# Insert blossom in delta2_queue if necessary.
|
||||
prio = sub.vertex_set.min_prio()
|
||||
if prio < math.inf:
|
||||
assert sub.delta2_node is None
|
||||
prio += sub.vertex_dual_offset
|
||||
sub.delta2_node = self.delta2_queue.insert(prio, sub)
|
||||
|
||||
# The expanding blossom was part of an alternating tree, linked to
|
||||
# a parent node in the tree via one of its subblossoms, and linked to
|
||||
|
@ -1117,6 +1181,9 @@ class _MatchingContext:
|
|||
assert blossom.parent is None
|
||||
assert blossom.label == _LABEL_NONE
|
||||
|
||||
# Split union-find structure.
|
||||
blossom.vertex_set.split()
|
||||
|
||||
# Prepare to push lazy delta updates down to the sub-blossoms.
|
||||
vertex_dual_offset = blossom.vertex_dual_offset
|
||||
blossom.vertex_dual_offset = 0
|
||||
|
@ -1126,15 +1193,21 @@ class _MatchingContext:
|
|||
assert sub.label == _LABEL_NONE
|
||||
sub.parent = None
|
||||
|
||||
assert sub.vertex_dual_offset == 0
|
||||
sub.vertex_dual_offset = vertex_dual_offset
|
||||
|
||||
if isinstance(sub, _NonTrivialBlossom):
|
||||
assert sub.vertex_dual_offset == 0
|
||||
sub.vertex_dual_offset = vertex_dual_offset
|
||||
for x in sub.vertices():
|
||||
self.vertex_top_blossom[x] = sub
|
||||
else:
|
||||
x = sub.base_vertex
|
||||
self.vertex_dual_2x[x] += vertex_dual_offset
|
||||
self.vertex_top_blossom[x] = sub
|
||||
self.vertex_top_blossom[sub.base_vertex] = sub
|
||||
|
||||
# Insert blossom in delta2_queue if necessary.
|
||||
prio = sub.vertex_set.min_prio()
|
||||
if prio < math.inf:
|
||||
assert sub.delta2_node is None
|
||||
prio += sub.vertex_dual_offset
|
||||
sub.delta2_node = self.delta2_queue.insert(prio, sub)
|
||||
|
||||
# Delete the expanded blossom.
|
||||
self.nontrivial_blossom.remove(blossom)
|
||||
|
@ -1427,7 +1500,7 @@ class _MatchingContext:
|
|||
|
||||
# Check whether this edge is tight (has zero slack).
|
||||
# Only tight edges may be part of an alternating tree.
|
||||
slack = self.edge_slack_2x(x, y, bx, by, w)
|
||||
slack = self.edge_slack_2x_info(x, y, bx, by, w)
|
||||
if slack <= 0:
|
||||
if ylabel == _LABEL_NONE:
|
||||
# Assign label T to the blossom that contains "y".
|
||||
|
@ -1460,7 +1533,7 @@ class _MatchingContext:
|
|||
# any S-vertex. We do this for T-vertices and unlabeled
|
||||
# vertices. Edges which already have zero slack are still
|
||||
# tracked.
|
||||
self.lset_add_vertex_edge(y, e, slack)
|
||||
self.lset_add_vertex_edge(y, by, e, slack)
|
||||
|
||||
# No further S vertices to scan, and no augmenting path found.
|
||||
return None
|
||||
|
@ -1483,7 +1556,8 @@ class _MatchingContext:
|
|||
weights are integers.
|
||||
|
||||
This function assumes that there is at least one S-vertex.
|
||||
This function takes time O(n).
|
||||
This function takes total time O(m * log(n)) for all calls
|
||||
within a stage.
|
||||
|
||||
Returns:
|
||||
Tuple (delta_type, delta_2x, delta_edge, delta_blossom).
|
||||
|
@ -1499,6 +1573,7 @@ class _MatchingContext:
|
|||
|
||||
# Compute delta2: minimum slack of any edge between an S-vertex and
|
||||
# an unlabeled vertex.
|
||||
# This takes time O(log(n)).
|
||||
(e, slack) = self.lset_get_best_vertex_edge()
|
||||
if (e != -1) and (slack <= delta_2x):
|
||||
delta_type = 2
|
||||
|
@ -1507,6 +1582,9 @@ class _MatchingContext:
|
|||
|
||||
# Compute delta3: half minimum slack of any edge between two top-level
|
||||
# S-blossoms.
|
||||
#
|
||||
# This loop iterates O(m) times per stage.
|
||||
# Each iteration takes time O(log(n)).
|
||||
while not self.delta3_queue.empty():
|
||||
delta3_node = self.delta3_queue.find_min()
|
||||
e = delta3_node.data
|
||||
|
@ -1531,6 +1609,7 @@ class _MatchingContext:
|
|||
self.delta3_set.remove(e)
|
||||
|
||||
# Compute delta4: half minimum dual variable of a top-level T-blossom.
|
||||
# This takes time O(log(n)).
|
||||
if not self.delta4_queue.empty():
|
||||
blossom = self.delta4_queue.find_min().data
|
||||
assert blossom.label == _LABEL_T
|
||||
|
|
Loading…
Reference in New Issue