1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/SCF/EDSC/Builders.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/BlockAndValueMapping.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 26 using namespace mlir; 27 using namespace mlir::edsc; 28 using namespace mlir::edsc::intrinsics; 29 using namespace mlir::linalg; 30 31 using edsc::op::operator+; 32 33 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b, 34 Location loc, 35 AffineMap map, 36 ArrayRef<Value> vals) { 37 if (map.isEmpty()) 38 return {}; 39 assert(map.getNumSymbols() == 0); 40 assert(map.getNumInputs() == vals.size()); 41 SmallVector<Value, 8> res; 42 res.reserve(map.getNumResults()); 43 auto dims = map.getNumDims(); 44 for (auto e : map.getResults()) { 45 auto exprMap = AffineMap::get(dims, 0, e); 46 SmallVector<Value, 4> operands(vals.begin(), vals.end()); 47 canonicalizeMapAndOperands(&exprMap, &operands); 48 res.push_back(affine_apply(exprMap, operands)); 49 } 50 return res; 51 } 52 53 static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs, 54 Optional<AffineMap> permutation) { 55 return permutation ? applyMapToValues(ScopedContext::getBuilderRef(), 56 ScopedContext::getLocation(), 57 permutation.getValue(), ivs) 58 : SmallVector<Value, 4>(ivs.begin(), ivs.end()); 59 } 60 61 // Creates a number of ranges equal to the number of results in `map`. 62 // The returned ranges correspond to the loop ranges, in the proper order, for 63 // which new loops will be created. 64 static SmallVector<SubViewOp::Range, 4> 65 emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, 66 ArrayRef<Value> allViewSizes) { 67 // Apply `map` to get view sizes in loop order. 68 auto sizes = applyMapToValues(b, loc, map, allViewSizes); 69 // Create a new range with the applied tile sizes. 70 ScopedContext scope(b, loc); 71 SmallVector<SubViewOp::Range, 4> res; 72 for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { 73 res.push_back(SubViewOp::Range{std_constant_index(0), sizes[idx], 74 std_constant_index(1)}); 75 } 76 return res; 77 } 78 79 template <typename IndexedValueType, typename OpType> 80 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues, 81 ArrayRef<SmallVector<Value, 8>> indexing, 82 ArrayRef<Value> outputBuffers) { 83 assert(op.getOperation()->getNumRegions() == 1 && 84 "Expected single region op"); 85 auto &b = ScopedContext::getBuilderRef(); 86 auto &block = op.region().front(); 87 BlockAndValueMapping map; 88 map.map(block.getArguments(), indexedValues); 89 for (auto &op : block.without_terminator()) { 90 assert(op.getNumRegions() == 0 && "expected a non-nested region"); 91 auto *newOp = b.clone(op, map); 92 map.map(op.getResults(), newOp->getResults()); 93 } 94 95 Operation &terminator = block.back(); 96 assert(isa<YieldOp>(terminator) && 97 "expected a yield op in the end of the region"); 98 for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { 99 IndexedValueType O(outputBuffers[i]); 100 O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); 101 } 102 } 103 104 // Returns a pair that contains input indices and output indices of a 105 // SingleInputPoolingOp `op`. 106 struct InputAndOutputIndices { 107 SmallVector<Value, 8> inputs; 108 SmallVector<Value, 8> outputs; 109 }; 110 template <typename SingleInputPoolingOp> 111 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs, 112 SingleInputPoolingOp op) { 113 auto &b = ScopedContext::getBuilderRef(); 114 auto loc = ScopedContext::getLocation(); 115 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 116 auto maps = llvm::to_vector<8>( 117 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 118 return InputAndOutputIndices{ 119 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 120 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 121 } 122 123 namespace { 124 125 /// Emits the MLIR for the scalar part of the generic op by: 126 /// 1. Emitting load ops for each input and output view in order. This is 127 /// achieved by applying the appropriate input or output map to the 128 /// enclosing induction variables. 129 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 130 /// from point 1. above. 131 /// 3. Emitting store ops to store the results of 2. to the output 132 /// views. 133 /// 134 /// An example output may resemble: 135 /// 136 /// ``` 137 /// scf.for %i = %c0 to %0 step %c1 { 138 /// scf.for %j = %c0 to %1 step %c1 { 139 /// scf.for %k = %c0 to %4 step %c1 { 140 /// %11 = load %arg0[%i, %j] : 141 /// memref<?x?xf32, stride_specification> 142 /// %12 = load %arg1[%i, %j, %k] : 143 /// memref<?x?x?xf32, stride_specification> 144 /// %13 = load %arg2[%i, %k, %j] : 145 /// memref<?x?x?xf32, stride_specification> 146 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 147 /// store %14#0, %arg1[%i, %j, %k] : 148 /// memref<?x?x?Xf32, stride_specification> 149 /// store %14#1, %arg2[%i, %k, %j] : 150 /// memref<?x?x?Xf32, stride_specification> 151 /// } 152 /// } 153 /// } 154 /// ``` 155 // TODO: need a LinalgStructuredOpInterface. 156 template <typename IndexedValueType, typename LinalgStructuredOpType> 157 void emitScalarImplementation(ArrayRef<Value> allIvs, 158 LinalgStructuredOpType linalgOp) { 159 assert(linalgOp.hasBufferSemantics() && 160 "expected linalg op with buffer semantics"); 161 auto &b = ScopedContext::getBuilderRef(); 162 auto loc = ScopedContext::getLocation(); 163 unsigned nInputs = linalgOp.getNumInputs(); 164 unsigned nOutputs = linalgOp.getNumOutputs(); 165 SmallVector<Value, 4> indexedValues; 166 indexedValues.reserve(nInputs + nOutputs); 167 168 // TODO: Avoid the loads if the corresponding argument of the 169 // region has no uses. 170 // 1.a. Emit load from input views. 171 for (unsigned i = 0; i < nInputs; ++i) { 172 auto indexing = makeCanonicalAffineApplies( 173 b, loc, linalgOp.getInputIndexingMap(i), allIvs); 174 // Passing through IndexedValueType emits the proper load operation. 175 indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); 176 } 177 // 1.b. Emit load from output views. 178 for (unsigned i = 0; i < nOutputs; ++i) { 179 auto indexing = makeCanonicalAffineApplies( 180 b, loc, linalgOp.getOutputIndexingMap(i), allIvs); 181 // Passing through IndexedValueType emits the proper load operation. 182 indexedValues.push_back( 183 IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); 184 } 185 186 // TODO: When a region inliner exists, use it. 187 // 2. Inline region, currently only works for a single basic block. 188 // 3. Emit store. 189 SmallVector<SmallVector<Value, 8>, 8> indexing; 190 SmallVector<Value, 8> outputBuffers; 191 for (unsigned i = 0; i < nOutputs; ++i) { 192 indexing.push_back(makeCanonicalAffineApplies( 193 b, loc, linalgOp.getOutputIndexingMap(i), allIvs)); 194 outputBuffers.push_back(linalgOp.getOutputBuffer(i)); 195 } 196 inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing, 197 outputBuffers); 198 } 199 200 template <typename IndexedValueType> 201 void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) { 202 assert(copyOp.hasBufferSemantics() && 203 "expected linalg op with buffer semantics"); 204 auto nPar = copyOp.getNumParallelLoops(); 205 assert(nPar == allIvs.size()); 206 auto inputIvs = 207 permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); 208 auto outputIvs = 209 permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); 210 SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end()); 211 SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end()); 212 IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); 213 // Emit the proper scalar assignment, whether we are dealing with a 0-D or 214 // an n-D loop nest; with or without permutations. 215 // clang-format off 216 nPar > 0 ? O(oivs) = I(iivs) : 217 O() = I(); 218 // clang-format on 219 } 220 221 template <typename IndexedValueType> 222 void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) { 223 assert(fillOp.hasBufferSemantics() && 224 "expected linalg op with buffer semantics"); 225 auto nPar = fillOp.getNumParallelLoops(); 226 assert(nPar == allIvs.size()); 227 auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar); 228 IndexedValueType O(fillOp.getOutputBuffer(0)); 229 // Emit the proper scalar assignment, whether we are dealing with a 0-D or 230 // an n-D loop nest; with or without permutations. 231 nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); 232 } 233 234 template <typename IndexedValueType> 235 void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) { 236 assert(dotOp.hasBufferSemantics() && 237 "expected linalg op with buffer semantics"); 238 assert(allIvs.size() == 1); 239 Value r_i(allIvs[0]); 240 IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), 241 C(dotOp.getOutputBuffer(0)); 242 // Emit scalar form. 243 C() = C() + A(r_i) * B(r_i); 244 } 245 246 template <typename IndexedValueType> 247 Value getConvOpInput(ConvOp convOp, StdIndexedValue im, 248 MutableArrayRef<Value> imIdx) { 249 // TODO: add a level of indirection to linalg.generic. 250 if (!convOp.padding()) 251 return im(imIdx); 252 253 auto *context = ScopedContext::getContext(); 254 Value zeroIndex = std_constant_index(0); 255 SmallVector<Value, 8> conds; 256 SmallVector<Value, 8> clampedImIdx; 257 for (auto iter : llvm::enumerate(imIdx)) { 258 int idx = iter.index(); 259 auto dim = iter.value(); 260 // Only need to iterate over the window dimensions. 261 if (idx == 0 || idx == static_cast<int>(imIdx.size()) - 1) { 262 clampedImIdx.push_back(dim); 263 continue; 264 } 265 266 using edsc::op::sge; 267 using edsc::op::slt; 268 using edsc::op::operator||; 269 Value leftOutOfBound = slt(dim, zeroIndex); 270 if (conds.empty()) 271 conds.push_back(leftOutOfBound); 272 else 273 conds.push_back(conds.back() || leftOutOfBound); 274 Value rightBound = std_dim(convOp.input(), idx); 275 conds.push_back(conds.back() || (sge(dim, rightBound))); 276 277 // When padding is involved, the indices will only be shifted to negative, 278 // so having a max op is enough. 279 auto maxMap = AffineMap::get(/*dimCount=*/1, 0, 280 {getAffineDimExpr(/*position=*/0, context), 281 getAffineConstantExpr(0, context)}, 282 context); 283 clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim})); 284 } 285 286 auto &b = ScopedContext::getBuilderRef(); 287 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 288 Value zero = std_constant(type, b.getZeroAttr(type)); 289 Value readInput = im(clampedImIdx); 290 return conds.empty() ? readInput 291 : (Value)std_select(conds.back(), zero, readInput); 292 } 293 294 /// Returns true is `convOp` has a non-zero padding. 295 static bool hasPadding(ConvOp convOp) { 296 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 297 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 298 return true; 299 } 300 return false; 301 } 302 303 template <typename IndexedValueType> 304 static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { 305 assert(convOp.hasBufferSemantics() && 306 "expected linalg op with buffer semantics"); 307 auto &b = ScopedContext::getBuilderRef(); 308 auto loc = ScopedContext::getLocation(); 309 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 310 auto maps = llvm::to_vector<8>( 311 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 312 SmallVector<Value, 8> fIdx( 313 makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 314 SmallVector<Value, 8> imIdx( 315 makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 316 SmallVector<Value, 8> oIdx( 317 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 318 319 IndexedValueType F(convOp.filter()), O(convOp.output()); 320 321 // Emit scalar form. Padded conv involves an affine.max in the memory access 322 // which is not allowed by affine.load. Override to use an StdIndexedValue 323 // when there is non-zero padding. 324 if (hasPadding(convOp)) { 325 StdIndexedValue I(convOp.input()); 326 Value paddedInput = getConvOpInput<IndexedValueType>(convOp, I, imIdx); 327 O(oIdx) += F(fIdx) * paddedInput; 328 } else { 329 IndexedValueType I(convOp.input()); 330 O(oIdx) += F(fIdx) * I(imIdx); 331 } 332 } 333 334 template <typename IndexedValueType> 335 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) { 336 InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); 337 // Emit scalar form. 338 IndexedValueType output(op.output()); 339 IndexedValueType input(op.input()); 340 Value lhs = output(indices.outputs); 341 Value rhs = input(indices.inputs); 342 using edsc::op::sgt; 343 Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs); 344 output(indices.outputs) = maxValue; 345 } 346 347 template <typename IndexedValueType> 348 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) { 349 InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); 350 // Emit scalar form. 351 IndexedValueType output(op.output()); 352 IndexedValueType input(op.input()); 353 Value lhs = output(indices.outputs); 354 Value rhs = input(indices.inputs); 355 using edsc::op::slt; 356 Value minValue = std_select(slt(lhs, rhs), lhs, rhs); 357 output(indices.outputs) = minValue; 358 } 359 template <typename IndexedValueType> 360 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) { 361 auto indices = getInputAndOutputIndices(allIvs, op); 362 IndexedValueType input(op.input()), output(op.output()); 363 364 // Emit scalar form. 365 output(indices.outputs) += input(indices.inputs); 366 } 367 /// Emits the MLIR for the scalar part of the indexed generic op by: 368 /// 1. Emitting load ops for each input and output view in order. This is 369 /// achieved by applying the appropriate input or output map to the 370 /// enclosing induction variables. 371 /// 2. Emitting a call to `op.fun()` that takes as arguments the induction 372 /// variables and the scalars from point 1. above. 373 /// 3. Emitting store ops to store the results of 2. to the output views. 374 /// 375 /// An example output may resemble: 376 /// 377 /// ``` 378 /// scf.for %i = %c0 to %0 step %c1 { 379 /// scf.for %j = %c0 to %1 step %c1 { 380 /// scf.for %k = %c0 to %4 step %c1 { 381 /// %11 = load %arg0[%i, %j] : 382 /// memref<?x?xf32, stride_specification> 383 /// %12 = load %arg1[%i, %j, %k] : 384 /// memref<?x?x?xf32, stride_specification> 385 /// %13 = load %arg2[%i, %k, %j] : 386 /// memref<?x?x?xf32, stride_specification> 387 /// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : 388 /// (index, index, index, f32, f32, f32) -> (f32, f32) 389 /// store %14#0, %arg1[%i, %j, %k] : 390 /// memref<?x?x?Xf32, stride_specification> 391 /// store %14#1, %arg2[%i, %k, %j] : 392 /// memref<?x?x?Xf32, stride_specification> 393 /// } 394 /// } 395 /// } 396 /// ``` 397 template <typename IndexedValueType> 398 static void emitScalarImplementation(ArrayRef<Value> allIvs, 399 IndexedGenericOp indexedGenericOp) { 400 assert(indexedGenericOp.hasBufferSemantics() && 401 "expected linalg op with buffer semantics"); 402 auto &b = ScopedContext::getBuilderRef(); 403 auto loc = ScopedContext::getLocation(); 404 unsigned nInputs = indexedGenericOp.getNumInputs(); 405 unsigned nOutputs = indexedGenericOp.getNumOutputs(); 406 unsigned nLoops = allIvs.size(); 407 SmallVector<Value, 4> indexedValues; 408 indexedValues.reserve(nLoops + nInputs + nOutputs); 409 for (unsigned i = 0; i < nLoops; ++i) 410 indexedValues.push_back(allIvs[i]); 411 412 // TODO: Avoid the loads if the corresponding argument of the 413 // region has no uses. 414 // 1.a. Emit load from input views. 415 for (unsigned i = 0; i < nInputs; ++i) { 416 auto indexing = makeCanonicalAffineApplies( 417 b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); 418 // Pass input i through IndexedValueType emits the proper load operation. 419 indexedValues.push_back( 420 IndexedValueType(indexedGenericOp.getInput(i))(indexing)); 421 } 422 // 1.b. Emit load from output views. 423 for (unsigned i = 0; i < nOutputs; ++i) { 424 auto indexing = makeCanonicalAffineApplies( 425 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); 426 // Pass output i through IndexedValueType emits the proper load operation. 427 indexedValues.push_back( 428 IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); 429 } 430 431 // TODO: When a region inliner exists, use it. 432 // 2. Inline region, currently only works for a single basic block. 433 // 3. Emit store. 434 SmallVector<SmallVector<Value, 8>, 8> indexing; 435 SmallVector<Value, 8> outputBuffers; 436 for (unsigned i = 0; i < nOutputs; ++i) { 437 indexing.push_back(makeCanonicalAffineApplies( 438 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); 439 outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); 440 } 441 inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues, 442 indexing, outputBuffers); 443 } 444 445 template <typename LoopTy, typename ConcreteOpTy> 446 Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { 447 using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; 448 449 ScopedContext scope(builder, op->getLoc()); 450 451 // The flattened loopToOperandRangesMaps is expected to be an invertible 452 // permutation map (which is asserted in the inverse calculation). 453 auto linalgOp = cast<ConcreteOpTy>(op); 454 assert(linalgOp.hasBufferSemantics() && 455 "expected linalg op with buffer semantics"); 456 auto mapsRange = 457 linalgOp.indexing_maps().template getAsRange<AffineMapAttr>(); 458 auto maps = llvm::to_vector<8>( 459 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 460 AffineMap invertedMap = inversePermutation(concatAffineMaps(maps)); 461 if (!invertedMap) 462 return {}; 463 if (invertedMap.isEmpty()) { 464 emitScalarImplementation<IndexedValueTy>({}, linalgOp); 465 return LinalgLoops(); 466 } 467 468 SmallVector<Value, 4> allIvs; 469 auto loopRanges = 470 emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap, 471 getViewSizes(builder, linalgOp)); 472 GenerateLoopNest<LoopTy>::doit( 473 loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) { 474 allIvs.append(ivs.begin(), ivs.end()); 475 emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp); 476 }); 477 // Number of loop ops might be different from the number of ivs since some 478 // loops like affine.parallel and scf.parallel have multiple ivs. 479 llvm::SetVector<Operation *> loopSet; 480 for (Value iv : allIvs) { 481 if (!iv) 482 return {}; 483 // The induction variable is a block argument of the entry block of the 484 // loop operation. 485 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 486 if (!ivVal) 487 return {}; 488 loopSet.insert(ivVal.getOwner()->getParentOp()); 489 } 490 LinalgLoops loops(loopSet.begin(), loopSet.end()); 491 return loops; 492 } 493 494 template <typename LoopType, typename ConcreteOp> 495 class LinalgRewritePattern : public RewritePattern { 496 public: 497 explicit LinalgRewritePattern(MLIRContext *context) 498 : RewritePattern(ConcreteOp::getOperationName(), 1, context) {} 499 500 LogicalResult matchAndRewrite(Operation *op, 501 PatternRewriter &rewriter) const override { 502 if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter)) 503 return failure(); 504 rewriter.eraseOp(op); 505 return success(); 506 } 507 }; 508 509 template <typename LoopType, typename ConcreteOp> 510 void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) { 511 patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx); 512 } 513 514 template <typename LoopType, typename... Args> 515 void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { 516 (void)std::initializer_list<int>{ 517 0, (insertOnePattern<LoopType, Args>(patterns, ctx), 0)...}; 518 } 519 520 /// Local folding pattern for AffineApplyOp that we can apply greedily. 521 /// This replaces AffineApplyOp by the proper value in cases where the 522 /// associated map is trivial. 523 /// A trivial map here is defined as a map with a single result and either: 524 /// 1. Zero operand + returns a single AffineConstantExpr 525 /// 2. One operand + returns a single AffineDimExpr 526 /// 3. One operand + returns a single AffineSymbolExpr 527 // 528 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 529 /// other cases, it is replaced by its unique operand. 530 struct FoldAffineOp : public RewritePattern { 531 FoldAffineOp(MLIRContext *context) 532 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 533 534 LogicalResult matchAndRewrite(Operation *op, 535 PatternRewriter &rewriter) const override { 536 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 537 auto map = affineApplyOp.getAffineMap(); 538 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 539 return failure(); 540 541 AffineExpr expr = map.getResult(0); 542 if (map.getNumInputs() == 0) { 543 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 544 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 545 return success(); 546 } 547 return failure(); 548 } 549 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 550 rewriter.replaceOp(op, op->getOperand(0)); 551 return success(); 552 } 553 return failure(); 554 } 555 }; 556 } // namespace 557 558 template <typename LoopType> 559 static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) { 560 OwningRewritePatternList patterns; 561 // Canonicalization and folding patterns applied greedily allow cleaning up 562 // the emitted IR on the fly. 563 // TODO: fold view and subview ops? 564 insertPatterns<LoopType, 565 #define GET_OP_LIST 566 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 567 >(patterns, context); 568 569 DimOp::getCanonicalizationPatterns(patterns, context); 570 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 571 patterns.insert<FoldAffineOp>(context); 572 // Just apply the patterns greedily. 573 applyPatternsAndFoldGreedily(funcOp, patterns); 574 } 575 576 namespace { 577 struct LowerToAffineLoops 578 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 579 void runOnFunction() override { 580 lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext()); 581 } 582 }; 583 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 584 void runOnFunction() override { 585 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext()); 586 } 587 }; 588 struct LowerToParallelLoops 589 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 590 void runOnFunction() override { 591 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), &getContext()); 592 } 593 }; 594 } // namespace 595 596 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 597 return std::make_unique<LowerToLoops>(); 598 } 599 600 std::unique_ptr<OperationPass<FuncOp>> 601 mlir::createConvertLinalgToParallelLoopsPass() { 602 return std::make_unique<LowerToParallelLoops>(); 603 } 604 605 std::unique_ptr<OperationPass<FuncOp>> 606 mlir::createConvertLinalgToAffineLoopsPass() { 607 return std::make_unique<LowerToAffineLoops>(); 608 } 609 610 // TODO: gradually remove this layer as more ops become "named". 611 template <typename LoopTy> 612 Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op, 613 OpBuilder &builder) { 614 assert(isa<LinalgOp>(op) && "LinalgOp expected"); 615 if (isa<CopyOp>(op)) 616 return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder); 617 if (isa<FillOp>(op)) 618 return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder); 619 if (isa<DotOp>(op)) 620 return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder); 621 if (isa<ConvOp>(op)) 622 return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder); 623 if (isa<PoolingMaxOp>(op)) 624 return linalgOpToLoopsImpl<LoopTy, PoolingMaxOp>(op, builder); 625 if (isa<PoolingMinOp>(op)) 626 return linalgOpToLoopsImpl<LoopTy, PoolingMinOp>(op, builder); 627 if (isa<PoolingSumOp>(op)) 628 return linalgOpToLoopsImpl<LoopTy, PoolingSumOp>(op, builder); 629 if (isa<IndexedGenericOp>(op)) 630 return linalgOpToLoopsImpl<LoopTy, IndexedGenericOp>(op, builder); 631 632 // TODO: Cases below are generic and need a LinalgStructuredOpInterface. 633 if (isa<GenericOp>(op)) 634 return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder); 635 if (isa<MatmulOp>(op)) 636 return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder); 637 if (isa<MatvecOp>(op)) 638 return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder); 639 if (isa<BatchMatmulOp>(op)) 640 return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder); 641 llvm_unreachable("Unexpected op in linalgOpToLoopsImpl"); 642 } 643 644 /// Emits a loop nest with the proper body for `op`. 645 template <typename LoopTy> 646 Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, 647 Operation *op) { 648 return linalgOpToLoopsImplSwitch<LoopTy>(op, builder); 649 } 650 651 template Optional<LinalgLoops> 652 mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder, 653 Operation *op); 654 template Optional<LinalgLoops> 655 mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder, 656 Operation *op); 657 template Optional<LinalgLoops> 658 mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder, 659 Operation *op); 660 661 /// Emits a loop nest of `affine.for` with the proper body for `op`. 662 LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, 663 Operation *op) { 664 Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op); 665 return loops ? success() : failure(); 666 } 667 668 /// Emits a loop nest of `scf.for` with the proper body for `op`. 669 LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { 670 Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op); 671 return loops ? success() : failure(); 672 } 673 674 /// Emits a loop nest of `scf.parallel` with the proper body for `op`. 675 LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, 676 Operation *op) { 677 Optional<LinalgLoops> loops = 678 linalgLowerOpToLoops<scf::ParallelOp>(builder, op); 679 return loops ? success() : failure(); 680 } 681