1 //===- RootOrdering.cpp - Optimal root ordering ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // An implementation of Edmonds' optimal branching algorithm. This is a 10 // directed analogue of the minimum spanning tree problem for a given root. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "RootOrdering.h" 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/DenseSet.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include <queue> 20 #include <utility> 21 22 using namespace mlir; 23 using namespace mlir::pdl_to_pdl_interp; 24 25 /// Returns the cycle implied by the specified parent relation, starting at the 26 /// given node. 27 static SmallVector<Value> getCycle(const DenseMap<Value, Value> &parents, 28 Value rep) { 29 SmallVector<Value> cycle; 30 Value node = rep; 31 do { 32 cycle.push_back(node); 33 node = parents.lookup(node); 34 assert(node && "got an empty value in the cycle"); 35 } while (node != rep); 36 return cycle; 37 } 38 39 /// Contracts the specified cycle in the given graph in-place. 40 /// The parentsCost map specifies, for each node in the cycle, the lowest cost 41 /// among the edges entering that node. Then, the nodes in the cycle C are 42 /// replaced with a single node v_C (the first node in the cycle). All edges 43 /// (u, v) entering the cycle, v \in C, are replaced with a single edge 44 /// (u, v_C) with an appropriately chosen cost, and the selected node v is 45 /// marked in the output map actualTarget[u]. All edges (u, v) leaving the 46 /// cycle, u \in C, are replaced with a single edge (v_C, v), and the selected 47 /// node u is marked in the ouptut map actualSource[v]. 48 static void contract(RootOrderingGraph &graph, ArrayRef<Value> cycle, 49 const DenseMap<Value, unsigned> &parentDepths, 50 DenseMap<Value, Value> &actualSource, 51 DenseMap<Value, Value> &actualTarget) { 52 Value rep = cycle.front(); 53 DenseSet<Value> cycleSet(cycle.begin(), cycle.end()); 54 55 // Now, contract the cycle, marking the actual sources and targets. 56 DenseMap<Value, RootOrderingEntry> repEntries; 57 for (auto outer = graph.begin(), e = graph.end(); outer != e; ++outer) { 58 Value target = outer->first; 59 if (cycleSet.contains(target)) { 60 // Target in the cycle => edges incoming to the cycle or within the cycle. 61 unsigned parentDepth = parentDepths.lookup(target); 62 for (const auto &inner : outer->second) { 63 Value source = inner.first; 64 // Ignore edges within the cycle. 65 if (cycleSet.contains(source)) 66 continue; 67 68 // Edge incoming to the cycle. 69 std::pair<unsigned, unsigned> cost = inner.second.cost; 70 assert(parentDepth <= cost.first && "invalid parent depth"); 71 72 // Subtract the cost of the parent within the cycle from the cost of 73 // the edge incoming to the cycle. This update ensures that the cost 74 // of the minimum-weight spanning arborescence of the entire graph is 75 // the cost of arborescence for the contracted graph plus the cost of 76 // the cycle, no matter which edge in the cycle we choose to drop. 77 cost.first -= parentDepth; 78 auto it = repEntries.find(source); 79 if (it == repEntries.end() || it->second.cost > cost) { 80 actualTarget[source] = target; 81 // Do not bother populating the connector (the connector is only 82 // relevant for the final traversal, not for the optimal branching). 83 repEntries[source].cost = cost; 84 } 85 } 86 // Erase the node in the cycle. 87 graph.erase(outer); 88 } else { 89 // Target not in cycle => edges going away from or unrelated to the cycle. 90 DenseMap<Value, RootOrderingEntry> &entries = outer->second; 91 Value bestSource; 92 std::pair<unsigned, unsigned> bestCost; 93 auto inner = entries.begin(), innerE = entries.end(); 94 while (inner != innerE) { 95 Value source = inner->first; 96 if (cycleSet.contains(source)) { 97 // Going-away edge => get its cost and erase it. 98 if (!bestSource || bestCost > inner->second.cost) { 99 bestSource = source; 100 bestCost = inner->second.cost; 101 } 102 entries.erase(inner++); 103 } else { 104 ++inner; 105 } 106 } 107 108 // There were going-away edges, contract them. 109 if (bestSource) { 110 entries[rep].cost = bestCost; 111 actualSource[target] = bestSource; 112 } 113 } 114 } 115 116 // Store the edges to the representative. 117 graph[rep] = std::move(repEntries); 118 } 119 120 OptimalBranching::OptimalBranching(RootOrderingGraph graph, Value root) 121 : graph(std::move(graph)), root(root) {} 122 123 unsigned OptimalBranching::solve() { 124 // Initialize the parents and total cost. 125 parents.clear(); 126 parents[root] = Value(); 127 unsigned totalCost = 0; 128 129 // A map that stores the cost of the optimal local choice for each node 130 // in a directed cycle. This map is cleared every time we seed the search. 131 DenseMap<Value, unsigned> parentDepths; 132 parentDepths.reserve(graph.size()); 133 134 // Determine if the optimal local choice results in an acyclic graph. This is 135 // done by computing the optimal local choice and traversing up the computed 136 // parents. On success, `parents` will contain the parent of each node. 137 for (const auto &outer : graph) { 138 Value node = outer.first; 139 if (parents.count(node)) // already visited 140 continue; 141 142 // Follow the trail of best sources until we reach an already visited node. 143 // The code will assert if we cannot reach an already visited node, i.e., 144 // the graph is not strongly connected. 145 parentDepths.clear(); 146 do { 147 auto it = graph.find(node); 148 assert(it != graph.end() && "the graph is not strongly connected"); 149 150 // Find the best local parent, taking into account both the depth and the 151 // tie breaking rules. 152 Value &bestSource = parents[node]; 153 std::pair<unsigned, unsigned> bestCost; 154 for (const auto &inner : it->second) { 155 const RootOrderingEntry &entry = inner.second; 156 if (!bestSource /* initial */ || bestCost > entry.cost) { 157 bestSource = inner.first; 158 bestCost = entry.cost; 159 } 160 } 161 assert(bestSource && "the graph is not strongly connected"); 162 parentDepths[node] = bestCost.first; 163 node = bestSource; 164 totalCost += bestCost.first; 165 } while (!parents.count(node)); 166 167 // If we reached a non-root node, we have a cycle. 168 if (parentDepths.count(node)) { 169 // Determine the cycle starting at the representative node. 170 SmallVector<Value> cycle = getCycle(parents, node); 171 172 // The following maps disambiguate the source / target of the edges 173 // going out of / into the cycle. 174 DenseMap<Value, Value> actualSource, actualTarget; 175 176 // Contract the cycle and recurse. 177 contract(graph, cycle, parentDepths, actualSource, actualTarget); 178 totalCost = solve(); 179 180 // Redirect the going-away edges. 181 for (auto &p : parents) 182 if (p.second == node) 183 // The parent is the node representating the cycle; replace it 184 // with the actual (best) source in the cycle. 185 p.second = actualSource.lookup(p.first); 186 187 // Redirect the unique incoming edge and copy the cycle. 188 Value parent = parents.lookup(node); 189 Value entry = actualTarget.lookup(parent); 190 cycle.push_back(node); // complete the cycle 191 for (size_t i = 0, e = cycle.size() - 1; i < e; ++i) { 192 totalCost += parentDepths.lookup(cycle[i]); 193 if (cycle[i] == entry) 194 parents[cycle[i]] = parent; // break the cycle 195 else 196 parents[cycle[i]] = cycle[i + 1]; 197 } 198 199 // `parents` has a complete solution. 200 break; 201 } 202 } 203 204 return totalCost; 205 } 206 207 OptimalBranching::EdgeList 208 OptimalBranching::preOrderTraversal(ArrayRef<Value> nodes) const { 209 // Invert the parent mapping. 210 DenseMap<Value, std::vector<Value>> children; 211 for (Value node : nodes) { 212 if (node != root) { 213 Value parent = parents.lookup(node); 214 assert(parent && "invalid parent"); 215 children[parent].push_back(node); 216 } 217 } 218 219 // The result which simultaneously acts as a queue. 220 EdgeList result; 221 result.reserve(nodes.size()); 222 result.emplace_back(root, Value()); 223 224 // Perform a BFS, pushing into the queue. 225 for (size_t i = 0; i < result.size(); ++i) { 226 Value node = result[i].first; 227 for (Value child : children[node]) 228 result.emplace_back(child, node); 229 } 230 231 return result; 232 } 233