1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/SCF/EDSC/Builders.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/BlockAndValueMapping.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 26 using namespace mlir; 27 using namespace mlir::edsc; 28 using namespace mlir::edsc::intrinsics; 29 using namespace mlir::linalg; 30 31 using edsc::op::operator+; 32 33 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b, 34 Location loc, 35 AffineMap map, 36 ArrayRef<Value> vals) { 37 if (map.isEmpty()) 38 return {}; 39 assert(map.getNumSymbols() == 0); 40 assert(map.getNumInputs() == vals.size()); 41 SmallVector<Value, 8> res; 42 res.reserve(map.getNumResults()); 43 auto dims = map.getNumDims(); 44 for (auto e : map.getResults()) { 45 auto exprMap = AffineMap::get(dims, 0, e); 46 SmallVector<Value, 4> operands(vals.begin(), vals.end()); 47 canonicalizeMapAndOperands(&exprMap, &operands); 48 res.push_back(affine_apply(exprMap, operands)); 49 } 50 return res; 51 } 52 53 static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs, 54 Optional<AffineMap> permutation) { 55 return permutation ? applyMapToValues(ScopedContext::getBuilderRef(), 56 ScopedContext::getLocation(), 57 permutation.getValue(), ivs) 58 : SmallVector<Value, 4>(ivs.begin(), ivs.end()); 59 } 60 61 // Creates a number of ranges equal to the number of results in `map`. 62 // The returned ranges correspond to the loop ranges, in the proper order, for 63 // which new loops will be created. 64 static SmallVector<SubViewOp::Range, 4> 65 emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, 66 ArrayRef<Value> allViewSizes) { 67 // Apply `map` to get view sizes in loop order. 68 auto sizes = applyMapToValues(b, loc, map, allViewSizes); 69 // Create a new range with the applied tile sizes. 70 ScopedContext scope(b, loc); 71 SmallVector<SubViewOp::Range, 4> res; 72 for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { 73 res.push_back(SubViewOp::Range{std_constant_index(0), sizes[idx], 74 std_constant_index(1)}); 75 } 76 return res; 77 } 78 79 template <typename IndexedValueType, typename OpType> 80 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues, 81 ArrayRef<SmallVector<Value, 8>> indexing, 82 ArrayRef<Value> outputBuffers) { 83 assert(op.getOperation()->getNumRegions() == 1 && 84 "Expected single region op"); 85 auto &b = ScopedContext::getBuilderRef(); 86 auto &block = op.region().front(); 87 BlockAndValueMapping map; 88 map.map(block.getArguments(), indexedValues); 89 for (auto &op : block.without_terminator()) { 90 assert(op.getNumRegions() == 0 && "expected a non-nested region"); 91 auto *newOp = b.clone(op, map); 92 map.map(op.getResults(), newOp->getResults()); 93 } 94 95 Operation &terminator = block.back(); 96 assert(isa<YieldOp>(terminator) && 97 "expected a yield op in the end of the region"); 98 for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { 99 IndexedValueType O(outputBuffers[i]); 100 O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); 101 } 102 } 103 104 // Returns a pair that contains input indices and output indices of a 105 // SingleInputPoolingOp `op`. 106 struct InputAndOutputIndices { 107 SmallVector<Value, 8> inputs; 108 SmallVector<Value, 8> outputs; 109 }; 110 template <typename SingleInputPoolingOp> 111 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs, 112 SingleInputPoolingOp op) { 113 auto &b = ScopedContext::getBuilderRef(); 114 auto loc = ScopedContext::getLocation(); 115 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 116 auto maps = llvm::to_vector<8>( 117 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 118 return InputAndOutputIndices{ 119 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 120 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 121 } 122 123 namespace { 124 125 /// Emits the MLIR for the scalar part of the generic op by: 126 /// 1. Emitting load ops for each input and output view in order. This is 127 /// achieved by applying the appropriate input or output map to the 128 /// enclosing induction variables. 129 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 130 /// from point 1. above. 131 /// 3. Emitting store ops to store the results of 2. to the output 132 /// views. 133 /// 134 /// An example output may resemble: 135 /// 136 /// ``` 137 /// scf.for %i = %c0 to %0 step %c1 { 138 /// scf.for %j = %c0 to %1 step %c1 { 139 /// scf.for %k = %c0 to %4 step %c1 { 140 /// %11 = load %arg0[%i, %j] : 141 /// memref<?x?xf32, stride_specification> 142 /// %12 = load %arg1[%i, %j, %k] : 143 /// memref<?x?x?xf32, stride_specification> 144 /// %13 = load %arg2[%i, %k, %j] : 145 /// memref<?x?x?xf32, stride_specification> 146 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 147 /// store %14#0, %arg1[%i, %j, %k] : 148 /// memref<?x?x?Xf32, stride_specification> 149 /// store %14#1, %arg2[%i, %k, %j] : 150 /// memref<?x?x?Xf32, stride_specification> 151 /// } 152 /// } 153 /// } 154 /// ``` 155 // TODO: need a LinalgStructuredOpInterface. 156 template <typename IndexedValueType, typename LinalgStructuredOpType> 157 void emitScalarImplementation(ArrayRef<Value> allIvs, 158 LinalgStructuredOpType linalgOp) { 159 assert(linalgOp.hasBufferSemantics() && 160 "expected linalg op with buffer semantics"); 161 auto &b = ScopedContext::getBuilderRef(); 162 auto loc = ScopedContext::getLocation(); 163 unsigned nInputs = linalgOp.getNumInputs(); 164 unsigned nOutputs = linalgOp.getNumOutputs(); 165 SmallVector<Value, 4> indexedValues; 166 indexedValues.reserve(nInputs + nOutputs); 167 168 // TODO(mravishankar): Avoid the loads if the corresponding argument of the 169 // region has no uses. 170 // 1.a. Emit load from input views. 171 for (unsigned i = 0; i < nInputs; ++i) { 172 auto indexing = makeCanonicalAffineApplies( 173 b, loc, linalgOp.getInputIndexingMap(i), allIvs); 174 // Passing through IndexedValueType emits the proper load operation. 175 indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); 176 } 177 // 1.b. Emit load from output views. 178 for (unsigned i = 0; i < nOutputs; ++i) { 179 auto indexing = makeCanonicalAffineApplies( 180 b, loc, linalgOp.getOutputIndexingMap(i), allIvs); 181 // Passing through IndexedValueType emits the proper load operation. 182 indexedValues.push_back( 183 IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); 184 } 185 186 // TODO(ntv): When a region inliner exists, use it. 187 // 2. Inline region, currently only works for a single basic block. 188 // 3. Emit store. 189 SmallVector<SmallVector<Value, 8>, 8> indexing; 190 SmallVector<Value, 8> outputBuffers; 191 for (unsigned i = 0; i < nOutputs; ++i) { 192 indexing.push_back(makeCanonicalAffineApplies( 193 b, loc, linalgOp.getOutputIndexingMap(i), allIvs)); 194 outputBuffers.push_back(linalgOp.getOutputBuffer(i)); 195 } 196 inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing, 197 outputBuffers); 198 } 199 200 template <typename IndexedValueType> 201 void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) { 202 assert(copyOp.hasBufferSemantics() && 203 "expected linalg op with buffer semantics"); 204 auto nPar = copyOp.getNumParallelLoops(); 205 assert(nPar == allIvs.size()); 206 auto inputIvs = 207 permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); 208 auto outputIvs = 209 permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); 210 SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end()); 211 SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end()); 212 IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); 213 // Emit the proper scalar assignment, whether we are dealing with a 0-D or 214 // an n-D loop nest; with or without permutations. 215 // clang-format off 216 nPar > 0 ? O(oivs) = I(iivs) : 217 O() = I(); 218 // clang-format on 219 } 220 221 template <typename IndexedValueType> 222 void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) { 223 assert(fillOp.hasBufferSemantics() && 224 "expected linalg op with buffer semantics"); 225 auto nPar = fillOp.getNumParallelLoops(); 226 assert(nPar == allIvs.size()); 227 auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar); 228 IndexedValueType O(fillOp.getOutputBuffer(0)); 229 // Emit the proper scalar assignment, whether we are dealing with a 0-D or 230 // an n-D loop nest; with or without permutations. 231 nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); 232 } 233 234 template <typename IndexedValueType> 235 void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) { 236 assert(dotOp.hasBufferSemantics() && 237 "expected linalg op with buffer semantics"); 238 assert(allIvs.size() == 1); 239 Value r_i(allIvs[0]); 240 IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), 241 C(dotOp.getOutputBuffer(0)); 242 // Emit scalar form. 243 C() = C() + A(r_i) * B(r_i); 244 } 245 246 template <typename IndexedValueType> 247 Value getConvOpInput(ConvOp convOp, StdIndexedValue im, 248 MutableArrayRef<Value> imIdx) { 249 // TODO(ntv): add a level of indirection to linalg.generic. 250 if (!convOp.padding()) 251 return im(imIdx); 252 253 auto *context = ScopedContext::getContext(); 254 Value zeroIndex = std_constant_index(0); 255 SmallVector<Value, 8> conds; 256 SmallVector<Value, 8> clampedImIdx; 257 for (auto iter : llvm::enumerate(imIdx)) { 258 int idx = iter.index(); 259 auto dim = iter.value(); 260 // Only need to iterate over the window dimensions. 261 if (idx == 0 || idx == static_cast<int>(imIdx.size()) - 1) { 262 clampedImIdx.push_back(dim); 263 continue; 264 } 265 266 using edsc::op::operator<; 267 using edsc::op::operator>=; 268 using edsc::op::operator||; 269 Value leftOutOfBound = dim < zeroIndex; 270 if (conds.empty()) 271 conds.push_back(leftOutOfBound); 272 else 273 conds.push_back(conds.back() || leftOutOfBound); 274 Value rightBound = std_dim(convOp.input(), idx); 275 conds.push_back(conds.back() || (dim >= rightBound)); 276 277 // When padding is involved, the indices will only be shifted to negative, 278 // so having a max op is enough. 279 auto maxMap = AffineMap::get(/*dimCount=*/1, 0, 280 {getAffineDimExpr(/*position=*/0, context), 281 getAffineConstantExpr(0, context)}, 282 context); 283 clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim})); 284 } 285 286 auto &b = ScopedContext::getBuilderRef(); 287 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 288 Value zero = std_constant(type, b.getZeroAttr(type)); 289 Value readInput = im(clampedImIdx); 290 return conds.empty() ? readInput 291 : (Value)std_select(conds.back(), zero, readInput); 292 } 293 294 /// Returns true is `convOp` has a non-zero padding. 295 static bool hasPadding(ConvOp convOp) { 296 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 297 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 298 return true; 299 } 300 return false; 301 } 302 303 template <typename IndexedValueType> 304 static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { 305 assert(convOp.hasBufferSemantics() && 306 "expected linalg op with buffer semantics"); 307 auto &b = ScopedContext::getBuilderRef(); 308 auto loc = ScopedContext::getLocation(); 309 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 310 auto maps = llvm::to_vector<8>( 311 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 312 SmallVector<Value, 8> fIdx( 313 makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 314 SmallVector<Value, 8> imIdx( 315 makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 316 SmallVector<Value, 8> oIdx( 317 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 318 319 IndexedValueType F(convOp.filter()), O(convOp.output()); 320 321 // Emit scalar form. Padded conv involves an affine.max in the memory access 322 // which is not allowed by affine.load. Override to use an StdIndexedValue 323 // when there is non-zero padding. 324 if (hasPadding(convOp)) { 325 StdIndexedValue I(convOp.input()); 326 Value paddedInput = getConvOpInput<IndexedValueType>(convOp, I, imIdx); 327 O(oIdx) += F(fIdx) * paddedInput; 328 } else { 329 IndexedValueType I(convOp.input()); 330 O(oIdx) += F(fIdx) * I(imIdx); 331 } 332 } 333 334 template <typename IndexedValueType> 335 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) { 336 auto indices = getInputAndOutputIndices(allIvs, op); 337 // Emit scalar form. 338 Value lhs = std_load(op.output(), indices.outputs); 339 Value rhs = std_load(op.input(), indices.inputs); 340 using edsc::op::operator>; 341 Value maxValue = std_select(lhs > rhs, lhs, rhs); 342 std_store(maxValue, op.output(), indices.outputs); 343 } 344 template <typename IndexedValueType> 345 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) { 346 auto indices = getInputAndOutputIndices(allIvs, op); 347 // Emit scalar form. 348 Value lhs = std_load(op.output(), indices.outputs); 349 Value rhs = std_load(op.input(), indices.inputs); 350 using edsc::op::operator<; 351 Value minValue = std_select(lhs < rhs, lhs, rhs); 352 std_store(minValue, op.output(), indices.outputs); 353 } 354 template <typename IndexedValueType> 355 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) { 356 auto indices = getInputAndOutputIndices(allIvs, op); 357 IndexedValueType input(op.input()), output(op.output()); 358 359 // Emit scalar form. 360 output(indices.outputs) += input(indices.inputs); 361 } 362 /// Emits the MLIR for the scalar part of the indexed generic op by: 363 /// 1. Emitting load ops for each input and output view in order. This is 364 /// achieved by applying the appropriate input or output map to the 365 /// enclosing induction variables. 366 /// 2. Emitting a call to `op.fun()` that takes as arguments the induction 367 /// variables and the scalars from point 1. above. 368 /// 3. Emitting store ops to store the results of 2. to the output views. 369 /// 370 /// An example output may resemble: 371 /// 372 /// ``` 373 /// scf.for %i = %c0 to %0 step %c1 { 374 /// scf.for %j = %c0 to %1 step %c1 { 375 /// scf.for %k = %c0 to %4 step %c1 { 376 /// %11 = load %arg0[%i, %j] : 377 /// memref<?x?xf32, stride_specification> 378 /// %12 = load %arg1[%i, %j, %k] : 379 /// memref<?x?x?xf32, stride_specification> 380 /// %13 = load %arg2[%i, %k, %j] : 381 /// memref<?x?x?xf32, stride_specification> 382 /// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : 383 /// (index, index, index, f32, f32, f32) -> (f32, f32) 384 /// store %14#0, %arg1[%i, %j, %k] : 385 /// memref<?x?x?Xf32, stride_specification> 386 /// store %14#1, %arg2[%i, %k, %j] : 387 /// memref<?x?x?Xf32, stride_specification> 388 /// } 389 /// } 390 /// } 391 /// ``` 392 template <typename IndexedValueType> 393 static void emitScalarImplementation(ArrayRef<Value> allIvs, 394 IndexedGenericOp indexedGenericOp) { 395 assert(indexedGenericOp.hasBufferSemantics() && 396 "expected linalg op with buffer semantics"); 397 auto &b = ScopedContext::getBuilderRef(); 398 auto loc = ScopedContext::getLocation(); 399 unsigned nInputs = indexedGenericOp.getNumInputs(); 400 unsigned nOutputs = indexedGenericOp.getNumOutputs(); 401 unsigned nLoops = allIvs.size(); 402 SmallVector<Value, 4> indexedValues; 403 indexedValues.reserve(nLoops + nInputs + nOutputs); 404 for (unsigned i = 0; i < nLoops; ++i) 405 indexedValues.push_back(allIvs[i]); 406 407 // TODO(mravishankar): Avoid the loads if the corresponding argument of the 408 // region has no uses. 409 // 1.a. Emit load from input views. 410 for (unsigned i = 0; i < nInputs; ++i) { 411 auto indexing = makeCanonicalAffineApplies( 412 b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); 413 // Pass input i through IndexedValueType emits the proper load operation. 414 indexedValues.push_back( 415 IndexedValueType(indexedGenericOp.getInput(i))(indexing)); 416 } 417 // 1.b. Emit load from output views. 418 for (unsigned i = 0; i < nOutputs; ++i) { 419 auto indexing = makeCanonicalAffineApplies( 420 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); 421 // Pass output i through IndexedValueType emits the proper load operation. 422 indexedValues.push_back( 423 IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); 424 } 425 426 // TODO(ntv): When a region inliner exists, use it. 427 // 2. Inline region, currently only works for a single basic block. 428 // 3. Emit store. 429 SmallVector<SmallVector<Value, 8>, 8> indexing; 430 SmallVector<Value, 8> outputBuffers; 431 for (unsigned i = 0; i < nOutputs; ++i) { 432 indexing.push_back(makeCanonicalAffineApplies( 433 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); 434 outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); 435 } 436 inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues, 437 indexing, outputBuffers); 438 } 439 440 template <typename LoopTy, typename ConcreteOpTy> 441 Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { 442 using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; 443 444 ScopedContext scope(builder, op->getLoc()); 445 446 // The flattened loopToOperandRangesMaps is expected to be an invertible 447 // permutation map (which is asserted in the inverse calculation). 448 auto linalgOp = cast<ConcreteOpTy>(op); 449 assert(linalgOp.hasBufferSemantics() && 450 "expected linalg op with buffer semantics"); 451 auto mapsRange = 452 linalgOp.indexing_maps().template getAsRange<AffineMapAttr>(); 453 auto maps = llvm::to_vector<8>( 454 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 455 AffineMap invertedMap = inversePermutation(concatAffineMaps(maps)); 456 if (!invertedMap) 457 return {}; 458 if (invertedMap.isEmpty()) { 459 emitScalarImplementation<IndexedValueTy>({}, linalgOp); 460 return LinalgLoops(); 461 } 462 463 SmallVector<Value, 4> allIvs; 464 auto loopRanges = 465 emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap, 466 getViewSizes(builder, linalgOp)); 467 GenerateLoopNest<LoopTy>::doit( 468 loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) { 469 allIvs.append(ivs.begin(), ivs.end()); 470 emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp); 471 }); 472 // Number of loop ops might be different from the number of ivs since some 473 // loops like affine.parallel and scf.parallel have multiple ivs. 474 llvm::SetVector<Operation *> loopSet; 475 for (Value iv : allIvs) { 476 if (!iv) 477 return {}; 478 // The induction variable is a block argument of the entry block of the 479 // loop operation. 480 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 481 if (!ivVal) 482 return {}; 483 loopSet.insert(ivVal.getOwner()->getParentOp()); 484 } 485 LinalgLoops loops(loopSet.begin(), loopSet.end()); 486 return loops; 487 } 488 489 template <typename LoopType, typename ConcreteOp> 490 class LinalgRewritePattern : public RewritePattern { 491 public: 492 explicit LinalgRewritePattern(MLIRContext *context) 493 : RewritePattern(ConcreteOp::getOperationName(), 1, context) {} 494 495 LogicalResult matchAndRewrite(Operation *op, 496 PatternRewriter &rewriter) const override { 497 if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter)) 498 return failure(); 499 rewriter.eraseOp(op); 500 return success(); 501 } 502 }; 503 504 template <typename LoopType, typename ConcreteOp> 505 void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) { 506 patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx); 507 } 508 509 template <typename LoopType, typename... Args> 510 void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { 511 (void)std::initializer_list<int>{ 512 0, (insertOnePattern<LoopType, Args>(patterns, ctx), 0)...}; 513 } 514 515 /// Local folding pattern for AffineApplyOp that we can apply greedily. 516 /// This replaces AffineApplyOp by the proper value in cases where the 517 /// associated map is trivial. 518 /// A trivial map here is defined as a map with a single result and either: 519 /// 1. Zero operand + returns a single AffineConstantExpr 520 /// 2. One operand + returns a single AffineDimExpr 521 /// 3. One operand + returns a single AffineSymbolExpr 522 // 523 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 524 /// other cases, it is replaced by its unique operand. 525 struct FoldAffineOp : public RewritePattern { 526 FoldAffineOp(MLIRContext *context) 527 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 528 529 LogicalResult matchAndRewrite(Operation *op, 530 PatternRewriter &rewriter) const override { 531 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 532 auto map = affineApplyOp.getAffineMap(); 533 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 534 return failure(); 535 536 AffineExpr expr = map.getResult(0); 537 if (map.getNumInputs() == 0) { 538 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 539 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 540 return success(); 541 } 542 return failure(); 543 } 544 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 545 rewriter.replaceOp(op, op->getOperand(0)); 546 return success(); 547 } 548 return failure(); 549 } 550 }; 551 } // namespace 552 553 template <typename LoopType> 554 static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) { 555 OwningRewritePatternList patterns; 556 // Canonicalization and folding patterns applied greedily allow cleaning up 557 // the emitted IR on the fly. 558 // TODO(ntv) fold view and subview ops? 559 insertPatterns<LoopType, 560 #define GET_OP_LIST 561 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 562 >(patterns, context); 563 564 DimOp::getCanonicalizationPatterns(patterns, context); 565 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 566 patterns.insert<FoldAffineOp>(context); 567 // Just apply the patterns greedily. 568 applyPatternsAndFoldGreedily(funcOp, patterns); 569 } 570 571 namespace { 572 struct LowerToAffineLoops 573 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 574 void runOnFunction() override { 575 lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext()); 576 } 577 }; 578 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 579 void runOnFunction() override { 580 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext()); 581 } 582 }; 583 struct LowerToParallelLoops 584 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 585 void runOnFunction() override { 586 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), &getContext()); 587 } 588 }; 589 } // namespace 590 591 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 592 return std::make_unique<LowerToLoops>(); 593 } 594 595 std::unique_ptr<OperationPass<FuncOp>> 596 mlir::createConvertLinalgToParallelLoopsPass() { 597 return std::make_unique<LowerToParallelLoops>(); 598 } 599 600 std::unique_ptr<OperationPass<FuncOp>> 601 mlir::createConvertLinalgToAffineLoopsPass() { 602 return std::make_unique<LowerToAffineLoops>(); 603 } 604 605 // TODO: gradually remove this layer as more ops become "named". 606 template <typename LoopTy> 607 Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op, 608 OpBuilder &builder) { 609 assert(isa<LinalgOp>(op) && "LinalgOp expected"); 610 if (isa<CopyOp>(op)) 611 return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder); 612 if (isa<FillOp>(op)) 613 return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder); 614 if (isa<DotOp>(op)) 615 return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder); 616 if (isa<ConvOp>(op)) 617 return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder); 618 if (isa<PoolingMaxOp>(op)) 619 return linalgOpToLoopsImpl<LoopTy, PoolingMaxOp>(op, builder); 620 if (isa<PoolingMinOp>(op)) 621 return linalgOpToLoopsImpl<LoopTy, PoolingMinOp>(op, builder); 622 if (isa<PoolingSumOp>(op)) 623 return linalgOpToLoopsImpl<LoopTy, PoolingSumOp>(op, builder); 624 if (isa<IndexedGenericOp>(op)) 625 return linalgOpToLoopsImpl<LoopTy, IndexedGenericOp>(op, builder); 626 627 // TODO: Cases below are generic and need a LinalgStructuredOpInterface. 628 if (isa<GenericOp>(op)) 629 return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder); 630 if (isa<MatmulOp>(op)) 631 return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder); 632 if (isa<MatvecOp>(op)) 633 return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder); 634 if (isa<BatchMatmulOp>(op)) 635 return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder); 636 llvm_unreachable("Unexpected op in linalgOpToLoopsImpl"); 637 } 638 639 /// Emits a loop nest with the proper body for `op`. 640 template <typename LoopTy> 641 Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, 642 Operation *op) { 643 return linalgOpToLoopsImplSwitch<LoopTy>(op, builder); 644 } 645 646 template Optional<LinalgLoops> 647 mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder, 648 Operation *op); 649 template Optional<LinalgLoops> 650 mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder, 651 Operation *op); 652 template Optional<LinalgLoops> 653 mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder, 654 Operation *op); 655 656 /// Emits a loop nest of `affine.for` with the proper body for `op`. 657 LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, 658 Operation *op) { 659 Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op); 660 return loops ? success() : failure(); 661 } 662 663 /// Emits a loop nest of `scf.for` with the proper body for `op`. 664 LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { 665 Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op); 666 return loops ? success() : failure(); 667 } 668 669 /// Emits a loop nest of `scf.parallel` with the proper body for `op`. 670 LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, 671 Operation *op) { 672 Optional<LinalgLoops> loops = 673 linalgLowerOpToLoops<scf::ParallelOp>(builder, op); 674 return loops ? success() : failure(); 675 } 676