1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
21 #include "mlir/Dialect/Linalg/Passes.h"
22 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Attributes.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Module.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/StandardTypes.h"
32 #include "mlir/IR/Types.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Transforms/DialectConversion.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/SetVector.h"
37 #include "llvm/IR/DerivedTypes.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/Support/Allocator.h"
41 #include "llvm/Support/ErrorHandling.h"
42 
43 using namespace mlir;
44 using namespace mlir::edsc;
45 using namespace mlir::edsc::intrinsics;
46 using namespace mlir::LLVM;
47 using namespace mlir::linalg;
48 
49 using llvm_add = ValueBuilder<LLVM::AddOp>;
50 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
51 using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
52 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
53 using llvm_gep = ValueBuilder<LLVM::GEPOp>;
54 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
55 using llvm_call = OperationBuilder<LLVM::CallOp>;
56 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
57 using llvm_load = ValueBuilder<LLVM::LoadOp>;
58 using llvm_store = OperationBuilder<LLVM::StoreOp>;
59 using llvm_select = ValueBuilder<LLVM::SelectOp>;
60 using llvm_mul = ValueBuilder<LLVM::MulOp>;
61 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
62 using llvm_sub = ValueBuilder<LLVM::SubOp>;
63 using llvm_undef = ValueBuilder<LLVM::UndefOp>;
64 using llvm_urem = ValueBuilder<LLVM::URemOp>;
65 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
66 using llvm_return = OperationBuilder<LLVM::ReturnOp>;
67 
68 template <typename T>
69 static LLVMType getPtrToElementType(T containerType,
70                                     LLVMTypeConverter &lowering) {
71   return lowering.convertType(containerType.getElementType())
72       .template cast<LLVMType>()
73       .getPointerTo();
74 }
75 
76 /// Convert the given range descriptor type to the LLVMIR dialect.
77 /// Range descriptor contains the range bounds and the step as 64-bit integers.
78 ///
79 /// struct {
80 ///   int64_t min;
81 ///   int64_t max;
82 ///   int64_t step;
83 /// };
84 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
85   auto *context = t.getContext();
86   auto int64Ty = converter.convertType(IntegerType::get(64, context))
87                      .cast<LLVM::LLVMType>();
88   return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
89 }
90 
91 namespace {
92 /// EDSC-compatible wrapper for MemRefDescriptor.
93 class BaseViewConversionHelper {
94 public:
95   BaseViewConversionHelper(Type type)
96       : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
97 
98   BaseViewConversionHelper(Value v) : d(v) {}
99 
100   /// Wrappers around MemRefDescriptor that use EDSC builder and location.
101   Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
102   void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
103   Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
104   void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
105   Value offset() { return d.offset(rewriter(), loc()); }
106   void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
107   Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
108   void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
109   void setConstantSize(unsigned i, int64_t v) {
110     d.setConstantSize(rewriter(), loc(), i, v);
111   }
112   Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
113   void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
114   void setConstantStride(unsigned i, int64_t v) {
115     d.setConstantStride(rewriter(), loc(), i, v);
116   }
117 
118   operator Value() { return d; }
119 
120 private:
121   OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); }
122   Location loc() { return ScopedContext::getLocation(); }
123 
124   MemRefDescriptor d;
125 };
126 
127 // RangeOp creates a new range descriptor.
128 class RangeOpConversion : public ConvertToLLVMPattern {
129 public:
130   explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
131       : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
132 
133   LogicalResult
134   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
135                   ConversionPatternRewriter &rewriter) const override {
136     auto rangeOp = cast<RangeOp>(op);
137     auto rangeDescriptorTy =
138         convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter);
139 
140     edsc::ScopedContext context(rewriter, op->getLoc());
141 
142     // Fill in an aggregate value of the descriptor.
143     RangeOpOperandAdaptor adaptor(operands);
144     Value desc = llvm_undef(rangeDescriptorTy);
145     desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
146     desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
147     desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
148     rewriter.replaceOp(op, desc);
149     return success();
150   }
151 };
152 
153 // ReshapeOp creates a new view descriptor of the proper rank.
154 // For now, the only conversion supported is for target MemRef with static sizes
155 // and strides.
156 class ReshapeOpConversion : public ConvertToLLVMPattern {
157 public:
158   explicit ReshapeOpConversion(MLIRContext *context,
159                                LLVMTypeConverter &lowering_)
160       : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
161                              lowering_) {}
162 
163   LogicalResult
164   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
165                   ConversionPatternRewriter &rewriter) const override {
166     auto reshapeOp = cast<ReshapeOp>(op);
167     MemRefType dstType = reshapeOp.getResultType();
168 
169     if (!dstType.hasStaticShape())
170       return failure();
171 
172     int64_t offset;
173     SmallVector<int64_t, 4> strides;
174     auto res = getStridesAndOffset(dstType, strides, offset);
175     if (failed(res) || llvm::any_of(strides, [](int64_t val) {
176           return ShapedType::isDynamicStrideOrOffset(val);
177         }))
178       return failure();
179 
180     edsc::ScopedContext context(rewriter, op->getLoc());
181     ReshapeOpOperandAdaptor adaptor(operands);
182     BaseViewConversionHelper baseDesc(adaptor.src());
183     BaseViewConversionHelper desc(typeConverter.convertType(dstType));
184     desc.setAllocatedPtr(baseDesc.allocatedPtr());
185     desc.setAlignedPtr(baseDesc.alignedPtr());
186     desc.setOffset(baseDesc.offset());
187     for (auto en : llvm::enumerate(dstType.getShape()))
188       desc.setConstantSize(en.index(), en.value());
189     for (auto en : llvm::enumerate(strides))
190       desc.setConstantStride(en.index(), en.value());
191     rewriter.replaceOp(op, {desc});
192     return success();
193   }
194 };
195 
196 /// Conversion pattern that transforms a linalg.slice op into:
197 ///   1. An "undef" value for the ViewDescriptor.
198 ///   2. Updates to the ViewDescriptor to introduce the data ptr, offset, size
199 ///      and stride corresponding to the region of memory within the bounds of
200 ///      the parent view.
201 /// The linalg.slice op is replaced by the alloca'ed pointer.
202 class SliceOpConversion : public ConvertToLLVMPattern {
203 public:
204   explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
205       : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
206 
207   LogicalResult
208   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
209                   ConversionPatternRewriter &rewriter) const override {
210     edsc::ScopedContext context(rewriter, op->getLoc());
211     SliceOpOperandAdaptor adaptor(operands);
212     BaseViewConversionHelper baseDesc(adaptor.view());
213 
214     auto sliceOp = cast<SliceOp>(op);
215     auto memRefType = sliceOp.getBaseViewType();
216     auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64))
217                        .cast<LLVM::LLVMType>();
218 
219     BaseViewConversionHelper desc(
220         typeConverter.convertType(sliceOp.getShapedType()));
221 
222     // TODO(ntv): extract sizes and emit asserts.
223     SmallVector<Value, 4> strides(memRefType.getRank());
224     for (int i = 0, e = memRefType.getRank(); i < e; ++i)
225       strides[i] = baseDesc.stride(i);
226 
227     auto pos = [&rewriter](ArrayRef<int64_t> values) {
228       return rewriter.getI64ArrayAttr(values);
229     };
230 
231     // Compute base offset.
232     Value baseOffset = baseDesc.offset();
233     for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
234       Value indexing = adaptor.indexings()[i];
235       Value min = indexing;
236       if (sliceOp.indexing(i).getType().isa<RangeType>())
237         min = llvm_extractvalue(int64Ty, indexing, pos(0));
238       baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i]));
239     }
240 
241     // Insert the base and aligned pointers.
242     desc.setAllocatedPtr(baseDesc.allocatedPtr());
243     desc.setAlignedPtr(baseDesc.alignedPtr());
244 
245     // Insert base offset.
246     desc.setOffset(baseOffset);
247 
248     // Corner case, no sizes or strides: early return the descriptor.
249     if (sliceOp.getShapedType().getRank() == 0)
250       return rewriter.replaceOp(op, {desc}), success();
251 
252     Value zero = llvm_constant(
253         int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
254     // Compute and insert view sizes (max - min along the range) and strides.
255     // Skip the non-range operands as they will be projected away from the view.
256     int numNewDims = 0;
257     for (auto en : llvm::enumerate(sliceOp.indexings())) {
258       Value indexing = en.value();
259       if (indexing.getType().isa<RangeType>()) {
260         int rank = en.index();
261         Value rangeDescriptor = adaptor.indexings()[rank];
262         Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0));
263         Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1));
264         Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2));
265         Value baseSize = baseDesc.size(rank);
266 
267         // Bound upper by base view upper bound.
268         max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
269                           baseSize);
270         Value size = llvm_sub(max, min);
271         // Bound lower by zero.
272         size =
273             llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
274         Value stride = llvm_mul(strides[rank], step);
275         desc.setSize(numNewDims, size);
276         desc.setStride(numNewDims, stride);
277         ++numNewDims;
278       }
279     }
280 
281     rewriter.replaceOp(op, {desc});
282     return success();
283   }
284 };
285 
286 /// Conversion pattern that transforms a linalg.transpose op into:
287 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
288 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
289 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
290 ///      and stride. Size and stride are permutations of the original values.
291 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
292 /// The linalg.transpose op is replaced by the alloca'ed pointer.
293 class TransposeOpConversion : public ConvertToLLVMPattern {
294 public:
295   explicit TransposeOpConversion(MLIRContext *context,
296                                  LLVMTypeConverter &lowering_)
297       : ConvertToLLVMPattern(TransposeOp::getOperationName(), context,
298                              lowering_) {}
299 
300   LogicalResult
301   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
302                   ConversionPatternRewriter &rewriter) const override {
303     // Initialize the common boilerplate and alloca at the top of the FuncOp.
304     edsc::ScopedContext context(rewriter, op->getLoc());
305     TransposeOpOperandAdaptor adaptor(operands);
306     BaseViewConversionHelper baseDesc(adaptor.view());
307 
308     auto transposeOp = cast<TransposeOp>(op);
309     // No permutation, early exit.
310     if (transposeOp.permutation().isIdentity())
311       return rewriter.replaceOp(op, {baseDesc}), success();
312 
313     BaseViewConversionHelper desc(
314         typeConverter.convertType(transposeOp.getShapedType()));
315 
316     // Copy the base and aligned pointers from the old descriptor to the new
317     // one.
318     desc.setAllocatedPtr(baseDesc.allocatedPtr());
319     desc.setAlignedPtr(baseDesc.alignedPtr());
320 
321     // Copy the offset pointer from the old descriptor to the new one.
322     desc.setOffset(baseDesc.offset());
323 
324     // Iterate over the dimensions and apply size/stride permutation.
325     for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
326       int sourcePos = en.index();
327       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
328       desc.setSize(targetPos, baseDesc.size(sourcePos));
329       desc.setStride(targetPos, baseDesc.stride(sourcePos));
330     }
331 
332     rewriter.replaceOp(op, {desc});
333     return success();
334   }
335 };
336 
337 // YieldOp produces and LLVM::ReturnOp.
338 class YieldOpConversion : public ConvertToLLVMPattern {
339 public:
340   explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
341       : ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {}
342 
343   LogicalResult
344   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
345                   ConversionPatternRewriter &rewriter) const override {
346     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
347     return success();
348   }
349 };
350 } // namespace
351 
352 /// Populate the given list with patterns that convert from Linalg to LLVM.
353 void mlir::populateLinalgToLLVMConversionPatterns(
354     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
355     MLIRContext *ctx) {
356   patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
357                   TransposeOpConversion, YieldOpConversion>(ctx, converter);
358 
359   // Populate the type conversions for the linalg types.
360   converter.addConversion(
361       [&](RangeType type) { return convertRangeType(type, converter); });
362 }
363 
364 namespace {
365 struct ConvertLinalgToLLVMPass
366     : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
367   void runOnOperation() override;
368 };
369 } // namespace
370 
371 void ConvertLinalgToLLVMPass::runOnOperation() {
372   auto module = getOperation();
373 
374   // Convert to the LLVM IR dialect using the converter defined above.
375   OwningRewritePatternList patterns;
376   LLVMTypeConverter converter(&getContext());
377   populateAffineToStdConversionPatterns(patterns, &getContext());
378   populateLoopToStdConversionPatterns(patterns, &getContext());
379   populateStdToLLVMConversionPatterns(converter, patterns);
380   populateVectorToSCFConversionPatterns(patterns, &getContext());
381   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
382   populateVectorToLLVMConversionPatterns(converter, patterns);
383   populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
384 
385   LLVMConversionTarget target(getContext());
386   target.addDynamicallyLegalOp<FuncOp>(
387       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
388   target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
389   if (failed(applyFullConversion(module, target, patterns, &converter)))
390     signalPassFailure();
391 }
392 
393 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
394   return std::make_unique<ConvertLinalgToLLVMPass>();
395 }
396