1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
21 #include "mlir/Dialect/Linalg/Passes.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/IR/Operation.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/Support/LogicalResult.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/SetVector.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/Allocator.h"
42 #include "llvm/Support/ErrorHandling.h"
43 
44 using namespace mlir;
45 using namespace mlir::edsc;
46 using namespace mlir::edsc::intrinsics;
47 using namespace mlir::LLVM;
48 using namespace mlir::linalg;
49 
50 using llvm_add = ValueBuilder<LLVM::AddOp>;
51 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
52 using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
53 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
54 using llvm_gep = ValueBuilder<LLVM::GEPOp>;
55 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
56 using llvm_call = OperationBuilder<LLVM::CallOp>;
57 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
58 using llvm_load = ValueBuilder<LLVM::LoadOp>;
59 using llvm_store = OperationBuilder<LLVM::StoreOp>;
60 using llvm_select = ValueBuilder<LLVM::SelectOp>;
61 using llvm_mul = ValueBuilder<LLVM::MulOp>;
62 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
63 using llvm_sub = ValueBuilder<LLVM::SubOp>;
64 using llvm_undef = ValueBuilder<LLVM::UndefOp>;
65 using llvm_urem = ValueBuilder<LLVM::URemOp>;
66 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
67 using llvm_return = OperationBuilder<LLVM::ReturnOp>;
68 
69 template <typename T>
70 static LLVMType getPtrToElementType(T containerType,
71                                     LLVMTypeConverter &lowering) {
72   return lowering.convertType(containerType.getElementType())
73       .template cast<LLVMType>()
74       .getPointerTo();
75 }
76 
77 /// Convert the given range descriptor type to the LLVMIR dialect.
78 /// Range descriptor contains the range bounds and the step as 64-bit integers.
79 ///
80 /// struct {
81 ///   int64_t min;
82 ///   int64_t max;
83 ///   int64_t step;
84 /// };
85 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
86   auto *context = t.getContext();
87   auto int64Ty = converter.convertType(IntegerType::get(context, 64))
88                      .cast<LLVM::LLVMType>();
89   return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
90 }
91 
92 namespace {
93 /// EDSC-compatible wrapper for MemRefDescriptor.
94 class BaseViewConversionHelper {
95 public:
96   BaseViewConversionHelper(Type type)
97       : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
98 
99   BaseViewConversionHelper(Value v) : d(v) {}
100 
101   /// Wrappers around MemRefDescriptor that use EDSC builder and location.
102   Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
103   void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
104   Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
105   void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
106   Value offset() { return d.offset(rewriter(), loc()); }
107   void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
108   Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
109   void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
110   void setConstantSize(unsigned i, int64_t v) {
111     d.setConstantSize(rewriter(), loc(), i, v);
112   }
113   Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
114   void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
115   void setConstantStride(unsigned i, int64_t v) {
116     d.setConstantStride(rewriter(), loc(), i, v);
117   }
118 
119   operator Value() { return d; }
120 
121 private:
122   OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); }
123   Location loc() { return ScopedContext::getLocation(); }
124 
125   MemRefDescriptor d;
126 };
127 
128 // RangeOp creates a new range descriptor.
129 class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
130 public:
131   using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
132 
133   LogicalResult
134   matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
135                   ConversionPatternRewriter &rewriter) const override {
136     auto rangeDescriptorTy = convertRangeType(
137         rangeOp.getType().cast<RangeType>(), *getTypeConverter());
138 
139     edsc::ScopedContext context(rewriter, rangeOp->getLoc());
140 
141     // Fill in an aggregate value of the descriptor.
142     RangeOpAdaptor adaptor(operands);
143     Value desc = llvm_undef(rangeDescriptorTy);
144     desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
145     desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
146     desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
147     rewriter.replaceOp(rangeOp, desc);
148     return success();
149   }
150 };
151 
152 // ReshapeOp creates a new view descriptor of the proper rank.
153 // For now, the only conversion supported is for target MemRef with static sizes
154 // and strides.
155 class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
156 public:
157   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
158 
159   LogicalResult
160   matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
161                   ConversionPatternRewriter &rewriter) const override {
162     MemRefType dstType = reshapeOp.getResultType();
163 
164     if (!dstType.hasStaticShape())
165       return failure();
166 
167     int64_t offset;
168     SmallVector<int64_t, 4> strides;
169     auto res = getStridesAndOffset(dstType, strides, offset);
170     if (failed(res) || llvm::any_of(strides, [](int64_t val) {
171           return ShapedType::isDynamicStrideOrOffset(val);
172         }))
173       return failure();
174 
175     edsc::ScopedContext context(rewriter, reshapeOp->getLoc());
176     ReshapeOpAdaptor adaptor(operands);
177     BaseViewConversionHelper baseDesc(adaptor.src());
178     BaseViewConversionHelper desc(typeConverter->convertType(dstType));
179     desc.setAllocatedPtr(baseDesc.allocatedPtr());
180     desc.setAlignedPtr(baseDesc.alignedPtr());
181     desc.setOffset(baseDesc.offset());
182     for (auto en : llvm::enumerate(dstType.getShape()))
183       desc.setConstantSize(en.index(), en.value());
184     for (auto en : llvm::enumerate(strides))
185       desc.setConstantStride(en.index(), en.value());
186     rewriter.replaceOp(reshapeOp, {desc});
187     return success();
188   }
189 };
190 
191 /// Conversion pattern that transforms a linalg.slice op into:
192 ///   1. An "undef" value for the ViewDescriptor.
193 ///   2. Updates to the ViewDescriptor to introduce the data ptr, offset, size
194 ///      and stride corresponding to the region of memory within the bounds of
195 ///      the parent view.
196 /// The linalg.slice op is replaced by the alloca'ed pointer.
197 class SliceOpConversion : public ConvertOpToLLVMPattern<SliceOp> {
198 public:
199   using ConvertOpToLLVMPattern<SliceOp>::ConvertOpToLLVMPattern;
200 
201   LogicalResult
202   matchAndRewrite(SliceOp sliceOp, ArrayRef<Value> operands,
203                   ConversionPatternRewriter &rewriter) const override {
204     edsc::ScopedContext context(rewriter, sliceOp->getLoc());
205     SliceOpAdaptor adaptor(operands);
206     BaseViewConversionHelper baseDesc(adaptor.view());
207 
208     auto memRefType = sliceOp.getBaseViewType();
209     auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64))
210                        .cast<LLVM::LLVMType>();
211 
212     BaseViewConversionHelper desc(
213         typeConverter->convertType(sliceOp.getShapedType()));
214 
215     // TODO: extract sizes and emit asserts.
216     SmallVector<Value, 4> strides(memRefType.getRank());
217     for (int i = 0, e = memRefType.getRank(); i < e; ++i)
218       strides[i] = baseDesc.stride(i);
219 
220     auto pos = [&rewriter](ArrayRef<int64_t> values) {
221       return rewriter.getI64ArrayAttr(values);
222     };
223 
224     // Compute base offset.
225     Value baseOffset = baseDesc.offset();
226     for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
227       Value indexing = adaptor.indexings()[i];
228       Value min = indexing;
229       if (sliceOp.indexing(i).getType().isa<RangeType>())
230         min = llvm_extractvalue(int64Ty, indexing, pos(0));
231       baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i]));
232     }
233 
234     // Insert the base and aligned pointers.
235     desc.setAllocatedPtr(baseDesc.allocatedPtr());
236     desc.setAlignedPtr(baseDesc.alignedPtr());
237 
238     // Insert base offset.
239     desc.setOffset(baseOffset);
240 
241     // Corner case, no sizes or strides: early return the descriptor.
242     if (sliceOp.getShapedType().getRank() == 0)
243       return rewriter.replaceOp(sliceOp, {desc}), success();
244 
245     Value zero = llvm_constant(
246         int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
247     // Compute and insert view sizes (max - min along the range) and strides.
248     // Skip the non-range operands as they will be projected away from the view.
249     int numNewDims = 0;
250     for (auto en : llvm::enumerate(sliceOp.indexings())) {
251       Value indexing = en.value();
252       if (indexing.getType().isa<RangeType>()) {
253         int rank = en.index();
254         Value rangeDescriptor = adaptor.indexings()[rank];
255         Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0));
256         Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1));
257         Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2));
258         Value baseSize = baseDesc.size(rank);
259 
260         // Bound upper by base view upper bound.
261         max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
262                           baseSize);
263         Value size = llvm_sub(max, min);
264         // Bound lower by zero.
265         size =
266             llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
267         Value stride = llvm_mul(strides[rank], step);
268         desc.setSize(numNewDims, size);
269         desc.setStride(numNewDims, stride);
270         ++numNewDims;
271       }
272     }
273 
274     rewriter.replaceOp(sliceOp, {desc});
275     return success();
276   }
277 };
278 
279 // YieldOp produces and LLVM::ReturnOp.
280 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
281 public:
282   using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
283 
284   LogicalResult
285   matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
286                   ConversionPatternRewriter &rewriter) const override {
287     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
288     return success();
289   }
290 };
291 } // namespace
292 
293 /// Populate the given list with patterns that convert from Linalg to LLVM.
294 void mlir::populateLinalgToLLVMConversionPatterns(
295     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
296   patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
297                   YieldOpConversion>(converter);
298 
299   // Populate the type conversions for the linalg types.
300   converter.addConversion(
301       [&](RangeType type) { return convertRangeType(type, converter); });
302 }
303 
304 namespace {
305 struct ConvertLinalgToLLVMPass
306     : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
307   void runOnOperation() override;
308 };
309 } // namespace
310 
311 void ConvertLinalgToLLVMPass::runOnOperation() {
312   auto module = getOperation();
313 
314   // Convert to the LLVM IR dialect using the converter defined above.
315   OwningRewritePatternList patterns;
316   LLVMTypeConverter converter(&getContext());
317   populateAffineToStdConversionPatterns(patterns, &getContext());
318   populateLoopToStdConversionPatterns(patterns, &getContext());
319   populateStdToLLVMConversionPatterns(converter, patterns);
320   populateVectorToSCFConversionPatterns(patterns, &getContext());
321   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
322   populateVectorToLLVMConversionPatterns(converter, patterns);
323   populateLinalgToLLVMConversionPatterns(converter, patterns);
324 
325   LLVMConversionTarget target(getContext());
326   target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
327   if (failed(applyFullConversion(module, target, std::move(patterns))))
328     signalPassFailure();
329 }
330 
331 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
332   return std::make_unique<ConvertLinalgToLLVMPass>();
333 }
334