1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "mlir/Support/LLVM.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 28 static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc, 29 AffineMap map, 30 ArrayRef<Value> vals) { 31 if (map.isEmpty()) 32 return {}; 33 34 assert(map.getNumInputs() == vals.size()); 35 SmallVector<Value> res; 36 res.reserve(map.getNumResults()); 37 auto dims = map.getNumDims(); 38 for (auto e : map.getResults()) { 39 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 40 SmallVector<Value> operands(vals.begin(), vals.end()); 41 canonicalizeMapAndOperands(&exprMap, &operands); 42 res.push_back(b.create<AffineApplyOp>(loc, exprMap, operands)); 43 } 44 return res; 45 } 46 47 template <typename LoadOpTy, typename StoreOpTy, typename OpType> 48 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, 49 ArrayRef<Value> indexedValues, 50 ArrayRef<SmallVector<Value>> indexing, 51 ArrayRef<Value> outputBuffers) { 52 auto &block = op->getRegion(0).front(); 53 BlockAndValueMapping map; 54 map.map(block.getArguments(), indexedValues); 55 for (auto &op : block.without_terminator()) { 56 auto *newOp = b.clone(op, map); 57 map.map(op.getResults(), newOp->getResults()); 58 } 59 60 Operation *terminator = block.getTerminator(); 61 for (OpOperand &operand : terminator->getOpOperands()) { 62 Value toStore = map.lookupOrDefault(operand.get()); 63 b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()], 64 indexing[operand.getOperandNumber()]); 65 } 66 } 67 68 // Returns a pair that contains input indices and output indices of a 69 // SingleInputPoolingOp `op`. 70 struct InputAndOutputIndices { 71 SmallVector<Value> inputs; 72 SmallVector<Value> outputs; 73 }; 74 template <typename SingleInputPoolingOp> 75 static InputAndOutputIndices 76 getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs, 77 SingleInputPoolingOp op) { 78 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 79 auto maps = llvm::to_vector<8>( 80 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 81 return InputAndOutputIndices{ 82 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 83 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 84 } 85 86 /// Emits the MLIR for the scalar part of the generic op by: 87 /// 1. Emitting load ops for each input and output view in order. This is 88 /// achieved by applying the appropriate input or output map to the 89 /// enclosing induction variables. 90 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 91 /// from point 1. above. 92 /// 3. Emitting store ops to store the results of 2. to the output 93 /// views. 94 /// 95 /// An example output may resemble: 96 /// 97 /// ``` 98 /// scf.for %i = %c0 to %0 step %c1 { 99 /// scf.for %j = %c0 to %1 step %c1 { 100 /// scf.for %k = %c0 to %4 step %c1 { 101 /// %11 = load %arg0[%i, %j] : 102 /// memref<?x?xf32, stride_specification> 103 /// %12 = load %arg1[%i, %j, %k] : 104 /// memref<?x?x?xf32, stride_specification> 105 /// %13 = load %arg2[%i, %k, %j] : 106 /// memref<?x?x?xf32, stride_specification> 107 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 108 /// store %14#0, %arg1[%i, %j, %k] : 109 /// memref<?x?x?Xf32, stride_specification> 110 /// store %14#1, %arg2[%i, %k, %j] : 111 /// memref<?x?x?Xf32, stride_specification> 112 /// } 113 /// } 114 /// } 115 /// ``` 116 template <typename LoadOpTy, typename StoreOpTy> 117 static void emitScalarImplementation(OpBuilder &b, Location loc, 118 ArrayRef<Value> allIvs, 119 LinalgOp linalgOp) { 120 assert(linalgOp.hasBufferSemantics() && 121 "expected linalg op with buffer semantics"); 122 unsigned nInputs = linalgOp.getNumInputs(); 123 unsigned nOutputs = linalgOp.getNumOutputs(); 124 SmallVector<Value> indexedValues; 125 indexedValues.reserve(nInputs + nOutputs); 126 127 auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end()); 128 129 // TODO: Avoid the loads if the corresponding argument of the 130 // region has no uses. 131 // 1.a. Emit load from input views. 132 for (unsigned i = 0; i < nInputs; ++i) { 133 auto indexing = makeCanonicalAffineApplies( 134 b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims); 135 indexedValues.push_back( 136 b.create<LoadOpTy>(loc, linalgOp.getInput(i), indexing)); 137 } 138 // 1.b. Emit load from output views. 139 for (unsigned i = 0; i < nOutputs; ++i) { 140 auto indexing = makeCanonicalAffineApplies( 141 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims); 142 indexedValues.push_back( 143 b.create<LoadOpTy>(loc, linalgOp.getOutputBuffer(i), indexing)); 144 } 145 146 // TODO: When a region inliner exists, use it. 147 // 2. Inline region, currently only works for a single basic block. 148 // 3. Emit store. 149 SmallVector<SmallVector<Value>, 8> indexing; 150 SmallVector<Value> outputBuffers; 151 for (unsigned i = 0; i < nOutputs; ++i) { 152 indexing.push_back(makeCanonicalAffineApplies( 153 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims)); 154 outputBuffers.push_back(linalgOp.getOutputBuffer(i)); 155 } 156 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues, 157 indexing, outputBuffers); 158 } 159 160 // Create a padded view into the given `input` tensor using the 'indices' 161 // to access the tensor. `skipPadding` lists the dimensions for which no padding 162 // is needed e.g. the non-spatial dimensions for convolutions. 163 Value getPaddedInput(OpBuilder &b, Location loc, Value input, 164 ArrayRef<Value> indices, ArrayRef<int> skipPadding, 165 Value padValue) { 166 Value zeroIndex = b.create<ConstantIndexOp>(loc, 0); 167 SmallVector<Value> conds; 168 SmallVector<Value> clampedImIdx; 169 for (auto iter : llvm::enumerate(indices)) { 170 int idx = iter.index(); 171 auto dim = iter.value(); 172 if (is_contained(skipPadding, idx)) { 173 clampedImIdx.push_back(dim); 174 continue; 175 } 176 177 Value leftOutOfBound = 178 b.create<CmpIOp>(loc, CmpIPredicate::slt, dim, zeroIndex); 179 if (conds.empty()) 180 conds.push_back(leftOutOfBound); 181 else 182 conds.push_back(b.create<OrOp>(loc, conds.back(), leftOutOfBound)); 183 Value rightBound = b.create<memref::DimOp>(loc, input, idx); 184 Value rightOutOfBound = 185 b.create<CmpIOp>(loc, CmpIPredicate::sge, dim, rightBound); 186 conds.push_back(b.create<OrOp>(loc, conds.back(), rightOutOfBound)); 187 188 // When padding is involved, the indices will only be shifted to negative, 189 // so having a max op is enough. 190 MLIRContext *ctx = input.getContext(); 191 AffineExpr m = getAffineDimExpr(/*position=*/0, ctx), 192 zero = getAffineConstantExpr(0, ctx); 193 AffineMap maxMap = 194 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>>{{m, zero}}) 195 .front(); 196 clampedImIdx.push_back(b.create<AffineMaxOp>(loc, maxMap, ValueRange{dim})); 197 } 198 199 Value readInput = b.create<memref::LoadOp>(loc, input, clampedImIdx); 200 if (conds.empty()) 201 return readInput; 202 203 return b.create<SelectOp>(loc, conds.back(), padValue, readInput); 204 } 205 206 namespace { 207 208 /// The padding value for a given Op depends on the semantics of the Op. 209 /// The identity value for ConvOp and PoolingSumOp is 0, for PoolingMaxOp is 210 /// -inf or minInt and for PoolingMinOp is inf or maxInt. 211 template <typename OpType> Attribute getPadValueAttr(Type type) { 212 llvm_unreachable("Unexpected op type for getPadValueAttr"); 213 return {}; 214 } 215 216 template <> Attribute getPadValueAttr<PoolingMaxOp>(Type type) { 217 if (auto floatType = type.dyn_cast<FloatType>()) { 218 return OpBuilder(type.getContext()) 219 .getFloatAttr(floatType, APFloat::getInf(floatType.getFloatSemantics(), 220 /*Negative*/ true)); 221 } 222 if (auto intType = type.dyn_cast<IntegerType>()) { 223 unsigned width = intType.getWidth(); 224 // The select instruction used to lower the PoolingMin uses a signed 225 // comparison, use a signed constant irrespective of the signedness of the 226 // integer type. 227 return OpBuilder(type.getContext()) 228 .getIntegerAttr(intType, APInt::getSignedMinValue(width)); 229 } 230 llvm_unreachable("Unsupported data type for PoolingMaxOp"); 231 return {}; 232 } 233 234 template <> Attribute getPadValueAttr<PoolingMinOp>(Type type) { 235 if (auto floatType = type.dyn_cast<FloatType>()) { 236 return OpBuilder(type.getContext()) 237 .getFloatAttr(floatType, 238 APFloat::getInf(floatType.getFloatSemantics())); 239 } 240 if (auto intType = type.dyn_cast<IntegerType>()) { 241 unsigned width = intType.getWidth(); 242 // The select instruction used to lower the PoolingMin uses a signed 243 // comparison, use a signed constant irrespective of the signedness of the 244 // integer type. 245 return OpBuilder(type.getContext()) 246 .getIntegerAttr(intType, APInt::getSignedMaxValue(width)); 247 } 248 llvm_unreachable("Unsupported data type for PoolingMinOp"); 249 return {}; 250 } 251 252 template <> Attribute getPadValueAttr<PoolingSumOp>(Type type) { 253 return OpBuilder(type.getContext()).getZeroAttr(type); 254 } 255 256 template <> Attribute getPadValueAttr<ConvOp>(Type type) { 257 return OpBuilder(type.getContext()).getZeroAttr(type); 258 } 259 260 } // namespace 261 262 /// Returns true is `convOp` has a non-zero padding. 263 static bool hasPadding(ConvOp convOp) { 264 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 265 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 266 return true; 267 } 268 return false; 269 } 270 271 template <typename LoadOpTy, typename StoreOpTy> 272 static void emitScalarImplementation(OpBuilder &b, Location loc, 273 ArrayRef<Value> allIvs, ConvOp convOp) { 274 assert(convOp.hasBufferSemantics() && 275 "expected linalg op with buffer semantics"); 276 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 277 auto maps = llvm::to_vector<8>( 278 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 279 SmallVector<Value> fIdx(makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 280 SmallVector<Value> imIdx(makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 281 SmallVector<Value> oIdx(makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 282 283 Value filter = convOp.filter(), output = convOp.output(); 284 285 // Emit scalar form. Padded conv involves an affine.max in the memory access 286 // which is not allowed by affine.load. Override to use an MemRefIndexedValue 287 // when there is non-zero padding. 288 if (hasPadding(convOp)) { 289 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 290 Value padValue = 291 b.create<ConstantOp>(loc, type, getPadValueAttr<ConvOp>(type)); 292 Value paddedInput = 293 getPaddedInput(b, loc, convOp.input(), imIdx, 294 /* Only need to pad the window dimensions */ 295 {0, static_cast<int>(imIdx.size()) - 1}, padValue); 296 Value filterVal = b.create<LoadOpTy>(loc, filter, fIdx); 297 Value mulVal = ArithBuilder(b, loc).mul(filterVal, paddedInput); 298 Value outputVal = b.create<LoadOpTy>(loc, output, oIdx); 299 Value addVal = ArithBuilder(b, loc).add(mulVal, outputVal); 300 b.create<StoreOpTy>(loc, addVal, output, oIdx); 301 } else { 302 Value inputVal = b.create<LoadOpTy>(loc, convOp.input(), imIdx); 303 Value filterVal = b.create<LoadOpTy>(loc, filter, fIdx); 304 Value mulVal = ArithBuilder(b, loc).mul(filterVal, inputVal); 305 Value outputVal = b.create<LoadOpTy>(loc, output, oIdx); 306 Value addVal = ArithBuilder(b, loc).add(mulVal, outputVal); 307 b.create<StoreOpTy>(loc, addVal, output, oIdx); 308 } 309 } 310 311 template <typename PoolingOp> static bool hasPadding(PoolingOp poolingOp) { 312 for (unsigned i = 0, e = poolingOp.getNumWindowLoops(); i < e; ++i) { 313 if (poolingOp.getLowPad(i) > 0 || poolingOp.getHighPad(i) > 0) 314 return true; 315 } 316 return false; 317 } 318 319 template <typename LoadOpTy, typename StoreOpTy, typename PoolingOp> 320 static Value getPoolingInput(OpBuilder &b, Location loc, PoolingOp op, 321 ArrayRef<Value> inputIndices) { 322 if (hasPadding(op)) { 323 Type type = 324 op.input().getType().template cast<MemRefType>().getElementType(); 325 Value padValue = 326 b.create<ConstantOp>(loc, type, getPadValueAttr<PoolingOp>(type)); 327 return getPaddedInput(b, loc, op.input(), inputIndices, 328 /*Pad every dimension*/ {}, padValue); 329 } 330 return b.create<LoadOpTy>(loc, op.input(), inputIndices); 331 } 332 333 template <typename LoadOpTy, typename StoreOpTy, typename OpType> 334 void emitPoolingMinMaxScalarImplementation(OpBuilder &b, Location loc, 335 ArrayRef<Value> allIvs, OpType op) { 336 InputAndOutputIndices indices = getInputAndOutputIndices(b, loc, allIvs, op); 337 Value lhs = b.create<LoadOpTy>(loc, op.output(), indices.outputs); 338 Value rhs = getPoolingInput<LoadOpTy, StoreOpTy>(b, loc, op, indices.inputs); 339 Value value = llvm::TypeSwitch<Operation *, Value>(op) 340 .Case([&](PoolingMinOp poolingOp) { 341 return ArithBuilder(b, loc).select( 342 ArithBuilder(b, loc).slt(lhs, rhs), lhs, rhs); 343 }) 344 .Case([&](PoolingMaxOp poolingOp) { 345 return ArithBuilder(b, loc).select( 346 ArithBuilder(b, loc).sgt(lhs, rhs), lhs, rhs); 347 }) 348 .Default([&](auto) { return Value(); }); 349 b.create<StoreOpTy>(loc, value, op.output(), indices.outputs); 350 } 351 352 template <typename LoadOpTy, typename StoreOpTy> 353 static void emitScalarImplementation(OpBuilder &b, Location loc, 354 ArrayRef<Value> allIvs, PoolingMaxOp op) { 355 emitPoolingMinMaxScalarImplementation<LoadOpTy, StoreOpTy, PoolingMaxOp>( 356 b, loc, allIvs, op); 357 } 358 359 template <typename LoadOpTy, typename StoreOpTy> 360 static void emitScalarImplementation(OpBuilder &b, Location loc, 361 ArrayRef<Value> allIvs, PoolingMinOp op) { 362 emitPoolingMinMaxScalarImplementation<LoadOpTy, StoreOpTy, PoolingMinOp>( 363 b, loc, allIvs, op); 364 } 365 366 template <typename LoadOpTy, typename StoreOpTy> 367 static void emitScalarImplementation(OpBuilder &b, Location loc, 368 ArrayRef<Value> allIvs, PoolingSumOp op) { 369 auto indices = getInputAndOutputIndices(b, loc, allIvs, op); 370 Value inputVal = 371 getPoolingInput<LoadOpTy, StoreOpTy>(b, loc, op, indices.inputs); 372 Value outputVal = b.create<LoadOpTy>(loc, op.output(), indices.outputs); 373 Value added = ArithBuilder(b, loc).add(outputVal, inputVal); 374 b.create<StoreOpTy>(loc, added, op.output(), indices.outputs); 375 } 376 377 /// Replace the index operations in the body of the loop nest by the matching 378 /// induction variables. 379 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, 380 PatternRewriter &rewriter, 381 ArrayRef<Operation *> loopOps) { 382 // Extract the induction variables of the loop nest from outer to inner. 383 SmallVector<Value> allIvs; 384 for (Operation *loopOp : loopOps) { 385 llvm::TypeSwitch<Operation *>(loopOp) 386 .Case([&](scf::ParallelOp parallelOp) { 387 allIvs.append(parallelOp.getInductionVars().begin(), 388 parallelOp.getInductionVars().end()); 389 }) 390 .Case([&](scf::ForOp forOp) { 391 allIvs.push_back(forOp.getInductionVar()); 392 }) 393 .Case([&](AffineForOp affineForOp) { 394 allIvs.push_back(affineForOp.getInductionVar()); 395 }) 396 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 397 } 398 assert(linalgOp.getNumLoops() == allIvs.size() && 399 "expected the number of loops and induction variables to match"); 400 // Replace the index operations in the body of the innermost loop op. 401 if (!loopOps.empty()) { 402 LoopLikeOpInterface loopOp = loopOps.back(); 403 for (IndexOp indexOp : 404 llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) 405 rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); 406 } 407 } 408 409 template <typename LoopTy> 410 static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter, 411 LinalgOp linalgOp) { 412 using LoadOpTy = 413 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value, 414 AffineLoadOp, memref::LoadOp>::type; 415 using StoreOpTy = 416 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value, 417 AffineStoreOp, memref::StoreOp>::type; 418 419 // Canonicalize indexed_generic operations before lowering them to loops. 420 if (isa<IndexedGenericOp>(linalgOp)) 421 return llvm::None; 422 423 // The flattened loopToOperandRangesMaps is expected to be an invertible 424 // permutation map (which is asserted in the inverse calculation). 425 assert(linalgOp.hasBufferSemantics() && 426 "expected linalg op with buffer semantics"); 427 428 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); 429 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); 430 431 SmallVector<Value> allIvs; 432 GenerateLoopNest<LoopTy>::doit( 433 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, 434 [&](OpBuilder &b, Location loc, ValueRange ivs, 435 ValueRange iterArgs) -> scf::ValueVector { 436 assert(iterArgs.empty() && "unexpected iterArgs"); 437 allIvs.append(ivs.begin(), ivs.end()); 438 llvm::TypeSwitch<Operation *>(linalgOp) 439 .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>( 440 [&](auto op) { 441 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, 442 op); 443 }) 444 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 445 return scf::ValueVector{}; 446 }); 447 // Number of loop ops might be different from the number of ivs since some 448 // loops like affine.parallel and scf.parallel have multiple ivs. 449 SetVector<Operation *> loopSet; 450 for (Value iv : allIvs) { 451 if (!iv) 452 return {}; 453 // The induction variable is a block argument of the entry block of the 454 // loop operation. 455 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 456 if (!ivVal) 457 return {}; 458 loopSet.insert(ivVal.getOwner()->getParentOp()); 459 } 460 LinalgLoops loops(loopSet.begin(), loopSet.end()); 461 // Replace all index operations in the loop body. 462 replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops); 463 return loops; 464 } 465 466 namespace { 467 template <typename LoopType> 468 class LinalgRewritePattern : public RewritePattern { 469 public: 470 LinalgRewritePattern(MLIRContext *context) 471 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 472 473 LogicalResult matchAndRewrite(Operation *op, 474 PatternRewriter &rewriter) const override { 475 auto linalgOp = dyn_cast<LinalgOp>(op); 476 if (!isa<LinalgOp>(op)) 477 return failure(); 478 if (!linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)) 479 return failure(); 480 rewriter.eraseOp(op); 481 return success(); 482 } 483 }; 484 485 struct TiledLoopToSCFPattern : public OpRewritePattern<TiledLoopOp> { 486 using OpRewritePattern<TiledLoopOp>::OpRewritePattern; 487 488 LogicalResult matchAndRewrite(TiledLoopOp tiledLoop, 489 PatternRewriter &rewriter) const override { 490 Location loc = tiledLoop.getLoc(); 491 492 // Fail conversion if the `tiled_loop` has not been bufferized. 493 if (!llvm::all_of(tiledLoop.outputs(), [&](Value arg) { 494 return arg.getType().isa<MemRefType>(); 495 })) 496 return failure(); 497 498 // TODO: Build loop nest with `scf.for` and `scf.parallel` depending on the 499 // iterator type. 500 scf::buildLoopNest(rewriter, loc, tiledLoop.lowerBound(), 501 tiledLoop.upperBound(), tiledLoop.step(), 502 [&](OpBuilder &builder, Location loc, ValueRange ivs) { 503 // Move body without its terminator. 504 SmallVector<Value> newBlockArgs; 505 newBlockArgs.append(ivs.begin(), ivs.end()); 506 newBlockArgs.append(tiledLoop.inputs().begin(), 507 tiledLoop.inputs().end()); 508 newBlockArgs.append(tiledLoop.outputs().begin(), 509 tiledLoop.outputs().end()); 510 Block *newBody = rewriter.getInsertionBlock(); 511 rewriter.mergeBlocks(tiledLoop.getBody(), newBody, 512 newBlockArgs); 513 rewriter.eraseOp(newBody->getTerminator()); 514 }); 515 rewriter.eraseOp(tiledLoop); 516 return success(); 517 } 518 }; 519 520 /// Local folding pattern for AffineApplyOp that we can apply greedily. 521 /// This replaces AffineApplyOp by the proper value in cases where the 522 /// associated map is trivial. 523 /// A trivial map here is defined as a map with a single result and either: 524 /// 1. Zero operand + returns a single AffineConstantExpr 525 /// 2. One operand + returns a single AffineDimExpr 526 /// 3. One operand + returns a single AffineSymbolExpr 527 // 528 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 529 /// other cases, it is replaced by its unique operand. 530 struct FoldAffineOp : public RewritePattern { 531 FoldAffineOp(MLIRContext *context) 532 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 533 534 LogicalResult matchAndRewrite(Operation *op, 535 PatternRewriter &rewriter) const override { 536 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 537 auto map = affineApplyOp.getAffineMap(); 538 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 539 return failure(); 540 541 AffineExpr expr = map.getResult(0); 542 if (map.getNumInputs() == 0) { 543 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 544 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 545 return success(); 546 } 547 return failure(); 548 } 549 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 550 rewriter.replaceOp(op, op->getOperand(0)); 551 return success(); 552 } 553 return failure(); 554 } 555 }; 556 557 template <typename LoopType> 558 static void lowerLinalgToLoopsImpl(FuncOp funcOp) { 559 MLIRContext *context = funcOp.getContext(); 560 RewritePatternSet patterns(context); 561 patterns.add<LinalgRewritePattern<LoopType>>(context); 562 memref::DimOp::getCanonicalizationPatterns(patterns, context); 563 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 564 patterns.add<FoldAffineOp>(context); 565 // Just apply the patterns greedily. 566 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 567 } 568 569 struct LowerToAffineLoops 570 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 571 void getDependentDialects(DialectRegistry ®istry) const override { 572 registry.insert<memref::MemRefDialect>(); 573 } 574 void runOnFunction() override { 575 lowerLinalgToLoopsImpl<AffineForOp>(getFunction()); 576 } 577 }; 578 579 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 580 void getDependentDialects(DialectRegistry ®istry) const override { 581 registry.insert<memref::MemRefDialect, scf::SCFDialect>(); 582 } 583 void runOnFunction() override { 584 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction()); 585 } 586 }; 587 588 struct LowerToParallelLoops 589 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 590 void runOnFunction() override { 591 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction()); 592 } 593 }; 594 595 struct LowerTiledLoopsToSCF 596 : public LinalgLowerTiledLoopsToSCFBase<LowerTiledLoopsToSCF> { 597 void runOnFunction() override { 598 MLIRContext *context = &getContext(); 599 RewritePatternSet patterns(context); 600 populateTiledLoopToSCFPattern(patterns); 601 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 602 } 603 }; 604 } // namespace 605 606 void mlir::linalg::populateTiledLoopToSCFPattern(RewritePatternSet &patterns) { 607 patterns.add<TiledLoopToSCFPattern>(patterns.getContext()); 608 } 609 610 std::unique_ptr<OperationPass<FuncOp>> 611 mlir::createConvertLinalgTiledLoopsToSCFPass() { 612 return std::make_unique<LowerTiledLoopsToSCF>(); 613 } 614 615 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 616 return std::make_unique<LowerToLoops>(); 617 } 618 619 std::unique_ptr<OperationPass<FuncOp>> 620 mlir::createConvertLinalgToParallelLoopsPass() { 621 return std::make_unique<LowerToParallelLoops>(); 622 } 623 624 std::unique_ptr<OperationPass<FuncOp>> 625 mlir::createConvertLinalgToAffineLoopsPass() { 626 return std::make_unique<LowerToAffineLoops>(); 627 } 628 629 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. 630 Optional<LinalgLoops> 631 mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, 632 LinalgOp linalgOp) { 633 return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp); 634 } 635 636 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. 637 Optional<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, 638 LinalgOp linalgOp) { 639 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp); 640 } 641 642 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. 643 Optional<LinalgLoops> 644 mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, 645 LinalgOp linalgOp) { 646 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp); 647 } 648