1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/Pattern.h" 10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 11 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 12 #include "mlir/IR/AffineMap.h" 13 14 using namespace mlir; 15 16 //===----------------------------------------------------------------------===// 17 // ConvertToLLVMPattern 18 //===----------------------------------------------------------------------===// 19 20 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, 21 MLIRContext *context, 22 LLVMTypeConverter &typeConverter, 23 PatternBenefit benefit) 24 : ConversionPattern(typeConverter, rootOpName, benefit, context) {} 25 26 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { 27 return static_cast<LLVMTypeConverter *>( 28 ConversionPattern::getTypeConverter()); 29 } 30 31 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { 32 return *getTypeConverter()->getDialect(); 33 } 34 35 Type ConvertToLLVMPattern::getIndexType() const { 36 return getTypeConverter()->getIndexType(); 37 } 38 39 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { 40 return IntegerType::get(&getTypeConverter()->getContext(), 41 getTypeConverter()->getPointerBitwidth(addressSpace)); 42 } 43 44 Type ConvertToLLVMPattern::getVoidType() const { 45 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); 46 } 47 48 Type ConvertToLLVMPattern::getVoidPtrType() const { 49 return LLVM::LLVMPointerType::get( 50 IntegerType::get(&getTypeConverter()->getContext(), 8)); 51 } 52 53 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, 54 Location loc, 55 Type resultType, 56 int64_t value) { 57 return builder.create<LLVM::ConstantOp>( 58 loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); 59 } 60 61 Value ConvertToLLVMPattern::createIndexConstant( 62 ConversionPatternRewriter &builder, Location loc, uint64_t value) const { 63 return createIndexAttrConstant(builder, loc, getIndexType(), value); 64 } 65 66 Value ConvertToLLVMPattern::getStridedElementPtr( 67 Location loc, MemRefType type, Value memRefDesc, ValueRange indices, 68 ConversionPatternRewriter &rewriter) const { 69 70 int64_t offset; 71 SmallVector<int64_t, 4> strides; 72 auto successStrides = getStridesAndOffset(type, strides, offset); 73 assert(succeeded(successStrides) && "unexpected non-strided memref"); 74 (void)successStrides; 75 76 MemRefDescriptor memRefDescriptor(memRefDesc); 77 Value base = memRefDescriptor.alignedPtr(rewriter, loc); 78 79 Value index; 80 if (offset != 0) // Skip if offset is zero. 81 index = MemRefType::isDynamicStrideOrOffset(offset) 82 ? memRefDescriptor.offset(rewriter, loc) 83 : createIndexConstant(rewriter, loc, offset); 84 85 for (int i = 0, e = indices.size(); i < e; ++i) { 86 Value increment = indices[i]; 87 if (strides[i] != 1) { // Skip if stride is 1. 88 Value stride = MemRefType::isDynamicStrideOrOffset(strides[i]) 89 ? memRefDescriptor.stride(rewriter, loc, i) 90 : createIndexConstant(rewriter, loc, strides[i]); 91 increment = rewriter.create<LLVM::MulOp>(loc, increment, stride); 92 } 93 index = 94 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; 95 } 96 97 Type elementPtrType = memRefDescriptor.getElementPtrType(); 98 return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index) 99 : base; 100 } 101 102 // Check if the MemRefType `type` is supported by the lowering. We currently 103 // only support memrefs with identity maps. 104 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( 105 MemRefType type) const { 106 if (!typeConverter->convertType(type.getElementType())) 107 return false; 108 return type.getAffineMaps().empty() || 109 llvm::all_of(type.getAffineMaps(), 110 [](AffineMap map) { return map.isIdentity(); }); 111 } 112 113 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { 114 auto elementType = type.getElementType(); 115 auto structElementType = typeConverter->convertType(elementType); 116 return LLVM::LLVMPointerType::get(structElementType, 117 type.getMemorySpaceAsInt()); 118 } 119 120 void ConvertToLLVMPattern::getMemRefDescriptorSizes( 121 Location loc, MemRefType memRefType, ValueRange dynamicSizes, 122 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, 123 SmallVectorImpl<Value> &strides, Value &sizeBytes) const { 124 assert(isConvertibleAndHasIdentityMaps(memRefType) && 125 "layout maps must have been normalized away"); 126 assert(count(memRefType.getShape(), ShapedType::kDynamicSize) == 127 static_cast<ssize_t>(dynamicSizes.size()) && 128 "dynamicSizes size doesn't match dynamic sizes count in memref shape"); 129 130 sizes.reserve(memRefType.getRank()); 131 unsigned dynamicIndex = 0; 132 for (int64_t size : memRefType.getShape()) { 133 sizes.push_back(size == ShapedType::kDynamicSize 134 ? dynamicSizes[dynamicIndex++] 135 : createIndexConstant(rewriter, loc, size)); 136 } 137 138 // Strides: iterate sizes in reverse order and multiply. 139 int64_t stride = 1; 140 Value runningStride = createIndexConstant(rewriter, loc, 1); 141 strides.resize(memRefType.getRank()); 142 for (auto i = memRefType.getRank(); i-- > 0;) { 143 strides[i] = runningStride; 144 145 int64_t size = memRefType.getShape()[i]; 146 if (size == 0) 147 continue; 148 bool useSizeAsStride = stride == 1; 149 if (size == ShapedType::kDynamicSize) 150 stride = ShapedType::kDynamicSize; 151 if (stride != ShapedType::kDynamicSize) 152 stride *= size; 153 154 if (useSizeAsStride) 155 runningStride = sizes[i]; 156 else if (stride == ShapedType::kDynamicSize) 157 runningStride = 158 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); 159 else 160 runningStride = createIndexConstant(rewriter, loc, stride); 161 } 162 163 // Buffer size in bytes. 164 Type elementPtrType = getElementPtrType(memRefType); 165 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType); 166 Value gepPtr = rewriter.create<LLVM::GEPOp>( 167 loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride}); 168 sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 169 } 170 171 Value ConvertToLLVMPattern::getSizeInBytes( 172 Location loc, Type type, ConversionPatternRewriter &rewriter) const { 173 // Compute the size of an individual element. This emits the MLIR equivalent 174 // of the following sizeof(...) implementation in LLVM IR: 175 // %0 = getelementptr %elementType* null, %indexType 1 176 // %1 = ptrtoint %elementType* %0 to %indexType 177 // which is a common pattern of getting the size of a type in bytes. 178 auto convertedPtrType = 179 LLVM::LLVMPointerType::get(typeConverter->convertType(type)); 180 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType); 181 auto gep = rewriter.create<LLVM::GEPOp>( 182 loc, convertedPtrType, 183 ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)}); 184 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); 185 } 186 187 Value ConvertToLLVMPattern::getNumElements( 188 Location loc, ArrayRef<Value> shape, 189 ConversionPatternRewriter &rewriter) const { 190 // Compute the total number of memref elements. 191 Value numElements = 192 shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); 193 for (unsigned i = 1, e = shape.size(); i < e; ++i) 194 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]); 195 return numElements; 196 } 197 198 /// Creates and populates the memref descriptor struct given all its fields. 199 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( 200 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, 201 ArrayRef<Value> sizes, ArrayRef<Value> strides, 202 ConversionPatternRewriter &rewriter) const { 203 auto structType = typeConverter->convertType(memRefType); 204 auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); 205 206 // Field 1: Allocated pointer, used for malloc/free. 207 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); 208 209 // Field 2: Actual aligned pointer to payload. 210 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); 211 212 // Field 3: Offset in aligned pointer. 213 memRefDescriptor.setOffset(rewriter, loc, 214 createIndexConstant(rewriter, loc, 0)); 215 216 // Fields 4: Sizes. 217 for (auto en : llvm::enumerate(sizes)) 218 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); 219 220 // Field 5: Strides. 221 for (auto en : llvm::enumerate(strides)) 222 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); 223 224 return memRefDescriptor; 225 } 226 227 //===----------------------------------------------------------------------===// 228 // Detail methods 229 //===----------------------------------------------------------------------===// 230 231 /// Replaces the given operation "op" with a new operation of type "targetOp" 232 /// and given operands. 233 LogicalResult LLVM::detail::oneToOneRewrite( 234 Operation *op, StringRef targetOp, ValueRange operands, 235 LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { 236 unsigned numResults = op->getNumResults(); 237 238 Type packedType; 239 if (numResults != 0) { 240 packedType = typeConverter.packFunctionResults(op->getResultTypes()); 241 if (!packedType) 242 return failure(); 243 } 244 245 // Create the operation through state since we don't know its C++ type. 246 OperationState state(op->getLoc(), targetOp); 247 state.addTypes(packedType); 248 state.addOperands(operands); 249 state.addAttributes(op->getAttrs()); 250 Operation *newOp = rewriter.createOperation(state); 251 252 // If the operation produced 0 or 1 result, return them immediately. 253 if (numResults == 0) 254 return rewriter.eraseOp(op), success(); 255 if (numResults == 1) 256 return rewriter.replaceOp(op, newOp->getResult(0)), success(); 257 258 // Otherwise, it had been converted to an operation producing a structure. 259 // Extract individual results from the structure and return them as list. 260 SmallVector<Value, 4> results; 261 results.reserve(numResults); 262 for (unsigned i = 0; i < numResults; ++i) { 263 auto type = typeConverter.convertType(op->getResult(i).getType()); 264 results.push_back(rewriter.create<LLVM::ExtractValueOp>( 265 op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); 266 } 267 rewriter.replaceOp(op, results); 268 return success(); 269 } 270