1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 
19 using namespace mlir;
20 using namespace mlir::linalg;
21 
22 /// Helper function to extract the operand types that are passed to the
23 /// generated CallOp. MemRefTypes have their layout canonicalized since the
24 /// information is not used in signature generation.
25 /// Note that static size information is not modified.
26 static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
27   SmallVector<Type, 4> result;
28   result.reserve(op->getNumOperands());
29   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
30     auto *ctx = op->getContext();
31     auto numLoops = indexedGenericOp.getNumLoops();
32     result.reserve(op->getNumOperands() + numLoops);
33     result.assign(numLoops, IndexType::get(ctx));
34   }
35   for (auto type : op->getOperandTypes()) {
36     // The underlying descriptor type (e.g. LLVM) does not have layout
37     // information. Canonicalizing the type at the level of std when going into
38     // a library call avoids needing to introduce DialectCastOp.
39     if (auto memrefType = type.dyn_cast<MemRefType>())
40       result.push_back(eraseStridedLayout(memrefType));
41     else
42       result.push_back(type);
43   }
44   return result;
45 }
46 
47 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
48 // If the library function does not exist, insert a declaration.
49 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
50                                                  PatternRewriter &rewriter) {
51   auto linalgOp = cast<LinalgOp>(op);
52   auto fnName = linalgOp.getLibraryCallName();
53   if (fnName.empty()) {
54     op->emitWarning("No library call defined for: ") << *op;
55     return {};
56   }
57 
58   // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
59   FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
60   auto module = op->getParentOfType<ModuleOp>();
61   if (module.lookupSymbol(fnName)) {
62     return fnNameAttr;
63   }
64 
65   SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
66   assert(op->getNumResults() == 0 &&
67          "Library call for linalg operation can be generated only for ops that "
68          "have void return types");
69   auto libFnType = rewriter.getFunctionType(inputTypes, {});
70 
71   OpBuilder::InsertionGuard guard(rewriter);
72   // Insert before module terminator.
73   rewriter.setInsertionPoint(module.getBody(),
74                              std::prev(module.getBody()->end()));
75   FuncOp funcOp =
76       rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType);
77   // Insert a function attribute that will trigger the emission of the
78   // corresponding `_mlir_ciface_xxx` interface so that external libraries see
79   // a normalized ABI. This interface is added during std to llvm conversion.
80   funcOp->setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
81   funcOp.setPrivate();
82   return fnNameAttr;
83 }
84 
85 static SmallVector<Value, 4>
86 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
87                                       ValueRange operands) {
88   SmallVector<Value, 4> res;
89   res.reserve(operands.size());
90   for (auto op : operands) {
91     auto memrefType = op.getType().dyn_cast<MemRefType>();
92     if (!memrefType) {
93       res.push_back(op);
94       continue;
95     }
96     Value cast =
97         b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op);
98     res.push_back(cast);
99   }
100   return res;
101 }
102 
103 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
104     LinalgOp op, PatternRewriter &rewriter) const {
105   // Only LinalgOp for which there is no specialized pattern go through this.
106   if (isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
107     return failure();
108 
109   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
110   if (!libraryCallName)
111     return failure();
112 
113   rewriter.replaceOpWithNewOp<mlir::CallOp>(
114       op, libraryCallName.getValue(), TypeRange(),
115       createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
116                                             op->getOperands()));
117   return success();
118 }
119 
120 LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite(
121     CopyOp op, PatternRewriter &rewriter) const {
122   auto inputPerm = op.inputPermutation();
123   if (inputPerm.hasValue() && !inputPerm->isIdentity())
124     return failure();
125   auto outputPerm = op.outputPermutation();
126   if (outputPerm.hasValue() && !outputPerm->isIdentity())
127     return failure();
128 
129   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
130   if (!libraryCallName)
131     return failure();
132 
133   rewriter.replaceOpWithNewOp<mlir::CallOp>(
134       op, libraryCallName.getValue(), TypeRange(),
135       createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
136                                             op.getOperands()));
137   return success();
138 }
139 
140 LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
141     CopyOp op, PatternRewriter &rewriter) const {
142   Value in = op.input(), out = op.output();
143 
144   // If either inputPerm or outputPerm are non-identities, insert transposes.
145   auto inputPerm = op.inputPermutation();
146   if (inputPerm.hasValue() && !inputPerm->isIdentity())
147     in = rewriter.create<memref::TransposeOp>(op.getLoc(), in,
148                                               AffineMapAttr::get(*inputPerm));
149   auto outputPerm = op.outputPermutation();
150   if (outputPerm.hasValue() && !outputPerm->isIdentity())
151     out = rewriter.create<memref::TransposeOp>(op.getLoc(), out,
152                                                AffineMapAttr::get(*outputPerm));
153 
154   // If nothing was transposed, fail and let the conversion kick in.
155   if (in == op.input() && out == op.output())
156     return failure();
157 
158   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
159   if (!libraryCallName)
160     return failure();
161 
162   rewriter.replaceOpWithNewOp<mlir::CallOp>(
163       op, libraryCallName.getValue(), TypeRange(),
164       createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
165   return success();
166 }
167 
168 LogicalResult
169 mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
170     IndexedGenericOp op, PatternRewriter &rewriter) const {
171   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
172   if (!libraryCallName)
173     return failure();
174 
175   // TODO: Use induction variables values instead of zeros, when
176   // IndexedGenericOp is tiled.
177   auto zero = rewriter.create<mlir::ConstantOp>(
178       op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
179   auto indexedGenericOp = cast<IndexedGenericOp>(op);
180   auto numLoops = indexedGenericOp.getNumLoops();
181   SmallVector<Value, 4> operands;
182   operands.reserve(numLoops + op.getNumOperands());
183   for (unsigned i = 0; i < numLoops; ++i)
184     operands.push_back(zero);
185   for (auto operand : op.getOperands())
186     operands.push_back(operand);
187   rewriter.replaceOpWithNewOp<mlir::CallOp>(
188       op, libraryCallName.getValue(), TypeRange(),
189       createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
190   return success();
191 }
192 
193 /// Populate the given list with patterns that convert from Linalg to Standard.
194 void mlir::linalg::populateLinalgToStandardConversionPatterns(
195     RewritePatternSet &patterns) {
196   // TODO: ConvOp conversion needs to export a descriptor with relevant
197   // attribute values such as kernel striding and dilation.
198   // clang-format off
199   patterns.add<
200       CopyOpToLibraryCallRewrite,
201       CopyTransposeRewrite,
202       IndexedGenericOpToLibraryCallRewrite,
203       LinalgOpToLibraryCallRewrite>(patterns.getContext());
204   // clang-format on
205 }
206 
207 namespace {
208 struct ConvertLinalgToStandardPass
209     : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
210   void runOnOperation() override;
211 };
212 } // namespace
213 
214 void ConvertLinalgToStandardPass::runOnOperation() {
215   auto module = getOperation();
216   ConversionTarget target(getContext());
217   target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
218                          StandardOpsDialect>();
219   target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
220   target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
221   RewritePatternSet patterns(&getContext());
222   populateLinalgToStandardConversionPatterns(patterns);
223   if (failed(applyFullConversion(module, target, std::move(patterns))))
224     signalPassFailure();
225 }
226 
227 std::unique_ptr<OperationPass<ModuleOp>>
228 mlir::createConvertLinalgToStandardPass() {
229   return std::make_unique<ConvertLinalgToStandardPass>();
230 }
231