1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 
23 /// Helper function to extract the operand types that are passed to the
24 /// generated CallOp. MemRefTypes have their layout canonicalized since the
25 /// information is not used in signature generation.
26 /// Note that static size information is not modified.
extractOperandTypes(Operation * op)27 static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
28   SmallVector<Type, 4> result;
29   result.reserve(op->getNumOperands());
30   for (auto type : op->getOperandTypes()) {
31     // The underlying descriptor type (e.g. LLVM) does not have layout
32     // information. Canonicalizing the type at the level of std when going into
33     // a library call avoids needing to introduce DialectCastOp.
34     if (auto memrefType = type.dyn_cast<MemRefType>())
35       result.push_back(eraseStridedLayout(memrefType));
36     else
37       result.push_back(type);
38   }
39   return result;
40 }
41 
42 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
43 // If the library function does not exist, insert a declaration.
getLibraryCallSymbolRef(Operation * op,PatternRewriter & rewriter)44 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
45                                                  PatternRewriter &rewriter) {
46   auto linalgOp = cast<LinalgOp>(op);
47   auto fnName = linalgOp.getLibraryCallName();
48   if (fnName.empty()) {
49     op->emitWarning("No library call defined for: ") << *op;
50     return {};
51   }
52 
53   // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
54   FlatSymbolRefAttr fnNameAttr =
55       SymbolRefAttr::get(rewriter.getContext(), fnName);
56   auto module = op->getParentOfType<ModuleOp>();
57   if (module.lookupSymbol(fnNameAttr.getAttr()))
58     return fnNameAttr;
59 
60   SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
61   assert(op->getNumResults() == 0 &&
62          "Library call for linalg operation can be generated only for ops that "
63          "have void return types");
64   auto libFnType = rewriter.getFunctionType(inputTypes, {});
65 
66   OpBuilder::InsertionGuard guard(rewriter);
67   // Insert before module terminator.
68   rewriter.setInsertionPoint(module.getBody(),
69                              std::prev(module.getBody()->end()));
70   func::FuncOp funcOp = rewriter.create<func::FuncOp>(
71       op->getLoc(), fnNameAttr.getValue(), libFnType);
72   // Insert a function attribute that will trigger the emission of the
73   // corresponding `_mlir_ciface_xxx` interface so that external libraries see
74   // a normalized ABI. This interface is added during std to llvm conversion.
75   funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
76                   UnitAttr::get(op->getContext()));
77   funcOp.setPrivate();
78   return fnNameAttr;
79 }
80 
81 static SmallVector<Value, 4>
createTypeCanonicalizedMemRefOperands(OpBuilder & b,Location loc,ValueRange operands)82 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
83                                       ValueRange operands) {
84   SmallVector<Value, 4> res;
85   res.reserve(operands.size());
86   for (auto op : operands) {
87     auto memrefType = op.getType().dyn_cast<MemRefType>();
88     if (!memrefType) {
89       res.push_back(op);
90       continue;
91     }
92     Value cast =
93         b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op);
94     res.push_back(cast);
95   }
96   return res;
97 }
98 
matchAndRewrite(LinalgOp op,PatternRewriter & rewriter) const99 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
100     LinalgOp op, PatternRewriter &rewriter) const {
101   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
102   if (!libraryCallName)
103     return failure();
104 
105   // TODO: Add support for more complex library call signatures that include
106   // indices or captured values.
107   rewriter.replaceOpWithNewOp<func::CallOp>(
108       op, libraryCallName.getValue(), TypeRange(),
109       createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
110                                             op->getOperands()));
111   return success();
112 }
113 
114 /// Populate the given list with patterns that convert from Linalg to Standard.
populateLinalgToStandardConversionPatterns(RewritePatternSet & patterns)115 void mlir::linalg::populateLinalgToStandardConversionPatterns(
116     RewritePatternSet &patterns) {
117   // TODO: ConvOp conversion needs to export a descriptor with relevant
118   // attribute values such as kernel striding and dilation.
119   patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext());
120 }
121 
122 namespace {
123 struct ConvertLinalgToStandardPass
124     : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
125   void runOnOperation() override;
126 };
127 } // namespace
128 
runOnOperation()129 void ConvertLinalgToStandardPass::runOnOperation() {
130   auto module = getOperation();
131   ConversionTarget target(getContext());
132   target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
133                          func::FuncDialect, memref::MemRefDialect,
134                          scf::SCFDialect>();
135   target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
136   RewritePatternSet patterns(&getContext());
137   populateLinalgToStandardConversionPatterns(patterns);
138   if (failed(applyFullConversion(module, target, std::move(patterns))))
139     signalPassFailure();
140 }
141 
142 std::unique_ptr<OperationPass<ModuleOp>>
createConvertLinalgToStandardPass()143 mlir::createConvertLinalgToStandardPass() {
144   return std::make_unique<ConvertLinalgToStandardPass>();
145 }
146