1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
17 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
18 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
19 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/Linalg/Passes.h"
23 #include "mlir/Dialect/SCF/IR/SCF.h"
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/IR/Operation.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/Support/LogicalResult.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/SetVector.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/Allocator.h"
42 #include "llvm/Support/ErrorHandling.h"
43
44 using namespace mlir;
45 using namespace mlir::LLVM;
46 using namespace mlir::linalg;
47
48 template <typename T>
getPtrToElementType(T containerType,LLVMTypeConverter & lowering)49 static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
50 return LLVMPointerType::get(
51 lowering.convertType(containerType.getElementType()));
52 }
53
54 namespace {
55 // YieldOp produces and LLVM::ReturnOp.
56 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
57 public:
58 using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
59
60 LogicalResult
matchAndRewrite(linalg::YieldOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const61 matchAndRewrite(linalg::YieldOp op, OpAdaptor adaptor,
62 ConversionPatternRewriter &rewriter) const override {
63 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
64 return success();
65 }
66 };
67 } // namespace
68
69 /// Populate the given list with patterns that convert from Linalg to LLVM.
populateLinalgToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)70 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
71 RewritePatternSet &patterns) {
72 patterns.add<YieldOpConversion>(converter);
73 }
74
75 namespace {
76 struct ConvertLinalgToLLVMPass
77 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
78 void runOnOperation() override;
79 };
80 } // namespace
81
runOnOperation()82 void ConvertLinalgToLLVMPass::runOnOperation() {
83 auto module = getOperation();
84
85 // Convert to the LLVM IR dialect using the converter defined above.
86 RewritePatternSet patterns(&getContext());
87 LLVMTypeConverter converter(&getContext());
88 populateLinalgToLLVMConversionPatterns(converter, patterns);
89 populateMemRefToLLVMConversionPatterns(converter, patterns);
90
91 LLVMConversionTarget target(getContext());
92 target.addLegalOp<ModuleOp>();
93 if (failed(applyPartialConversion(module, target, std::move(patterns))))
94 signalPassFailure();
95 }
96
createConvertLinalgToLLVMPass()97 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
98 return std::make_unique<ConvertLinalgToLLVMPass>();
99 }
100