1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/Pattern.h"
10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13 #include "mlir/IR/AffineMap.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // ConvertToLLVMPattern
19 //===----------------------------------------------------------------------===//
20 
21 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
22                                            MLIRContext *context,
23                                            LLVMTypeConverter &typeConverter,
24                                            PatternBenefit benefit)
25     : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26 
27 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
28   return static_cast<LLVMTypeConverter *>(
29       ConversionPattern::getTypeConverter());
30 }
31 
32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33   return *getTypeConverter()->getDialect();
34 }
35 
36 Type ConvertToLLVMPattern::getIndexType() const {
37   return getTypeConverter()->getIndexType();
38 }
39 
40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
41   return IntegerType::get(&getTypeConverter()->getContext(),
42                           getTypeConverter()->getPointerBitwidth(addressSpace));
43 }
44 
45 Type ConvertToLLVMPattern::getVoidType() const {
46   return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
47 }
48 
49 Type ConvertToLLVMPattern::getVoidPtrType() const {
50   return LLVM::LLVMPointerType::get(
51       IntegerType::get(&getTypeConverter()->getContext(), 8));
52 }
53 
54 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
55                                                     Location loc,
56                                                     Type resultType,
57                                                     int64_t value) {
58   return builder.create<LLVM::ConstantOp>(
59       loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
60 }
61 
62 Value ConvertToLLVMPattern::createIndexConstant(
63     ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
64   return createIndexAttrConstant(builder, loc, getIndexType(), value);
65 }
66 
67 Value ConvertToLLVMPattern::getStridedElementPtr(
68     Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
69     ConversionPatternRewriter &rewriter) const {
70 
71   int64_t offset;
72   SmallVector<int64_t, 4> strides;
73   auto successStrides = getStridesAndOffset(type, strides, offset);
74   assert(succeeded(successStrides) && "unexpected non-strided memref");
75   (void)successStrides;
76 
77   MemRefDescriptor memRefDescriptor(memRefDesc);
78   Value base = memRefDescriptor.alignedPtr(rewriter, loc);
79 
80   Value index;
81   if (offset != 0) // Skip if offset is zero.
82     index = ShapedType::isDynamicStrideOrOffset(offset)
83                 ? memRefDescriptor.offset(rewriter, loc)
84                 : createIndexConstant(rewriter, loc, offset);
85 
86   for (int i = 0, e = indices.size(); i < e; ++i) {
87     Value increment = indices[i];
88     if (strides[i] != 1) { // Skip if stride is 1.
89       Value stride = ShapedType::isDynamicStrideOrOffset(strides[i])
90                          ? memRefDescriptor.stride(rewriter, loc, i)
91                          : createIndexConstant(rewriter, loc, strides[i]);
92       increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
93     }
94     index =
95         index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
96   }
97 
98   Type elementPtrType = memRefDescriptor.getElementPtrType();
99   return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
100                : base;
101 }
102 
103 // Check if the MemRefType `type` is supported by the lowering. We currently
104 // only support memrefs with identity maps.
105 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
106     MemRefType type) const {
107   if (!typeConverter->convertType(type.getElementType()))
108     return false;
109   return type.getLayout().isIdentity();
110 }
111 
112 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
113   auto elementType = type.getElementType();
114   auto structElementType = typeConverter->convertType(elementType);
115   return LLVM::LLVMPointerType::get(structElementType,
116                                     type.getMemorySpaceAsInt());
117 }
118 
119 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
120     Location loc, MemRefType memRefType, ValueRange dynamicSizes,
121     ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
122     SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
123   assert(isConvertibleAndHasIdentityMaps(memRefType) &&
124          "layout maps must have been normalized away");
125   assert(count(memRefType.getShape(), ShapedType::kDynamicSize) ==
126              static_cast<ssize_t>(dynamicSizes.size()) &&
127          "dynamicSizes size doesn't match dynamic sizes count in memref shape");
128 
129   sizes.reserve(memRefType.getRank());
130   unsigned dynamicIndex = 0;
131   for (int64_t size : memRefType.getShape()) {
132     sizes.push_back(size == ShapedType::kDynamicSize
133                         ? dynamicSizes[dynamicIndex++]
134                         : createIndexConstant(rewriter, loc, size));
135   }
136 
137   // Strides: iterate sizes in reverse order and multiply.
138   int64_t stride = 1;
139   Value runningStride = createIndexConstant(rewriter, loc, 1);
140   strides.resize(memRefType.getRank());
141   for (auto i = memRefType.getRank(); i-- > 0;) {
142     strides[i] = runningStride;
143 
144     int64_t size = memRefType.getShape()[i];
145     if (size == 0)
146       continue;
147     bool useSizeAsStride = stride == 1;
148     if (size == ShapedType::kDynamicSize)
149       stride = ShapedType::kDynamicSize;
150     if (stride != ShapedType::kDynamicSize)
151       stride *= size;
152 
153     if (useSizeAsStride)
154       runningStride = sizes[i];
155     else if (stride == ShapedType::kDynamicSize)
156       runningStride =
157           rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
158     else
159       runningStride = createIndexConstant(rewriter, loc, stride);
160   }
161 
162   // Buffer size in bytes.
163   Type elementPtrType = getElementPtrType(memRefType);
164   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
165   Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
166                                               ArrayRef<Value>{runningStride});
167   sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
168 }
169 
170 Value ConvertToLLVMPattern::getSizeInBytes(
171     Location loc, Type type, ConversionPatternRewriter &rewriter) const {
172   // Compute the size of an individual element. This emits the MLIR equivalent
173   // of the following sizeof(...) implementation in LLVM IR:
174   //   %0 = getelementptr %elementType* null, %indexType 1
175   //   %1 = ptrtoint %elementType* %0 to %indexType
176   // which is a common pattern of getting the size of a type in bytes.
177   auto convertedPtrType =
178       LLVM::LLVMPointerType::get(typeConverter->convertType(type));
179   auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
180   auto gep = rewriter.create<LLVM::GEPOp>(
181       loc, convertedPtrType, nullPtr,
182       ArrayRef<Value>{createIndexConstant(rewriter, loc, 1)});
183   return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
184 }
185 
186 Value ConvertToLLVMPattern::getNumElements(
187     Location loc, ArrayRef<Value> shape,
188     ConversionPatternRewriter &rewriter) const {
189   // Compute the total number of memref elements.
190   Value numElements =
191       shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
192   for (unsigned i = 1, e = shape.size(); i < e; ++i)
193     numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
194   return numElements;
195 }
196 
197 /// Creates and populates the memref descriptor struct given all its fields.
198 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
199     Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
200     ArrayRef<Value> sizes, ArrayRef<Value> strides,
201     ConversionPatternRewriter &rewriter) const {
202   auto structType = typeConverter->convertType(memRefType);
203   auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
204 
205   // Field 1: Allocated pointer, used for malloc/free.
206   memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
207 
208   // Field 2: Actual aligned pointer to payload.
209   memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
210 
211   // Field 3: Offset in aligned pointer.
212   memRefDescriptor.setOffset(rewriter, loc,
213                              createIndexConstant(rewriter, loc, 0));
214 
215   // Fields 4: Sizes.
216   for (const auto &en : llvm::enumerate(sizes))
217     memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
218 
219   // Field 5: Strides.
220   for (const auto &en : llvm::enumerate(strides))
221     memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
222 
223   return memRefDescriptor;
224 }
225 
226 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
227     OpBuilder &builder, Location loc, TypeRange origTypes,
228     SmallVectorImpl<Value> &operands, bool toDynamic) const {
229   assert(origTypes.size() == operands.size() &&
230          "expected as may original types as operands");
231 
232   // Find operands of unranked memref type and store them.
233   SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
234   for (unsigned i = 0, e = operands.size(); i < e; ++i)
235     if (origTypes[i].isa<UnrankedMemRefType>())
236       unrankedMemrefs.emplace_back(operands[i]);
237 
238   if (unrankedMemrefs.empty())
239     return success();
240 
241   // Compute allocation sizes.
242   SmallVector<Value, 4> sizes;
243   UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
244                                          unrankedMemrefs, sizes);
245 
246   // Get frequently used types.
247   MLIRContext *context = builder.getContext();
248   Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
249   auto i1Type = IntegerType::get(context, 1);
250   Type indexType = getTypeConverter()->getIndexType();
251 
252   // Find the malloc and free, or declare them if necessary.
253   auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
254   LLVM::LLVMFuncOp freeFunc, mallocFunc;
255   if (toDynamic)
256     mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
257   if (!toDynamic)
258     freeFunc = LLVM::lookupOrCreateFreeFn(module);
259 
260   // Initialize shared constants.
261   Value zero =
262       builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
263 
264   unsigned unrankedMemrefPos = 0;
265   for (unsigned i = 0, e = operands.size(); i < e; ++i) {
266     Type type = origTypes[i];
267     if (!type.isa<UnrankedMemRefType>())
268       continue;
269     Value allocationSize = sizes[unrankedMemrefPos++];
270     UnrankedMemRefDescriptor desc(operands[i]);
271 
272     // Allocate memory, copy, and free the source if necessary.
273     Value memory =
274         toDynamic
275             ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
276                   .getResult(0)
277             : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
278                                              /*alignment=*/0);
279     Value source = desc.memRefDescPtr(builder, loc);
280     builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
281     if (!toDynamic)
282       builder.create<LLVM::CallOp>(loc, freeFunc, source);
283 
284     // Create a new descriptor. The same descriptor can be returned multiple
285     // times, attempting to modify its pointer can lead to memory leaks
286     // (allocated twice and overwritten) or double frees (the caller does not
287     // know if the descriptor points to the same memory).
288     Type descriptorType = getTypeConverter()->convertType(type);
289     if (!descriptorType)
290       return failure();
291     auto updatedDesc =
292         UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
293     Value rank = desc.rank(builder, loc);
294     updatedDesc.setRank(builder, loc, rank);
295     updatedDesc.setMemRefDescPtr(builder, loc, memory);
296 
297     operands[i] = updatedDesc;
298   }
299 
300   return success();
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // Detail methods
305 //===----------------------------------------------------------------------===//
306 
307 /// Replaces the given operation "op" with a new operation of type "targetOp"
308 /// and given operands.
309 LogicalResult LLVM::detail::oneToOneRewrite(
310     Operation *op, StringRef targetOp, ValueRange operands,
311     LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
312   unsigned numResults = op->getNumResults();
313 
314   Type packedType;
315   if (numResults != 0) {
316     packedType = typeConverter.packFunctionResults(op->getResultTypes());
317     if (!packedType)
318       return failure();
319   }
320 
321   // Create the operation through state since we don't know its C++ type.
322   OperationState state(op->getLoc(), targetOp);
323   state.addTypes(packedType);
324   state.addOperands(operands);
325   state.addAttributes(op->getAttrs());
326   Operation *newOp = rewriter.createOperation(state);
327 
328   // If the operation produced 0 or 1 result, return them immediately.
329   if (numResults == 0)
330     return rewriter.eraseOp(op), success();
331   if (numResults == 1)
332     return rewriter.replaceOp(op, newOp->getResult(0)), success();
333 
334   // Otherwise, it had been converted to an operation producing a structure.
335   // Extract individual results from the structure and return them as list.
336   SmallVector<Value, 4> results;
337   results.reserve(numResults);
338   for (unsigned i = 0; i < numResults; ++i) {
339     auto type = typeConverter.convertType(op->getResult(i).getType());
340     results.push_back(rewriter.create<LLVM::ExtractValueOp>(
341         op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
342   }
343   rewriter.replaceOp(op, results);
344   return success();
345 }
346