1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 
17 using namespace mlir;
18 using namespace mlir::linalg;
19 
20 /// Helper function to extract the operand types that are passed to the
21 /// generated CallOp. MemRefTypes have their layout canonicalized since the
22 /// information is not used in signature generation.
23 /// Note that static size information is not modified.
24 template <typename LinalgOp>
25 static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
26   SmallVector<Type, 4> result;
27   result.reserve(op->getNumOperands());
28   for (auto type : op->getOperandTypes()) {
29     // The underlying descriptor type (e.g. LLVM) does not have layout
30     // information. Canonicalizing the type at the level of std when going into
31     // a library call avoids needing to introduce DialectCastOp.
32     if (auto memrefType = type.dyn_cast<MemRefType>())
33       result.push_back(eraseStridedLayout(memrefType));
34     else
35       result.push_back(type);
36   }
37   return result;
38 }
39 
40 template <>
41 SmallVector<Type, 4> extractOperandTypes<IndexedGenericOp>(Operation *op) {
42   auto *ctx = op->getContext();
43   auto indexedGenericOp = cast<IndexedGenericOp>(op);
44   auto numLoops = indexedGenericOp.getNumLoops();
45 
46   SmallVector<Type, 4> result(numLoops, IndexType::get(ctx));
47   auto canonicalizedOperands = extractOperandTypes<LinalgOp>(op);
48   result.append(canonicalizedOperands.begin(), canonicalizedOperands.end());
49   return result;
50 }
51 
52 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
53 // If the library function does not exist, insert a declaration.
54 template <typename LinalgOp>
55 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
56                                                  PatternRewriter &rewriter) {
57   auto linalgOp = cast<LinalgOp>(op);
58   auto fnName = linalgOp.getLibraryCallName();
59   if (fnName.empty()) {
60     op->emitWarning("No library call defined for: ") << *op;
61     return {};
62   }
63 
64   // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
65   FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
66   auto module = op->getParentOfType<ModuleOp>();
67   if (module.lookupSymbol(fnName)) {
68     return fnNameAttr;
69   }
70 
71   SmallVector<Type, 4> inputTypes(extractOperandTypes<LinalgOp>(op));
72   assert(op->getNumResults() == 0 &&
73          "Library call for linalg operation can be generated only for ops that "
74          "have void return types");
75   auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
76 
77   OpBuilder::InsertionGuard guard(rewriter);
78   // Insert before module terminator.
79   rewriter.setInsertionPoint(module.getBody(),
80                              std::prev(module.getBody()->end()));
81   FuncOp funcOp =
82       rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType);
83   // Insert a function attribute that will trigger the emission of the
84   // corresponding `_mlir_ciface_xxx` interface so that external libraries see
85   // a normalized ABI. This interface is added during std to llvm conversion.
86   funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
87   return fnNameAttr;
88 }
89 
90 namespace {
91 
92 SmallVector<Value, 4>
93 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
94                                       ValueRange operands) {
95   SmallVector<Value, 4> res;
96   res.reserve(operands.size());
97   for (auto op : operands) {
98     auto memrefType = op.getType().dyn_cast<MemRefType>();
99     if (!memrefType) {
100       res.push_back(op);
101       continue;
102     }
103     Value cast =
104         b.create<MemRefCastOp>(loc, eraseStridedLayout(memrefType), op);
105     res.push_back(cast);
106   }
107   return res;
108 }
109 
110 // LinalgOpConversion<LinalgOp> creates a new call to the type-canonicalized
111 // `LinalgOp::getLibraryCallName()` function.
112 // The implementation of the function can be either in the same module or in an
113 // externally linked library.
114 template <typename LinalgOp>
115 class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
116 public:
117   using OpRewritePattern<LinalgOp>::OpRewritePattern;
118 
119   LogicalResult matchAndRewrite(LinalgOp op,
120                                 PatternRewriter &rewriter) const override {
121     auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
122     if (!libraryCallName)
123       return failure();
124 
125     rewriter.replaceOpWithNewOp<mlir::CallOp>(
126         op, libraryCallName.getValue(), ArrayRef<Type>{},
127         createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
128                                               op.getOperands()));
129     return success();
130   }
131 };
132 
133 /// Conversion pattern specialization for CopyOp. This kicks in when both input
134 /// and output permutations are left unspecified or are the identity.
135 template <>
136 class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
137 public:
138   using OpRewritePattern<CopyOp>::OpRewritePattern;
139 
140   LogicalResult matchAndRewrite(CopyOp op,
141                                 PatternRewriter &rewriter) const override {
142     auto inputPerm = op.inputPermutation();
143     if (inputPerm.hasValue() && !inputPerm->isIdentity())
144       return failure();
145     auto outputPerm = op.outputPermutation();
146     if (outputPerm.hasValue() && !outputPerm->isIdentity())
147       return failure();
148 
149     auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
150     if (!libraryCallName)
151       return failure();
152 
153     rewriter.replaceOpWithNewOp<mlir::CallOp>(
154         op, libraryCallName.getValue(), ArrayRef<Type>{},
155         createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
156                                               op.getOperands()));
157     return success();
158   }
159 };
160 
161 /// Conversion pattern specialization for IndexedGenericOp.
162 template <>
163 class LinalgOpConversion<IndexedGenericOp>
164     : public OpRewritePattern<IndexedGenericOp> {
165 public:
166   using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
167 
168   LogicalResult matchAndRewrite(IndexedGenericOp op,
169                                 PatternRewriter &rewriter) const override {
170     auto libraryCallName =
171         getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
172     if (!libraryCallName)
173       return failure();
174 
175     // TODO: Use induction variables values instead of zeros, when
176     // IndexedGenericOp is tiled.
177     auto zero = rewriter.create<mlir::ConstantOp>(
178         op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
179     auto indexedGenericOp = cast<IndexedGenericOp>(op);
180     auto numLoops = indexedGenericOp.getNumLoops();
181     SmallVector<Value, 4> operands;
182     operands.reserve(numLoops + op.getNumOperands());
183     for (unsigned i = 0; i < numLoops; ++i)
184       operands.push_back(zero);
185     for (auto operand : op.getOperands())
186       operands.push_back(operand);
187     rewriter.replaceOpWithNewOp<mlir::CallOp>(
188         op, libraryCallName.getValue(), ArrayRef<Type>{},
189         createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
190     return success();
191   }
192 };
193 
194 /// A non-conversion rewrite pattern kicks in to convert CopyOp with
195 /// permutations into a sequence of TransposeOp and permutation-free CopyOp.
196 /// This interplays together with TransposeOpConversion and
197 /// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
198 class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
199 public:
200   using OpRewritePattern<CopyOp>::OpRewritePattern;
201 
202   LogicalResult matchAndRewrite(CopyOp op,
203                                 PatternRewriter &rewriter) const override {
204     Value in = op.input(), out = op.output();
205 
206     // If either inputPerm or outputPerm are non-identities, insert transposes.
207     auto inputPerm = op.inputPermutation();
208     if (inputPerm.hasValue() && !inputPerm->isIdentity())
209       in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
210                                                 AffineMapAttr::get(*inputPerm));
211     auto outputPerm = op.outputPermutation();
212     if (outputPerm.hasValue() && !outputPerm->isIdentity())
213       out = rewriter.create<linalg::TransposeOp>(
214           op.getLoc(), out, AffineMapAttr::get(*outputPerm));
215 
216     // If nothing was transposed, fail and let the conversion kick in.
217     if (in == op.input() && out == op.output())
218       return failure();
219 
220     rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
221     return success();
222   }
223 };
224 } // namespace
225 
226 /// Populate the given list with patterns that convert from Linalg to Standard.
227 void mlir::populateLinalgToStandardConversionPatterns(
228     OwningRewritePatternList &patterns, MLIRContext *ctx) {
229   // TODO: ConvOp conversion needs to export a descriptor with relevant
230   // attribute values such as kernel striding and dilation.
231   // clang-format off
232   patterns.insert<
233       CopyTransposeConversion,
234       LinalgOpConversion<ConvOp>,
235       LinalgOpConversion<PoolingMaxOp>,
236       LinalgOpConversion<PoolingMinOp>,
237       LinalgOpConversion<PoolingSumOp>,
238       LinalgOpConversion<CopyOp>,
239       LinalgOpConversion<FillOp>,
240       LinalgOpConversion<GenericOp>,
241       LinalgOpConversion<IndexedGenericOp>>(ctx);
242   // TODO: collect all auto-generated named ops with a tblgen directive.
243   patterns.insert<
244       LinalgOpConversion<DotOp>,
245       LinalgOpConversion<BatchMatmulOp>,
246       LinalgOpConversion<MatvecOp>,
247       LinalgOpConversion<VecmatOp>,
248       LinalgOpConversion<MatmulOp>,
249       LinalgOpConversion<ConvWOp>,
250       LinalgOpConversion<ConvNWCOp>,
251       LinalgOpConversion<ConvNCWOp>,
252       LinalgOpConversion<ConvHWOp>,
253       LinalgOpConversion<ConvNHWCOp>,
254       LinalgOpConversion<ConvNCHWOp>,
255       LinalgOpConversion<ConvDHWOp>,
256       LinalgOpConversion<ConvNDHWCOp>,
257       LinalgOpConversion<ConvNCDHWOp>>(ctx);
258   // clang-format on
259 }
260 
261 namespace {
262 struct ConvertLinalgToStandardPass
263     : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
264   void runOnOperation() override;
265 };
266 } // namespace
267 
268 void ConvertLinalgToStandardPass::runOnOperation() {
269   auto module = getOperation();
270   ConversionTarget target(getContext());
271   target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>();
272   target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
273   target.addLegalOp<linalg::TransposeOp, linalg::ReshapeOp, linalg::RangeOp>();
274   OwningRewritePatternList patterns;
275   populateLinalgToStandardConversionPatterns(patterns, &getContext());
276   if (failed(applyFullConversion(module, target, patterns)))
277     signalPassFailure();
278 }
279 
280 std::unique_ptr<OperationPass<ModuleOp>>
281 mlir::createConvertLinalgToStandardPass() {
282   return std::make_unique<ConvertLinalgToStandardPass>();
283 }
284