1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/Linalg/Passes.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
17 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/Support/LLVM.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
makeCanonicalAffineApplies(OpBuilder & b,Location loc,AffineMap map,ArrayRef<Value> vals)30 static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc,
31                                                      AffineMap map,
32                                                      ArrayRef<Value> vals) {
33   if (map.isEmpty())
34     return {};
35 
36   assert(map.getNumInputs() == vals.size());
37   SmallVector<Value> res;
38   res.reserve(map.getNumResults());
39   auto dims = map.getNumDims();
40   for (auto e : map.getResults()) {
41     auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
42     SmallVector<Value> operands(vals.begin(), vals.end());
43     canonicalizeMapAndOperands(&exprMap, &operands);
44     res.push_back(b.create<AffineApplyOp>(loc, exprMap, operands));
45   }
46   return res;
47 }
48 
49 template <typename LoadOpTy, typename StoreOpTy, typename OpType>
inlineRegionAndEmitStore(OpBuilder & b,Location loc,OpType op,ArrayRef<Value> indexedValues,ArrayRef<SmallVector<Value>> indexing,ArrayRef<Value> outputBuffers)50 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op,
51                                      ArrayRef<Value> indexedValues,
52                                      ArrayRef<SmallVector<Value>> indexing,
53                                      ArrayRef<Value> outputBuffers) {
54   auto &block = op->getRegion(0).front();
55   BlockAndValueMapping map;
56   map.map(block.getArguments(), indexedValues);
57   for (auto &op : block.without_terminator()) {
58     auto *newOp = b.clone(op, map);
59     map.map(op.getResults(), newOp->getResults());
60   }
61 
62   Operation *terminator = block.getTerminator();
63   for (OpOperand &operand : terminator->getOpOperands()) {
64     Value toStore = map.lookupOrDefault(operand.get());
65     b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
66                         indexing[operand.getOperandNumber()]);
67   }
68 }
69 
70 // Returns a pair that contains input indices and output indices of a
71 // SingleInputPoolingOp `op`.
72 struct InputAndOutputIndices {
73   SmallVector<Value> inputs;
74   SmallVector<Value> outputs;
75 };
76 template <typename SingleInputPoolingOp>
77 static InputAndOutputIndices
getInputAndOutputIndices(OpBuilder & b,Location loc,ArrayRef<Value> allIvs,SingleInputPoolingOp op)78 getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs,
79                          SingleInputPoolingOp op) {
80   auto mapsRange = op.getIndexingMapsArray();
81   auto maps = llvm::to_vector<8>(
82       llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
83   return InputAndOutputIndices{
84       makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
85       makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
86 }
87 
88 /// Emits the MLIR for the scalar part of the generic op by:
89 ///   1. Emitting load ops for each input and output view in order. This is
90 ///      achieved by applying the appropriate input or output map to the
91 ///      enclosing induction variables.
92 ///   2. Emitting a call to `op.fun()` that takes as arguments the scalars
93 ///      from point 1. above.
94 ///   3. Emitting store ops to store the results of 2. to the output
95 ///      views.
96 ///
97 /// An example output may resemble:
98 ///
99 /// ```
100 ///    scf.for %i = %c0 to %0 step %c1 {
101 ///      scf.for %j = %c0 to %1 step %c1 {
102 ///        scf.for %k = %c0 to %4 step %c1 {
103 ///          %11 = load %arg0[%i, %j] :
104 ///            memref<?x?xf32, stride_specification>
105 ///          %12 = load %arg1[%i, %j, %k] :
106 ///            memref<?x?x?xf32, stride_specification>
107 ///          %13 = load %arg2[%i, %k, %j] :
108 ///            memref<?x?x?xf32, stride_specification>
109 ///          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
110 ///          store %14#0, %arg1[%i, %j, %k] :
111 ///            memref<?x?x?Xf32, stride_specification>
112 ///          store %14#1, %arg2[%i, %k, %j] :
113 ///            memref<?x?x?Xf32, stride_specification>
114 ///       }
115 ///      }
116 ///    }
117 /// ```
118 template <typename LoadOpTy, typename StoreOpTy>
emitScalarImplementation(OpBuilder & b,Location loc,ArrayRef<Value> allIvs,LinalgOp linalgOp)119 static void emitScalarImplementation(OpBuilder &b, Location loc,
120                                      ArrayRef<Value> allIvs,
121                                      LinalgOp linalgOp) {
122   assert(linalgOp.hasBufferSemantics() &&
123          "expected linalg op with buffer semantics");
124   SmallVector<Value> indexedValues;
125   indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
126 
127   auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end());
128 
129   // TODO: Avoid the loads if the corresponding argument of the
130   // region has no uses.
131   // 1.a. Emit load from input operand or for scalars access the operand itself.
132   for (OpOperand *inputOperand : linalgOp.getInputOperands()) {
133     if (linalgOp.isScalar(inputOperand)) {
134       indexedValues.push_back(inputOperand->get());
135       continue;
136     }
137     auto indexing = makeCanonicalAffineApplies(
138         b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims);
139     indexedValues.push_back(
140         b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
141   }
142   // 1.b. Emit load from output views.
143   for (OpOperand *outputOperand : linalgOp.getOutputOperands()) {
144     SmallVector<Value> indexing = makeCanonicalAffineApplies(
145         b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims);
146     indexedValues.push_back(
147         b.create<LoadOpTy>(loc, outputOperand->get(), indexing));
148   }
149 
150   // TODO: When a region inliner exists, use it.
151   // 2. Inline region, currently only works for a single basic block.
152   // 3. Emit store.
153   SmallVector<SmallVector<Value>, 8> indexing;
154   SmallVector<Value> outputBuffers;
155   for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) {
156     indexing.push_back(makeCanonicalAffineApplies(
157         b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims));
158     outputBuffers.push_back(outputOperand->get());
159   }
160   inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
161                                                 indexing, outputBuffers);
162 }
163 
164 /// Replace the index operations in the body of the loop nest by the matching
165 /// induction variables.
replaceIndexOpsByInductionVariables(LinalgOp linalgOp,PatternRewriter & rewriter,ArrayRef<Operation * > loopOps)166 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
167                                                 PatternRewriter &rewriter,
168                                                 ArrayRef<Operation *> loopOps) {
169   // Extract the induction variables of the loop nest from outer to inner.
170   SmallVector<Value> allIvs;
171   for (Operation *loopOp : loopOps) {
172     llvm::TypeSwitch<Operation *>(loopOp)
173         .Case([&](scf::ParallelOp parallelOp) {
174           allIvs.append(parallelOp.getInductionVars().begin(),
175                         parallelOp.getInductionVars().end());
176         })
177         .Case([&](scf::ForOp forOp) {
178           allIvs.push_back(forOp.getInductionVar());
179         })
180         .Case([&](AffineForOp affineForOp) {
181           allIvs.push_back(affineForOp.getInductionVar());
182         })
183         .Default([&](Operation *op) { assert(false && "unexpected op"); });
184   }
185   assert(linalgOp.getNumLoops() == allIvs.size() &&
186          "expected the number of loops and induction variables to match");
187   // Replace the index operations in the body of the innermost loop op.
188   if (!loopOps.empty()) {
189     LoopLikeOpInterface loopOp = loopOps.back();
190     for (IndexOp indexOp :
191          llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
192       rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
193   }
194 }
195 
196 template <typename LoopTy>
linalgOpToLoopsImpl(PatternRewriter & rewriter,LinalgOp linalgOp)197 static FailureOr<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
198                                                   LinalgOp linalgOp) {
199   using LoadOpTy =
200       typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
201                                 AffineLoadOp, memref::LoadOp>::type;
202   using StoreOpTy =
203       typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
204                                 AffineStoreOp, memref::StoreOp>::type;
205 
206   // The flattened loopToOperandRangesMaps is expected to be an invertible
207   // permutation map (which is asserted in the inverse calculation).
208   assert(linalgOp.hasBufferSemantics() &&
209          "expected linalg op with buffer semantics");
210 
211   auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
212   auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
213 
214   SmallVector<Value> allIvs;
215   GenerateLoopNest<LoopTy>::doit(
216       rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
217       [&](OpBuilder &b, Location loc, ValueRange ivs,
218           ValueRange operandValuesToUse) -> scf::ValueVector {
219         assert(operandValuesToUse == linalgOp->getOperands() &&
220                "expect operands are captured and not passed by loop argument");
221         allIvs.append(ivs.begin(), ivs.end());
222         emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
223         return scf::ValueVector{};
224       });
225   // Number of loop ops might be different from the number of ivs since some
226   // loops like affine.parallel and scf.parallel have multiple ivs.
227   SetVector<Operation *> loopSet;
228   for (Value iv : allIvs) {
229     if (!iv)
230       return failure();
231     // The induction variable is a block argument of the entry block of the
232     // loop operation.
233     BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
234     if (!ivVal)
235       return failure();
236     loopSet.insert(ivVal.getOwner()->getParentOp());
237   }
238   LinalgLoops loops(loopSet.begin(), loopSet.end());
239   // Replace all index operations in the loop body.
240   replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops);
241   return loops;
242 }
243 
244 namespace {
245 template <typename LoopType>
246 class LinalgRewritePattern : public RewritePattern {
247 public:
LinalgRewritePattern(MLIRContext * context)248   LinalgRewritePattern(MLIRContext *context)
249       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
250 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const251   LogicalResult matchAndRewrite(Operation *op,
252                                 PatternRewriter &rewriter) const override {
253     auto linalgOp = dyn_cast<LinalgOp>(op);
254     if (!isa<LinalgOp>(op))
255       return failure();
256     if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp)))
257       return failure();
258     rewriter.eraseOp(op);
259     return success();
260   }
261 };
262 
263 /// Local folding pattern for AffineApplyOp that we can apply greedily.
264 /// This replaces AffineApplyOp by the proper value in cases where the
265 /// associated map is trivial.
266 /// A trivial map here is defined as a map with a single result and either:
267 ///   1. Zero operand + returns a single AffineConstantExpr
268 ///   2. One operand + returns a single AffineDimExpr
269 ///   3. One operand + returns a single AffineSymbolExpr
270 //
271 /// In the first case, the AffineApplyOp is replaced by a new constant. In the
272 /// other cases, it is replaced by its unique operand.
273 struct FoldAffineOp : public RewritePattern {
FoldAffineOp__anonc00a4a8f0711::FoldAffineOp274   FoldAffineOp(MLIRContext *context)
275       : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
276 
matchAndRewrite__anonc00a4a8f0711::FoldAffineOp277   LogicalResult matchAndRewrite(Operation *op,
278                                 PatternRewriter &rewriter) const override {
279     AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op);
280     auto map = affineApplyOp.getAffineMap();
281     if (map.getNumResults() != 1 || map.getNumInputs() > 1)
282       return failure();
283 
284     AffineExpr expr = map.getResult(0);
285     if (map.getNumInputs() == 0) {
286       if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
287         rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue());
288         return success();
289       }
290       return failure();
291     }
292     if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
293       rewriter.replaceOp(op, op->getOperand(0));
294       return success();
295     }
296     return failure();
297   }
298 };
299 
300 template <typename LoopType>
lowerLinalgToLoopsImpl(func::FuncOp funcOp)301 static void lowerLinalgToLoopsImpl(func::FuncOp funcOp) {
302   MLIRContext *context = funcOp.getContext();
303   RewritePatternSet patterns(context);
304   patterns.add<LinalgRewritePattern<LoopType>>(context);
305   memref::DimOp::getCanonicalizationPatterns(patterns, context);
306   tensor::DimOp::getCanonicalizationPatterns(patterns, context);
307   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
308   patterns.add<FoldAffineOp>(context);
309   // Just apply the patterns greedily.
310   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
311 }
312 
313 struct LowerToAffineLoops
314     : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
getDependentDialects__anonc00a4a8f0711::LowerToAffineLoops315   void getDependentDialects(DialectRegistry &registry) const override {
316     registry.insert<memref::MemRefDialect>();
317   }
runOnOperation__anonc00a4a8f0711::LowerToAffineLoops318   void runOnOperation() override {
319     lowerLinalgToLoopsImpl<AffineForOp>(getOperation());
320   }
321 };
322 
323 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
getDependentDialects__anonc00a4a8f0711::LowerToLoops324   void getDependentDialects(DialectRegistry &registry) const override {
325     registry.insert<memref::MemRefDialect, scf::SCFDialect>();
326   }
runOnOperation__anonc00a4a8f0711::LowerToLoops327   void runOnOperation() override {
328     lowerLinalgToLoopsImpl<scf::ForOp>(getOperation());
329   }
330 };
331 
332 struct LowerToParallelLoops
333     : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
runOnOperation__anonc00a4a8f0711::LowerToParallelLoops334   void runOnOperation() override {
335     lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation());
336   }
337 };
338 
339 } // namespace
340 
341 std::unique_ptr<OperationPass<func::FuncOp>>
createConvertLinalgToLoopsPass()342 mlir::createConvertLinalgToLoopsPass() {
343   return std::make_unique<LowerToLoops>();
344 }
345 
346 std::unique_ptr<OperationPass<func::FuncOp>>
createConvertLinalgToParallelLoopsPass()347 mlir::createConvertLinalgToParallelLoopsPass() {
348   return std::make_unique<LowerToParallelLoops>();
349 }
350 
351 std::unique_ptr<OperationPass<func::FuncOp>>
createConvertLinalgToAffineLoopsPass()352 mlir::createConvertLinalgToAffineLoopsPass() {
353   return std::make_unique<LowerToAffineLoops>();
354 }
355 
356 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
357 FailureOr<LinalgLoops>
linalgOpToAffineLoops(PatternRewriter & rewriter,LinalgOp linalgOp)358 mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
359                                     LinalgOp linalgOp) {
360   return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp);
361 }
362 
363 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
linalgOpToLoops(PatternRewriter & rewriter,LinalgOp linalgOp)364 FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
365                                                      LinalgOp linalgOp) {
366   return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
367 }
368 
369 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
370 FailureOr<LinalgLoops>
linalgOpToParallelLoops(PatternRewriter & rewriter,LinalgOp linalgOp)371 mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
372                                       LinalgOp linalgOp) {
373   return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
374 }
375