1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
21 #include "mlir/Dialect/Linalg/Passes.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/IR/Operation.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/Support/LogicalResult.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/SetVector.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/Allocator.h"
42 #include "llvm/Support/ErrorHandling.h"
43 
44 using namespace mlir;
45 using namespace mlir::edsc;
46 using namespace mlir::edsc::intrinsics;
47 using namespace mlir::LLVM;
48 using namespace mlir::linalg;
49 
50 using llvm_add = ValueBuilder<LLVM::AddOp>;
51 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
52 using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
53 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
54 using llvm_gep = ValueBuilder<LLVM::GEPOp>;
55 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
56 using llvm_call = OperationBuilder<LLVM::CallOp>;
57 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
58 using llvm_load = ValueBuilder<LLVM::LoadOp>;
59 using llvm_store = OperationBuilder<LLVM::StoreOp>;
60 using llvm_select = ValueBuilder<LLVM::SelectOp>;
61 using llvm_mul = ValueBuilder<LLVM::MulOp>;
62 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
63 using llvm_sub = ValueBuilder<LLVM::SubOp>;
64 using llvm_undef = ValueBuilder<LLVM::UndefOp>;
65 using llvm_urem = ValueBuilder<LLVM::URemOp>;
66 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
67 using llvm_return = OperationBuilder<LLVM::ReturnOp>;
68 
69 template <typename T>
70 static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
71   return LLVMPointerType::get(
72       lowering.convertType(containerType.getElementType()));
73 }
74 
75 /// Convert the given range descriptor type to the LLVMIR dialect.
76 /// Range descriptor contains the range bounds and the step as 64-bit integers.
77 ///
78 /// struct {
79 ///   int64_t min;
80 ///   int64_t max;
81 ///   int64_t step;
82 /// };
83 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
84   auto *context = t.getContext();
85   auto int64Ty = converter.convertType(IntegerType::get(context, 64));
86   return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
87 }
88 
89 namespace {
90 /// EDSC-compatible wrapper for MemRefDescriptor.
91 class BaseViewConversionHelper {
92 public:
93   BaseViewConversionHelper(Type type)
94       : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
95 
96   BaseViewConversionHelper(Value v) : d(v) {}
97 
98   /// Wrappers around MemRefDescriptor that use EDSC builder and location.
99   Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
100   void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
101   Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
102   void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
103   Value offset() { return d.offset(rewriter(), loc()); }
104   void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
105   Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
106   void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
107   void setConstantSize(unsigned i, int64_t v) {
108     d.setConstantSize(rewriter(), loc(), i, v);
109   }
110   Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
111   void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
112   void setConstantStride(unsigned i, int64_t v) {
113     d.setConstantStride(rewriter(), loc(), i, v);
114   }
115 
116   operator Value() { return d; }
117 
118 private:
119   OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); }
120   Location loc() { return ScopedContext::getLocation(); }
121 
122   MemRefDescriptor d;
123 };
124 
125 // RangeOp creates a new range descriptor.
126 class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
127 public:
128   using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
129 
130   LogicalResult
131   matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
132                   ConversionPatternRewriter &rewriter) const override {
133     auto rangeDescriptorTy = convertRangeType(
134         rangeOp.getType().cast<RangeType>(), *getTypeConverter());
135 
136     edsc::ScopedContext context(rewriter, rangeOp->getLoc());
137 
138     // Fill in an aggregate value of the descriptor.
139     RangeOpAdaptor adaptor(operands);
140     Value desc = llvm_undef(rangeDescriptorTy);
141     desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
142     desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
143     desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
144     rewriter.replaceOp(rangeOp, desc);
145     return success();
146   }
147 };
148 
149 // ReshapeOp creates a new view descriptor of the proper rank.
150 // For now, the only conversion supported is for target MemRef with static sizes
151 // and strides.
152 class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
153 public:
154   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
155 
156   LogicalResult
157   matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
158                   ConversionPatternRewriter &rewriter) const override {
159     MemRefType dstType = reshapeOp.getResultType();
160 
161     if (!dstType.hasStaticShape())
162       return failure();
163 
164     int64_t offset;
165     SmallVector<int64_t, 4> strides;
166     auto res = getStridesAndOffset(dstType, strides, offset);
167     if (failed(res) || llvm::any_of(strides, [](int64_t val) {
168           return ShapedType::isDynamicStrideOrOffset(val);
169         }))
170       return failure();
171 
172     edsc::ScopedContext context(rewriter, reshapeOp->getLoc());
173     ReshapeOpAdaptor adaptor(operands);
174     BaseViewConversionHelper baseDesc(adaptor.src());
175     BaseViewConversionHelper desc(typeConverter->convertType(dstType));
176     desc.setAllocatedPtr(baseDesc.allocatedPtr());
177     desc.setAlignedPtr(baseDesc.alignedPtr());
178     desc.setOffset(baseDesc.offset());
179     for (auto en : llvm::enumerate(dstType.getShape()))
180       desc.setConstantSize(en.index(), en.value());
181     for (auto en : llvm::enumerate(strides))
182       desc.setConstantStride(en.index(), en.value());
183     rewriter.replaceOp(reshapeOp, {desc});
184     return success();
185   }
186 };
187 
188 // YieldOp produces and LLVM::ReturnOp.
189 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
190 public:
191   using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
192 
193   LogicalResult
194   matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
195                   ConversionPatternRewriter &rewriter) const override {
196     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
197     return success();
198   }
199 };
200 } // namespace
201 
202 /// Populate the given list with patterns that convert from Linalg to LLVM.
203 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
204                                                   RewritePatternSet &patterns) {
205   patterns.add<RangeOpConversion, ReshapeOpConversion, YieldOpConversion>(
206       converter);
207 
208   // Populate the type conversions for the linalg types.
209   converter.addConversion(
210       [&](RangeType type) { return convertRangeType(type, converter); });
211 }
212 
213 namespace {
214 struct ConvertLinalgToLLVMPass
215     : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
216   void runOnOperation() override;
217 };
218 } // namespace
219 
220 void ConvertLinalgToLLVMPass::runOnOperation() {
221   auto module = getOperation();
222 
223   // Convert to the LLVM IR dialect using the converter defined above.
224   RewritePatternSet patterns(&getContext());
225   LLVMTypeConverter converter(&getContext());
226   populateLinalgToLLVMConversionPatterns(converter, patterns);
227 
228   LLVMConversionTarget target(getContext());
229   target.addIllegalOp<RangeOp>();
230   target.addLegalOp<ModuleOp, LLVM::DialectCastOp>();
231   if (failed(applyPartialConversion(module, target, std::move(patterns))))
232     signalPassFailure();
233 }
234 
235 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
236   return std::make_unique<ConvertLinalgToLLVMPass>();
237 }
238