1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 11 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/SCF/EDSC/Builders.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/BlockAndValueMapping.h" 22 #include "mlir/Support/LLVM.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 26 using namespace mlir; 27 using namespace mlir::edsc; 28 using namespace mlir::edsc::intrinsics; 29 using namespace mlir::linalg; 30 31 using edsc::op::operator+; 32 33 static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b, 34 Location loc, 35 AffineMap map, 36 ArrayRef<Value> vals) { 37 if (map.isEmpty()) 38 return {}; 39 40 assert(map.getNumInputs() == vals.size()); 41 SmallVector<Value, 8> res; 42 res.reserve(map.getNumResults()); 43 auto dims = map.getNumDims(); 44 for (auto e : map.getResults()) { 45 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 46 SmallVector<Value, 4> operands(vals.begin(), vals.end()); 47 canonicalizeMapAndOperands(&exprMap, &operands); 48 res.push_back(affine_apply(exprMap, operands)); 49 } 50 return res; 51 } 52 53 static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs, 54 Optional<AffineMap> permutation) { 55 return permutation ? applyMapToValues(ScopedContext::getBuilderRef(), 56 ScopedContext::getLocation(), 57 permutation.getValue(), ivs) 58 : SmallVector<Value, 4>(ivs.begin(), ivs.end()); 59 } 60 61 /// Creates a number of ranges equal to the number of dimensions in the `map`. 62 /// The returned ranges correspond to the loop ranges, in the proper order, for 63 /// which new loops will be created. 64 /// The function supports only maps that are invertible and have results of type 65 /// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr). 66 /// It expects a non-inverted, concatenated map and last values in 67 /// allViewSizes will be applied to the symbols in the map if it contains any. 68 static SmallVector<Range, 4> emitLoopRanges(OpBuilder &b, Location loc, 69 AffineMap map, 70 ValueRange viewSizes) { 71 unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); 72 unsigned numSym = map.getNumSymbols(); 73 assert(viewSizes.size() == numRes + numSym && 74 "viewSizes must contain sizes of all views and values for symbols"); 75 SmallVector<Range, 4> res(numDims); 76 for (unsigned idx = 0; idx < numRes; ++idx) { 77 auto result = map.getResult(idx); 78 if (auto d = result.dyn_cast<AffineDimExpr>()) { 79 if (res[d.getPosition()].offset) 80 continue; 81 res[d.getPosition()] = 82 Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)}; 83 } 84 85 // If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2), 86 // then the bounds are: 87 // (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1). 88 // where size(n) is applied to the symbol s. 89 // This is done statically now. 90 if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) { 91 auto lhs = binOp.getLHS().dyn_cast<AffineBinaryOpExpr>(); 92 auto rhs = binOp.getRHS().dyn_cast<AffineBinaryOpExpr>(); 93 if (!lhs || !rhs || binOp.getKind() != AffineExprKind::Add || 94 lhs.getKind() != AffineExprKind::Add || 95 rhs.getKind() != mlir::AffineExprKind::Mul) 96 continue; 97 98 auto m = lhs.getLHS().dyn_cast<AffineDimExpr>(); 99 auto n = lhs.getRHS().dyn_cast<AffineDimExpr>(); 100 auto fDiv = rhs.getLHS().dyn_cast<AffineBinaryOpExpr>(); 101 auto minusOne = rhs.getRHS().dyn_cast<AffineConstantExpr>(); 102 if (!m || !n || !fDiv || !minusOne || 103 fDiv.getKind() != AffineExprKind::FloorDiv || 104 fDiv.getLHS().getKind() != AffineExprKind::SymbolId || 105 fDiv.getRHS().getKind() != AffineExprKind::Constant) 106 continue; 107 108 auto s = fDiv.getLHS().dyn_cast<AffineSymbolExpr>(); 109 if (minusOne.getValue() != -1) 110 continue; 111 112 int mPos = m.getPosition(); 113 AffineExpr one = getAffineConstantExpr(1, s.getContext()); 114 AffineExpr sizeOfM = getAffineSymbolExpr(numSym, s.getContext()); 115 // Construction of upper bound (size(m) + s floordiv 2 - s + 1). 116 AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s; 117 AffineMap fromMap = AffineMap::get(numDims, numSym + 1, fDiv); 118 AffineMap toMap = AffineMap::get(numDims, numSym + 1, upperOffsetExpr); 119 SmallVector<Value, 8> values(viewSizes.begin(), 120 viewSizes.begin() + numDims); 121 values.insert(values.end(), viewSizes.begin() + numRes, viewSizes.end()); 122 values.push_back(viewSizes[mPos]); 123 // Construction of the lower bound (s floordiv 2). 124 Value from = applyMapToValues(b, loc, fromMap, values).front(); 125 Value to = applyMapToValues(b, loc, toMap, values).front(); 126 res[mPos] = Range{from, to, std_constant_index(1)}; 127 } 128 } 129 return res; 130 } 131 132 template <typename IndexedValueType, typename OpType> 133 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues, 134 ArrayRef<SmallVector<Value, 8>> indexing, 135 ArrayRef<Value> outputBuffers) { 136 assert(op.getOperation()->getNumRegions() == 1 && 137 "Expected single region op"); 138 auto &b = ScopedContext::getBuilderRef(); 139 auto &block = op.region().front(); 140 BlockAndValueMapping map; 141 map.map(block.getArguments(), indexedValues); 142 for (auto &op : block.without_terminator()) { 143 assert(op.getNumRegions() == 0 && "expected a non-nested region"); 144 auto *newOp = b.clone(op, map); 145 map.map(op.getResults(), newOp->getResults()); 146 } 147 148 Operation &terminator = block.back(); 149 assert(isa<linalg::YieldOp>(terminator) && 150 "expected a yield op in the end of the region"); 151 for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { 152 IndexedValueType O(outputBuffers[i]); 153 O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); 154 } 155 } 156 157 // Returns a pair that contains input indices and output indices of a 158 // SingleInputPoolingOp `op`. 159 struct InputAndOutputIndices { 160 SmallVector<Value, 8> inputs; 161 SmallVector<Value, 8> outputs; 162 }; 163 template <typename SingleInputPoolingOp> 164 static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs, 165 SingleInputPoolingOp op) { 166 auto &b = ScopedContext::getBuilderRef(); 167 auto loc = ScopedContext::getLocation(); 168 auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>(); 169 auto maps = llvm::to_vector<8>( 170 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 171 return InputAndOutputIndices{ 172 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 173 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 174 } 175 176 namespace { 177 178 /// Emits the MLIR for the scalar part of the generic op by: 179 /// 1. Emitting load ops for each input and output view in order. This is 180 /// achieved by applying the appropriate input or output map to the 181 /// enclosing induction variables. 182 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 183 /// from point 1. above. 184 /// 3. Emitting store ops to store the results of 2. to the output 185 /// views. 186 /// 187 /// An example output may resemble: 188 /// 189 /// ``` 190 /// scf.for %i = %c0 to %0 step %c1 { 191 /// scf.for %j = %c0 to %1 step %c1 { 192 /// scf.for %k = %c0 to %4 step %c1 { 193 /// %11 = load %arg0[%i, %j] : 194 /// memref<?x?xf32, stride_specification> 195 /// %12 = load %arg1[%i, %j, %k] : 196 /// memref<?x?x?xf32, stride_specification> 197 /// %13 = load %arg2[%i, %k, %j] : 198 /// memref<?x?x?xf32, stride_specification> 199 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 200 /// store %14#0, %arg1[%i, %j, %k] : 201 /// memref<?x?x?Xf32, stride_specification> 202 /// store %14#1, %arg2[%i, %k, %j] : 203 /// memref<?x?x?Xf32, stride_specification> 204 /// } 205 /// } 206 /// } 207 /// ``` 208 // TODO: need a LinalgStructuredOpInterface. 209 template <typename IndexedValueType, typename LinalgStructuredOpType> 210 void emitScalarImplementation(ArrayRef<Value> allIvs, 211 LinalgStructuredOpType linalgOp) { 212 assert(linalgOp.hasBufferSemantics() && 213 "expected linalg op with buffer semantics"); 214 auto &b = ScopedContext::getBuilderRef(); 215 auto loc = ScopedContext::getLocation(); 216 unsigned nInputs = linalgOp.getNumInputs(); 217 unsigned nOutputs = linalgOp.getNumOutputs(); 218 SmallVector<Value, 4> indexedValues; 219 indexedValues.reserve(nInputs + nOutputs); 220 221 auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source"); 222 auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end()); 223 if (attr) { 224 auto operand = linalgOp.getOperand(attr.getInt()); 225 auto shapedType = operand.getType().template cast<ShapedType>(); 226 allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank()); 227 for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx) 228 allIvsPlusDims.push_back(b.create<DimOp>(loc, operand, idx)); 229 } 230 231 // TODO: Avoid the loads if the corresponding argument of the 232 // region has no uses. 233 // 1.a. Emit load from input views. 234 for (unsigned i = 0; i < nInputs; ++i) { 235 auto indexing = makeCanonicalAffineApplies( 236 b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims); 237 // Passing through IndexedValueType emits the proper load operation. 238 indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); 239 } 240 // 1.b. Emit load from output views. 241 for (unsigned i = 0; i < nOutputs; ++i) { 242 auto indexing = makeCanonicalAffineApplies( 243 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims); 244 // Passing through IndexedValueType emits the proper load operation. 245 indexedValues.push_back( 246 IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); 247 } 248 249 // TODO: When a region inliner exists, use it. 250 // 2. Inline region, currently only works for a single basic block. 251 // 3. Emit store. 252 SmallVector<SmallVector<Value, 8>, 8> indexing; 253 SmallVector<Value, 8> outputBuffers; 254 for (unsigned i = 0; i < nOutputs; ++i) { 255 indexing.push_back(makeCanonicalAffineApplies( 256 b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims)); 257 outputBuffers.push_back(linalgOp.getOutputBuffer(i)); 258 } 259 inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing, 260 outputBuffers); 261 } 262 263 template <typename IndexedValueType> 264 void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) { 265 assert(copyOp.hasBufferSemantics() && 266 "expected linalg op with buffer semantics"); 267 auto nPar = copyOp.getNumParallelLoops(); 268 assert(nPar == allIvs.size()); 269 auto inputIvs = 270 permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); 271 auto outputIvs = 272 permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); 273 SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end()); 274 SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end()); 275 IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); 276 // Emit the proper scalar assignment, whether we are dealing with a 0-D or 277 // an n-D loop nest; with or without permutations. 278 // clang-format off 279 nPar > 0 ? O(oivs) = I(iivs) : 280 O() = I(); 281 // clang-format on 282 } 283 284 template <typename IndexedValueType> 285 void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) { 286 assert(fillOp.hasBufferSemantics() && 287 "expected linalg op with buffer semantics"); 288 auto nPar = fillOp.getNumParallelLoops(); 289 assert(nPar == allIvs.size()); 290 auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar); 291 IndexedValueType O(fillOp.getOutputBuffer(0)); 292 // Emit the proper scalar assignment, whether we are dealing with a 0-D or 293 // an n-D loop nest; with or without permutations. 294 nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); 295 } 296 297 template <typename IndexedValueType> 298 Value getConvOpInput(ConvOp convOp, StdIndexedValue im, 299 MutableArrayRef<Value> imIdx) { 300 // TODO: add a level of indirection to linalg.generic. 301 if (!convOp.padding()) 302 return im(imIdx); 303 304 auto *context = ScopedContext::getContext(); 305 Value zeroIndex = std_constant_index(0); 306 SmallVector<Value, 8> conds; 307 SmallVector<Value, 8> clampedImIdx; 308 for (auto iter : llvm::enumerate(imIdx)) { 309 int idx = iter.index(); 310 auto dim = iter.value(); 311 // Only need to iterate over the window dimensions. 312 if (idx == 0 || idx == static_cast<int>(imIdx.size()) - 1) { 313 clampedImIdx.push_back(dim); 314 continue; 315 } 316 317 using edsc::op::sge; 318 using edsc::op::slt; 319 using edsc::op::operator||; 320 Value leftOutOfBound = slt(dim, zeroIndex); 321 if (conds.empty()) 322 conds.push_back(leftOutOfBound); 323 else 324 conds.push_back(conds.back() || leftOutOfBound); 325 Value rightBound = std_dim(convOp.input(), idx); 326 conds.push_back(conds.back() || (sge(dim, rightBound))); 327 328 // When padding is involved, the indices will only be shifted to negative, 329 // so having a max op is enough. 330 auto maxMap = AffineMap::get(/*dimCount=*/1, 0, 331 {getAffineDimExpr(/*position=*/0, context), 332 getAffineConstantExpr(0, context)}, 333 context); 334 clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim})); 335 } 336 337 auto &b = ScopedContext::getBuilderRef(); 338 Type type = convOp.input().getType().cast<MemRefType>().getElementType(); 339 Value zero = std_constant(type, b.getZeroAttr(type)); 340 Value readInput = im(clampedImIdx); 341 return conds.empty() ? readInput 342 : (Value)std_select(conds.back(), zero, readInput); 343 } 344 345 /// Returns true is `convOp` has a non-zero padding. 346 static bool hasPadding(ConvOp convOp) { 347 for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { 348 if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) 349 return true; 350 } 351 return false; 352 } 353 354 template <typename IndexedValueType> 355 static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { 356 assert(convOp.hasBufferSemantics() && 357 "expected linalg op with buffer semantics"); 358 auto &b = ScopedContext::getBuilderRef(); 359 auto loc = ScopedContext::getLocation(); 360 auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>(); 361 auto maps = llvm::to_vector<8>( 362 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 363 SmallVector<Value, 8> fIdx( 364 makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); 365 SmallVector<Value, 8> imIdx( 366 makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); 367 SmallVector<Value, 8> oIdx( 368 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); 369 370 IndexedValueType F(convOp.filter()), O(convOp.output()); 371 372 // Emit scalar form. Padded conv involves an affine.max in the memory access 373 // which is not allowed by affine.load. Override to use an StdIndexedValue 374 // when there is non-zero padding. 375 if (hasPadding(convOp)) { 376 StdIndexedValue I(convOp.input()); 377 Value paddedInput = getConvOpInput<IndexedValueType>(convOp, I, imIdx); 378 O(oIdx) += F(fIdx) * paddedInput; 379 } else { 380 IndexedValueType I(convOp.input()); 381 O(oIdx) += F(fIdx) * I(imIdx); 382 } 383 } 384 385 template <typename IndexedValueType> 386 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) { 387 InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); 388 // Emit scalar form. 389 IndexedValueType output(op.output()); 390 IndexedValueType input(op.input()); 391 Value lhs = output(indices.outputs); 392 Value rhs = input(indices.inputs); 393 using edsc::op::sgt; 394 Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs); 395 output(indices.outputs) = maxValue; 396 } 397 398 template <typename IndexedValueType> 399 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) { 400 InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); 401 // Emit scalar form. 402 IndexedValueType output(op.output()); 403 IndexedValueType input(op.input()); 404 Value lhs = output(indices.outputs); 405 Value rhs = input(indices.inputs); 406 using edsc::op::slt; 407 Value minValue = std_select(slt(lhs, rhs), lhs, rhs); 408 output(indices.outputs) = minValue; 409 } 410 template <typename IndexedValueType> 411 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) { 412 auto indices = getInputAndOutputIndices(allIvs, op); 413 IndexedValueType input(op.input()), output(op.output()); 414 415 // Emit scalar form. 416 output(indices.outputs) += input(indices.inputs); 417 } 418 /// Emits the MLIR for the scalar part of the indexed generic op by: 419 /// 1. Emitting load ops for each input and output view in order. This is 420 /// achieved by applying the appropriate input or output map to the 421 /// enclosing induction variables. 422 /// 2. Emitting a call to `op.fun()` that takes as arguments the induction 423 /// variables and the scalars from point 1. above. 424 /// 3. Emitting store ops to store the results of 2. to the output views. 425 /// 426 /// An example output may resemble: 427 /// 428 /// ``` 429 /// scf.for %i = %c0 to %0 step %c1 { 430 /// scf.for %j = %c0 to %1 step %c1 { 431 /// scf.for %k = %c0 to %4 step %c1 { 432 /// %11 = load %arg0[%i, %j] : 433 /// memref<?x?xf32, stride_specification> 434 /// %12 = load %arg1[%i, %j, %k] : 435 /// memref<?x?x?xf32, stride_specification> 436 /// %13 = load %arg2[%i, %k, %j] : 437 /// memref<?x?x?xf32, stride_specification> 438 /// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : 439 /// (index, index, index, f32, f32, f32) -> (f32, f32) 440 /// store %14#0, %arg1[%i, %j, %k] : 441 /// memref<?x?x?Xf32, stride_specification> 442 /// store %14#1, %arg2[%i, %k, %j] : 443 /// memref<?x?x?Xf32, stride_specification> 444 /// } 445 /// } 446 /// } 447 /// ``` 448 template <typename IndexedValueType> 449 static void emitScalarImplementation(ArrayRef<Value> allIvs, 450 IndexedGenericOp indexedGenericOp) { 451 assert(indexedGenericOp.hasBufferSemantics() && 452 "expected linalg op with buffer semantics"); 453 auto &b = ScopedContext::getBuilderRef(); 454 auto loc = ScopedContext::getLocation(); 455 unsigned nInputs = indexedGenericOp.getNumInputs(); 456 unsigned nOutputs = indexedGenericOp.getNumOutputs(); 457 unsigned nLoops = allIvs.size(); 458 SmallVector<Value, 4> indexedValues; 459 indexedValues.reserve(nLoops + nInputs + nOutputs); 460 for (unsigned i = 0; i < nLoops; ++i) 461 indexedValues.push_back(allIvs[i]); 462 463 // TODO: Avoid the loads if the corresponding argument of the 464 // region has no uses. 465 // 1.a. Emit load from input views. 466 for (unsigned i = 0; i < nInputs; ++i) { 467 auto indexing = makeCanonicalAffineApplies( 468 b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); 469 // Pass input i through IndexedValueType emits the proper load operation. 470 indexedValues.push_back( 471 IndexedValueType(indexedGenericOp.getInput(i))(indexing)); 472 } 473 // 1.b. Emit load from output views. 474 for (unsigned i = 0; i < nOutputs; ++i) { 475 auto indexing = makeCanonicalAffineApplies( 476 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); 477 // Pass output i through IndexedValueType emits the proper load operation. 478 indexedValues.push_back( 479 IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); 480 } 481 482 // TODO: When a region inliner exists, use it. 483 // 2. Inline region, currently only works for a single basic block. 484 // 3. Emit store. 485 SmallVector<SmallVector<Value, 8>, 8> indexing; 486 SmallVector<Value, 8> outputBuffers; 487 for (unsigned i = 0; i < nOutputs; ++i) { 488 indexing.push_back(makeCanonicalAffineApplies( 489 b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); 490 outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); 491 } 492 inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues, 493 indexing, outputBuffers); 494 } 495 496 template <typename LoopTy, typename ConcreteOpTy> 497 Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { 498 using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; 499 500 ScopedContext scope(builder, op->getLoc()); 501 502 // The flattened loopToOperandRangesMaps is expected to be an invertible 503 // permutation map (which is asserted in the inverse calculation). 504 auto linalgOp = cast<ConcreteOpTy>(op); 505 assert(linalgOp.hasBufferSemantics() && 506 "expected linalg op with buffer semantics"); 507 auto mapsRange = 508 linalgOp.indexing_maps().template getAsRange<AffineMapAttr>(); 509 auto maps = llvm::to_vector<8>( 510 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 511 SmallVector<Value, 8> sizes = getViewSizes(builder, linalgOp); 512 AffineMap map = concatAffineMaps(maps); 513 auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), 514 map, getViewSizes(builder, linalgOp)); 515 SmallVector<Value, 4> allIvs; 516 GenerateLoopNest<LoopTy>::doit( 517 loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(), 518 [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { 519 assert(iterArgs.empty() && "unexpected iterArgs"); 520 allIvs.append(ivs.begin(), ivs.end()); 521 emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp); 522 return scf::ValueVector{}; 523 }); 524 // Number of loop ops might be different from the number of ivs since some 525 // loops like affine.parallel and scf.parallel have multiple ivs. 526 llvm::SetVector<Operation *> loopSet; 527 for (Value iv : allIvs) { 528 if (!iv) 529 return {}; 530 // The induction variable is a block argument of the entry block of the 531 // loop operation. 532 BlockArgument ivVal = iv.dyn_cast<BlockArgument>(); 533 if (!ivVal) 534 return {}; 535 loopSet.insert(ivVal.getOwner()->getParentOp()); 536 } 537 LinalgLoops loops(loopSet.begin(), loopSet.end()); 538 return loops; 539 } 540 541 template <typename LoopType, typename ConcreteOp> 542 class LinalgRewritePattern : public RewritePattern { 543 public: 544 explicit LinalgRewritePattern(MLIRContext *context) 545 : RewritePattern(ConcreteOp::getOperationName(), 1, context) {} 546 547 LogicalResult matchAndRewrite(Operation *op, 548 PatternRewriter &rewriter) const override { 549 if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter)) 550 return failure(); 551 rewriter.eraseOp(op); 552 return success(); 553 } 554 }; 555 556 template <typename LoopType, typename ConcreteOp> 557 void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) { 558 patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx); 559 } 560 561 template <typename LoopType, typename... Args> 562 void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { 563 (void)std::initializer_list<int>{ 564 0, (insertOnePattern<LoopType, Args>(patterns, ctx), 0)...}; 565 } 566 567 /// Local folding pattern for AffineApplyOp that we can apply greedily. 568 /// This replaces AffineApplyOp by the proper value in cases where the 569 /// associated map is trivial. 570 /// A trivial map here is defined as a map with a single result and either: 571 /// 1. Zero operand + returns a single AffineConstantExpr 572 /// 2. One operand + returns a single AffineDimExpr 573 /// 3. One operand + returns a single AffineSymbolExpr 574 // 575 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 576 /// other cases, it is replaced by its unique operand. 577 struct FoldAffineOp : public RewritePattern { 578 FoldAffineOp(MLIRContext *context) 579 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} 580 581 LogicalResult matchAndRewrite(Operation *op, 582 PatternRewriter &rewriter) const override { 583 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); 584 auto map = affineApplyOp.getAffineMap(); 585 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 586 return failure(); 587 588 AffineExpr expr = map.getResult(0); 589 if (map.getNumInputs() == 0) { 590 if (auto val = expr.dyn_cast<AffineConstantExpr>()) { 591 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); 592 return success(); 593 } 594 return failure(); 595 } 596 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { 597 rewriter.replaceOp(op, op->getOperand(0)); 598 return success(); 599 } 600 return failure(); 601 } 602 }; 603 } // namespace 604 605 template <typename LoopType> 606 static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) { 607 OwningRewritePatternList patterns; 608 // Canonicalization and folding patterns applied greedily allow cleaning up 609 // the emitted IR on the fly. 610 // TODO: fold view and subview ops? 611 insertPatterns<LoopType, 612 #define GET_OP_LIST 613 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 614 >(patterns, context); 615 616 DimOp::getCanonicalizationPatterns(patterns, context); 617 AffineApplyOp::getCanonicalizationPatterns(patterns, context); 618 patterns.insert<FoldAffineOp>(context); 619 // Just apply the patterns greedily. 620 applyPatternsAndFoldGreedily(funcOp, patterns); 621 } 622 623 namespace { 624 struct LowerToAffineLoops 625 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { 626 void runOnFunction() override { 627 lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext()); 628 } 629 }; 630 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { 631 void runOnFunction() override { 632 lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext()); 633 } 634 }; 635 struct LowerToParallelLoops 636 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { 637 void runOnFunction() override { 638 lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), &getContext()); 639 } 640 }; 641 } // namespace 642 643 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() { 644 return std::make_unique<LowerToLoops>(); 645 } 646 647 std::unique_ptr<OperationPass<FuncOp>> 648 mlir::createConvertLinalgToParallelLoopsPass() { 649 return std::make_unique<LowerToParallelLoops>(); 650 } 651 652 std::unique_ptr<OperationPass<FuncOp>> 653 mlir::createConvertLinalgToAffineLoopsPass() { 654 return std::make_unique<LowerToAffineLoops>(); 655 } 656 657 // TODO: gradually remove this layer as more ops become "named". 658 template <typename LoopTy> 659 static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op, 660 OpBuilder &builder) { 661 assert(isa<LinalgOp>(op) && "LinalgOp expected"); 662 if (isa<CopyOp>(op)) 663 return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder); 664 if (isa<FillOp>(op)) 665 return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder); 666 if (isa<ConvOp>(op)) 667 return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder); 668 if (isa<PoolingMaxOp>(op)) 669 return linalgOpToLoopsImpl<LoopTy, PoolingMaxOp>(op, builder); 670 if (isa<PoolingMinOp>(op)) 671 return linalgOpToLoopsImpl<LoopTy, PoolingMinOp>(op, builder); 672 if (isa<PoolingSumOp>(op)) 673 return linalgOpToLoopsImpl<LoopTy, PoolingSumOp>(op, builder); 674 if (isa<IndexedGenericOp>(op)) 675 return linalgOpToLoopsImpl<LoopTy, IndexedGenericOp>(op, builder); 676 677 // TODO: Cases below are generic and need a LinalgStructuredOpInterface. 678 if (isa<GenericOp>(op)) 679 return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder); 680 if (isa<MatmulOp>(op)) 681 return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder); 682 if (isa<MatvecOp>(op)) 683 return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder); 684 if (isa<VecmatOp>(op)) 685 return linalgOpToLoopsImpl<LoopTy, VecmatOp>(op, builder); 686 if (isa<DotOp>(op)) 687 return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder); 688 if (isa<BatchMatmulOp>(op)) 689 return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder); 690 if (isa<ConvWOp>(op)) 691 return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder); 692 if (isa<ConvNWCOp>(op)) 693 return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder); 694 if (isa<ConvNCWOp>(op)) 695 return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder); 696 if (isa<ConvHWOp>(op)) 697 return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder); 698 if (isa<ConvNHWCOp>(op)) 699 return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder); 700 if (isa<ConvNCHWOp>(op)) 701 return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder); 702 if (isa<ConvDHWOp>(op)) 703 return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder); 704 if (isa<ConvNDHWCOp>(op)) 705 return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder); 706 if (isa<ConvNCDHWOp>(op)) 707 return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder); 708 llvm_unreachable("Unexpected op in linalgOpToLoopsImpl"); 709 } 710 711 /// Emits a loop nest with the proper body for `op`. 712 template <typename LoopTy> 713 Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, 714 Operation *op) { 715 return linalgOpToLoopsImplSwitch<LoopTy>(op, builder); 716 } 717 718 template Optional<LinalgLoops> 719 mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder, 720 Operation *op); 721 template Optional<LinalgLoops> 722 mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder, 723 Operation *op); 724 template Optional<LinalgLoops> 725 mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder, 726 Operation *op); 727 728 /// Emits a loop nest of `affine.for` with the proper body for `op`. 729 LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, 730 Operation *op) { 731 Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op); 732 return loops ? success() : failure(); 733 } 734 735 /// Emits a loop nest of `scf.for` with the proper body for `op`. 736 LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { 737 Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op); 738 return loops ? success() : failure(); 739 } 740 741 /// Emits a loop nest of `scf.parallel` with the proper body for `op`. 742 LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, 743 Operation *op) { 744 Optional<LinalgLoops> loops = 745 linalgLowerOpToLoops<scf::ParallelOp>(builder, op); 746 return loops ? success() : failure(); 747 } 748