1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/Dominance.h" 14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/EDSC/Helpers.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Support/STLExtras.h" 27 #include "mlir/Transforms/FoldUtils.h" 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Debug.h" 31 32 #define DEBUG_TYPE "linalg-fusion" 33 34 using namespace mlir; 35 using namespace mlir::edsc; 36 using namespace mlir::edsc::intrinsics; 37 using namespace mlir::linalg; 38 using namespace mlir::linalg::intrinsics; 39 40 using llvm::dbgs; 41 42 /// Implements a simple high-level fusion pass of linalg library operations. 43 /// 44 /// In each block, linalg ops are processed in reverse textual order. 45 /// Given a linalg op `O`, fusion occurs by: 46 /// 1. inspecting the linalg ops that write into the views read by `O`. This 47 /// uses the SSA value of the views and a simple subview/slice analysis to 48 /// determine producer-consumer dependences; 49 /// 2. greedily fuse the linalg ops that produce subview 50 /// 3. inspect the fused ops and determine whether they have other remaining 51 /// LinalgOp uses. If not, then erase the original producing linalg op. 52 /// 53 /// More advanced use cases, analyses as well as profitability heuristics are 54 /// left for future work. 55 56 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 57 static llvm::cl::list<unsigned> clTileSizes( 58 "linalg-fusion-tile-sizes", 59 llvm::cl::desc( 60 "Tile sizes by which to tile linalg operations during linalg fusion"), 61 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 62 llvm::cl::cat(clOptionsCategory)); 63 64 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 65 // a subset of the original loop ranges of `op`. 66 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 67 // to the `loopRanges` in order to obtain view ranges. 68 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 69 ArrayRef<SubViewOp::Range> loopRanges) { 70 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 71 auto maps = loopToOperandRangesMaps(op); 72 SmallVector<Value, 8> clonedViews; 73 clonedViews.reserve(op.getNumInputsAndOutputs()); 74 // Iterate over the inputs and outputs in order. 75 // Extract the subranges from the linearized ranges. 76 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 77 for (auto en : llvm::enumerate(ios)) { 78 unsigned idx = en.index(); 79 auto map = maps[idx]; 80 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 81 Value view = en.value(); 82 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 83 for (auto en2 : llvm::enumerate(map.getResults())) { 84 unsigned d = en2.index(); 85 // loopToOperandRangesMaps are permutations-only. 86 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 87 viewRanges[d] = loopRanges[loopPos]; 88 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 89 << "\t" 90 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 91 } 92 // Construct a new subview for the tile. 93 unsigned rank = viewRanges.size(); 94 SmallVector<Value, 4> offsets, sizes, strides; 95 offsets.reserve(rank); 96 sizes.reserve(rank); 97 strides.reserve(rank); 98 for (auto r : viewRanges) { 99 offsets.push_back(r.offset); 100 sizes.push_back(r.size); 101 strides.push_back(r.stride); 102 } 103 clonedViews.push_back( 104 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 105 } 106 auto operands = getAssumedNonViewOperands(op); 107 clonedViews.append(operands.begin(), operands.end()); 108 return op.clone(b, loc, clonedViews); 109 } 110 111 struct ViewDimension { 112 Value view; 113 unsigned dimension; 114 }; 115 116 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 117 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 118 // guarantees at least one such dimension is found. If multiple candidates exist 119 // they must agree by construction (i.e. have the same size) and we just return 120 // the first one. 121 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 122 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 123 auto maps = loopToOperandRangesMaps(op); 124 // Iterate over the inputs and outputs in order. 125 // Extract the subranges from the linearized ranges. 126 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 127 for (auto en : llvm::enumerate(ios)) { 128 unsigned idx = en.index(); 129 auto map = maps[idx]; 130 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 131 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 132 Value view = en.value(); 133 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 134 for (auto en2 : llvm::enumerate(map.getResults())) { 135 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 136 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 137 << "\n"); 138 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 139 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 140 } 141 } 142 } 143 llvm_unreachable("Expect to be able to extract a view defining loop range"); 144 } 145 146 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, 147 unsigned consumerIdx, unsigned producerIdx, 148 OperationFolder *folder) { 149 assert(producer.hasBufferSemantics() && 150 "expected linalg op with buffer semantics"); 151 assert(consumer.hasBufferSemantics() && 152 "expected linalg op with buffer semantics"); 153 auto subView = dyn_cast_or_null<SubViewOp>( 154 consumer.getInput(consumerIdx).getDefiningOp()); 155 auto slice = 156 dyn_cast_or_null<SliceOp>(consumer.getInput(consumerIdx).getDefiningOp()); 157 assert(subView || slice); 158 (void)subView; 159 (void)slice; 160 161 // loopToOperandRangesMaps are permutations-only by construction: 162 // we can always identify a data dimension with a (at least one) loop 163 // dimension. 164 AffineMap producerMap = 165 loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx]; 166 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 167 << ", producer map: " << producerMap << "\n"); 168 169 unsigned nPar = producer.getNumParallelLoops(); 170 unsigned nRed = producer.getNumReductionLoops(); 171 unsigned nWin = producer.getNumWindowLoops(); 172 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 173 174 // Iterate over dimensions identified by the producer map for `producerIdx`. 175 // This defines a subset of the loop ranges that we need to complete later. 176 for (auto en : llvm::enumerate(producerMap.getResults())) { 177 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 178 loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; 179 } 180 181 OpBuilder b(consumer.getOperation()); 182 auto loc = consumer.getLoc(); 183 // Iterate over all dimensions. For the dimensions not identified by the 184 // producer map for `producerIdx`, we need to explicitly compute the view that 185 // defines the loop ranges using the `producer`. 186 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 187 if (loopRanges[i].offset) 188 LLVM_DEBUG(llvm::dbgs() 189 << "existing LoopRange: " << loopRanges[i] << "\n"); 190 else { 191 auto viewDim = getViewDefiningLoopRange(producer, i); 192 loopRanges[i] = SubViewOp::Range{constant_index(folder, 0), 193 dim(viewDim.view, viewDim.dimension), 194 constant_index(folder, 1)}; 195 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 196 } 197 } 198 199 return cloneWithLoopRanges(b, loc, producer, loopRanges); 200 } 201 202 // Encode structural fusion safety preconditions. 203 // Some of these will be lifted in the future with better analysis. 204 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 205 LinalgOp consumer) { 206 assert(producer.hasBufferSemantics() && 207 "expected linalg op with buffer semantics"); 208 assert(consumer.hasBufferSemantics() && 209 "expected linalg op with buffer semantics"); 210 if (producer.getNumOutputs() != 1) { 211 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 212 return false; 213 } 214 // Only fuse when the producer block dominates. 215 DominanceInfo dom(producer.getOperation()); 216 if (!dom.dominates(producer.getOperation()->getBlock(), 217 consumer.getOperation()->getBlock())) { 218 LLVM_DEBUG( 219 dbgs() 220 << "\nNot structurally fusable (producer block does not dominate)"); 221 return false; 222 } 223 return true; 224 } 225 226 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 227 LinalgOp consumer, 228 Value consumedView, 229 LinalgOp producer) { 230 assert(producer.hasBufferSemantics() && 231 "expected linalg op with buffer semantics"); 232 assert(consumer.hasBufferSemantics() && 233 "expected linalg op with buffer semantics"); 234 // Make some simple structural checks that alleviate the need for more 235 // complex analyses. 236 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 237 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 238 << *producer.getOperation()); 239 return false; 240 } 241 // Check for any interleaved write to consumedView. 242 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 243 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 244 << *producer.getOperation()); 245 return false; 246 } 247 return true; 248 } 249 250 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 251 LinalgOp consumer, Value consumedView, 252 LinalgOp producer) { 253 assert(producer.hasBufferSemantics() && 254 "expected linalg op with buffer semantics"); 255 assert(consumer.hasBufferSemantics() && 256 "expected linalg op with buffer semantics"); 257 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 258 return false; 259 // Check for any fusion-preventing dependence to any view read/written that 260 // would violate dependences. 261 if (!graph.findCoveringDependences(producer, consumer).empty()) { 262 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 263 << *producer.getOperation()); 264 return false; 265 } 266 return true; 267 } 268 269 // Only consider RAW atm. 270 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 271 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 272 const LinalgDependenceGraph &graph, OperationFolder *folder) { 273 assert(consumer.hasBufferSemantics() && 274 "expected linalg op with buffer semantics"); 275 LLVM_DEBUG(dbgs() << "\nStart examining consumer: " 276 << *consumer.getOperation()); 277 for (auto dependence : graph.getDependencesInto( 278 consumer, LinalgDependenceGraph::DependenceType::RAW)) { 279 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 280 << *dependence.dependentOpView.op << "\n"); 281 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 282 283 // Check that the dependence is indeed on the input `consumerIdx` view. 284 auto consumedView = dependence.indexingView; 285 if (consumer.getInput(consumerIdx) != consumedView) 286 continue; 287 288 // Consumer consumes this view, `isStructurallyFusableProducer` also checks 289 // whether it is a strict subview of the producer view. 290 auto producedView = dependence.dependentOpView.view; 291 auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); 292 // `consumerIdx` and `producerIdx` exist by construction. 293 LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() 294 << " view: " << producedView 295 << " output index: " << producerIdx); 296 297 // Must be a subview or a slice to guarantee there are loops we can fuse 298 // into. 299 auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp()); 300 auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp()); 301 if (!subView && !slice) { 302 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 303 continue; 304 } 305 306 // Simple fusability checks. 307 if (!isFusableInto(graph, consumer, consumedView, producer)) 308 continue; 309 310 // Fuse `producer` just before `consumer`. 311 OpBuilder::InsertionGuard g(b); 312 b.setInsertionPoint(consumer.getOperation()); 313 ScopedContext scope(b, consumer.getLoc()); 314 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 315 auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, 316 producerIdx, folder); 317 318 return FusionInfo{producer, fusedProducer}; 319 } 320 return llvm::None; 321 } 322 323 static void fuseLinalgOpsGreedily(FuncOp f) { 324 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 325 326 OpBuilder b(f); 327 OperationFolder folder(f.getContext()); 328 DenseSet<Operation *> eraseSet; 329 330 // Save original Linalg ops, we only want to make a pass over those. 331 SmallVector<Operation *, 8> linalgOps; 332 f.walk([&](LinalgOp op) { 333 if (op.hasBufferSemantics()) 334 linalgOps.push_back(op); 335 }); 336 337 Aliases aliases; 338 LinalgDependenceGraph G(aliases, linalgOps); 339 for (auto *op : llvm::reverse(linalgOps)) { 340 for (unsigned consumerIdx = 0, e = LinalgOp(op).getNumInputs(); 341 consumerIdx < e; ++consumerIdx) { 342 if (auto fusionInfo = fuseProducerOf(b, op, consumerIdx, G, &folder)) 343 eraseSet.insert(fusionInfo->originalProducer.getOperation()); 344 } 345 } 346 347 // The `fuseProducerOf` function performs structural checks and in particular 348 // that no covering read or write exist between the consumer and the producer. 349 // As a consequence, the only fusions that may occur preserve subsequent 350 // dependences and are guaranteed by construction to produce the whole view. 351 // We may thus erase the producer once it is fused. 352 for (auto *e : eraseSet) 353 e->erase(); 354 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 355 } 356 357 namespace { 358 struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> { 359 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 360 }; 361 } // namespace 362 363 std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgFusionPass() { 364 return std::make_unique<LinalgFusionPass>(); 365 } 366 367 static PassRegistration<LinalgFusionPass> 368 pass("linalg-fusion", "Fuse operations in the linalg dialect"); 369