1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/SCF/EDSC/Builders.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "mlir/Support/LLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/FoldUtils.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 29 using namespace mlir; 30 using namespace mlir::edsc; 31 using namespace mlir::edsc::intrinsics; 32 using namespace mlir::linalg; 33 34 using edsc::op::operator+; 35 36 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b, 37 Location loc, 38 AffineMap map, 39 ArrayRef<Value> vals) { 40 if (map.isEmpty()) 41 return {}; 42 43 assert(map.getNumInputs() == vals.size()); 44 SmallVector<Value, 8> res; 45 res.reserve(map.getNumResults()); 46 auto dims = map.getNumDims(); 47 for (auto e : map.getResults()) { 48 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 49 SmallVector<Value, 4> operands(vals.begin(), vals.end()); 50 canonicalizeMapAndOperands(&exprMap, &operands); 51 res.push_back(affine_apply(exprMap, operands)); 52 } 53 return res; 54 } 55 56 template <typename IndexedValueType, typename OpType> 57 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues, 58 ArrayRef<SmallVector<Value, 8>> indexing, 59 ArrayRef<Value> outputBuffers) { 60 assert(op->getNumRegions() == 1 && "Expected single region op"); 61 auto &b = ScopedContext::getBuilderRef(); 62 auto &block = op->getRegion(0).front(); 63 BlockAndValueMapping map; 64 map.map(block.getArguments(), indexedValues); 65 for (auto &op : block.without_terminator()) { 66 auto *newOp = b.clone(op, map); 67 map.map(op.getResults(), newOp->getResults()); 68 } 69 70 Operation &terminator = block.back(); 71 assert(isa<linalg::YieldOp>(terminator) && 72 "expected a yield op in the end of the region"); 73 for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { 74 IndexedValueType O(outputBuffers[i]); 75 O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); 76 } 77 } 78 79 // Returns a pair that contains input indices and output indices of a 80 // SingleInputPoolingOp `op`. 81 struct InputAndOutputIndices { 82 SmallVector<Value, 8> inputs; 83 SmallVector<Value, 8> outputs; 84 }; 85 template <typename SingleInputPoolingOp> 86 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs, 87 SingleInputPoolingOp op) { 88 auto &b = ScopedContext::getBuilderRef(); 89 auto loc = ScopedContext::getLocation(); 90 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 91 auto maps = llvm::to_vector<8>( 92 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 93 return InputAndOutputIndices{ 94 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 95 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 96 } 97 98 /// Emits the MLIR for the scalar part of the generic op by: 99 /// 1. Emitting load ops for each input and output view in order. This is 100 /// achieved by applying the appropriate input or output map to the 101 /// enclosing induction variables. 102 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 103 /// from point 1. above. 104 /// 3. Emitting store ops to store the results of 2. to the output 105 /// views. 106 /// 107 /// An example output may resemble: 108 /// 109 /// ``` 110 /// scf.for %i = %c0 to %0 step %c1 { 111 /// scf.for %j = %c0 to %1 step %c1 { 112 /// scf.for %k = %c0 to %4 step %c1 { 113 /// %11 = load %arg0[%i, %j] : 114 /// memref<?x?xf32, stride_specification> 115 /// %12 = load %arg1[%i, %j, %k] : 116 /// memref<?x?x?xf32, stride_specification> 117 /// %13 = load %arg2[%i, %k, %j] : 118 /// memref<?x?x?xf32, stride_specification> 119 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 120 /// store %14#0, %arg1[%i, %j, %k] : 121 /// memref<?x?x?Xf32, stride_specification> 122 /// store %14#1, %arg2[%i, %k, %j] : 123 /// memref<?x?x?Xf32, stride_specification> 124 /// } 125 /// } 126 /// } 127 /// ``` 128 template <typename IndexedValueType> 129 static void emitScalarImplementation(ArrayRef<Value> allIvs, 130 LinalgOp linalgOp) { 131 assert(linalgOp.hasBufferSemantics() && 132 "expected linalg op with buffer semantics"); 133 auto &b = ScopedContext::getBuilderRef(); 134 auto loc = ScopedContext::getLocation(); 135 unsigned nInputs = linalgOp.getNumInputs(); 136 unsigned nOutputs = linalgOp.getNumOutputs(); 137 SmallVector<Value, 4> indexedValues; 138 indexedValues.reserve(nInputs + nOutputs); 139 140 auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end()); 141 142 // TODO: Avoid the loads if the corresponding argument of the 143 // region has no uses. 144 // 1.a. Emit load from input views. 145 for (unsigned i = 0; i < nInputs; ++i) { 146 auto indexing = makeCanonicalAffineApplies( 147 b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims); 148 // Passing through IndexedValueType emits the proper load operation. 149 indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); 150 } 151 // 1.b. Emit load from output views. 152 for (unsigned i = 0; i < nOutputs; ++i) { 153 auto indexing = makeCanonicalAffineApplies( 154 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims); 155 // Passing through IndexedValueType emits the proper load operation. 156 indexedValues.push_back( 157 IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); 158 } 159 160 // TODO: When a region inliner exists, use it. 161 // 2. Inline region, currently only works for a single basic block. 162 // 3. Emit store. 163 SmallVector<SmallVector<Value, 8>, 8> indexing; 164 SmallVector<Value, 8> outputBuffers; 165 for (unsigned i = 0; i < nOutputs; ++i) { 166 indexing.push_back(makeCanonicalAffineApplies( 167 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims)); 168 outputBuffers.push_back(linalgOp.getOutputBuffer(i)); 169 } 170 inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing, 171 outputBuffers); 172 } 173 174 // Create a padded view into the given `input` tensor using the 'indices' 175 // to access the tensor. `skipPadding` lists the dimensions for which no padding 176 // is needed e.g. the non-spatial dimensions for convolutions. 177 template <typename IndexedValueType> 178 Value getPaddedInput(Value input, ArrayRef<Value> indices, 179 ArrayRef<int> skipPadding, Value padValue) { 180 // TODO: add a level of indirection to linalg.generic. 181 182 IndexedValueType indexedInput(input); 183 184 auto *context = ScopedContext::getContext(); 185 Value zeroIndex = std_constant_index(0); 186 SmallVector<Value, 8> conds; 187 SmallVector<Value, 8> clampedImIdx; 188 for (auto iter : llvm::enumerate(indices)) { 189 int idx = iter.index(); 190 auto dim = iter.value(); 191 if (is_contained(skipPadding, idx)) { 192 clampedImIdx.push_back(dim); 193 continue; 194 } 195 196 using edsc::op::sge; 197 using edsc::op::slt; 198 using edsc::op::operator||; 199 Value leftOutOfBound = slt(dim, zeroIndex); 200 if (conds.empty()) 201 conds.push_back(leftOutOfBound); 202 else 203 conds.push_back(conds.back() || leftOutOfBound); 204 Value rightBound = memref_dim(input, idx); 205 conds.push_back(conds.back() || (sge(dim, rightBound))); 206 207 // When padding is involved, the indices will only be shifted to negative, 208 // so having a max op is enough. 209 auto maxMap = AffineMap::get(/*dimCount=*/1, 0, 210 {getAffineDimExpr(/*position=*/0, context), 211 getAffineConstantExpr(0, context)}, 212 context); 213 clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim})); 214 } 215 216 Value readInput = indexedInput(clampedImIdx); 217 return conds.empty() ? readInput 218 : (Value)std_select(conds.back(), padValue, readInput); 219 } 220 221 namespace { 222 223 /// The padding value for a given Op depends on the semantics of the Op. 224 /// The identity value for ConvOp and PoolingSumOp is 0, for PoolingMaxOp is 225 /// -inf or minInt and for PoolingMinOp is inf or maxInt. 226 template <typename OpType> Attribute getPadValueAttr(Type type) { 227 llvm_unreachable("Unexpected op type for getPadValueAttr"); 228 return {}; 229 } 230 231 template <> Attribute getPadValueAttr<PoolingMaxOp>(Type type) { 232 auto &b = ScopedContext::getBuilderRef(); 233 if (auto floatType = type.dyn_cast<FloatType>()) { 234 return b.getFloatAttr( 235 floatType, 236 APFloat::getInf(floatType.getFloatSemantics(), /*Negative*/ true)); 237 } 238 if (auto intType = type.dyn_cast<IntegerType>()) { 239 unsigned width = intType.getWidth(); 240 // The select instruction used to lower the PoolingMin uses a signed 241 // comparison, use a signed constant irrespective of the signedness of the 242 // integer type. 243 return b.getIntegerAttr(intType, APInt::getSignedMinValue(width)); 244 } 245 llvm_unreachable("Unsupported data type for PoolingMaxOp"); 246 return {}; 247 } 248 249 template <> Attribute getPadValueAttr<PoolingMinOp>(Type type) { 250 auto &b = ScopedContext::getBuilderRef(); 251 if (auto floatType = type.dyn_cast<FloatType>()) { 252 return b.getFloatAttr(floatType, 253 APFloat::getInf(floatType.getFloatSemantics())); 254 } 255 if (auto intType = type.dyn_cast<IntegerType>()) { 256 unsigned width = intType.getWidth(); 257 // The select instruction used to lower the PoolingMin uses a signed 258 // comparison, use a signed constant irrespective of the signedness of the 259 // integer type. 260 return b.getIntegerAttr(intType, APInt::getSignedMaxValue(width)); 261 } 262 llvm_unreachable("Unsupported data type for PoolingMinOp"); 263 return {}; 264 } 265 266 template <> Attribute getPadValueAttr<PoolingSumOp>(Type type) { 267 auto &b = ScopedContext::getBuilderRef(); 268 return b.getZeroAttr(type); 269 } 270 271 template <> Attribute getPadValueAttr<ConvOp>(Type type) { 272 auto &b = ScopedContext::getBuilderRef(); 273 return b.getZeroAttr(type); 274 } 275 276 } // namespace 277 278 /// Returns true is `convOp` has a non-zero padding. 279 static bool hasPadding(ConvOp convOp) { 280 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 281 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 282 return true; 283 } 284 return false; 285 } 286 287 template <typename IndexedValueType> 288 static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { 289 assert(convOp.hasBufferSemantics() && 290 "expected linalg op with buffer semantics"); 291 auto &b = ScopedContext::getBuilderRef(); 292 auto loc = ScopedContext::getLocation(); 293 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 294 auto maps = llvm::to_vector<8>( 295 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 296 SmallVector<Value, 8> fIdx( 297 makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 298 SmallVector<Value, 8> imIdx( 299 makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 300 SmallVector<Value, 8> oIdx( 301 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 302 303 IndexedValueType F(convOp.filter()), O(convOp.output()); 304 305 // Emit scalar form. Padded conv involves an affine.max in the memory access 306 // which is not allowed by affine.load. Override to use an MemRefIndexedValue 307 // when there is non-zero padding. 308 if (hasPadding(convOp)) { 309 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 310 Value padValue = std_constant(type, getPadValueAttr<ConvOp>(type)); 311 Value paddedInput = getPaddedInput<MemRefIndexedValue>( 312 convOp.input(), imIdx, 313 /* Only need to pad the window dimensions */ 314 {0, static_cast<int>(imIdx.size()) - 1}, padValue); 315 O(oIdx) += F(fIdx) * paddedInput; 316 } else { 317 IndexedValueType I(convOp.input()); 318 O(oIdx) += F(fIdx) * I(imIdx); 319 } 320 } 321 322 template <typename PoolingOp> static bool hasPadding(PoolingOp poolingOp) { 323 for (unsigned i = 0, e = poolingOp.getNumWindowLoops(); i < e; ++i) { 324 if (poolingOp.getLowPad(i) > 0 || poolingOp.getHighPad(i) > 0) 325 return true; 326 } 327 return false; 328 } 329 330 template <typename IndexedValueType, typename PoolingOp> 331 static Value getPoolingInput(PoolingOp op, ArrayRef<Value> inputIndices) { 332 if (hasPadding(op)) { 333 Type type = 334 op.input().getType().template cast<MemRefType>().getElementType(); 335 Value padValue = std_constant(type, getPadValueAttr<PoolingOp>(type)); 336 return getPaddedInput<MemRefIndexedValue>(op.input(), inputIndices, 337 /*Pad every dimension*/ {}, 338 padValue); 339 } 340 IndexedValueType input(op.input()); 341 return input(inputIndices); 342 } 343 344 template <typename IndexedValueType, typename OpType> 345 void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs, OpType op) { 346 InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); 347 // Emit scalar form. 348 IndexedValueType output(op.output()); 349 Value lhs = output(indices.outputs); 350 Value rhs = getPoolingInput<IndexedValueType>(op, indices.inputs); 351 using edsc::op::sgt; 352 using edsc::op::slt; 353 Value value = std::is_same<OpType, PoolingMinOp>() 354 ? std_select(slt(lhs, rhs), lhs, rhs) 355 : std_select(sgt(lhs, rhs), lhs, rhs); 356 output(indices.outputs) = value; 357 } 358 359 template <typename IndexedValueType> 360 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) { 361 emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs, 362 op); 363 } 364 365 template <typename IndexedValueType> 366 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) { 367 emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs, 368 op); 369 } 370 371 template <typename IndexedValueType> 372 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) { 373 auto indices = getInputAndOutputIndices(allIvs, op); 374 IndexedValueType output(op.output()); 375 376 // Emit scalar form. 377 output(indices.outputs) += 378 getPoolingInput<IndexedValueType>(op, indices.inputs); 379 } 380 381 /// Emits the MLIR for the scalar part of the indexed generic op by: 382 /// 1. Emitting load ops for each input and output view in order. This is 383 /// achieved by applying the appropriate input or output map to the 384 /// enclosing induction variables. 385 /// 2. Emitting a call to `op.fun()` that takes as arguments the induction 386 /// variables and the scalars from point 1. above. 387 /// 3. Emitting store ops to store the results of 2. to the output views. 388 /// 389 /// An example output may resemble: 390 /// 391 /// ``` 392 /// scf.for %i = %c0 to %0 step %c1 { 393 /// scf.for %j = %c0 to %1 step %c1 { 394 /// scf.for %k = %c0 to %4 step %c1 { 395 /// %11 = load %arg0[%i, %j] : 396 /// memref<?x?xf32, stride_specification> 397 /// %12 = load %arg1[%i, %j, %k] : 398 /// memref<?x?x?xf32, stride_specification> 399 /// %13 = load %arg2[%i, %k, %j] : 400 /// memref<?x?x?xf32, stride_specification> 401 /// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : 402 /// (index, index, index, f32, f32, f32) -> (f32, f32) 403 /// store %14#0, %arg1[%i, %j, %k] : 404 /// memref<?x?x?Xf32, stride_specification> 405 /// store %14#1, %arg2[%i, %k, %j] : 406 /// memref<?x?x?Xf32, stride_specification> 407 /// } 408 /// } 409 /// } 410 /// ``` 411 template <typename IndexedValueType> 412 static void emitScalarImplementation(ArrayRef<Value> allIvs, 413 IndexedGenericOp indexedGenericOp) { 414 assert(indexedGenericOp.hasBufferSemantics() && 415 "expected linalg op with buffer semantics"); 416 auto &b = ScopedContext::getBuilderRef(); 417 auto loc = ScopedContext::getLocation(); 418 unsigned nInputs = indexedGenericOp.getNumInputs(); 419 unsigned nOutputs = indexedGenericOp.getNumOutputs(); 420 unsigned nLoops = allIvs.size(); 421 SmallVector<Value, 4> indexedValues; 422 indexedValues.reserve(nLoops + nInputs + nOutputs); 423 for (unsigned i = 0; i < nLoops; ++i) 424 indexedValues.push_back(allIvs[i]); 425 426 // TODO: Avoid the loads if the corresponding argument of the 427 // region has no uses. 428 // 1.a. Emit load from input views. 429 for (unsigned i = 0; i < nInputs; ++i) { 430 auto indexing = makeCanonicalAffineApplies( 431 b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); 432 // Pass input i through IndexedValueType emits the proper load operation. 433 indexedValues.push_back( 434 IndexedValueType(indexedGenericOp.getInput(i))(indexing)); 435 } 436 // 1.b. Emit load from output views. 437 for (unsigned i = 0; i < nOutputs; ++i) { 438 auto indexing = makeCanonicalAffineApplies( 439 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); 440 // Pass output i through IndexedValueType emits the proper load operation. 441 indexedValues.push_back( 442 IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); 443 } 444 445 // TODO: When a region inliner exists, use it. 446 // 2. Inline region, currently only works for a single basic block. 447 // 3. Emit store. 448 SmallVector<SmallVector<Value, 8>, 8> indexing; 449 SmallVector<Value, 8> outputBuffers; 450 for (unsigned i = 0; i < nOutputs; ++i) { 451 indexing.push_back(makeCanonicalAffineApplies( 452 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); 453 outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); 454 } 455 inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues, 456 indexing, outputBuffers); 457 } 458 459 template <typename LoopTy> 460 static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, 461 OpBuilder &builder) { 462 using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; 463 ScopedContext scope(builder, op->getLoc()); 464 465 // The flattened loopToOperandRangesMaps is expected to be an invertible 466 // permutation map (which is asserted in the inverse calculation). 467 auto linalgOp = cast<LinalgOp>(op); 468 assert(linalgOp.hasBufferSemantics() && 469 "expected linalg op with buffer semantics"); 470 471 auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc()); 472 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); 473 474 SmallVector<Value, 4> allIvs; 475 GenerateLoopNest<LoopTy>::doit( 476 loopRanges, linalgOp, iteratorTypes, 477 [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { 478 assert(iterArgs.empty() && "unexpected iterArgs"); 479 allIvs.append(ivs.begin(), ivs.end()); 480 llvm::TypeSwitch<Operation *>(op) 481 .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, 482 IndexedGenericOp, LinalgOp>([&](auto op) { 483 emitScalarImplementation<IndexedValueTy>(allIvs, op); 484 }) 485 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 486 return scf::ValueVector{}; 487 }); 488 // Number of loop ops might be different from the number of ivs since some 489 // loops like affine.parallel and scf.parallel have multiple ivs. 490 SetVector<Operation *> loopSet; 491 for (Value iv : allIvs) { 492 if (!iv) 493 return {}; 494 // The induction variable is a block argument of the entry block of the 495 // loop operation. 496 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 497 if (!ivVal) 498 return {}; 499 loopSet.insert(ivVal.getOwner()->getParentOp()); 500 } 501 LinalgLoops loops(loopSet.begin(), loopSet.end()); 502 return loops; 503 } 504 505 /// Replace the index operations in the body of the loop nest by the matching 506 /// induction variables. 507 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, 508 PatternRewriter &rewriter, 509 ArrayRef<Operation *> loopOps) { 510 // Extract the induction variables of the loop nest from outer to inner. 511 SmallVector<Value> allIvs; 512 for (Operation *loopOp : loopOps) { 513 llvm::TypeSwitch<Operation *>(loopOp) 514 .Case([&](scf::ParallelOp parallelOp) { 515 allIvs.append(parallelOp.getInductionVars().begin(), 516 parallelOp.getInductionVars().end()); 517 }) 518 .Case([&](scf::ForOp forOp) { 519 allIvs.push_back(forOp.getInductionVar()); 520 }) 521 .Case([&](AffineForOp affineForOp) { 522 allIvs.push_back(affineForOp.getInductionVar()); 523 }) 524 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 525 } 526 assert(linalgOp.getNumLoops() == allIvs.size() && 527 "expected the number of loops and induction variables to match"); 528 // Replace the index operations in the body of the innermost loop op. 529 if (!loopOps.empty()) { 530 LoopLikeOpInterface loopOp = loopOps.back(); 531 for (IndexOp indexOp : 532 llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) 533 rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); 534 } 535 } 536 537 namespace { 538 template <typename LoopType> 539 class LinalgRewritePattern : public RewritePattern { 540 public: 541 LinalgRewritePattern(MLIRContext *context) 542 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 543 544 LogicalResult matchAndRewrite(Operation *op, 545 PatternRewriter &rewriter) const override { 546 auto linalgOp = dyn_cast<LinalgOp>(op); 547 if (!isa<LinalgOp>(op)) 548 return failure(); 549 Optional<LinalgLoops> loopOps = linalgOpToLoopsImpl<LoopType>(op, rewriter); 550 if (!loopOps.hasValue()) 551 return failure(); 552 replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue()); 553 rewriter.eraseOp(op); 554 return success(); 555 } 556 }; 557 558 struct FoldAffineOp; 559 } // namespace 560 561 template <typename LoopType> 562 static void lowerLinalgToLoopsImpl(FuncOp funcOp) { 563 MLIRContext *context = funcOp.getContext(); 564 RewritePatternSet patterns(context); 565 patterns.add<LinalgRewritePattern<LoopType>>(context); 566 memref::DimOp::getCanonicalizationPatterns(patterns, context); 567 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 568 patterns.add<FoldAffineOp>(context); 569 // Just apply the patterns greedily. 570 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 571 } 572 573 namespace { 574 /// Local folding pattern for AffineApplyOp that we can apply greedily. 575 /// This replaces AffineApplyOp by the proper value in cases where the 576 /// associated map is trivial. 577 /// A trivial map here is defined as a map with a single result and either: 578 /// 1. Zero operand + returns a single AffineConstantExpr 579 /// 2. One operand + returns a single AffineDimExpr 580 /// 3. One operand + returns a single AffineSymbolExpr 581 // 582 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 583 /// other cases, it is replaced by its unique operand. 584 struct FoldAffineOp : public RewritePattern { 585 FoldAffineOp(MLIRContext *context) 586 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 587 588 LogicalResult matchAndRewrite(Operation *op, 589 PatternRewriter &rewriter) const override { 590 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 591 auto map = affineApplyOp.getAffineMap(); 592 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 593 return failure(); 594 595 AffineExpr expr = map.getResult(0); 596 if (map.getNumInputs() == 0) { 597 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 598 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 599 return success(); 600 } 601 return failure(); 602 } 603 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 604 rewriter.replaceOp(op, op->getOperand(0)); 605 return success(); 606 } 607 return failure(); 608 } 609 }; 610 611 struct LowerToAffineLoops 612 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 613 void getDependentDialects(DialectRegistry ®istry) const override { 614 registry.insert<memref::MemRefDialect>(); 615 } 616 void runOnFunction() override { 617 lowerLinalgToLoopsImpl<AffineForOp>(getFunction()); 618 } 619 }; 620 621 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 622 void getDependentDialects(DialectRegistry ®istry) const override { 623 registry.insert<memref::MemRefDialect, scf::SCFDialect>(); 624 } 625 void runOnFunction() override { 626 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction()); 627 } 628 }; 629 630 struct LowerToParallelLoops 631 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 632 void runOnFunction() override { 633 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction()); 634 } 635 }; 636 } // namespace 637 638 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 639 return std::make_unique<LowerToLoops>(); 640 } 641 642 std::unique_ptr<OperationPass<FuncOp>> 643 mlir::createConvertLinalgToParallelLoopsPass() { 644 return std::make_unique<LowerToParallelLoops>(); 645 } 646 647 std::unique_ptr<OperationPass<FuncOp>> 648 mlir::createConvertLinalgToAffineLoopsPass() { 649 return std::make_unique<LowerToAffineLoops>(); 650 } 651 652 /// Emits a loop nest with the proper body for `op`. 653 template <typename LoopTy> 654 Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, 655 Operation *op) { 656 return linalgOpToLoopsImpl<LoopTy>(op, builder); 657 } 658 659 template Optional<LinalgLoops> 660 mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder, 661 Operation *op); 662 template Optional<LinalgLoops> 663 mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder, 664 Operation *op); 665 template Optional<LinalgLoops> 666 mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder, 667 Operation *op); 668 669 /// Emits a loop nest of `affine.for` with the proper body for `op`. 670 LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, 671 Operation *op) { 672 Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op); 673 return loops ? success() : failure(); 674 } 675 676 /// Emits a loop nest of `scf.for` with the proper body for `op`. 677 LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { 678 Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op); 679 return loops ? success() : failure(); 680 } 681 682 /// Emits a loop nest of `scf.parallel` with the proper body for `op`. 683 LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, 684 Operation *op) { 685 Optional<LinalgLoops> loops = 686 linalgLowerOpToLoops<scf::ParallelOp>(builder, op); 687 return loops ? success() : failure(); 688 } 689