1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/SCF/EDSC/Builders.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "mlir/Support/LLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/FoldUtils.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 29 using namespace mlir; 30 using namespace mlir::edsc; 31 using namespace mlir::edsc::intrinsics; 32 using namespace mlir::linalg; 33 34 using edsc::op::operator+; 35 36 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b, 37 Location loc, 38 AffineMap map, 39 ArrayRef<Value> vals) { 40 if (map.isEmpty()) 41 return {}; 42 43 assert(map.getNumInputs() == vals.size()); 44 SmallVector<Value, 8> res; 45 res.reserve(map.getNumResults()); 46 auto dims = map.getNumDims(); 47 for (auto e : map.getResults()) { 48 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 49 SmallVector<Value, 4> operands(vals.begin(), vals.end()); 50 canonicalizeMapAndOperands(&exprMap, &operands); 51 res.push_back(affine_apply(exprMap, operands)); 52 } 53 return res; 54 } 55 56 template <typename IndexedValueType, typename OpType> 57 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues, 58 ArrayRef<SmallVector<Value, 8>> indexing, 59 ArrayRef<Value> outputBuffers) { 60 assert(op->getNumRegions() == 1 && "Expected single region op"); 61 auto &b = ScopedContext::getBuilderRef(); 62 auto &block = op->getRegion(0).front(); 63 BlockAndValueMapping map; 64 map.map(block.getArguments(), indexedValues); 65 for (auto &op : block.without_terminator()) { 66 auto *newOp = b.clone(op, map); 67 map.map(op.getResults(), newOp->getResults()); 68 } 69 70 Operation &terminator = block.back(); 71 assert(isa<linalg::YieldOp>(terminator) && 72 "expected a yield op in the end of the region"); 73 for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { 74 IndexedValueType O(outputBuffers[i]); 75 O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); 76 } 77 } 78 79 // Returns a pair that contains input indices and output indices of a 80 // SingleInputPoolingOp `op`. 81 struct InputAndOutputIndices { 82 SmallVector<Value, 8> inputs; 83 SmallVector<Value, 8> outputs; 84 }; 85 template <typename SingleInputPoolingOp> 86 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs, 87 SingleInputPoolingOp op) { 88 auto &b = ScopedContext::getBuilderRef(); 89 auto loc = ScopedContext::getLocation(); 90 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 91 auto maps = llvm::to_vector<8>( 92 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 93 return InputAndOutputIndices{ 94 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 95 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 96 } 97 98 /// Emits the MLIR for the scalar part of the generic op by: 99 /// 1. Emitting load ops for each input and output view in order. This is 100 /// achieved by applying the appropriate input or output map to the 101 /// enclosing induction variables. 102 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 103 /// from point 1. above. 104 /// 3. Emitting store ops to store the results of 2. to the output 105 /// views. 106 /// 107 /// An example output may resemble: 108 /// 109 /// ``` 110 /// scf.for %i = %c0 to %0 step %c1 { 111 /// scf.for %j = %c0 to %1 step %c1 { 112 /// scf.for %k = %c0 to %4 step %c1 { 113 /// %11 = load %arg0[%i, %j] : 114 /// memref<?x?xf32, stride_specification> 115 /// %12 = load %arg1[%i, %j, %k] : 116 /// memref<?x?x?xf32, stride_specification> 117 /// %13 = load %arg2[%i, %k, %j] : 118 /// memref<?x?x?xf32, stride_specification> 119 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 120 /// store %14#0, %arg1[%i, %j, %k] : 121 /// memref<?x?x?Xf32, stride_specification> 122 /// store %14#1, %arg2[%i, %k, %j] : 123 /// memref<?x?x?Xf32, stride_specification> 124 /// } 125 /// } 126 /// } 127 /// ``` 128 template <typename IndexedValueType> 129 static void emitScalarImplementation(ArrayRef<Value> allIvs, 130 LinalgOp linalgOp) { 131 assert(linalgOp.hasBufferSemantics() && 132 "expected linalg op with buffer semantics"); 133 auto &b = ScopedContext::getBuilderRef(); 134 auto loc = ScopedContext::getLocation(); 135 unsigned nInputs = linalgOp.getNumInputs(); 136 unsigned nOutputs = linalgOp.getNumOutputs(); 137 SmallVector<Value, 4> indexedValues; 138 indexedValues.reserve(nInputs + nOutputs); 139 140 auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end()); 141 142 // TODO: Avoid the loads if the corresponding argument of the 143 // region has no uses. 144 // 1.a. Emit load from input views. 145 for (unsigned i = 0; i < nInputs; ++i) { 146 auto indexing = makeCanonicalAffineApplies( 147 b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims); 148 // Passing through IndexedValueType emits the proper load operation. 149 indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); 150 } 151 // 1.b. Emit load from output views. 152 for (unsigned i = 0; i < nOutputs; ++i) { 153 auto indexing = makeCanonicalAffineApplies( 154 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims); 155 // Passing through IndexedValueType emits the proper load operation. 156 indexedValues.push_back( 157 IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); 158 } 159 160 // TODO: When a region inliner exists, use it. 161 // 2. Inline region, currently only works for a single basic block. 162 // 3. Emit store. 163 SmallVector<SmallVector<Value, 8>, 8> indexing; 164 SmallVector<Value, 8> outputBuffers; 165 for (unsigned i = 0; i < nOutputs; ++i) { 166 indexing.push_back(makeCanonicalAffineApplies( 167 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims)); 168 outputBuffers.push_back(linalgOp.getOutputBuffer(i)); 169 } 170 inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing, 171 outputBuffers); 172 } 173 174 // Create a padded view into the given `input` tensor using the 'indices' 175 // to access the tensor. `skipPadding` lists the dimensions for which no padding 176 // is needed e.g. the non-spatial dimensions for convolutions. 177 template <typename IndexedValueType> 178 Value getPaddedInput(Value input, ArrayRef<Value> indices, 179 ArrayRef<int> skipPadding, Value padValue) { 180 // TODO: add a level of indirection to linalg.generic. 181 182 IndexedValueType indexedInput(input); 183 184 auto *context = ScopedContext::getContext(); 185 Value zeroIndex = std_constant_index(0); 186 SmallVector<Value, 8> conds; 187 SmallVector<Value, 8> clampedImIdx; 188 for (auto iter : llvm::enumerate(indices)) { 189 int idx = iter.index(); 190 auto dim = iter.value(); 191 if (is_contained(skipPadding, idx)) { 192 clampedImIdx.push_back(dim); 193 continue; 194 } 195 196 using edsc::op::sge; 197 using edsc::op::slt; 198 using edsc::op::operator||; 199 Value leftOutOfBound = slt(dim, zeroIndex); 200 if (conds.empty()) 201 conds.push_back(leftOutOfBound); 202 else 203 conds.push_back(conds.back() || leftOutOfBound); 204 Value rightBound = memref_dim(input, idx); 205 conds.push_back(conds.back() || (sge(dim, rightBound))); 206 207 // When padding is involved, the indices will only be shifted to negative, 208 // so having a max op is enough. 209 auto maxMap = AffineMap::get(/*dimCount=*/1, 0, 210 {getAffineDimExpr(/*position=*/0, context), 211 getAffineConstantExpr(0, context)}, 212 context); 213 clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim})); 214 } 215 216 Value readInput = indexedInput(clampedImIdx); 217 return conds.empty() ? readInput 218 : (Value)std_select(conds.back(), padValue, readInput); 219 } 220 221 namespace { 222 223 /// The padding value for a given Op depends on the semantics of the Op. 224 /// The identity value for ConvOp and PoolingSumOp is 0, for PoolingMaxOp is 225 /// -inf or minInt and for PoolingMinOp is inf or maxInt. 226 template <typename OpType> 227 Attribute getPadValueAttr(Type type) { 228 llvm_unreachable("Unexpected op type for getPadValueAttr"); 229 return {}; 230 } 231 232 template <> 233 Attribute getPadValueAttr<PoolingMaxOp>(Type type) { 234 auto &b = ScopedContext::getBuilderRef(); 235 if (auto floatType = type.dyn_cast<FloatType>()) { 236 return b.getFloatAttr( 237 floatType, 238 APFloat::getInf(floatType.getFloatSemantics(), /*Negative*/ true)); 239 } 240 if (auto intType = type.dyn_cast<IntegerType>()) { 241 unsigned width = intType.getWidth(); 242 // The select instruction used to lower the PoolingMin uses a signed 243 // comparison, use a signed constant irrespective of the signedness of the 244 // integer type. 245 return b.getIntegerAttr(intType, APInt::getSignedMinValue(width)); 246 } 247 llvm_unreachable("Unsupported data type for PoolingMaxOp"); 248 return {}; 249 } 250 251 template <> 252 Attribute getPadValueAttr<PoolingMinOp>(Type type) { 253 auto &b = ScopedContext::getBuilderRef(); 254 if (auto floatType = type.dyn_cast<FloatType>()) { 255 return b.getFloatAttr(floatType, 256 APFloat::getInf(floatType.getFloatSemantics())); 257 } 258 if (auto intType = type.dyn_cast<IntegerType>()) { 259 unsigned width = intType.getWidth(); 260 // The select instruction used to lower the PoolingMin uses a signed 261 // comparison, use a signed constant irrespective of the signedness of the 262 // integer type. 263 return b.getIntegerAttr(intType, APInt::getSignedMaxValue(width)); 264 } 265 llvm_unreachable("Unsupported data type for PoolingMinOp"); 266 return {}; 267 } 268 269 template <> 270 Attribute getPadValueAttr<PoolingSumOp>(Type type) { 271 auto &b = ScopedContext::getBuilderRef(); 272 return b.getZeroAttr(type); 273 } 274 275 template <> 276 Attribute getPadValueAttr<ConvOp>(Type type) { 277 auto &b = ScopedContext::getBuilderRef(); 278 return b.getZeroAttr(type); 279 } 280 281 } // namespace 282 283 /// Returns true is `convOp` has a non-zero padding. 284 static bool hasPadding(ConvOp convOp) { 285 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 286 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 287 return true; 288 } 289 return false; 290 } 291 292 template <typename IndexedValueType> 293 static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { 294 assert(convOp.hasBufferSemantics() && 295 "expected linalg op with buffer semantics"); 296 auto &b = ScopedContext::getBuilderRef(); 297 auto loc = ScopedContext::getLocation(); 298 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 299 auto maps = llvm::to_vector<8>( 300 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 301 SmallVector<Value, 8> fIdx( 302 makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 303 SmallVector<Value, 8> imIdx( 304 makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 305 SmallVector<Value, 8> oIdx( 306 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 307 308 IndexedValueType F(convOp.filter()), O(convOp.output()); 309 310 // Emit scalar form. Padded conv involves an affine.max in the memory access 311 // which is not allowed by affine.load. Override to use an MemRefIndexedValue 312 // when there is non-zero padding. 313 if (hasPadding(convOp)) { 314 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 315 Value padValue = std_constant(type, getPadValueAttr<ConvOp>(type)); 316 Value paddedInput = getPaddedInput<MemRefIndexedValue>( 317 convOp.input(), imIdx, 318 /* Only need to pad the window dimensions */ 319 {0, static_cast<int>(imIdx.size()) - 1}, padValue); 320 O(oIdx) += F(fIdx) * paddedInput; 321 } else { 322 IndexedValueType I(convOp.input()); 323 O(oIdx) += F(fIdx) * I(imIdx); 324 } 325 } 326 327 template <typename PoolingOp> 328 static bool hasPadding(PoolingOp poolingOp) { 329 for (unsigned i = 0, e = poolingOp.getNumWindowLoops(); i < e; ++i) { 330 if (poolingOp.getLowPad(i) > 0 || poolingOp.getHighPad(i) > 0) 331 return true; 332 } 333 return false; 334 } 335 336 template <typename IndexedValueType, typename PoolingOp> 337 static Value getPoolingInput(PoolingOp op, ArrayRef<Value> inputIndices) { 338 if (hasPadding(op)) { 339 Type type = 340 op.input().getType().template cast<MemRefType>().getElementType(); 341 Value padValue = std_constant(type, getPadValueAttr<PoolingOp>(type)); 342 return getPaddedInput<MemRefIndexedValue>(op.input(), inputIndices, 343 /*Pad every dimension*/ {}, 344 padValue); 345 } 346 IndexedValueType input(op.input()); 347 return input(inputIndices); 348 } 349 350 template <typename IndexedValueType, typename OpType> 351 void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs, OpType op) { 352 InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); 353 // Emit scalar form. 354 IndexedValueType output(op.output()); 355 Value lhs = output(indices.outputs); 356 Value rhs = getPoolingInput<IndexedValueType>(op, indices.inputs); 357 using edsc::op::sgt; 358 using edsc::op::slt; 359 Value value = std::is_same<OpType, PoolingMinOp>() 360 ? std_select(slt(lhs, rhs), lhs, rhs) 361 : std_select(sgt(lhs, rhs), lhs, rhs); 362 output(indices.outputs) = value; 363 } 364 365 template <typename IndexedValueType> 366 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) { 367 emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs, 368 op); 369 } 370 371 template <typename IndexedValueType> 372 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) { 373 emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs, 374 op); 375 } 376 377 template <typename IndexedValueType> 378 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) { 379 auto indices = getInputAndOutputIndices(allIvs, op); 380 IndexedValueType output(op.output()); 381 382 // Emit scalar form. 383 output(indices.outputs) += 384 getPoolingInput<IndexedValueType>(op, indices.inputs); 385 } 386 387 /// Emits the MLIR for the scalar part of the indexed generic op by: 388 /// 1. Emitting load ops for each input and output view in order. This is 389 /// achieved by applying the appropriate input or output map to the 390 /// enclosing induction variables. 391 /// 2. Emitting a call to `op.fun()` that takes as arguments the induction 392 /// variables and the scalars from point 1. above. 393 /// 3. Emitting store ops to store the results of 2. to the output views. 394 /// 395 /// An example output may resemble: 396 /// 397 /// ``` 398 /// scf.for %i = %c0 to %0 step %c1 { 399 /// scf.for %j = %c0 to %1 step %c1 { 400 /// scf.for %k = %c0 to %4 step %c1 { 401 /// %11 = load %arg0[%i, %j] : 402 /// memref<?x?xf32, stride_specification> 403 /// %12 = load %arg1[%i, %j, %k] : 404 /// memref<?x?x?xf32, stride_specification> 405 /// %13 = load %arg2[%i, %k, %j] : 406 /// memref<?x?x?xf32, stride_specification> 407 /// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : 408 /// (index, index, index, f32, f32, f32) -> (f32, f32) 409 /// store %14#0, %arg1[%i, %j, %k] : 410 /// memref<?x?x?Xf32, stride_specification> 411 /// store %14#1, %arg2[%i, %k, %j] : 412 /// memref<?x?x?Xf32, stride_specification> 413 /// } 414 /// } 415 /// } 416 /// ``` 417 template <typename IndexedValueType> 418 static void emitScalarImplementation(ArrayRef<Value> allIvs, 419 IndexedGenericOp indexedGenericOp) { 420 assert(indexedGenericOp.hasBufferSemantics() && 421 "expected linalg op with buffer semantics"); 422 auto &b = ScopedContext::getBuilderRef(); 423 auto loc = ScopedContext::getLocation(); 424 unsigned nInputs = indexedGenericOp.getNumInputs(); 425 unsigned nOutputs = indexedGenericOp.getNumOutputs(); 426 unsigned nLoops = allIvs.size(); 427 SmallVector<Value, 4> indexedValues; 428 indexedValues.reserve(nLoops + nInputs + nOutputs); 429 for (unsigned i = 0; i < nLoops; ++i) 430 indexedValues.push_back(allIvs[i]); 431 432 // TODO: Avoid the loads if the corresponding argument of the 433 // region has no uses. 434 // 1.a. Emit load from input views. 435 for (unsigned i = 0; i < nInputs; ++i) { 436 auto indexing = makeCanonicalAffineApplies( 437 b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); 438 // Pass input i through IndexedValueType emits the proper load operation. 439 indexedValues.push_back( 440 IndexedValueType(indexedGenericOp.getInput(i))(indexing)); 441 } 442 // 1.b. Emit load from output views. 443 for (unsigned i = 0; i < nOutputs; ++i) { 444 auto indexing = makeCanonicalAffineApplies( 445 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); 446 // Pass output i through IndexedValueType emits the proper load operation. 447 indexedValues.push_back( 448 IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); 449 } 450 451 // TODO: When a region inliner exists, use it. 452 // 2. Inline region, currently only works for a single basic block. 453 // 3. Emit store. 454 SmallVector<SmallVector<Value, 8>, 8> indexing; 455 SmallVector<Value, 8> outputBuffers; 456 for (unsigned i = 0; i < nOutputs; ++i) { 457 indexing.push_back(makeCanonicalAffineApplies( 458 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); 459 outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); 460 } 461 inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues, 462 indexing, outputBuffers); 463 } 464 465 template <typename LoopTy> 466 static Optional<LinalgLoops> 467 linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, 468 ArrayRef<unsigned> interchangeVector) { 469 using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; 470 ScopedContext scope(builder, op->getLoc()); 471 472 // The flattened loopToOperandRangesMaps is expected to be an invertible 473 // permutation map (which is asserted in the inverse calculation). 474 auto linalgOp = cast<LinalgOp>(op); 475 assert(linalgOp.hasBufferSemantics() && 476 "expected linalg op with buffer semantics"); 477 478 auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc()); 479 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); 480 481 if (!interchangeVector.empty()) { 482 assert(interchangeVector.size() == loopRanges.size()); 483 assert(interchangeVector.size() == iteratorTypes.size()); 484 applyPermutationToVector(loopRanges, interchangeVector); 485 applyPermutationToVector(iteratorTypes, interchangeVector); 486 } 487 488 SmallVector<Value, 4> allIvs; 489 GenerateLoopNest<LoopTy>::doit( 490 loopRanges, /*iterInitArgs=*/{}, iteratorTypes, 491 [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { 492 assert(iterArgs.empty() && "unexpected iterArgs"); 493 allIvs.append(ivs.begin(), ivs.end()); 494 llvm::TypeSwitch<Operation *>(op) 495 .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, 496 IndexedGenericOp, LinalgOp>([&](auto op) { 497 emitScalarImplementation<IndexedValueTy>(allIvs, op); 498 }) 499 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 500 return scf::ValueVector{}; 501 }); 502 // Number of loop ops might be different from the number of ivs since some 503 // loops like affine.parallel and scf.parallel have multiple ivs. 504 llvm::SetVector<Operation *> loopSet; 505 for (Value iv : allIvs) { 506 if (!iv) 507 return {}; 508 // The induction variable is a block argument of the entry block of the 509 // loop operation. 510 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 511 if (!ivVal) 512 return {}; 513 loopSet.insert(ivVal.getOwner()->getParentOp()); 514 } 515 LinalgLoops loops(loopSet.begin(), loopSet.end()); 516 return loops; 517 } 518 519 namespace { 520 template <typename LoopType> 521 class LinalgRewritePattern : public RewritePattern { 522 public: 523 LinalgRewritePattern(MLIRContext *context, 524 ArrayRef<unsigned> interchangeVector) 525 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), 526 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 527 528 LogicalResult matchAndRewrite(Operation *op, 529 PatternRewriter &rewriter) const override { 530 if (!isa<LinalgOp>(op)) 531 return failure(); 532 if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector)) 533 return failure(); 534 rewriter.eraseOp(op); 535 return success(); 536 } 537 538 private: 539 SmallVector<unsigned, 4> interchangeVector; 540 }; 541 542 struct FoldAffineOp; 543 } // namespace 544 545 template <typename LoopType> 546 static void lowerLinalgToLoopsImpl(FuncOp funcOp, 547 ArrayRef<unsigned> interchangeVector) { 548 MLIRContext *context = funcOp.getContext(); 549 RewritePatternSet patterns(context); 550 patterns.add<LinalgRewritePattern<LoopType>>(context, interchangeVector); 551 memref::DimOp::getCanonicalizationPatterns(patterns, context); 552 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 553 patterns.add<FoldAffineOp>(context); 554 // Just apply the patterns greedily. 555 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 556 } 557 558 namespace { 559 /// Local folding pattern for AffineApplyOp that we can apply greedily. 560 /// This replaces AffineApplyOp by the proper value in cases where the 561 /// associated map is trivial. 562 /// A trivial map here is defined as a map with a single result and either: 563 /// 1. Zero operand + returns a single AffineConstantExpr 564 /// 2. One operand + returns a single AffineDimExpr 565 /// 3. One operand + returns a single AffineSymbolExpr 566 // 567 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 568 /// other cases, it is replaced by its unique operand. 569 struct FoldAffineOp : public RewritePattern { 570 FoldAffineOp(MLIRContext *context) 571 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 572 573 LogicalResult matchAndRewrite(Operation *op, 574 PatternRewriter &rewriter) const override { 575 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 576 auto map = affineApplyOp.getAffineMap(); 577 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 578 return failure(); 579 580 AffineExpr expr = map.getResult(0); 581 if (map.getNumInputs() == 0) { 582 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 583 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 584 return success(); 585 } 586 return failure(); 587 } 588 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 589 rewriter.replaceOp(op, op->getOperand(0)); 590 return success(); 591 } 592 return failure(); 593 } 594 }; 595 596 struct LowerToAffineLoops 597 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 598 void getDependentDialects(DialectRegistry ®istry) const override { 599 registry.insert<memref::MemRefDialect>(); 600 } 601 void runOnFunction() override { 602 lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), interchangeVector); 603 } 604 }; 605 606 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 607 void getDependentDialects(DialectRegistry ®istry) const override { 608 registry.insert<memref::MemRefDialect, scf::SCFDialect>(); 609 } 610 void runOnFunction() override { 611 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), interchangeVector); 612 } 613 }; 614 615 struct LowerToParallelLoops 616 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 617 void runOnFunction() override { 618 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), interchangeVector); 619 } 620 }; 621 } // namespace 622 623 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 624 return std::make_unique<LowerToLoops>(); 625 } 626 627 std::unique_ptr<OperationPass<FuncOp>> 628 mlir::createConvertLinalgToParallelLoopsPass() { 629 return std::make_unique<LowerToParallelLoops>(); 630 } 631 632 std::unique_ptr<OperationPass<FuncOp>> 633 mlir::createConvertLinalgToAffineLoopsPass() { 634 return std::make_unique<LowerToAffineLoops>(); 635 } 636 637 /// Emits a loop nest with the proper body for `op`. 638 template <typename LoopTy> 639 Optional<LinalgLoops> 640 mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op, 641 ArrayRef<unsigned> interchangeVector) { 642 return linalgOpToLoopsImpl<LoopTy>(op, builder, interchangeVector); 643 } 644 645 template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<AffineForOp>( 646 OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector); 647 template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<scf::ForOp>( 648 OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector); 649 template Optional<LinalgLoops> 650 mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>( 651 OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector); 652 653 /// Emits a loop nest of `affine.for` with the proper body for `op`. 654 LogicalResult 655 mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op, 656 ArrayRef<unsigned> interchangeVector) { 657 Optional<LinalgLoops> loops = 658 linalgLowerOpToLoops<AffineForOp>(builder, op, interchangeVector); 659 return loops ? success() : failure(); 660 } 661 662 /// Emits a loop nest of `scf.for` with the proper body for `op`. 663 LogicalResult 664 mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op, 665 ArrayRef<unsigned> interchangeVector) { 666 Optional<LinalgLoops> loops = 667 linalgLowerOpToLoops<scf::ForOp>(builder, op, interchangeVector); 668 return loops ? success() : failure(); 669 } 670 671 /// Emits a loop nest of `scf.parallel` with the proper body for `op`. 672 LogicalResult 673 mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op, 674 ArrayRef<unsigned> interchangeVector) { 675 Optional<LinalgLoops> loops = 676 linalgLowerOpToLoops<scf::ParallelOp>(builder, op, interchangeVector); 677 return loops ? success() : failure(); 678 } 679