1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/SCF/Transforms.h" 16 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 17 #include "mlir/IR/AffineExpr.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/FoldUtils.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc, 30 AffineMap map, 31 ArrayRef<Value> vals) { 32 if (map.isEmpty()) 33 return {}; 34 35 assert(map.getNumInputs() == vals.size()); 36 SmallVector<Value> res; 37 res.reserve(map.getNumResults()); 38 auto dims = map.getNumDims(); 39 for (auto e : map.getResults()) { 40 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 41 SmallVector<Value> operands(vals.begin(), vals.end()); 42 canonicalizeMapAndOperands(&exprMap, &operands); 43 res.push_back(b.create<AffineApplyOp>(loc, exprMap, operands)); 44 } 45 return res; 46 } 47 48 template <typename LoadOpTy, typename StoreOpTy, typename OpType> 49 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, 50 ArrayRef<Value> indexedValues, 51 ArrayRef<SmallVector<Value>> indexing, 52 ArrayRef<Value> outputBuffers) { 53 auto &block = op->getRegion(0).front(); 54 BlockAndValueMapping map; 55 map.map(block.getArguments(), indexedValues); 56 for (auto &op : block.without_terminator()) { 57 auto *newOp = b.clone(op, map); 58 map.map(op.getResults(), newOp->getResults()); 59 } 60 61 Operation *terminator = block.getTerminator(); 62 for (OpOperand &operand : terminator->getOpOperands()) { 63 Value toStore = map.lookupOrDefault(operand.get()); 64 b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()], 65 indexing[operand.getOperandNumber()]); 66 } 67 } 68 69 // Returns a pair that contains input indices and output indices of a 70 // SingleInputPoolingOp `op`. 71 struct InputAndOutputIndices { 72 SmallVector<Value> inputs; 73 SmallVector<Value> outputs; 74 }; 75 template <typename SingleInputPoolingOp> 76 static InputAndOutputIndices 77 getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs, 78 SingleInputPoolingOp op) { 79 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 80 auto maps = llvm::to_vector<8>( 81 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 82 return InputAndOutputIndices{ 83 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 84 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 85 } 86 87 /// Emits the MLIR for the scalar part of the generic op by: 88 /// 1. Emitting load ops for each input and output view in order. This is 89 /// achieved by applying the appropriate input or output map to the 90 /// enclosing induction variables. 91 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 92 /// from point 1. above. 93 /// 3. Emitting store ops to store the results of 2. to the output 94 /// views. 95 /// 96 /// An example output may resemble: 97 /// 98 /// ``` 99 /// scf.for %i = %c0 to %0 step %c1 { 100 /// scf.for %j = %c0 to %1 step %c1 { 101 /// scf.for %k = %c0 to %4 step %c1 { 102 /// %11 = load %arg0[%i, %j] : 103 /// memref<?x?xf32, stride_specification> 104 /// %12 = load %arg1[%i, %j, %k] : 105 /// memref<?x?x?xf32, stride_specification> 106 /// %13 = load %arg2[%i, %k, %j] : 107 /// memref<?x?x?xf32, stride_specification> 108 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 109 /// store %14#0, %arg1[%i, %j, %k] : 110 /// memref<?x?x?Xf32, stride_specification> 111 /// store %14#1, %arg2[%i, %k, %j] : 112 /// memref<?x?x?Xf32, stride_specification> 113 /// } 114 /// } 115 /// } 116 /// ``` 117 template <typename LoadOpTy, typename StoreOpTy> 118 static void emitScalarImplementation(OpBuilder &b, Location loc, 119 ArrayRef<Value> allIvs, 120 LinalgOp linalgOp) { 121 assert(linalgOp.hasBufferSemantics() && 122 "expected linalg op with buffer semantics"); 123 SmallVector<Value> indexedValues; 124 indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); 125 126 auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end()); 127 128 // TODO: Avoid the loads if the corresponding argument of the 129 // region has no uses. 130 // 1.a. Emit load from input operand or for scalars access the operand itself. 131 for (OpOperand *inputOperand : linalgOp.getInputOperands()) { 132 if (linalgOp.isScalar(inputOperand)) { 133 indexedValues.push_back(inputOperand->get()); 134 continue; 135 } 136 auto indexing = makeCanonicalAffineApplies( 137 b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims); 138 indexedValues.push_back( 139 b.create<LoadOpTy>(loc, inputOperand->get(), indexing)); 140 } 141 // 1.b. Emit load from output views. 142 for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { 143 SmallVector<Value> indexing = makeCanonicalAffineApplies( 144 b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims); 145 indexedValues.push_back( 146 b.create<LoadOpTy>(loc, outputOperand->get(), indexing)); 147 } 148 149 // TODO: When a region inliner exists, use it. 150 // 2. Inline region, currently only works for a single basic block. 151 // 3. Emit store. 152 SmallVector<SmallVector<Value>, 8> indexing; 153 SmallVector<Value> outputBuffers; 154 for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) { 155 indexing.push_back(makeCanonicalAffineApplies( 156 b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims)); 157 outputBuffers.push_back(outputOperand->get()); 158 } 159 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues, 160 indexing, outputBuffers); 161 } 162 163 // Create a padded view into the given `input` tensor using the 'indices' 164 // to access the tensor. `skipPadding` lists the dimensions for which no padding 165 // is needed e.g. the non-spatial dimensions for convolutions. 166 Value getPaddedInput(OpBuilder &b, Location loc, Value input, 167 ArrayRef<Value> indices, ArrayRef<int> skipPadding, 168 Value padValue) { 169 Value zeroIndex = b.create<ConstantIndexOp>(loc, 0); 170 SmallVector<Value> conds; 171 SmallVector<Value> clampedImIdx; 172 for (auto iter : llvm::enumerate(indices)) { 173 int idx = iter.index(); 174 auto dim = iter.value(); 175 if (is_contained(skipPadding, idx)) { 176 clampedImIdx.push_back(dim); 177 continue; 178 } 179 180 Value leftOutOfBound = 181 b.create<CmpIOp>(loc, CmpIPredicate::slt, dim, zeroIndex); 182 if (conds.empty()) 183 conds.push_back(leftOutOfBound); 184 else 185 conds.push_back(b.create<OrOp>(loc, conds.back(), leftOutOfBound)); 186 Value rightBound = createOrFoldDimOp(b, loc, input, idx); 187 Value rightOutOfBound = 188 b.create<CmpIOp>(loc, CmpIPredicate::sge, dim, rightBound); 189 conds.push_back(b.create<OrOp>(loc, conds.back(), rightOutOfBound)); 190 191 // When padding is involved, the indices will only be shifted to negative, 192 // so having a max op is enough. 193 MLIRContext *ctx = input.getContext(); 194 AffineExpr m = getAffineDimExpr(/*position=*/0, ctx), 195 zero = getAffineConstantExpr(0, ctx); 196 AffineMap maxMap = 197 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>>{{m, zero}}) 198 .front(); 199 clampedImIdx.push_back(b.create<AffineMaxOp>(loc, maxMap, ValueRange{dim})); 200 } 201 202 Value readInput = b.create<memref::LoadOp>(loc, input, clampedImIdx); 203 if (conds.empty()) 204 return readInput; 205 206 return b.create<SelectOp>(loc, conds.back(), padValue, readInput); 207 } 208 209 namespace { 210 211 /// The padding value for a given Op depends on the semantics of the Op. 212 /// The identity value for ConvOp and PoolingSumOp is 0, for PoolingMaxOp is 213 /// -inf or minInt and for PoolingMinOp is inf or maxInt. 214 template <typename OpType> Attribute getPadValueAttr(Type type) { 215 llvm_unreachable("Unexpected op type for getPadValueAttr"); 216 return {}; 217 } 218 219 template <> Attribute getPadValueAttr<PoolingMaxOp>(Type type) { 220 if (auto floatType = type.dyn_cast<FloatType>()) { 221 return OpBuilder(type.getContext()) 222 .getFloatAttr(floatType, APFloat::getInf(floatType.getFloatSemantics(), 223 /*Negative*/ true)); 224 } 225 if (auto intType = type.dyn_cast<IntegerType>()) { 226 unsigned width = intType.getWidth(); 227 // The select instruction used to lower the PoolingMin uses a signed 228 // comparison, use a signed constant irrespective of the signedness of the 229 // integer type. 230 return OpBuilder(type.getContext()) 231 .getIntegerAttr(intType, APInt::getSignedMinValue(width)); 232 } 233 llvm_unreachable("Unsupported data type for PoolingMaxOp"); 234 return {}; 235 } 236 237 template <> Attribute getPadValueAttr<PoolingMinOp>(Type type) { 238 if (auto floatType = type.dyn_cast<FloatType>()) { 239 return OpBuilder(type.getContext()) 240 .getFloatAttr(floatType, 241 APFloat::getInf(floatType.getFloatSemantics())); 242 } 243 if (auto intType = type.dyn_cast<IntegerType>()) { 244 unsigned width = intType.getWidth(); 245 // The select instruction used to lower the PoolingMin uses a signed 246 // comparison, use a signed constant irrespective of the signedness of the 247 // integer type. 248 return OpBuilder(type.getContext()) 249 .getIntegerAttr(intType, APInt::getSignedMaxValue(width)); 250 } 251 llvm_unreachable("Unsupported data type for PoolingMinOp"); 252 return {}; 253 } 254 255 template <> Attribute getPadValueAttr<PoolingSumOp>(Type type) { 256 return OpBuilder(type.getContext()).getZeroAttr(type); 257 } 258 259 template <> Attribute getPadValueAttr<ConvOp>(Type type) { 260 return OpBuilder(type.getContext()).getZeroAttr(type); 261 } 262 263 } // namespace 264 265 /// Returns true is `convOp` has a non-zero padding. 266 static bool hasPadding(ConvOp convOp) { 267 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 268 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 269 return true; 270 } 271 return false; 272 } 273 274 template <typename LoadOpTy, typename StoreOpTy> 275 static void emitScalarImplementation(OpBuilder &b, Location loc, 276 ArrayRef<Value> allIvs, ConvOp convOp) { 277 assert(convOp.hasBufferSemantics() && 278 "expected linalg op with buffer semantics"); 279 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 280 auto maps = llvm::to_vector<8>( 281 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 282 SmallVector<Value> fIdx(makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 283 SmallVector<Value> imIdx(makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 284 SmallVector<Value> oIdx(makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 285 286 Value filter = convOp.filter(), output = convOp.output(); 287 288 // Emit scalar form. Padded conv involves an affine.max in the memory access 289 // which is not allowed by affine.load. Override to use an MemRefIndexedValue 290 // when there is non-zero padding. 291 if (hasPadding(convOp)) { 292 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 293 Value padValue = 294 b.create<ConstantOp>(loc, type, getPadValueAttr<ConvOp>(type)); 295 Value paddedInput = 296 getPaddedInput(b, loc, convOp.input(), imIdx, 297 /* Only need to pad the window dimensions */ 298 {0, static_cast<int>(imIdx.size()) - 1}, padValue); 299 Value filterVal = b.create<LoadOpTy>(loc, filter, fIdx); 300 Value mulVal = ArithBuilder(b, loc).mul(filterVal, paddedInput); 301 Value outputVal = b.create<LoadOpTy>(loc, output, oIdx); 302 Value addVal = ArithBuilder(b, loc).add(mulVal, outputVal); 303 b.create<StoreOpTy>(loc, addVal, output, oIdx); 304 } else { 305 Value inputVal = b.create<LoadOpTy>(loc, convOp.input(), imIdx); 306 Value filterVal = b.create<LoadOpTy>(loc, filter, fIdx); 307 Value mulVal = ArithBuilder(b, loc).mul(filterVal, inputVal); 308 Value outputVal = b.create<LoadOpTy>(loc, output, oIdx); 309 Value addVal = ArithBuilder(b, loc).add(mulVal, outputVal); 310 b.create<StoreOpTy>(loc, addVal, output, oIdx); 311 } 312 } 313 314 template <typename PoolingOp> static bool hasPadding(PoolingOp poolingOp) { 315 for (unsigned i = 0, e = poolingOp.getNumWindowLoops(); i < e; ++i) { 316 if (poolingOp.getLowPad(i) > 0 || poolingOp.getHighPad(i) > 0) 317 return true; 318 } 319 return false; 320 } 321 322 template <typename LoadOpTy, typename StoreOpTy, typename PoolingOp> 323 static Value getPoolingInput(OpBuilder &b, Location loc, PoolingOp op, 324 ArrayRef<Value> inputIndices) { 325 if (hasPadding(op)) { 326 Type type = 327 op.input().getType().template cast<MemRefType>().getElementType(); 328 Value padValue = 329 b.create<ConstantOp>(loc, type, getPadValueAttr<PoolingOp>(type)); 330 return getPaddedInput(b, loc, op.input(), inputIndices, 331 /*Pad every dimension*/ {}, padValue); 332 } 333 return b.create<LoadOpTy>(loc, op.input(), inputIndices); 334 } 335 336 template <typename LoadOpTy, typename StoreOpTy, typename OpType> 337 void emitPoolingMinMaxScalarImplementation(OpBuilder &b, Location loc, 338 ArrayRef<Value> allIvs, OpType op) { 339 InputAndOutputIndices indices = getInputAndOutputIndices(b, loc, allIvs, op); 340 Value lhs = b.create<LoadOpTy>(loc, op.output(), indices.outputs); 341 Value rhs = getPoolingInput<LoadOpTy, StoreOpTy>(b, loc, op, indices.inputs); 342 Value value = llvm::TypeSwitch<Operation *, Value>(op) 343 .Case([&](PoolingMinOp poolingOp) { 344 return ArithBuilder(b, loc).select( 345 ArithBuilder(b, loc).slt(lhs, rhs), lhs, rhs); 346 }) 347 .Case([&](PoolingMaxOp poolingOp) { 348 return ArithBuilder(b, loc).select( 349 ArithBuilder(b, loc).sgt(lhs, rhs), lhs, rhs); 350 }) 351 .Default([&](auto) { return Value(); }); 352 b.create<StoreOpTy>(loc, value, op.output(), indices.outputs); 353 } 354 355 template <typename LoadOpTy, typename StoreOpTy> 356 static void emitScalarImplementation(OpBuilder &b, Location loc, 357 ArrayRef<Value> allIvs, PoolingMaxOp op) { 358 emitPoolingMinMaxScalarImplementation<LoadOpTy, StoreOpTy, PoolingMaxOp>( 359 b, loc, allIvs, op); 360 } 361 362 template <typename LoadOpTy, typename StoreOpTy> 363 static void emitScalarImplementation(OpBuilder &b, Location loc, 364 ArrayRef<Value> allIvs, PoolingMinOp op) { 365 emitPoolingMinMaxScalarImplementation<LoadOpTy, StoreOpTy, PoolingMinOp>( 366 b, loc, allIvs, op); 367 } 368 369 template <typename LoadOpTy, typename StoreOpTy> 370 static void emitScalarImplementation(OpBuilder &b, Location loc, 371 ArrayRef<Value> allIvs, PoolingSumOp op) { 372 auto indices = getInputAndOutputIndices(b, loc, allIvs, op); 373 Value inputVal = 374 getPoolingInput<LoadOpTy, StoreOpTy>(b, loc, op, indices.inputs); 375 Value outputVal = b.create<LoadOpTy>(loc, op.output(), indices.outputs); 376 Value added = ArithBuilder(b, loc).add(outputVal, inputVal); 377 b.create<StoreOpTy>(loc, added, op.output(), indices.outputs); 378 } 379 380 /// Replace the index operations in the body of the loop nest by the matching 381 /// induction variables. 382 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, 383 PatternRewriter &rewriter, 384 ArrayRef<Operation *> loopOps) { 385 // Extract the induction variables of the loop nest from outer to inner. 386 SmallVector<Value> allIvs; 387 for (Operation *loopOp : loopOps) { 388 llvm::TypeSwitch<Operation *>(loopOp) 389 .Case([&](scf::ParallelOp parallelOp) { 390 allIvs.append(parallelOp.getInductionVars().begin(), 391 parallelOp.getInductionVars().end()); 392 }) 393 .Case([&](scf::ForOp forOp) { 394 allIvs.push_back(forOp.getInductionVar()); 395 }) 396 .Case([&](AffineForOp affineForOp) { 397 allIvs.push_back(affineForOp.getInductionVar()); 398 }) 399 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 400 } 401 assert(linalgOp.getNumLoops() == allIvs.size() && 402 "expected the number of loops and induction variables to match"); 403 // Replace the index operations in the body of the innermost loop op. 404 if (!loopOps.empty()) { 405 LoopLikeOpInterface loopOp = loopOps.back(); 406 for (IndexOp indexOp : 407 llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) 408 rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); 409 } 410 } 411 412 template <typename LoopTy> 413 static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter, 414 LinalgOp linalgOp) { 415 using LoadOpTy = 416 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value, 417 AffineLoadOp, memref::LoadOp>::type; 418 using StoreOpTy = 419 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value, 420 AffineStoreOp, memref::StoreOp>::type; 421 422 // The flattened loopToOperandRangesMaps is expected to be an invertible 423 // permutation map (which is asserted in the inverse calculation). 424 assert(linalgOp.hasBufferSemantics() && 425 "expected linalg op with buffer semantics"); 426 427 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); 428 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); 429 430 SmallVector<Value> allIvs; 431 GenerateLoopNest<LoopTy>::doit( 432 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, 433 [&](OpBuilder &b, Location loc, ValueRange ivs, 434 ValueRange operandValuesToUse) -> scf::ValueVector { 435 assert(operandValuesToUse == linalgOp->getOperands() && 436 "expect operands are captured and not passed by loop argument"); 437 allIvs.append(ivs.begin(), ivs.end()); 438 llvm::TypeSwitch<Operation *>(linalgOp) 439 .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>( 440 [&](auto op) { 441 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, 442 op); 443 }) 444 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 445 return scf::ValueVector{}; 446 }); 447 // Number of loop ops might be different from the number of ivs since some 448 // loops like affine.parallel and scf.parallel have multiple ivs. 449 SetVector<Operation *> loopSet; 450 for (Value iv : allIvs) { 451 if (!iv) 452 return {}; 453 // The induction variable is a block argument of the entry block of the 454 // loop operation. 455 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 456 if (!ivVal) 457 return {}; 458 loopSet.insert(ivVal.getOwner()->getParentOp()); 459 } 460 LinalgLoops loops(loopSet.begin(), loopSet.end()); 461 // Replace all index operations in the loop body. 462 replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops); 463 return loops; 464 } 465 466 namespace { 467 template <typename LoopType> 468 class LinalgRewritePattern : public RewritePattern { 469 public: 470 LinalgRewritePattern(MLIRContext *context) 471 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 472 473 LogicalResult matchAndRewrite(Operation *op, 474 PatternRewriter &rewriter) const override { 475 auto linalgOp = dyn_cast<LinalgOp>(op); 476 if (!isa<LinalgOp>(op)) 477 return failure(); 478 if (!linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)) 479 return failure(); 480 rewriter.eraseOp(op); 481 return success(); 482 } 483 }; 484 485 /// Converts tiled_loop to SCF loop nests. All parallel dimensions are collected 486 /// into an scf.parallel loop and all sequential dimensions will result in the 487 /// nested scf.for loop nest. The pattern assumes that a tiled loop with 488 /// iterator_types ["reduction", "parallel", "reduction"] can be reordered. It 489 /// is true for the tiling that is currently suppported by Linalg. 490 struct TiledLoopToSCFPattern : public OpRewritePattern<TiledLoopOp> { 491 using OpRewritePattern<TiledLoopOp>::OpRewritePattern; 492 493 LogicalResult matchAndRewrite(TiledLoopOp tiledLoop, 494 PatternRewriter &rewriter) const override { 495 // Fail conversion if the `tiled_loop` has not been bufferized. 496 if (!tiledLoop.hasBufferSemantics()) 497 return failure(); 498 499 // Collect loop control parameters for parallel and sequential dimensions. 500 SmallVector<Value, 3> seqLBs, seqUBs, seqSteps, seqIVs; 501 SmallVector<Value, 3> parLBs, parUBs, parSteps, parIVs; 502 for (auto en : llvm::enumerate( 503 llvm::zip(tiledLoop.lowerBound(), tiledLoop.upperBound(), 504 tiledLoop.step(), tiledLoop.getInductionVars()))) { 505 Value lb, ub, step, iv; 506 std::tie(lb, ub, step, iv) = en.value(); 507 if (tiledLoop.isParallelDimension(en.index())) { 508 parLBs.push_back(lb); 509 parUBs.push_back(ub); 510 parSteps.push_back(step); 511 parIVs.push_back(iv); 512 } else { 513 seqLBs.push_back(lb); 514 seqUBs.push_back(ub); 515 seqSteps.push_back(step); 516 seqIVs.push_back(iv); 517 } 518 } 519 520 Location loc = tiledLoop.getLoc(); 521 auto generateForLoopNestAndCloneBody = [&](OpBuilder &builder, Location loc, 522 ValueRange ivs) { 523 BlockAndValueMapping bvm; 524 bvm.map(parIVs, ivs); 525 bvm.map(tiledLoop.getRegionInputArgs(), tiledLoop.inputs()); 526 bvm.map(tiledLoop.getRegionOutputArgs(), tiledLoop.outputs()); 527 528 // If not all dimensions of the tiled loop are parallel, an scf.for loop 529 // nest is generated. 530 if (!seqIVs.empty()) { 531 scf::LoopNest nest = 532 scf::buildLoopNest(builder, loc, seqLBs, seqUBs, seqSteps, 533 [&](OpBuilder &builder, Location loc, 534 ValueRange ivs) { bvm.map(seqIVs, ivs); }); 535 builder.setInsertionPointToStart(nest.loops.back().getBody()); 536 } 537 for (auto &op : tiledLoop.getBody()->without_terminator()) 538 builder.clone(op, bvm); 539 }; 540 541 if (parIVs.empty()) 542 generateForLoopNestAndCloneBody(rewriter, loc, llvm::None); 543 else 544 rewriter.create<scf::ParallelOp>(loc, parLBs, parUBs, parSteps, 545 generateForLoopNestAndCloneBody); 546 rewriter.eraseOp(tiledLoop); 547 return success(); 548 } 549 }; 550 551 /// Local folding pattern for AffineApplyOp that we can apply greedily. 552 /// This replaces AffineApplyOp by the proper value in cases where the 553 /// associated map is trivial. 554 /// A trivial map here is defined as a map with a single result and either: 555 /// 1. Zero operand + returns a single AffineConstantExpr 556 /// 2. One operand + returns a single AffineDimExpr 557 /// 3. One operand + returns a single AffineSymbolExpr 558 // 559 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 560 /// other cases, it is replaced by its unique operand. 561 struct FoldAffineOp : public RewritePattern { 562 FoldAffineOp(MLIRContext *context) 563 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 564 565 LogicalResult matchAndRewrite(Operation *op, 566 PatternRewriter &rewriter) const override { 567 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 568 auto map = affineApplyOp.getAffineMap(); 569 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 570 return failure(); 571 572 AffineExpr expr = map.getResult(0); 573 if (map.getNumInputs() == 0) { 574 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 575 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 576 return success(); 577 } 578 return failure(); 579 } 580 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 581 rewriter.replaceOp(op, op->getOperand(0)); 582 return success(); 583 } 584 return failure(); 585 } 586 }; 587 588 template <typename LoopType> 589 static void lowerLinalgToLoopsImpl(FuncOp funcOp) { 590 MLIRContext *context = funcOp.getContext(); 591 RewritePatternSet patterns(context); 592 patterns.add<LinalgRewritePattern<LoopType>>(context); 593 memref::DimOp::getCanonicalizationPatterns(patterns, context); 594 tensor::DimOp::getCanonicalizationPatterns(patterns, context); 595 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 596 patterns.add<FoldAffineOp>(context); 597 // Just apply the patterns greedily. 598 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 599 } 600 601 struct LowerToAffineLoops 602 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 603 void getDependentDialects(DialectRegistry ®istry) const override { 604 registry.insert<memref::MemRefDialect>(); 605 } 606 void runOnFunction() override { 607 lowerLinalgToLoopsImpl<AffineForOp>(getFunction()); 608 } 609 }; 610 611 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 612 void getDependentDialects(DialectRegistry ®istry) const override { 613 registry.insert<memref::MemRefDialect, scf::SCFDialect>(); 614 } 615 void runOnFunction() override { 616 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction()); 617 } 618 }; 619 620 struct LowerToParallelLoops 621 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 622 void runOnFunction() override { 623 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction()); 624 } 625 }; 626 627 struct LowerTiledLoopsToSCF 628 : public LinalgLowerTiledLoopsToSCFBase<LowerTiledLoopsToSCF> { 629 void runOnFunction() override { 630 MLIRContext *context = &getContext(); 631 RewritePatternSet patterns(context); 632 populateTiledLoopToSCFPattern(patterns); 633 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 634 } 635 }; 636 } // namespace 637 638 /// Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly 639 /// into two TiledLoopOps: One where the step divides the iteration space 640 /// evenly, followed another one for the last (partial) iteration (if any). This 641 /// function only rewrites the `idx`-th loop of the loop nest represented by 642 /// the TiledLoopOp. To peel the entire loop nest, this function must be called 643 /// multiple times. 644 /// 645 /// This function rewrites the given TiledLoopOp in-place and creates a new 646 /// TiledLoopOp for the last iteration. It replaces all uses of the original 647 /// TiledLoopOp with the results of the newly generated one. 648 /// 649 /// The newly generated TiledLoopOp is returned via `result`. The boundary 650 /// at which the loop is split (new upper bound) is returned via `splitBound`. 651 /// The return value indicates whether the TiledLoopOp was rewritten or not. 652 static LogicalResult peelTiledLoop(RewriterBase &b, TiledLoopOp loopOp, 653 int64_t idx, TiledLoopOp &result, 654 Value &splitBound) { 655 Value lb = loopOp.lowerBound()[idx], ub = loopOp.upperBound()[idx], 656 step = loopOp.step()[idx]; 657 auto ubInt = getConstantIntValue(ub); 658 659 auto loc = loopOp.getLoc(); 660 AffineExpr exprLb, exprUb, exprStep; 661 bindSymbols(b.getContext(), exprLb, exprUb, exprStep); 662 // New upper bound: %ub - (%ub - %lb) mod %step 663 auto modMap = AffineMap::get(0, 3, {exprUb - ((exprUb - exprLb) % exprStep)}); 664 SmallVector<Value> operands{lb, ub, step}; 665 mlir::canonicalizeMapAndOperands(&modMap, &operands); 666 modMap = mlir::simplifyAffineMap(modMap); 667 RewriterBase::InsertionGuard guard(b); 668 b.setInsertionPoint(loopOp); 669 splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, operands); 670 // No specialization necessary if step already divides upper bound evenly. 671 if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound))) 672 return failure(); 673 674 // Create remainder loop. 675 b.setInsertionPointAfter(loopOp); 676 auto remainderLoop = cast<TiledLoopOp>(b.clone(*loopOp.getOperation())); 677 loopOp.replaceAllUsesWith(remainderLoop->getResults()); 678 // Outputs: Take tensors from main loop's results. Take memrefs from main 679 // loop's outputs. 680 SmallVector<Value> remainderOutputs; 681 for (unsigned o = 0, t = 0; o < loopOp.getNumOutputs(); ++o) { 682 remainderOutputs.push_back(loopOp.outputs()[o].getType().isa<MemRefType>() 683 ? loopOp.outputs()[o] 684 : loopOp->getResult(t++)); 685 } 686 remainderLoop.outputsMutable().assign(remainderOutputs); 687 688 // Set new loop bounds. 689 b.updateRootInPlace(loopOp, [&]() { 690 SmallVector<Value> ubs = loopOp.upperBound(); 691 ubs[idx] = splitBound; 692 loopOp.upperBoundMutable().assign(ubs); 693 }); 694 SmallVector<Value> lbs = remainderLoop.lowerBound(); 695 lbs[idx] = splitBound; 696 remainderLoop.lowerBoundMutable().assign(lbs); 697 698 result = remainderLoop; 699 return success(); 700 } 701 702 template <typename OpTy, bool IsMin> 703 static void 704 rewriteAffineOpAfterPeeling(RewriterBase &rewriter, TiledLoopOp mainLoop, 705 TiledLoopOp remainderLoop, Value mainIv, 706 Value remainderIv, Value ub, Value step) { 707 mainLoop.walk([&](OpTy affineOp) { 708 AffineMap map = affineOp.getAffineMap(); 709 (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, 710 affineOp.operands(), IsMin, mainIv, ub, 711 step, /*insideLoop=*/true); 712 }); 713 remainderLoop.walk([&](OpTy affineOp) { 714 AffineMap map = affineOp.getAffineMap(); 715 (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, 716 affineOp.operands(), IsMin, remainderIv, 717 ub, step, /*insideLoop=*/false); 718 }); 719 } 720 721 LogicalResult mlir::linalg::peelAndCanonicalizeTiledLoop(RewriterBase &rewriter, 722 TiledLoopOp loopOp, 723 int64_t idx, 724 TiledLoopOp &result) { 725 int64_t numLoops = loopOp.iterator_types().size(); 726 if (idx < 0 || numLoops <= idx) 727 return failure(); 728 // Only parallel iterator supported. 729 if (!isParallelIterator(loopOp.iterator_types()[idx])) 730 return failure(); 731 732 Value ub = loopOp.upperBound()[idx]; 733 TiledLoopOp remainderLoop; 734 Value splitBound; 735 if (failed(peelTiledLoop(rewriter, loopOp, idx, remainderLoop, splitBound))) 736 return failure(); 737 738 // Rewrite affine.min and affine.max ops. 739 Value mainIv = loopOp.getInductionVars()[idx], step = loopOp.step()[idx], 740 remainderIv = remainderLoop.getInductionVars()[idx]; 741 742 rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>( 743 rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step); 744 rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>( 745 rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step); 746 747 result = remainderLoop; 748 return success(); 749 } 750 751 void mlir::linalg::populateTiledLoopToSCFPattern(RewritePatternSet &patterns) { 752 patterns.add<TiledLoopToSCFPattern>(patterns.getContext()); 753 } 754 755 std::unique_ptr<OperationPass<FuncOp>> 756 mlir::createConvertLinalgTiledLoopsToSCFPass() { 757 return std::make_unique<LowerTiledLoopsToSCF>(); 758 } 759 760 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 761 return std::make_unique<LowerToLoops>(); 762 } 763 764 std::unique_ptr<OperationPass<FuncOp>> 765 mlir::createConvertLinalgToParallelLoopsPass() { 766 return std::make_unique<LowerToParallelLoops>(); 767 } 768 769 std::unique_ptr<OperationPass<FuncOp>> 770 mlir::createConvertLinalgToAffineLoopsPass() { 771 return std::make_unique<LowerToAffineLoops>(); 772 } 773 774 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. 775 Optional<LinalgLoops> 776 mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, 777 LinalgOp linalgOp) { 778 return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp); 779 } 780 781 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. 782 Optional<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, 783 LinalgOp linalgOp) { 784 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp); 785 } 786 787 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. 788 Optional<LinalgLoops> 789 mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, 790 LinalgOp linalgOp) { 791 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp); 792 } 793