1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
21 #include "mlir/Dialect/Linalg/Passes.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Attributes.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/MLIRContext.h"
30 #include "mlir/IR/Operation.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/Types.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Transforms/DialectConversion.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/SetVector.h"
37 #include "llvm/IR/DerivedTypes.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/Support/Allocator.h"
41 #include "llvm/Support/ErrorHandling.h"
42 
43 using namespace mlir;
44 using namespace mlir::LLVM;
45 using namespace mlir::linalg;
46 
47 template <typename T>
48 static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
49   return LLVMPointerType::get(
50       lowering.convertType(containerType.getElementType()));
51 }
52 
53 /// Convert the given range descriptor type to the LLVMIR dialect.
54 /// Range descriptor contains the range bounds and the step as 64-bit integers.
55 ///
56 /// struct {
57 ///   int64_t min;
58 ///   int64_t max;
59 ///   int64_t step;
60 /// };
61 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
62   auto *context = t.getContext();
63   auto int64Ty = converter.convertType(IntegerType::get(context, 64));
64   return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
65 }
66 
67 namespace {
68 // RangeOp creates a new range descriptor.
69 class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
70 public:
71   using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
72 
73   LogicalResult
74   matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
75                   ConversionPatternRewriter &rewriter) const override {
76     auto rangeDescriptorTy = convertRangeType(
77         rangeOp.getType().cast<RangeType>(), *getTypeConverter());
78 
79     ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter);
80 
81     // Fill in an aggregate value of the descriptor.
82     RangeOpAdaptor adaptor(operands);
83     Value desc = b.create<LLVM::UndefOp>(rangeDescriptorTy);
84     desc = b.create<LLVM::InsertValueOp>(desc, adaptor.min(),
85                                          rewriter.getI64ArrayAttr(0));
86     desc = b.create<LLVM::InsertValueOp>(desc, adaptor.max(),
87                                          rewriter.getI64ArrayAttr(1));
88     desc = b.create<LLVM::InsertValueOp>(desc, adaptor.step(),
89                                          rewriter.getI64ArrayAttr(2));
90     rewriter.replaceOp(rangeOp, desc);
91     return success();
92   }
93 };
94 
95 // ReshapeOp creates a new view descriptor of the proper rank.
96 // For now, the only conversion supported is for target MemRef with static sizes
97 // and strides.
98 template <typename ReshapeOp>
99 class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
100 public:
101   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
102   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
103 
104   LogicalResult
105   matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
106                   ConversionPatternRewriter &rewriter) const override {
107     MemRefType dstType = reshapeOp.getResultType();
108 
109     if (!dstType.hasStaticShape())
110       return failure();
111 
112     int64_t offset;
113     SmallVector<int64_t, 4> strides;
114     auto res = getStridesAndOffset(dstType, strides, offset);
115     if (failed(res) || llvm::any_of(strides, [](int64_t val) {
116           return ShapedType::isDynamicStrideOrOffset(val);
117         }))
118       return failure();
119 
120     ReshapeOpAdaptor adaptor(operands);
121     MemRefDescriptor baseDesc(adaptor.src());
122     Location loc = reshapeOp->getLoc();
123     auto desc =
124         MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(),
125                                 this->typeConverter->convertType(dstType));
126     desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc));
127     desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc));
128     desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc));
129     for (auto en : llvm::enumerate(dstType.getShape()))
130       desc.setConstantSize(rewriter, loc, en.index(), en.value());
131     for (auto en : llvm::enumerate(strides))
132       desc.setConstantStride(rewriter, loc, en.index(), en.value());
133     rewriter.replaceOp(reshapeOp, {desc});
134     return success();
135   }
136 };
137 
138 // YieldOp produces and LLVM::ReturnOp.
139 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
140 public:
141   using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
142 
143   LogicalResult
144   matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
145                   ConversionPatternRewriter &rewriter) const override {
146     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
147     return success();
148   }
149 };
150 } // namespace
151 
152 /// Populate the given list with patterns that convert from Linalg to LLVM.
153 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
154                                                   RewritePatternSet &patterns) {
155   patterns.add<RangeOpConversion, ReshapeOpConversion<ExpandShapeOp>,
156                ReshapeOpConversion<CollapseShapeOp>, YieldOpConversion>(
157       converter);
158 
159   // Populate the type conversions for the linalg types.
160   converter.addConversion(
161       [&](RangeType type) { return convertRangeType(type, converter); });
162 }
163 
164 namespace {
165 struct ConvertLinalgToLLVMPass
166     : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
167   void runOnOperation() override;
168 };
169 } // namespace
170 
171 void ConvertLinalgToLLVMPass::runOnOperation() {
172   auto module = getOperation();
173 
174   // Convert to the LLVM IR dialect using the converter defined above.
175   RewritePatternSet patterns(&getContext());
176   LLVMTypeConverter converter(&getContext());
177   populateLinalgToLLVMConversionPatterns(converter, patterns);
178 
179   LLVMConversionTarget target(getContext());
180   target.addIllegalOp<RangeOp>();
181   target.addLegalOp<ModuleOp, LLVM::DialectCastOp>();
182   if (failed(applyPartialConversion(module, target, std::move(patterns))))
183     signalPassFailure();
184 }
185 
186 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
187   return std::make_unique<ConvertLinalgToLLVMPass>();
188 }
189