1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Linalg/Passes.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/SCF/Transforms.h" 17 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" 18 #include "mlir/IR/AffineExpr.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/Support/LLVM.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "mlir/Transforms/FoldUtils.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 using namespace mlir; 28 using namespace mlir::linalg; 29 30 static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc, 31 AffineMap map, 32 ArrayRef<Value> vals) { 33 if (map.isEmpty()) 34 return {}; 35 36 assert(map.getNumInputs() == vals.size()); 37 SmallVector<Value> res; 38 res.reserve(map.getNumResults()); 39 auto dims = map.getNumDims(); 40 for (auto e : map.getResults()) { 41 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 42 SmallVector<Value> operands(vals.begin(), vals.end()); 43 canonicalizeMapAndOperands(&exprMap, &operands); 44 res.push_back(b.create<AffineApplyOp>(loc, exprMap, operands)); 45 } 46 return res; 47 } 48 49 template <typename LoadOpTy, typename StoreOpTy, typename OpType> 50 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, 51 ArrayRef<Value> indexedValues, 52 ArrayRef<SmallVector<Value>> indexing, 53 ArrayRef<Value> outputBuffers) { 54 auto &block = op->getRegion(0).front(); 55 BlockAndValueMapping map; 56 map.map(block.getArguments(), indexedValues); 57 for (auto &op : block.without_terminator()) { 58 auto *newOp = b.clone(op, map); 59 map.map(op.getResults(), newOp->getResults()); 60 } 61 62 Operation *terminator = block.getTerminator(); 63 for (OpOperand &operand : terminator->getOpOperands()) { 64 Value toStore = map.lookupOrDefault(operand.get()); 65 b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()], 66 indexing[operand.getOperandNumber()]); 67 } 68 } 69 70 // Returns a pair that contains input indices and output indices of a 71 // SingleInputPoolingOp `op`. 72 struct InputAndOutputIndices { 73 SmallVector<Value> inputs; 74 SmallVector<Value> outputs; 75 }; 76 template <typename SingleInputPoolingOp> 77 static InputAndOutputIndices 78 getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs, 79 SingleInputPoolingOp op) { 80 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 81 auto maps = llvm::to_vector<8>( 82 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 83 return InputAndOutputIndices{ 84 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 85 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 86 } 87 88 /// Emits the MLIR for the scalar part of the generic op by: 89 /// 1. Emitting load ops for each input and output view in order. This is 90 /// achieved by applying the appropriate input or output map to the 91 /// enclosing induction variables. 92 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 93 /// from point 1. above. 94 /// 3. Emitting store ops to store the results of 2. to the output 95 /// views. 96 /// 97 /// An example output may resemble: 98 /// 99 /// ``` 100 /// scf.for %i = %c0 to %0 step %c1 { 101 /// scf.for %j = %c0 to %1 step %c1 { 102 /// scf.for %k = %c0 to %4 step %c1 { 103 /// %11 = load %arg0[%i, %j] : 104 /// memref<?x?xf32, stride_specification> 105 /// %12 = load %arg1[%i, %j, %k] : 106 /// memref<?x?x?xf32, stride_specification> 107 /// %13 = load %arg2[%i, %k, %j] : 108 /// memref<?x?x?xf32, stride_specification> 109 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 110 /// store %14#0, %arg1[%i, %j, %k] : 111 /// memref<?x?x?Xf32, stride_specification> 112 /// store %14#1, %arg2[%i, %k, %j] : 113 /// memref<?x?x?Xf32, stride_specification> 114 /// } 115 /// } 116 /// } 117 /// ``` 118 template <typename LoadOpTy, typename StoreOpTy> 119 static void emitScalarImplementation(OpBuilder &b, Location loc, 120 ArrayRef<Value> allIvs, 121 LinalgOp linalgOp) { 122 assert(linalgOp.hasBufferSemantics() && 123 "expected linalg op with buffer semantics"); 124 SmallVector<Value> indexedValues; 125 indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); 126 127 auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end()); 128 129 // TODO: Avoid the loads if the corresponding argument of the 130 // region has no uses. 131 // 1.a. Emit load from input operand or for scalars access the operand itself. 132 for (OpOperand *inputOperand : linalgOp.getInputOperands()) { 133 if (linalgOp.isScalar(inputOperand)) { 134 indexedValues.push_back(inputOperand->get()); 135 continue; 136 } 137 auto indexing = makeCanonicalAffineApplies( 138 b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims); 139 indexedValues.push_back( 140 b.create<LoadOpTy>(loc, inputOperand->get(), indexing)); 141 } 142 // 1.b. Emit load from output views. 143 for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { 144 SmallVector<Value> indexing = makeCanonicalAffineApplies( 145 b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims); 146 indexedValues.push_back( 147 b.create<LoadOpTy>(loc, outputOperand->get(), indexing)); 148 } 149 150 // TODO: When a region inliner exists, use it. 151 // 2. Inline region, currently only works for a single basic block. 152 // 3. Emit store. 153 SmallVector<SmallVector<Value>, 8> indexing; 154 SmallVector<Value> outputBuffers; 155 for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) { 156 indexing.push_back(makeCanonicalAffineApplies( 157 b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims)); 158 outputBuffers.push_back(outputOperand->get()); 159 } 160 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues, 161 indexing, outputBuffers); 162 } 163 164 /// Replace the index operations in the body of the loop nest by the matching 165 /// induction variables. 166 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, 167 PatternRewriter &rewriter, 168 ArrayRef<Operation *> loopOps) { 169 // Extract the induction variables of the loop nest from outer to inner. 170 SmallVector<Value> allIvs; 171 for (Operation *loopOp : loopOps) { 172 llvm::TypeSwitch<Operation *>(loopOp) 173 .Case([&](scf::ParallelOp parallelOp) { 174 allIvs.append(parallelOp.getInductionVars().begin(), 175 parallelOp.getInductionVars().end()); 176 }) 177 .Case([&](scf::ForOp forOp) { 178 allIvs.push_back(forOp.getInductionVar()); 179 }) 180 .Case([&](AffineForOp affineForOp) { 181 allIvs.push_back(affineForOp.getInductionVar()); 182 }) 183 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 184 } 185 assert(linalgOp.getNumLoops() == allIvs.size() && 186 "expected the number of loops and induction variables to match"); 187 // Replace the index operations in the body of the innermost loop op. 188 if (!loopOps.empty()) { 189 LoopLikeOpInterface loopOp = loopOps.back(); 190 for (IndexOp indexOp : 191 llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) 192 rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); 193 } 194 } 195 196 template <typename LoopTy> 197 static FailureOr<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter, 198 LinalgOp linalgOp) { 199 using LoadOpTy = 200 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value, 201 AffineLoadOp, memref::LoadOp>::type; 202 using StoreOpTy = 203 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value, 204 AffineStoreOp, memref::StoreOp>::type; 205 206 // The flattened loopToOperandRangesMaps is expected to be an invertible 207 // permutation map (which is asserted in the inverse calculation). 208 assert(linalgOp.hasBufferSemantics() && 209 "expected linalg op with buffer semantics"); 210 211 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); 212 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); 213 214 SmallVector<Value> allIvs; 215 GenerateLoopNest<LoopTy>::doit( 216 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, 217 [&](OpBuilder &b, Location loc, ValueRange ivs, 218 ValueRange operandValuesToUse) -> scf::ValueVector { 219 assert(operandValuesToUse == linalgOp->getOperands() && 220 "expect operands are captured and not passed by loop argument"); 221 allIvs.append(ivs.begin(), ivs.end()); 222 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp); 223 return scf::ValueVector{}; 224 }); 225 // Number of loop ops might be different from the number of ivs since some 226 // loops like affine.parallel and scf.parallel have multiple ivs. 227 SetVector<Operation *> loopSet; 228 for (Value iv : allIvs) { 229 if (!iv) 230 return failure(); 231 // The induction variable is a block argument of the entry block of the 232 // loop operation. 233 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 234 if (!ivVal) 235 return failure(); 236 loopSet.insert(ivVal.getOwner()->getParentOp()); 237 } 238 LinalgLoops loops(loopSet.begin(), loopSet.end()); 239 // Replace all index operations in the loop body. 240 replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops); 241 return loops; 242 } 243 244 namespace { 245 template <typename LoopType> 246 class LinalgRewritePattern : public RewritePattern { 247 public: 248 LinalgRewritePattern(MLIRContext *context) 249 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 250 251 LogicalResult matchAndRewrite(Operation *op, 252 PatternRewriter &rewriter) const override { 253 auto linalgOp = dyn_cast<LinalgOp>(op); 254 if (!isa<LinalgOp>(op)) 255 return failure(); 256 if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp))) 257 return failure(); 258 rewriter.eraseOp(op); 259 return success(); 260 } 261 }; 262 263 /// Local folding pattern for AffineApplyOp that we can apply greedily. 264 /// This replaces AffineApplyOp by the proper value in cases where the 265 /// associated map is trivial. 266 /// A trivial map here is defined as a map with a single result and either: 267 /// 1. Zero operand + returns a single AffineConstantExpr 268 /// 2. One operand + returns a single AffineDimExpr 269 /// 3. One operand + returns a single AffineSymbolExpr 270 // 271 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 272 /// other cases, it is replaced by its unique operand. 273 struct FoldAffineOp : public RewritePattern { 274 FoldAffineOp(MLIRContext *context) 275 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 276 277 LogicalResult matchAndRewrite(Operation *op, 278 PatternRewriter &rewriter) const override { 279 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 280 auto map = affineApplyOp.getAffineMap(); 281 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 282 return failure(); 283 284 AffineExpr expr = map.getResult(0); 285 if (map.getNumInputs() == 0) { 286 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 287 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue()); 288 return success(); 289 } 290 return failure(); 291 } 292 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 293 rewriter.replaceOp(op, op->getOperand(0)); 294 return success(); 295 } 296 return failure(); 297 } 298 }; 299 300 template <typename LoopType> 301 static void lowerLinalgToLoopsImpl(func::FuncOp funcOp) { 302 MLIRContext *context = funcOp.getContext(); 303 RewritePatternSet patterns(context); 304 patterns.add<LinalgRewritePattern<LoopType>>(context); 305 memref::DimOp::getCanonicalizationPatterns(patterns, context); 306 tensor::DimOp::getCanonicalizationPatterns(patterns, context); 307 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 308 patterns.add<FoldAffineOp>(context); 309 // Just apply the patterns greedily. 310 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 311 } 312 313 struct LowerToAffineLoops 314 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 315 void getDependentDialects(DialectRegistry ®istry) const override { 316 registry.insert<memref::MemRefDialect>(); 317 } 318 void runOnOperation() override { 319 lowerLinalgToLoopsImpl<AffineForOp>(getOperation()); 320 } 321 }; 322 323 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 324 void getDependentDialects(DialectRegistry ®istry) const override { 325 registry.insert<memref::MemRefDialect, scf::SCFDialect>(); 326 } 327 void runOnOperation() override { 328 lowerLinalgToLoopsImpl<scf::ForOp>(getOperation()); 329 } 330 }; 331 332 struct LowerToParallelLoops 333 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 334 void runOnOperation() override { 335 lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation()); 336 } 337 }; 338 339 } // namespace 340 341 std::unique_ptr<OperationPass<func::FuncOp>> 342 mlir::createConvertLinalgToLoopsPass() { 343 return std::make_unique<LowerToLoops>(); 344 } 345 346 std::unique_ptr<OperationPass<func::FuncOp>> 347 mlir::createConvertLinalgToParallelLoopsPass() { 348 return std::make_unique<LowerToParallelLoops>(); 349 } 350 351 std::unique_ptr<OperationPass<func::FuncOp>> 352 mlir::createConvertLinalgToAffineLoopsPass() { 353 return std::make_unique<LowerToAffineLoops>(); 354 } 355 356 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. 357 FailureOr<LinalgLoops> 358 mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, 359 LinalgOp linalgOp) { 360 return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp); 361 } 362 363 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. 364 FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, 365 LinalgOp linalgOp) { 366 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp); 367 } 368 369 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. 370 FailureOr<LinalgLoops> 371 mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, 372 LinalgOp linalgOp) { 373 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp); 374 } 375