1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
17 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
18 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
19 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
22 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
23 #include "mlir/Dialect/Linalg/Passes.h"
24 #include "mlir/Dialect/SCF/SCF.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/IR/Attributes.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/MLIRContext.h"
32 #include "mlir/IR/Operation.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/Types.h"
35 #include "mlir/Support/LogicalResult.h"
36 #include "mlir/Transforms/DialectConversion.h"
37 #include "mlir/Transforms/Passes.h"
38 #include "llvm/ADT/SetVector.h"
39 #include "llvm/IR/DerivedTypes.h"
40 #include "llvm/IR/Module.h"
41 #include "llvm/IR/Type.h"
42 #include "llvm/Support/Allocator.h"
43 #include "llvm/Support/ErrorHandling.h"
44 
45 using namespace mlir;
46 using namespace mlir::LLVM;
47 using namespace mlir::linalg;
48 
49 template <typename T>
50 static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
51   return LLVMPointerType::get(
52       lowering.convertType(containerType.getElementType()));
53 }
54 
55 /// Convert the given range descriptor type to the LLVMIR dialect.
56 /// Range descriptor contains the range bounds and the step as 64-bit integers.
57 ///
58 /// struct {
59 ///   int64_t min;
60 ///   int64_t max;
61 ///   int64_t step;
62 /// };
63 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
64   auto *context = t.getContext();
65   auto int64Ty = converter.convertType(IntegerType::get(context, 64));
66   return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
67 }
68 
69 namespace {
70 // RangeOp creates a new range descriptor.
71 class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
72 public:
73   using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
74 
75   LogicalResult
76   matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
77                   ConversionPatternRewriter &rewriter) const override {
78     auto rangeDescriptorTy = convertRangeType(
79         rangeOp.getType().cast<RangeType>(), *getTypeConverter());
80 
81     ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter);
82 
83     // Fill in an aggregate value of the descriptor.
84     RangeOpAdaptor adaptor(operands);
85     Value desc = b.create<LLVM::UndefOp>(rangeDescriptorTy);
86     desc = b.create<LLVM::InsertValueOp>(desc, adaptor.min(),
87                                          rewriter.getI64ArrayAttr(0));
88     desc = b.create<LLVM::InsertValueOp>(desc, adaptor.max(),
89                                          rewriter.getI64ArrayAttr(1));
90     desc = b.create<LLVM::InsertValueOp>(desc, adaptor.step(),
91                                          rewriter.getI64ArrayAttr(2));
92     rewriter.replaceOp(rangeOp, desc);
93     return success();
94   }
95 };
96 
97 
98 // YieldOp produces and LLVM::ReturnOp.
99 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
100 public:
101   using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
102 
103   LogicalResult
104   matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
105                   ConversionPatternRewriter &rewriter) const override {
106     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
107     return success();
108   }
109 };
110 } // namespace
111 
112 /// Populate the given list with patterns that convert from Linalg to LLVM.
113 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
114                                                   RewritePatternSet &patterns) {
115   patterns.add<RangeOpConversion, YieldOpConversion>(converter);
116 
117   // Populate the type conversions for the linalg types.
118   converter.addConversion(
119       [&](RangeType type) { return convertRangeType(type, converter); });
120 }
121 
122 namespace {
123 struct ConvertLinalgToLLVMPass
124     : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
125   void runOnOperation() override;
126 };
127 } // namespace
128 
129 void ConvertLinalgToLLVMPass::runOnOperation() {
130   auto module = getOperation();
131 
132   // Convert to the LLVM IR dialect using the converter defined above.
133   RewritePatternSet patterns(&getContext());
134   LLVMTypeConverter converter(&getContext());
135   populateLinalgToLLVMConversionPatterns(converter, patterns);
136   populateMemRefToLLVMConversionPatterns(converter, patterns);
137 
138   LLVMConversionTarget target(getContext());
139   target.addIllegalOp<RangeOp>();
140   target.addLegalOp<ModuleOp>();
141   if (failed(applyPartialConversion(module, target, std::move(patterns))))
142     signalPassFailure();
143 }
144 
145 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
146   return std::make_unique<ConvertLinalgToLLVMPass>();
147 }
148