1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/Linalg/Passes.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
17 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/Support/LLVM.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/ADT/TypeSwitch.h"
26
27 using namespace mlir;
28 using namespace mlir::linalg;
29
makeCanonicalAffineApplies(OpBuilder & b,Location loc,AffineMap map,ArrayRef<Value> vals)30 static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc,
31 AffineMap map,
32 ArrayRef<Value> vals) {
33 if (map.isEmpty())
34 return {};
35
36 assert(map.getNumInputs() == vals.size());
37 SmallVector<Value> res;
38 res.reserve(map.getNumResults());
39 auto dims = map.getNumDims();
40 for (auto e : map.getResults()) {
41 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
42 SmallVector<Value> operands(vals.begin(), vals.end());
43 canonicalizeMapAndOperands(&exprMap, &operands);
44 res.push_back(b.create<AffineApplyOp>(loc, exprMap, operands));
45 }
46 return res;
47 }
48
49 template <typename LoadOpTy, typename StoreOpTy, typename OpType>
inlineRegionAndEmitStore(OpBuilder & b,Location loc,OpType op,ArrayRef<Value> indexedValues,ArrayRef<SmallVector<Value>> indexing,ArrayRef<Value> outputBuffers)50 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op,
51 ArrayRef<Value> indexedValues,
52 ArrayRef<SmallVector<Value>> indexing,
53 ArrayRef<Value> outputBuffers) {
54 auto &block = op->getRegion(0).front();
55 BlockAndValueMapping map;
56 map.map(block.getArguments(), indexedValues);
57 for (auto &op : block.without_terminator()) {
58 auto *newOp = b.clone(op, map);
59 map.map(op.getResults(), newOp->getResults());
60 }
61
62 Operation *terminator = block.getTerminator();
63 for (OpOperand &operand : terminator->getOpOperands()) {
64 Value toStore = map.lookupOrDefault(operand.get());
65 b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
66 indexing[operand.getOperandNumber()]);
67 }
68 }
69
70 // Returns a pair that contains input indices and output indices of a
71 // SingleInputPoolingOp `op`.
72 struct InputAndOutputIndices {
73 SmallVector<Value> inputs;
74 SmallVector<Value> outputs;
75 };
76 template <typename SingleInputPoolingOp>
77 static InputAndOutputIndices
getInputAndOutputIndices(OpBuilder & b,Location loc,ArrayRef<Value> allIvs,SingleInputPoolingOp op)78 getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs,
79 SingleInputPoolingOp op) {
80 auto mapsRange = op.getIndexingMapsArray();
81 auto maps = llvm::to_vector<8>(
82 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
83 return InputAndOutputIndices{
84 makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
85 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
86 }
87
88 /// Emits the MLIR for the scalar part of the generic op by:
89 /// 1. Emitting load ops for each input and output view in order. This is
90 /// achieved by applying the appropriate input or output map to the
91 /// enclosing induction variables.
92 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars
93 /// from point 1. above.
94 /// 3. Emitting store ops to store the results of 2. to the output
95 /// views.
96 ///
97 /// An example output may resemble:
98 ///
99 /// ```
100 /// scf.for %i = %c0 to %0 step %c1 {
101 /// scf.for %j = %c0 to %1 step %c1 {
102 /// scf.for %k = %c0 to %4 step %c1 {
103 /// %11 = load %arg0[%i, %j] :
104 /// memref<?x?xf32, stride_specification>
105 /// %12 = load %arg1[%i, %j, %k] :
106 /// memref<?x?x?xf32, stride_specification>
107 /// %13 = load %arg2[%i, %k, %j] :
108 /// memref<?x?x?xf32, stride_specification>
109 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
110 /// store %14#0, %arg1[%i, %j, %k] :
111 /// memref<?x?x?Xf32, stride_specification>
112 /// store %14#1, %arg2[%i, %k, %j] :
113 /// memref<?x?x?Xf32, stride_specification>
114 /// }
115 /// }
116 /// }
117 /// ```
118 template <typename LoadOpTy, typename StoreOpTy>
emitScalarImplementation(OpBuilder & b,Location loc,ArrayRef<Value> allIvs,LinalgOp linalgOp)119 static void emitScalarImplementation(OpBuilder &b, Location loc,
120 ArrayRef<Value> allIvs,
121 LinalgOp linalgOp) {
122 assert(linalgOp.hasBufferSemantics() &&
123 "expected linalg op with buffer semantics");
124 SmallVector<Value> indexedValues;
125 indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
126
127 auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end());
128
129 // TODO: Avoid the loads if the corresponding argument of the
130 // region has no uses.
131 // 1.a. Emit load from input operand or for scalars access the operand itself.
132 for (OpOperand *inputOperand : linalgOp.getInputOperands()) {
133 if (linalgOp.isScalar(inputOperand)) {
134 indexedValues.push_back(inputOperand->get());
135 continue;
136 }
137 auto indexing = makeCanonicalAffineApplies(
138 b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims);
139 indexedValues.push_back(
140 b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
141 }
142 // 1.b. Emit load from output views.
143 for (OpOperand *outputOperand : linalgOp.getOutputOperands()) {
144 SmallVector<Value> indexing = makeCanonicalAffineApplies(
145 b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims);
146 indexedValues.push_back(
147 b.create<LoadOpTy>(loc, outputOperand->get(), indexing));
148 }
149
150 // TODO: When a region inliner exists, use it.
151 // 2. Inline region, currently only works for a single basic block.
152 // 3. Emit store.
153 SmallVector<SmallVector<Value>, 8> indexing;
154 SmallVector<Value> outputBuffers;
155 for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) {
156 indexing.push_back(makeCanonicalAffineApplies(
157 b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims));
158 outputBuffers.push_back(outputOperand->get());
159 }
160 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
161 indexing, outputBuffers);
162 }
163
164 /// Replace the index operations in the body of the loop nest by the matching
165 /// induction variables.
replaceIndexOpsByInductionVariables(LinalgOp linalgOp,PatternRewriter & rewriter,ArrayRef<Operation * > loopOps)166 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
167 PatternRewriter &rewriter,
168 ArrayRef<Operation *> loopOps) {
169 // Extract the induction variables of the loop nest from outer to inner.
170 SmallVector<Value> allIvs;
171 for (Operation *loopOp : loopOps) {
172 llvm::TypeSwitch<Operation *>(loopOp)
173 .Case([&](scf::ParallelOp parallelOp) {
174 allIvs.append(parallelOp.getInductionVars().begin(),
175 parallelOp.getInductionVars().end());
176 })
177 .Case([&](scf::ForOp forOp) {
178 allIvs.push_back(forOp.getInductionVar());
179 })
180 .Case([&](AffineForOp affineForOp) {
181 allIvs.push_back(affineForOp.getInductionVar());
182 })
183 .Default([&](Operation *op) { assert(false && "unexpected op"); });
184 }
185 assert(linalgOp.getNumLoops() == allIvs.size() &&
186 "expected the number of loops and induction variables to match");
187 // Replace the index operations in the body of the innermost loop op.
188 if (!loopOps.empty()) {
189 LoopLikeOpInterface loopOp = loopOps.back();
190 for (IndexOp indexOp :
191 llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
192 rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
193 }
194 }
195
196 template <typename LoopTy>
linalgOpToLoopsImpl(PatternRewriter & rewriter,LinalgOp linalgOp)197 static FailureOr<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
198 LinalgOp linalgOp) {
199 using LoadOpTy =
200 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
201 AffineLoadOp, memref::LoadOp>::type;
202 using StoreOpTy =
203 typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
204 AffineStoreOp, memref::StoreOp>::type;
205
206 // The flattened loopToOperandRangesMaps is expected to be an invertible
207 // permutation map (which is asserted in the inverse calculation).
208 assert(linalgOp.hasBufferSemantics() &&
209 "expected linalg op with buffer semantics");
210
211 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
212 auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
213
214 SmallVector<Value> allIvs;
215 GenerateLoopNest<LoopTy>::doit(
216 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
217 [&](OpBuilder &b, Location loc, ValueRange ivs,
218 ValueRange operandValuesToUse) -> scf::ValueVector {
219 assert(operandValuesToUse == linalgOp->getOperands() &&
220 "expect operands are captured and not passed by loop argument");
221 allIvs.append(ivs.begin(), ivs.end());
222 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
223 return scf::ValueVector{};
224 });
225 // Number of loop ops might be different from the number of ivs since some
226 // loops like affine.parallel and scf.parallel have multiple ivs.
227 SetVector<Operation *> loopSet;
228 for (Value iv : allIvs) {
229 if (!iv)
230 return failure();
231 // The induction variable is a block argument of the entry block of the
232 // loop operation.
233 BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
234 if (!ivVal)
235 return failure();
236 loopSet.insert(ivVal.getOwner()->getParentOp());
237 }
238 LinalgLoops loops(loopSet.begin(), loopSet.end());
239 // Replace all index operations in the loop body.
240 replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops);
241 return loops;
242 }
243
244 namespace {
245 template <typename LoopType>
246 class LinalgRewritePattern : public RewritePattern {
247 public:
LinalgRewritePattern(MLIRContext * context)248 LinalgRewritePattern(MLIRContext *context)
249 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
250
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const251 LogicalResult matchAndRewrite(Operation *op,
252 PatternRewriter &rewriter) const override {
253 auto linalgOp = dyn_cast<LinalgOp>(op);
254 if (!isa<LinalgOp>(op))
255 return failure();
256 if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
257 return failure();
258 rewriter.eraseOp(op);
259 return success();
260 }
261 };
262
263 /// Local folding pattern for AffineApplyOp that we can apply greedily.
264 /// This replaces AffineApplyOp by the proper value in cases where the
265 /// associated map is trivial.
266 /// A trivial map here is defined as a map with a single result and either:
267 /// 1. Zero operand + returns a single AffineConstantExpr
268 /// 2. One operand + returns a single AffineDimExpr
269 /// 3. One operand + returns a single AffineSymbolExpr
270 //
271 /// In the first case, the AffineApplyOp is replaced by a new constant. In the
272 /// other cases, it is replaced by its unique operand.
273 struct FoldAffineOp : public RewritePattern {
FoldAffineOp__anonc00a4a8f0711::FoldAffineOp274 FoldAffineOp(MLIRContext *context)
275 : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
276
matchAndRewrite__anonc00a4a8f0711::FoldAffineOp277 LogicalResult matchAndRewrite(Operation *op,
278 PatternRewriter &rewriter) const override {
279 AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op);
280 auto map = affineApplyOp.getAffineMap();
281 if (map.getNumResults() != 1 || map.getNumInputs() > 1)
282 return failure();
283
284 AffineExpr expr = map.getResult(0);
285 if (map.getNumInputs() == 0) {
286 if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
287 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue());
288 return success();
289 }
290 return failure();
291 }
292 if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
293 rewriter.replaceOp(op, op->getOperand(0));
294 return success();
295 }
296 return failure();
297 }
298 };
299
300 template <typename LoopType>
lowerLinalgToLoopsImpl(func::FuncOp funcOp)301 static void lowerLinalgToLoopsImpl(func::FuncOp funcOp) {
302 MLIRContext *context = funcOp.getContext();
303 RewritePatternSet patterns(context);
304 patterns.add<LinalgRewritePattern<LoopType>>(context);
305 memref::DimOp::getCanonicalizationPatterns(patterns, context);
306 tensor::DimOp::getCanonicalizationPatterns(patterns, context);
307 AffineApplyOp::getCanonicalizationPatterns(patterns, context);
308 patterns.add<FoldAffineOp>(context);
309 // Just apply the patterns greedily.
310 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
311 }
312
313 struct LowerToAffineLoops
314 : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
getDependentDialects__anonc00a4a8f0711::LowerToAffineLoops315 void getDependentDialects(DialectRegistry ®istry) const override {
316 registry.insert<memref::MemRefDialect>();
317 }
runOnOperation__anonc00a4a8f0711::LowerToAffineLoops318 void runOnOperation() override {
319 lowerLinalgToLoopsImpl<AffineForOp>(getOperation());
320 }
321 };
322
323 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
getDependentDialects__anonc00a4a8f0711::LowerToLoops324 void getDependentDialects(DialectRegistry ®istry) const override {
325 registry.insert<memref::MemRefDialect, scf::SCFDialect>();
326 }
runOnOperation__anonc00a4a8f0711::LowerToLoops327 void runOnOperation() override {
328 lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
329 }
330 };
331
332 struct LowerToParallelLoops
333 : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
runOnOperation__anonc00a4a8f0711::LowerToParallelLoops334 void runOnOperation() override {
335 lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
336 }
337 };
338
339 } // namespace
340
341 std::unique_ptr<OperationPass<func::FuncOp>>
createConvertLinalgToLoopsPass()342 mlir::createConvertLinalgToLoopsPass() {
343 return std::make_unique<LowerToLoops>();
344 }
345
346 std::unique_ptr<OperationPass<func::FuncOp>>
createConvertLinalgToParallelLoopsPass()347 mlir::createConvertLinalgToParallelLoopsPass() {
348 return std::make_unique<LowerToParallelLoops>();
349 }
350
351 std::unique_ptr<OperationPass<func::FuncOp>>
createConvertLinalgToAffineLoopsPass()352 mlir::createConvertLinalgToAffineLoopsPass() {
353 return std::make_unique<LowerToAffineLoops>();
354 }
355
356 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
357 FailureOr<LinalgLoops>
linalgOpToAffineLoops(PatternRewriter & rewriter,LinalgOp linalgOp)358 mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
359 LinalgOp linalgOp) {
360 return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp);
361 }
362
363 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
linalgOpToLoops(PatternRewriter & rewriter,LinalgOp linalgOp)364 FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
365 LinalgOp linalgOp) {
366 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
367 }
368
369 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
370 FailureOr<LinalgLoops>
linalgOpToParallelLoops(PatternRewriter & rewriter,LinalgOp linalgOp)371 mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
372 LinalgOp linalgOp) {
373 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
374 }
375