1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/SCF/EDSC/Builders.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "mlir/Support/LLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/FoldUtils.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 29 using namespace mlir; 30 using namespace mlir::edsc; 31 using namespace mlir::edsc::intrinsics; 32 using namespace mlir::linalg; 33 34 using edsc::op::operator+; 35 36 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b, 37 Location loc, 38 AffineMap map, 39 ArrayRef<Value> vals) { 40 if (map.isEmpty()) 41 return {}; 42 43 assert(map.getNumInputs() == vals.size()); 44 SmallVector<Value, 8> res; 45 res.reserve(map.getNumResults()); 46 auto dims = map.getNumDims(); 47 for (auto e : map.getResults()) { 48 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 49 SmallVector<Value, 4> operands(vals.begin(), vals.end()); 50 canonicalizeMapAndOperands(&exprMap, &operands); 51 res.push_back(affine_apply(exprMap, operands)); 52 } 53 return res; 54 } 55 56 template <typename IndexedValueType, typename OpType> 57 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues, 58 ArrayRef<SmallVector<Value, 8>> indexing, 59 ArrayRef<Value> outputBuffers) { 60 assert(op->getNumRegions() == 1 && "Expected single region op"); 61 auto &b = ScopedContext::getBuilderRef(); 62 auto &block = op->getRegion(0).front(); 63 BlockAndValueMapping map; 64 map.map(block.getArguments(), indexedValues); 65 for (auto &op : block.without_terminator()) { 66 auto *newOp = b.clone(op, map); 67 map.map(op.getResults(), newOp->getResults()); 68 } 69 70 Operation &terminator = block.back(); 71 assert(isa<linalg::YieldOp>(terminator) && 72 "expected a yield op in the end of the region"); 73 for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { 74 IndexedValueType O(outputBuffers[i]); 75 O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); 76 } 77 } 78 79 // Returns a pair that contains input indices and output indices of a 80 // SingleInputPoolingOp `op`. 81 struct InputAndOutputIndices { 82 SmallVector<Value, 8> inputs; 83 SmallVector<Value, 8> outputs; 84 }; 85 template <typename SingleInputPoolingOp> 86 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs, 87 SingleInputPoolingOp op) { 88 auto &b = ScopedContext::getBuilderRef(); 89 auto loc = ScopedContext::getLocation(); 90 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 91 auto maps = llvm::to_vector<8>( 92 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 93 return InputAndOutputIndices{ 94 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 95 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 96 } 97 98 /// Emits the MLIR for the scalar part of the generic op by: 99 /// 1. Emitting load ops for each input and output view in order. This is 100 /// achieved by applying the appropriate input or output map to the 101 /// enclosing induction variables. 102 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 103 /// from point 1. above. 104 /// 3. Emitting store ops to store the results of 2. to the output 105 /// views. 106 /// 107 /// An example output may resemble: 108 /// 109 /// ``` 110 /// scf.for %i = %c0 to %0 step %c1 { 111 /// scf.for %j = %c0 to %1 step %c1 { 112 /// scf.for %k = %c0 to %4 step %c1 { 113 /// %11 = load %arg0[%i, %j] : 114 /// memref<?x?xf32, stride_specification> 115 /// %12 = load %arg1[%i, %j, %k] : 116 /// memref<?x?x?xf32, stride_specification> 117 /// %13 = load %arg2[%i, %k, %j] : 118 /// memref<?x?x?xf32, stride_specification> 119 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 120 /// store %14#0, %arg1[%i, %j, %k] : 121 /// memref<?x?x?Xf32, stride_specification> 122 /// store %14#1, %arg2[%i, %k, %j] : 123 /// memref<?x?x?Xf32, stride_specification> 124 /// } 125 /// } 126 /// } 127 /// ``` 128 template <typename IndexedValueType> 129 static void emitScalarImplementation(ArrayRef<Value> allIvs, 130 LinalgOp linalgOp) { 131 assert(linalgOp.hasBufferSemantics() && 132 "expected linalg op with buffer semantics"); 133 auto &b = ScopedContext::getBuilderRef(); 134 auto loc = ScopedContext::getLocation(); 135 unsigned nInputs = linalgOp.getNumInputs(); 136 unsigned nOutputs = linalgOp.getNumOutputs(); 137 SmallVector<Value, 4> indexedValues; 138 indexedValues.reserve(nInputs + nOutputs); 139 140 auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end()); 141 142 // TODO: Avoid the loads if the corresponding argument of the 143 // region has no uses. 144 // 1.a. Emit load from input views. 145 for (unsigned i = 0; i < nInputs; ++i) { 146 auto indexing = makeCanonicalAffineApplies( 147 b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims); 148 // Passing through IndexedValueType emits the proper load operation. 149 indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); 150 } 151 // 1.b. Emit load from output views. 152 for (unsigned i = 0; i < nOutputs; ++i) { 153 auto indexing = makeCanonicalAffineApplies( 154 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims); 155 // Passing through IndexedValueType emits the proper load operation. 156 indexedValues.push_back( 157 IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); 158 } 159 160 // TODO: When a region inliner exists, use it. 161 // 2. Inline region, currently only works for a single basic block. 162 // 3. Emit store. 163 SmallVector<SmallVector<Value, 8>, 8> indexing; 164 SmallVector<Value, 8> outputBuffers; 165 for (unsigned i = 0; i < nOutputs; ++i) { 166 indexing.push_back(makeCanonicalAffineApplies( 167 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims)); 168 outputBuffers.push_back(linalgOp.getOutputBuffer(i)); 169 } 170 inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing, 171 outputBuffers); 172 } 173 174 // Create a padded view into the given `input` tensor using the 'indices' 175 // to access the tensor. `skipPadding` lists the dimensions for which no padding 176 // is needed e.g. the non-spatial dimensions for convolutions. 177 template <typename IndexedValueType> 178 Value getPaddedInput(Value input, ArrayRef<Value> indices, 179 ArrayRef<int> skipPadding, Value padValue) { 180 // TODO: add a level of indirection to linalg.generic. 181 182 IndexedValueType indexedInput(input); 183 184 auto *context = ScopedContext::getContext(); 185 Value zeroIndex = std_constant_index(0); 186 SmallVector<Value, 8> conds; 187 SmallVector<Value, 8> clampedImIdx; 188 for (auto iter : llvm::enumerate(indices)) { 189 int idx = iter.index(); 190 auto dim = iter.value(); 191 if (is_contained(skipPadding, idx)) { 192 clampedImIdx.push_back(dim); 193 continue; 194 } 195 196 using edsc::op::sge; 197 using edsc::op::slt; 198 using edsc::op::operator||; 199 Value leftOutOfBound = slt(dim, zeroIndex); 200 if (conds.empty()) 201 conds.push_back(leftOutOfBound); 202 else 203 conds.push_back(conds.back() || leftOutOfBound); 204 Value rightBound = memref_dim(input, idx); 205 conds.push_back(conds.back() || (sge(dim, rightBound))); 206 207 // When padding is involved, the indices will only be shifted to negative, 208 // so having a max op is enough. 209 auto maxMap = AffineMap::get(/*dimCount=*/1, 0, 210 {getAffineDimExpr(/*position=*/0, context), 211 getAffineConstantExpr(0, context)}, 212 context); 213 clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim})); 214 } 215 216 Value readInput = indexedInput(clampedImIdx); 217 return conds.empty() ? readInput 218 : (Value)std_select(conds.back(), padValue, readInput); 219 } 220 221 namespace { 222 223 /// The padding value for a given Op depends on the semantics of the Op. 224 /// The identity value for ConvOp and PoolingSumOp is 0, for PoolingMaxOp is 225 /// -inf or minInt and for PoolingMinOp is inf or maxInt. 226 template <typename OpType> Attribute getPadValueAttr(Type type) { 227 llvm_unreachable("Unexpected op type for getPadValueAttr"); 228 return {}; 229 } 230 231 template <> Attribute getPadValueAttr<PoolingMaxOp>(Type type) { 232 auto &b = ScopedContext::getBuilderRef(); 233 if (auto floatType = type.dyn_cast<FloatType>()) { 234 return b.getFloatAttr( 235 floatType, 236 APFloat::getInf(floatType.getFloatSemantics(), /*Negative*/ true)); 237 } 238 if (auto intType = type.dyn_cast<IntegerType>()) { 239 unsigned width = intType.getWidth(); 240 // The select instruction used to lower the PoolingMin uses a signed 241 // comparison, use a signed constant irrespective of the signedness of the 242 // integer type. 243 return b.getIntegerAttr(intType, APInt::getSignedMinValue(width)); 244 } 245 llvm_unreachable("Unsupported data type for PoolingMaxOp"); 246 return {}; 247 } 248 249 template <> Attribute getPadValueAttr<PoolingMinOp>(Type type) { 250 auto &b = ScopedContext::getBuilderRef(); 251 if (auto floatType = type.dyn_cast<FloatType>()) { 252 return b.getFloatAttr(floatType, 253 APFloat::getInf(floatType.getFloatSemantics())); 254 } 255 if (auto intType = type.dyn_cast<IntegerType>()) { 256 unsigned width = intType.getWidth(); 257 // The select instruction used to lower the PoolingMin uses a signed 258 // comparison, use a signed constant irrespective of the signedness of the 259 // integer type. 260 return b.getIntegerAttr(intType, APInt::getSignedMaxValue(width)); 261 } 262 llvm_unreachable("Unsupported data type for PoolingMinOp"); 263 return {}; 264 } 265 266 template <> Attribute getPadValueAttr<PoolingSumOp>(Type type) { 267 auto &b = ScopedContext::getBuilderRef(); 268 return b.getZeroAttr(type); 269 } 270 271 template <> Attribute getPadValueAttr<ConvOp>(Type type) { 272 auto &b = ScopedContext::getBuilderRef(); 273 return b.getZeroAttr(type); 274 } 275 276 } // namespace 277 278 /// Returns true is `convOp` has a non-zero padding. 279 static bool hasPadding(ConvOp convOp) { 280 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 281 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 282 return true; 283 } 284 return false; 285 } 286 287 template <typename IndexedValueType> 288 static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { 289 assert(convOp.hasBufferSemantics() && 290 "expected linalg op with buffer semantics"); 291 auto &b = ScopedContext::getBuilderRef(); 292 auto loc = ScopedContext::getLocation(); 293 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 294 auto maps = llvm::to_vector<8>( 295 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 296 SmallVector<Value, 8> fIdx( 297 makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 298 SmallVector<Value, 8> imIdx( 299 makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 300 SmallVector<Value, 8> oIdx( 301 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 302 303 IndexedValueType F(convOp.filter()), O(convOp.output()); 304 305 // Emit scalar form. Padded conv involves an affine.max in the memory access 306 // which is not allowed by affine.load. Override to use an MemRefIndexedValue 307 // when there is non-zero padding. 308 if (hasPadding(convOp)) { 309 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 310 Value padValue = std_constant(type, getPadValueAttr<ConvOp>(type)); 311 Value paddedInput = getPaddedInput<MemRefIndexedValue>( 312 convOp.input(), imIdx, 313 /* Only need to pad the window dimensions */ 314 {0, static_cast<int>(imIdx.size()) - 1}, padValue); 315 O(oIdx) += F(fIdx) * paddedInput; 316 } else { 317 IndexedValueType I(convOp.input()); 318 O(oIdx) += F(fIdx) * I(imIdx); 319 } 320 } 321 322 template <typename PoolingOp> static bool hasPadding(PoolingOp poolingOp) { 323 for (unsigned i = 0, e = poolingOp.getNumWindowLoops(); i < e; ++i) { 324 if (poolingOp.getLowPad(i) > 0 || poolingOp.getHighPad(i) > 0) 325 return true; 326 } 327 return false; 328 } 329 330 template <typename IndexedValueType, typename PoolingOp> 331 static Value getPoolingInput(PoolingOp op, ArrayRef<Value> inputIndices) { 332 if (hasPadding(op)) { 333 Type type = 334 op.input().getType().template cast<MemRefType>().getElementType(); 335 Value padValue = std_constant(type, getPadValueAttr<PoolingOp>(type)); 336 return getPaddedInput<MemRefIndexedValue>(op.input(), inputIndices, 337 /*Pad every dimension*/ {}, 338 padValue); 339 } 340 IndexedValueType input(op.input()); 341 return input(inputIndices); 342 } 343 344 template <typename IndexedValueType, typename OpType> 345 void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs, OpType op) { 346 InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); 347 // Emit scalar form. 348 IndexedValueType output(op.output()); 349 Value lhs = output(indices.outputs); 350 Value rhs = getPoolingInput<IndexedValueType>(op, indices.inputs); 351 using edsc::op::sgt; 352 using edsc::op::slt; 353 Value value = std::is_same<OpType, PoolingMinOp>() 354 ? std_select(slt(lhs, rhs), lhs, rhs) 355 : std_select(sgt(lhs, rhs), lhs, rhs); 356 output(indices.outputs) = value; 357 } 358 359 template <typename IndexedValueType> 360 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) { 361 emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs, 362 op); 363 } 364 365 template <typename IndexedValueType> 366 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) { 367 emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs, 368 op); 369 } 370 371 template <typename IndexedValueType> 372 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) { 373 auto indices = getInputAndOutputIndices(allIvs, op); 374 IndexedValueType output(op.output()); 375 376 // Emit scalar form. 377 output(indices.outputs) += 378 getPoolingInput<IndexedValueType>(op, indices.inputs); 379 } 380 381 template <typename LoopTy> 382 static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp, 383 OpBuilder &builder) { 384 using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; 385 ScopedContext scope(builder, linalgOp.getLoc()); 386 387 // The flattened loopToOperandRangesMaps is expected to be an invertible 388 // permutation map (which is asserted in the inverse calculation). 389 assert(linalgOp.hasBufferSemantics() && 390 "expected linalg op with buffer semantics"); 391 392 auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc()); 393 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); 394 395 SmallVector<Value, 4> allIvs; 396 GenerateLoopNest<LoopTy>::doit( 397 loopRanges, linalgOp, iteratorTypes, 398 [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { 399 assert(iterArgs.empty() && "unexpected iterArgs"); 400 allIvs.append(ivs.begin(), ivs.end()); 401 llvm::TypeSwitch<Operation *>(linalgOp) 402 .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>( 403 [&](auto op) { 404 emitScalarImplementation<IndexedValueTy>(allIvs, op); 405 }) 406 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 407 return scf::ValueVector{}; 408 }); 409 // Number of loop ops might be different from the number of ivs since some 410 // loops like affine.parallel and scf.parallel have multiple ivs. 411 SetVector<Operation *> loopSet; 412 for (Value iv : allIvs) { 413 if (!iv) 414 return {}; 415 // The induction variable is a block argument of the entry block of the 416 // loop operation. 417 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 418 if (!ivVal) 419 return {}; 420 loopSet.insert(ivVal.getOwner()->getParentOp()); 421 } 422 LinalgLoops loops(loopSet.begin(), loopSet.end()); 423 return loops; 424 } 425 426 /// Replace the index operations in the body of the loop nest by the matching 427 /// induction variables. 428 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, 429 PatternRewriter &rewriter, 430 ArrayRef<Operation *> loopOps) { 431 // Extract the induction variables of the loop nest from outer to inner. 432 SmallVector<Value> allIvs; 433 for (Operation *loopOp : loopOps) { 434 llvm::TypeSwitch<Operation *>(loopOp) 435 .Case([&](scf::ParallelOp parallelOp) { 436 allIvs.append(parallelOp.getInductionVars().begin(), 437 parallelOp.getInductionVars().end()); 438 }) 439 .Case([&](scf::ForOp forOp) { 440 allIvs.push_back(forOp.getInductionVar()); 441 }) 442 .Case([&](AffineForOp affineForOp) { 443 allIvs.push_back(affineForOp.getInductionVar()); 444 }) 445 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 446 } 447 assert(linalgOp.getNumLoops() == allIvs.size() && 448 "expected the number of loops and induction variables to match"); 449 // Replace the index operations in the body of the innermost loop op. 450 if (!loopOps.empty()) { 451 LoopLikeOpInterface loopOp = loopOps.back(); 452 for (IndexOp indexOp : 453 llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) 454 rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); 455 } 456 } 457 458 namespace { 459 template <typename LoopType> 460 class LinalgRewritePattern : public RewritePattern { 461 public: 462 LinalgRewritePattern(MLIRContext *context) 463 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 464 465 LogicalResult matchAndRewrite(Operation *op, 466 PatternRewriter &rewriter) const override { 467 auto linalgOp = dyn_cast<LinalgOp>(op); 468 if (!isa<LinalgOp>(op)) 469 return failure(); 470 if (!linalgLowerOpToLoops<LoopType>(rewriter, linalgOp)) 471 return failure(); 472 rewriter.eraseOp(op); 473 return success(); 474 } 475 }; 476 477 struct TiledLoopToSCFPattern : public OpRewritePattern<TiledLoopOp> { 478 using OpRewritePattern<TiledLoopOp>::OpRewritePattern; 479 480 LogicalResult matchAndRewrite(TiledLoopOp tiledLoop, 481 PatternRewriter &rewriter) const override { 482 Location loc = tiledLoop.getLoc(); 483 484 // Fail conversion if the `tiled_loop` has not been bufferized. 485 if (!llvm::all_of(tiledLoop.outputs(), [&](Value arg) { 486 return arg.getType().isa<MemRefType>(); 487 })) 488 return failure(); 489 490 // TODO: Build loop nest with `scf.for` and `scf.parallel` depending on the 491 // iterator type. 492 scf::buildLoopNest(rewriter, loc, tiledLoop.lowerBound(), 493 tiledLoop.upperBound(), tiledLoop.step(), 494 [&](OpBuilder &builder, Location loc, ValueRange ivs) { 495 // Move body without its terminator. 496 SmallVector<Value, 16> newBlockArgs; 497 newBlockArgs.append(ivs.begin(), ivs.end()); 498 newBlockArgs.append(tiledLoop.inputs().begin(), 499 tiledLoop.inputs().end()); 500 newBlockArgs.append(tiledLoop.outputs().begin(), 501 tiledLoop.outputs().end()); 502 Block *newBody = rewriter.getInsertionBlock(); 503 rewriter.mergeBlocks(tiledLoop.getBody(), newBody, 504 newBlockArgs); 505 rewriter.eraseOp(newBody->getTerminator()); 506 }); 507 rewriter.eraseOp(tiledLoop); 508 return success(); 509 } 510 }; 511 512 /// Local folding pattern for AffineApplyOp that we can apply greedily. 513 /// This replaces AffineApplyOp by the proper value in cases where the 514 /// associated map is trivial. 515 /// A trivial map here is defined as a map with a single result and either: 516 /// 1. Zero operand + returns a single AffineConstantExpr 517 /// 2. One operand + returns a single AffineDimExpr 518 /// 3. One operand + returns a single AffineSymbolExpr 519 // 520 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 521 /// other cases, it is replaced by its unique operand. 522 struct FoldAffineOp : public RewritePattern { 523 FoldAffineOp(MLIRContext *context) 524 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 525 526 LogicalResult matchAndRewrite(Operation *op, 527 PatternRewriter &rewriter) const override { 528 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 529 auto map = affineApplyOp.getAffineMap(); 530 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 531 return failure(); 532 533 AffineExpr expr = map.getResult(0); 534 if (map.getNumInputs() == 0) { 535 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 536 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 537 return success(); 538 } 539 return failure(); 540 } 541 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 542 rewriter.replaceOp(op, op->getOperand(0)); 543 return success(); 544 } 545 return failure(); 546 } 547 }; 548 549 template <typename LoopType> 550 static void lowerLinalgToLoopsImpl(FuncOp funcOp) { 551 MLIRContext *context = funcOp.getContext(); 552 RewritePatternSet patterns(context); 553 patterns.add<LinalgRewritePattern<LoopType>>(context); 554 memref::DimOp::getCanonicalizationPatterns(patterns, context); 555 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 556 patterns.add<FoldAffineOp>(context); 557 // Just apply the patterns greedily. 558 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 559 } 560 561 struct LowerToAffineLoops 562 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 563 void getDependentDialects(DialectRegistry ®istry) const override { 564 registry.insert<memref::MemRefDialect>(); 565 } 566 void runOnFunction() override { 567 lowerLinalgToLoopsImpl<AffineForOp>(getFunction()); 568 } 569 }; 570 571 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 572 void getDependentDialects(DialectRegistry ®istry) const override { 573 registry.insert<memref::MemRefDialect, scf::SCFDialect>(); 574 } 575 void runOnFunction() override { 576 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction()); 577 } 578 }; 579 580 struct LowerToParallelLoops 581 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 582 void runOnFunction() override { 583 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction()); 584 } 585 }; 586 587 struct LowerTiledLoopsToSCF 588 : public LinalgLowerTiledLoopsToSCFBase<LowerTiledLoopsToSCF> { 589 void runOnFunction() override { 590 MLIRContext *context = &getContext(); 591 RewritePatternSet patterns(context); 592 patterns.add<TiledLoopToSCFPattern>(context); 593 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 594 } 595 }; 596 } // namespace 597 598 std::unique_ptr<OperationPass<FuncOp>> 599 mlir::createConvertLinalgTiledLoopsToSCFPass() { 600 return std::make_unique<LowerTiledLoopsToSCF>(); 601 } 602 603 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 604 return std::make_unique<LowerToLoops>(); 605 } 606 607 std::unique_ptr<OperationPass<FuncOp>> 608 mlir::createConvertLinalgToParallelLoopsPass() { 609 return std::make_unique<LowerToParallelLoops>(); 610 } 611 612 std::unique_ptr<OperationPass<FuncOp>> 613 mlir::createConvertLinalgToAffineLoopsPass() { 614 return std::make_unique<LowerToAffineLoops>(); 615 } 616 617 /// Emits a loop nest with the proper body for `linalgOp`. 618 template <typename LoopTy> 619 Optional<LinalgLoops> 620 mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, 621 LinalgOp linalgOp) { 622 // Convert indexed_generic ops to generic ops before lowering them to loops. 623 if (isa<IndexedGenericOp>(linalgOp)) 624 return llvm::None; 625 626 Optional<LinalgLoops> loopOps = 627 linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter); 628 if (loopOps.hasValue()) 629 replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue()); 630 return loopOps; 631 } 632 633 template Optional<LinalgLoops> 634 mlir::linalg::linalgLowerOpToLoops<AffineForOp>(PatternRewriter &rewriter, 635 LinalgOp linalgOp); 636 template Optional<LinalgLoops> 637 mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(PatternRewriter &rewriter, 638 LinalgOp linalgOp); 639 template Optional<LinalgLoops> 640 mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(PatternRewriter &rewriter, 641 LinalgOp linalgOp); 642 643 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. 644 LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, 645 LinalgOp linalgOp) { 646 Optional<LinalgLoops> loops = 647 linalgLowerOpToLoops<AffineForOp>(rewriter, linalgOp); 648 return loops ? success() : failure(); 649 } 650 651 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. 652 LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, 653 LinalgOp linalgOp) { 654 Optional<LinalgLoops> loops = 655 linalgLowerOpToLoops<scf::ForOp>(rewriter, linalgOp); 656 return loops ? success() : failure(); 657 } 658 659 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. 660 LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, 661 LinalgOp linalgOp) { 662 Optional<LinalgLoops> loops = 663 linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp); 664 return loops ? success() : failure(); 665 } 666