1
0
Fork 0

Lazy delta updates of T-blossom duals

This commit is contained in:
Joris van Rantwijk 2024-05-26 11:32:02 +02:00
parent b2e055b357
commit 6318de3b1f
1 changed files with 25 additions and 20 deletions

View File

@ -448,9 +448,15 @@ class _NonTrivialBlossom(_Blossom):
# "dual_var" is the value of the dual variable minus 2 times # "dual_var" is the value of the dual variable minus 2 times
# the running sum of delta steps. # the running sum of delta steps.
# #
# If this is a top-level T-blossom,
# "dual_var" is the value of the dual variable plus 2 times
# the running sum of delta steps.
#
# In all other cases, # In all other cases,
# "dual_var" is the current value of the dual variable. # "dual_var" is the current value of the dual variable.
# #
# Note that "dual_var" is invariant under delta steps.
#
# New blossoms start with dual variable 0. # New blossoms start with dual variable 0.
self.dual_var: float = 0 self.dual_var: float = 0
@ -688,6 +694,7 @@ class _MatchingContext:
assert blossom.label == _LABEL_NONE assert blossom.label == _LABEL_NONE
blossom.label = _LABEL_S blossom.label = _LABEL_S
# Prepare for lazy updating of S-blossom dual variable.
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, _NonTrivialBlossom):
blossom.dual_var -= self.delta_sum_2x blossom.dual_var -= self.delta_sum_2x
@ -697,13 +704,14 @@ class _MatchingContext:
assert blossom.label == _LABEL_S assert blossom.label == _LABEL_S
blossom.label = _LABEL_NONE blossom.label = _LABEL_NONE
# Catch up with lazy updates to S-blossom dual variable.
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, _NonTrivialBlossom):
blossom.dual_var += self.delta_sum_2x blossom.dual_var += self.delta_sum_2x
def assign_vertex_label_s(self, blossom: _Blossom) -> None: def assign_vertex_label_s(self, blossom: _Blossom) -> None:
"""Adjust after assigning label S to previously unlabeled vertices.""" """Adjust after assigning label S to previously unlabeled vertices."""
# Adjust the vertex dual variables of the new S-vertices. # Prepare for lazy updating of S-vertex dual variables.
vertices = blossom.vertices() vertices = blossom.vertices()
for x in vertices: for x in vertices:
self.vertex_dual_2x[x] += self.delta_sum_2x self.vertex_dual_2x[x] += self.delta_sum_2x
@ -714,7 +722,7 @@ class _MatchingContext:
def remove_vertex_label_s(self, blossom: _Blossom) -> None: def remove_vertex_label_s(self, blossom: _Blossom) -> None:
"""Adjust after removing labels from S-vertices.""" """Adjust after removing labels from S-vertices."""
# Adjust the vertex dual variables of the former S-vertices. # Catch up with lazy updates to S-vertex dual variables.
for x in blossom.vertices(): for x in blossom.vertices():
self.vertex_dual_2x[x] -= self.delta_sum_2x self.vertex_dual_2x[x] -= self.delta_sum_2x
@ -723,31 +731,35 @@ class _MatchingContext:
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_NONE assert blossom.label == _LABEL_NONE
# Assign label T.
blossom.label = _LABEL_T blossom.label = _LABEL_T
# Insert blossom into the delta4 queue.
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, _NonTrivialBlossom):
# Prepare for lazy updates to T-blossom dual variables.
blossom.dual_var += self.delta_sum_2x
# Insert blossom into the delta4 queue.
assert blossom.delta4_node is None assert blossom.delta4_node is None
prio = blossom.dual_var + self.delta_sum_2x blossom.delta4_node = self.delta4_queue.insert(blossom.dual_var,
blossom.delta4_node = self.delta4_queue.insert(prio, blossom) blossom)
def remove_blossom_label_t(self, blossom: _Blossom) -> None: def remove_blossom_label_t(self, blossom: _Blossom) -> None:
"""Remove label T from a top-level T-blossom.""" """Remove label T from a top-level T-blossom."""
assert blossom.parent is None assert blossom.parent is None
assert blossom.label == _LABEL_T assert blossom.label == _LABEL_T
# Remove label.
blossom.label = _LABEL_NONE blossom.label = _LABEL_NONE
# Remove blossom from delta4 queue.
if isinstance(blossom, _NonTrivialBlossom): if isinstance(blossom, _NonTrivialBlossom):
# Remove blossom from delta4 queue.
assert blossom.delta4_node is not None assert blossom.delta4_node is not None
self.delta4_queue.delete(blossom.delta4_node) self.delta4_queue.delete(blossom.delta4_node)
blossom.delta4_node = None blossom.delta4_node = None
# Catch up with lazy updates to T-blossom dual variable.
blossom.dual_var -= self.delta_sum_2x
def reset_stage(self) -> None: def reset_stage(self) -> None:
"""Reset data which are only valid during a stage. """Reset data which are only valid during a stage.
@ -1428,9 +1440,10 @@ class _MatchingContext:
blossom = self.delta4_queue.find_min().data blossom = self.delta4_queue.find_min().data
assert blossom.label == _LABEL_T assert blossom.label == _LABEL_T
assert blossom.parent is None assert blossom.parent is None
if blossom.dual_var <= delta_2x: blossom_dual = blossom.dual_var - self.delta_sum_2x
if blossom_dual <= delta_2x:
delta_type = 4 delta_type = 4
delta_2x = blossom.dual_var delta_2x = blossom_dual
delta_blossom = blossom delta_blossom = blossom
return (delta_type, delta_2x, delta_edge, delta_blossom) return (delta_type, delta_2x, delta_edge, delta_blossom)
@ -1449,14 +1462,6 @@ class _MatchingContext:
# T-vertex: add delta to dual variable. # T-vertex: add delta to dual variable.
self.vertex_dual_2x[x] += delta_2x self.vertex_dual_2x[x] += delta_2x
# Apply delta to dual variables of top-level non-trivial blossoms.
for blossom in self.nontrivial_blossom:
if blossom.parent is None:
blabel = blossom.label
if blabel == _LABEL_T:
# T-blossom: subtract 2*delta from dual variable.
blossom.dual_var -= delta_2x
# #
# Main stage function: # Main stage function:
# #