1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/Pattern.h"
10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinAttributes.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
21 
ConvertToLLVMPattern(StringRef rootOpName,MLIRContext * context,LLVMTypeConverter & typeConverter,PatternBenefit benefit)22 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
23                                            MLIRContext *context,
24                                            LLVMTypeConverter &typeConverter,
25                                            PatternBenefit benefit)
26     : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
27 
getTypeConverter() const28 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
29   return static_cast<LLVMTypeConverter *>(
30       ConversionPattern::getTypeConverter());
31 }
32 
getDialect() const33 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
34   return *getTypeConverter()->getDialect();
35 }
36 
getIndexType() const37 Type ConvertToLLVMPattern::getIndexType() const {
38   return getTypeConverter()->getIndexType();
39 }
40 
getIntPtrType(unsigned addressSpace) const41 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
42   return IntegerType::get(&getTypeConverter()->getContext(),
43                           getTypeConverter()->getPointerBitwidth(addressSpace));
44 }
45 
getVoidType() const46 Type ConvertToLLVMPattern::getVoidType() const {
47   return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
48 }
49 
getVoidPtrType() const50 Type ConvertToLLVMPattern::getVoidPtrType() const {
51   return LLVM::LLVMPointerType::get(
52       IntegerType::get(&getTypeConverter()->getContext(), 8));
53 }
54 
createIndexAttrConstant(OpBuilder & builder,Location loc,Type resultType,int64_t value)55 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
56                                                     Location loc,
57                                                     Type resultType,
58                                                     int64_t value) {
59   return builder.create<LLVM::ConstantOp>(
60       loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
61 }
62 
createIndexConstant(ConversionPatternRewriter & builder,Location loc,uint64_t value) const63 Value ConvertToLLVMPattern::createIndexConstant(
64     ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
65   return createIndexAttrConstant(builder, loc, getIndexType(), value);
66 }
67 
getStridedElementPtr(Location loc,MemRefType type,Value memRefDesc,ValueRange indices,ConversionPatternRewriter & rewriter) const68 Value ConvertToLLVMPattern::getStridedElementPtr(
69     Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
70     ConversionPatternRewriter &rewriter) const {
71 
72   int64_t offset;
73   SmallVector<int64_t, 4> strides;
74   auto successStrides = getStridesAndOffset(type, strides, offset);
75   assert(succeeded(successStrides) && "unexpected non-strided memref");
76   (void)successStrides;
77 
78   MemRefDescriptor memRefDescriptor(memRefDesc);
79   Value base = memRefDescriptor.alignedPtr(rewriter, loc);
80 
81   Value index;
82   if (offset != 0) // Skip if offset is zero.
83     index = ShapedType::isDynamicStrideOrOffset(offset)
84                 ? memRefDescriptor.offset(rewriter, loc)
85                 : createIndexConstant(rewriter, loc, offset);
86 
87   for (int i = 0, e = indices.size(); i < e; ++i) {
88     Value increment = indices[i];
89     if (strides[i] != 1) { // Skip if stride is 1.
90       Value stride = ShapedType::isDynamicStrideOrOffset(strides[i])
91                          ? memRefDescriptor.stride(rewriter, loc, i)
92                          : createIndexConstant(rewriter, loc, strides[i]);
93       increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
94     }
95     index =
96         index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
97   }
98 
99   Type elementPtrType = memRefDescriptor.getElementPtrType();
100   return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
101                : base;
102 }
103 
104 // Check if the MemRefType `type` is supported by the lowering. We currently
105 // only support memrefs with identity maps.
isConvertibleAndHasIdentityMaps(MemRefType type) const106 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
107     MemRefType type) const {
108   if (!typeConverter->convertType(type.getElementType()))
109     return false;
110   return type.getLayout().isIdentity();
111 }
112 
getElementPtrType(MemRefType type) const113 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
114   auto elementType = type.getElementType();
115   auto structElementType = typeConverter->convertType(elementType);
116   return LLVM::LLVMPointerType::get(structElementType,
117                                     type.getMemorySpaceAsInt());
118 }
119 
getMemRefDescriptorSizes(Location loc,MemRefType memRefType,ValueRange dynamicSizes,ConversionPatternRewriter & rewriter,SmallVectorImpl<Value> & sizes,SmallVectorImpl<Value> & strides,Value & sizeBytes) const120 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
121     Location loc, MemRefType memRefType, ValueRange dynamicSizes,
122     ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
123     SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
124   assert(isConvertibleAndHasIdentityMaps(memRefType) &&
125          "layout maps must have been normalized away");
126   assert(count(memRefType.getShape(), ShapedType::kDynamicSize) ==
127              static_cast<ssize_t>(dynamicSizes.size()) &&
128          "dynamicSizes size doesn't match dynamic sizes count in memref shape");
129 
130   sizes.reserve(memRefType.getRank());
131   unsigned dynamicIndex = 0;
132   for (int64_t size : memRefType.getShape()) {
133     sizes.push_back(size == ShapedType::kDynamicSize
134                         ? dynamicSizes[dynamicIndex++]
135                         : createIndexConstant(rewriter, loc, size));
136   }
137 
138   // Strides: iterate sizes in reverse order and multiply.
139   int64_t stride = 1;
140   Value runningStride = createIndexConstant(rewriter, loc, 1);
141   strides.resize(memRefType.getRank());
142   for (auto i = memRefType.getRank(); i-- > 0;) {
143     strides[i] = runningStride;
144 
145     int64_t size = memRefType.getShape()[i];
146     if (size == 0)
147       continue;
148     bool useSizeAsStride = stride == 1;
149     if (size == ShapedType::kDynamicSize)
150       stride = ShapedType::kDynamicSize;
151     if (stride != ShapedType::kDynamicSize)
152       stride *= size;
153 
154     if (useSizeAsStride)
155       runningStride = sizes[i];
156     else if (stride == ShapedType::kDynamicSize)
157       runningStride =
158           rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
159     else
160       runningStride = createIndexConstant(rewriter, loc, stride);
161   }
162 
163   // Buffer size in bytes.
164   Type elementPtrType = getElementPtrType(memRefType);
165   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
166   Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
167                                               ArrayRef<Value>{runningStride});
168   sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
169 }
170 
getSizeInBytes(Location loc,Type type,ConversionPatternRewriter & rewriter) const171 Value ConvertToLLVMPattern::getSizeInBytes(
172     Location loc, Type type, ConversionPatternRewriter &rewriter) const {
173   // Compute the size of an individual element. This emits the MLIR equivalent
174   // of the following sizeof(...) implementation in LLVM IR:
175   //   %0 = getelementptr %elementType* null, %indexType 1
176   //   %1 = ptrtoint %elementType* %0 to %indexType
177   // which is a common pattern of getting the size of a type in bytes.
178   auto convertedPtrType =
179       LLVM::LLVMPointerType::get(typeConverter->convertType(type));
180   auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
181   auto gep = rewriter.create<LLVM::GEPOp>(
182       loc, convertedPtrType, nullPtr,
183       ArrayRef<Value>{createIndexConstant(rewriter, loc, 1)});
184   return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
185 }
186 
getNumElements(Location loc,ArrayRef<Value> shape,ConversionPatternRewriter & rewriter) const187 Value ConvertToLLVMPattern::getNumElements(
188     Location loc, ArrayRef<Value> shape,
189     ConversionPatternRewriter &rewriter) const {
190   // Compute the total number of memref elements.
191   Value numElements =
192       shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
193   for (unsigned i = 1, e = shape.size(); i < e; ++i)
194     numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
195   return numElements;
196 }
197 
198 /// Creates and populates the memref descriptor struct given all its fields.
createMemRefDescriptor(Location loc,MemRefType memRefType,Value allocatedPtr,Value alignedPtr,ArrayRef<Value> sizes,ArrayRef<Value> strides,ConversionPatternRewriter & rewriter) const199 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
200     Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
201     ArrayRef<Value> sizes, ArrayRef<Value> strides,
202     ConversionPatternRewriter &rewriter) const {
203   auto structType = typeConverter->convertType(memRefType);
204   auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
205 
206   // Field 1: Allocated pointer, used for malloc/free.
207   memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
208 
209   // Field 2: Actual aligned pointer to payload.
210   memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
211 
212   // Field 3: Offset in aligned pointer.
213   memRefDescriptor.setOffset(rewriter, loc,
214                              createIndexConstant(rewriter, loc, 0));
215 
216   // Fields 4: Sizes.
217   for (const auto &en : llvm::enumerate(sizes))
218     memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
219 
220   // Field 5: Strides.
221   for (const auto &en : llvm::enumerate(strides))
222     memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
223 
224   return memRefDescriptor;
225 }
226 
copyUnrankedDescriptors(OpBuilder & builder,Location loc,TypeRange origTypes,SmallVectorImpl<Value> & operands,bool toDynamic) const227 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
228     OpBuilder &builder, Location loc, TypeRange origTypes,
229     SmallVectorImpl<Value> &operands, bool toDynamic) const {
230   assert(origTypes.size() == operands.size() &&
231          "expected as may original types as operands");
232 
233   // Find operands of unranked memref type and store them.
234   SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
235   for (unsigned i = 0, e = operands.size(); i < e; ++i)
236     if (origTypes[i].isa<UnrankedMemRefType>())
237       unrankedMemrefs.emplace_back(operands[i]);
238 
239   if (unrankedMemrefs.empty())
240     return success();
241 
242   // Compute allocation sizes.
243   SmallVector<Value, 4> sizes;
244   UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
245                                          unrankedMemrefs, sizes);
246 
247   // Get frequently used types.
248   MLIRContext *context = builder.getContext();
249   Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
250   auto i1Type = IntegerType::get(context, 1);
251   Type indexType = getTypeConverter()->getIndexType();
252 
253   // Find the malloc and free, or declare them if necessary.
254   auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
255   LLVM::LLVMFuncOp freeFunc, mallocFunc;
256   if (toDynamic)
257     mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
258   if (!toDynamic)
259     freeFunc = LLVM::lookupOrCreateFreeFn(module);
260 
261   // Initialize shared constants.
262   Value zero =
263       builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
264 
265   unsigned unrankedMemrefPos = 0;
266   for (unsigned i = 0, e = operands.size(); i < e; ++i) {
267     Type type = origTypes[i];
268     if (!type.isa<UnrankedMemRefType>())
269       continue;
270     Value allocationSize = sizes[unrankedMemrefPos++];
271     UnrankedMemRefDescriptor desc(operands[i]);
272 
273     // Allocate memory, copy, and free the source if necessary.
274     Value memory =
275         toDynamic
276             ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
277                   .getResult(0)
278             : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
279                                              /*alignment=*/0);
280     Value source = desc.memRefDescPtr(builder, loc);
281     builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
282     if (!toDynamic)
283       builder.create<LLVM::CallOp>(loc, freeFunc, source);
284 
285     // Create a new descriptor. The same descriptor can be returned multiple
286     // times, attempting to modify its pointer can lead to memory leaks
287     // (allocated twice and overwritten) or double frees (the caller does not
288     // know if the descriptor points to the same memory).
289     Type descriptorType = getTypeConverter()->convertType(type);
290     if (!descriptorType)
291       return failure();
292     auto updatedDesc =
293         UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
294     Value rank = desc.rank(builder, loc);
295     updatedDesc.setRank(builder, loc, rank);
296     updatedDesc.setMemRefDescPtr(builder, loc, memory);
297 
298     operands[i] = updatedDesc;
299   }
300 
301   return success();
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // Detail methods
306 //===----------------------------------------------------------------------===//
307 
308 /// Replaces the given operation "op" with a new operation of type "targetOp"
309 /// and given operands.
oneToOneRewrite(Operation * op,StringRef targetOp,ValueRange operands,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)310 LogicalResult LLVM::detail::oneToOneRewrite(
311     Operation *op, StringRef targetOp, ValueRange operands,
312     LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
313   unsigned numResults = op->getNumResults();
314 
315   Type packedType;
316   if (numResults != 0) {
317     packedType = typeConverter.packFunctionResults(op->getResultTypes());
318     if (!packedType)
319       return failure();
320   }
321 
322   // Create the operation through state since we don't know its C++ type.
323   Operation *newOp =
324       rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
325                       packedType, op->getAttrs());
326 
327   // If the operation produced 0 or 1 result, return them immediately.
328   if (numResults == 0)
329     return rewriter.eraseOp(op), success();
330   if (numResults == 1)
331     return rewriter.replaceOp(op, newOp->getResult(0)), success();
332 
333   // Otherwise, it had been converted to an operation producing a structure.
334   // Extract individual results from the structure and return them as list.
335   SmallVector<Value, 4> results;
336   results.reserve(numResults);
337   for (unsigned i = 0; i < numResults; ++i) {
338     auto type = typeConverter.convertType(op->getResult(i).getType());
339     results.push_back(rewriter.create<LLVM::ExtractValueOp>(
340         op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
341   }
342   rewriter.replaceOp(op, results);
343   return success();
344 }
345