1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 15 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 20 #include "mlir/Dialect/Affine/LoopUtils.h" 21 #include "mlir/Dialect/Affine/Utils.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/Transforms/Passes.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 #include "llvm/Support/raw_ostream.h" 33 #include <iomanip> 34 #include <sstream> 35 #define DEBUG_TYPE "affine-loop-fusion" 36 37 using namespace mlir; 38 39 namespace { 40 /// Loop fusion pass. This pass currently supports a greedy fusion policy, 41 /// which fuses loop nests with single-writer/single-reader memref dependences 42 /// with the goal of improving locality. 43 44 // TODO: Support fusion of source loop nests which write to multiple 45 // memrefs, where each memref can have multiple users (if profitable). 46 // TODO: Extend this pass to check for fusion preventing dependences, 47 // and add support for more general loop fusion algorithms. 48 49 struct LoopFusion : public AffineLoopFusionBase<LoopFusion> { 50 LoopFusion() = default; 51 LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes, 52 bool maximalFusion, enum FusionMode affineFusionMode) { 53 this->fastMemorySpace = fastMemorySpace; 54 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024; 55 this->maximalFusion = maximalFusion; 56 this->affineFusionMode = affineFusionMode; 57 } 58 59 void runOnOperation() override; 60 }; 61 62 } // namespace 63 64 std::unique_ptr<OperationPass<FuncOp>> 65 mlir::createLoopFusionPass(unsigned fastMemorySpace, 66 uint64_t localBufSizeThreshold, bool maximalFusion, 67 enum FusionMode affineFusionMode) { 68 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, 69 maximalFusion, affineFusionMode); 70 } 71 72 namespace { 73 74 // LoopNestStateCollector walks loop nests and collects load and store 75 // operations, and whether or not a region holding op other than ForOp and IfOp 76 // was encountered in the loop nest. 77 struct LoopNestStateCollector { 78 SmallVector<AffineForOp, 4> forOps; 79 SmallVector<Operation *, 4> loadOpInsts; 80 SmallVector<Operation *, 4> storeOpInsts; 81 bool hasNonAffineRegionOp = false; 82 83 void collect(Operation *opToWalk) { 84 opToWalk->walk([&](Operation *op) { 85 if (isa<AffineForOp>(op)) 86 forOps.push_back(cast<AffineForOp>(op)); 87 else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op)) 88 hasNonAffineRegionOp = true; 89 else if (isa<AffineReadOpInterface>(op)) 90 loadOpInsts.push_back(op); 91 else if (isa<AffineWriteOpInterface>(op)) 92 storeOpInsts.push_back(op); 93 }); 94 } 95 }; 96 97 // MemRefDependenceGraph is a graph data structure where graph nodes are 98 // top-level operations in a FuncOp which contain load/store ops, and edges 99 // are memref dependences between the nodes. 100 // TODO: Add a more flexible dependence graph representation. 101 // TODO: Add a depth parameter to dependence graph construction. 102 struct MemRefDependenceGraph { 103 public: 104 // Node represents a node in the graph. A Node is either an entire loop nest 105 // rooted at the top level which contains loads/stores, or a top level 106 // load/store. 107 struct Node { 108 // The unique identifier of this node in the graph. 109 unsigned id; 110 // The top-level statement which is (or contains) a load/store. 111 Operation *op; 112 // List of load operations. 113 SmallVector<Operation *, 4> loads; 114 // List of store op insts. 115 SmallVector<Operation *, 4> stores; 116 Node(unsigned id, Operation *op) : id(id), op(op) {} 117 118 // Returns the load op count for 'memref'. 119 unsigned getLoadOpCount(Value memref) { 120 unsigned loadOpCount = 0; 121 for (auto *loadOpInst : loads) { 122 if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef()) 123 ++loadOpCount; 124 } 125 return loadOpCount; 126 } 127 128 // Returns the store op count for 'memref'. 129 unsigned getStoreOpCount(Value memref) { 130 unsigned storeOpCount = 0; 131 for (auto *storeOpInst : stores) { 132 if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef()) 133 ++storeOpCount; 134 } 135 return storeOpCount; 136 } 137 138 // Returns all store ops in 'storeOps' which access 'memref'. 139 void getStoreOpsForMemref(Value memref, 140 SmallVectorImpl<Operation *> *storeOps) { 141 for (auto *storeOpInst : stores) { 142 if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef()) 143 storeOps->push_back(storeOpInst); 144 } 145 } 146 147 // Returns all load ops in 'loadOps' which access 'memref'. 148 void getLoadOpsForMemref(Value memref, 149 SmallVectorImpl<Operation *> *loadOps) { 150 for (auto *loadOpInst : loads) { 151 if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef()) 152 loadOps->push_back(loadOpInst); 153 } 154 } 155 156 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node 157 // has at least one load and store operation. 158 void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) { 159 llvm::SmallDenseSet<Value, 2> loadMemrefs; 160 for (auto *loadOpInst : loads) { 161 loadMemrefs.insert(cast<AffineReadOpInterface>(loadOpInst).getMemRef()); 162 } 163 for (auto *storeOpInst : stores) { 164 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); 165 if (loadMemrefs.count(memref) > 0) 166 loadAndStoreMemrefSet->insert(memref); 167 } 168 } 169 }; 170 171 // Edge represents a data dependence between nodes in the graph. 172 struct Edge { 173 // The id of the node at the other end of the edge. 174 // If this edge is stored in Edge = Node.inEdges[i], then 175 // 'Node.inEdges[i].id' is the identifier of the source node of the edge. 176 // If this edge is stored in Edge = Node.outEdges[i], then 177 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge. 178 unsigned id; 179 // The SSA value on which this edge represents a dependence. 180 // If the value is a memref, then the dependence is between graph nodes 181 // which contain accesses to the same memref 'value'. If the value is a 182 // non-memref value, then the dependence is between a graph node which 183 // defines an SSA value and another graph node which uses the SSA value 184 // (e.g. a constant or load operation defining a value which is used inside 185 // a loop nest). 186 Value value; 187 }; 188 189 // Map from node id to Node. 190 DenseMap<unsigned, Node> nodes; 191 // Map from node id to list of input edges. 192 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges; 193 // Map from node id to list of output edges. 194 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges; 195 // Map from memref to a count on the dependence edges associated with that 196 // memref. 197 DenseMap<Value, unsigned> memrefEdgeCount; 198 // The next unique identifier to use for newly created graph nodes. 199 unsigned nextNodeId = 0; 200 201 MemRefDependenceGraph() = default; 202 203 // Initializes the dependence graph based on operations in 'f'. 204 // Returns true on success, false otherwise. 205 bool init(FuncOp f); 206 207 // Returns the graph node for 'id'. 208 Node *getNode(unsigned id) { 209 auto it = nodes.find(id); 210 assert(it != nodes.end()); 211 return &it->second; 212 } 213 214 // Returns the graph node for 'forOp'. 215 Node *getForOpNode(AffineForOp forOp) { 216 for (auto &idAndNode : nodes) 217 if (idAndNode.second.op == forOp.getOperation()) 218 return &idAndNode.second; 219 return nullptr; 220 } 221 222 // Adds a node with 'op' to the graph and returns its unique identifier. 223 unsigned addNode(Operation *op) { 224 Node node(nextNodeId++, op); 225 nodes.insert({node.id, node}); 226 return node.id; 227 } 228 229 // Remove node 'id' (and its associated edges) from graph. 230 void removeNode(unsigned id) { 231 // Remove each edge in 'inEdges[id]'. 232 if (inEdges.count(id) > 0) { 233 SmallVector<Edge, 2> oldInEdges = inEdges[id]; 234 for (auto &inEdge : oldInEdges) { 235 removeEdge(inEdge.id, id, inEdge.value); 236 } 237 } 238 // Remove each edge in 'outEdges[id]'. 239 if (outEdges.count(id) > 0) { 240 SmallVector<Edge, 2> oldOutEdges = outEdges[id]; 241 for (auto &outEdge : oldOutEdges) { 242 removeEdge(id, outEdge.id, outEdge.value); 243 } 244 } 245 // Erase remaining node state. 246 inEdges.erase(id); 247 outEdges.erase(id); 248 nodes.erase(id); 249 } 250 251 // Returns true if node 'id' writes to any memref which escapes (or is an 252 // argument to) the function/block. Returns false otherwise. 253 bool writesToLiveInOrEscapingMemrefs(unsigned id) { 254 Node *node = getNode(id); 255 for (auto *storeOpInst : node->stores) { 256 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); 257 auto *op = memref.getDefiningOp(); 258 // Return true if 'memref' is a block argument. 259 if (!op) 260 return true; 261 // Return true if any use of 'memref' escapes the function. 262 for (auto *user : memref.getUsers()) 263 if (!isa<AffineMapAccessInterface>(*user)) 264 return true; 265 } 266 return false; 267 } 268 269 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which 270 // is for 'value' if non-null, or for any value otherwise. Returns false 271 // otherwise. 272 bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) { 273 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { 274 return false; 275 } 276 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { 277 return edge.id == dstId && (!value || edge.value == value); 278 }); 279 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { 280 return edge.id == srcId && (!value || edge.value == value); 281 }); 282 return hasOutEdge && hasInEdge; 283 } 284 285 // Adds an edge from node 'srcId' to node 'dstId' for 'value'. 286 void addEdge(unsigned srcId, unsigned dstId, Value value) { 287 if (!hasEdge(srcId, dstId, value)) { 288 outEdges[srcId].push_back({dstId, value}); 289 inEdges[dstId].push_back({srcId, value}); 290 if (value.getType().isa<MemRefType>()) 291 memrefEdgeCount[value]++; 292 } 293 } 294 295 // Removes an edge from node 'srcId' to node 'dstId' for 'value'. 296 void removeEdge(unsigned srcId, unsigned dstId, Value value) { 297 assert(inEdges.count(dstId) > 0); 298 assert(outEdges.count(srcId) > 0); 299 if (value.getType().isa<MemRefType>()) { 300 assert(memrefEdgeCount.count(value) > 0); 301 memrefEdgeCount[value]--; 302 } 303 // Remove 'srcId' from 'inEdges[dstId]'. 304 for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { 305 if ((*it).id == srcId && (*it).value == value) { 306 inEdges[dstId].erase(it); 307 break; 308 } 309 } 310 // Remove 'dstId' from 'outEdges[srcId]'. 311 for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); 312 ++it) { 313 if ((*it).id == dstId && (*it).value == value) { 314 outEdges[srcId].erase(it); 315 break; 316 } 317 } 318 } 319 320 // Returns true if there is a path in the dependence graph from node 'srcId' 321 // to node 'dstId'. Returns false otherwise. 322 bool hasDependencePath(unsigned srcId, unsigned dstId) { 323 // Worklist state is: <node-id, next-output-edge-index-to-visit> 324 SmallVector<std::pair<unsigned, unsigned>, 4> worklist; 325 worklist.push_back({srcId, 0}); 326 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. 327 while (!worklist.empty()) { 328 auto &idAndIndex = worklist.back(); 329 // Return true if we have reached 'dstId'. 330 if (idAndIndex.first == dstId) 331 return true; 332 // Pop and continue if node has no out edges, or if all out edges have 333 // already been visited. 334 if (outEdges.count(idAndIndex.first) == 0 || 335 idAndIndex.second == outEdges[idAndIndex.first].size()) { 336 worklist.pop_back(); 337 continue; 338 } 339 // Get graph edge to traverse. 340 Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; 341 // Increment next output edge index for 'idAndIndex'. 342 ++idAndIndex.second; 343 // Add node at 'edge.id' to worklist. 344 worklist.push_back({edge.id, 0}); 345 } 346 return false; 347 } 348 349 // Returns the input edge count for node 'id' and 'memref' from src nodes 350 // which access 'memref' with a store operation. 351 unsigned getIncomingMemRefAccesses(unsigned id, Value memref) { 352 unsigned inEdgeCount = 0; 353 if (inEdges.count(id) > 0) 354 for (auto &inEdge : inEdges[id]) 355 if (inEdge.value == memref) { 356 Node *srcNode = getNode(inEdge.id); 357 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' 358 if (srcNode->getStoreOpCount(memref) > 0) 359 ++inEdgeCount; 360 } 361 return inEdgeCount; 362 } 363 364 // Returns the output edge count for node 'id' and 'memref' (if non-null), 365 // otherwise returns the total output edge count from node 'id'. 366 unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) { 367 unsigned outEdgeCount = 0; 368 if (outEdges.count(id) > 0) 369 for (auto &outEdge : outEdges[id]) 370 if (!memref || outEdge.value == memref) 371 ++outEdgeCount; 372 return outEdgeCount; 373 } 374 375 /// Return all nodes which define SSA values used in node 'id'. 376 void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes) { 377 for (MemRefDependenceGraph::Edge edge : inEdges[id]) 378 // By definition of edge, if the edge value is a non-memref value, 379 // then the dependence is between a graph node which defines an SSA value 380 // and another graph node which uses the SSA value. 381 if (!edge.value.getType().isa<MemRefType>()) 382 definingNodes.insert(edge.id); 383 } 384 385 // Computes and returns an insertion point operation, before which the 386 // the fused <srcId, dstId> loop nest can be inserted while preserving 387 // dependences. Returns nullptr if no such insertion point is found. 388 Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { 389 if (outEdges.count(srcId) == 0) 390 return getNode(dstId)->op; 391 392 // Skip if there is any defining node of 'dstId' that depends on 'srcId'. 393 DenseSet<unsigned> definingNodes; 394 gatherDefiningNodes(dstId, definingNodes); 395 if (llvm::any_of(definingNodes, [&](unsigned id) { 396 return hasDependencePath(srcId, id); 397 })) { 398 LLVM_DEBUG(llvm::dbgs() 399 << "Can't fuse: a defining op with a user in the dst " 400 "loop has dependence from the src loop\n"); 401 return nullptr; 402 } 403 404 // Build set of insts in range (srcId, dstId) which depend on 'srcId'. 405 SmallPtrSet<Operation *, 2> srcDepInsts; 406 for (auto &outEdge : outEdges[srcId]) 407 if (outEdge.id != dstId) 408 srcDepInsts.insert(getNode(outEdge.id)->op); 409 410 // Build set of insts in range (srcId, dstId) on which 'dstId' depends. 411 SmallPtrSet<Operation *, 2> dstDepInsts; 412 for (auto &inEdge : inEdges[dstId]) 413 if (inEdge.id != srcId) 414 dstDepInsts.insert(getNode(inEdge.id)->op); 415 416 Operation *srcNodeInst = getNode(srcId)->op; 417 Operation *dstNodeInst = getNode(dstId)->op; 418 419 // Computing insertion point: 420 // *) Walk all operation positions in Block operation list in the 421 // range (src, dst). For each operation 'op' visited in this search: 422 // *) Store in 'firstSrcDepPos' the first position where 'op' has a 423 // dependence edge from 'srcNode'. 424 // *) Store in 'lastDstDepPost' the last position where 'op' has a 425 // dependence edge to 'dstNode'. 426 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the 427 // operation insertion point (or return null pointer if no such 428 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). 429 SmallVector<Operation *, 2> depInsts; 430 Optional<unsigned> firstSrcDepPos; 431 Optional<unsigned> lastDstDepPos; 432 unsigned pos = 0; 433 for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); 434 it != Block::iterator(dstNodeInst); ++it) { 435 Operation *op = &(*it); 436 if (srcDepInsts.count(op) > 0 && firstSrcDepPos == None) 437 firstSrcDepPos = pos; 438 if (dstDepInsts.count(op) > 0) 439 lastDstDepPos = pos; 440 depInsts.push_back(op); 441 ++pos; 442 } 443 444 if (firstSrcDepPos.hasValue()) { 445 if (lastDstDepPos.hasValue()) { 446 if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) { 447 // No valid insertion point exists which preserves dependences. 448 return nullptr; 449 } 450 } 451 // Return the insertion point at 'firstSrcDepPos'. 452 return depInsts[firstSrcDepPos.getValue()]; 453 } 454 // No dependence targets in range (or only dst deps in range), return 455 // 'dstNodInst' insertion point. 456 return dstNodeInst; 457 } 458 459 // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, 460 // taking into account that: 461 // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, 462 // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a 463 // private memref. 464 void updateEdges(unsigned srcId, unsigned dstId, 465 const DenseSet<Value> &privateMemRefs, bool removeSrcId) { 466 // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. 467 if (inEdges.count(srcId) > 0) { 468 SmallVector<Edge, 2> oldInEdges = inEdges[srcId]; 469 for (auto &inEdge : oldInEdges) { 470 // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. 471 if (privateMemRefs.count(inEdge.value) == 0) 472 addEdge(inEdge.id, dstId, inEdge.value); 473 } 474 } 475 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. 476 // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. 477 if (outEdges.count(srcId) > 0) { 478 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId]; 479 for (auto &outEdge : oldOutEdges) { 480 // Remove any out edges from 'srcId' to 'dstId' across memrefs. 481 if (outEdge.id == dstId) 482 removeEdge(srcId, outEdge.id, outEdge.value); 483 else if (removeSrcId) { 484 addEdge(dstId, outEdge.id, outEdge.value); 485 removeEdge(srcId, outEdge.id, outEdge.value); 486 } 487 } 488 } 489 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being 490 // replaced by a private memref). These edges could come from nodes 491 // other than 'srcId' which were removed in the previous step. 492 if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) { 493 SmallVector<Edge, 2> oldInEdges = inEdges[dstId]; 494 for (auto &inEdge : oldInEdges) 495 if (privateMemRefs.count(inEdge.value) > 0) 496 removeEdge(inEdge.id, dstId, inEdge.value); 497 } 498 } 499 500 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion 501 // of sibling node 'sibId' into node 'dstId'. 502 void updateEdges(unsigned sibId, unsigned dstId) { 503 // For each edge in 'inEdges[sibId]': 504 // *) Add new edge from source node 'inEdge.id' to 'dstNode'. 505 // *) Remove edge from source node 'inEdge.id' to 'sibNode'. 506 if (inEdges.count(sibId) > 0) { 507 SmallVector<Edge, 2> oldInEdges = inEdges[sibId]; 508 for (auto &inEdge : oldInEdges) { 509 addEdge(inEdge.id, dstId, inEdge.value); 510 removeEdge(inEdge.id, sibId, inEdge.value); 511 } 512 } 513 514 // For each edge in 'outEdges[sibId]' to node 'id' 515 // *) Add new edge from 'dstId' to 'outEdge.id'. 516 // *) Remove edge from 'sibId' to 'outEdge.id'. 517 if (outEdges.count(sibId) > 0) { 518 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId]; 519 for (auto &outEdge : oldOutEdges) { 520 addEdge(dstId, outEdge.id, outEdge.value); 521 removeEdge(sibId, outEdge.id, outEdge.value); 522 } 523 } 524 } 525 526 // Adds ops in 'loads' and 'stores' to node at 'id'. 527 void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads, 528 const SmallVectorImpl<Operation *> &stores) { 529 Node *node = getNode(id); 530 for (auto *loadOpInst : loads) 531 node->loads.push_back(loadOpInst); 532 for (auto *storeOpInst : stores) 533 node->stores.push_back(storeOpInst); 534 } 535 536 void clearNodeLoadAndStores(unsigned id) { 537 Node *node = getNode(id); 538 node->loads.clear(); 539 node->stores.clear(); 540 } 541 542 // Calls 'callback' for each input edge incident to node 'id' which carries a 543 // memref dependence. 544 void forEachMemRefInputEdge(unsigned id, 545 const std::function<void(Edge)> &callback) { 546 if (inEdges.count(id) > 0) 547 forEachMemRefEdge(inEdges[id], callback); 548 } 549 550 // Calls 'callback' for each output edge from node 'id' which carries a 551 // memref dependence. 552 void forEachMemRefOutputEdge(unsigned id, 553 const std::function<void(Edge)> &callback) { 554 if (outEdges.count(id) > 0) 555 forEachMemRefEdge(outEdges[id], callback); 556 } 557 558 // Calls 'callback' for each edge in 'edges' which carries a memref 559 // dependence. 560 void forEachMemRefEdge(ArrayRef<Edge> edges, 561 const std::function<void(Edge)> &callback) { 562 for (const auto &edge : edges) { 563 // Skip if 'edge' is not a memref dependence edge. 564 if (!edge.value.getType().isa<MemRefType>()) 565 continue; 566 assert(nodes.count(edge.id) > 0); 567 // Skip if 'edge.id' is not a loop nest. 568 if (!isa<AffineForOp>(getNode(edge.id)->op)) 569 continue; 570 // Visit current input edge 'edge'. 571 callback(edge); 572 } 573 } 574 575 void print(raw_ostream &os) const { 576 os << "\nMemRefDependenceGraph\n"; 577 os << "\nNodes:\n"; 578 for (const auto &idAndNode : nodes) { 579 os << "Node: " << idAndNode.first << "\n"; 580 auto it = inEdges.find(idAndNode.first); 581 if (it != inEdges.end()) { 582 for (const auto &e : it->second) 583 os << " InEdge: " << e.id << " " << e.value << "\n"; 584 } 585 it = outEdges.find(idAndNode.first); 586 if (it != outEdges.end()) { 587 for (const auto &e : it->second) 588 os << " OutEdge: " << e.id << " " << e.value << "\n"; 589 } 590 } 591 } 592 void dump() const { print(llvm::errs()); } 593 }; 594 595 /// Returns true if node 'srcId' can be removed after fusing it with node 596 /// 'dstId'. The node can be removed if any of the following conditions are met: 597 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs. 598 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs 599 /// and the fusion slice is maximal. 600 /// 3. 'srcId' has output dependences after fusion, the fusion slice is 601 /// maximal and the fusion insertion point dominates all the dependences. 602 static bool canRemoveSrcNodeAfterFusion( 603 unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, 604 Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs, 605 MemRefDependenceGraph *mdg) { 606 607 Operation *dstNodeOp = mdg->getNode(dstId)->op; 608 bool hasOutDepsAfterFusion = false; 609 610 for (auto &outEdge : mdg->outEdges[srcId]) { 611 Operation *depNodeOp = mdg->getNode(outEdge.id)->op; 612 // Skip dependence with dstOp since it will be removed after fusion. 613 if (depNodeOp == dstNodeOp) 614 continue; 615 616 // Only fusion within the same block is supported. Use domination analysis 617 // when needed. 618 if (depNodeOp->getBlock() != dstNodeOp->getBlock()) 619 return false; 620 621 // Check if the insertion point of the fused loop dominates the dependence. 622 // Otherwise, the src loop can't be removed. 623 if (fusedLoopInsPoint != depNodeOp && 624 !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { 625 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " 626 "dominate dependence\n"); 627 return false; 628 } 629 630 hasOutDepsAfterFusion = true; 631 } 632 633 // If src loop has dependences after fusion or it writes to an live-out or 634 // escaping memref, we can only remove it if the fusion slice is maximal so 635 // that all the dependences are preserved. 636 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { 637 Optional<bool> isMaximal = fusionSlice.isMaximal(); 638 if (!isMaximal.hasValue()) { 639 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " 640 "if fusion is maximal\n"); 641 return false; 642 } 643 644 if (!isMaximal.getValue()) { 645 LLVM_DEBUG(llvm::dbgs() 646 << "Src loop can't be removed: fusion is not maximal\n"); 647 return false; 648 } 649 } 650 651 return true; 652 } 653 654 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer 655 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to 656 /// the program order when the 'mdg' is created. However, program order is not 657 /// guaranteed and must not be required by the client. Program order won't be 658 /// held if the 'mdg' is reused from a previous fusion step or if the node 659 /// creation order changes in the future to support more advance cases. 660 // TODO: Move this to a loop fusion utility once 'mdg' is also moved. 661 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, 662 SmallVectorImpl<unsigned> &srcIdCandidates) { 663 // Skip if no input edges along which to fuse. 664 if (mdg->inEdges.count(dstId) == 0) 665 return; 666 667 // Gather memrefs from loads in 'dstId'. 668 auto *dstNode = mdg->getNode(dstId); 669 DenseSet<Value> consumedMemrefs; 670 for (Operation *load : dstNode->loads) 671 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef()); 672 673 // Traverse 'dstId' incoming edges and gather the nodes that contain a store 674 // to one of the consumed memrefs. 675 for (auto &srcEdge : mdg->inEdges[dstId]) { 676 auto *srcNode = mdg->getNode(srcEdge.id); 677 // Skip if 'srcNode' is not a loop nest. 678 if (!isa<AffineForOp>(srcNode->op)) 679 continue; 680 681 if (any_of(srcNode->stores, [&](Operation *op) { 682 auto storeOp = cast<AffineWriteOpInterface>(op); 683 return consumedMemrefs.count(storeOp.getMemRef()) > 0; 684 })) 685 srcIdCandidates.push_back(srcNode->id); 686 } 687 688 std::sort(srcIdCandidates.begin(), srcIdCandidates.end()); 689 srcIdCandidates.erase( 690 std::unique(srcIdCandidates.begin(), srcIdCandidates.end()), 691 srcIdCandidates.end()); 692 } 693 694 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 695 /// producer-consumer dependence between 'srcId' and 'dstId'. 696 static void 697 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, 698 MemRefDependenceGraph *mdg, 699 DenseSet<Value> &producerConsumerMemrefs) { 700 auto *dstNode = mdg->getNode(dstId); 701 auto *srcNode = mdg->getNode(srcId); 702 gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, 703 producerConsumerMemrefs); 704 } 705 706 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' 707 /// that escape the function. A memref escapes the function if either: 708 /// 1. It's a function argument, or 709 /// 2. It's used by a non-affine op (e.g., std load/store, std call, etc.) 710 void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, 711 DenseSet<Value> &escapingMemRefs) { 712 auto *node = mdg->getNode(id); 713 for (auto *storeOpInst : node->stores) { 714 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); 715 if (escapingMemRefs.count(memref)) 716 continue; 717 // Check if 'memref' escapes because it's a block argument. 718 if (memref.isa<BlockArgument>()) { 719 escapingMemRefs.insert(memref); 720 continue; 721 } 722 // Check if 'memref' escapes through a non-affine op (e.g., std load/store, 723 // call op, etc.). 724 for (Operation *user : memref.getUsers()) 725 if (!isa<AffineMapAccessInterface>(*user)) 726 escapingMemRefs.insert(memref); 727 } 728 } 729 730 } // namespace 731 732 // Initializes the data dependence graph by walking operations in 'f'. 733 // Assigns each node in the graph a node id based on program order in 'f'. 734 // TODO: Add support for taking a Block arg to construct the 735 // dependence graph at a different depth. 736 bool MemRefDependenceGraph::init(FuncOp f) { 737 LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); 738 DenseMap<Value, SetVector<unsigned>> memrefAccesses; 739 740 // TODO: support multi-block functions. 741 if (!llvm::hasSingleElement(f)) 742 return false; 743 744 DenseMap<Operation *, unsigned> forToNodeMap; 745 for (auto &op : f.front()) { 746 if (auto forOp = dyn_cast<AffineForOp>(op)) { 747 // Create graph node 'id' to represent top-level 'forOp' and record 748 // all loads and store accesses it contains. 749 LoopNestStateCollector collector; 750 collector.collect(&op); 751 // Return false if a region holding op other than 'affine.for' and 752 // 'affine.if' was found (not currently supported). 753 if (collector.hasNonAffineRegionOp) 754 return false; 755 Node node(nextNodeId++, &op); 756 for (auto *opInst : collector.loadOpInsts) { 757 node.loads.push_back(opInst); 758 auto memref = cast<AffineReadOpInterface>(opInst).getMemRef(); 759 memrefAccesses[memref].insert(node.id); 760 } 761 for (auto *opInst : collector.storeOpInsts) { 762 node.stores.push_back(opInst); 763 auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef(); 764 memrefAccesses[memref].insert(node.id); 765 } 766 forToNodeMap[&op] = node.id; 767 nodes.insert({node.id, node}); 768 } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 769 // Create graph node for top-level load op. 770 Node node(nextNodeId++, &op); 771 node.loads.push_back(&op); 772 auto memref = cast<AffineReadOpInterface>(op).getMemRef(); 773 memrefAccesses[memref].insert(node.id); 774 nodes.insert({node.id, node}); 775 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 776 // Create graph node for top-level store op. 777 Node node(nextNodeId++, &op); 778 node.stores.push_back(&op); 779 auto memref = cast<AffineWriteOpInterface>(op).getMemRef(); 780 memrefAccesses[memref].insert(node.id); 781 nodes.insert({node.id, node}); 782 } else if (op.getNumRegions() != 0) { 783 // Return false if another region is found (not currently supported). 784 return false; 785 } else if (op.getNumResults() > 0 && !op.use_empty()) { 786 // Create graph node for top-level producer of SSA values, which 787 // could be used by loop nest nodes. 788 Node node(nextNodeId++, &op); 789 nodes.insert({node.id, node}); 790 } else if (isa<CallOpInterface>(op)) { 791 // Create graph node for top-level Call Op that takes any argument of 792 // memref type. Call Op that returns one or more memref type results 793 // is already taken care of, by the previous conditions. 794 if (llvm::any_of(op.getOperandTypes(), 795 [&](Type t) { return t.isa<MemRefType>(); })) { 796 Node node(nextNodeId++, &op); 797 nodes.insert({node.id, node}); 798 } 799 } else if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) { 800 // Create graph node for top-level op, which could have a memory write 801 // side effect. 802 SmallVector<MemoryEffects::EffectInstance, 1> effects; 803 effectInterface.getEffects(effects); 804 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &it) { 805 return isa<MemoryEffects::Write, MemoryEffects::Free>( 806 it.getEffect()); 807 })) { 808 Node node(nextNodeId++, &op); 809 nodes.insert({node.id, node}); 810 } 811 } 812 } 813 814 for (auto &idAndNode : nodes) { 815 LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n" 816 << *(idAndNode.second.op) << "\n"); 817 (void)idAndNode; 818 } 819 820 // Add dependence edges between nodes which produce SSA values and their 821 // users. Load ops can be considered as the ones producing SSA values. 822 for (auto &idAndNode : nodes) { 823 const Node &node = idAndNode.second; 824 // Stores don't define SSA values, skip them. 825 if (!node.stores.empty()) 826 continue; 827 auto *opInst = node.op; 828 for (auto value : opInst->getResults()) { 829 for (auto *user : value.getUsers()) { 830 SmallVector<AffineForOp, 4> loops; 831 getLoopIVs(*user, &loops); 832 if (loops.empty()) 833 continue; 834 assert(forToNodeMap.count(loops[0].getOperation()) > 0); 835 unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()]; 836 addEdge(node.id, userLoopNestId, value); 837 } 838 } 839 } 840 841 // Walk memref access lists and add graph edges between dependent nodes. 842 for (auto &memrefAndList : memrefAccesses) { 843 unsigned n = memrefAndList.second.size(); 844 for (unsigned i = 0; i < n; ++i) { 845 unsigned srcId = memrefAndList.second[i]; 846 bool srcHasStore = 847 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; 848 for (unsigned j = i + 1; j < n; ++j) { 849 unsigned dstId = memrefAndList.second[j]; 850 bool dstHasStore = 851 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; 852 if (srcHasStore || dstHasStore) 853 addEdge(srcId, dstId, memrefAndList.first); 854 } 855 } 856 } 857 return true; 858 } 859 860 // Sinks all sequential loops to the innermost levels (while preserving 861 // relative order among them) and moves all parallel loops to the 862 // outermost (while again preserving relative order among them). 863 // This can increase the loop depth at which we can fuse a slice, since we are 864 // pushing loop carried dependence to a greater depth in the loop nest. 865 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { 866 assert(isa<AffineForOp>(node->op)); 867 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op)); 868 node->op = newRootForOp.getOperation(); 869 } 870 871 // TODO: improve/complete this when we have target data. 872 static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { 873 auto elementType = memRefType.getElementType(); 874 875 unsigned sizeInBits; 876 if (elementType.isIntOrFloat()) { 877 sizeInBits = elementType.getIntOrFloatBitWidth(); 878 } else { 879 auto vectorType = elementType.cast<VectorType>(); 880 sizeInBits = 881 vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); 882 } 883 return llvm::divideCeil(sizeInBits, 8); 884 } 885 886 // Creates and returns a private (single-user) memref for fused loop rooted 887 // at 'forOp', with (potentially reduced) memref size based on the 888 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. 889 // TODO: consider refactoring the common code from generateDma and 890 // this one. 891 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, 892 unsigned dstLoopDepth, 893 Optional<unsigned> fastMemorySpace, 894 uint64_t localBufSizeThreshold) { 895 auto *forInst = forOp.getOperation(); 896 897 // Create builder to insert alloc op just before 'forOp'. 898 OpBuilder b(forInst); 899 // Builder to create constants at the top level. 900 OpBuilder top(forInst->getParentOfType<FuncOp>().getBody()); 901 // Create new memref type based on slice bounds. 902 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef(); 903 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>(); 904 unsigned rank = oldMemRefType.getRank(); 905 906 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. 907 MemRefRegion region(srcStoreOpInst->getLoc()); 908 bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); 909 (void)validRegion; 910 assert(validRegion && "unexpected memref region failure"); 911 SmallVector<int64_t, 4> newShape; 912 std::vector<SmallVector<int64_t, 4>> lbs; 913 SmallVector<int64_t, 8> lbDivisors; 914 lbs.reserve(rank); 915 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed 916 // by 'srcStoreOpInst' at depth 'dstLoopDepth'. 917 Optional<int64_t> numElements = 918 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); 919 assert(numElements.hasValue() && 920 "non-constant number of elts in local buffer"); 921 922 const FlatAffineValueConstraints *cst = region.getConstraints(); 923 // 'outerIVs' holds the values that this memory region is symbolic/parametric 924 // on; this would correspond to loop IVs surrounding the level at which the 925 // slice is being materialized. 926 SmallVector<Value, 8> outerIVs; 927 cst->getValues(rank, cst->getNumIds(), &outerIVs); 928 929 // Build 'rank' AffineExprs from MemRefRegion 'lbs' 930 SmallVector<AffineExpr, 4> offsets; 931 offsets.reserve(rank); 932 for (unsigned d = 0; d < rank; ++d) { 933 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); 934 935 AffineExpr offset = top.getAffineConstantExpr(0); 936 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { 937 offset = offset + lbs[d][j] * top.getAffineDimExpr(j); 938 } 939 assert(lbDivisors[d] > 0); 940 offset = 941 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); 942 offsets.push_back(offset); 943 } 944 945 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed 946 // by 'srcStoreOpInst'. 947 uint64_t bufSize = 948 getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue(); 949 unsigned newMemSpace; 950 if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) { 951 newMemSpace = fastMemorySpace.getValue(); 952 } else { 953 newMemSpace = oldMemRefType.getMemorySpaceAsInt(); 954 } 955 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), 956 {}, newMemSpace); 957 958 // Create new private memref for fused loop 'forOp'. 'newShape' is always 959 // a constant shape. 960 // TODO: Create/move alloc ops for private memrefs closer to their 961 // consumer loop nests to reduce their live range. Currently they are added 962 // at the beginning of the function, because loop nests can be reordered 963 // during the fusion pass. 964 Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType); 965 966 // Build an AffineMap to remap access functions based on lower bound offsets. 967 SmallVector<AffineExpr, 4> remapExprs; 968 remapExprs.reserve(rank); 969 for (unsigned i = 0; i < rank; i++) { 970 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); 971 972 auto remapExpr = 973 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); 974 remapExprs.push_back(remapExpr); 975 } 976 977 auto indexRemap = 978 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext()); 979 980 // Replace all users of 'oldMemRef' with 'newMemRef'. 981 LogicalResult res = 982 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, 983 /*extraOperands=*/outerIVs, 984 /*symbolOperands=*/{}, 985 /*domOpFilter=*/&*forOp.getBody()->begin()); 986 assert(succeeded(res) && 987 "replaceAllMemrefUsesWith should always succeed here"); 988 (void)res; 989 return newMemRef; 990 } 991 992 /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 993 /// 'dstId'), if there is any non-affine operation accessing 'memref', return 994 /// true. Otherwise, return false. 995 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 996 Value memref, 997 MemRefDependenceGraph *mdg) { 998 auto *srcNode = mdg->getNode(srcId); 999 auto *dstNode = mdg->getNode(dstId); 1000 Value::user_range users = memref.getUsers(); 1001 // For each MemRefDependenceGraph's node that is between 'srcNode' and 1002 // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any 1003 // non-affine operation in the node accesses the 'memref'. 1004 for (auto &idAndNode : mdg->nodes) { 1005 Operation *op = idAndNode.second.op; 1006 // Take care of operations between 'srcNode' and 'dstNode'. 1007 if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) { 1008 // Walk inside the operation to find any use of the memref. 1009 // Interrupt the walk if found. 1010 auto walkResult = op->walk([&](Operation *user) { 1011 // Skip affine ops. 1012 if (isa<AffineMapAccessInterface>(*user)) 1013 return WalkResult::advance(); 1014 // Find a non-affine op that uses the memref. 1015 if (llvm::is_contained(users, user)) 1016 return WalkResult::interrupt(); 1017 return WalkResult::advance(); 1018 }); 1019 if (walkResult.wasInterrupted()) 1020 return true; 1021 } 1022 } 1023 return false; 1024 } 1025 1026 /// Check whether a memref value in node 'srcId' has a non-affine that 1027 /// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and 1028 /// 'dstNode'). 1029 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 1030 MemRefDependenceGraph *mdg) { 1031 // Collect memref values in node 'srcId'. 1032 auto *srcNode = mdg->getNode(srcId); 1033 llvm::SmallDenseSet<Value, 2> memRefValues; 1034 srcNode->op->walk([&](Operation *op) { 1035 // Skip affine ops. 1036 if (isa<AffineForOp>(op)) 1037 return WalkResult::advance(); 1038 for (Value v : op->getOperands()) 1039 // Collect memref values only. 1040 if (v.getType().isa<MemRefType>()) 1041 memRefValues.insert(v); 1042 return WalkResult::advance(); 1043 }); 1044 // Looking for users between node 'srcId' and node 'dstId'. 1045 for (Value memref : memRefValues) 1046 if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg)) 1047 return true; 1048 return false; 1049 } 1050 1051 // Checks the profitability of fusing a backwards slice of the loop nest 1052 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. 1053 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on 1054 // the memref being produced and consumed, which is an input to the cost model. 1055 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as 1056 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse 1057 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the 1058 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the 1059 // unique store op in the src node, which will be used to check that the write 1060 // region is the same after input-reuse fusion. Computation slices are provided 1061 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which 1062 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is 1063 // profitable to fuse the candidate loop nests. Returns false otherwise. 1064 // `dstLoopDepth` is set to the most profitable depth at which to materialize 1065 // the source loop nest slice. 1066 // The profitability model executes the following steps: 1067 // *) Computes the backward computation slice at 'srcOpInst'. This 1068 // computation slice of the loop nest surrounding 'srcOpInst' is 1069 // represented by modified src loop bounds in 'sliceState', which are 1070 // functions of loop IVs in the loop nest surrounding 'srcOpInst'. 1071 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a 1072 // loop nest is the total number of dynamic operation instances in the loop 1073 // nest). 1074 // *) Computes the cost of fusing a slice of the src loop nest into the dst 1075 // loop nest at various values of dst loop depth, attempting to fuse 1076 // the largest computation slice at the maximal dst loop depth (closest to 1077 // the load) to minimize reuse distance and potentially enable subsequent 1078 // load/store forwarding. 1079 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop 1080 // nest, at which the src computation slice is inserted/fused. 1081 // NOTE: We attempt to maximize the dst loop depth, but there are cases 1082 // where a particular setting for 'dstLoopNest' might fuse an unsliced 1083 // loop (within the src computation slice) at a depth which results in 1084 // excessive recomputation (see unit tests for examples). 1085 // *) Compares the total cost of the unfused loop nests to the min cost fused 1086 // loop nest computed in the previous step, and returns true if the latter 1087 // is lower. 1088 // TODO: Extend profitability analysis to support scenarios with multiple 1089 // stores. 1090 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 1091 AffineForOp dstForOp, 1092 ArrayRef<ComputationSliceState> depthSliceUnions, 1093 unsigned maxLegalFusionDepth, 1094 unsigned *dstLoopDepth, 1095 double computeToleranceThreshold) { 1096 LLVM_DEBUG({ 1097 llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; 1098 llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; 1099 llvm::dbgs() << dstForOp << "\n"; 1100 }); 1101 1102 if (maxLegalFusionDepth == 0) { 1103 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth == 0 .\n"); 1104 return false; 1105 } 1106 1107 // Compute cost of sliced and unsliced src loop nest. 1108 SmallVector<AffineForOp, 4> srcLoopIVs; 1109 getLoopIVs(*srcOpInst, &srcLoopIVs); 1110 1111 // Walk src loop nest and collect stats. 1112 LoopNestStats srcLoopNestStats; 1113 if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) 1114 return false; 1115 1116 // Compute cost of dst loop nest. 1117 LoopNestStats dstLoopNestStats; 1118 if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) 1119 return false; 1120 1121 // Search for min cost value for 'dstLoopDepth'. At each value of 1122 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice 1123 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union 1124 // of these bounds). Next the union slice bounds are used to calculate 1125 // the cost of the slice and the cost of the slice inserted into the dst 1126 // loop nest at 'dstLoopDepth'. 1127 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); 1128 double maxStorageReduction = 0.0; 1129 Optional<uint64_t> sliceMemEstimate = None; 1130 1131 // The best loop depth at which to materialize the slice. 1132 Optional<unsigned> bestDstLoopDepth = None; 1133 1134 // Compute op instance count for the src loop nest without iteration slicing. 1135 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); 1136 1137 // Compute src loop nest write region size. 1138 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); 1139 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { 1140 LLVM_DEBUG(llvm::dbgs() 1141 << "Unable to compute MemRefRegion for source operation\n."); 1142 return false; 1143 } 1144 1145 Optional<int64_t> maybeSrcWriteRegionSizeBytes = 1146 srcWriteRegion.getRegionSize(); 1147 if (!maybeSrcWriteRegionSizeBytes.hasValue()) 1148 return false; 1149 int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue(); 1150 1151 // Compute op instance count for the src loop nest. 1152 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); 1153 1154 // Evaluate all depth choices for materializing the slice in the destination 1155 // loop nest. 1156 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { 1157 const ComputationSliceState &slice = depthSliceUnions[i - 1]; 1158 // Skip slice union if it wasn't computed for this depth. 1159 if (slice.isEmpty()) 1160 continue; 1161 1162 int64_t fusedLoopNestComputeCost; 1163 if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, 1164 dstLoopNestStats, slice, 1165 &fusedLoopNestComputeCost)) { 1166 LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); 1167 continue; 1168 } 1169 1170 double additionalComputeFraction = 1171 fusedLoopNestComputeCost / 1172 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 1173 1; 1174 1175 // Determine what the slice write MemRefRegion would be, if the src loop 1176 // nest slice 'slice' were to be inserted into the dst loop nest at loop 1177 // depth 'i'. 1178 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); 1179 if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, 1180 &slice))) { 1181 LLVM_DEBUG(llvm::dbgs() 1182 << "Failed to compute slice write region at loopDepth: " << i 1183 << "\n"); 1184 continue; 1185 } 1186 1187 Optional<int64_t> maybeSliceWriteRegionSizeBytes = 1188 sliceWriteRegion.getRegionSize(); 1189 if (!maybeSliceWriteRegionSizeBytes.hasValue() || 1190 maybeSliceWriteRegionSizeBytes.getValue() == 0) { 1191 LLVM_DEBUG(llvm::dbgs() 1192 << "Failed to get slice write region size at loopDepth: " << i 1193 << "\n"); 1194 continue; 1195 } 1196 int64_t sliceWriteRegionSizeBytes = 1197 maybeSliceWriteRegionSizeBytes.getValue(); 1198 1199 // If we are fusing for reuse, check that write regions remain the same. 1200 // TODO: Write region check should check sizes and offsets in 1201 // each dimension, so that we are sure they are covering the same memref 1202 // region. Also, move this out to a isMemRefRegionSuperSet helper function. 1203 if (srcOpInst != srcStoreOpInst && 1204 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) 1205 continue; 1206 1207 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) / 1208 static_cast<double>(sliceWriteRegionSizeBytes); 1209 1210 LLVM_DEBUG({ 1211 std::stringstream msg; 1212 msg << " evaluating fusion profitability at depth : " << i << "\n" 1213 << std::fixed << std::setprecision(2) 1214 << " additional compute fraction: " 1215 << 100.0 * additionalComputeFraction << "%\n" 1216 << " storage reduction factor: " << storageReduction << "x\n" 1217 << " fused nest cost: " << fusedLoopNestComputeCost << "\n" 1218 << " src write region size: " << srcWriteRegionSizeBytes << "\n" 1219 << " slice write region size: " << sliceWriteRegionSizeBytes 1220 << "\n"; 1221 llvm::dbgs() << msg.str(); 1222 }); 1223 1224 // TODO: This is a placeholder cost model. 1225 // Among all choices that add an acceptable amount of redundant computation 1226 // (as per computeToleranceThreshold), we will simply pick the one that 1227 // reduces the intermediary size the most. 1228 if ((storageReduction > maxStorageReduction) && 1229 (additionalComputeFraction < computeToleranceThreshold)) { 1230 maxStorageReduction = storageReduction; 1231 bestDstLoopDepth = i; 1232 minFusedLoopNestComputeCost = fusedLoopNestComputeCost; 1233 sliceMemEstimate = sliceWriteRegionSizeBytes; 1234 } 1235 } 1236 1237 // A simple cost model: fuse if it reduces the memory footprint. 1238 1239 if (!bestDstLoopDepth.hasValue()) { 1240 LLVM_DEBUG( 1241 llvm::dbgs() 1242 << "All fusion choices involve more than the threshold amount of " 1243 "redundant computation; NOT fusing.\n"); 1244 return false; 1245 } 1246 1247 if (!bestDstLoopDepth.hasValue()) { 1248 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); 1249 return false; 1250 } 1251 1252 // Set dstLoopDepth based on best values from search. 1253 *dstLoopDepth = bestDstLoopDepth.getValue(); 1254 1255 LLVM_DEBUG( 1256 llvm::dbgs() << " LoopFusion fusion stats:" 1257 << "\n best loop depth: " << bestDstLoopDepth 1258 << "\n src loop nest compute cost: " << srcLoopNestCost 1259 << "\n dst loop nest compute cost: " << dstLoopNestCost 1260 << "\n fused loop nest compute cost: " 1261 << minFusedLoopNestComputeCost << "\n"); 1262 1263 auto dstMemSize = getMemoryFootprintBytes(dstForOp); 1264 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); 1265 1266 Optional<double> storageReduction = None; 1267 1268 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { 1269 LLVM_DEBUG(llvm::dbgs() 1270 << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); 1271 return false; 1272 } 1273 1274 auto srcMemSizeVal = srcMemSize.getValue(); 1275 auto dstMemSizeVal = dstMemSize.getValue(); 1276 1277 assert(sliceMemEstimate.hasValue() && "expected value"); 1278 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); 1279 1280 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" 1281 << " dst mem: " << dstMemSizeVal << "\n" 1282 << " fused mem: " << fusedMem << "\n" 1283 << " slice mem: " << sliceMemEstimate << "\n"); 1284 1285 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { 1286 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); 1287 return false; 1288 } 1289 storageReduction = 1290 100.0 * 1291 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); 1292 1293 double additionalComputeFraction = 1294 100.0 * (minFusedLoopNestComputeCost / 1295 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 1296 1); 1297 (void)additionalComputeFraction; 1298 LLVM_DEBUG({ 1299 std::stringstream msg; 1300 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " 1301 << std::setprecision(2) << additionalComputeFraction 1302 << "% redundant computation and a "; 1303 msg << (storageReduction.hasValue() 1304 ? std::to_string(storageReduction.getValue()) 1305 : "<unknown>"); 1306 msg << "% storage reduction.\n"; 1307 llvm::dbgs() << msg.str(); 1308 }); 1309 1310 return true; 1311 } 1312 1313 namespace { 1314 1315 // GreedyFusion greedily fuses loop nests which have a producer/consumer or 1316 // input-reuse relationship on a memref, with the goal of improving locality. 1317 // 1318 // The steps of the producer-consumer fusion algorithm are as follows: 1319 // 1320 // *) A worklist is initialized with node ids from the dependence graph. 1321 // *) For each node id in the worklist: 1322 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a 1323 // candidate destination AffineForOp into which fusion will be attempted. 1324 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. 1325 // *) For each LoadOp in 'dstLoadOps' do: 1326 // *) Look up dependent loop nests which have a single store op to the same 1327 // memref. 1328 // *) Check if dependences would be violated by the fusion. 1329 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop 1330 // bounds to be functions of 'dstLoopNest' IVs and symbols. 1331 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', 1332 // at a loop depth determined by the cost model in 'isFusionProfitable'. 1333 // *) Add the newly fused load/store operations to the state, 1334 // and also add newly fused load ops to 'dstLoopOps' to be considered 1335 // as fusion dst load ops in another iteration. 1336 // *) Remove old src loop nest and its associated state. 1337 // 1338 // The steps of the input-reuse fusion algorithm are as follows: 1339 // 1340 // *) Initialize 'worklist' with node ids from the dependence graph. 1341 // *) For each 'dstNode' in the worklist: 1342 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which 1343 // loads from the same memref, but which has no dependence paths to/from. 1344 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop 1345 // bounds to be functions of 'dstLoopNest' IVs and symbols. 1346 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', 1347 // at a loop depth determined by the cost model in 'isFusionProfitable'. 1348 // This function also checks that the memref write region of 'sibLoopNest', 1349 // is preserved in the fused loop nest. 1350 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. 1351 // 1352 // Given a graph where top-level operations are vertices in the set 'V' and 1353 // edges in the set 'E' are dependences between vertices, this algorithm 1354 // takes O(V) time for initialization, and has runtime O(V + E). 1355 // 1356 // This greedy algorithm is not 'maximal' due to the current restriction of 1357 // fusing along single producer consumer edges, but there is a TODO: to fix 1358 // this. 1359 // 1360 // TODO: Experiment with other fusion policies. 1361 struct GreedyFusion { 1362 public: 1363 // The data dependence graph to traverse during fusion. 1364 MemRefDependenceGraph *mdg; 1365 // Worklist of graph nodes visited during the fusion pass. 1366 SmallVector<unsigned, 8> worklist; 1367 // Parameter for local buffer size threshold. 1368 unsigned localBufSizeThreshold; 1369 // Parameter for fast memory space. 1370 Optional<unsigned> fastMemorySpace; 1371 // If true, ignore any additional (redundant) computation tolerance threshold 1372 // that would have prevented fusion. 1373 bool maximalFusion; 1374 // The amount of additional computation that is tolerated while fusing 1375 // pair-wise as a fraction of the total computation. 1376 double computeToleranceThreshold; 1377 1378 using Node = MemRefDependenceGraph::Node; 1379 1380 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, 1381 Optional<unsigned> fastMemorySpace, bool maximalFusion, 1382 double computeToleranceThreshold) 1383 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), 1384 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), 1385 computeToleranceThreshold(computeToleranceThreshold) {} 1386 1387 /// Initializes 'worklist' with nodes from 'mdg'. 1388 void init() { 1389 // TODO: Add a priority queue for prioritizing nodes by different 1390 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). 1391 worklist.clear(); 1392 for (auto &idAndNode : mdg->nodes) { 1393 const Node &node = idAndNode.second; 1394 worklist.push_back(node.id); 1395 } 1396 } 1397 /// Run only sibling fusion on the `mdg`. 1398 void runSiblingFusionOnly() { 1399 fuseSiblingNodes(); 1400 eraseUnusedMemRefAllocations(); 1401 } 1402 1403 /// Run only producer/consumer fusion on the `mdg`. 1404 void runProducerConsumerFusionOnly() { 1405 fuseProducerConsumerNodes( 1406 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 1407 eraseUnusedMemRefAllocations(); 1408 } 1409 1410 // Run the GreedyFusion pass. 1411 // *) First pass through the nodes fuses single-use producer nodes into their 1412 // unique consumer. 1413 // *) Second pass fuses sibling nodes which share no dependence edges. 1414 // *) Third pass fuses any remaining producer nodes into their users. 1415 void runGreedyFusion() { 1416 // TODO: Run this repeatedly until a fixed-point is reached. 1417 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); 1418 fuseSiblingNodes(); 1419 fuseProducerConsumerNodes( 1420 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 1421 eraseUnusedMemRefAllocations(); 1422 } 1423 1424 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { 1425 LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); 1426 init(); 1427 while (!worklist.empty()) { 1428 unsigned dstId = worklist.back(); 1429 worklist.pop_back(); 1430 1431 // Skip if this node was removed (fused into another node). 1432 if (mdg->nodes.count(dstId) == 0) 1433 continue; 1434 // Get 'dstNode' into which to attempt fusion. 1435 auto *dstNode = mdg->getNode(dstId); 1436 // Skip if 'dstNode' is not a loop nest. 1437 if (!isa<AffineForOp>(dstNode->op)) 1438 continue; 1439 // Skip if 'dstNode' is a loop nest returning values. 1440 // TODO: support loop nests that return values. 1441 if (dstNode->op->getNumResults() > 0) 1442 continue; 1443 1444 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 1445 1446 // Sink sequential loops in 'dstNode' (and thus raise parallel loops) 1447 // while preserving relative order. This can increase the maximum loop 1448 // depth at which we can fuse a slice of a producer loop nest into a 1449 // consumer loop nest. 1450 sinkSequentialLoops(dstNode); 1451 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1452 1453 // Try to fuse 'dstNode' with candidate producer loops until a fixed point 1454 // is reached. Fusing two loops may expose new fusion opportunities. 1455 bool dstNodeChanged; 1456 do { 1457 // Gather src loop candidates for 'dstNode' and visit them in "quasi" 1458 // reverse program order to minimize the number of iterations needed to 1459 // reach the fixed point. Note that this is a best effort approach since 1460 // 'getProducerCandidates' does not always guarantee that program order 1461 // in 'srcIdCandidates'. 1462 dstNodeChanged = false; 1463 SmallVector<unsigned, 16> srcIdCandidates; 1464 getProducerCandidates(dstId, mdg, srcIdCandidates); 1465 1466 for (unsigned srcId : llvm::reverse(srcIdCandidates)) { 1467 // Get 'srcNode' from which to attempt fusion into 'dstNode'. 1468 auto *srcNode = mdg->getNode(srcId); 1469 auto srcAffineForOp = cast<AffineForOp>(srcNode->op); 1470 LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId 1471 << " for dst loop " << dstId << "\n"); 1472 1473 // Skip if 'srcNode' is a loop nest returning values. 1474 // TODO: support loop nests that return values. 1475 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0) 1476 continue; 1477 1478 DenseSet<Value> producerConsumerMemrefs; 1479 gatherProducerConsumerMemrefs(srcId, dstId, mdg, 1480 producerConsumerMemrefs); 1481 1482 // Skip if 'srcNode' out edge count on any memref is greater than 1483 // 'maxSrcUserCount'. 1484 if (any_of(producerConsumerMemrefs, [&](Value memref) { 1485 return mdg->getOutEdgeCount(srcNode->id, memref) > 1486 maxSrcUserCount; 1487 })) 1488 continue; 1489 1490 // Gather memrefs in 'srcNode' that are written and escape to the 1491 // function (e.g., memref function arguments, returned memrefs, 1492 // memrefs passed to function calls, etc.). 1493 DenseSet<Value> srcEscapingMemRefs; 1494 gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); 1495 1496 // Skip if there are non-affine operations in between the 'srcNode' 1497 // and 'dstNode' using their memrefs. If so, we wouldn't be able to 1498 // compute a legal insertion point for now. 'srcNode' and 'dstNode' 1499 // memrefs with non-affine operation users would be considered 1500 // escaping memrefs so we can limit this check to only scenarios with 1501 // escaping memrefs. 1502 if (!srcEscapingMemRefs.empty() && 1503 hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { 1504 LLVM_DEBUG( 1505 llvm::dbgs() 1506 << "Can't fuse: non-affine users in between the loops\n."); 1507 continue; 1508 } 1509 1510 // Compute an operation list insertion point for the fused loop 1511 // nest which preserves dependences. 1512 Operation *fusedLoopInsPoint = 1513 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); 1514 if (fusedLoopInsPoint == nullptr) 1515 continue; 1516 1517 // Compute the innermost common loop depth for dstNode 1518 // producer-consumer loads/stores. 1519 SmallVector<Operation *, 2> dstMemrefOps; 1520 for (Operation *op : dstNode->loads) 1521 if (producerConsumerMemrefs.count( 1522 cast<AffineReadOpInterface>(op).getMemRef()) > 0) 1523 dstMemrefOps.push_back(op); 1524 for (Operation *op : dstNode->stores) 1525 if (producerConsumerMemrefs.count( 1526 cast<AffineWriteOpInterface>(op).getMemRef())) 1527 dstMemrefOps.push_back(op); 1528 unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); 1529 1530 // Check the feasibility of fusing src loop nest into dst loop nest 1531 // at loop depths in range [1, dstLoopDepthTest]. 1532 unsigned maxLegalFusionDepth = 0; 1533 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1534 depthSliceUnions.resize(dstLoopDepthTest); 1535 FusionStrategy strategy(FusionStrategy::ProducerConsumer); 1536 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1537 FusionResult result = mlir::canFuseLoops( 1538 srcAffineForOp, dstAffineForOp, 1539 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1540 1541 if (result.value == FusionResult::Success) 1542 maxLegalFusionDepth = i; 1543 } 1544 1545 if (maxLegalFusionDepth == 0) { 1546 LLVM_DEBUG(llvm::dbgs() 1547 << "Can't fuse: fusion is not legal at any depth\n"); 1548 continue; 1549 } 1550 1551 // Check if fusion would be profitable. We skip profitability analysis 1552 // for maximal fusion since we already know the maximal legal depth to 1553 // fuse. 1554 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1555 if (!maximalFusion) { 1556 // Retrieve producer stores from the src loop. 1557 SmallVector<Operation *, 2> producerStores; 1558 for (Operation *op : srcNode->stores) 1559 if (producerConsumerMemrefs.count( 1560 cast<AffineWriteOpInterface>(op).getMemRef())) 1561 producerStores.push_back(op); 1562 1563 // TODO: Suppport multiple producer stores in profitability 1564 // analysis. We limit profitability analysis to only scenarios with 1565 // a single producer store for now. Note that some multi-store 1566 // producer scenarios will still go through profitability analysis 1567 // if only one of the stores is involved the producer-consumer 1568 // relationship of the candidate loops. 1569 assert(!producerStores.empty() && "Expected producer store"); 1570 if (producerStores.size() > 1) 1571 LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " 1572 "supported for this case\n"); 1573 else if (!isFusionProfitable(producerStores[0], producerStores[0], 1574 dstAffineForOp, depthSliceUnions, 1575 maxLegalFusionDepth, &bestDstLoopDepth, 1576 computeToleranceThreshold)) 1577 continue; 1578 } 1579 1580 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1581 ComputationSliceState &bestSlice = 1582 depthSliceUnions[bestDstLoopDepth - 1]; 1583 assert(!bestSlice.isEmpty() && "Missing slice union for depth"); 1584 1585 // Determine if 'srcId' can be removed after fusion, taking into 1586 // account remaining dependences, escaping memrefs and the fusion 1587 // insertion point. 1588 bool removeSrcNode = canRemoveSrcNodeAfterFusion( 1589 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, 1590 mdg); 1591 1592 DenseSet<Value> privateMemrefs; 1593 for (Value memref : producerConsumerMemrefs) { 1594 // If `memref` is an escaping one, do not create a private memref 1595 // for the below scenarios, since doing so will leave the escaping 1596 // memref unmodified as all the writes originally meant for the 1597 // escaping memref would be performed on the private memref: 1598 // 1. The source is to be removed after fusion, 1599 // OR 1600 // 2. The destination writes to `memref`. 1601 if (srcEscapingMemRefs.count(memref) > 0 && 1602 (removeSrcNode || dstNode->getStoreOpCount(memref) > 0)) 1603 continue; 1604 1605 // Don't create a private memref if 'srcNode' has in edges on 1606 // 'memref' or 'dstNode' has out edges on 'memref'. 1607 if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 || 1608 mdg->getOutEdgeCount(dstId, memref) > 0) 1609 continue; 1610 1611 // If 'srcNode' will be removed but it has out edges on 'memref' to 1612 // nodes other than 'dstNode', we have to preserve dependences and 1613 // cannot create a private memref. 1614 if (removeSrcNode && 1615 any_of(mdg->outEdges[srcId], [&](const auto &edge) { 1616 return edge.value == memref && edge.id != dstId; 1617 })) 1618 continue; 1619 1620 // Create a private version of this memref. 1621 privateMemrefs.insert(memref); 1622 } 1623 1624 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. 1625 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); 1626 dstNodeChanged = true; 1627 1628 LLVM_DEBUG(llvm::dbgs() 1629 << "Fused src loop " << srcId << " into dst loop " << dstId 1630 << " at depth " << bestDstLoopDepth << ":\n" 1631 << dstAffineForOp << "\n"); 1632 1633 // Move 'dstAffineForOp' before 'insertPointInst' if needed. 1634 if (fusedLoopInsPoint != dstAffineForOp.getOperation()) 1635 dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint); 1636 1637 // Update edges between 'srcNode' and 'dstNode'. 1638 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, 1639 removeSrcNode); 1640 1641 // Create private memrefs. 1642 if (!privateMemrefs.empty()) { 1643 // Gather stores for all the private-to-be memrefs. 1644 DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores; 1645 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { 1646 Value storeMemRef = storeOp.getMemRef(); 1647 if (privateMemrefs.count(storeMemRef) > 0) 1648 privateMemRefToStores[storeMemRef].push_back( 1649 storeOp.getOperation()); 1650 }); 1651 1652 // Replace original memrefs with private memrefs. Note that all the 1653 // loads and stores on these memrefs will be replaced with a new 1654 // loads and stores. Any reference to the original ones becomes 1655 // invalid after this point. 1656 for (auto &memrefToStoresPair : privateMemRefToStores) { 1657 // TODO: Use union of memref write regions to compute 1658 // private memref footprint. 1659 SmallVector<Operation *, 4> &storesForMemref = 1660 memrefToStoresPair.second; 1661 Value newMemRef = createPrivateMemRef( 1662 dstAffineForOp, storesForMemref[0], bestDstLoopDepth, 1663 fastMemorySpace, localBufSizeThreshold); 1664 // Create new node in dependence graph for 'newMemRef' alloc op. 1665 unsigned newMemRefNodeId = 1666 mdg->addNode(newMemRef.getDefiningOp()); 1667 // Add edge from 'newMemRef' node to dstNode. 1668 mdg->addEdge(newMemRefNodeId, dstId, newMemRef); 1669 } 1670 // One or more entries for 'newMemRef' alloc op are inserted into 1671 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to 1672 // reallocate, update dstNode. 1673 dstNode = mdg->getNode(dstId); 1674 } 1675 1676 // Collect dst loop stats after memref privatization transformation. 1677 LoopNestStateCollector dstLoopCollector; 1678 dstLoopCollector.collect(dstAffineForOp.getOperation()); 1679 1680 // Clear and add back loads and stores. 1681 mdg->clearNodeLoadAndStores(dstNode->id); 1682 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, 1683 dstLoopCollector.storeOpInsts); 1684 1685 if (removeSrcNode) { 1686 LLVM_DEBUG(llvm::dbgs() 1687 << "Removing src loop " << srcId << " after fusion\n"); 1688 // srcNode is no longer valid after it is removed from mdg. 1689 srcAffineForOp.erase(); 1690 mdg->removeNode(srcId); 1691 srcNode = nullptr; 1692 } 1693 } 1694 } while (dstNodeChanged); 1695 } 1696 } 1697 1698 // Visits each node in the graph, and for each node, attempts to fuse it with 1699 // its sibling nodes (nodes which share a parent, but no dependence edges). 1700 void fuseSiblingNodes() { 1701 LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n"); 1702 init(); 1703 while (!worklist.empty()) { 1704 unsigned dstId = worklist.back(); 1705 worklist.pop_back(); 1706 1707 // Skip if this node was removed (fused into another node). 1708 if (mdg->nodes.count(dstId) == 0) 1709 continue; 1710 // Get 'dstNode' into which to attempt fusion. 1711 auto *dstNode = mdg->getNode(dstId); 1712 // Skip if 'dstNode' is not a loop nest. 1713 if (!isa<AffineForOp>(dstNode->op)) 1714 continue; 1715 // Attempt to fuse 'dstNode' with its sibling nodes in the graph. 1716 fuseWithSiblingNodes(dstNode); 1717 } 1718 } 1719 1720 // Attempt to fuse 'dstNode' with sibling nodes in the graph. 1721 void fuseWithSiblingNodes(Node *dstNode) { 1722 DenseSet<unsigned> visitedSibNodeIds; 1723 std::pair<unsigned, Value> idAndMemref; 1724 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1725 1726 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { 1727 unsigned sibId = idAndMemref.first; 1728 Value memref = idAndMemref.second; 1729 // TODO: Check that 'sibStoreOpInst' post-dominates all other 1730 // stores to the same memref in 'sibNode' loop nest. 1731 auto *sibNode = mdg->getNode(sibId); 1732 // Compute an operation list insertion point for the fused loop 1733 // nest which preserves dependences. 1734 assert(sibNode->op->getBlock() == dstNode->op->getBlock()); 1735 Operation *insertPointInst = 1736 sibNode->op->isBeforeInBlock(dstNode->op) 1737 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) 1738 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); 1739 if (insertPointInst == nullptr) 1740 continue; 1741 1742 // Check if fusion would be profitable and at what depth. 1743 1744 // Get unique 'sibNode' load op to 'memref'. 1745 SmallVector<Operation *, 2> sibLoadOpInsts; 1746 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); 1747 // Currently findSiblingNodeToFuse searches for siblings with one load. 1748 assert(sibLoadOpInsts.size() == 1); 1749 Operation *sibLoadOpInst = sibLoadOpInsts[0]; 1750 assert(!sibNode->stores.empty()); 1751 // TODO: Choose the store which postdominates all other stores. 1752 auto *sibStoreOpInst = sibNode->stores.back(); 1753 1754 // Gather 'dstNode' load ops to 'memref'. 1755 SmallVector<Operation *, 2> dstLoadOpInsts; 1756 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); 1757 1758 SmallVector<AffineForOp, 4> dstLoopIVs; 1759 getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); 1760 unsigned dstLoopDepthTest = dstLoopIVs.size(); 1761 auto sibAffineForOp = cast<AffineForOp>(sibNode->op); 1762 1763 // Compute loop depth and slice union for fusion. 1764 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1765 depthSliceUnions.resize(dstLoopDepthTest); 1766 unsigned maxLegalFusionDepth = 0; 1767 FusionStrategy strategy(memref); 1768 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1769 FusionResult result = mlir::canFuseLoops( 1770 sibAffineForOp, dstAffineForOp, 1771 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1772 1773 if (result.value == FusionResult::Success) 1774 maxLegalFusionDepth = i; 1775 } 1776 1777 // Skip if fusion is not feasible at any loop depths. 1778 if (maxLegalFusionDepth == 0) 1779 continue; 1780 1781 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1782 if (!maximalFusion) { 1783 // Check if fusion would be profitable. 1784 if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp, 1785 depthSliceUnions, maxLegalFusionDepth, 1786 &bestDstLoopDepth, computeToleranceThreshold)) 1787 continue; 1788 } 1789 1790 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1791 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && 1792 "Fusion depth has no computed slice union"); 1793 // Check if source loop is being inserted in the innermost 1794 // destination loop. Based on this, the fused loop may be optimized 1795 // further inside `fuseLoops`. 1796 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest); 1797 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. 1798 mlir::fuseLoops(sibAffineForOp, dstAffineForOp, 1799 depthSliceUnions[bestDstLoopDepth - 1], 1800 isInnermostInsertion); 1801 1802 auto dstForInst = cast<AffineForOp>(dstNode->op); 1803 // Update operation position of fused loop nest (if needed). 1804 if (insertPointInst != dstForInst.getOperation()) { 1805 dstForInst->moveBefore(insertPointInst); 1806 } 1807 // Update data dependence graph state post fusion. 1808 updateStateAfterSiblingFusion(sibNode, dstNode); 1809 } 1810 } 1811 1812 // Searches function argument uses and the graph from 'dstNode' looking for a 1813 // fusion candidate sibling node which shares no dependences with 'dstNode' 1814 // but which loads from the same memref. Returns true and sets 1815 // 'idAndMemrefToFuse' on success. Returns false otherwise. 1816 bool findSiblingNodeToFuse(Node *dstNode, 1817 DenseSet<unsigned> *visitedSibNodeIds, 1818 std::pair<unsigned, Value> *idAndMemrefToFuse) { 1819 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse 1820 // on 'memref'. 1821 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) { 1822 // Skip if 'outEdge' is not a read-after-write dependence. 1823 // TODO: Remove restrict to single load op restriction. 1824 if (sibNode->getLoadOpCount(memref) != 1) 1825 return false; 1826 // Skip if there exists a path of dependent edges between 1827 // 'sibNode' and 'dstNode'. 1828 if (mdg->hasDependencePath(sibNode->id, dstNode->id) || 1829 mdg->hasDependencePath(dstNode->id, sibNode->id)) 1830 return false; 1831 // Skip sib node if it loads to (and stores from) the same memref on 1832 // which it also has an input dependence edge. 1833 DenseSet<Value> loadAndStoreMemrefSet; 1834 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); 1835 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) { 1836 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; 1837 })) 1838 return false; 1839 1840 // Check that all stores are to the same memref. 1841 DenseSet<Value> storeMemrefs; 1842 for (auto *storeOpInst : sibNode->stores) { 1843 storeMemrefs.insert( 1844 cast<AffineWriteOpInterface>(storeOpInst).getMemRef()); 1845 } 1846 if (storeMemrefs.size() != 1) 1847 return false; 1848 1849 // Skip if a memref value in one node is used by a non-affine memref 1850 // access that lies between 'dstNode' and 'sibNode'. 1851 if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) || 1852 hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg)) 1853 return false; 1854 return true; 1855 }; 1856 1857 // Search for siblings which load the same memref function argument. 1858 auto fn = dstNode->op->getParentOfType<FuncOp>(); 1859 for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { 1860 for (auto *user : fn.getArgument(i).getUsers()) { 1861 if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) { 1862 // Gather loops surrounding 'use'. 1863 SmallVector<AffineForOp, 4> loops; 1864 getLoopIVs(*user, &loops); 1865 // Skip 'use' if it is not within a loop nest. 1866 if (loops.empty()) 1867 continue; 1868 Node *sibNode = mdg->getForOpNode(loops[0]); 1869 assert(sibNode != nullptr); 1870 // Skip 'use' if it not a sibling to 'dstNode'. 1871 if (sibNode->id == dstNode->id) 1872 continue; 1873 // Skip 'use' if it has been visited. 1874 if (visitedSibNodeIds->count(sibNode->id) > 0) 1875 continue; 1876 // Skip 'use' if it does not load from the same memref as 'dstNode'. 1877 auto memref = loadOp.getMemRef(); 1878 if (dstNode->getLoadOpCount(memref) == 0) 1879 continue; 1880 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1881 if (canFuseWithSibNode(sibNode, memref)) { 1882 visitedSibNodeIds->insert(sibNode->id); 1883 idAndMemrefToFuse->first = sibNode->id; 1884 idAndMemrefToFuse->second = memref; 1885 return true; 1886 } 1887 } 1888 } 1889 } 1890 1891 // Search for siblings by following edges through an intermediate src node. 1892 // Collect candidate 'dstNode' input edges in 'inEdges'. 1893 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges; 1894 mdg->forEachMemRefInputEdge( 1895 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { 1896 // Add 'inEdge' if it is a read-after-write dependence. 1897 if (dstNode->getLoadOpCount(inEdge.value) > 0 && 1898 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) 1899 inEdges.push_back(inEdge); 1900 }); 1901 1902 // Search for sibling nodes to fuse by visiting output edges from each input 1903 // edge in 'inEdges'. 1904 for (auto &inEdge : inEdges) { 1905 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. 1906 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges; 1907 mdg->forEachMemRefOutputEdge( 1908 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { 1909 unsigned sibNodeId = outEdge.id; 1910 if (visitedSibNodeIds->count(sibNodeId) > 0) 1911 return; 1912 // Skip output edge if not a sibling using the same memref. 1913 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) 1914 return; 1915 auto *sibNode = mdg->getNode(sibNodeId); 1916 if (!isa<AffineForOp>(sibNode->op)) 1917 return; 1918 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1919 if (canFuseWithSibNode(sibNode, outEdge.value)) { 1920 // Add candidate 'outEdge' to sibling node. 1921 outEdges.push_back(outEdge); 1922 } 1923 }); 1924 1925 // Add first candidate if any were returned. 1926 if (!outEdges.empty()) { 1927 visitedSibNodeIds->insert(outEdges[0].id); 1928 idAndMemrefToFuse->first = outEdges[0].id; 1929 idAndMemrefToFuse->second = outEdges[0].value; 1930 return true; 1931 } 1932 } 1933 return false; 1934 } 1935 1936 /// Update data dependence graph state to reflect sibling fusion of 'sibNode' 1937 /// into 'dstNode'. 1938 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { 1939 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. 1940 mdg->updateEdges(sibNode->id, dstNode->id); 1941 1942 // Collect dst loop stats after memref privatization transformation. 1943 auto dstForInst = cast<AffineForOp>(dstNode->op); 1944 LoopNestStateCollector dstLoopCollector; 1945 dstLoopCollector.collect(dstForInst.getOperation()); 1946 // Clear and add back loads and stores 1947 mdg->clearNodeLoadAndStores(dstNode->id); 1948 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, 1949 dstLoopCollector.storeOpInsts); 1950 // Remove old sibling loop nest if it no longer has outgoing dependence 1951 // edges, and it does not write to a memref which escapes the 1952 // function. 1953 if (mdg->getOutEdgeCount(sibNode->id) == 0) { 1954 mdg->removeNode(sibNode->id); 1955 sibNode->op->erase(); 1956 } 1957 } 1958 1959 // Clean up any allocs with no users. 1960 void eraseUnusedMemRefAllocations() { 1961 for (auto &pair : mdg->memrefEdgeCount) { 1962 if (pair.second > 0) 1963 continue; 1964 auto memref = pair.first; 1965 // Skip if there exist other uses (return operation or function calls). 1966 if (!memref.use_empty()) 1967 continue; 1968 // Use list expected to match the dep graph info. 1969 auto *op = memref.getDefiningOp(); 1970 if (isa_and_nonnull<memref::AllocOp>(op)) 1971 op->erase(); 1972 } 1973 } 1974 }; 1975 1976 } // namespace 1977 1978 void LoopFusion::runOnOperation() { 1979 MemRefDependenceGraph g; 1980 if (!g.init(getOperation())) 1981 return; 1982 1983 Optional<unsigned> fastMemorySpaceOpt; 1984 if (fastMemorySpace.hasValue()) 1985 fastMemorySpaceOpt = fastMemorySpace; 1986 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024; 1987 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt, 1988 maximalFusion, computeToleranceThreshold); 1989 1990 if (affineFusionMode == FusionMode::ProducerConsumer) 1991 fusion.runProducerConsumerFusionOnly(); 1992 else if (affineFusionMode == FusionMode::Sibling) 1993 fusion.runSiblingFusionOnly(); 1994 else 1995 fusion.runGreedyFusion(); 1996 } 1997