1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 
18 using namespace mlir;
19 using namespace mlir::linalg;
20 
21 /// Helper function to extract the operand types that are passed to the
22 /// generated CallOp. MemRefTypes have their layout canonicalized since the
23 /// information is not used in signature generation.
24 /// Note that static size information is not modified.
25 static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
26   SmallVector<Type, 4> result;
27   result.reserve(op->getNumOperands());
28   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
29     auto *ctx = op->getContext();
30     auto numLoops = indexedGenericOp.getNumLoops();
31     result.reserve(op->getNumOperands() + numLoops);
32     result.assign(numLoops, IndexType::get(ctx));
33   }
34   for (auto type : op->getOperandTypes()) {
35     // The underlying descriptor type (e.g. LLVM) does not have layout
36     // information. Canonicalizing the type at the level of std when going into
37     // a library call avoids needing to introduce DialectCastOp.
38     if (auto memrefType = type.dyn_cast<MemRefType>())
39       result.push_back(eraseStridedLayout(memrefType));
40     else
41       result.push_back(type);
42   }
43   return result;
44 }
45 
46 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
47 // If the library function does not exist, insert a declaration.
48 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
49                                                  PatternRewriter &rewriter) {
50   auto linalgOp = cast<LinalgOp>(op);
51   auto fnName = linalgOp.getLibraryCallName();
52   if (fnName.empty()) {
53     op->emitWarning("No library call defined for: ") << *op;
54     return {};
55   }
56 
57   // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
58   FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
59   auto module = op->getParentOfType<ModuleOp>();
60   if (module.lookupSymbol(fnName)) {
61     return fnNameAttr;
62   }
63 
64   SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
65   assert(op->getNumResults() == 0 &&
66          "Library call for linalg operation can be generated only for ops that "
67          "have void return types");
68   auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
69 
70   OpBuilder::InsertionGuard guard(rewriter);
71   // Insert before module terminator.
72   rewriter.setInsertionPoint(module.getBody(),
73                              std::prev(module.getBody()->end()));
74   FuncOp funcOp =
75       rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType);
76   // Insert a function attribute that will trigger the emission of the
77   // corresponding `_mlir_ciface_xxx` interface so that external libraries see
78   // a normalized ABI. This interface is added during std to llvm conversion.
79   funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
80   return fnNameAttr;
81 }
82 
83 static SmallVector<Value, 4>
84 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
85                                       ValueRange operands) {
86   SmallVector<Value, 4> res;
87   res.reserve(operands.size());
88   for (auto op : operands) {
89     auto memrefType = op.getType().dyn_cast<MemRefType>();
90     if (!memrefType) {
91       res.push_back(op);
92       continue;
93     }
94     Value cast =
95         b.create<MemRefCastOp>(loc, eraseStridedLayout(memrefType), op);
96     res.push_back(cast);
97   }
98   return res;
99 }
100 
101 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
102     Operation *op, PatternRewriter &rewriter) const {
103   // Only LinalgOp for which there is no specialized pattern go through this.
104   if (!isa<LinalgOp>(op) || isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
105     return failure();
106 
107   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
108   if (!libraryCallName)
109     return failure();
110 
111   rewriter.replaceOpWithNewOp<mlir::CallOp>(
112       op, libraryCallName.getValue(), TypeRange(),
113       createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
114                                             op->getOperands()));
115   return success();
116 }
117 
118 LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite(
119     CopyOp op, PatternRewriter &rewriter) const {
120   auto inputPerm = op.inputPermutation();
121   if (inputPerm.hasValue() && !inputPerm->isIdentity())
122     return failure();
123   auto outputPerm = op.outputPermutation();
124   if (outputPerm.hasValue() && !outputPerm->isIdentity())
125     return failure();
126 
127   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
128   if (!libraryCallName)
129     return failure();
130 
131   rewriter.replaceOpWithNewOp<mlir::CallOp>(
132       op, libraryCallName.getValue(), TypeRange(),
133       createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
134                                             op.getOperands()));
135   return success();
136 }
137 
138 LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
139     CopyOp op, PatternRewriter &rewriter) const {
140   Value in = op.input(), out = op.output();
141 
142   // If either inputPerm or outputPerm are non-identities, insert transposes.
143   auto inputPerm = op.inputPermutation();
144   if (inputPerm.hasValue() && !inputPerm->isIdentity())
145     in = rewriter.create<TransposeOp>(op.getLoc(), in,
146                                       AffineMapAttr::get(*inputPerm));
147   auto outputPerm = op.outputPermutation();
148   if (outputPerm.hasValue() && !outputPerm->isIdentity())
149     out = rewriter.create<TransposeOp>(op.getLoc(), out,
150                                        AffineMapAttr::get(*outputPerm));
151 
152   // If nothing was transposed, fail and let the conversion kick in.
153   if (in == op.input() && out == op.output())
154     return failure();
155 
156   rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
157   return success();
158 }
159 
160 LogicalResult
161 mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
162     IndexedGenericOp op, PatternRewriter &rewriter) const {
163   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
164   if (!libraryCallName)
165     return failure();
166 
167   // TODO: Use induction variables values instead of zeros, when
168   // IndexedGenericOp is tiled.
169   auto zero = rewriter.create<mlir::ConstantOp>(
170       op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
171   auto indexedGenericOp = cast<IndexedGenericOp>(op);
172   auto numLoops = indexedGenericOp.getNumLoops();
173   SmallVector<Value, 4> operands;
174   operands.reserve(numLoops + op.getNumOperands());
175   for (unsigned i = 0; i < numLoops; ++i)
176     operands.push_back(zero);
177   for (auto operand : op.getOperands())
178     operands.push_back(operand);
179   rewriter.replaceOpWithNewOp<mlir::CallOp>(
180       op, libraryCallName.getValue(), TypeRange(),
181       createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
182   return success();
183 }
184 
185 /// Populate the given list with patterns that convert from Linalg to Standard.
186 void mlir::linalg::populateLinalgToStandardConversionPatterns(
187     OwningRewritePatternList &patterns, MLIRContext *ctx) {
188   // TODO: ConvOp conversion needs to export a descriptor with relevant
189   // attribute values such as kernel striding and dilation.
190   // clang-format off
191   patterns.insert<
192       CopyOpToLibraryCallRewrite,
193       CopyTransposeRewrite,
194       IndexedGenericOpToLibraryCallRewrite>(ctx);
195   patterns.insert<LinalgOpToLibraryCallRewrite>();
196   // clang-format on
197 }
198 
199 namespace {
200 struct ConvertLinalgToStandardPass
201     : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
202   void runOnOperation() override;
203 };
204 } // namespace
205 
206 void ConvertLinalgToStandardPass::runOnOperation() {
207   auto module = getOperation();
208   ConversionTarget target(getContext());
209   target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>();
210   target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
211   target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
212   OwningRewritePatternList patterns;
213   populateLinalgToStandardConversionPatterns(patterns, &getContext());
214   if (failed(applyFullConversion(module, target, std::move(patterns))))
215     signalPassFailure();
216 }
217 
218 std::unique_ptr<OperationPass<ModuleOp>>
219 mlir::createConvertLinalgToStandardPass() {
220   return std::make_unique<ConvertLinalgToStandardPass>();
221 }
222