1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/Pattern.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
12 #include "mlir/IR/AffineMap.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // ConvertToLLVMPattern
18 //===----------------------------------------------------------------------===//
19 
20 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
21                                            MLIRContext *context,
22                                            LLVMTypeConverter &typeConverter,
23                                            PatternBenefit benefit)
24     : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
25 
26 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
27   return static_cast<LLVMTypeConverter *>(
28       ConversionPattern::getTypeConverter());
29 }
30 
31 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
32   return *getTypeConverter()->getDialect();
33 }
34 
35 Type ConvertToLLVMPattern::getIndexType() const {
36   return getTypeConverter()->getIndexType();
37 }
38 
39 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
40   return IntegerType::get(&getTypeConverter()->getContext(),
41                           getTypeConverter()->getPointerBitwidth(addressSpace));
42 }
43 
44 Type ConvertToLLVMPattern::getVoidType() const {
45   return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
46 }
47 
48 Type ConvertToLLVMPattern::getVoidPtrType() const {
49   return LLVM::LLVMPointerType::get(
50       IntegerType::get(&getTypeConverter()->getContext(), 8));
51 }
52 
53 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
54                                                     Location loc,
55                                                     Type resultType,
56                                                     int64_t value) {
57   return builder.create<LLVM::ConstantOp>(
58       loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
59 }
60 
61 Value ConvertToLLVMPattern::createIndexConstant(
62     ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
63   return createIndexAttrConstant(builder, loc, getIndexType(), value);
64 }
65 
66 Value ConvertToLLVMPattern::getStridedElementPtr(
67     Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
68     ConversionPatternRewriter &rewriter) const {
69 
70   int64_t offset;
71   SmallVector<int64_t, 4> strides;
72   auto successStrides = getStridesAndOffset(type, strides, offset);
73   assert(succeeded(successStrides) && "unexpected non-strided memref");
74   (void)successStrides;
75 
76   MemRefDescriptor memRefDescriptor(memRefDesc);
77   Value base = memRefDescriptor.alignedPtr(rewriter, loc);
78 
79   Value index;
80   if (offset != 0) // Skip if offset is zero.
81     index = MemRefType::isDynamicStrideOrOffset(offset)
82                 ? memRefDescriptor.offset(rewriter, loc)
83                 : createIndexConstant(rewriter, loc, offset);
84 
85   for (int i = 0, e = indices.size(); i < e; ++i) {
86     Value increment = indices[i];
87     if (strides[i] != 1) { // Skip if stride is 1.
88       Value stride = MemRefType::isDynamicStrideOrOffset(strides[i])
89                          ? memRefDescriptor.stride(rewriter, loc, i)
90                          : createIndexConstant(rewriter, loc, strides[i]);
91       increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
92     }
93     index =
94         index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
95   }
96 
97   Type elementPtrType = memRefDescriptor.getElementPtrType();
98   return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
99                : base;
100 }
101 
102 // Check if the MemRefType `type` is supported by the lowering. We currently
103 // only support memrefs with identity maps.
104 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
105     MemRefType type) const {
106   if (!typeConverter->convertType(type.getElementType()))
107     return false;
108   return type.getAffineMaps().empty() ||
109          llvm::all_of(type.getAffineMaps(),
110                       [](AffineMap map) { return map.isIdentity(); });
111 }
112 
113 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
114   auto elementType = type.getElementType();
115   auto structElementType = typeConverter->convertType(elementType);
116   return LLVM::LLVMPointerType::get(structElementType,
117                                     type.getMemorySpaceAsInt());
118 }
119 
120 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
121     Location loc, MemRefType memRefType, ValueRange dynamicSizes,
122     ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
123     SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
124   assert(isConvertibleAndHasIdentityMaps(memRefType) &&
125          "layout maps must have been normalized away");
126   assert(count(memRefType.getShape(), ShapedType::kDynamicSize) ==
127              static_cast<ssize_t>(dynamicSizes.size()) &&
128          "dynamicSizes size doesn't match dynamic sizes count in memref shape");
129 
130   sizes.reserve(memRefType.getRank());
131   unsigned dynamicIndex = 0;
132   for (int64_t size : memRefType.getShape()) {
133     sizes.push_back(size == ShapedType::kDynamicSize
134                         ? dynamicSizes[dynamicIndex++]
135                         : createIndexConstant(rewriter, loc, size));
136   }
137 
138   // Strides: iterate sizes in reverse order and multiply.
139   int64_t stride = 1;
140   Value runningStride = createIndexConstant(rewriter, loc, 1);
141   strides.resize(memRefType.getRank());
142   for (auto i = memRefType.getRank(); i-- > 0;) {
143     strides[i] = runningStride;
144 
145     int64_t size = memRefType.getShape()[i];
146     if (size == 0)
147       continue;
148     bool useSizeAsStride = stride == 1;
149     if (size == ShapedType::kDynamicSize)
150       stride = ShapedType::kDynamicSize;
151     if (stride != ShapedType::kDynamicSize)
152       stride *= size;
153 
154     if (useSizeAsStride)
155       runningStride = sizes[i];
156     else if (stride == ShapedType::kDynamicSize)
157       runningStride =
158           rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
159     else
160       runningStride = createIndexConstant(rewriter, loc, stride);
161   }
162 
163   // Buffer size in bytes.
164   Type elementPtrType = getElementPtrType(memRefType);
165   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
166   Value gepPtr = rewriter.create<LLVM::GEPOp>(
167       loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride});
168   sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
169 }
170 
171 Value ConvertToLLVMPattern::getSizeInBytes(
172     Location loc, Type type, ConversionPatternRewriter &rewriter) const {
173   // Compute the size of an individual element. This emits the MLIR equivalent
174   // of the following sizeof(...) implementation in LLVM IR:
175   //   %0 = getelementptr %elementType* null, %indexType 1
176   //   %1 = ptrtoint %elementType* %0 to %indexType
177   // which is a common pattern of getting the size of a type in bytes.
178   auto convertedPtrType =
179       LLVM::LLVMPointerType::get(typeConverter->convertType(type));
180   auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
181   auto gep = rewriter.create<LLVM::GEPOp>(
182       loc, convertedPtrType,
183       ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
184   return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
185 }
186 
187 Value ConvertToLLVMPattern::getNumElements(
188     Location loc, ArrayRef<Value> shape,
189     ConversionPatternRewriter &rewriter) const {
190   // Compute the total number of memref elements.
191   Value numElements =
192       shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
193   for (unsigned i = 1, e = shape.size(); i < e; ++i)
194     numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
195   return numElements;
196 }
197 
198 /// Creates and populates the memref descriptor struct given all its fields.
199 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
200     Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
201     ArrayRef<Value> sizes, ArrayRef<Value> strides,
202     ConversionPatternRewriter &rewriter) const {
203   auto structType = typeConverter->convertType(memRefType);
204   auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
205 
206   // Field 1: Allocated pointer, used for malloc/free.
207   memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
208 
209   // Field 2: Actual aligned pointer to payload.
210   memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
211 
212   // Field 3: Offset in aligned pointer.
213   memRefDescriptor.setOffset(rewriter, loc,
214                              createIndexConstant(rewriter, loc, 0));
215 
216   // Fields 4: Sizes.
217   for (auto en : llvm::enumerate(sizes))
218     memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
219 
220   // Field 5: Strides.
221   for (auto en : llvm::enumerate(strides))
222     memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
223 
224   return memRefDescriptor;
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // Detail methods
229 //===----------------------------------------------------------------------===//
230 
231 /// Replaces the given operation "op" with a new operation of type "targetOp"
232 /// and given operands.
233 LogicalResult LLVM::detail::oneToOneRewrite(
234     Operation *op, StringRef targetOp, ValueRange operands,
235     LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
236   unsigned numResults = op->getNumResults();
237 
238   Type packedType;
239   if (numResults != 0) {
240     packedType = typeConverter.packFunctionResults(op->getResultTypes());
241     if (!packedType)
242       return failure();
243   }
244 
245   // Create the operation through state since we don't know its C++ type.
246   OperationState state(op->getLoc(), targetOp);
247   state.addTypes(packedType);
248   state.addOperands(operands);
249   state.addAttributes(op->getAttrs());
250   Operation *newOp = rewriter.createOperation(state);
251 
252   // If the operation produced 0 or 1 result, return them immediately.
253   if (numResults == 0)
254     return rewriter.eraseOp(op), success();
255   if (numResults == 1)
256     return rewriter.replaceOp(op, newOp->getResult(0)), success();
257 
258   // Otherwise, it had been converted to an operation producing a structure.
259   // Extract individual results from the structure and return them as list.
260   SmallVector<Value, 4> results;
261   results.reserve(numResults);
262   for (unsigned i = 0; i < numResults; ++i) {
263     auto type = typeConverter.convertType(op->getResult(i).getType());
264     results.push_back(rewriter.create<LLVM::ExtractValueOp>(
265         op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
266   }
267   rewriter.replaceOp(op, results);
268   return success();
269 }
270