diff --git a/python/datastruct.py b/python/datastruct.py index 9c71d72..f543988 100644 --- a/python/datastruct.py +++ b/python/datastruct.py @@ -112,6 +112,8 @@ class UnionFindQueue(Generic[_NameT, _ElemT]): self.split_nodes.clear() # Wipe pointers to enable refcounted garbage collection. + if node is not None: + node.owner = None while node is not None: prev_node = node if node.left is not None: diff --git a/python/mwmatching.py b/python/mwmatching.py index 73d5077..569c398 100644 --- a/python/mwmatching.py +++ b/python/mwmatching.py @@ -383,6 +383,19 @@ class _Blossom: # "tree_edge = None" if the blossom is the root of an alternating tree. self.tree_edge: Optional[tuple[int, int]] = None + # Each top-level blossom maintains a union-find datastructure + # containing all vertices in the blossom. + self.vertex_set: "UnionFindQueue[_Blossom, int]" + self.vertex_set = UnionFindQueue(self) + + # If this is a top-level unlabeled blossom with an edge to an + # S-blossom, "delta2_node" is the corresponding node in the delta2 + # queue. + self.delta2_node: Optional[PriorityQueue.Node] = None + + # Support variable for lazy updating of vertex dual variables. + self.vertex_dual_offset: float = 0 + # "marker" is a temporary variable used to discover common # ancestors in the blossom tree. It is normally False, except # when used by "trace_alternating_paths()". @@ -460,9 +473,6 @@ class _NonTrivialBlossom(_Blossom): # Note that "dual_var" is invariant under delta steps. self.dual_var: float = 0 - # Support variable for lazy updating of vertex dual variables. - self.vertex_dual_offset: float = 0 - # If this is a top-level T-blossom, # "delta4_node" is the corresponding node in the delta4 queue. # Otherwise "delta4_node" is None. @@ -543,6 +553,11 @@ class _MatchingContext: # Initially all vertices are trivial top-level blossoms. self.vertex_top_blossom: list[_Blossom] = self.trivial_blossom.copy() + # "vertex_set_node[x]" represents the vertex "x" inside the + # union-find datastructure of its top-level blossom. + self.vertex_set_node = [b.vertex_set.insert(i, math.inf) + for (i, b) in enumerate(self.trivial_blossom)] + # All vertex duals are initialized to half the maximum edge weight. # # "start_vertex_dual_2x" is 2 times the initial vertex dual value. @@ -572,14 +587,19 @@ class _MatchingContext: # Running sum of applied delta steps times 2. self.delta_sum_2x: float = 0 + # Queue containing unlabeled top-level blossoms that have an edge to + # an S-blossom. The priority of a blossom is 2 times the least slack + # to an S blossom, plus 2 times the running sum of delta steps. + self.delta2_queue: PriorityQueue[_Blossom] = PriorityQueue() + # Queue containing edges between S-vertices in different top-level - # blossoms. The priority of an edge is its slack plus two times the + # blossoms. The priority of an edge is its slack plus 2 times the # running sum of delta steps. self.delta3_queue: PriorityQueue[int] = PriorityQueue() self.delta3_set: set[int] = set() # Queue containing top-level non-trivial T-blossoms. - # The priority of a blossom is its dual plus two times the running + # The priority of a blossom is its dual plus 2 times the running # sum of delta steps. self.delta4_queue: PriorityQueue[_NonTrivialBlossom] = PriorityQueue() @@ -591,7 +611,18 @@ class _MatchingContext: # Queue of S-vertices to be scanned. self.scan_queue: collections.deque[int] = collections.deque() - def edge_slack_2x( + def __del__(self) -> None: + """Delete reference cycles during cleanup of the matching context.""" + + for blossom in self.trivial_blossom: + blossom.vertex_set.clear() + del blossom.vertex_set + + for blossom in self.nontrivial_blossom: + blossom.vertex_set.clear() + del blossom.vertex_set + + def edge_slack_2x_info( self, x: int, y: int, @@ -619,12 +650,17 @@ class _MatchingContext: dual_2x += self.delta_sum_2x if by.label == _LABEL_T: dual_2x += self.delta_sum_2x - if isinstance(bx, _NonTrivialBlossom): - dual_2x += bx.vertex_dual_offset - if isinstance(by, _NonTrivialBlossom): - dual_2x += by.vertex_dual_offset + dual_2x += bx.vertex_dual_offset + dual_2x += by.vertex_dual_offset return dual_2x - 2 * w + def edge_slack_2x(self, e: int) -> float: + """Return 2 times the slack of the edge with index "e".""" + (x, y, w) = self.graph.edges[e] + bx = self.vertex_top_blossom[x] + by = self.vertex_top_blossom[y] + return self.edge_slack_2x_info(x, y, bx, by, w) + # # Least-slack edge tracking: # @@ -649,57 +685,80 @@ class _MatchingContext: def lset_reset(self) -> None: """Reset least-slack edge tracking. - This function takes time O(n). + This function takes time O(n * log(n)). """ num_vertex = self.graph.num_vertex for x in range(num_vertex): self.vertex_best_edge[x] = -1 + self.vertex_set_node[x].set_prio(math.inf) - def lset_add_vertex_edge(self, y: int, e: int, slack: float) -> None: + self.delta2_queue.clear() + + for blossom in self.trivial_blossom + self.nontrivial_blossom: + blossom.delta2_node = None + + def lset_add_vertex_edge( + self, + y: int, + by: _Blossom, + e: int, + slack: float + ) -> None: """Add edge "e" from an S-vertex to unlabeled vertex or T-vertex "y". - This function takes time O(1) per call. - This function is called O(m) times per stage. + This function takes time O(log(n)). """ + + # TODO -- Simplify: We don't need to know the true slack of the edge, + # only the pseudo-slack based on raw vertex duals and weight. + best_edge = self.vertex_best_edge[y] - if best_edge == -1: - self.vertex_best_edge[y] = e - else: - (xx, yy, w) = self.graph.edges[best_edge] - bx = self.vertex_top_blossom[xx] - by = self.vertex_top_blossom[yy] - best_slack = self.edge_slack_2x(xx, yy, bx, by, w) - if slack < best_slack: - self.vertex_best_edge[y] = e + if best_edge != -1: + best_slack = self.edge_slack_2x(best_edge) + if slack >= best_slack: + return + + self.vertex_best_edge[y] = e + + (p, q, w) = self.graph.edges[e] + prio = self.vertex_dual_2x[p] + self.vertex_dual_2x[q] - 2 * w + prev_min = by.vertex_set.min_prio() + self.vertex_set_node[y].set_prio(prio) + + if (by.label == _LABEL_NONE) and (prio < prev_min): + prio = slack + self.delta_sum_2x + if by.delta2_node is None: + by.delta2_node = self.delta2_queue.insert(prio, by) + elif prio < by.delta2_node.prio: + self.delta2_queue.decrease_prio(by.delta2_node, prio) def lset_get_best_vertex_edge(self) -> tuple[int, float]: """Return the index and slack of the least-slack edge between any S-vertex and unlabeled vertex. - This function takes time O(n) per call. - This function takes total time O(n**2) per stage. + This function takes time O(log(n)). Returns: Tuple (edge_index, slack_2x) if there is a least-slack edge, or (-1, 0) if there is no suitable edge. """ - best_index = -1 - best_slack: float = 0 - for x in range(self.graph.num_vertex): - if self.vertex_top_blossom[x].label == _LABEL_NONE: - e = self.vertex_best_edge[x] - if e != -1: - (x, y, w) = self.graph.edges[e] - bx = self.vertex_top_blossom[x] - by = self.vertex_top_blossom[y] - slack = self.edge_slack_2x(x, y, bx, by, w) - if (best_index == -1) or (slack < best_slack): - best_index = e - best_slack = slack + if self.delta2_queue.empty(): + return (-1, 0) - return (best_index, best_slack) + delta2_node = self.delta2_queue.find_min() + blossom = delta2_node.data + prio = delta2_node.prio + slack_2x = prio - self.delta_sum_2x + assert blossom.parent is None + assert blossom.label == _LABEL_NONE + + x = blossom.vertex_set.min_elem() + e = self.vertex_best_edge[x] + assert e >= 0 + + return (e, slack_2x) # # General support routines: @@ -711,6 +770,11 @@ class _MatchingContext: assert blossom.label == _LABEL_NONE blossom.label = _LABEL_S + # Delete blossom from delta2 queue. + if blossom.delta2_node is not None: + self.delta2_queue.delete(blossom.delta2_node) + blossom.delta2_node = None + # Prepare for lazy updating of S-blossom dual variable. if isinstance(blossom, _NonTrivialBlossom): blossom.dual_var -= self.delta_sum_2x @@ -733,11 +797,8 @@ class _MatchingContext: self.scan_queue.extend(vertices) # Prepare for lazy updating of S-vertex dual variables. - vertex_dual_fixup = self.delta_sum_2x - if isinstance(blossom, _NonTrivialBlossom): - vertex_dual_fixup += blossom.vertex_dual_offset - blossom.vertex_dual_offset = 0 - + vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset + blossom.vertex_dual_offset = 0 for x in vertices: self.vertex_dual_2x[x] += vertex_dual_fixup @@ -748,23 +809,23 @@ class _MatchingContext: assert blossom.label == _LABEL_NONE blossom.label = _LABEL_T + # Delete blossom from delta2 queue. + if blossom.delta2_node is not None: + self.delta2_queue.delete(blossom.delta2_node) + blossom.delta2_node = None + if isinstance(blossom, _NonTrivialBlossom): # Prepare for lazy updating of T-blossom dual variables. blossom.dual_var += self.delta_sum_2x - # Prepare for lazy updating of T-vertex dual variables. - blossom.vertex_dual_offset -= self.delta_sum_2x - # Insert blossom into the delta4 queue. assert blossom.delta4_node is None blossom.delta4_node = self.delta4_queue.insert(blossom.dual_var, blossom) - else: - # Prepare for lazy updating of T-vertex dual variables. - self.vertex_dual_2x[blossom.base_vertex] -= self.delta_sum_2x - + # Prepare for lazy updating of T-vertex dual variables. + blossom.vertex_dual_offset -= self.delta_sum_2x def remove_blossom_label_t(self, blossom: _Blossom) -> None: """Remove label T from a top-level T-blossom.""" @@ -783,12 +844,8 @@ class _MatchingContext: # Unwind lazy updates to T-blossom dual variable. blossom.dual_var -= self.delta_sum_2x - # Unwind lazy updates of T-vertex dual variables. - blossom.vertex_dual_offset += self.delta_sum_2x - - else: - # Unwind lazy updates of T-vertex dual variables. - self.vertex_dual_2x[blossom.base_vertex] += self.delta_sum_2x + # Unwind lazy updates of T-vertex dual variables. + blossom.vertex_dual_offset += self.delta_sum_2x def reset_blossom_label(self, blossom: _Blossom) -> None: """Remove blossom label and calculate true dual variables.""" @@ -803,9 +860,9 @@ class _MatchingContext: # Unwind lazy delta updates to S-blossom dual variable. if isinstance(blossom, _NonTrivialBlossom): blossom.dual_var += self.delta_sum_2x - assert blossom.vertex_dual_offset == 0 # Unwind lazy delta updates to S-vertex dual variables. + assert blossom.vertex_dual_offset == 0 vertex_dual_fixup = -self.delta_sum_2x for x in blossom.vertices(): self.vertex_dual_2x[x] += vertex_dual_fixup @@ -815,30 +872,23 @@ class _MatchingContext: # Remove label. blossom.label = _LABEL_NONE - # Prepare to unwind lazy delta updates to vertices. - vertex_dual_fixup = self.delta_sum_2x - + # Unwind lazy delta updates to T-blossom dual variable. if isinstance(blossom, _NonTrivialBlossom): - # Unwind lazy delta updates to T-blossom dual variable. blossom.dual_var -= self.delta_sum_2x - # Prepare to unwind lazy delta updates to vertices. - vertex_dual_fixup += blossom.vertex_dual_offset - blossom.vertex_dual_offset = 0 - # Unwind lazy delta updates to T-vertex dual variables. + vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset + blossom.vertex_dual_offset = 0 for x in blossom.vertices(): self.vertex_dual_2x[x] += vertex_dual_fixup else: # Unwind lazy delta updates to vertex dual variables. - if isinstance(blossom, _NonTrivialBlossom): - vertex_dual_fixup = blossom.vertex_dual_offset - blossom.vertex_dual_offset = 0 - - for x in blossom.vertices(): - self.vertex_dual_2x[x] += vertex_dual_fixup + vertex_dual_fixup = blossom.vertex_dual_offset + blossom.vertex_dual_offset = 0 + for x in blossom.vertices(): + self.vertex_dual_2x[x] += vertex_dual_fixup def reset_stage(self) -> None: """Reset data which are only valid during a stage. @@ -846,7 +896,7 @@ class _MatchingContext: Marks all blossoms as unlabeled, clears the queue, and resets tracking of least-slack edges. - This function takes time O(n). + This function takes time O(n * log(n)). """ # Remove blossom labels and unwind lazy dual updates. @@ -1007,6 +1057,9 @@ class _MatchingContext: for x in blossom.vertices(): self.vertex_top_blossom[x] = blossom + # Merge union-find structures. + blossom.vertex_set.merge([sub.vertex_set for sub in subblossoms]) + @staticmethod def find_path_through_blossom( blossom: _NonTrivialBlossom, @@ -1048,6 +1101,9 @@ class _MatchingContext: self.delta4_queue.delete(blossom.delta4_node) blossom.delta4_node = None + # Split union-find structure. + blossom.vertex_set.split() + # Prepare to push lazy delta updates down to the sub-blossoms. vertex_dual_fixup = self.delta_sum_2x + blossom.vertex_dual_offset blossom.vertex_dual_offset = 0 @@ -1056,14 +1112,22 @@ class _MatchingContext: for sub in blossom.subblossoms: assert sub.label == _LABEL_NONE sub.parent = None + + assert sub.vertex_dual_offset == 0 + sub.vertex_dual_offset = vertex_dual_fixup + if isinstance(sub, _NonTrivialBlossom): - sub.vertex_dual_offset = vertex_dual_fixup for x in sub.vertices(): self.vertex_top_blossom[x] = sub else: - x = sub.base_vertex - self.vertex_dual_2x[x] += vertex_dual_fixup - self.vertex_top_blossom[x] = sub + self.vertex_top_blossom[sub.base_vertex] = sub + + # Insert blossom in delta2_queue if necessary. + prio = sub.vertex_set.min_prio() + if prio < math.inf: + assert sub.delta2_node is None + prio += sub.vertex_dual_offset + sub.delta2_node = self.delta2_queue.insert(prio, sub) # The expanding blossom was part of an alternating tree, linked to # a parent node in the tree via one of its subblossoms, and linked to @@ -1117,6 +1181,9 @@ class _MatchingContext: assert blossom.parent is None assert blossom.label == _LABEL_NONE + # Split union-find structure. + blossom.vertex_set.split() + # Prepare to push lazy delta updates down to the sub-blossoms. vertex_dual_offset = blossom.vertex_dual_offset blossom.vertex_dual_offset = 0 @@ -1126,15 +1193,21 @@ class _MatchingContext: assert sub.label == _LABEL_NONE sub.parent = None + assert sub.vertex_dual_offset == 0 + sub.vertex_dual_offset = vertex_dual_offset + if isinstance(sub, _NonTrivialBlossom): - assert sub.vertex_dual_offset == 0 - sub.vertex_dual_offset = vertex_dual_offset for x in sub.vertices(): self.vertex_top_blossom[x] = sub else: - x = sub.base_vertex - self.vertex_dual_2x[x] += vertex_dual_offset - self.vertex_top_blossom[x] = sub + self.vertex_top_blossom[sub.base_vertex] = sub + + # Insert blossom in delta2_queue if necessary. + prio = sub.vertex_set.min_prio() + if prio < math.inf: + assert sub.delta2_node is None + prio += sub.vertex_dual_offset + sub.delta2_node = self.delta2_queue.insert(prio, sub) # Delete the expanded blossom. self.nontrivial_blossom.remove(blossom) @@ -1427,7 +1500,7 @@ class _MatchingContext: # Check whether this edge is tight (has zero slack). # Only tight edges may be part of an alternating tree. - slack = self.edge_slack_2x(x, y, bx, by, w) + slack = self.edge_slack_2x_info(x, y, bx, by, w) if slack <= 0: if ylabel == _LABEL_NONE: # Assign label T to the blossom that contains "y". @@ -1460,7 +1533,7 @@ class _MatchingContext: # any S-vertex. We do this for T-vertices and unlabeled # vertices. Edges which already have zero slack are still # tracked. - self.lset_add_vertex_edge(y, e, slack) + self.lset_add_vertex_edge(y, by, e, slack) # No further S vertices to scan, and no augmenting path found. return None @@ -1483,7 +1556,8 @@ class _MatchingContext: weights are integers. This function assumes that there is at least one S-vertex. - This function takes time O(n). + This function takes total time O(m * log(n)) for all calls + within a stage. Returns: Tuple (delta_type, delta_2x, delta_edge, delta_blossom). @@ -1499,6 +1573,7 @@ class _MatchingContext: # Compute delta2: minimum slack of any edge between an S-vertex and # an unlabeled vertex. + # This takes time O(log(n)). (e, slack) = self.lset_get_best_vertex_edge() if (e != -1) and (slack <= delta_2x): delta_type = 2 @@ -1507,6 +1582,9 @@ class _MatchingContext: # Compute delta3: half minimum slack of any edge between two top-level # S-blossoms. + # + # This loop iterates O(m) times per stage. + # Each iteration takes time O(log(n)). while not self.delta3_queue.empty(): delta3_node = self.delta3_queue.find_min() e = delta3_node.data @@ -1531,6 +1609,7 @@ class _MatchingContext: self.delta3_set.remove(e) # Compute delta4: half minimum dual variable of a top-level T-blossom. + # This takes time O(log(n)). if not self.delta4_queue.empty(): blossom = self.delta4_queue.find_min().data assert blossom.label == _LABEL_T