1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/SCF/Transforms.h" 16 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 17 #include "mlir/IR/AffineExpr.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/FoldUtils.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc, 30 AffineMap map, 31 ArrayRef<Value> vals) { 32 if (map.isEmpty()) 33 return {}; 34 35 assert(map.getNumInputs() == vals.size()); 36 SmallVector<Value> res; 37 res.reserve(map.getNumResults()); 38 auto dims = map.getNumDims(); 39 for (auto e : map.getResults()) { 40 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 41 SmallVector<Value> operands(vals.begin(), vals.end()); 42 canonicalizeMapAndOperands(&exprMap, &operands); 43 res.push_back(b.create<AffineApplyOp>(loc, exprMap, operands)); 44 } 45 return res; 46 } 47 48 template <typename LoadOpTy, typename StoreOpTy, typename OpType> 49 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, 50 ArrayRef<Value> indexedValues, 51 ArrayRef<SmallVector<Value>> indexing, 52 ArrayRef<Value> outputBuffers) { 53 auto &block = op->getRegion(0).front(); 54 BlockAndValueMapping map; 55 map.map(block.getArguments(), indexedValues); 56 for (auto &op : block.without_terminator()) { 57 auto *newOp = b.clone(op, map); 58 map.map(op.getResults(), newOp->getResults()); 59 } 60 61 Operation *terminator = block.getTerminator(); 62 for (OpOperand &operand : terminator->getOpOperands()) { 63 Value toStore = map.lookupOrDefault(operand.get()); 64 b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()], 65 indexing[operand.getOperandNumber()]); 66 } 67 } 68 69 // Returns a pair that contains input indices and output indices of a 70 // SingleInputPoolingOp `op`. 71 struct InputAndOutputIndices { 72 SmallVector<Value> inputs; 73 SmallVector<Value> outputs; 74 }; 75 template <typename SingleInputPoolingOp> 76 static InputAndOutputIndices 77 getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs, 78 SingleInputPoolingOp op) { 79 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 80 auto maps = llvm::to_vector<8>( 81 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 82 return InputAndOutputIndices{ 83 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 84 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 85 } 86 87 /// Emits the MLIR for the scalar part of the generic op by: 88 /// 1. Emitting load ops for each input and output view in order. This is 89 /// achieved by applying the appropriate input or output map to the 90 /// enclosing induction variables. 91 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 92 /// from point 1. above. 93 /// 3. Emitting store ops to store the results of 2. to the output 94 /// views. 95 /// 96 /// An example output may resemble: 97 /// 98 /// ``` 99 /// scf.for %i = %c0 to %0 step %c1 { 100 /// scf.for %j = %c0 to %1 step %c1 { 101 /// scf.for %k = %c0 to %4 step %c1 { 102 /// %11 = load %arg0[%i, %j] : 103 /// memref<?x?xf32, stride_specification> 104 /// %12 = load %arg1[%i, %j, %k] : 105 /// memref<?x?x?xf32, stride_specification> 106 /// %13 = load %arg2[%i, %k, %j] : 107 /// memref<?x?x?xf32, stride_specification> 108 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 109 /// store %14#0, %arg1[%i, %j, %k] : 110 /// memref<?x?x?Xf32, stride_specification> 111 /// store %14#1, %arg2[%i, %k, %j] : 112 /// memref<?x?x?Xf32, stride_specification> 113 /// } 114 /// } 115 /// } 116 /// ``` 117 template <typename LoadOpTy, typename StoreOpTy> 118 static void emitScalarImplementation(OpBuilder &b, Location loc, 119 ArrayRef<Value> allIvs, 120 LinalgOp linalgOp) { 121 assert(linalgOp.hasBufferSemantics() && 122 "expected linalg op with buffer semantics"); 123 SmallVector<Value> indexedValues; 124 indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); 125 126 auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end()); 127 128 // TODO: Avoid the loads if the corresponding argument of the 129 // region has no uses. 130 // 1.a. Emit load from input operand or for scalars access the operand itself. 131 for (OpOperand *inputOperand : linalgOp.getInputOperands()) { 132 if (linalgOp.isScalar(inputOperand)) { 133 indexedValues.push_back(inputOperand->get()); 134 continue; 135 } 136 auto indexing = makeCanonicalAffineApplies( 137 b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims); 138 indexedValues.push_back( 139 b.create<LoadOpTy>(loc, inputOperand->get(), indexing)); 140 } 141 // 1.b. Emit load from output views. 142 for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { 143 SmallVector<Value> indexing = makeCanonicalAffineApplies( 144 b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims); 145 indexedValues.push_back( 146 b.create<LoadOpTy>(loc, outputOperand->get(), indexing)); 147 } 148 149 // TODO: When a region inliner exists, use it. 150 // 2. Inline region, currently only works for a single basic block. 151 // 3. Emit store. 152 SmallVector<SmallVector<Value>, 8> indexing; 153 SmallVector<Value> outputBuffers; 154 for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) { 155 indexing.push_back(makeCanonicalAffineApplies( 156 b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims)); 157 outputBuffers.push_back(outputOperand->get()); 158 } 159 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues, 160 indexing, outputBuffers); 161 } 162 163 // Create a padded view into the given `input` tensor using the 'indices' 164 // to access the tensor. `skipPadding` lists the dimensions for which no padding 165 // is needed e.g. the non-spatial dimensions for convolutions. 166 Value getPaddedInput(OpBuilder &b, Location loc, Value input, 167 ArrayRef<Value> indices, ArrayRef<int> skipPadding, 168 Value padValue) { 169 Value zeroIndex = b.create<ConstantIndexOp>(loc, 0); 170 SmallVector<Value> conds; 171 SmallVector<Value> clampedImIdx; 172 for (auto iter : llvm::enumerate(indices)) { 173 int idx = iter.index(); 174 auto dim = iter.value(); 175 if (is_contained(skipPadding, idx)) { 176 clampedImIdx.push_back(dim); 177 continue; 178 } 179 180 Value leftOutOfBound = 181 b.create<CmpIOp>(loc, CmpIPredicate::slt, dim, zeroIndex); 182 if (conds.empty()) 183 conds.push_back(leftOutOfBound); 184 else 185 conds.push_back(b.create<OrOp>(loc, conds.back(), leftOutOfBound)); 186 Value rightBound = createOrFoldDimOp(b, loc, input, idx); 187 Value rightOutOfBound = 188 b.create<CmpIOp>(loc, CmpIPredicate::sge, dim, rightBound); 189 conds.push_back(b.create<OrOp>(loc, conds.back(), rightOutOfBound)); 190 191 // When padding is involved, the indices will only be shifted to negative, 192 // so having a max op is enough. 193 MLIRContext *ctx = input.getContext(); 194 AffineExpr m = getAffineDimExpr(/*position=*/0, ctx), 195 zero = getAffineConstantExpr(0, ctx); 196 AffineMap maxMap = 197 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>>{{m, zero}}) 198 .front(); 199 clampedImIdx.push_back(b.create<AffineMaxOp>(loc, maxMap, ValueRange{dim})); 200 } 201 202 Value readInput = b.create<memref::LoadOp>(loc, input, clampedImIdx); 203 if (conds.empty()) 204 return readInput; 205 206 return b.create<SelectOp>(loc, conds.back(), padValue, readInput); 207 } 208 209 namespace { 210 211 /// The padding value for a given Op depends on the semantics of the Op. 212 /// The identity value for ConvOp is 0. 213 template <typename OpType> Attribute getPadValueAttr(Type type) { 214 llvm_unreachable("Unexpected op type for getPadValueAttr"); 215 return {}; 216 } 217 218 template <> Attribute getPadValueAttr<ConvOp>(Type type) { 219 return OpBuilder(type.getContext()).getZeroAttr(type); 220 } 221 222 } // namespace 223 224 /// Returns true is `convOp` has a non-zero padding. 225 static bool hasPadding(ConvOp convOp) { 226 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 227 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 228 return true; 229 } 230 return false; 231 } 232 233 template <typename LoadOpTy, typename StoreOpTy> 234 static void emitScalarImplementation(OpBuilder &b, Location loc, 235 ArrayRef<Value> allIvs, ConvOp convOp) { 236 assert(convOp.hasBufferSemantics() && 237 "expected linalg op with buffer semantics"); 238 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 239 auto maps = llvm::to_vector<8>( 240 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 241 SmallVector<Value> fIdx(makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 242 SmallVector<Value> imIdx(makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 243 SmallVector<Value> oIdx(makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 244 245 Value filter = convOp.filter(), output = convOp.output(); 246 247 // Emit scalar form. Padded conv involves an affine.max in the memory access 248 // which is not allowed by affine.load. Override to use an MemRefIndexedValue 249 // when there is non-zero padding. 250 if (hasPadding(convOp)) { 251 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 252 Value padValue = 253 b.create<ConstantOp>(loc, type, getPadValueAttr<ConvOp>(type)); 254 Value paddedInput = 255 getPaddedInput(b, loc, convOp.input(), imIdx, 256 /* Only need to pad the window dimensions */ 257 {0, static_cast<int>(imIdx.size()) - 1}, padValue); 258 Value filterVal = b.create<LoadOpTy>(loc, filter, fIdx); 259 Value mulVal = ArithBuilder(b, loc).mul(filterVal, paddedInput); 260 Value outputVal = b.create<LoadOpTy>(loc, output, oIdx); 261 Value addVal = ArithBuilder(b, loc).add(mulVal, outputVal); 262 b.create<StoreOpTy>(loc, addVal, output, oIdx); 263 } else { 264 Value inputVal = b.create<LoadOpTy>(loc, convOp.input(), imIdx); 265 Value filterVal = b.create<LoadOpTy>(loc, filter, fIdx); 266 Value mulVal = ArithBuilder(b, loc).mul(filterVal, inputVal); 267 Value outputVal = b.create<LoadOpTy>(loc, output, oIdx); 268 Value addVal = ArithBuilder(b, loc).add(mulVal, outputVal); 269 b.create<StoreOpTy>(loc, addVal, output, oIdx); 270 } 271 } 272 273 /// Replace the index operations in the body of the loop nest by the matching 274 /// induction variables. 275 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, 276 PatternRewriter &rewriter, 277 ArrayRef<Operation *> loopOps) { 278 // Extract the induction variables of the loop nest from outer to inner. 279 SmallVector<Value> allIvs; 280 for (Operation *loopOp : loopOps) { 281 llvm::TypeSwitch<Operation *>(loopOp) 282 .Case([&](scf::ParallelOp parallelOp) { 283 allIvs.append(parallelOp.getInductionVars().begin(), 284 parallelOp.getInductionVars().end()); 285 }) 286 .Case([&](scf::ForOp forOp) { 287 allIvs.push_back(forOp.getInductionVar()); 288 }) 289 .Case([&](AffineForOp affineForOp) { 290 allIvs.push_back(affineForOp.getInductionVar()); 291 }) 292 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 293 } 294 assert(linalgOp.getNumLoops() == allIvs.size() && 295 "expected the number of loops and induction variables to match"); 296 // Replace the index operations in the body of the innermost loop op. 297 if (!loopOps.empty()) { 298 LoopLikeOpInterface loopOp = loopOps.back(); 299 for (IndexOp indexOp : 300 llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) 301 rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); 302 } 303 } 304 305 template <typename LoopTy> 306 static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter, 307 LinalgOp linalgOp) { 308 using LoadOpTy = 309 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value, 310 AffineLoadOp, memref::LoadOp>::type; 311 using StoreOpTy = 312 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value, 313 AffineStoreOp, memref::StoreOp>::type; 314 315 // The flattened loopToOperandRangesMaps is expected to be an invertible 316 // permutation map (which is asserted in the inverse calculation). 317 assert(linalgOp.hasBufferSemantics() && 318 "expected linalg op with buffer semantics"); 319 320 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); 321 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); 322 323 SmallVector<Value> allIvs; 324 GenerateLoopNest<LoopTy>::doit( 325 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, 326 [&](OpBuilder &b, Location loc, ValueRange ivs, 327 ValueRange operandValuesToUse) -> scf::ValueVector { 328 assert(operandValuesToUse == linalgOp->getOperands() && 329 "expect operands are captured and not passed by loop argument"); 330 allIvs.append(ivs.begin(), ivs.end()); 331 llvm::TypeSwitch<Operation *>(linalgOp) 332 .Case<ConvOp, LinalgOp>([&](auto op) { 333 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, op); 334 }) 335 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 336 return scf::ValueVector{}; 337 }); 338 // Number of loop ops might be different from the number of ivs since some 339 // loops like affine.parallel and scf.parallel have multiple ivs. 340 SetVector<Operation *> loopSet; 341 for (Value iv : allIvs) { 342 if (!iv) 343 return {}; 344 // The induction variable is a block argument of the entry block of the 345 // loop operation. 346 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 347 if (!ivVal) 348 return {}; 349 loopSet.insert(ivVal.getOwner()->getParentOp()); 350 } 351 LinalgLoops loops(loopSet.begin(), loopSet.end()); 352 // Replace all index operations in the loop body. 353 replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops); 354 return loops; 355 } 356 357 namespace { 358 template <typename LoopType> 359 class LinalgRewritePattern : public RewritePattern { 360 public: 361 LinalgRewritePattern(MLIRContext *context) 362 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 363 364 LogicalResult matchAndRewrite(Operation *op, 365 PatternRewriter &rewriter) const override { 366 auto linalgOp = dyn_cast<LinalgOp>(op); 367 if (!isa<LinalgOp>(op)) 368 return failure(); 369 if (!linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)) 370 return failure(); 371 rewriter.eraseOp(op); 372 return success(); 373 } 374 }; 375 376 /// Converts tiled_loop to SCF loop nests. All parallel dimensions are collected 377 /// into an scf.parallel loop and all sequential dimensions will result in the 378 /// nested scf.for loop nest. The pattern assumes that a tiled loop with 379 /// iterator_types ["reduction", "parallel", "reduction"] can be reordered. It 380 /// is true for the tiling that is currently suppported by Linalg. 381 struct TiledLoopToSCFPattern : public OpRewritePattern<TiledLoopOp> { 382 using OpRewritePattern<TiledLoopOp>::OpRewritePattern; 383 384 LogicalResult matchAndRewrite(TiledLoopOp tiledLoop, 385 PatternRewriter &rewriter) const override { 386 // Fail conversion if the `tiled_loop` has not been bufferized. 387 if (!tiledLoop.hasBufferSemantics()) 388 return failure(); 389 390 // Collect loop control parameters for parallel and sequential dimensions. 391 SmallVector<Value, 3> seqLBs, seqUBs, seqSteps, seqIVs; 392 SmallVector<Value, 3> parLBs, parUBs, parSteps, parIVs; 393 for (auto en : llvm::enumerate( 394 llvm::zip(tiledLoop.lowerBound(), tiledLoop.upperBound(), 395 tiledLoop.step(), tiledLoop.getInductionVars()))) { 396 Value lb, ub, step, iv; 397 std::tie(lb, ub, step, iv) = en.value(); 398 if (tiledLoop.isParallelDimension(en.index())) { 399 parLBs.push_back(lb); 400 parUBs.push_back(ub); 401 parSteps.push_back(step); 402 parIVs.push_back(iv); 403 } else { 404 seqLBs.push_back(lb); 405 seqUBs.push_back(ub); 406 seqSteps.push_back(step); 407 seqIVs.push_back(iv); 408 } 409 } 410 411 Location loc = tiledLoop.getLoc(); 412 auto generateForLoopNestAndCloneBody = [&](OpBuilder &builder, Location loc, 413 ValueRange ivs) { 414 BlockAndValueMapping bvm; 415 bvm.map(parIVs, ivs); 416 bvm.map(tiledLoop.getRegionInputArgs(), tiledLoop.inputs()); 417 bvm.map(tiledLoop.getRegionOutputArgs(), tiledLoop.outputs()); 418 419 // If not all dimensions of the tiled loop are parallel, an scf.for loop 420 // nest is generated. 421 if (!seqIVs.empty()) { 422 scf::LoopNest nest = 423 scf::buildLoopNest(builder, loc, seqLBs, seqUBs, seqSteps, 424 [&](OpBuilder &builder, Location loc, 425 ValueRange ivs) { bvm.map(seqIVs, ivs); }); 426 builder.setInsertionPointToStart(nest.loops.back().getBody()); 427 } 428 for (auto &op : tiledLoop.getBody()->without_terminator()) 429 builder.clone(op, bvm); 430 }; 431 432 if (parIVs.empty()) 433 generateForLoopNestAndCloneBody(rewriter, loc, llvm::None); 434 else 435 rewriter.create<scf::ParallelOp>(loc, parLBs, parUBs, parSteps, 436 generateForLoopNestAndCloneBody); 437 rewriter.eraseOp(tiledLoop); 438 return success(); 439 } 440 }; 441 442 /// Local folding pattern for AffineApplyOp that we can apply greedily. 443 /// This replaces AffineApplyOp by the proper value in cases where the 444 /// associated map is trivial. 445 /// A trivial map here is defined as a map with a single result and either: 446 /// 1. Zero operand + returns a single AffineConstantExpr 447 /// 2. One operand + returns a single AffineDimExpr 448 /// 3. One operand + returns a single AffineSymbolExpr 449 // 450 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 451 /// other cases, it is replaced by its unique operand. 452 struct FoldAffineOp : public RewritePattern { 453 FoldAffineOp(MLIRContext *context) 454 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 455 456 LogicalResult matchAndRewrite(Operation *op, 457 PatternRewriter &rewriter) const override { 458 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 459 auto map = affineApplyOp.getAffineMap(); 460 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 461 return failure(); 462 463 AffineExpr expr = map.getResult(0); 464 if (map.getNumInputs() == 0) { 465 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 466 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 467 return success(); 468 } 469 return failure(); 470 } 471 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 472 rewriter.replaceOp(op, op->getOperand(0)); 473 return success(); 474 } 475 return failure(); 476 } 477 }; 478 479 template <typename LoopType> 480 static void lowerLinalgToLoopsImpl(FuncOp funcOp) { 481 MLIRContext *context = funcOp.getContext(); 482 RewritePatternSet patterns(context); 483 patterns.add<LinalgRewritePattern<LoopType>>(context); 484 memref::DimOp::getCanonicalizationPatterns(patterns, context); 485 tensor::DimOp::getCanonicalizationPatterns(patterns, context); 486 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 487 patterns.add<FoldAffineOp>(context); 488 // Just apply the patterns greedily. 489 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 490 } 491 492 struct LowerToAffineLoops 493 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 494 void getDependentDialects(DialectRegistry ®istry) const override { 495 registry.insert<memref::MemRefDialect>(); 496 } 497 void runOnFunction() override { 498 lowerLinalgToLoopsImpl<AffineForOp>(getFunction()); 499 } 500 }; 501 502 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 503 void getDependentDialects(DialectRegistry ®istry) const override { 504 registry.insert<memref::MemRefDialect, scf::SCFDialect>(); 505 } 506 void runOnFunction() override { 507 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction()); 508 } 509 }; 510 511 struct LowerToParallelLoops 512 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 513 void runOnFunction() override { 514 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction()); 515 } 516 }; 517 518 struct LowerTiledLoopsToSCF 519 : public LinalgLowerTiledLoopsToSCFBase<LowerTiledLoopsToSCF> { 520 void runOnFunction() override { 521 MLIRContext *context = &getContext(); 522 RewritePatternSet patterns(context); 523 populateTiledLoopToSCFPattern(patterns); 524 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 525 } 526 }; 527 } // namespace 528 529 /// Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly 530 /// into two TiledLoopOps: One where the step divides the iteration space 531 /// evenly, followed another one for the last (partial) iteration (if any). This 532 /// function only rewrites the `idx`-th loop of the loop nest represented by 533 /// the TiledLoopOp. To peel the entire loop nest, this function must be called 534 /// multiple times. 535 /// 536 /// This function rewrites the given TiledLoopOp in-place and creates a new 537 /// TiledLoopOp for the last iteration. It replaces all uses of the original 538 /// TiledLoopOp with the results of the newly generated one. 539 /// 540 /// The newly generated TiledLoopOp is returned via `result`. The boundary 541 /// at which the loop is split (new upper bound) is returned via `splitBound`. 542 /// The return value indicates whether the TiledLoopOp was rewritten or not. 543 static LogicalResult peelTiledLoop(RewriterBase &b, TiledLoopOp loopOp, 544 int64_t idx, TiledLoopOp &result, 545 Value &splitBound) { 546 Value lb = loopOp.lowerBound()[idx], ub = loopOp.upperBound()[idx], 547 step = loopOp.step()[idx]; 548 auto ubInt = getConstantIntValue(ub); 549 550 auto loc = loopOp.getLoc(); 551 AffineExpr exprLb, exprUb, exprStep; 552 bindSymbols(b.getContext(), exprLb, exprUb, exprStep); 553 // New upper bound: %ub - (%ub - %lb) mod %step 554 auto modMap = AffineMap::get(0, 3, {exprUb - ((exprUb - exprLb) % exprStep)}); 555 SmallVector<Value> operands{lb, ub, step}; 556 mlir::canonicalizeMapAndOperands(&modMap, &operands); 557 modMap = mlir::simplifyAffineMap(modMap); 558 RewriterBase::InsertionGuard guard(b); 559 b.setInsertionPoint(loopOp); 560 splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, operands); 561 // No specialization necessary if step already divides upper bound evenly. 562 if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound))) 563 return failure(); 564 565 // Create remainder loop. 566 b.setInsertionPointAfter(loopOp); 567 auto remainderLoop = cast<TiledLoopOp>(b.clone(*loopOp.getOperation())); 568 loopOp.replaceAllUsesWith(remainderLoop->getResults()); 569 // Outputs: Take tensors from main loop's results. Take memrefs from main 570 // loop's outputs. 571 SmallVector<Value> remainderOutputs; 572 for (unsigned o = 0, t = 0; o < loopOp.getNumOutputs(); ++o) { 573 remainderOutputs.push_back(loopOp.outputs()[o].getType().isa<MemRefType>() 574 ? loopOp.outputs()[o] 575 : loopOp->getResult(t++)); 576 } 577 remainderLoop.outputsMutable().assign(remainderOutputs); 578 579 // Set new loop bounds. 580 b.updateRootInPlace(loopOp, [&]() { 581 SmallVector<Value> ubs = loopOp.upperBound(); 582 ubs[idx] = splitBound; 583 loopOp.upperBoundMutable().assign(ubs); 584 }); 585 SmallVector<Value> lbs = remainderLoop.lowerBound(); 586 lbs[idx] = splitBound; 587 remainderLoop.lowerBoundMutable().assign(lbs); 588 589 result = remainderLoop; 590 return success(); 591 } 592 593 template <typename OpTy, bool IsMin> 594 static void 595 rewriteAffineOpAfterPeeling(RewriterBase &rewriter, TiledLoopOp mainLoop, 596 TiledLoopOp remainderLoop, Value mainIv, 597 Value remainderIv, Value ub, Value step) { 598 mainLoop.walk([&](OpTy affineOp) { 599 AffineMap map = affineOp.getAffineMap(); 600 (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, 601 affineOp.operands(), IsMin, mainIv, ub, 602 step, /*insideLoop=*/true); 603 }); 604 remainderLoop.walk([&](OpTy affineOp) { 605 AffineMap map = affineOp.getAffineMap(); 606 (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, 607 affineOp.operands(), IsMin, remainderIv, 608 ub, step, /*insideLoop=*/false); 609 }); 610 } 611 612 LogicalResult mlir::linalg::peelAndCanonicalizeTiledLoop(RewriterBase &rewriter, 613 TiledLoopOp loopOp, 614 int64_t idx, 615 TiledLoopOp &result) { 616 int64_t numLoops = loopOp.iterator_types().size(); 617 if (idx < 0 || numLoops <= idx) 618 return failure(); 619 620 Value ub = loopOp.upperBound()[idx]; 621 TiledLoopOp remainderLoop; 622 Value splitBound; 623 if (failed(peelTiledLoop(rewriter, loopOp, idx, remainderLoop, splitBound))) 624 return failure(); 625 626 // Rewrite affine.min and affine.max ops. 627 Value mainIv = loopOp.getInductionVars()[idx], step = loopOp.step()[idx], 628 remainderIv = remainderLoop.getInductionVars()[idx]; 629 630 rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>( 631 rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step); 632 rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>( 633 rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step); 634 635 result = remainderLoop; 636 return success(); 637 } 638 639 void mlir::linalg::populateTiledLoopToSCFPattern(RewritePatternSet &patterns) { 640 patterns.add<TiledLoopToSCFPattern>(patterns.getContext()); 641 } 642 643 std::unique_ptr<OperationPass<FuncOp>> 644 mlir::createConvertLinalgTiledLoopsToSCFPass() { 645 return std::make_unique<LowerTiledLoopsToSCF>(); 646 } 647 648 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 649 return std::make_unique<LowerToLoops>(); 650 } 651 652 std::unique_ptr<OperationPass<FuncOp>> 653 mlir::createConvertLinalgToParallelLoopsPass() { 654 return std::make_unique<LowerToParallelLoops>(); 655 } 656 657 std::unique_ptr<OperationPass<FuncOp>> 658 mlir::createConvertLinalgToAffineLoopsPass() { 659 return std::make_unique<LowerToAffineLoops>(); 660 } 661 662 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. 663 Optional<LinalgLoops> 664 mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, 665 LinalgOp linalgOp) { 666 return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp); 667 } 668 669 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. 670 Optional<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, 671 LinalgOp linalgOp) { 672 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp); 673 } 674 675 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. 676 Optional<LinalgLoops> 677 mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, 678 LinalgOp linalgOp) { 679 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp); 680 } 681