1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "mlir/Support/LLVM.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 28 static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc, 29 AffineMap map, 30 ArrayRef<Value> vals) { 31 if (map.isEmpty()) 32 return {}; 33 34 assert(map.getNumInputs() == vals.size()); 35 SmallVector<Value> res; 36 res.reserve(map.getNumResults()); 37 auto dims = map.getNumDims(); 38 for (auto e : map.getResults()) { 39 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 40 SmallVector<Value> operands(vals.begin(), vals.end()); 41 canonicalizeMapAndOperands(&exprMap, &operands); 42 res.push_back(b.create<AffineApplyOp>(loc, exprMap, operands)); 43 } 44 return res; 45 } 46 47 template <typename LoadOpTy, typename StoreOpTy, typename OpType> 48 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, 49 ArrayRef<Value> indexedValues, 50 ArrayRef<SmallVector<Value>> indexing, 51 ArrayRef<Value> outputBuffers) { 52 auto &block = op->getRegion(0).front(); 53 BlockAndValueMapping map; 54 map.map(block.getArguments(), indexedValues); 55 for (auto &op : block.without_terminator()) { 56 auto *newOp = b.clone(op, map); 57 map.map(op.getResults(), newOp->getResults()); 58 } 59 60 Operation *terminator = block.getTerminator(); 61 for (OpOperand &operand : terminator->getOpOperands()) { 62 Value toStore = map.lookupOrDefault(operand.get()); 63 b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()], 64 indexing[operand.getOperandNumber()]); 65 } 66 } 67 68 // Returns a pair that contains input indices and output indices of a 69 // SingleInputPoolingOp `op`. 70 struct InputAndOutputIndices { 71 SmallVector<Value> inputs; 72 SmallVector<Value> outputs; 73 }; 74 template <typename SingleInputPoolingOp> 75 static InputAndOutputIndices 76 getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs, 77 SingleInputPoolingOp op) { 78 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 79 auto maps = llvm::to_vector<8>( 80 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 81 return InputAndOutputIndices{ 82 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 83 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 84 } 85 86 /// Emits the MLIR for the scalar part of the generic op by: 87 /// 1. Emitting load ops for each input and output view in order. This is 88 /// achieved by applying the appropriate input or output map to the 89 /// enclosing induction variables. 90 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 91 /// from point 1. above. 92 /// 3. Emitting store ops to store the results of 2. to the output 93 /// views. 94 /// 95 /// An example output may resemble: 96 /// 97 /// ``` 98 /// scf.for %i = %c0 to %0 step %c1 { 99 /// scf.for %j = %c0 to %1 step %c1 { 100 /// scf.for %k = %c0 to %4 step %c1 { 101 /// %11 = load %arg0[%i, %j] : 102 /// memref<?x?xf32, stride_specification> 103 /// %12 = load %arg1[%i, %j, %k] : 104 /// memref<?x?x?xf32, stride_specification> 105 /// %13 = load %arg2[%i, %k, %j] : 106 /// memref<?x?x?xf32, stride_specification> 107 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 108 /// store %14#0, %arg1[%i, %j, %k] : 109 /// memref<?x?x?Xf32, stride_specification> 110 /// store %14#1, %arg2[%i, %k, %j] : 111 /// memref<?x?x?Xf32, stride_specification> 112 /// } 113 /// } 114 /// } 115 /// ``` 116 template <typename LoadOpTy, typename StoreOpTy> 117 static void emitScalarImplementation(OpBuilder &b, Location loc, 118 ArrayRef<Value> allIvs, 119 LinalgOp linalgOp) { 120 assert(linalgOp.hasBufferSemantics() && 121 "expected linalg op with buffer semantics"); 122 SmallVector<Value> indexedValues; 123 indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); 124 125 auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end()); 126 127 // TODO: Avoid the loads if the corresponding argument of the 128 // region has no uses. 129 // 1.a. Emit load from input operand or for scalars access the operand itself. 130 for (OpOperand *inputOperand : linalgOp.getInputOperands()) { 131 if (linalgOp.isScalar(inputOperand)) { 132 indexedValues.push_back(inputOperand->get()); 133 continue; 134 } 135 auto indexing = makeCanonicalAffineApplies( 136 b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims); 137 indexedValues.push_back( 138 b.create<LoadOpTy>(loc, inputOperand->get(), indexing)); 139 } 140 // 1.b. Emit load from output views. 141 for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { 142 SmallVector<Value> indexing = makeCanonicalAffineApplies( 143 b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims); 144 indexedValues.push_back( 145 b.create<LoadOpTy>(loc, outputOperand->get(), indexing)); 146 } 147 148 // TODO: When a region inliner exists, use it. 149 // 2. Inline region, currently only works for a single basic block. 150 // 3. Emit store. 151 SmallVector<SmallVector<Value>, 8> indexing; 152 SmallVector<Value> outputBuffers; 153 for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) { 154 indexing.push_back(makeCanonicalAffineApplies( 155 b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims)); 156 outputBuffers.push_back(outputOperand->get()); 157 } 158 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues, 159 indexing, outputBuffers); 160 } 161 162 // Create a padded view into the given `input` tensor using the 'indices' 163 // to access the tensor. `skipPadding` lists the dimensions for which no padding 164 // is needed e.g. the non-spatial dimensions for convolutions. 165 Value getPaddedInput(OpBuilder &b, Location loc, Value input, 166 ArrayRef<Value> indices, ArrayRef<int> skipPadding, 167 Value padValue) { 168 Value zeroIndex = b.create<ConstantIndexOp>(loc, 0); 169 SmallVector<Value> conds; 170 SmallVector<Value> clampedImIdx; 171 for (auto iter : llvm::enumerate(indices)) { 172 int idx = iter.index(); 173 auto dim = iter.value(); 174 if (is_contained(skipPadding, idx)) { 175 clampedImIdx.push_back(dim); 176 continue; 177 } 178 179 Value leftOutOfBound = 180 b.create<CmpIOp>(loc, CmpIPredicate::slt, dim, zeroIndex); 181 if (conds.empty()) 182 conds.push_back(leftOutOfBound); 183 else 184 conds.push_back(b.create<OrOp>(loc, conds.back(), leftOutOfBound)); 185 Value rightBound = createOrFoldDimOp(b, loc, input, idx); 186 Value rightOutOfBound = 187 b.create<CmpIOp>(loc, CmpIPredicate::sge, dim, rightBound); 188 conds.push_back(b.create<OrOp>(loc, conds.back(), rightOutOfBound)); 189 190 // When padding is involved, the indices will only be shifted to negative, 191 // so having a max op is enough. 192 MLIRContext *ctx = input.getContext(); 193 AffineExpr m = getAffineDimExpr(/*position=*/0, ctx), 194 zero = getAffineConstantExpr(0, ctx); 195 AffineMap maxMap = 196 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>>{{m, zero}}) 197 .front(); 198 clampedImIdx.push_back(b.create<AffineMaxOp>(loc, maxMap, ValueRange{dim})); 199 } 200 201 Value readInput = b.create<memref::LoadOp>(loc, input, clampedImIdx); 202 if (conds.empty()) 203 return readInput; 204 205 return b.create<SelectOp>(loc, conds.back(), padValue, readInput); 206 } 207 208 namespace { 209 210 /// The padding value for a given Op depends on the semantics of the Op. 211 /// The identity value for ConvOp and PoolingSumOp is 0, for PoolingMaxOp is 212 /// -inf or minInt and for PoolingMinOp is inf or maxInt. 213 template <typename OpType> Attribute getPadValueAttr(Type type) { 214 llvm_unreachable("Unexpected op type for getPadValueAttr"); 215 return {}; 216 } 217 218 template <> Attribute getPadValueAttr<PoolingMaxOp>(Type type) { 219 if (auto floatType = type.dyn_cast<FloatType>()) { 220 return OpBuilder(type.getContext()) 221 .getFloatAttr(floatType, APFloat::getInf(floatType.getFloatSemantics(), 222 /*Negative*/ true)); 223 } 224 if (auto intType = type.dyn_cast<IntegerType>()) { 225 unsigned width = intType.getWidth(); 226 // The select instruction used to lower the PoolingMin uses a signed 227 // comparison, use a signed constant irrespective of the signedness of the 228 // integer type. 229 return OpBuilder(type.getContext()) 230 .getIntegerAttr(intType, APInt::getSignedMinValue(width)); 231 } 232 llvm_unreachable("Unsupported data type for PoolingMaxOp"); 233 return {}; 234 } 235 236 template <> Attribute getPadValueAttr<PoolingMinOp>(Type type) { 237 if (auto floatType = type.dyn_cast<FloatType>()) { 238 return OpBuilder(type.getContext()) 239 .getFloatAttr(floatType, 240 APFloat::getInf(floatType.getFloatSemantics())); 241 } 242 if (auto intType = type.dyn_cast<IntegerType>()) { 243 unsigned width = intType.getWidth(); 244 // The select instruction used to lower the PoolingMin uses a signed 245 // comparison, use a signed constant irrespective of the signedness of the 246 // integer type. 247 return OpBuilder(type.getContext()) 248 .getIntegerAttr(intType, APInt::getSignedMaxValue(width)); 249 } 250 llvm_unreachable("Unsupported data type for PoolingMinOp"); 251 return {}; 252 } 253 254 template <> Attribute getPadValueAttr<PoolingSumOp>(Type type) { 255 return OpBuilder(type.getContext()).getZeroAttr(type); 256 } 257 258 template <> Attribute getPadValueAttr<ConvOp>(Type type) { 259 return OpBuilder(type.getContext()).getZeroAttr(type); 260 } 261 262 } // namespace 263 264 /// Returns true is `convOp` has a non-zero padding. 265 static bool hasPadding(ConvOp convOp) { 266 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 267 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 268 return true; 269 } 270 return false; 271 } 272 273 template <typename LoadOpTy, typename StoreOpTy> 274 static void emitScalarImplementation(OpBuilder &b, Location loc, 275 ArrayRef<Value> allIvs, ConvOp convOp) { 276 assert(convOp.hasBufferSemantics() && 277 "expected linalg op with buffer semantics"); 278 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 279 auto maps = llvm::to_vector<8>( 280 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 281 SmallVector<Value> fIdx(makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 282 SmallVector<Value> imIdx(makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 283 SmallVector<Value> oIdx(makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 284 285 Value filter = convOp.filter(), output = convOp.output(); 286 287 // Emit scalar form. Padded conv involves an affine.max in the memory access 288 // which is not allowed by affine.load. Override to use an MemRefIndexedValue 289 // when there is non-zero padding. 290 if (hasPadding(convOp)) { 291 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 292 Value padValue = 293 b.create<ConstantOp>(loc, type, getPadValueAttr<ConvOp>(type)); 294 Value paddedInput = 295 getPaddedInput(b, loc, convOp.input(), imIdx, 296 /* Only need to pad the window dimensions */ 297 {0, static_cast<int>(imIdx.size()) - 1}, padValue); 298 Value filterVal = b.create<LoadOpTy>(loc, filter, fIdx); 299 Value mulVal = ArithBuilder(b, loc).mul(filterVal, paddedInput); 300 Value outputVal = b.create<LoadOpTy>(loc, output, oIdx); 301 Value addVal = ArithBuilder(b, loc).add(mulVal, outputVal); 302 b.create<StoreOpTy>(loc, addVal, output, oIdx); 303 } else { 304 Value inputVal = b.create<LoadOpTy>(loc, convOp.input(), imIdx); 305 Value filterVal = b.create<LoadOpTy>(loc, filter, fIdx); 306 Value mulVal = ArithBuilder(b, loc).mul(filterVal, inputVal); 307 Value outputVal = b.create<LoadOpTy>(loc, output, oIdx); 308 Value addVal = ArithBuilder(b, loc).add(mulVal, outputVal); 309 b.create<StoreOpTy>(loc, addVal, output, oIdx); 310 } 311 } 312 313 template <typename PoolingOp> static bool hasPadding(PoolingOp poolingOp) { 314 for (unsigned i = 0, e = poolingOp.getNumWindowLoops(); i < e; ++i) { 315 if (poolingOp.getLowPad(i) > 0 || poolingOp.getHighPad(i) > 0) 316 return true; 317 } 318 return false; 319 } 320 321 template <typename LoadOpTy, typename StoreOpTy, typename PoolingOp> 322 static Value getPoolingInput(OpBuilder &b, Location loc, PoolingOp op, 323 ArrayRef<Value> inputIndices) { 324 if (hasPadding(op)) { 325 Type type = 326 op.input().getType().template cast<MemRefType>().getElementType(); 327 Value padValue = 328 b.create<ConstantOp>(loc, type, getPadValueAttr<PoolingOp>(type)); 329 return getPaddedInput(b, loc, op.input(), inputIndices, 330 /*Pad every dimension*/ {}, padValue); 331 } 332 return b.create<LoadOpTy>(loc, op.input(), inputIndices); 333 } 334 335 template <typename LoadOpTy, typename StoreOpTy, typename OpType> 336 void emitPoolingMinMaxScalarImplementation(OpBuilder &b, Location loc, 337 ArrayRef<Value> allIvs, OpType op) { 338 InputAndOutputIndices indices = getInputAndOutputIndices(b, loc, allIvs, op); 339 Value lhs = b.create<LoadOpTy>(loc, op.output(), indices.outputs); 340 Value rhs = getPoolingInput<LoadOpTy, StoreOpTy>(b, loc, op, indices.inputs); 341 Value value = llvm::TypeSwitch<Operation *, Value>(op) 342 .Case([&](PoolingMinOp poolingOp) { 343 return ArithBuilder(b, loc).select( 344 ArithBuilder(b, loc).slt(lhs, rhs), lhs, rhs); 345 }) 346 .Case([&](PoolingMaxOp poolingOp) { 347 return ArithBuilder(b, loc).select( 348 ArithBuilder(b, loc).sgt(lhs, rhs), lhs, rhs); 349 }) 350 .Default([&](auto) { return Value(); }); 351 b.create<StoreOpTy>(loc, value, op.output(), indices.outputs); 352 } 353 354 template <typename LoadOpTy, typename StoreOpTy> 355 static void emitScalarImplementation(OpBuilder &b, Location loc, 356 ArrayRef<Value> allIvs, PoolingMaxOp op) { 357 emitPoolingMinMaxScalarImplementation<LoadOpTy, StoreOpTy, PoolingMaxOp>( 358 b, loc, allIvs, op); 359 } 360 361 template <typename LoadOpTy, typename StoreOpTy> 362 static void emitScalarImplementation(OpBuilder &b, Location loc, 363 ArrayRef<Value> allIvs, PoolingMinOp op) { 364 emitPoolingMinMaxScalarImplementation<LoadOpTy, StoreOpTy, PoolingMinOp>( 365 b, loc, allIvs, op); 366 } 367 368 template <typename LoadOpTy, typename StoreOpTy> 369 static void emitScalarImplementation(OpBuilder &b, Location loc, 370 ArrayRef<Value> allIvs, PoolingSumOp op) { 371 auto indices = getInputAndOutputIndices(b, loc, allIvs, op); 372 Value inputVal = 373 getPoolingInput<LoadOpTy, StoreOpTy>(b, loc, op, indices.inputs); 374 Value outputVal = b.create<LoadOpTy>(loc, op.output(), indices.outputs); 375 Value added = ArithBuilder(b, loc).add(outputVal, inputVal); 376 b.create<StoreOpTy>(loc, added, op.output(), indices.outputs); 377 } 378 379 /// Replace the index operations in the body of the loop nest by the matching 380 /// induction variables. 381 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, 382 PatternRewriter &rewriter, 383 ArrayRef<Operation *> loopOps) { 384 // Extract the induction variables of the loop nest from outer to inner. 385 SmallVector<Value> allIvs; 386 for (Operation *loopOp : loopOps) { 387 llvm::TypeSwitch<Operation *>(loopOp) 388 .Case([&](scf::ParallelOp parallelOp) { 389 allIvs.append(parallelOp.getInductionVars().begin(), 390 parallelOp.getInductionVars().end()); 391 }) 392 .Case([&](scf::ForOp forOp) { 393 allIvs.push_back(forOp.getInductionVar()); 394 }) 395 .Case([&](AffineForOp affineForOp) { 396 allIvs.push_back(affineForOp.getInductionVar()); 397 }) 398 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 399 } 400 assert(linalgOp.getNumLoops() == allIvs.size() && 401 "expected the number of loops and induction variables to match"); 402 // Replace the index operations in the body of the innermost loop op. 403 if (!loopOps.empty()) { 404 LoopLikeOpInterface loopOp = loopOps.back(); 405 for (IndexOp indexOp : 406 llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) 407 rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); 408 } 409 } 410 411 template <typename LoopTy> 412 static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter, 413 LinalgOp linalgOp) { 414 using LoadOpTy = 415 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value, 416 AffineLoadOp, memref::LoadOp>::type; 417 using StoreOpTy = 418 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value, 419 AffineStoreOp, memref::StoreOp>::type; 420 421 // The flattened loopToOperandRangesMaps is expected to be an invertible 422 // permutation map (which is asserted in the inverse calculation). 423 assert(linalgOp.hasBufferSemantics() && 424 "expected linalg op with buffer semantics"); 425 426 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); 427 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); 428 429 SmallVector<Value> allIvs; 430 GenerateLoopNest<LoopTy>::doit( 431 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, 432 [&](OpBuilder &b, Location loc, ValueRange ivs, 433 ValueRange iterArgs) -> scf::ValueVector { 434 assert(iterArgs.empty() && "unexpected iterArgs"); 435 allIvs.append(ivs.begin(), ivs.end()); 436 llvm::TypeSwitch<Operation *>(linalgOp) 437 .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>( 438 [&](auto op) { 439 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, 440 op); 441 }) 442 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 443 return scf::ValueVector{}; 444 }); 445 // Number of loop ops might be different from the number of ivs since some 446 // loops like affine.parallel and scf.parallel have multiple ivs. 447 SetVector<Operation *> loopSet; 448 for (Value iv : allIvs) { 449 if (!iv) 450 return {}; 451 // The induction variable is a block argument of the entry block of the 452 // loop operation. 453 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 454 if (!ivVal) 455 return {}; 456 loopSet.insert(ivVal.getOwner()->getParentOp()); 457 } 458 LinalgLoops loops(loopSet.begin(), loopSet.end()); 459 // Replace all index operations in the loop body. 460 replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops); 461 return loops; 462 } 463 464 namespace { 465 template <typename LoopType> 466 class LinalgRewritePattern : public RewritePattern { 467 public: 468 LinalgRewritePattern(MLIRContext *context) 469 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 470 471 LogicalResult matchAndRewrite(Operation *op, 472 PatternRewriter &rewriter) const override { 473 auto linalgOp = dyn_cast<LinalgOp>(op); 474 if (!isa<LinalgOp>(op)) 475 return failure(); 476 if (!linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)) 477 return failure(); 478 rewriter.eraseOp(op); 479 return success(); 480 } 481 }; 482 483 /// Converts tiled_loop to SCF loop nests. All parallel dimensions are collected 484 /// into an scf.parallel loop and all sequential dimensions will result in the 485 /// nested scf.for loop nest. The pattern assumes that a tiled loop with 486 /// iterator_types ["reduction", "parallel", "reduction"] can be reordered. It 487 /// is true for the tiling that is currently suppported by Linalg. 488 struct TiledLoopToSCFPattern : public OpRewritePattern<TiledLoopOp> { 489 using OpRewritePattern<TiledLoopOp>::OpRewritePattern; 490 491 LogicalResult matchAndRewrite(TiledLoopOp tiledLoop, 492 PatternRewriter &rewriter) const override { 493 // Fail conversion if the `tiled_loop` has not been bufferized. 494 if (!tiledLoop.hasBufferSemantics()) 495 return failure(); 496 497 // Collect loop control parameters for parallel and sequential dimensions. 498 SmallVector<Value, 3> seqLBs, seqUBs, seqSteps, seqIVs; 499 SmallVector<Value, 3> parLBs, parUBs, parSteps, parIVs; 500 for (auto en : llvm::enumerate( 501 llvm::zip(tiledLoop.lowerBound(), tiledLoop.upperBound(), 502 tiledLoop.step(), tiledLoop.getInductionVars()))) { 503 Value lb, ub, step, iv; 504 std::tie(lb, ub, step, iv) = en.value(); 505 if (tiledLoop.isParallelDimension(en.index())) { 506 parLBs.push_back(lb); 507 parUBs.push_back(ub); 508 parSteps.push_back(step); 509 parIVs.push_back(iv); 510 } else { 511 seqLBs.push_back(lb); 512 seqUBs.push_back(ub); 513 seqSteps.push_back(step); 514 seqIVs.push_back(iv); 515 } 516 } 517 518 Location loc = tiledLoop.getLoc(); 519 auto generateForLoopNestAndCloneBody = [&](OpBuilder &builder, Location loc, 520 ValueRange ivs) { 521 BlockAndValueMapping bvm; 522 bvm.map(parIVs, ivs); 523 bvm.map(tiledLoop.getRegionInputArgs(), tiledLoop.inputs()); 524 bvm.map(tiledLoop.getRegionOutputArgs(), tiledLoop.outputs()); 525 526 // If not all dimensions of the tiled loop are parallel, an scf.for loop 527 // nest is generated. 528 if (!seqIVs.empty()) { 529 scf::LoopNest nest = 530 scf::buildLoopNest(builder, loc, seqLBs, seqUBs, seqSteps, 531 [&](OpBuilder &builder, Location loc, 532 ValueRange ivs) { bvm.map(seqIVs, ivs); }); 533 builder.setInsertionPointToStart(nest.loops.back().getBody()); 534 } 535 for (auto &op : tiledLoop.getBody()->without_terminator()) 536 builder.clone(op, bvm); 537 }; 538 539 if (parIVs.empty()) 540 generateForLoopNestAndCloneBody(rewriter, loc, llvm::None); 541 else 542 rewriter.create<scf::ParallelOp>(loc, parLBs, parUBs, parSteps, 543 generateForLoopNestAndCloneBody); 544 rewriter.eraseOp(tiledLoop); 545 return success(); 546 } 547 }; 548 549 /// Local folding pattern for AffineApplyOp that we can apply greedily. 550 /// This replaces AffineApplyOp by the proper value in cases where the 551 /// associated map is trivial. 552 /// A trivial map here is defined as a map with a single result and either: 553 /// 1. Zero operand + returns a single AffineConstantExpr 554 /// 2. One operand + returns a single AffineDimExpr 555 /// 3. One operand + returns a single AffineSymbolExpr 556 // 557 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 558 /// other cases, it is replaced by its unique operand. 559 struct FoldAffineOp : public RewritePattern { 560 FoldAffineOp(MLIRContext *context) 561 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 562 563 LogicalResult matchAndRewrite(Operation *op, 564 PatternRewriter &rewriter) const override { 565 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 566 auto map = affineApplyOp.getAffineMap(); 567 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 568 return failure(); 569 570 AffineExpr expr = map.getResult(0); 571 if (map.getNumInputs() == 0) { 572 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 573 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 574 return success(); 575 } 576 return failure(); 577 } 578 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 579 rewriter.replaceOp(op, op->getOperand(0)); 580 return success(); 581 } 582 return failure(); 583 } 584 }; 585 586 template <typename LoopType> 587 static void lowerLinalgToLoopsImpl(FuncOp funcOp) { 588 MLIRContext *context = funcOp.getContext(); 589 RewritePatternSet patterns(context); 590 patterns.add<LinalgRewritePattern<LoopType>>(context); 591 memref::DimOp::getCanonicalizationPatterns(patterns, context); 592 tensor::DimOp::getCanonicalizationPatterns(patterns, context); 593 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 594 patterns.add<FoldAffineOp>(context); 595 // Just apply the patterns greedily. 596 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 597 } 598 599 struct LowerToAffineLoops 600 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 601 void getDependentDialects(DialectRegistry ®istry) const override { 602 registry.insert<memref::MemRefDialect>(); 603 } 604 void runOnFunction() override { 605 lowerLinalgToLoopsImpl<AffineForOp>(getFunction()); 606 } 607 }; 608 609 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 610 void getDependentDialects(DialectRegistry ®istry) const override { 611 registry.insert<memref::MemRefDialect, scf::SCFDialect>(); 612 } 613 void runOnFunction() override { 614 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction()); 615 } 616 }; 617 618 struct LowerToParallelLoops 619 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 620 void runOnFunction() override { 621 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction()); 622 } 623 }; 624 625 struct LowerTiledLoopsToSCF 626 : public LinalgLowerTiledLoopsToSCFBase<LowerTiledLoopsToSCF> { 627 void runOnFunction() override { 628 MLIRContext *context = &getContext(); 629 RewritePatternSet patterns(context); 630 populateTiledLoopToSCFPattern(patterns); 631 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 632 } 633 }; 634 } // namespace 635 636 void mlir::linalg::populateTiledLoopToSCFPattern(RewritePatternSet &patterns) { 637 patterns.add<TiledLoopToSCFPattern>(patterns.getContext()); 638 } 639 640 std::unique_ptr<OperationPass<FuncOp>> 641 mlir::createConvertLinalgTiledLoopsToSCFPass() { 642 return std::make_unique<LowerTiledLoopsToSCF>(); 643 } 644 645 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 646 return std::make_unique<LowerToLoops>(); 647 } 648 649 std::unique_ptr<OperationPass<FuncOp>> 650 mlir::createConvertLinalgToParallelLoopsPass() { 651 return std::make_unique<LowerToParallelLoops>(); 652 } 653 654 std::unique_ptr<OperationPass<FuncOp>> 655 mlir::createConvertLinalgToAffineLoopsPass() { 656 return std::make_unique<LowerToAffineLoops>(); 657 } 658 659 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. 660 Optional<LinalgLoops> 661 mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, 662 LinalgOp linalgOp) { 663 return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp); 664 } 665 666 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. 667 Optional<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, 668 LinalgOp linalgOp) { 669 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp); 670 } 671 672 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. 673 Optional<LinalgLoops> 674 mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, 675 LinalgOp linalgOp) { 676 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp); 677 } 678