1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/SCF/EDSC/Builders.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/BlockAndValueMapping.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 26 using namespace mlir; 27 using namespace mlir::edsc; 28 using namespace mlir::edsc::intrinsics; 29 using namespace mlir::linalg; 30 31 using edsc::op::operator+; 32 33 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b, 34 Location loc, 35 AffineMap map, 36 ArrayRef<Value> vals) { 37 if (map.isEmpty()) 38 return {}; 39 assert(map.getNumSymbols() == 0); 40 assert(map.getNumInputs() == vals.size()); 41 SmallVector<Value, 8> res; 42 res.reserve(map.getNumResults()); 43 auto dims = map.getNumDims(); 44 for (auto e : map.getResults()) { 45 auto exprMap = AffineMap::get(dims, 0, e); 46 SmallVector<Value, 4> operands(vals.begin(), vals.end()); 47 canonicalizeMapAndOperands(&exprMap, &operands); 48 res.push_back(affine_apply(exprMap, operands)); 49 } 50 return res; 51 } 52 53 static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs, 54 Optional<AffineMap> permutation) { 55 return permutation ? applyMapToValues(ScopedContext::getBuilderRef(), 56 ScopedContext::getLocation(), 57 permutation.getValue(), ivs) 58 : SmallVector<Value, 4>(ivs.begin(), ivs.end()); 59 } 60 61 // Creates a number of ranges equal to the number of results in `map`. 62 // The returned ranges correspond to the loop ranges, in the proper order, for 63 // which new loops will be created. 64 static SmallVector<SubViewOp::Range, 4> 65 emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, 66 ArrayRef<Value> allViewSizes) { 67 // Apply `map` to get view sizes in loop order. 68 auto sizes = applyMapToValues(b, loc, map, allViewSizes); 69 // Create a new range with the applied tile sizes. 70 ScopedContext scope(b, loc); 71 SmallVector<SubViewOp::Range, 4> res; 72 for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { 73 res.push_back(SubViewOp::Range{std_constant_index(0), sizes[idx], 74 std_constant_index(1)}); 75 } 76 return res; 77 } 78 79 template <typename IndexedValueType, typename OpType> 80 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues, 81 ArrayRef<SmallVector<Value, 8>> indexing, 82 ArrayRef<Value> outputBuffers) { 83 auto &b = ScopedContext::getBuilderRef(); 84 auto &block = op.region().front(); 85 BlockAndValueMapping map; 86 map.map(block.getArguments(), indexedValues); 87 for (auto &op : block.without_terminator()) { 88 assert(op.getNumRegions() == 0 && "expected a non-nested region"); 89 auto *newOp = b.clone(op, map); 90 map.map(op.getResults(), newOp->getResults()); 91 } 92 93 Operation &terminator = block.back(); 94 assert(isa<YieldOp>(terminator) && 95 "expected a yield op in the end of the region"); 96 for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { 97 IndexedValueType O(outputBuffers[i]); 98 O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); 99 } 100 } 101 102 // Returns a pair that contains input indices and output indices of a 103 // SingleInputPoolingOp `op`. 104 struct InputAndOutputIndices { 105 SmallVector<Value, 8> inputs; 106 SmallVector<Value, 8> outputs; 107 }; 108 template <typename SingleInputPoolingOp> 109 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs, 110 SingleInputPoolingOp op) { 111 auto &b = ScopedContext::getBuilderRef(); 112 auto loc = ScopedContext::getLocation(); 113 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 114 auto maps = llvm::to_vector<8>( 115 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 116 return InputAndOutputIndices{ 117 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 118 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 119 } 120 121 namespace { 122 123 /// Emits the MLIR for the scalar part of the generic op by: 124 /// 1. Emitting load ops for each input and output view in order. This is 125 /// achieved by applying the appropriate input or output map to the 126 /// enclosing induction variables. 127 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 128 /// from point 1. above. 129 /// 3. Emitting store ops to store the results of 2. to the output 130 /// views. 131 /// 132 /// An example output may resemble: 133 /// 134 /// ``` 135 /// scf.for %i = %c0 to %0 step %c1 { 136 /// scf.for %j = %c0 to %1 step %c1 { 137 /// scf.for %k = %c0 to %4 step %c1 { 138 /// %11 = load %arg0[%i, %j] : 139 /// memref<?x?xf32, stride_specification> 140 /// %12 = load %arg1[%i, %j, %k] : 141 /// memref<?x?x?xf32, stride_specification> 142 /// %13 = load %arg2[%i, %k, %j] : 143 /// memref<?x?x?xf32, stride_specification> 144 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 145 /// store %14#0, %arg1[%i, %j, %k] : 146 /// memref<?x?x?Xf32, stride_specification> 147 /// store %14#1, %arg2[%i, %k, %j] : 148 /// memref<?x?x?Xf32, stride_specification> 149 /// } 150 /// } 151 /// } 152 /// ``` 153 template <typename IndexedValueType, typename LinalgOpType> 154 class LinalgScopedEmitter { 155 public: 156 static void emitScalarImplementation(ArrayRef<Value> allIvs, 157 LinalgOpType linalgOp) { 158 assert(linalgOp.hasBufferSemantics() && 159 "expected linalg op with buffer semantics"); 160 auto &b = ScopedContext::getBuilderRef(); 161 auto loc = ScopedContext::getLocation(); 162 unsigned nInputs = linalgOp.getNumInputs(); 163 unsigned nOutputs = linalgOp.getNumOutputs(); 164 SmallVector<Value, 4> indexedValues; 165 indexedValues.reserve(nInputs + nOutputs); 166 167 // TODO(mravishankar): Avoid the loads if the corresponding argument of the 168 // region has no uses. 169 // 1.a. Emit load from input views. 170 for (unsigned i = 0; i < nInputs; ++i) { 171 auto indexing = makeCanonicalAffineApplies( 172 b, loc, linalgOp.getInputIndexingMap(i), allIvs); 173 // Passing through IndexedValueType emits the proper load operation. 174 indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); 175 } 176 // 1.b. Emit load from output views. 177 for (unsigned i = 0; i < nOutputs; ++i) { 178 auto indexing = makeCanonicalAffineApplies( 179 b, loc, linalgOp.getOutputIndexingMap(i), allIvs); 180 // Passing through IndexedValueType emits the proper load operation. 181 indexedValues.push_back( 182 IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); 183 } 184 185 // TODO(ntv): When a region inliner exists, use it. 186 // 2. Inline region, currently only works for a single basic block. 187 // 3. Emit store. 188 SmallVector<SmallVector<Value, 8>, 8> indexing; 189 SmallVector<Value, 8> outputBuffers; 190 for (unsigned i = 0; i < nOutputs; ++i) { 191 indexing.push_back(makeCanonicalAffineApplies( 192 b, loc, linalgOp.getOutputIndexingMap(i), allIvs)); 193 outputBuffers.push_back(linalgOp.getOutputBuffer(i)); 194 } 195 inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, 196 indexing, outputBuffers); 197 } 198 }; 199 200 template <typename IndexedValueType> 201 class LinalgScopedEmitter<IndexedValueType, CopyOp> { 202 public: 203 static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) { 204 assert(copyOp.hasBufferSemantics() && 205 "expected linalg op with buffer semantics"); 206 auto nPar = copyOp.getNumParallelLoops(); 207 assert(nPar == allIvs.size()); 208 auto inputIvs = 209 permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); 210 auto outputIvs = 211 permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); 212 SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end()); 213 SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end()); 214 IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); 215 // Emit the proper scalar assignment, whether we are dealing with a 0-D or 216 // an n-D loop nest; with or without permutations. 217 // clang-format off 218 nPar > 0 ? O(oivs) = I(iivs) : 219 O() = I(); 220 // clang-format on 221 } 222 }; 223 224 template <typename IndexedValueType> 225 class LinalgScopedEmitter<IndexedValueType, FillOp> { 226 public: 227 static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) { 228 assert(fillOp.hasBufferSemantics() && 229 "expected linalg op with buffer semantics"); 230 auto nPar = fillOp.getNumParallelLoops(); 231 assert(nPar == allIvs.size()); 232 auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar); 233 IndexedValueType O(fillOp.getOutputBuffer(0)); 234 // Emit the proper scalar assignment, whether we are dealing with a 0-D or 235 // an n-D loop nest; with or without permutations. 236 nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); 237 } 238 }; 239 240 template <typename IndexedValueType> 241 class LinalgScopedEmitter<IndexedValueType, DotOp> { 242 public: 243 static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) { 244 assert(dotOp.hasBufferSemantics() && 245 "expected linalg op with buffer semantics"); 246 assert(allIvs.size() == 1); 247 Value r_i(allIvs[0]); 248 IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), 249 C(dotOp.getOutputBuffer(0)); 250 // Emit scalar form. 251 C() = C() + A(r_i) * B(r_i); 252 } 253 }; 254 255 template <typename IndexedValueType> 256 class LinalgScopedEmitter<IndexedValueType, MatvecOp> { 257 public: 258 static void emitScalarImplementation(ArrayRef<Value> allIvs, 259 MatvecOp matvecOp) { 260 assert(matvecOp.hasBufferSemantics() && 261 "expected linalg op with buffer semantics"); 262 assert(allIvs.size() == 2); 263 Value i(allIvs[0]), r_j(allIvs[1]); 264 IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), 265 C(matvecOp.getOutputBuffer(0)); 266 // Emit scalar form. 267 C(i) = C(i) + A(i, r_j) * B(r_j); 268 } 269 }; 270 271 template <typename IndexedValueType> 272 class LinalgScopedEmitter<IndexedValueType, MatmulOp> { 273 public: 274 static void emitScalarImplementation(ArrayRef<Value> allIvs, 275 MatmulOp matmulOp) { 276 assert(matmulOp.hasBufferSemantics() && 277 "expected linalg op with buffer semantics"); 278 assert(allIvs.size() == 3); 279 Value i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); 280 IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), 281 C(matmulOp.getOutputBuffer(0)); 282 // Emit scalar form. 283 C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); 284 } 285 }; 286 287 template <typename IndexedValueType> 288 class LinalgScopedEmitter<IndexedValueType, ConvOp> { 289 public: 290 /// Returns the input value of convOp. If the indices in `imIdx` is out of 291 /// boundary, returns 0 instead. 292 static Value getConvOpInput(ConvOp convOp, StdIndexedValue im, 293 MutableArrayRef<Value> imIdx) { 294 // TODO(ntv): add a level of indirection to linalg.generic. 295 if (!convOp.padding()) 296 return im(imIdx); 297 298 auto *context = ScopedContext::getContext(); 299 Value zeroIndex = std_constant_index(0); 300 SmallVector<Value, 8> conds; 301 SmallVector<Value, 8> clampedImIdx; 302 for (auto iter : llvm::enumerate(imIdx)) { 303 int idx = iter.index(); 304 auto dim = iter.value(); 305 // Only need to iterate over the window dimensions. 306 if (idx == 0 || idx == static_cast<int>(imIdx.size()) - 1) { 307 clampedImIdx.push_back(dim); 308 continue; 309 } 310 311 using edsc::op::operator<; 312 using edsc::op::operator>=; 313 using edsc::op::operator||; 314 Value leftOutOfBound = dim < zeroIndex; 315 if (conds.empty()) 316 conds.push_back(leftOutOfBound); 317 else 318 conds.push_back(conds.back() || leftOutOfBound); 319 Value rightBound = std_dim(convOp.input(), idx); 320 conds.push_back(conds.back() || (dim >= rightBound)); 321 322 // When padding is involved, the indices will only be shifted to negative, 323 // so having a max op is enough. 324 auto maxMap = AffineMap::get(/*dimCount=*/1, 0, 325 {getAffineDimExpr(/*position=*/0, context), 326 getAffineConstantExpr(0, context)}, 327 context); 328 clampedImIdx.push_back( 329 affine_max(dim.getType(), maxMap, ValueRange{dim})); 330 } 331 332 auto &b = ScopedContext::getBuilderRef(); 333 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 334 Value zero = std_constant(type, b.getZeroAttr(type)); 335 Value readInput = im(clampedImIdx); 336 return conds.empty() ? readInput 337 : (Value)std_select(conds.back(), zero, readInput); 338 } 339 340 static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { 341 assert(convOp.hasBufferSemantics() && 342 "expected linalg op with buffer semantics"); 343 auto &b = ScopedContext::getBuilderRef(); 344 auto loc = ScopedContext::getLocation(); 345 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 346 auto maps = llvm::to_vector<8>(llvm::map_range( 347 mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 348 SmallVector<Value, 8> fIdx( 349 makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 350 SmallVector<Value, 8> imIdx( 351 makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 352 SmallVector<Value, 8> oIdx( 353 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 354 355 // Padded conv involves an affine.max in the memory access which is not 356 // allowed by affine.load. Override to always use an StdIndexedValue. 357 StdIndexedValue I(convOp.input()); 358 IndexedValueType F(convOp.filter()), O(convOp.output()); 359 360 // Emit scalar form. 361 Value paddedInput = getConvOpInput(convOp, I, imIdx); 362 O(oIdx) += F(fIdx) * paddedInput; 363 } 364 }; 365 366 template <typename IndexedValueType> 367 class LinalgScopedEmitter<IndexedValueType, PoolingMaxOp> { 368 public: 369 static void emitScalarImplementation(ArrayRef<Value> allIvs, 370 PoolingMaxOp op) { 371 auto indices = getInputAndOutputIndices(allIvs, op); 372 // Emit scalar form. 373 Value lhs = std_load(op.output(), indices.outputs); 374 Value rhs = std_load(op.input(), indices.inputs); 375 using edsc::op::operator>; 376 Value maxValue = std_select(lhs > rhs, lhs, rhs); 377 std_store(maxValue, op.output(), indices.outputs); 378 } 379 }; 380 381 template <typename IndexedValueType> 382 class LinalgScopedEmitter<IndexedValueType, PoolingMinOp> { 383 public: 384 static void emitScalarImplementation(ArrayRef<Value> allIvs, 385 PoolingMinOp op) { 386 auto indices = getInputAndOutputIndices(allIvs, op); 387 // Emit scalar form. 388 Value lhs = std_load(op.output(), indices.outputs); 389 Value rhs = std_load(op.input(), indices.inputs); 390 using edsc::op::operator<; 391 Value minValue = std_select(lhs < rhs, lhs, rhs); 392 std_store(minValue, op.output(), indices.outputs); 393 } 394 }; 395 396 template <typename IndexedValueType> 397 class LinalgScopedEmitter<IndexedValueType, PoolingSumOp> { 398 public: 399 static void emitScalarImplementation(ArrayRef<Value> allIvs, 400 PoolingSumOp op) { 401 auto indices = getInputAndOutputIndices(allIvs, op); 402 IndexedValueType input(op.input()), output(op.output()); 403 404 // Emit scalar form. 405 output(indices.outputs) += input(indices.inputs); 406 } 407 }; 408 409 /// Emits the MLIR for the scalar part of the indexed generic op by: 410 /// 1. Emitting load ops for each input and output view in order. This is 411 /// achieved by applying the appropriate input or output map to the 412 /// enclosing induction variables. 413 /// 2. Emitting a call to `op.fun()` that takes as arguments the induction 414 /// variables and the scalars from point 1. above. 415 /// 3. Emitting store ops to store the results of 2. to the output views. 416 /// 417 /// An example output may resemble: 418 /// 419 /// ``` 420 /// scf.for %i = %c0 to %0 step %c1 { 421 /// scf.for %j = %c0 to %1 step %c1 { 422 /// scf.for %k = %c0 to %4 step %c1 { 423 /// %11 = load %arg0[%i, %j] : 424 /// memref<?x?xf32, stride_specification> 425 /// %12 = load %arg1[%i, %j, %k] : 426 /// memref<?x?x?xf32, stride_specification> 427 /// %13 = load %arg2[%i, %k, %j] : 428 /// memref<?x?x?xf32, stride_specification> 429 /// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : 430 /// (index, index, index, f32, f32, f32) -> (f32, f32) 431 /// store %14#0, %arg1[%i, %j, %k] : 432 /// memref<?x?x?Xf32, stride_specification> 433 /// store %14#1, %arg2[%i, %k, %j] : 434 /// memref<?x?x?Xf32, stride_specification> 435 /// } 436 /// } 437 /// } 438 /// ``` 439 template <typename IndexedValueType> 440 class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> { 441 public: 442 static void emitScalarImplementation(ArrayRef<Value> allIvs, 443 IndexedGenericOp indexedGenericOp) { 444 assert(indexedGenericOp.hasBufferSemantics() && 445 "expected linalg op with buffer semantics"); 446 auto &b = ScopedContext::getBuilderRef(); 447 auto loc = ScopedContext::getLocation(); 448 unsigned nInputs = indexedGenericOp.getNumInputs(); 449 unsigned nOutputs = indexedGenericOp.getNumOutputs(); 450 unsigned nLoops = allIvs.size(); 451 SmallVector<Value, 4> indexedValues; 452 indexedValues.reserve(nLoops + nInputs + nOutputs); 453 for (unsigned i = 0; i < nLoops; ++i) 454 indexedValues.push_back(allIvs[i]); 455 456 // TODO(mravishankar): Avoid the loads if the corresponding argument of the 457 // region has no uses. 458 // 1.a. Emit load from input views. 459 for (unsigned i = 0; i < nInputs; ++i) { 460 auto indexing = makeCanonicalAffineApplies( 461 b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); 462 // Pass input i through IndexedValueType emits the proper load operation. 463 indexedValues.push_back( 464 IndexedValueType(indexedGenericOp.getInput(i))(indexing)); 465 } 466 // 1.b. Emit load from output views. 467 for (unsigned i = 0; i < nOutputs; ++i) { 468 auto indexing = makeCanonicalAffineApplies( 469 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); 470 // Pass output i through IndexedValueType emits the proper load operation. 471 indexedValues.push_back( 472 IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); 473 } 474 475 // TODO(ntv): When a region inliner exists, use it. 476 // 2. Inline region, currently only works for a single basic block. 477 // 3. Emit store. 478 SmallVector<SmallVector<Value, 8>, 8> indexing; 479 SmallVector<Value, 8> outputBuffers; 480 for (unsigned i = 0; i < nOutputs; ++i) { 481 indexing.push_back(makeCanonicalAffineApplies( 482 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); 483 outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); 484 } 485 inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues, 486 indexing, outputBuffers); 487 } 488 }; 489 490 template <typename LoopTy, typename ConcreteOpTy> 491 Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { 492 using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; 493 494 ScopedContext scope(builder, op->getLoc()); 495 496 // The flattened loopToOperandRangesMaps is expected to be an invertible 497 // permutation map (which is asserted in the inverse calculation). 498 auto linalgOp = cast<ConcreteOpTy>(op); 499 assert(linalgOp.hasBufferSemantics() && 500 "expected linalg op with buffer semantics"); 501 auto nPar = linalgOp.getNumParallelLoops(); 502 auto nRed = linalgOp.getNumReductionLoops(); 503 auto nWin = linalgOp.getNumWindowLoops(); 504 auto nLoops = nPar + nRed + nWin; 505 auto mapsRange = 506 linalgOp.indexing_maps().template getAsRange<AffineMapAttr>(); 507 auto maps = llvm::to_vector<8>( 508 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 509 AffineMap invertedMap = inversePermutation(concatAffineMaps(maps)); 510 if (!invertedMap) 511 return {}; 512 if (invertedMap.isEmpty()) { 513 LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation( 514 {}, linalgOp); 515 return LinalgLoops(); 516 } 517 518 SmallVector<Value, 4> allIvs(nLoops); 519 auto loopRanges = 520 emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap, 521 getViewSizes(builder, linalgOp)); 522 assert(loopRanges.size() == allIvs.size()); 523 GenerateLoopNest<LoopTy>::doit( 524 allIvs, loopRanges, linalgOp.iterator_types().getValue(), [&] { 525 SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end()); 526 LinalgScopedEmitter<IndexedValueTy, 527 ConcreteOpTy>::emitScalarImplementation(allIvValues, 528 linalgOp); 529 }); 530 // Number of loop ops might be different from the number of ivs since some 531 // loops like affine.parallel and scf.parallel have multiple ivs. 532 llvm::SetVector<Operation *> loopSet; 533 for (Value iv : allIvs) { 534 if (!iv) 535 return {}; 536 // The induction variable is a block argument of the entry block of the 537 // loop operation. 538 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 539 if (!ivVal) 540 return {}; 541 loopSet.insert(ivVal.getOwner()->getParentOp()); 542 } 543 LinalgLoops loops(loopSet.begin(), loopSet.end()); 544 return loops; 545 } 546 547 template <typename LoopType, typename ConcreteOp> 548 class LinalgRewritePattern : public RewritePattern { 549 public: 550 explicit LinalgRewritePattern(MLIRContext *context) 551 : RewritePattern(ConcreteOp::getOperationName(), 1, context) {} 552 553 LogicalResult matchAndRewrite(Operation *op, 554 PatternRewriter &rewriter) const override { 555 if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter)) 556 return failure(); 557 rewriter.eraseOp(op); 558 return success(); 559 } 560 }; 561 562 /// Helper classes for type list expansion. 563 template <typename LoopType, typename... LinalgOps> 564 class RewritePatternList; 565 566 template <typename LoopType> 567 class RewritePatternList<LoopType> { 568 public: 569 static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {} 570 }; 571 572 template <typename LoopType, typename ConcreteOp, typename... LinalgOps> 573 class RewritePatternList<LoopType, ConcreteOp, LinalgOps...> { 574 public: 575 static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) { 576 patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx); 577 RewritePatternList<LoopType, LinalgOps...>::build(patterns, ctx); 578 } 579 }; 580 581 /// Populate the given list with patterns that convert from Linalg to loops. 582 template <typename LoopType> 583 void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { 584 RewritePatternList<LoopType, 585 #define GET_OP_LIST 586 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 587 >::build(patterns, ctx); 588 } 589 590 /// Local folding pattern for AffineApplyOp that we can apply greedily. 591 /// This replaces AffineApplyOp by the proper value in cases where the 592 /// associated map is trivial. 593 /// A trivial map here is defined as a map with a single result and either: 594 /// 1. Zero operand + returns a single AffineConstantExpr 595 /// 2. One operand + returns a single AffineDimExpr 596 /// 3. One operand + returns a single AffineSymbolExpr 597 // 598 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 599 /// other cases, it is replaced by its unique operand. 600 struct FoldAffineOp : public RewritePattern { 601 FoldAffineOp(MLIRContext *context) 602 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 603 604 LogicalResult matchAndRewrite(Operation *op, 605 PatternRewriter &rewriter) const override { 606 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 607 auto map = affineApplyOp.getAffineMap(); 608 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 609 return failure(); 610 611 AffineExpr expr = map.getResult(0); 612 if (map.getNumInputs() == 0) { 613 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 614 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 615 return success(); 616 } 617 return failure(); 618 } 619 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 620 rewriter.replaceOp(op, op->getOperand(0)); 621 return success(); 622 } 623 return failure(); 624 } 625 }; 626 } // namespace 627 628 template <typename LoopType> 629 static void lowerLinalgToLoopsImpl(Operation *op, MLIRContext *context) { 630 OwningRewritePatternList patterns; 631 // Canonicalization and folding patterns applied greedily allow cleaning up 632 // the emitted IR on the fly. 633 // TODO(ntv) fold view and subview ops? 634 FillRewritePatterns<LoopType>(patterns, context); 635 DimOp::getCanonicalizationPatterns(patterns, context); 636 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 637 patterns.insert<FoldAffineOp>(context); 638 // Just apply the patterns greedily. 639 applyPatternsAndFoldGreedily(op, patterns); 640 } 641 642 namespace { 643 struct LowerToAffineLoops 644 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 645 void runOnFunction() override { 646 lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext()); 647 } 648 }; 649 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 650 void runOnFunction() override { 651 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext()); 652 } 653 }; 654 struct LowerToParallelLoops 655 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 656 void runOnFunction() override { 657 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), &getContext()); 658 } 659 }; 660 } // namespace 661 662 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 663 return std::make_unique<LowerToLoops>(); 664 } 665 666 std::unique_ptr<OperationPass<FuncOp>> 667 mlir::createConvertLinalgToParallelLoopsPass() { 668 return std::make_unique<LowerToParallelLoops>(); 669 } 670 671 std::unique_ptr<OperationPass<FuncOp>> 672 mlir::createConvertLinalgToAffineLoopsPass() { 673 return std::make_unique<LowerToAffineLoops>(); 674 } 675 676 /// Emits a loop nest with the proper body for `op`. 677 template <typename LoopTy, typename ConcreteOp> 678 Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, 679 Operation *op) { 680 return linalgOpToLoopsImpl<LoopTy, ConcreteOp>(op, builder); 681 } 682 683 /// Emits a loop nest of `scf.for` with the proper body for `op`. 684 template <typename ConcreteOp> 685 LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { 686 Optional<LinalgLoops> loops = 687 linalgLowerOpToLoops<scf::ForOp, ConcreteOp>(builder, op); 688 return loops ? success() : failure(); 689 } 690 691 /// Emits a loop nest of `affine.for` with the proper body for `op`. 692 template <typename ConcreteOp> 693 LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, 694 Operation *op) { 695 Optional<LinalgLoops> loops = 696 linalgLowerOpToLoops<AffineForOp, ConcreteOp>(builder, op); 697 return loops ? success() : failure(); 698 } 699 700 /// Emits a loop nest of `scf.parallel` with the proper body for `op`. 701 template <typename ConcreteOp> 702 LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, 703 Operation *op) { 704 Optional<LinalgLoops> loops = 705 linalgLowerOpToLoops<scf::ParallelOp, ConcreteOp>(builder, op); 706 return loops ? success() : failure(); 707 } 708 709 // TODO Need to make these instantiations more future-proof to avoid the need to 710 // update as soon as we add new ops. 711 #define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ 712 template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \ 713 OpBuilder & builder, Operation * op); \ 714 template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \ 715 OpBuilder & builder, Operation * op); \ 716 template LogicalResult mlir::linalg::linalgOpToParallelLoops<OP_TYPE>( \ 717 OpBuilder & builder, Operation * op); \ 718 template Optional<LinalgLoops> \ 719 mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp, OP_TYPE>( \ 720 OpBuilder & builder, Operation * op); 721 722 INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp) 723 INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp) 724 INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp) 725 INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp) 726 INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp) 727 INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp) 728 INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMaxOp) 729 INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp) 730 INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingSumOp) 731 INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp) 732 INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp) 733