1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/SCF/EDSC/Builders.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "mlir/Support/LLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/FoldUtils.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 29 using namespace mlir; 30 using namespace mlir::edsc; 31 using namespace mlir::edsc::intrinsics; 32 using namespace mlir::linalg; 33 34 using edsc::op::operator+; 35 36 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b, 37 Location loc, 38 AffineMap map, 39 ArrayRef<Value> vals) { 40 if (map.isEmpty()) 41 return {}; 42 43 assert(map.getNumInputs() == vals.size()); 44 SmallVector<Value, 8> res; 45 res.reserve(map.getNumResults()); 46 auto dims = map.getNumDims(); 47 for (auto e : map.getResults()) { 48 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 49 SmallVector<Value, 4> operands(vals.begin(), vals.end()); 50 canonicalizeMapAndOperands(&exprMap, &operands); 51 res.push_back(affine_apply(exprMap, operands)); 52 } 53 return res; 54 } 55 56 template <typename IndexedValueType, typename OpType> 57 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues, 58 ArrayRef<SmallVector<Value, 8>> indexing, 59 ArrayRef<Value> outputBuffers) { 60 assert(op->getNumRegions() == 1 && "Expected single region op"); 61 auto &b = ScopedContext::getBuilderRef(); 62 auto &block = op->getRegion(0).front(); 63 BlockAndValueMapping map; 64 map.map(block.getArguments(), indexedValues); 65 for (auto &op : block.without_terminator()) { 66 auto *newOp = b.clone(op, map); 67 map.map(op.getResults(), newOp->getResults()); 68 } 69 70 Operation &terminator = block.back(); 71 assert(isa<linalg::YieldOp>(terminator) && 72 "expected a yield op in the end of the region"); 73 for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { 74 IndexedValueType O(outputBuffers[i]); 75 O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); 76 } 77 } 78 79 // Returns a pair that contains input indices and output indices of a 80 // SingleInputPoolingOp `op`. 81 struct InputAndOutputIndices { 82 SmallVector<Value, 8> inputs; 83 SmallVector<Value, 8> outputs; 84 }; 85 template <typename SingleInputPoolingOp> 86 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs, 87 SingleInputPoolingOp op) { 88 auto &b = ScopedContext::getBuilderRef(); 89 auto loc = ScopedContext::getLocation(); 90 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 91 auto maps = llvm::to_vector<8>( 92 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 93 return InputAndOutputIndices{ 94 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 95 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 96 } 97 98 /// Emits the MLIR for the scalar part of the generic op by: 99 /// 1. Emitting load ops for each input and output view in order. This is 100 /// achieved by applying the appropriate input or output map to the 101 /// enclosing induction variables. 102 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 103 /// from point 1. above. 104 /// 3. Emitting store ops to store the results of 2. to the output 105 /// views. 106 /// 107 /// An example output may resemble: 108 /// 109 /// ``` 110 /// scf.for %i = %c0 to %0 step %c1 { 111 /// scf.for %j = %c0 to %1 step %c1 { 112 /// scf.for %k = %c0 to %4 step %c1 { 113 /// %11 = load %arg0[%i, %j] : 114 /// memref<?x?xf32, stride_specification> 115 /// %12 = load %arg1[%i, %j, %k] : 116 /// memref<?x?x?xf32, stride_specification> 117 /// %13 = load %arg2[%i, %k, %j] : 118 /// memref<?x?x?xf32, stride_specification> 119 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 120 /// store %14#0, %arg1[%i, %j, %k] : 121 /// memref<?x?x?Xf32, stride_specification> 122 /// store %14#1, %arg2[%i, %k, %j] : 123 /// memref<?x?x?Xf32, stride_specification> 124 /// } 125 /// } 126 /// } 127 /// ``` 128 template <typename IndexedValueType> 129 static void emitScalarImplementation(ArrayRef<Value> allIvs, 130 LinalgOp linalgOp) { 131 assert(linalgOp.hasBufferSemantics() && 132 "expected linalg op with buffer semantics"); 133 auto &b = ScopedContext::getBuilderRef(); 134 auto loc = ScopedContext::getLocation(); 135 unsigned nInputs = linalgOp.getNumInputs(); 136 unsigned nOutputs = linalgOp.getNumOutputs(); 137 SmallVector<Value, 4> indexedValues; 138 indexedValues.reserve(nInputs + nOutputs); 139 140 auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end()); 141 142 // TODO: Avoid the loads if the corresponding argument of the 143 // region has no uses. 144 // 1.a. Emit load from input views. 145 for (unsigned i = 0; i < nInputs; ++i) { 146 auto indexing = makeCanonicalAffineApplies( 147 b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims); 148 // Passing through IndexedValueType emits the proper load operation. 149 indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); 150 } 151 // 1.b. Emit load from output views. 152 for (unsigned i = 0; i < nOutputs; ++i) { 153 auto indexing = makeCanonicalAffineApplies( 154 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims); 155 // Passing through IndexedValueType emits the proper load operation. 156 indexedValues.push_back( 157 IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); 158 } 159 160 // TODO: When a region inliner exists, use it. 161 // 2. Inline region, currently only works for a single basic block. 162 // 3. Emit store. 163 SmallVector<SmallVector<Value, 8>, 8> indexing; 164 SmallVector<Value, 8> outputBuffers; 165 for (unsigned i = 0; i < nOutputs; ++i) { 166 indexing.push_back(makeCanonicalAffineApplies( 167 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims)); 168 outputBuffers.push_back(linalgOp.getOutputBuffer(i)); 169 } 170 inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing, 171 outputBuffers); 172 } 173 174 // Create a padded view into the given `input` tensor using the 'indices' 175 // to access the tensor. `skipPadding` lists the dimensions for which no padding 176 // is needed e.g. the non-spatial dimensions for convolutions. 177 template <typename IndexedValueType> 178 Value getPaddedInput(Value input, ArrayRef<Value> indices, 179 ArrayRef<int> skipPadding, Value padValue) { 180 // TODO: add a level of indirection to linalg.generic. 181 182 IndexedValueType indexedInput(input); 183 184 auto *context = ScopedContext::getContext(); 185 Value zeroIndex = std_constant_index(0); 186 SmallVector<Value, 8> conds; 187 SmallVector<Value, 8> clampedImIdx; 188 for (auto iter : llvm::enumerate(indices)) { 189 int idx = iter.index(); 190 auto dim = iter.value(); 191 if (is_contained(skipPadding, idx)) { 192 clampedImIdx.push_back(dim); 193 continue; 194 } 195 196 using edsc::op::sge; 197 using edsc::op::slt; 198 using edsc::op::operator||; 199 Value leftOutOfBound = slt(dim, zeroIndex); 200 if (conds.empty()) 201 conds.push_back(leftOutOfBound); 202 else 203 conds.push_back(conds.back() || leftOutOfBound); 204 Value rightBound = memref_dim(input, idx); 205 conds.push_back(conds.back() || (sge(dim, rightBound))); 206 207 // When padding is involved, the indices will only be shifted to negative, 208 // so having a max op is enough. 209 auto maxMap = AffineMap::get(/*dimCount=*/1, 0, 210 {getAffineDimExpr(/*position=*/0, context), 211 getAffineConstantExpr(0, context)}, 212 context); 213 clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim})); 214 } 215 216 Value readInput = indexedInput(clampedImIdx); 217 return conds.empty() ? readInput 218 : (Value)std_select(conds.back(), padValue, readInput); 219 } 220 221 namespace { 222 223 /// The padding value for a given Op depends on the semantics of the Op. 224 /// The identity value for ConvOp and PoolingSumOp is 0, for PoolingMaxOp is 225 /// -inf or minInt and for PoolingMinOp is inf or maxInt. 226 template <typename OpType> Attribute getPadValueAttr(Type type) { 227 llvm_unreachable("Unexpected op type for getPadValueAttr"); 228 return {}; 229 } 230 231 template <> Attribute getPadValueAttr<PoolingMaxOp>(Type type) { 232 auto &b = ScopedContext::getBuilderRef(); 233 if (auto floatType = type.dyn_cast<FloatType>()) { 234 return b.getFloatAttr( 235 floatType, 236 APFloat::getInf(floatType.getFloatSemantics(), /*Negative*/ true)); 237 } 238 if (auto intType = type.dyn_cast<IntegerType>()) { 239 unsigned width = intType.getWidth(); 240 // The select instruction used to lower the PoolingMin uses a signed 241 // comparison, use a signed constant irrespective of the signedness of the 242 // integer type. 243 return b.getIntegerAttr(intType, APInt::getSignedMinValue(width)); 244 } 245 llvm_unreachable("Unsupported data type for PoolingMaxOp"); 246 return {}; 247 } 248 249 template <> Attribute getPadValueAttr<PoolingMinOp>(Type type) { 250 auto &b = ScopedContext::getBuilderRef(); 251 if (auto floatType = type.dyn_cast<FloatType>()) { 252 return b.getFloatAttr(floatType, 253 APFloat::getInf(floatType.getFloatSemantics())); 254 } 255 if (auto intType = type.dyn_cast<IntegerType>()) { 256 unsigned width = intType.getWidth(); 257 // The select instruction used to lower the PoolingMin uses a signed 258 // comparison, use a signed constant irrespective of the signedness of the 259 // integer type. 260 return b.getIntegerAttr(intType, APInt::getSignedMaxValue(width)); 261 } 262 llvm_unreachable("Unsupported data type for PoolingMinOp"); 263 return {}; 264 } 265 266 template <> Attribute getPadValueAttr<PoolingSumOp>(Type type) { 267 auto &b = ScopedContext::getBuilderRef(); 268 return b.getZeroAttr(type); 269 } 270 271 template <> Attribute getPadValueAttr<ConvOp>(Type type) { 272 auto &b = ScopedContext::getBuilderRef(); 273 return b.getZeroAttr(type); 274 } 275 276 } // namespace 277 278 /// Returns true is `convOp` has a non-zero padding. 279 static bool hasPadding(ConvOp convOp) { 280 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 281 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 282 return true; 283 } 284 return false; 285 } 286 287 template <typename IndexedValueType> 288 static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { 289 assert(convOp.hasBufferSemantics() && 290 "expected linalg op with buffer semantics"); 291 auto &b = ScopedContext::getBuilderRef(); 292 auto loc = ScopedContext::getLocation(); 293 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 294 auto maps = llvm::to_vector<8>( 295 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 296 SmallVector<Value, 8> fIdx( 297 makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 298 SmallVector<Value, 8> imIdx( 299 makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 300 SmallVector<Value, 8> oIdx( 301 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 302 303 IndexedValueType F(convOp.filter()), O(convOp.output()); 304 305 // Emit scalar form. Padded conv involves an affine.max in the memory access 306 // which is not allowed by affine.load. Override to use an MemRefIndexedValue 307 // when there is non-zero padding. 308 if (hasPadding(convOp)) { 309 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 310 Value padValue = std_constant(type, getPadValueAttr<ConvOp>(type)); 311 Value paddedInput = getPaddedInput<MemRefIndexedValue>( 312 convOp.input(), imIdx, 313 /* Only need to pad the window dimensions */ 314 {0, static_cast<int>(imIdx.size()) - 1}, padValue); 315 O(oIdx) += F(fIdx) * paddedInput; 316 } else { 317 IndexedValueType I(convOp.input()); 318 O(oIdx) += F(fIdx) * I(imIdx); 319 } 320 } 321 322 template <typename PoolingOp> static bool hasPadding(PoolingOp poolingOp) { 323 for (unsigned i = 0, e = poolingOp.getNumWindowLoops(); i < e; ++i) { 324 if (poolingOp.getLowPad(i) > 0 || poolingOp.getHighPad(i) > 0) 325 return true; 326 } 327 return false; 328 } 329 330 template <typename IndexedValueType, typename PoolingOp> 331 static Value getPoolingInput(PoolingOp op, ArrayRef<Value> inputIndices) { 332 if (hasPadding(op)) { 333 Type type = 334 op.input().getType().template cast<MemRefType>().getElementType(); 335 Value padValue = std_constant(type, getPadValueAttr<PoolingOp>(type)); 336 return getPaddedInput<MemRefIndexedValue>(op.input(), inputIndices, 337 /*Pad every dimension*/ {}, 338 padValue); 339 } 340 IndexedValueType input(op.input()); 341 return input(inputIndices); 342 } 343 344 template <typename IndexedValueType, typename OpType> 345 void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs, OpType op) { 346 InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); 347 // Emit scalar form. 348 IndexedValueType output(op.output()); 349 Value lhs = output(indices.outputs); 350 Value rhs = getPoolingInput<IndexedValueType>(op, indices.inputs); 351 using edsc::op::sgt; 352 using edsc::op::slt; 353 Value value = std::is_same<OpType, PoolingMinOp>() 354 ? std_select(slt(lhs, rhs), lhs, rhs) 355 : std_select(sgt(lhs, rhs), lhs, rhs); 356 output(indices.outputs) = value; 357 } 358 359 template <typename IndexedValueType> 360 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) { 361 emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs, 362 op); 363 } 364 365 template <typename IndexedValueType> 366 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) { 367 emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs, 368 op); 369 } 370 371 template <typename IndexedValueType> 372 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) { 373 auto indices = getInputAndOutputIndices(allIvs, op); 374 IndexedValueType output(op.output()); 375 376 // Emit scalar form. 377 output(indices.outputs) += 378 getPoolingInput<IndexedValueType>(op, indices.inputs); 379 } 380 381 /// Emits the MLIR for the scalar part of the indexed generic op by: 382 /// 1. Emitting load ops for each input and output view in order. This is 383 /// achieved by applying the appropriate input or output map to the 384 /// enclosing induction variables. 385 /// 2. Emitting a call to `op.fun()` that takes as arguments the induction 386 /// variables and the scalars from point 1. above. 387 /// 3. Emitting store ops to store the results of 2. to the output views. 388 /// 389 /// An example output may resemble: 390 /// 391 /// ``` 392 /// scf.for %i = %c0 to %0 step %c1 { 393 /// scf.for %j = %c0 to %1 step %c1 { 394 /// scf.for %k = %c0 to %4 step %c1 { 395 /// %11 = load %arg0[%i, %j] : 396 /// memref<?x?xf32, stride_specification> 397 /// %12 = load %arg1[%i, %j, %k] : 398 /// memref<?x?x?xf32, stride_specification> 399 /// %13 = load %arg2[%i, %k, %j] : 400 /// memref<?x?x?xf32, stride_specification> 401 /// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : 402 /// (index, index, index, f32, f32, f32) -> (f32, f32) 403 /// store %14#0, %arg1[%i, %j, %k] : 404 /// memref<?x?x?Xf32, stride_specification> 405 /// store %14#1, %arg2[%i, %k, %j] : 406 /// memref<?x?x?Xf32, stride_specification> 407 /// } 408 /// } 409 /// } 410 /// ``` 411 template <typename IndexedValueType> 412 static void emitScalarImplementation(ArrayRef<Value> allIvs, 413 IndexedGenericOp indexedGenericOp) { 414 assert(indexedGenericOp.hasBufferSemantics() && 415 "expected linalg op with buffer semantics"); 416 auto &b = ScopedContext::getBuilderRef(); 417 auto loc = ScopedContext::getLocation(); 418 unsigned nInputs = indexedGenericOp.getNumInputs(); 419 unsigned nOutputs = indexedGenericOp.getNumOutputs(); 420 unsigned nLoops = allIvs.size(); 421 SmallVector<Value, 4> indexedValues; 422 indexedValues.reserve(nLoops + nInputs + nOutputs); 423 for (unsigned i = 0; i < nLoops; ++i) 424 indexedValues.push_back(allIvs[i]); 425 426 // TODO: Avoid the loads if the corresponding argument of the 427 // region has no uses. 428 // 1.a. Emit load from input views. 429 for (unsigned i = 0; i < nInputs; ++i) { 430 auto indexing = makeCanonicalAffineApplies( 431 b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); 432 // Pass input i through IndexedValueType emits the proper load operation. 433 indexedValues.push_back( 434 IndexedValueType(indexedGenericOp.getInput(i))(indexing)); 435 } 436 // 1.b. Emit load from output views. 437 for (unsigned i = 0; i < nOutputs; ++i) { 438 auto indexing = makeCanonicalAffineApplies( 439 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); 440 // Pass output i through IndexedValueType emits the proper load operation. 441 indexedValues.push_back( 442 IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); 443 } 444 445 // TODO: When a region inliner exists, use it. 446 // 2. Inline region, currently only works for a single basic block. 447 // 3. Emit store. 448 SmallVector<SmallVector<Value, 8>, 8> indexing; 449 SmallVector<Value, 8> outputBuffers; 450 for (unsigned i = 0; i < nOutputs; ++i) { 451 indexing.push_back(makeCanonicalAffineApplies( 452 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); 453 outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); 454 } 455 inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues, 456 indexing, outputBuffers); 457 } 458 459 template <typename LoopTy> 460 static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp, 461 OpBuilder &builder) { 462 using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; 463 ScopedContext scope(builder, linalgOp.getLoc()); 464 465 // The flattened loopToOperandRangesMaps is expected to be an invertible 466 // permutation map (which is asserted in the inverse calculation). 467 assert(linalgOp.hasBufferSemantics() && 468 "expected linalg op with buffer semantics"); 469 470 auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc()); 471 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); 472 473 SmallVector<Value, 4> allIvs; 474 GenerateLoopNest<LoopTy>::doit( 475 loopRanges, linalgOp, iteratorTypes, 476 [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { 477 assert(iterArgs.empty() && "unexpected iterArgs"); 478 allIvs.append(ivs.begin(), ivs.end()); 479 llvm::TypeSwitch<Operation *>(linalgOp) 480 .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, 481 IndexedGenericOp, LinalgOp>([&](auto op) { 482 emitScalarImplementation<IndexedValueTy>(allIvs, op); 483 }) 484 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 485 return scf::ValueVector{}; 486 }); 487 // Number of loop ops might be different from the number of ivs since some 488 // loops like affine.parallel and scf.parallel have multiple ivs. 489 SetVector<Operation *> loopSet; 490 for (Value iv : allIvs) { 491 if (!iv) 492 return {}; 493 // The induction variable is a block argument of the entry block of the 494 // loop operation. 495 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 496 if (!ivVal) 497 return {}; 498 loopSet.insert(ivVal.getOwner()->getParentOp()); 499 } 500 LinalgLoops loops(loopSet.begin(), loopSet.end()); 501 return loops; 502 } 503 504 /// Replace the index operations in the body of the loop nest by the matching 505 /// induction variables. 506 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, 507 PatternRewriter &rewriter, 508 ArrayRef<Operation *> loopOps) { 509 // Extract the induction variables of the loop nest from outer to inner. 510 SmallVector<Value> allIvs; 511 for (Operation *loopOp : loopOps) { 512 llvm::TypeSwitch<Operation *>(loopOp) 513 .Case([&](scf::ParallelOp parallelOp) { 514 allIvs.append(parallelOp.getInductionVars().begin(), 515 parallelOp.getInductionVars().end()); 516 }) 517 .Case([&](scf::ForOp forOp) { 518 allIvs.push_back(forOp.getInductionVar()); 519 }) 520 .Case([&](AffineForOp affineForOp) { 521 allIvs.push_back(affineForOp.getInductionVar()); 522 }) 523 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 524 } 525 assert(linalgOp.getNumLoops() == allIvs.size() && 526 "expected the number of loops and induction variables to match"); 527 // Replace the index operations in the body of the innermost loop op. 528 if (!loopOps.empty()) { 529 LoopLikeOpInterface loopOp = loopOps.back(); 530 for (IndexOp indexOp : 531 llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) 532 rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); 533 } 534 } 535 536 namespace { 537 template <typename LoopType> 538 class LinalgRewritePattern : public RewritePattern { 539 public: 540 LinalgRewritePattern(MLIRContext *context) 541 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 542 543 LogicalResult matchAndRewrite(Operation *op, 544 PatternRewriter &rewriter) const override { 545 auto linalgOp = dyn_cast<LinalgOp>(op); 546 if (!isa<LinalgOp>(op)) 547 return failure(); 548 if (!linalgLowerOpToLoops<LoopType>(rewriter, linalgOp)) 549 return failure(); 550 rewriter.eraseOp(op); 551 return success(); 552 } 553 }; 554 555 struct TiledLoopToSCFPattern : public OpRewritePattern<TiledLoopOp> { 556 using OpRewritePattern<TiledLoopOp>::OpRewritePattern; 557 558 LogicalResult matchAndRewrite(TiledLoopOp tiledLoop, 559 PatternRewriter &rewriter) const override { 560 Location loc = tiledLoop.getLoc(); 561 562 // Fail conversion if the `tiled_loop` has not been bufferized. 563 if (!llvm::all_of(tiledLoop.outputs(), [&](Value arg) { 564 return arg.getType().isa<MemRefType>(); 565 })) 566 return failure(); 567 568 // TODO: Build loop nest with `scf.for` and `scf.parallel` depending on the 569 // iterator type. 570 scf::buildLoopNest(rewriter, loc, tiledLoop.lowerBound(), 571 tiledLoop.upperBound(), tiledLoop.step(), 572 [&](OpBuilder &builder, Location loc, ValueRange ivs) { 573 // Move body without its terminator. 574 SmallVector<Value, 16> newBlockArgs; 575 newBlockArgs.append(ivs.begin(), ivs.end()); 576 newBlockArgs.append(tiledLoop.inputs().begin(), 577 tiledLoop.inputs().end()); 578 newBlockArgs.append(tiledLoop.outputs().begin(), 579 tiledLoop.outputs().end()); 580 Block *newBody = rewriter.getInsertionBlock(); 581 rewriter.mergeBlocks(tiledLoop.getBody(), newBody, 582 newBlockArgs); 583 rewriter.eraseOp(newBody->getTerminator()); 584 }); 585 rewriter.eraseOp(tiledLoop); 586 return success(); 587 } 588 }; 589 590 /// Local folding pattern for AffineApplyOp that we can apply greedily. 591 /// This replaces AffineApplyOp by the proper value in cases where the 592 /// associated map is trivial. 593 /// A trivial map here is defined as a map with a single result and either: 594 /// 1. Zero operand + returns a single AffineConstantExpr 595 /// 2. One operand + returns a single AffineDimExpr 596 /// 3. One operand + returns a single AffineSymbolExpr 597 // 598 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 599 /// other cases, it is replaced by its unique operand. 600 struct FoldAffineOp : public RewritePattern { 601 FoldAffineOp(MLIRContext *context) 602 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 603 604 LogicalResult matchAndRewrite(Operation *op, 605 PatternRewriter &rewriter) const override { 606 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 607 auto map = affineApplyOp.getAffineMap(); 608 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 609 return failure(); 610 611 AffineExpr expr = map.getResult(0); 612 if (map.getNumInputs() == 0) { 613 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 614 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 615 return success(); 616 } 617 return failure(); 618 } 619 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 620 rewriter.replaceOp(op, op->getOperand(0)); 621 return success(); 622 } 623 return failure(); 624 } 625 }; 626 627 template <typename LoopType> 628 static void lowerLinalgToLoopsImpl(FuncOp funcOp) { 629 MLIRContext *context = funcOp.getContext(); 630 RewritePatternSet patterns(context); 631 patterns.add<LinalgRewritePattern<LoopType>>(context); 632 memref::DimOp::getCanonicalizationPatterns(patterns, context); 633 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 634 patterns.add<FoldAffineOp>(context); 635 // Just apply the patterns greedily. 636 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 637 } 638 639 struct LowerToAffineLoops 640 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 641 void getDependentDialects(DialectRegistry ®istry) const override { 642 registry.insert<memref::MemRefDialect>(); 643 } 644 void runOnFunction() override { 645 lowerLinalgToLoopsImpl<AffineForOp>(getFunction()); 646 } 647 }; 648 649 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 650 void getDependentDialects(DialectRegistry ®istry) const override { 651 registry.insert<memref::MemRefDialect, scf::SCFDialect>(); 652 } 653 void runOnFunction() override { 654 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction()); 655 } 656 }; 657 658 struct LowerToParallelLoops 659 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 660 void runOnFunction() override { 661 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction()); 662 } 663 }; 664 665 struct LowerTiledLoopsToSCF 666 : public LinalgLowerTiledLoopsToSCFBase<LowerTiledLoopsToSCF> { 667 void runOnFunction() override { 668 MLIRContext *context = &getContext(); 669 RewritePatternSet patterns(context); 670 patterns.add<TiledLoopToSCFPattern>(context); 671 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 672 } 673 }; 674 } // namespace 675 676 std::unique_ptr<OperationPass<FuncOp>> 677 mlir::createConvertLinalgTiledLoopsToSCFPass() { 678 return std::make_unique<LowerTiledLoopsToSCF>(); 679 } 680 681 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 682 return std::make_unique<LowerToLoops>(); 683 } 684 685 std::unique_ptr<OperationPass<FuncOp>> 686 mlir::createConvertLinalgToParallelLoopsPass() { 687 return std::make_unique<LowerToParallelLoops>(); 688 } 689 690 std::unique_ptr<OperationPass<FuncOp>> 691 mlir::createConvertLinalgToAffineLoopsPass() { 692 return std::make_unique<LowerToAffineLoops>(); 693 } 694 695 /// Emits a loop nest with the proper body for `linalgOp`. 696 template <typename LoopTy> 697 Optional<LinalgLoops> 698 mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, 699 LinalgOp linalgOp) { 700 Optional<LinalgLoops> loopOps = 701 linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter); 702 if (loopOps.hasValue()) 703 replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue()); 704 return loopOps; 705 } 706 707 template Optional<LinalgLoops> 708 mlir::linalg::linalgLowerOpToLoops<AffineForOp>(PatternRewriter &rewriter, 709 LinalgOp linalgOp); 710 template Optional<LinalgLoops> 711 mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(PatternRewriter &rewriter, 712 LinalgOp linalgOp); 713 template Optional<LinalgLoops> 714 mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(PatternRewriter &rewriter, 715 LinalgOp linalgOp); 716 717 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. 718 LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, 719 LinalgOp linalgOp) { 720 Optional<LinalgLoops> loops = 721 linalgLowerOpToLoops<AffineForOp>(rewriter, linalgOp); 722 return loops ? success() : failure(); 723 } 724 725 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. 726 LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, 727 LinalgOp linalgOp) { 728 Optional<LinalgLoops> loops = 729 linalgLowerOpToLoops<scf::ForOp>(rewriter, linalgOp); 730 return loops ? success() : failure(); 731 } 732 733 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. 734 LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, 735 LinalgOp linalgOp) { 736 Optional<LinalgLoops> loops = 737 linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp); 738 return loops ? success() : failure(); 739 } 740