1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop fusion. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 15 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Affine/Analysis/Utils.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 20 #include "mlir/Dialect/Affine/LoopUtils.h" 21 #include "mlir/Dialect/Affine/Utils.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/Transforms/Passes.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/SetVector.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 #include <iomanip> 35 #include <sstream> 36 #define DEBUG_TYPE "affine-loop-fusion" 37 38 using namespace mlir; 39 40 namespace { 41 /// Loop fusion pass. This pass currently supports a greedy fusion policy, 42 /// which fuses loop nests with single-writer/single-reader memref dependences 43 /// with the goal of improving locality. 44 45 // TODO: Support fusion of source loop nests which write to multiple 46 // memrefs, where each memref can have multiple users (if profitable). 47 // TODO: Extend this pass to check for fusion preventing dependences, 48 // and add support for more general loop fusion algorithms. 49 50 struct LoopFusion : public AffineLoopFusionBase<LoopFusion> { 51 LoopFusion() = default; 52 LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes, 53 bool maximalFusion, enum FusionMode affineFusionMode) { 54 this->fastMemorySpace = fastMemorySpace; 55 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024; 56 this->maximalFusion = maximalFusion; 57 this->affineFusionMode = affineFusionMode; 58 } 59 60 void runOnOperation() override; 61 }; 62 63 } // namespace 64 65 std::unique_ptr<OperationPass<func::FuncOp>> 66 mlir::createLoopFusionPass(unsigned fastMemorySpace, 67 uint64_t localBufSizeThreshold, bool maximalFusion, 68 enum FusionMode affineFusionMode) { 69 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, 70 maximalFusion, affineFusionMode); 71 } 72 73 namespace { 74 75 // LoopNestStateCollector walks loop nests and collects load and store 76 // operations, and whether or not a region holding op other than ForOp and IfOp 77 // was encountered in the loop nest. 78 struct LoopNestStateCollector { 79 SmallVector<AffineForOp, 4> forOps; 80 SmallVector<Operation *, 4> loadOpInsts; 81 SmallVector<Operation *, 4> storeOpInsts; 82 bool hasNonAffineRegionOp = false; 83 84 void collect(Operation *opToWalk) { 85 opToWalk->walk([&](Operation *op) { 86 if (isa<AffineForOp>(op)) 87 forOps.push_back(cast<AffineForOp>(op)); 88 else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op)) 89 hasNonAffineRegionOp = true; 90 else if (isa<AffineReadOpInterface>(op)) 91 loadOpInsts.push_back(op); 92 else if (isa<AffineWriteOpInterface>(op)) 93 storeOpInsts.push_back(op); 94 }); 95 } 96 }; 97 98 // MemRefDependenceGraph is a graph data structure where graph nodes are 99 // top-level operations in a FuncOp which contain load/store ops, and edges 100 // are memref dependences between the nodes. 101 // TODO: Add a more flexible dependence graph representation. 102 // TODO: Add a depth parameter to dependence graph construction. 103 struct MemRefDependenceGraph { 104 public: 105 // Node represents a node in the graph. A Node is either an entire loop nest 106 // rooted at the top level which contains loads/stores, or a top level 107 // load/store. 108 struct Node { 109 // The unique identifier of this node in the graph. 110 unsigned id; 111 // The top-level statement which is (or contains) a load/store. 112 Operation *op; 113 // List of load operations. 114 SmallVector<Operation *, 4> loads; 115 // List of store op insts. 116 SmallVector<Operation *, 4> stores; 117 Node(unsigned id, Operation *op) : id(id), op(op) {} 118 119 // Returns the load op count for 'memref'. 120 unsigned getLoadOpCount(Value memref) { 121 unsigned loadOpCount = 0; 122 for (auto *loadOpInst : loads) { 123 if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef()) 124 ++loadOpCount; 125 } 126 return loadOpCount; 127 } 128 129 // Returns the store op count for 'memref'. 130 unsigned getStoreOpCount(Value memref) { 131 unsigned storeOpCount = 0; 132 for (auto *storeOpInst : stores) { 133 if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef()) 134 ++storeOpCount; 135 } 136 return storeOpCount; 137 } 138 139 // Returns all store ops in 'storeOps' which access 'memref'. 140 void getStoreOpsForMemref(Value memref, 141 SmallVectorImpl<Operation *> *storeOps) { 142 for (auto *storeOpInst : stores) { 143 if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef()) 144 storeOps->push_back(storeOpInst); 145 } 146 } 147 148 // Returns all load ops in 'loadOps' which access 'memref'. 149 void getLoadOpsForMemref(Value memref, 150 SmallVectorImpl<Operation *> *loadOps) { 151 for (auto *loadOpInst : loads) { 152 if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef()) 153 loadOps->push_back(loadOpInst); 154 } 155 } 156 157 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node 158 // has at least one load and store operation. 159 void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) { 160 llvm::SmallDenseSet<Value, 2> loadMemrefs; 161 for (auto *loadOpInst : loads) { 162 loadMemrefs.insert(cast<AffineReadOpInterface>(loadOpInst).getMemRef()); 163 } 164 for (auto *storeOpInst : stores) { 165 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); 166 if (loadMemrefs.count(memref) > 0) 167 loadAndStoreMemrefSet->insert(memref); 168 } 169 } 170 }; 171 172 // Edge represents a data dependence between nodes in the graph. 173 struct Edge { 174 // The id of the node at the other end of the edge. 175 // If this edge is stored in Edge = Node.inEdges[i], then 176 // 'Node.inEdges[i].id' is the identifier of the source node of the edge. 177 // If this edge is stored in Edge = Node.outEdges[i], then 178 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge. 179 unsigned id; 180 // The SSA value on which this edge represents a dependence. 181 // If the value is a memref, then the dependence is between graph nodes 182 // which contain accesses to the same memref 'value'. If the value is a 183 // non-memref value, then the dependence is between a graph node which 184 // defines an SSA value and another graph node which uses the SSA value 185 // (e.g. a constant or load operation defining a value which is used inside 186 // a loop nest). 187 Value value; 188 }; 189 190 // Map from node id to Node. 191 DenseMap<unsigned, Node> nodes; 192 // Map from node id to list of input edges. 193 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges; 194 // Map from node id to list of output edges. 195 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges; 196 // Map from memref to a count on the dependence edges associated with that 197 // memref. 198 DenseMap<Value, unsigned> memrefEdgeCount; 199 // The next unique identifier to use for newly created graph nodes. 200 unsigned nextNodeId = 0; 201 202 MemRefDependenceGraph() = default; 203 204 // Initializes the dependence graph based on operations in 'f'. 205 // Returns true on success, false otherwise. 206 bool init(func::FuncOp f); 207 208 // Returns the graph node for 'id'. 209 Node *getNode(unsigned id) { 210 auto it = nodes.find(id); 211 assert(it != nodes.end()); 212 return &it->second; 213 } 214 215 // Returns the graph node for 'forOp'. 216 Node *getForOpNode(AffineForOp forOp) { 217 for (auto &idAndNode : nodes) 218 if (idAndNode.second.op == forOp.getOperation()) 219 return &idAndNode.second; 220 return nullptr; 221 } 222 223 // Adds a node with 'op' to the graph and returns its unique identifier. 224 unsigned addNode(Operation *op) { 225 Node node(nextNodeId++, op); 226 nodes.insert({node.id, node}); 227 return node.id; 228 } 229 230 // Remove node 'id' (and its associated edges) from graph. 231 void removeNode(unsigned id) { 232 // Remove each edge in 'inEdges[id]'. 233 if (inEdges.count(id) > 0) { 234 SmallVector<Edge, 2> oldInEdges = inEdges[id]; 235 for (auto &inEdge : oldInEdges) { 236 removeEdge(inEdge.id, id, inEdge.value); 237 } 238 } 239 // Remove each edge in 'outEdges[id]'. 240 if (outEdges.count(id) > 0) { 241 SmallVector<Edge, 2> oldOutEdges = outEdges[id]; 242 for (auto &outEdge : oldOutEdges) { 243 removeEdge(id, outEdge.id, outEdge.value); 244 } 245 } 246 // Erase remaining node state. 247 inEdges.erase(id); 248 outEdges.erase(id); 249 nodes.erase(id); 250 } 251 252 // Returns true if node 'id' writes to any memref which escapes (or is an 253 // argument to) the function/block. Returns false otherwise. 254 bool writesToLiveInOrEscapingMemrefs(unsigned id) { 255 Node *node = getNode(id); 256 for (auto *storeOpInst : node->stores) { 257 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); 258 auto *op = memref.getDefiningOp(); 259 // Return true if 'memref' is a block argument. 260 if (!op) 261 return true; 262 // Return true if any use of 'memref' escapes the function. 263 for (auto *user : memref.getUsers()) 264 if (!isa<AffineMapAccessInterface>(*user)) 265 return true; 266 } 267 return false; 268 } 269 270 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which 271 // is for 'value' if non-null, or for any value otherwise. Returns false 272 // otherwise. 273 bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) { 274 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { 275 return false; 276 } 277 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { 278 return edge.id == dstId && (!value || edge.value == value); 279 }); 280 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { 281 return edge.id == srcId && (!value || edge.value == value); 282 }); 283 return hasOutEdge && hasInEdge; 284 } 285 286 // Adds an edge from node 'srcId' to node 'dstId' for 'value'. 287 void addEdge(unsigned srcId, unsigned dstId, Value value) { 288 if (!hasEdge(srcId, dstId, value)) { 289 outEdges[srcId].push_back({dstId, value}); 290 inEdges[dstId].push_back({srcId, value}); 291 if (value.getType().isa<MemRefType>()) 292 memrefEdgeCount[value]++; 293 } 294 } 295 296 // Removes an edge from node 'srcId' to node 'dstId' for 'value'. 297 void removeEdge(unsigned srcId, unsigned dstId, Value value) { 298 assert(inEdges.count(dstId) > 0); 299 assert(outEdges.count(srcId) > 0); 300 if (value.getType().isa<MemRefType>()) { 301 assert(memrefEdgeCount.count(value) > 0); 302 memrefEdgeCount[value]--; 303 } 304 // Remove 'srcId' from 'inEdges[dstId]'. 305 for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { 306 if ((*it).id == srcId && (*it).value == value) { 307 inEdges[dstId].erase(it); 308 break; 309 } 310 } 311 // Remove 'dstId' from 'outEdges[srcId]'. 312 for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); 313 ++it) { 314 if ((*it).id == dstId && (*it).value == value) { 315 outEdges[srcId].erase(it); 316 break; 317 } 318 } 319 } 320 321 // Returns true if there is a path in the dependence graph from node 'srcId' 322 // to node 'dstId'. Returns false otherwise. 323 bool hasDependencePath(unsigned srcId, unsigned dstId) { 324 // Worklist state is: <node-id, next-output-edge-index-to-visit> 325 SmallVector<std::pair<unsigned, unsigned>, 4> worklist; 326 worklist.push_back({srcId, 0}); 327 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. 328 while (!worklist.empty()) { 329 auto &idAndIndex = worklist.back(); 330 // Return true if we have reached 'dstId'. 331 if (idAndIndex.first == dstId) 332 return true; 333 // Pop and continue if node has no out edges, or if all out edges have 334 // already been visited. 335 if (outEdges.count(idAndIndex.first) == 0 || 336 idAndIndex.second == outEdges[idAndIndex.first].size()) { 337 worklist.pop_back(); 338 continue; 339 } 340 // Get graph edge to traverse. 341 Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; 342 // Increment next output edge index for 'idAndIndex'. 343 ++idAndIndex.second; 344 // Add node at 'edge.id' to worklist. 345 worklist.push_back({edge.id, 0}); 346 } 347 return false; 348 } 349 350 // Returns the input edge count for node 'id' and 'memref' from src nodes 351 // which access 'memref' with a store operation. 352 unsigned getIncomingMemRefAccesses(unsigned id, Value memref) { 353 unsigned inEdgeCount = 0; 354 if (inEdges.count(id) > 0) 355 for (auto &inEdge : inEdges[id]) 356 if (inEdge.value == memref) { 357 Node *srcNode = getNode(inEdge.id); 358 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' 359 if (srcNode->getStoreOpCount(memref) > 0) 360 ++inEdgeCount; 361 } 362 return inEdgeCount; 363 } 364 365 // Returns the output edge count for node 'id' and 'memref' (if non-null), 366 // otherwise returns the total output edge count from node 'id'. 367 unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) { 368 unsigned outEdgeCount = 0; 369 if (outEdges.count(id) > 0) 370 for (auto &outEdge : outEdges[id]) 371 if (!memref || outEdge.value == memref) 372 ++outEdgeCount; 373 return outEdgeCount; 374 } 375 376 /// Return all nodes which define SSA values used in node 'id'. 377 void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes) { 378 for (MemRefDependenceGraph::Edge edge : inEdges[id]) 379 // By definition of edge, if the edge value is a non-memref value, 380 // then the dependence is between a graph node which defines an SSA value 381 // and another graph node which uses the SSA value. 382 if (!edge.value.getType().isa<MemRefType>()) 383 definingNodes.insert(edge.id); 384 } 385 386 // Computes and returns an insertion point operation, before which the 387 // the fused <srcId, dstId> loop nest can be inserted while preserving 388 // dependences. Returns nullptr if no such insertion point is found. 389 Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { 390 if (outEdges.count(srcId) == 0) 391 return getNode(dstId)->op; 392 393 // Skip if there is any defining node of 'dstId' that depends on 'srcId'. 394 DenseSet<unsigned> definingNodes; 395 gatherDefiningNodes(dstId, definingNodes); 396 if (llvm::any_of(definingNodes, [&](unsigned id) { 397 return hasDependencePath(srcId, id); 398 })) { 399 LLVM_DEBUG(llvm::dbgs() 400 << "Can't fuse: a defining op with a user in the dst " 401 "loop has dependence from the src loop\n"); 402 return nullptr; 403 } 404 405 // Build set of insts in range (srcId, dstId) which depend on 'srcId'. 406 SmallPtrSet<Operation *, 2> srcDepInsts; 407 for (auto &outEdge : outEdges[srcId]) 408 if (outEdge.id != dstId) 409 srcDepInsts.insert(getNode(outEdge.id)->op); 410 411 // Build set of insts in range (srcId, dstId) on which 'dstId' depends. 412 SmallPtrSet<Operation *, 2> dstDepInsts; 413 for (auto &inEdge : inEdges[dstId]) 414 if (inEdge.id != srcId) 415 dstDepInsts.insert(getNode(inEdge.id)->op); 416 417 Operation *srcNodeInst = getNode(srcId)->op; 418 Operation *dstNodeInst = getNode(dstId)->op; 419 420 // Computing insertion point: 421 // *) Walk all operation positions in Block operation list in the 422 // range (src, dst). For each operation 'op' visited in this search: 423 // *) Store in 'firstSrcDepPos' the first position where 'op' has a 424 // dependence edge from 'srcNode'. 425 // *) Store in 'lastDstDepPost' the last position where 'op' has a 426 // dependence edge to 'dstNode'. 427 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the 428 // operation insertion point (or return null pointer if no such 429 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). 430 SmallVector<Operation *, 2> depInsts; 431 Optional<unsigned> firstSrcDepPos; 432 Optional<unsigned> lastDstDepPos; 433 unsigned pos = 0; 434 for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); 435 it != Block::iterator(dstNodeInst); ++it) { 436 Operation *op = &(*it); 437 if (srcDepInsts.count(op) > 0 && firstSrcDepPos == None) 438 firstSrcDepPos = pos; 439 if (dstDepInsts.count(op) > 0) 440 lastDstDepPos = pos; 441 depInsts.push_back(op); 442 ++pos; 443 } 444 445 if (firstSrcDepPos.has_value()) { 446 if (lastDstDepPos.has_value()) { 447 if (firstSrcDepPos.value() <= lastDstDepPos.value()) { 448 // No valid insertion point exists which preserves dependences. 449 return nullptr; 450 } 451 } 452 // Return the insertion point at 'firstSrcDepPos'. 453 return depInsts[firstSrcDepPos.value()]; 454 } 455 // No dependence targets in range (or only dst deps in range), return 456 // 'dstNodInst' insertion point. 457 return dstNodeInst; 458 } 459 460 // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, 461 // taking into account that: 462 // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, 463 // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a 464 // private memref. 465 void updateEdges(unsigned srcId, unsigned dstId, 466 const DenseSet<Value> &privateMemRefs, bool removeSrcId) { 467 // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. 468 if (inEdges.count(srcId) > 0) { 469 SmallVector<Edge, 2> oldInEdges = inEdges[srcId]; 470 for (auto &inEdge : oldInEdges) { 471 // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. 472 if (privateMemRefs.count(inEdge.value) == 0) 473 addEdge(inEdge.id, dstId, inEdge.value); 474 } 475 } 476 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. 477 // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. 478 if (outEdges.count(srcId) > 0) { 479 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId]; 480 for (auto &outEdge : oldOutEdges) { 481 // Remove any out edges from 'srcId' to 'dstId' across memrefs. 482 if (outEdge.id == dstId) 483 removeEdge(srcId, outEdge.id, outEdge.value); 484 else if (removeSrcId) { 485 addEdge(dstId, outEdge.id, outEdge.value); 486 removeEdge(srcId, outEdge.id, outEdge.value); 487 } 488 } 489 } 490 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being 491 // replaced by a private memref). These edges could come from nodes 492 // other than 'srcId' which were removed in the previous step. 493 if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) { 494 SmallVector<Edge, 2> oldInEdges = inEdges[dstId]; 495 for (auto &inEdge : oldInEdges) 496 if (privateMemRefs.count(inEdge.value) > 0) 497 removeEdge(inEdge.id, dstId, inEdge.value); 498 } 499 } 500 501 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion 502 // of sibling node 'sibId' into node 'dstId'. 503 void updateEdges(unsigned sibId, unsigned dstId) { 504 // For each edge in 'inEdges[sibId]': 505 // *) Add new edge from source node 'inEdge.id' to 'dstNode'. 506 // *) Remove edge from source node 'inEdge.id' to 'sibNode'. 507 if (inEdges.count(sibId) > 0) { 508 SmallVector<Edge, 2> oldInEdges = inEdges[sibId]; 509 for (auto &inEdge : oldInEdges) { 510 addEdge(inEdge.id, dstId, inEdge.value); 511 removeEdge(inEdge.id, sibId, inEdge.value); 512 } 513 } 514 515 // For each edge in 'outEdges[sibId]' to node 'id' 516 // *) Add new edge from 'dstId' to 'outEdge.id'. 517 // *) Remove edge from 'sibId' to 'outEdge.id'. 518 if (outEdges.count(sibId) > 0) { 519 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId]; 520 for (auto &outEdge : oldOutEdges) { 521 addEdge(dstId, outEdge.id, outEdge.value); 522 removeEdge(sibId, outEdge.id, outEdge.value); 523 } 524 } 525 } 526 527 // Adds ops in 'loads' and 'stores' to node at 'id'. 528 void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads, 529 const SmallVectorImpl<Operation *> &stores) { 530 Node *node = getNode(id); 531 llvm::append_range(node->loads, loads); 532 llvm::append_range(node->stores, stores); 533 } 534 535 void clearNodeLoadAndStores(unsigned id) { 536 Node *node = getNode(id); 537 node->loads.clear(); 538 node->stores.clear(); 539 } 540 541 // Calls 'callback' for each input edge incident to node 'id' which carries a 542 // memref dependence. 543 void forEachMemRefInputEdge(unsigned id, 544 const std::function<void(Edge)> &callback) { 545 if (inEdges.count(id) > 0) 546 forEachMemRefEdge(inEdges[id], callback); 547 } 548 549 // Calls 'callback' for each output edge from node 'id' which carries a 550 // memref dependence. 551 void forEachMemRefOutputEdge(unsigned id, 552 const std::function<void(Edge)> &callback) { 553 if (outEdges.count(id) > 0) 554 forEachMemRefEdge(outEdges[id], callback); 555 } 556 557 // Calls 'callback' for each edge in 'edges' which carries a memref 558 // dependence. 559 void forEachMemRefEdge(ArrayRef<Edge> edges, 560 const std::function<void(Edge)> &callback) { 561 for (const auto &edge : edges) { 562 // Skip if 'edge' is not a memref dependence edge. 563 if (!edge.value.getType().isa<MemRefType>()) 564 continue; 565 assert(nodes.count(edge.id) > 0); 566 // Skip if 'edge.id' is not a loop nest. 567 if (!isa<AffineForOp>(getNode(edge.id)->op)) 568 continue; 569 // Visit current input edge 'edge'. 570 callback(edge); 571 } 572 } 573 574 void print(raw_ostream &os) const { 575 os << "\nMemRefDependenceGraph\n"; 576 os << "\nNodes:\n"; 577 for (const auto &idAndNode : nodes) { 578 os << "Node: " << idAndNode.first << "\n"; 579 auto it = inEdges.find(idAndNode.first); 580 if (it != inEdges.end()) { 581 for (const auto &e : it->second) 582 os << " InEdge: " << e.id << " " << e.value << "\n"; 583 } 584 it = outEdges.find(idAndNode.first); 585 if (it != outEdges.end()) { 586 for (const auto &e : it->second) 587 os << " OutEdge: " << e.id << " " << e.value << "\n"; 588 } 589 } 590 } 591 void dump() const { print(llvm::errs()); } 592 }; 593 594 /// Returns true if node 'srcId' can be removed after fusing it with node 595 /// 'dstId'. The node can be removed if any of the following conditions are met: 596 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs. 597 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs 598 /// and the fusion slice is maximal. 599 /// 3. 'srcId' has output dependences after fusion, the fusion slice is 600 /// maximal and the fusion insertion point dominates all the dependences. 601 static bool canRemoveSrcNodeAfterFusion( 602 unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, 603 Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs, 604 MemRefDependenceGraph *mdg) { 605 606 Operation *dstNodeOp = mdg->getNode(dstId)->op; 607 bool hasOutDepsAfterFusion = false; 608 609 for (auto &outEdge : mdg->outEdges[srcId]) { 610 Operation *depNodeOp = mdg->getNode(outEdge.id)->op; 611 // Skip dependence with dstOp since it will be removed after fusion. 612 if (depNodeOp == dstNodeOp) 613 continue; 614 615 // Only fusion within the same block is supported. Use domination analysis 616 // when needed. 617 if (depNodeOp->getBlock() != dstNodeOp->getBlock()) 618 return false; 619 620 // Check if the insertion point of the fused loop dominates the dependence. 621 // Otherwise, the src loop can't be removed. 622 if (fusedLoopInsPoint != depNodeOp && 623 !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { 624 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " 625 "dominate dependence\n"); 626 return false; 627 } 628 629 hasOutDepsAfterFusion = true; 630 } 631 632 // If src loop has dependences after fusion or it writes to an live-out or 633 // escaping memref, we can only remove it if the fusion slice is maximal so 634 // that all the dependences are preserved. 635 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { 636 Optional<bool> isMaximal = fusionSlice.isMaximal(); 637 if (!isMaximal) { 638 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " 639 "if fusion is maximal\n"); 640 return false; 641 } 642 643 if (!*isMaximal) { 644 LLVM_DEBUG(llvm::dbgs() 645 << "Src loop can't be removed: fusion is not maximal\n"); 646 return false; 647 } 648 } 649 650 return true; 651 } 652 653 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer 654 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to 655 /// the program order when the 'mdg' is created. However, program order is not 656 /// guaranteed and must not be required by the client. Program order won't be 657 /// held if the 'mdg' is reused from a previous fusion step or if the node 658 /// creation order changes in the future to support more advance cases. 659 // TODO: Move this to a loop fusion utility once 'mdg' is also moved. 660 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, 661 SmallVectorImpl<unsigned> &srcIdCandidates) { 662 // Skip if no input edges along which to fuse. 663 if (mdg->inEdges.count(dstId) == 0) 664 return; 665 666 // Gather memrefs from loads in 'dstId'. 667 auto *dstNode = mdg->getNode(dstId); 668 DenseSet<Value> consumedMemrefs; 669 for (Operation *load : dstNode->loads) 670 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef()); 671 672 // Traverse 'dstId' incoming edges and gather the nodes that contain a store 673 // to one of the consumed memrefs. 674 for (auto &srcEdge : mdg->inEdges[dstId]) { 675 auto *srcNode = mdg->getNode(srcEdge.id); 676 // Skip if 'srcNode' is not a loop nest. 677 if (!isa<AffineForOp>(srcNode->op)) 678 continue; 679 680 if (any_of(srcNode->stores, [&](Operation *op) { 681 auto storeOp = cast<AffineWriteOpInterface>(op); 682 return consumedMemrefs.count(storeOp.getMemRef()) > 0; 683 })) 684 srcIdCandidates.push_back(srcNode->id); 685 } 686 687 llvm::sort(srcIdCandidates); 688 srcIdCandidates.erase( 689 std::unique(srcIdCandidates.begin(), srcIdCandidates.end()), 690 srcIdCandidates.end()); 691 } 692 693 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a 694 /// producer-consumer dependence between 'srcId' and 'dstId'. 695 static void 696 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, 697 MemRefDependenceGraph *mdg, 698 DenseSet<Value> &producerConsumerMemrefs) { 699 auto *dstNode = mdg->getNode(dstId); 700 auto *srcNode = mdg->getNode(srcId); 701 gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, 702 producerConsumerMemrefs); 703 } 704 705 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' 706 /// that escape the function. A memref escapes the function if either: 707 /// 1. It's a function argument, or 708 /// 2. It's used by a non-affine op (e.g., std load/store, std call, etc.) 709 void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, 710 DenseSet<Value> &escapingMemRefs) { 711 auto *node = mdg->getNode(id); 712 for (auto *storeOpInst : node->stores) { 713 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); 714 if (escapingMemRefs.count(memref)) 715 continue; 716 // Check if 'memref' escapes because it's a block argument. 717 if (memref.isa<BlockArgument>()) { 718 escapingMemRefs.insert(memref); 719 continue; 720 } 721 // Check if 'memref' escapes through a non-affine op (e.g., std load/store, 722 // call op, etc.). 723 for (Operation *user : memref.getUsers()) 724 if (!isa<AffineMapAccessInterface>(*user)) 725 escapingMemRefs.insert(memref); 726 } 727 } 728 729 } // namespace 730 731 // Initializes the data dependence graph by walking operations in 'f'. 732 // Assigns each node in the graph a node id based on program order in 'f'. 733 // TODO: Add support for taking a Block arg to construct the 734 // dependence graph at a different depth. 735 bool MemRefDependenceGraph::init(func::FuncOp f) { 736 LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); 737 DenseMap<Value, SetVector<unsigned>> memrefAccesses; 738 739 // TODO: support multi-block functions. 740 if (!llvm::hasSingleElement(f)) 741 return false; 742 743 DenseMap<Operation *, unsigned> forToNodeMap; 744 for (auto &op : f.front()) { 745 if (auto forOp = dyn_cast<AffineForOp>(op)) { 746 // Create graph node 'id' to represent top-level 'forOp' and record 747 // all loads and store accesses it contains. 748 LoopNestStateCollector collector; 749 collector.collect(&op); 750 // Return false if a region holding op other than 'affine.for' and 751 // 'affine.if' was found (not currently supported). 752 if (collector.hasNonAffineRegionOp) 753 return false; 754 Node node(nextNodeId++, &op); 755 for (auto *opInst : collector.loadOpInsts) { 756 node.loads.push_back(opInst); 757 auto memref = cast<AffineReadOpInterface>(opInst).getMemRef(); 758 memrefAccesses[memref].insert(node.id); 759 } 760 for (auto *opInst : collector.storeOpInsts) { 761 node.stores.push_back(opInst); 762 auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef(); 763 memrefAccesses[memref].insert(node.id); 764 } 765 forToNodeMap[&op] = node.id; 766 nodes.insert({node.id, node}); 767 } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { 768 // Create graph node for top-level load op. 769 Node node(nextNodeId++, &op); 770 node.loads.push_back(&op); 771 auto memref = cast<AffineReadOpInterface>(op).getMemRef(); 772 memrefAccesses[memref].insert(node.id); 773 nodes.insert({node.id, node}); 774 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { 775 // Create graph node for top-level store op. 776 Node node(nextNodeId++, &op); 777 node.stores.push_back(&op); 778 auto memref = cast<AffineWriteOpInterface>(op).getMemRef(); 779 memrefAccesses[memref].insert(node.id); 780 nodes.insert({node.id, node}); 781 } else if (op.getNumRegions() != 0) { 782 // Return false if another region is found (not currently supported). 783 return false; 784 } else if (op.getNumResults() > 0 && !op.use_empty()) { 785 // Create graph node for top-level producer of SSA values, which 786 // could be used by loop nest nodes. 787 Node node(nextNodeId++, &op); 788 nodes.insert({node.id, node}); 789 } else if (isa<CallOpInterface>(op)) { 790 // Create graph node for top-level Call Op that takes any argument of 791 // memref type. Call Op that returns one or more memref type results 792 // is already taken care of, by the previous conditions. 793 if (llvm::any_of(op.getOperandTypes(), 794 [&](Type t) { return t.isa<MemRefType>(); })) { 795 Node node(nextNodeId++, &op); 796 nodes.insert({node.id, node}); 797 } 798 } else if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) { 799 // Create graph node for top-level op, which could have a memory write 800 // side effect. 801 SmallVector<MemoryEffects::EffectInstance, 1> effects; 802 effectInterface.getEffects(effects); 803 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &it) { 804 return isa<MemoryEffects::Write, MemoryEffects::Free>( 805 it.getEffect()); 806 })) { 807 Node node(nextNodeId++, &op); 808 nodes.insert({node.id, node}); 809 } 810 } 811 } 812 813 for (auto &idAndNode : nodes) { 814 LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n" 815 << *(idAndNode.second.op) << "\n"); 816 (void)idAndNode; 817 } 818 819 // Add dependence edges between nodes which produce SSA values and their 820 // users. Load ops can be considered as the ones producing SSA values. 821 for (auto &idAndNode : nodes) { 822 const Node &node = idAndNode.second; 823 // Stores don't define SSA values, skip them. 824 if (!node.stores.empty()) 825 continue; 826 auto *opInst = node.op; 827 for (auto value : opInst->getResults()) { 828 for (auto *user : value.getUsers()) { 829 SmallVector<AffineForOp, 4> loops; 830 getLoopIVs(*user, &loops); 831 if (loops.empty()) 832 continue; 833 assert(forToNodeMap.count(loops[0].getOperation()) > 0); 834 unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()]; 835 addEdge(node.id, userLoopNestId, value); 836 } 837 } 838 } 839 840 // Walk memref access lists and add graph edges between dependent nodes. 841 for (auto &memrefAndList : memrefAccesses) { 842 unsigned n = memrefAndList.second.size(); 843 for (unsigned i = 0; i < n; ++i) { 844 unsigned srcId = memrefAndList.second[i]; 845 bool srcHasStore = 846 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; 847 for (unsigned j = i + 1; j < n; ++j) { 848 unsigned dstId = memrefAndList.second[j]; 849 bool dstHasStore = 850 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; 851 if (srcHasStore || dstHasStore) 852 addEdge(srcId, dstId, memrefAndList.first); 853 } 854 } 855 } 856 return true; 857 } 858 859 // Sinks all sequential loops to the innermost levels (while preserving 860 // relative order among them) and moves all parallel loops to the 861 // outermost (while again preserving relative order among them). 862 // This can increase the loop depth at which we can fuse a slice, since we are 863 // pushing loop carried dependence to a greater depth in the loop nest. 864 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { 865 assert(isa<AffineForOp>(node->op)); 866 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op)); 867 node->op = newRootForOp.getOperation(); 868 } 869 870 // TODO: improve/complete this when we have target data. 871 static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { 872 auto elementType = memRefType.getElementType(); 873 874 unsigned sizeInBits; 875 if (elementType.isIntOrFloat()) { 876 sizeInBits = elementType.getIntOrFloatBitWidth(); 877 } else { 878 auto vectorType = elementType.cast<VectorType>(); 879 sizeInBits = 880 vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); 881 } 882 return llvm::divideCeil(sizeInBits, 8); 883 } 884 885 // Creates and returns a private (single-user) memref for fused loop rooted 886 // at 'forOp', with (potentially reduced) memref size based on the 887 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. 888 // TODO: consider refactoring the common code from generateDma and 889 // this one. 890 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, 891 unsigned dstLoopDepth, 892 Optional<unsigned> fastMemorySpace, 893 uint64_t localBufSizeThreshold) { 894 auto *forInst = forOp.getOperation(); 895 896 // Create builder to insert alloc op just before 'forOp'. 897 OpBuilder b(forInst); 898 // Builder to create constants at the top level. 899 OpBuilder top(forInst->getParentOfType<func::FuncOp>().getBody()); 900 // Create new memref type based on slice bounds. 901 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef(); 902 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>(); 903 unsigned rank = oldMemRefType.getRank(); 904 905 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. 906 MemRefRegion region(srcStoreOpInst->getLoc()); 907 bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); 908 (void)validRegion; 909 assert(validRegion && "unexpected memref region failure"); 910 SmallVector<int64_t, 4> newShape; 911 std::vector<SmallVector<int64_t, 4>> lbs; 912 SmallVector<int64_t, 8> lbDivisors; 913 lbs.reserve(rank); 914 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed 915 // by 'srcStoreOpInst' at depth 'dstLoopDepth'. 916 Optional<int64_t> numElements = 917 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); 918 assert(numElements && "non-constant number of elts in local buffer"); 919 920 const FlatAffineValueConstraints *cst = region.getConstraints(); 921 // 'outerIVs' holds the values that this memory region is symbolic/parametric 922 // on; this would correspond to loop IVs surrounding the level at which the 923 // slice is being materialized. 924 SmallVector<Value, 8> outerIVs; 925 cst->getValues(rank, cst->getNumVars(), &outerIVs); 926 927 // Build 'rank' AffineExprs from MemRefRegion 'lbs' 928 SmallVector<AffineExpr, 4> offsets; 929 offsets.reserve(rank); 930 for (unsigned d = 0; d < rank; ++d) { 931 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); 932 933 AffineExpr offset = top.getAffineConstantExpr(0); 934 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { 935 offset = offset + lbs[d][j] * top.getAffineDimExpr(j); 936 } 937 assert(lbDivisors[d] > 0); 938 offset = 939 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); 940 offsets.push_back(offset); 941 } 942 943 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed 944 // by 'srcStoreOpInst'. 945 uint64_t bufSize = 946 getMemRefEltSizeInBytes(oldMemRefType) * numElements.value(); 947 unsigned newMemSpace; 948 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) { 949 newMemSpace = fastMemorySpace.value(); 950 } else { 951 newMemSpace = oldMemRefType.getMemorySpaceAsInt(); 952 } 953 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), 954 {}, newMemSpace); 955 956 // Create new private memref for fused loop 'forOp'. 'newShape' is always 957 // a constant shape. 958 // TODO: Create/move alloc ops for private memrefs closer to their 959 // consumer loop nests to reduce their live range. Currently they are added 960 // at the beginning of the function, because loop nests can be reordered 961 // during the fusion pass. 962 Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType); 963 964 // Build an AffineMap to remap access functions based on lower bound offsets. 965 SmallVector<AffineExpr, 4> remapExprs; 966 remapExprs.reserve(rank); 967 for (unsigned i = 0; i < rank; i++) { 968 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); 969 970 auto remapExpr = 971 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); 972 remapExprs.push_back(remapExpr); 973 } 974 975 auto indexRemap = 976 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext()); 977 978 // Replace all users of 'oldMemRef' with 'newMemRef'. 979 LogicalResult res = 980 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, 981 /*extraOperands=*/outerIVs, 982 /*symbolOperands=*/{}, 983 /*domOpFilter=*/&*forOp.getBody()->begin()); 984 assert(succeeded(res) && 985 "replaceAllMemrefUsesWith should always succeed here"); 986 (void)res; 987 return newMemRef; 988 } 989 990 /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 991 /// 'dstId'), if there is any non-affine operation accessing 'memref', return 992 /// true. Otherwise, return false. 993 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 994 Value memref, 995 MemRefDependenceGraph *mdg) { 996 auto *srcNode = mdg->getNode(srcId); 997 auto *dstNode = mdg->getNode(dstId); 998 Value::user_range users = memref.getUsers(); 999 // For each MemRefDependenceGraph's node that is between 'srcNode' and 1000 // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any 1001 // non-affine operation in the node accesses the 'memref'. 1002 for (auto &idAndNode : mdg->nodes) { 1003 Operation *op = idAndNode.second.op; 1004 // Take care of operations between 'srcNode' and 'dstNode'. 1005 if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) { 1006 // Walk inside the operation to find any use of the memref. 1007 // Interrupt the walk if found. 1008 auto walkResult = op->walk([&](Operation *user) { 1009 // Skip affine ops. 1010 if (isa<AffineMapAccessInterface>(*user)) 1011 return WalkResult::advance(); 1012 // Find a non-affine op that uses the memref. 1013 if (llvm::is_contained(users, user)) 1014 return WalkResult::interrupt(); 1015 return WalkResult::advance(); 1016 }); 1017 if (walkResult.wasInterrupted()) 1018 return true; 1019 } 1020 } 1021 return false; 1022 } 1023 1024 /// Check whether a memref value in node 'srcId' has a non-affine that 1025 /// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and 1026 /// 'dstNode'). 1027 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, 1028 MemRefDependenceGraph *mdg) { 1029 // Collect memref values in node 'srcId'. 1030 auto *srcNode = mdg->getNode(srcId); 1031 llvm::SmallDenseSet<Value, 2> memRefValues; 1032 srcNode->op->walk([&](Operation *op) { 1033 // Skip affine ops. 1034 if (isa<AffineForOp>(op)) 1035 return WalkResult::advance(); 1036 for (Value v : op->getOperands()) 1037 // Collect memref values only. 1038 if (v.getType().isa<MemRefType>()) 1039 memRefValues.insert(v); 1040 return WalkResult::advance(); 1041 }); 1042 // Looking for users between node 'srcId' and node 'dstId'. 1043 for (Value memref : memRefValues) 1044 if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg)) 1045 return true; 1046 return false; 1047 } 1048 1049 // Checks the profitability of fusing a backwards slice of the loop nest 1050 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. 1051 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on 1052 // the memref being produced and consumed, which is an input to the cost model. 1053 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as 1054 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse 1055 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the 1056 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the 1057 // unique store op in the src node, which will be used to check that the write 1058 // region is the same after input-reuse fusion. Computation slices are provided 1059 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which 1060 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is 1061 // profitable to fuse the candidate loop nests. Returns false otherwise. 1062 // `dstLoopDepth` is set to the most profitable depth at which to materialize 1063 // the source loop nest slice. 1064 // The profitability model executes the following steps: 1065 // *) Computes the backward computation slice at 'srcOpInst'. This 1066 // computation slice of the loop nest surrounding 'srcOpInst' is 1067 // represented by modified src loop bounds in 'sliceState', which are 1068 // functions of loop IVs in the loop nest surrounding 'srcOpInst'. 1069 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a 1070 // loop nest is the total number of dynamic operation instances in the loop 1071 // nest). 1072 // *) Computes the cost of fusing a slice of the src loop nest into the dst 1073 // loop nest at various values of dst loop depth, attempting to fuse 1074 // the largest computation slice at the maximal dst loop depth (closest to 1075 // the load) to minimize reuse distance and potentially enable subsequent 1076 // load/store forwarding. 1077 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop 1078 // nest, at which the src computation slice is inserted/fused. 1079 // NOTE: We attempt to maximize the dst loop depth, but there are cases 1080 // where a particular setting for 'dstLoopNest' might fuse an unsliced 1081 // loop (within the src computation slice) at a depth which results in 1082 // excessive recomputation (see unit tests for examples). 1083 // *) Compares the total cost of the unfused loop nests to the min cost fused 1084 // loop nest computed in the previous step, and returns true if the latter 1085 // is lower. 1086 // TODO: Extend profitability analysis to support scenarios with multiple 1087 // stores. 1088 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 1089 AffineForOp dstForOp, 1090 ArrayRef<ComputationSliceState> depthSliceUnions, 1091 unsigned maxLegalFusionDepth, 1092 unsigned *dstLoopDepth, 1093 double computeToleranceThreshold) { 1094 LLVM_DEBUG({ 1095 llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; 1096 llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; 1097 llvm::dbgs() << dstForOp << "\n"; 1098 }); 1099 1100 if (maxLegalFusionDepth == 0) { 1101 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth == 0 .\n"); 1102 return false; 1103 } 1104 1105 // Compute cost of sliced and unsliced src loop nest. 1106 SmallVector<AffineForOp, 4> srcLoopIVs; 1107 getLoopIVs(*srcOpInst, &srcLoopIVs); 1108 1109 // Walk src loop nest and collect stats. 1110 LoopNestStats srcLoopNestStats; 1111 if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) 1112 return false; 1113 1114 // Compute cost of dst loop nest. 1115 LoopNestStats dstLoopNestStats; 1116 if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) 1117 return false; 1118 1119 // Search for min cost value for 'dstLoopDepth'. At each value of 1120 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice 1121 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union 1122 // of these bounds). Next the union slice bounds are used to calculate 1123 // the cost of the slice and the cost of the slice inserted into the dst 1124 // loop nest at 'dstLoopDepth'. 1125 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); 1126 double maxStorageReduction = 0.0; 1127 Optional<uint64_t> sliceMemEstimate = None; 1128 1129 // The best loop depth at which to materialize the slice. 1130 Optional<unsigned> bestDstLoopDepth = None; 1131 1132 // Compute op instance count for the src loop nest without iteration slicing. 1133 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); 1134 1135 // Compute src loop nest write region size. 1136 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); 1137 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { 1138 LLVM_DEBUG(llvm::dbgs() 1139 << "Unable to compute MemRefRegion for source operation\n."); 1140 return false; 1141 } 1142 1143 Optional<int64_t> maybeSrcWriteRegionSizeBytes = 1144 srcWriteRegion.getRegionSize(); 1145 if (!maybeSrcWriteRegionSizeBytes.has_value()) 1146 return false; 1147 int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.value(); 1148 1149 // Compute op instance count for the src loop nest. 1150 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); 1151 1152 // Evaluate all depth choices for materializing the slice in the destination 1153 // loop nest. 1154 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { 1155 const ComputationSliceState &slice = depthSliceUnions[i - 1]; 1156 // Skip slice union if it wasn't computed for this depth. 1157 if (slice.isEmpty()) 1158 continue; 1159 1160 int64_t fusedLoopNestComputeCost; 1161 if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, 1162 dstLoopNestStats, slice, 1163 &fusedLoopNestComputeCost)) { 1164 LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); 1165 continue; 1166 } 1167 1168 double additionalComputeFraction = 1169 fusedLoopNestComputeCost / 1170 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 1171 1; 1172 1173 // Determine what the slice write MemRefRegion would be, if the src loop 1174 // nest slice 'slice' were to be inserted into the dst loop nest at loop 1175 // depth 'i'. 1176 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); 1177 if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, 1178 &slice))) { 1179 LLVM_DEBUG(llvm::dbgs() 1180 << "Failed to compute slice write region at loopDepth: " << i 1181 << "\n"); 1182 continue; 1183 } 1184 1185 Optional<int64_t> maybeSliceWriteRegionSizeBytes = 1186 sliceWriteRegion.getRegionSize(); 1187 if (!maybeSliceWriteRegionSizeBytes.has_value() || 1188 maybeSliceWriteRegionSizeBytes.value() == 0) { 1189 LLVM_DEBUG(llvm::dbgs() 1190 << "Failed to get slice write region size at loopDepth: " << i 1191 << "\n"); 1192 continue; 1193 } 1194 int64_t sliceWriteRegionSizeBytes = maybeSliceWriteRegionSizeBytes.value(); 1195 1196 // If we are fusing for reuse, check that write regions remain the same. 1197 // TODO: Write region check should check sizes and offsets in 1198 // each dimension, so that we are sure they are covering the same memref 1199 // region. Also, move this out to a isMemRefRegionSuperSet helper function. 1200 if (srcOpInst != srcStoreOpInst && 1201 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) 1202 continue; 1203 1204 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) / 1205 static_cast<double>(sliceWriteRegionSizeBytes); 1206 1207 LLVM_DEBUG({ 1208 std::stringstream msg; 1209 msg << " evaluating fusion profitability at depth : " << i << "\n" 1210 << std::fixed << std::setprecision(2) 1211 << " additional compute fraction: " 1212 << 100.0 * additionalComputeFraction << "%\n" 1213 << " storage reduction factor: " << storageReduction << "x\n" 1214 << " fused nest cost: " << fusedLoopNestComputeCost << "\n" 1215 << " src write region size: " << srcWriteRegionSizeBytes << "\n" 1216 << " slice write region size: " << sliceWriteRegionSizeBytes 1217 << "\n"; 1218 llvm::dbgs() << msg.str(); 1219 }); 1220 1221 // TODO: This is a placeholder cost model. 1222 // Among all choices that add an acceptable amount of redundant computation 1223 // (as per computeToleranceThreshold), we will simply pick the one that 1224 // reduces the intermediary size the most. 1225 if ((storageReduction > maxStorageReduction) && 1226 (additionalComputeFraction < computeToleranceThreshold)) { 1227 maxStorageReduction = storageReduction; 1228 bestDstLoopDepth = i; 1229 minFusedLoopNestComputeCost = fusedLoopNestComputeCost; 1230 sliceMemEstimate = sliceWriteRegionSizeBytes; 1231 } 1232 } 1233 1234 // A simple cost model: fuse if it reduces the memory footprint. 1235 1236 if (!bestDstLoopDepth) { 1237 LLVM_DEBUG( 1238 llvm::dbgs() 1239 << "All fusion choices involve more than the threshold amount of " 1240 "redundant computation; NOT fusing.\n"); 1241 return false; 1242 } 1243 1244 if (!bestDstLoopDepth) { 1245 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); 1246 return false; 1247 } 1248 1249 // Set dstLoopDepth based on best values from search. 1250 *dstLoopDepth = *bestDstLoopDepth; 1251 1252 LLVM_DEBUG( 1253 llvm::dbgs() << " LoopFusion fusion stats:" 1254 << "\n best loop depth: " << bestDstLoopDepth 1255 << "\n src loop nest compute cost: " << srcLoopNestCost 1256 << "\n dst loop nest compute cost: " << dstLoopNestCost 1257 << "\n fused loop nest compute cost: " 1258 << minFusedLoopNestComputeCost << "\n"); 1259 1260 auto dstMemSize = getMemoryFootprintBytes(dstForOp); 1261 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); 1262 1263 Optional<double> storageReduction = None; 1264 1265 if (!dstMemSize || !srcMemSize) { 1266 LLVM_DEBUG(llvm::dbgs() 1267 << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); 1268 return false; 1269 } 1270 1271 auto srcMemSizeVal = srcMemSize.value(); 1272 auto dstMemSizeVal = dstMemSize.value(); 1273 1274 assert(sliceMemEstimate && "expected value"); 1275 auto fusedMem = dstMemSizeVal + sliceMemEstimate.value(); 1276 1277 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" 1278 << " dst mem: " << dstMemSizeVal << "\n" 1279 << " fused mem: " << fusedMem << "\n" 1280 << " slice mem: " << sliceMemEstimate << "\n"); 1281 1282 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { 1283 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); 1284 return false; 1285 } 1286 storageReduction = 1287 100.0 * 1288 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); 1289 1290 double additionalComputeFraction = 1291 100.0 * (minFusedLoopNestComputeCost / 1292 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 1293 1); 1294 (void)additionalComputeFraction; 1295 LLVM_DEBUG({ 1296 std::stringstream msg; 1297 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " 1298 << std::setprecision(2) << additionalComputeFraction 1299 << "% redundant computation and a "; 1300 msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>"); 1301 msg << "% storage reduction.\n"; 1302 llvm::dbgs() << msg.str(); 1303 }); 1304 1305 return true; 1306 } 1307 1308 namespace { 1309 1310 // GreedyFusion greedily fuses loop nests which have a producer/consumer or 1311 // input-reuse relationship on a memref, with the goal of improving locality. 1312 // 1313 // The steps of the producer-consumer fusion algorithm are as follows: 1314 // 1315 // *) A worklist is initialized with node ids from the dependence graph. 1316 // *) For each node id in the worklist: 1317 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a 1318 // candidate destination AffineForOp into which fusion will be attempted. 1319 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. 1320 // *) For each LoadOp in 'dstLoadOps' do: 1321 // *) Look up dependent loop nests which have a single store op to the same 1322 // memref. 1323 // *) Check if dependences would be violated by the fusion. 1324 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop 1325 // bounds to be functions of 'dstLoopNest' IVs and symbols. 1326 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', 1327 // at a loop depth determined by the cost model in 'isFusionProfitable'. 1328 // *) Add the newly fused load/store operations to the state, 1329 // and also add newly fused load ops to 'dstLoopOps' to be considered 1330 // as fusion dst load ops in another iteration. 1331 // *) Remove old src loop nest and its associated state. 1332 // 1333 // The steps of the input-reuse fusion algorithm are as follows: 1334 // 1335 // *) Initialize 'worklist' with node ids from the dependence graph. 1336 // *) For each 'dstNode' in the worklist: 1337 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which 1338 // loads from the same memref, but which has no dependence paths to/from. 1339 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop 1340 // bounds to be functions of 'dstLoopNest' IVs and symbols. 1341 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', 1342 // at a loop depth determined by the cost model in 'isFusionProfitable'. 1343 // This function also checks that the memref write region of 'sibLoopNest', 1344 // is preserved in the fused loop nest. 1345 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. 1346 // 1347 // Given a graph where top-level operations are vertices in the set 'V' and 1348 // edges in the set 'E' are dependences between vertices, this algorithm 1349 // takes O(V) time for initialization, and has runtime O(V + E). 1350 // 1351 // This greedy algorithm is not 'maximal' due to the current restriction of 1352 // fusing along single producer consumer edges, but there is a TODO: to fix 1353 // this. 1354 // 1355 // TODO: Experiment with other fusion policies. 1356 struct GreedyFusion { 1357 public: 1358 // The data dependence graph to traverse during fusion. 1359 MemRefDependenceGraph *mdg; 1360 // Worklist of graph nodes visited during the fusion pass. 1361 SmallVector<unsigned, 8> worklist; 1362 // Parameter for local buffer size threshold. 1363 unsigned localBufSizeThreshold; 1364 // Parameter for fast memory space. 1365 Optional<unsigned> fastMemorySpace; 1366 // If true, ignore any additional (redundant) computation tolerance threshold 1367 // that would have prevented fusion. 1368 bool maximalFusion; 1369 // The amount of additional computation that is tolerated while fusing 1370 // pair-wise as a fraction of the total computation. 1371 double computeToleranceThreshold; 1372 1373 using Node = MemRefDependenceGraph::Node; 1374 1375 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, 1376 Optional<unsigned> fastMemorySpace, bool maximalFusion, 1377 double computeToleranceThreshold) 1378 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), 1379 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), 1380 computeToleranceThreshold(computeToleranceThreshold) {} 1381 1382 /// Initializes 'worklist' with nodes from 'mdg'. 1383 void init() { 1384 // TODO: Add a priority queue for prioritizing nodes by different 1385 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). 1386 worklist.clear(); 1387 for (auto &idAndNode : mdg->nodes) { 1388 const Node &node = idAndNode.second; 1389 worklist.push_back(node.id); 1390 } 1391 } 1392 /// Run only sibling fusion on the `mdg`. 1393 void runSiblingFusionOnly() { 1394 fuseSiblingNodes(); 1395 eraseUnusedMemRefAllocations(); 1396 } 1397 1398 /// Run only producer/consumer fusion on the `mdg`. 1399 void runProducerConsumerFusionOnly() { 1400 fuseProducerConsumerNodes( 1401 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 1402 eraseUnusedMemRefAllocations(); 1403 } 1404 1405 // Run the GreedyFusion pass. 1406 // *) First pass through the nodes fuses single-use producer nodes into their 1407 // unique consumer. 1408 // *) Second pass fuses sibling nodes which share no dependence edges. 1409 // *) Third pass fuses any remaining producer nodes into their users. 1410 void runGreedyFusion() { 1411 // TODO: Run this repeatedly until a fixed-point is reached. 1412 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); 1413 fuseSiblingNodes(); 1414 fuseProducerConsumerNodes( 1415 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 1416 eraseUnusedMemRefAllocations(); 1417 } 1418 1419 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { 1420 LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); 1421 init(); 1422 while (!worklist.empty()) { 1423 unsigned dstId = worklist.back(); 1424 worklist.pop_back(); 1425 1426 // Skip if this node was removed (fused into another node). 1427 if (mdg->nodes.count(dstId) == 0) 1428 continue; 1429 // Get 'dstNode' into which to attempt fusion. 1430 auto *dstNode = mdg->getNode(dstId); 1431 // Skip if 'dstNode' is not a loop nest. 1432 if (!isa<AffineForOp>(dstNode->op)) 1433 continue; 1434 // Skip if 'dstNode' is a loop nest returning values. 1435 // TODO: support loop nests that return values. 1436 if (dstNode->op->getNumResults() > 0) 1437 continue; 1438 1439 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); 1440 1441 // Sink sequential loops in 'dstNode' (and thus raise parallel loops) 1442 // while preserving relative order. This can increase the maximum loop 1443 // depth at which we can fuse a slice of a producer loop nest into a 1444 // consumer loop nest. 1445 sinkSequentialLoops(dstNode); 1446 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1447 1448 // Try to fuse 'dstNode' with candidate producer loops until a fixed point 1449 // is reached. Fusing two loops may expose new fusion opportunities. 1450 bool dstNodeChanged; 1451 do { 1452 // Gather src loop candidates for 'dstNode' and visit them in "quasi" 1453 // reverse program order to minimize the number of iterations needed to 1454 // reach the fixed point. Note that this is a best effort approach since 1455 // 'getProducerCandidates' does not always guarantee that program order 1456 // in 'srcIdCandidates'. 1457 dstNodeChanged = false; 1458 SmallVector<unsigned, 16> srcIdCandidates; 1459 getProducerCandidates(dstId, mdg, srcIdCandidates); 1460 1461 for (unsigned srcId : llvm::reverse(srcIdCandidates)) { 1462 // Get 'srcNode' from which to attempt fusion into 'dstNode'. 1463 auto *srcNode = mdg->getNode(srcId); 1464 auto srcAffineForOp = cast<AffineForOp>(srcNode->op); 1465 LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId 1466 << " for dst loop " << dstId << "\n"); 1467 1468 // Skip if 'srcNode' is a loop nest returning values. 1469 // TODO: support loop nests that return values. 1470 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0) 1471 continue; 1472 1473 DenseSet<Value> producerConsumerMemrefs; 1474 gatherProducerConsumerMemrefs(srcId, dstId, mdg, 1475 producerConsumerMemrefs); 1476 1477 // Skip if 'srcNode' out edge count on any memref is greater than 1478 // 'maxSrcUserCount'. 1479 if (any_of(producerConsumerMemrefs, [&](Value memref) { 1480 return mdg->getOutEdgeCount(srcNode->id, memref) > 1481 maxSrcUserCount; 1482 })) 1483 continue; 1484 1485 // Gather memrefs in 'srcNode' that are written and escape to the 1486 // function (e.g., memref function arguments, returned memrefs, 1487 // memrefs passed to function calls, etc.). 1488 DenseSet<Value> srcEscapingMemRefs; 1489 gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); 1490 1491 // Skip if there are non-affine operations in between the 'srcNode' 1492 // and 'dstNode' using their memrefs. If so, we wouldn't be able to 1493 // compute a legal insertion point for now. 'srcNode' and 'dstNode' 1494 // memrefs with non-affine operation users would be considered 1495 // escaping memrefs so we can limit this check to only scenarios with 1496 // escaping memrefs. 1497 if (!srcEscapingMemRefs.empty() && 1498 hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { 1499 LLVM_DEBUG( 1500 llvm::dbgs() 1501 << "Can't fuse: non-affine users in between the loops\n."); 1502 continue; 1503 } 1504 1505 // Compute an operation list insertion point for the fused loop 1506 // nest which preserves dependences. 1507 Operation *fusedLoopInsPoint = 1508 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); 1509 if (fusedLoopInsPoint == nullptr) 1510 continue; 1511 1512 // Compute the innermost common loop depth for dstNode 1513 // producer-consumer loads/stores. 1514 SmallVector<Operation *, 2> dstMemrefOps; 1515 for (Operation *op : dstNode->loads) 1516 if (producerConsumerMemrefs.count( 1517 cast<AffineReadOpInterface>(op).getMemRef()) > 0) 1518 dstMemrefOps.push_back(op); 1519 for (Operation *op : dstNode->stores) 1520 if (producerConsumerMemrefs.count( 1521 cast<AffineWriteOpInterface>(op).getMemRef())) 1522 dstMemrefOps.push_back(op); 1523 unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); 1524 1525 // Check the feasibility of fusing src loop nest into dst loop nest 1526 // at loop depths in range [1, dstLoopDepthTest]. 1527 unsigned maxLegalFusionDepth = 0; 1528 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1529 depthSliceUnions.resize(dstLoopDepthTest); 1530 FusionStrategy strategy(FusionStrategy::ProducerConsumer); 1531 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1532 FusionResult result = mlir::canFuseLoops( 1533 srcAffineForOp, dstAffineForOp, 1534 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1535 1536 if (result.value == FusionResult::Success) 1537 maxLegalFusionDepth = i; 1538 } 1539 1540 if (maxLegalFusionDepth == 0) { 1541 LLVM_DEBUG(llvm::dbgs() 1542 << "Can't fuse: fusion is not legal at any depth\n"); 1543 continue; 1544 } 1545 1546 // Check if fusion would be profitable. We skip profitability analysis 1547 // for maximal fusion since we already know the maximal legal depth to 1548 // fuse. 1549 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1550 if (!maximalFusion) { 1551 // Retrieve producer stores from the src loop. 1552 SmallVector<Operation *, 2> producerStores; 1553 for (Operation *op : srcNode->stores) 1554 if (producerConsumerMemrefs.count( 1555 cast<AffineWriteOpInterface>(op).getMemRef())) 1556 producerStores.push_back(op); 1557 1558 // TODO: Suppport multiple producer stores in profitability 1559 // analysis. We limit profitability analysis to only scenarios with 1560 // a single producer store for now. Note that some multi-store 1561 // producer scenarios will still go through profitability analysis 1562 // if only one of the stores is involved the producer-consumer 1563 // relationship of the candidate loops. 1564 assert(!producerStores.empty() && "Expected producer store"); 1565 if (producerStores.size() > 1) 1566 LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " 1567 "supported for this case\n"); 1568 else if (!isFusionProfitable(producerStores[0], producerStores[0], 1569 dstAffineForOp, depthSliceUnions, 1570 maxLegalFusionDepth, &bestDstLoopDepth, 1571 computeToleranceThreshold)) 1572 continue; 1573 } 1574 1575 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1576 ComputationSliceState &bestSlice = 1577 depthSliceUnions[bestDstLoopDepth - 1]; 1578 assert(!bestSlice.isEmpty() && "Missing slice union for depth"); 1579 1580 // Determine if 'srcId' can be removed after fusion, taking into 1581 // account remaining dependences, escaping memrefs and the fusion 1582 // insertion point. 1583 bool removeSrcNode = canRemoveSrcNodeAfterFusion( 1584 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, 1585 mdg); 1586 1587 DenseSet<Value> privateMemrefs; 1588 for (Value memref : producerConsumerMemrefs) { 1589 // If `memref` is an escaping one, do not create a private memref 1590 // for the below scenarios, since doing so will leave the escaping 1591 // memref unmodified as all the writes originally meant for the 1592 // escaping memref would be performed on the private memref: 1593 // 1. The source is to be removed after fusion, 1594 // OR 1595 // 2. The destination writes to `memref`. 1596 if (srcEscapingMemRefs.count(memref) > 0 && 1597 (removeSrcNode || dstNode->getStoreOpCount(memref) > 0)) 1598 continue; 1599 1600 // Don't create a private memref if 'srcNode' has in edges on 1601 // 'memref' or 'dstNode' has out edges on 'memref'. 1602 if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 || 1603 mdg->getOutEdgeCount(dstId, memref) > 0) 1604 continue; 1605 1606 // If 'srcNode' will be removed but it has out edges on 'memref' to 1607 // nodes other than 'dstNode', we have to preserve dependences and 1608 // cannot create a private memref. 1609 if (removeSrcNode && 1610 any_of(mdg->outEdges[srcId], [&](const auto &edge) { 1611 return edge.value == memref && edge.id != dstId; 1612 })) 1613 continue; 1614 1615 // Create a private version of this memref. 1616 privateMemrefs.insert(memref); 1617 } 1618 1619 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. 1620 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); 1621 dstNodeChanged = true; 1622 1623 LLVM_DEBUG(llvm::dbgs() 1624 << "Fused src loop " << srcId << " into dst loop " << dstId 1625 << " at depth " << bestDstLoopDepth << ":\n" 1626 << dstAffineForOp << "\n"); 1627 1628 // Move 'dstAffineForOp' before 'insertPointInst' if needed. 1629 if (fusedLoopInsPoint != dstAffineForOp.getOperation()) 1630 dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint); 1631 1632 // Update edges between 'srcNode' and 'dstNode'. 1633 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, 1634 removeSrcNode); 1635 1636 // Create private memrefs. 1637 if (!privateMemrefs.empty()) { 1638 // Gather stores for all the private-to-be memrefs. 1639 DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores; 1640 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { 1641 Value storeMemRef = storeOp.getMemRef(); 1642 if (privateMemrefs.count(storeMemRef) > 0) 1643 privateMemRefToStores[storeMemRef].push_back( 1644 storeOp.getOperation()); 1645 }); 1646 1647 // Replace original memrefs with private memrefs. Note that all the 1648 // loads and stores on these memrefs will be replaced with a new 1649 // loads and stores. Any reference to the original ones becomes 1650 // invalid after this point. 1651 for (auto &memrefToStoresPair : privateMemRefToStores) { 1652 // TODO: Use union of memref write regions to compute 1653 // private memref footprint. 1654 SmallVector<Operation *, 4> &storesForMemref = 1655 memrefToStoresPair.second; 1656 Value newMemRef = createPrivateMemRef( 1657 dstAffineForOp, storesForMemref[0], bestDstLoopDepth, 1658 fastMemorySpace, localBufSizeThreshold); 1659 // Create new node in dependence graph for 'newMemRef' alloc op. 1660 unsigned newMemRefNodeId = 1661 mdg->addNode(newMemRef.getDefiningOp()); 1662 // Add edge from 'newMemRef' node to dstNode. 1663 mdg->addEdge(newMemRefNodeId, dstId, newMemRef); 1664 } 1665 // One or more entries for 'newMemRef' alloc op are inserted into 1666 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to 1667 // reallocate, update dstNode. 1668 dstNode = mdg->getNode(dstId); 1669 } 1670 1671 // Collect dst loop stats after memref privatization transformation. 1672 LoopNestStateCollector dstLoopCollector; 1673 dstLoopCollector.collect(dstAffineForOp.getOperation()); 1674 1675 // Clear and add back loads and stores. 1676 mdg->clearNodeLoadAndStores(dstNode->id); 1677 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, 1678 dstLoopCollector.storeOpInsts); 1679 1680 if (removeSrcNode) { 1681 LLVM_DEBUG(llvm::dbgs() 1682 << "Removing src loop " << srcId << " after fusion\n"); 1683 // srcNode is no longer valid after it is removed from mdg. 1684 srcAffineForOp.erase(); 1685 mdg->removeNode(srcId); 1686 srcNode = nullptr; 1687 } 1688 } 1689 } while (dstNodeChanged); 1690 } 1691 } 1692 1693 // Visits each node in the graph, and for each node, attempts to fuse it with 1694 // its sibling nodes (nodes which share a parent, but no dependence edges). 1695 void fuseSiblingNodes() { 1696 LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n"); 1697 init(); 1698 while (!worklist.empty()) { 1699 unsigned dstId = worklist.back(); 1700 worklist.pop_back(); 1701 1702 // Skip if this node was removed (fused into another node). 1703 if (mdg->nodes.count(dstId) == 0) 1704 continue; 1705 // Get 'dstNode' into which to attempt fusion. 1706 auto *dstNode = mdg->getNode(dstId); 1707 // Skip if 'dstNode' is not a loop nest. 1708 if (!isa<AffineForOp>(dstNode->op)) 1709 continue; 1710 // Attempt to fuse 'dstNode' with its sibling nodes in the graph. 1711 fuseWithSiblingNodes(dstNode); 1712 } 1713 } 1714 1715 // Attempt to fuse 'dstNode' with sibling nodes in the graph. 1716 void fuseWithSiblingNodes(Node *dstNode) { 1717 DenseSet<unsigned> visitedSibNodeIds; 1718 std::pair<unsigned, Value> idAndMemref; 1719 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1720 1721 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { 1722 unsigned sibId = idAndMemref.first; 1723 Value memref = idAndMemref.second; 1724 // TODO: Check that 'sibStoreOpInst' post-dominates all other 1725 // stores to the same memref in 'sibNode' loop nest. 1726 auto *sibNode = mdg->getNode(sibId); 1727 // Compute an operation list insertion point for the fused loop 1728 // nest which preserves dependences. 1729 assert(sibNode->op->getBlock() == dstNode->op->getBlock()); 1730 Operation *insertPointInst = 1731 sibNode->op->isBeforeInBlock(dstNode->op) 1732 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) 1733 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); 1734 if (insertPointInst == nullptr) 1735 continue; 1736 1737 // Check if fusion would be profitable and at what depth. 1738 1739 // Get unique 'sibNode' load op to 'memref'. 1740 SmallVector<Operation *, 2> sibLoadOpInsts; 1741 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); 1742 // Currently findSiblingNodeToFuse searches for siblings with one load. 1743 assert(sibLoadOpInsts.size() == 1); 1744 Operation *sibLoadOpInst = sibLoadOpInsts[0]; 1745 assert(!sibNode->stores.empty()); 1746 // TODO: Choose the store which postdominates all other stores. 1747 auto *sibStoreOpInst = sibNode->stores.back(); 1748 1749 // Gather 'dstNode' load ops to 'memref'. 1750 SmallVector<Operation *, 2> dstLoadOpInsts; 1751 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); 1752 1753 SmallVector<AffineForOp, 4> dstLoopIVs; 1754 getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); 1755 unsigned dstLoopDepthTest = dstLoopIVs.size(); 1756 auto sibAffineForOp = cast<AffineForOp>(sibNode->op); 1757 1758 // Compute loop depth and slice union for fusion. 1759 SmallVector<ComputationSliceState, 8> depthSliceUnions; 1760 depthSliceUnions.resize(dstLoopDepthTest); 1761 unsigned maxLegalFusionDepth = 0; 1762 FusionStrategy strategy(memref); 1763 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { 1764 FusionResult result = mlir::canFuseLoops( 1765 sibAffineForOp, dstAffineForOp, 1766 /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); 1767 1768 if (result.value == FusionResult::Success) 1769 maxLegalFusionDepth = i; 1770 } 1771 1772 // Skip if fusion is not feasible at any loop depths. 1773 if (maxLegalFusionDepth == 0) 1774 continue; 1775 1776 unsigned bestDstLoopDepth = maxLegalFusionDepth; 1777 if (!maximalFusion) { 1778 // Check if fusion would be profitable. 1779 if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp, 1780 depthSliceUnions, maxLegalFusionDepth, 1781 &bestDstLoopDepth, computeToleranceThreshold)) 1782 continue; 1783 } 1784 1785 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); 1786 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && 1787 "Fusion depth has no computed slice union"); 1788 // Check if source loop is being inserted in the innermost 1789 // destination loop. Based on this, the fused loop may be optimized 1790 // further inside `fuseLoops`. 1791 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest); 1792 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. 1793 mlir::fuseLoops(sibAffineForOp, dstAffineForOp, 1794 depthSliceUnions[bestDstLoopDepth - 1], 1795 isInnermostInsertion); 1796 1797 auto dstForInst = cast<AffineForOp>(dstNode->op); 1798 // Update operation position of fused loop nest (if needed). 1799 if (insertPointInst != dstForInst.getOperation()) { 1800 dstForInst->moveBefore(insertPointInst); 1801 } 1802 // Update data dependence graph state post fusion. 1803 updateStateAfterSiblingFusion(sibNode, dstNode); 1804 } 1805 } 1806 1807 // Searches function argument uses and the graph from 'dstNode' looking for a 1808 // fusion candidate sibling node which shares no dependences with 'dstNode' 1809 // but which loads from the same memref. Returns true and sets 1810 // 'idAndMemrefToFuse' on success. Returns false otherwise. 1811 bool findSiblingNodeToFuse(Node *dstNode, 1812 DenseSet<unsigned> *visitedSibNodeIds, 1813 std::pair<unsigned, Value> *idAndMemrefToFuse) { 1814 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse 1815 // on 'memref'. 1816 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) { 1817 // Skip if 'outEdge' is not a read-after-write dependence. 1818 // TODO: Remove restrict to single load op restriction. 1819 if (sibNode->getLoadOpCount(memref) != 1) 1820 return false; 1821 // Skip if there exists a path of dependent edges between 1822 // 'sibNode' and 'dstNode'. 1823 if (mdg->hasDependencePath(sibNode->id, dstNode->id) || 1824 mdg->hasDependencePath(dstNode->id, sibNode->id)) 1825 return false; 1826 // Skip sib node if it loads to (and stores from) the same memref on 1827 // which it also has an input dependence edge. 1828 DenseSet<Value> loadAndStoreMemrefSet; 1829 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); 1830 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) { 1831 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; 1832 })) 1833 return false; 1834 1835 // Check that all stores are to the same memref. 1836 DenseSet<Value> storeMemrefs; 1837 for (auto *storeOpInst : sibNode->stores) { 1838 storeMemrefs.insert( 1839 cast<AffineWriteOpInterface>(storeOpInst).getMemRef()); 1840 } 1841 if (storeMemrefs.size() != 1) 1842 return false; 1843 1844 // Skip if a memref value in one node is used by a non-affine memref 1845 // access that lies between 'dstNode' and 'sibNode'. 1846 if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) || 1847 hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg)) 1848 return false; 1849 return true; 1850 }; 1851 1852 // Search for siblings which load the same memref function argument. 1853 auto fn = dstNode->op->getParentOfType<func::FuncOp>(); 1854 for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { 1855 for (auto *user : fn.getArgument(i).getUsers()) { 1856 if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) { 1857 // Gather loops surrounding 'use'. 1858 SmallVector<AffineForOp, 4> loops; 1859 getLoopIVs(*user, &loops); 1860 // Skip 'use' if it is not within a loop nest. 1861 if (loops.empty()) 1862 continue; 1863 Node *sibNode = mdg->getForOpNode(loops[0]); 1864 assert(sibNode != nullptr); 1865 // Skip 'use' if it not a sibling to 'dstNode'. 1866 if (sibNode->id == dstNode->id) 1867 continue; 1868 // Skip 'use' if it has been visited. 1869 if (visitedSibNodeIds->count(sibNode->id) > 0) 1870 continue; 1871 // Skip 'use' if it does not load from the same memref as 'dstNode'. 1872 auto memref = loadOp.getMemRef(); 1873 if (dstNode->getLoadOpCount(memref) == 0) 1874 continue; 1875 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1876 if (canFuseWithSibNode(sibNode, memref)) { 1877 visitedSibNodeIds->insert(sibNode->id); 1878 idAndMemrefToFuse->first = sibNode->id; 1879 idAndMemrefToFuse->second = memref; 1880 return true; 1881 } 1882 } 1883 } 1884 } 1885 1886 // Search for siblings by following edges through an intermediate src node. 1887 // Collect candidate 'dstNode' input edges in 'inEdges'. 1888 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges; 1889 mdg->forEachMemRefInputEdge( 1890 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { 1891 // Add 'inEdge' if it is a read-after-write dependence. 1892 if (dstNode->getLoadOpCount(inEdge.value) > 0 && 1893 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) 1894 inEdges.push_back(inEdge); 1895 }); 1896 1897 // Search for sibling nodes to fuse by visiting output edges from each input 1898 // edge in 'inEdges'. 1899 for (auto &inEdge : inEdges) { 1900 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. 1901 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges; 1902 mdg->forEachMemRefOutputEdge( 1903 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { 1904 unsigned sibNodeId = outEdge.id; 1905 if (visitedSibNodeIds->count(sibNodeId) > 0) 1906 return; 1907 // Skip output edge if not a sibling using the same memref. 1908 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) 1909 return; 1910 auto *sibNode = mdg->getNode(sibNodeId); 1911 if (!isa<AffineForOp>(sibNode->op)) 1912 return; 1913 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1914 if (canFuseWithSibNode(sibNode, outEdge.value)) { 1915 // Add candidate 'outEdge' to sibling node. 1916 outEdges.push_back(outEdge); 1917 } 1918 }); 1919 1920 // Add first candidate if any were returned. 1921 if (!outEdges.empty()) { 1922 visitedSibNodeIds->insert(outEdges[0].id); 1923 idAndMemrefToFuse->first = outEdges[0].id; 1924 idAndMemrefToFuse->second = outEdges[0].value; 1925 return true; 1926 } 1927 } 1928 return false; 1929 } 1930 1931 /// Update data dependence graph state to reflect sibling fusion of 'sibNode' 1932 /// into 'dstNode'. 1933 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { 1934 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. 1935 mdg->updateEdges(sibNode->id, dstNode->id); 1936 1937 // Collect dst loop stats after memref privatization transformation. 1938 auto dstForInst = cast<AffineForOp>(dstNode->op); 1939 LoopNestStateCollector dstLoopCollector; 1940 dstLoopCollector.collect(dstForInst.getOperation()); 1941 // Clear and add back loads and stores 1942 mdg->clearNodeLoadAndStores(dstNode->id); 1943 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, 1944 dstLoopCollector.storeOpInsts); 1945 // Remove old sibling loop nest if it no longer has outgoing dependence 1946 // edges, and it does not write to a memref which escapes the 1947 // function. 1948 if (mdg->getOutEdgeCount(sibNode->id) == 0) { 1949 Operation *op = sibNode->op; 1950 mdg->removeNode(sibNode->id); 1951 op->erase(); 1952 } 1953 } 1954 1955 // Clean up any allocs with no users. 1956 void eraseUnusedMemRefAllocations() { 1957 for (auto &pair : mdg->memrefEdgeCount) { 1958 if (pair.second > 0) 1959 continue; 1960 auto memref = pair.first; 1961 // Skip if there exist other uses (return operation or function calls). 1962 if (!memref.use_empty()) 1963 continue; 1964 // Use list expected to match the dep graph info. 1965 auto *op = memref.getDefiningOp(); 1966 if (isa_and_nonnull<memref::AllocOp>(op)) 1967 op->erase(); 1968 } 1969 } 1970 }; 1971 1972 } // namespace 1973 1974 void LoopFusion::runOnOperation() { 1975 MemRefDependenceGraph g; 1976 if (!g.init(getOperation())) 1977 return; 1978 1979 Optional<unsigned> fastMemorySpaceOpt; 1980 if (fastMemorySpace.hasValue()) 1981 fastMemorySpaceOpt = fastMemorySpace; 1982 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024; 1983 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt, 1984 maximalFusion, computeToleranceThreshold); 1985 1986 if (affineFusionMode == FusionMode::ProducerConsumer) 1987 fusion.runProducerConsumerFusionOnly(); 1988 else if (affineFusionMode == FusionMode::Sibling) 1989 fusion.runSiblingFusionOnly(); 1990 else 1991 fusion.runGreedyFusion(); 1992 } 1993