1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/Pattern.h" 10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 13 #include "mlir/IR/AffineMap.h" 14 #include "mlir/IR/BuiltinAttributes.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // ConvertToLLVMPattern 20 //===----------------------------------------------------------------------===// 21 22 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, 23 MLIRContext *context, 24 LLVMTypeConverter &typeConverter, 25 PatternBenefit benefit) 26 : ConversionPattern(typeConverter, rootOpName, benefit, context) {} 27 28 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { 29 return static_cast<LLVMTypeConverter *>( 30 ConversionPattern::getTypeConverter()); 31 } 32 33 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { 34 return *getTypeConverter()->getDialect(); 35 } 36 37 Type ConvertToLLVMPattern::getIndexType() const { 38 return getTypeConverter()->getIndexType(); 39 } 40 41 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { 42 return IntegerType::get(&getTypeConverter()->getContext(), 43 getTypeConverter()->getPointerBitwidth(addressSpace)); 44 } 45 46 Type ConvertToLLVMPattern::getVoidType() const { 47 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); 48 } 49 50 Type ConvertToLLVMPattern::getVoidPtrType() const { 51 return LLVM::LLVMPointerType::get( 52 IntegerType::get(&getTypeConverter()->getContext(), 8)); 53 } 54 55 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, 56 Location loc, 57 Type resultType, 58 int64_t value) { 59 return builder.create<LLVM::ConstantOp>( 60 loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); 61 } 62 63 Value ConvertToLLVMPattern::createIndexConstant( 64 ConversionPatternRewriter &builder, Location loc, uint64_t value) const { 65 return createIndexAttrConstant(builder, loc, getIndexType(), value); 66 } 67 68 Value ConvertToLLVMPattern::getStridedElementPtr( 69 Location loc, MemRefType type, Value memRefDesc, ValueRange indices, 70 ConversionPatternRewriter &rewriter) const { 71 72 int64_t offset; 73 SmallVector<int64_t, 4> strides; 74 auto successStrides = getStridesAndOffset(type, strides, offset); 75 assert(succeeded(successStrides) && "unexpected non-strided memref"); 76 (void)successStrides; 77 78 MemRefDescriptor memRefDescriptor(memRefDesc); 79 Value base = memRefDescriptor.alignedPtr(rewriter, loc); 80 81 Value index; 82 if (offset != 0) // Skip if offset is zero. 83 index = ShapedType::isDynamicStrideOrOffset(offset) 84 ? memRefDescriptor.offset(rewriter, loc) 85 : createIndexConstant(rewriter, loc, offset); 86 87 for (int i = 0, e = indices.size(); i < e; ++i) { 88 Value increment = indices[i]; 89 if (strides[i] != 1) { // Skip if stride is 1. 90 Value stride = ShapedType::isDynamicStrideOrOffset(strides[i]) 91 ? memRefDescriptor.stride(rewriter, loc, i) 92 : createIndexConstant(rewriter, loc, strides[i]); 93 increment = rewriter.create<LLVM::MulOp>(loc, increment, stride); 94 } 95 index = 96 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; 97 } 98 99 Type elementPtrType = memRefDescriptor.getElementPtrType(); 100 return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index) 101 : base; 102 } 103 104 // Check if the MemRefType `type` is supported by the lowering. We currently 105 // only support memrefs with identity maps. 106 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( 107 MemRefType type) const { 108 if (!typeConverter->convertType(type.getElementType())) 109 return false; 110 return type.getLayout().isIdentity(); 111 } 112 113 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { 114 auto elementType = type.getElementType(); 115 auto structElementType = typeConverter->convertType(elementType); 116 return LLVM::LLVMPointerType::get(structElementType, 117 type.getMemorySpaceAsInt()); 118 } 119 120 void ConvertToLLVMPattern::getMemRefDescriptorSizes( 121 Location loc, MemRefType memRefType, ValueRange dynamicSizes, 122 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, 123 SmallVectorImpl<Value> &strides, Value &sizeBytes) const { 124 assert(isConvertibleAndHasIdentityMaps(memRefType) && 125 "layout maps must have been normalized away"); 126 assert(count(memRefType.getShape(), ShapedType::kDynamicSize) == 127 static_cast<ssize_t>(dynamicSizes.size()) && 128 "dynamicSizes size doesn't match dynamic sizes count in memref shape"); 129 130 sizes.reserve(memRefType.getRank()); 131 unsigned dynamicIndex = 0; 132 for (int64_t size : memRefType.getShape()) { 133 sizes.push_back(size == ShapedType::kDynamicSize 134 ? dynamicSizes[dynamicIndex++] 135 : createIndexConstant(rewriter, loc, size)); 136 } 137 138 // Strides: iterate sizes in reverse order and multiply. 139 int64_t stride = 1; 140 Value runningStride = createIndexConstant(rewriter, loc, 1); 141 strides.resize(memRefType.getRank()); 142 for (auto i = memRefType.getRank(); i-- > 0;) { 143 strides[i] = runningStride; 144 145 int64_t size = memRefType.getShape()[i]; 146 if (size == 0) 147 continue; 148 bool useSizeAsStride = stride == 1; 149 if (size == ShapedType::kDynamicSize) 150 stride = ShapedType::kDynamicSize; 151 if (stride != ShapedType::kDynamicSize) 152 stride *= size; 153 154 if (useSizeAsStride) 155 runningStride = sizes[i]; 156 else if (stride == ShapedType::kDynamicSize) 157 runningStride = 158 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); 159 else 160 runningStride = createIndexConstant(rewriter, loc, stride); 161 } 162 163 // Buffer size in bytes. 164 Type elementPtrType = getElementPtrType(memRefType); 165 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType); 166 Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr, 167 ArrayRef<Value>{runningStride}); 168 sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 169 } 170 171 Value ConvertToLLVMPattern::getSizeInBytes( 172 Location loc, Type type, ConversionPatternRewriter &rewriter) const { 173 // Compute the size of an individual element. This emits the MLIR equivalent 174 // of the following sizeof(...) implementation in LLVM IR: 175 // %0 = getelementptr %elementType* null, %indexType 1 176 // %1 = ptrtoint %elementType* %0 to %indexType 177 // which is a common pattern of getting the size of a type in bytes. 178 auto convertedPtrType = 179 LLVM::LLVMPointerType::get(typeConverter->convertType(type)); 180 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType); 181 auto gep = rewriter.create<LLVM::GEPOp>( 182 loc, convertedPtrType, nullPtr, 183 ArrayRef<Value>{createIndexConstant(rewriter, loc, 1)}); 184 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); 185 } 186 187 Value ConvertToLLVMPattern::getNumElements( 188 Location loc, ArrayRef<Value> shape, 189 ConversionPatternRewriter &rewriter) const { 190 // Compute the total number of memref elements. 191 Value numElements = 192 shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); 193 for (unsigned i = 1, e = shape.size(); i < e; ++i) 194 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]); 195 return numElements; 196 } 197 198 /// Creates and populates the memref descriptor struct given all its fields. 199 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( 200 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, 201 ArrayRef<Value> sizes, ArrayRef<Value> strides, 202 ConversionPatternRewriter &rewriter) const { 203 auto structType = typeConverter->convertType(memRefType); 204 auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); 205 206 // Field 1: Allocated pointer, used for malloc/free. 207 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); 208 209 // Field 2: Actual aligned pointer to payload. 210 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); 211 212 // Field 3: Offset in aligned pointer. 213 memRefDescriptor.setOffset(rewriter, loc, 214 createIndexConstant(rewriter, loc, 0)); 215 216 // Fields 4: Sizes. 217 for (const auto &en : llvm::enumerate(sizes)) 218 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); 219 220 // Field 5: Strides. 221 for (const auto &en : llvm::enumerate(strides)) 222 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); 223 224 return memRefDescriptor; 225 } 226 227 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( 228 OpBuilder &builder, Location loc, TypeRange origTypes, 229 SmallVectorImpl<Value> &operands, bool toDynamic) const { 230 assert(origTypes.size() == operands.size() && 231 "expected as may original types as operands"); 232 233 // Find operands of unranked memref type and store them. 234 SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs; 235 for (unsigned i = 0, e = operands.size(); i < e; ++i) 236 if (origTypes[i].isa<UnrankedMemRefType>()) 237 unrankedMemrefs.emplace_back(operands[i]); 238 239 if (unrankedMemrefs.empty()) 240 return success(); 241 242 // Compute allocation sizes. 243 SmallVector<Value, 4> sizes; 244 UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), 245 unrankedMemrefs, sizes); 246 247 // Get frequently used types. 248 MLIRContext *context = builder.getContext(); 249 Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 250 auto i1Type = IntegerType::get(context, 1); 251 Type indexType = getTypeConverter()->getIndexType(); 252 253 // Find the malloc and free, or declare them if necessary. 254 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); 255 LLVM::LLVMFuncOp freeFunc, mallocFunc; 256 if (toDynamic) 257 mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); 258 if (!toDynamic) 259 freeFunc = LLVM::lookupOrCreateFreeFn(module); 260 261 // Initialize shared constants. 262 Value zero = 263 builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false)); 264 265 unsigned unrankedMemrefPos = 0; 266 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 267 Type type = origTypes[i]; 268 if (!type.isa<UnrankedMemRefType>()) 269 continue; 270 Value allocationSize = sizes[unrankedMemrefPos++]; 271 UnrankedMemRefDescriptor desc(operands[i]); 272 273 // Allocate memory, copy, and free the source if necessary. 274 Value memory = 275 toDynamic 276 ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize) 277 .getResult(0) 278 : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize, 279 /*alignment=*/0); 280 Value source = desc.memRefDescPtr(builder, loc); 281 builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero); 282 if (!toDynamic) 283 builder.create<LLVM::CallOp>(loc, freeFunc, source); 284 285 // Create a new descriptor. The same descriptor can be returned multiple 286 // times, attempting to modify its pointer can lead to memory leaks 287 // (allocated twice and overwritten) or double frees (the caller does not 288 // know if the descriptor points to the same memory). 289 Type descriptorType = getTypeConverter()->convertType(type); 290 if (!descriptorType) 291 return failure(); 292 auto updatedDesc = 293 UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); 294 Value rank = desc.rank(builder, loc); 295 updatedDesc.setRank(builder, loc, rank); 296 updatedDesc.setMemRefDescPtr(builder, loc, memory); 297 298 operands[i] = updatedDesc; 299 } 300 301 return success(); 302 } 303 304 //===----------------------------------------------------------------------===// 305 // Detail methods 306 //===----------------------------------------------------------------------===// 307 308 /// Replaces the given operation "op" with a new operation of type "targetOp" 309 /// and given operands. 310 LogicalResult LLVM::detail::oneToOneRewrite( 311 Operation *op, StringRef targetOp, ValueRange operands, 312 LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { 313 unsigned numResults = op->getNumResults(); 314 315 Type packedType; 316 if (numResults != 0) { 317 packedType = typeConverter.packFunctionResults(op->getResultTypes()); 318 if (!packedType) 319 return failure(); 320 } 321 322 // Create the operation through state since we don't know its C++ type. 323 Operation *newOp = 324 rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, 325 packedType, op->getAttrs()); 326 327 // If the operation produced 0 or 1 result, return them immediately. 328 if (numResults == 0) 329 return rewriter.eraseOp(op), success(); 330 if (numResults == 1) 331 return rewriter.replaceOp(op, newOp->getResult(0)), success(); 332 333 // Otherwise, it had been converted to an operation producing a structure. 334 // Extract individual results from the structure and return them as list. 335 SmallVector<Value, 4> results; 336 results.reserve(numResults); 337 for (unsigned i = 0; i < numResults; ++i) { 338 auto type = typeConverter.convertType(op->getResult(i).getType()); 339 results.push_back(rewriter.create<LLVM::ExtractValueOp>( 340 op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); 341 } 342 rewriter.replaceOp(op, results); 343 return success(); 344 } 345