1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/SCF/EDSC/Builders.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/BlockAndValueMapping.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 28 using namespace mlir; 29 using namespace mlir::edsc; 30 using namespace mlir::edsc::intrinsics; 31 using namespace mlir::linalg; 32 33 using edsc::op::operator+; 34 35 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b, 36 Location loc, 37 AffineMap map, 38 ArrayRef<Value> vals) { 39 if (map.isEmpty()) 40 return {}; 41 42 assert(map.getNumInputs() == vals.size()); 43 SmallVector<Value, 8> res; 44 res.reserve(map.getNumResults()); 45 auto dims = map.getNumDims(); 46 for (auto e : map.getResults()) { 47 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 48 SmallVector<Value, 4> operands(vals.begin(), vals.end()); 49 canonicalizeMapAndOperands(&exprMap, &operands); 50 res.push_back(affine_apply(exprMap, operands)); 51 } 52 return res; 53 } 54 55 template <typename IndexedValueType, typename OpType> 56 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues, 57 ArrayRef<SmallVector<Value, 8>> indexing, 58 ArrayRef<Value> outputBuffers) { 59 assert(op->getNumRegions() == 1 && "Expected single region op"); 60 auto &b = ScopedContext::getBuilderRef(); 61 auto &block = op->getRegion(0).front(); 62 BlockAndValueMapping map; 63 map.map(block.getArguments(), indexedValues); 64 for (auto &op : block.without_terminator()) { 65 auto *newOp = b.clone(op, map); 66 map.map(op.getResults(), newOp->getResults()); 67 } 68 69 Operation &terminator = block.back(); 70 assert(isa<linalg::YieldOp>(terminator) && 71 "expected a yield op in the end of the region"); 72 for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { 73 IndexedValueType O(outputBuffers[i]); 74 O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); 75 } 76 } 77 78 // Returns a pair that contains input indices and output indices of a 79 // SingleInputPoolingOp `op`. 80 struct InputAndOutputIndices { 81 SmallVector<Value, 8> inputs; 82 SmallVector<Value, 8> outputs; 83 }; 84 template <typename SingleInputPoolingOp> 85 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs, 86 SingleInputPoolingOp op) { 87 auto &b = ScopedContext::getBuilderRef(); 88 auto loc = ScopedContext::getLocation(); 89 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 90 auto maps = llvm::to_vector<8>( 91 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 92 return InputAndOutputIndices{ 93 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 94 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 95 } 96 97 /// Emits the MLIR for the scalar part of the generic op by: 98 /// 1. Emitting load ops for each input and output view in order. This is 99 /// achieved by applying the appropriate input or output map to the 100 /// enclosing induction variables. 101 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 102 /// from point 1. above. 103 /// 3. Emitting store ops to store the results of 2. to the output 104 /// views. 105 /// 106 /// An example output may resemble: 107 /// 108 /// ``` 109 /// scf.for %i = %c0 to %0 step %c1 { 110 /// scf.for %j = %c0 to %1 step %c1 { 111 /// scf.for %k = %c0 to %4 step %c1 { 112 /// %11 = load %arg0[%i, %j] : 113 /// memref<?x?xf32, stride_specification> 114 /// %12 = load %arg1[%i, %j, %k] : 115 /// memref<?x?x?xf32, stride_specification> 116 /// %13 = load %arg2[%i, %k, %j] : 117 /// memref<?x?x?xf32, stride_specification> 118 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 119 /// store %14#0, %arg1[%i, %j, %k] : 120 /// memref<?x?x?Xf32, stride_specification> 121 /// store %14#1, %arg2[%i, %k, %j] : 122 /// memref<?x?x?Xf32, stride_specification> 123 /// } 124 /// } 125 /// } 126 /// ``` 127 template <typename IndexedValueType> 128 static void emitScalarImplementation(ArrayRef<Value> allIvs, 129 LinalgOp linalgOp) { 130 assert(linalgOp.hasBufferSemantics() && 131 "expected linalg op with buffer semantics"); 132 auto &b = ScopedContext::getBuilderRef(); 133 auto loc = ScopedContext::getLocation(); 134 unsigned nInputs = linalgOp.getNumInputs(); 135 unsigned nOutputs = linalgOp.getNumOutputs(); 136 SmallVector<Value, 4> indexedValues; 137 indexedValues.reserve(nInputs + nOutputs); 138 139 auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end()); 140 141 // TODO: Avoid the loads if the corresponding argument of the 142 // region has no uses. 143 // 1.a. Emit load from input views. 144 for (unsigned i = 0; i < nInputs; ++i) { 145 auto indexing = makeCanonicalAffineApplies( 146 b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims); 147 // Passing through IndexedValueType emits the proper load operation. 148 indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); 149 } 150 // 1.b. Emit load from output views. 151 for (unsigned i = 0; i < nOutputs; ++i) { 152 auto indexing = makeCanonicalAffineApplies( 153 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims); 154 // Passing through IndexedValueType emits the proper load operation. 155 indexedValues.push_back( 156 IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); 157 } 158 159 // TODO: When a region inliner exists, use it. 160 // 2. Inline region, currently only works for a single basic block. 161 // 3. Emit store. 162 SmallVector<SmallVector<Value, 8>, 8> indexing; 163 SmallVector<Value, 8> outputBuffers; 164 for (unsigned i = 0; i < nOutputs; ++i) { 165 indexing.push_back(makeCanonicalAffineApplies( 166 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims)); 167 outputBuffers.push_back(linalgOp.getOutputBuffer(i)); 168 } 169 inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing, 170 outputBuffers); 171 } 172 173 // Create a padded view into the given `input` tensor using the 'indices' 174 // to access the tensor. `skipPadding` lists the dimensions for which no padding 175 // is needed e.g. the non-spatial dimensions for convolutions. 176 template <typename IndexedValueType> 177 Value getPaddedInput(Value input, ArrayRef<Value> indices, 178 ArrayRef<int> skipPadding, Value padValue) { 179 // TODO: add a level of indirection to linalg.generic. 180 181 IndexedValueType indexedInput(input); 182 183 auto *context = ScopedContext::getContext(); 184 Value zeroIndex = std_constant_index(0); 185 SmallVector<Value, 8> conds; 186 SmallVector<Value, 8> clampedImIdx; 187 for (auto iter : llvm::enumerate(indices)) { 188 int idx = iter.index(); 189 auto dim = iter.value(); 190 if (is_contained(skipPadding, idx)) { 191 clampedImIdx.push_back(dim); 192 continue; 193 } 194 195 using edsc::op::sge; 196 using edsc::op::slt; 197 using edsc::op::operator||; 198 Value leftOutOfBound = slt(dim, zeroIndex); 199 if (conds.empty()) 200 conds.push_back(leftOutOfBound); 201 else 202 conds.push_back(conds.back() || leftOutOfBound); 203 Value rightBound = std_dim(input, idx); 204 conds.push_back(conds.back() || (sge(dim, rightBound))); 205 206 // When padding is involved, the indices will only be shifted to negative, 207 // so having a max op is enough. 208 auto maxMap = AffineMap::get(/*dimCount=*/1, 0, 209 {getAffineDimExpr(/*position=*/0, context), 210 getAffineConstantExpr(0, context)}, 211 context); 212 clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim})); 213 } 214 215 Value readInput = indexedInput(clampedImIdx); 216 return conds.empty() ? readInput 217 : (Value)std_select(conds.back(), padValue, readInput); 218 } 219 220 namespace { 221 222 /// The padding value for a given Op depends on the semantics of the Op. 223 /// The identity value for ConvOp and PoolingSumOp is 0, for PoolingMaxOp is 224 /// -inf or minInt and for PoolingMinOp is inf or maxInt. 225 template <typename OpType> 226 Attribute getPadValueAttr(Type type) { 227 llvm_unreachable("Unexpected op type for getPadValueAttr"); 228 return {}; 229 } 230 231 template <> 232 Attribute getPadValueAttr<PoolingMaxOp>(Type type) { 233 auto &b = ScopedContext::getBuilderRef(); 234 if (auto floatType = type.dyn_cast<FloatType>()) { 235 return b.getFloatAttr( 236 floatType, 237 APFloat::getInf(floatType.getFloatSemantics(), /*Negative*/ true)); 238 } 239 if (auto intType = type.dyn_cast<IntegerType>()) { 240 unsigned width = intType.getWidth(); 241 // The select instruction used to lower the PoolingMin uses a signed 242 // comparison, use a signed constant irrespective of the signedness of the 243 // integer type. 244 return b.getIntegerAttr(intType, APInt::getSignedMinValue(width)); 245 } 246 llvm_unreachable("Unsupported data type for PoolingMaxOp"); 247 return {}; 248 } 249 250 template <> 251 Attribute getPadValueAttr<PoolingMinOp>(Type type) { 252 auto &b = ScopedContext::getBuilderRef(); 253 if (auto floatType = type.dyn_cast<FloatType>()) { 254 return b.getFloatAttr(floatType, 255 APFloat::getInf(floatType.getFloatSemantics())); 256 } 257 if (auto intType = type.dyn_cast<IntegerType>()) { 258 unsigned width = intType.getWidth(); 259 // The select instruction used to lower the PoolingMin uses a signed 260 // comparison, use a signed constant irrespective of the signedness of the 261 // integer type. 262 return b.getIntegerAttr(intType, APInt::getSignedMaxValue(width)); 263 } 264 llvm_unreachable("Unsupported data type for PoolingMinOp"); 265 return {}; 266 } 267 268 template <> 269 Attribute getPadValueAttr<PoolingSumOp>(Type type) { 270 auto &b = ScopedContext::getBuilderRef(); 271 return b.getZeroAttr(type); 272 } 273 274 template <> 275 Attribute getPadValueAttr<ConvOp>(Type type) { 276 auto &b = ScopedContext::getBuilderRef(); 277 return b.getZeroAttr(type); 278 } 279 280 } // namespace 281 282 /// Returns true is `convOp` has a non-zero padding. 283 static bool hasPadding(ConvOp convOp) { 284 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 285 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 286 return true; 287 } 288 return false; 289 } 290 291 template <typename IndexedValueType> 292 static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { 293 assert(convOp.hasBufferSemantics() && 294 "expected linalg op with buffer semantics"); 295 auto &b = ScopedContext::getBuilderRef(); 296 auto loc = ScopedContext::getLocation(); 297 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 298 auto maps = llvm::to_vector<8>( 299 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 300 SmallVector<Value, 8> fIdx( 301 makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 302 SmallVector<Value, 8> imIdx( 303 makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 304 SmallVector<Value, 8> oIdx( 305 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 306 307 IndexedValueType F(convOp.filter()), O(convOp.output()); 308 309 // Emit scalar form. Padded conv involves an affine.max in the memory access 310 // which is not allowed by affine.load. Override to use an StdIndexedValue 311 // when there is non-zero padding. 312 if (hasPadding(convOp)) { 313 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 314 Value padValue = std_constant(type, getPadValueAttr<ConvOp>(type)); 315 Value paddedInput = getPaddedInput<StdIndexedValue>( 316 convOp.input(), imIdx, 317 /* Only need to pad the window dimensions */ 318 {0, static_cast<int>(imIdx.size()) - 1}, padValue); 319 O(oIdx) += F(fIdx) * paddedInput; 320 } else { 321 IndexedValueType I(convOp.input()); 322 O(oIdx) += F(fIdx) * I(imIdx); 323 } 324 } 325 326 template <typename PoolingOp> 327 static bool hasPadding(PoolingOp poolingOp) { 328 for (unsigned i = 0, e = poolingOp.getNumWindowLoops(); i < e; ++i) { 329 if (poolingOp.getLowPad(i) > 0 || poolingOp.getHighPad(i) > 0) 330 return true; 331 } 332 return false; 333 } 334 335 template <typename IndexedValueType, typename PoolingOp> 336 static Value getPoolingInput(PoolingOp op, ArrayRef<Value> inputIndices) { 337 if (hasPadding(op)) { 338 Type type = 339 op.input().getType().template cast<MemRefType>().getElementType(); 340 Value padValue = std_constant(type, getPadValueAttr<PoolingOp>(type)); 341 return getPaddedInput<StdIndexedValue>(op.input(), inputIndices, 342 /*Pad every dimension*/ {}, 343 padValue); 344 } 345 IndexedValueType input(op.input()); 346 return input(inputIndices); 347 } 348 349 template <typename IndexedValueType, typename OpType> 350 void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs, OpType op) { 351 InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); 352 // Emit scalar form. 353 IndexedValueType output(op.output()); 354 Value lhs = output(indices.outputs); 355 Value rhs = getPoolingInput<IndexedValueType>(op, indices.inputs); 356 using edsc::op::sgt; 357 using edsc::op::slt; 358 Value value = std::is_same<OpType, PoolingMinOp>() 359 ? std_select(slt(lhs, rhs), lhs, rhs) 360 : std_select(sgt(lhs, rhs), lhs, rhs); 361 output(indices.outputs) = value; 362 } 363 364 template <typename IndexedValueType> 365 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) { 366 emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs, 367 op); 368 } 369 370 template <typename IndexedValueType> 371 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) { 372 emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs, 373 op); 374 } 375 376 template <typename IndexedValueType> 377 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) { 378 auto indices = getInputAndOutputIndices(allIvs, op); 379 IndexedValueType output(op.output()); 380 381 // Emit scalar form. 382 output(indices.outputs) += 383 getPoolingInput<IndexedValueType>(op, indices.inputs); 384 } 385 386 /// Emits the MLIR for the scalar part of the indexed generic op by: 387 /// 1. Emitting load ops for each input and output view in order. This is 388 /// achieved by applying the appropriate input or output map to the 389 /// enclosing induction variables. 390 /// 2. Emitting a call to `op.fun()` that takes as arguments the induction 391 /// variables and the scalars from point 1. above. 392 /// 3. Emitting store ops to store the results of 2. to the output views. 393 /// 394 /// An example output may resemble: 395 /// 396 /// ``` 397 /// scf.for %i = %c0 to %0 step %c1 { 398 /// scf.for %j = %c0 to %1 step %c1 { 399 /// scf.for %k = %c0 to %4 step %c1 { 400 /// %11 = load %arg0[%i, %j] : 401 /// memref<?x?xf32, stride_specification> 402 /// %12 = load %arg1[%i, %j, %k] : 403 /// memref<?x?x?xf32, stride_specification> 404 /// %13 = load %arg2[%i, %k, %j] : 405 /// memref<?x?x?xf32, stride_specification> 406 /// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : 407 /// (index, index, index, f32, f32, f32) -> (f32, f32) 408 /// store %14#0, %arg1[%i, %j, %k] : 409 /// memref<?x?x?Xf32, stride_specification> 410 /// store %14#1, %arg2[%i, %k, %j] : 411 /// memref<?x?x?Xf32, stride_specification> 412 /// } 413 /// } 414 /// } 415 /// ``` 416 template <typename IndexedValueType> 417 static void emitScalarImplementation(ArrayRef<Value> allIvs, 418 IndexedGenericOp indexedGenericOp) { 419 assert(indexedGenericOp.hasBufferSemantics() && 420 "expected linalg op with buffer semantics"); 421 auto &b = ScopedContext::getBuilderRef(); 422 auto loc = ScopedContext::getLocation(); 423 unsigned nInputs = indexedGenericOp.getNumInputs(); 424 unsigned nOutputs = indexedGenericOp.getNumOutputs(); 425 unsigned nLoops = allIvs.size(); 426 SmallVector<Value, 4> indexedValues; 427 indexedValues.reserve(nLoops + nInputs + nOutputs); 428 for (unsigned i = 0; i < nLoops; ++i) 429 indexedValues.push_back(allIvs[i]); 430 431 // TODO: Avoid the loads if the corresponding argument of the 432 // region has no uses. 433 // 1.a. Emit load from input views. 434 for (unsigned i = 0; i < nInputs; ++i) { 435 auto indexing = makeCanonicalAffineApplies( 436 b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); 437 // Pass input i through IndexedValueType emits the proper load operation. 438 indexedValues.push_back( 439 IndexedValueType(indexedGenericOp.getInput(i))(indexing)); 440 } 441 // 1.b. Emit load from output views. 442 for (unsigned i = 0; i < nOutputs; ++i) { 443 auto indexing = makeCanonicalAffineApplies( 444 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); 445 // Pass output i through IndexedValueType emits the proper load operation. 446 indexedValues.push_back( 447 IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); 448 } 449 450 // TODO: When a region inliner exists, use it. 451 // 2. Inline region, currently only works for a single basic block. 452 // 3. Emit store. 453 SmallVector<SmallVector<Value, 8>, 8> indexing; 454 SmallVector<Value, 8> outputBuffers; 455 for (unsigned i = 0; i < nOutputs; ++i) { 456 indexing.push_back(makeCanonicalAffineApplies( 457 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); 458 outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); 459 } 460 inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues, 461 indexing, outputBuffers); 462 } 463 464 template <typename LoopTy> 465 static Optional<LinalgLoops> 466 linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, 467 ArrayRef<unsigned> interchangeVector) { 468 using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; 469 ScopedContext scope(builder, op->getLoc()); 470 471 // The flattened loopToOperandRangesMaps is expected to be an invertible 472 // permutation map (which is asserted in the inverse calculation). 473 auto linalgOp = cast<LinalgOp>(op); 474 assert(linalgOp.hasBufferSemantics() && 475 "expected linalg op with buffer semantics"); 476 477 auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc()); 478 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); 479 480 if (!interchangeVector.empty()) { 481 assert(interchangeVector.size() == loopRanges.size()); 482 assert(interchangeVector.size() == iteratorTypes.size()); 483 applyPermutationToVector(loopRanges, interchangeVector); 484 applyPermutationToVector(iteratorTypes, interchangeVector); 485 } 486 487 SmallVector<Value, 4> allIvs; 488 GenerateLoopNest<LoopTy>::doit( 489 loopRanges, /*iterInitArgs=*/{}, iteratorTypes, 490 [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { 491 assert(iterArgs.empty() && "unexpected iterArgs"); 492 allIvs.append(ivs.begin(), ivs.end()); 493 llvm::TypeSwitch<Operation *>(op) 494 .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, 495 IndexedGenericOp, LinalgOp>([&](auto op) { 496 emitScalarImplementation<IndexedValueTy>(allIvs, op); 497 }) 498 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 499 return scf::ValueVector{}; 500 }); 501 // Number of loop ops might be different from the number of ivs since some 502 // loops like affine.parallel and scf.parallel have multiple ivs. 503 llvm::SetVector<Operation *> loopSet; 504 for (Value iv : allIvs) { 505 if (!iv) 506 return {}; 507 // The induction variable is a block argument of the entry block of the 508 // loop operation. 509 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 510 if (!ivVal) 511 return {}; 512 loopSet.insert(ivVal.getOwner()->getParentOp()); 513 } 514 LinalgLoops loops(loopSet.begin(), loopSet.end()); 515 return loops; 516 } 517 518 namespace { 519 template <typename LoopType> 520 class LinalgRewritePattern : public RewritePattern { 521 public: 522 LinalgRewritePattern(ArrayRef<unsigned> interchangeVector) 523 : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), 524 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 525 526 LogicalResult matchAndRewrite(Operation *op, 527 PatternRewriter &rewriter) const override { 528 if (!isa<LinalgOp>(op)) 529 return failure(); 530 if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector)) 531 return failure(); 532 rewriter.eraseOp(op); 533 return success(); 534 } 535 536 private: 537 SmallVector<unsigned, 4> interchangeVector; 538 }; 539 540 struct FoldAffineOp; 541 } // namespace 542 543 template <typename LoopType> 544 static void lowerLinalgToLoopsImpl(FuncOp funcOp, 545 ArrayRef<unsigned> interchangeVector) { 546 MLIRContext *context = funcOp.getContext(); 547 OwningRewritePatternList patterns; 548 patterns.insert<LinalgRewritePattern<LoopType>>(interchangeVector); 549 DimOp::getCanonicalizationPatterns(patterns, context); 550 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 551 patterns.insert<FoldAffineOp>(context); 552 // Just apply the patterns greedily. 553 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 554 } 555 556 namespace { 557 /// Local folding pattern for AffineApplyOp that we can apply greedily. 558 /// This replaces AffineApplyOp by the proper value in cases where the 559 /// associated map is trivial. 560 /// A trivial map here is defined as a map with a single result and either: 561 /// 1. Zero operand + returns a single AffineConstantExpr 562 /// 2. One operand + returns a single AffineDimExpr 563 /// 3. One operand + returns a single AffineSymbolExpr 564 // 565 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 566 /// other cases, it is replaced by its unique operand. 567 struct FoldAffineOp : public RewritePattern { 568 FoldAffineOp(MLIRContext *context) 569 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 570 571 LogicalResult matchAndRewrite(Operation *op, 572 PatternRewriter &rewriter) const override { 573 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 574 auto map = affineApplyOp.getAffineMap(); 575 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 576 return failure(); 577 578 AffineExpr expr = map.getResult(0); 579 if (map.getNumInputs() == 0) { 580 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 581 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 582 return success(); 583 } 584 return failure(); 585 } 586 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 587 rewriter.replaceOp(op, op->getOperand(0)); 588 return success(); 589 } 590 return failure(); 591 } 592 }; 593 594 struct LowerToAffineLoops 595 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 596 void runOnFunction() override { 597 lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), interchangeVector); 598 } 599 }; 600 601 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 602 void runOnFunction() override { 603 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), interchangeVector); 604 } 605 }; 606 607 struct LowerToParallelLoops 608 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 609 void runOnFunction() override { 610 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), interchangeVector); 611 } 612 }; 613 } // namespace 614 615 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 616 return std::make_unique<LowerToLoops>(); 617 } 618 619 std::unique_ptr<OperationPass<FuncOp>> 620 mlir::createConvertLinalgToParallelLoopsPass() { 621 return std::make_unique<LowerToParallelLoops>(); 622 } 623 624 std::unique_ptr<OperationPass<FuncOp>> 625 mlir::createConvertLinalgToAffineLoopsPass() { 626 return std::make_unique<LowerToAffineLoops>(); 627 } 628 629 /// Emits a loop nest with the proper body for `op`. 630 template <typename LoopTy> 631 Optional<LinalgLoops> 632 mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op, 633 ArrayRef<unsigned> interchangeVector) { 634 return linalgOpToLoopsImpl<LoopTy>(op, builder, interchangeVector); 635 } 636 637 template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<AffineForOp>( 638 OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector); 639 template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<scf::ForOp>( 640 OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector); 641 template Optional<LinalgLoops> 642 mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>( 643 OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector); 644 645 /// Emits a loop nest of `affine.for` with the proper body for `op`. 646 LogicalResult 647 mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op, 648 ArrayRef<unsigned> interchangeVector) { 649 Optional<LinalgLoops> loops = 650 linalgLowerOpToLoops<AffineForOp>(builder, op, interchangeVector); 651 return loops ? success() : failure(); 652 } 653 654 /// Emits a loop nest of `scf.for` with the proper body for `op`. 655 LogicalResult 656 mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op, 657 ArrayRef<unsigned> interchangeVector) { 658 Optional<LinalgLoops> loops = 659 linalgLowerOpToLoops<scf::ForOp>(builder, op, interchangeVector); 660 return loops ? success() : failure(); 661 } 662 663 /// Emits a loop nest of `scf.parallel` with the proper body for `op`. 664 LogicalResult 665 mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op, 666 ArrayRef<unsigned> interchangeVector) { 667 Optional<LinalgLoops> loops = 668 linalgLowerOpToLoops<scf::ParallelOp>(builder, op, interchangeVector); 669 return loops ? success() : failure(); 670 } 671