1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/Pattern.h" 10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 13 #include "mlir/IR/AffineMap.h" 14 15 using namespace mlir; 16 17 //===----------------------------------------------------------------------===// 18 // ConvertToLLVMPattern 19 //===----------------------------------------------------------------------===// 20 21 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, 22 MLIRContext *context, 23 LLVMTypeConverter &typeConverter, 24 PatternBenefit benefit) 25 : ConversionPattern(typeConverter, rootOpName, benefit, context) {} 26 27 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { 28 return static_cast<LLVMTypeConverter *>( 29 ConversionPattern::getTypeConverter()); 30 } 31 32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { 33 return *getTypeConverter()->getDialect(); 34 } 35 36 Type ConvertToLLVMPattern::getIndexType() const { 37 return getTypeConverter()->getIndexType(); 38 } 39 40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { 41 return IntegerType::get(&getTypeConverter()->getContext(), 42 getTypeConverter()->getPointerBitwidth(addressSpace)); 43 } 44 45 Type ConvertToLLVMPattern::getVoidType() const { 46 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); 47 } 48 49 Type ConvertToLLVMPattern::getVoidPtrType() const { 50 return LLVM::LLVMPointerType::get( 51 IntegerType::get(&getTypeConverter()->getContext(), 8)); 52 } 53 54 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, 55 Location loc, 56 Type resultType, 57 int64_t value) { 58 return builder.create<LLVM::ConstantOp>( 59 loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); 60 } 61 62 Value ConvertToLLVMPattern::createIndexConstant( 63 ConversionPatternRewriter &builder, Location loc, uint64_t value) const { 64 return createIndexAttrConstant(builder, loc, getIndexType(), value); 65 } 66 67 Value ConvertToLLVMPattern::getStridedElementPtr( 68 Location loc, MemRefType type, Value memRefDesc, ValueRange indices, 69 ConversionPatternRewriter &rewriter) const { 70 71 int64_t offset; 72 SmallVector<int64_t, 4> strides; 73 auto successStrides = getStridesAndOffset(type, strides, offset); 74 assert(succeeded(successStrides) && "unexpected non-strided memref"); 75 (void)successStrides; 76 77 MemRefDescriptor memRefDescriptor(memRefDesc); 78 Value base = memRefDescriptor.alignedPtr(rewriter, loc); 79 80 Value index; 81 if (offset != 0) // Skip if offset is zero. 82 index = ShapedType::isDynamicStrideOrOffset(offset) 83 ? memRefDescriptor.offset(rewriter, loc) 84 : createIndexConstant(rewriter, loc, offset); 85 86 for (int i = 0, e = indices.size(); i < e; ++i) { 87 Value increment = indices[i]; 88 if (strides[i] != 1) { // Skip if stride is 1. 89 Value stride = ShapedType::isDynamicStrideOrOffset(strides[i]) 90 ? memRefDescriptor.stride(rewriter, loc, i) 91 : createIndexConstant(rewriter, loc, strides[i]); 92 increment = rewriter.create<LLVM::MulOp>(loc, increment, stride); 93 } 94 index = 95 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; 96 } 97 98 Type elementPtrType = memRefDescriptor.getElementPtrType(); 99 return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index) 100 : base; 101 } 102 103 // Check if the MemRefType `type` is supported by the lowering. We currently 104 // only support memrefs with identity maps. 105 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( 106 MemRefType type) const { 107 if (!typeConverter->convertType(type.getElementType())) 108 return false; 109 return type.getLayout().isIdentity(); 110 } 111 112 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { 113 auto elementType = type.getElementType(); 114 auto structElementType = typeConverter->convertType(elementType); 115 return LLVM::LLVMPointerType::get(structElementType, 116 type.getMemorySpaceAsInt()); 117 } 118 119 void ConvertToLLVMPattern::getMemRefDescriptorSizes( 120 Location loc, MemRefType memRefType, ValueRange dynamicSizes, 121 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, 122 SmallVectorImpl<Value> &strides, Value &sizeBytes) const { 123 assert(isConvertibleAndHasIdentityMaps(memRefType) && 124 "layout maps must have been normalized away"); 125 assert(count(memRefType.getShape(), ShapedType::kDynamicSize) == 126 static_cast<ssize_t>(dynamicSizes.size()) && 127 "dynamicSizes size doesn't match dynamic sizes count in memref shape"); 128 129 sizes.reserve(memRefType.getRank()); 130 unsigned dynamicIndex = 0; 131 for (int64_t size : memRefType.getShape()) { 132 sizes.push_back(size == ShapedType::kDynamicSize 133 ? dynamicSizes[dynamicIndex++] 134 : createIndexConstant(rewriter, loc, size)); 135 } 136 137 // Strides: iterate sizes in reverse order and multiply. 138 int64_t stride = 1; 139 Value runningStride = createIndexConstant(rewriter, loc, 1); 140 strides.resize(memRefType.getRank()); 141 for (auto i = memRefType.getRank(); i-- > 0;) { 142 strides[i] = runningStride; 143 144 int64_t size = memRefType.getShape()[i]; 145 if (size == 0) 146 continue; 147 bool useSizeAsStride = stride == 1; 148 if (size == ShapedType::kDynamicSize) 149 stride = ShapedType::kDynamicSize; 150 if (stride != ShapedType::kDynamicSize) 151 stride *= size; 152 153 if (useSizeAsStride) 154 runningStride = sizes[i]; 155 else if (stride == ShapedType::kDynamicSize) 156 runningStride = 157 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); 158 else 159 runningStride = createIndexConstant(rewriter, loc, stride); 160 } 161 162 // Buffer size in bytes. 163 Type elementPtrType = getElementPtrType(memRefType); 164 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType); 165 Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr, 166 ArrayRef<Value>{runningStride}); 167 sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 168 } 169 170 Value ConvertToLLVMPattern::getSizeInBytes( 171 Location loc, Type type, ConversionPatternRewriter &rewriter) const { 172 // Compute the size of an individual element. This emits the MLIR equivalent 173 // of the following sizeof(...) implementation in LLVM IR: 174 // %0 = getelementptr %elementType* null, %indexType 1 175 // %1 = ptrtoint %elementType* %0 to %indexType 176 // which is a common pattern of getting the size of a type in bytes. 177 auto convertedPtrType = 178 LLVM::LLVMPointerType::get(typeConverter->convertType(type)); 179 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType); 180 auto gep = rewriter.create<LLVM::GEPOp>( 181 loc, convertedPtrType, nullPtr, 182 ArrayRef<Value>{createIndexConstant(rewriter, loc, 1)}); 183 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); 184 } 185 186 Value ConvertToLLVMPattern::getNumElements( 187 Location loc, ArrayRef<Value> shape, 188 ConversionPatternRewriter &rewriter) const { 189 // Compute the total number of memref elements. 190 Value numElements = 191 shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); 192 for (unsigned i = 1, e = shape.size(); i < e; ++i) 193 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]); 194 return numElements; 195 } 196 197 /// Creates and populates the memref descriptor struct given all its fields. 198 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( 199 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, 200 ArrayRef<Value> sizes, ArrayRef<Value> strides, 201 ConversionPatternRewriter &rewriter) const { 202 auto structType = typeConverter->convertType(memRefType); 203 auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); 204 205 // Field 1: Allocated pointer, used for malloc/free. 206 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); 207 208 // Field 2: Actual aligned pointer to payload. 209 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); 210 211 // Field 3: Offset in aligned pointer. 212 memRefDescriptor.setOffset(rewriter, loc, 213 createIndexConstant(rewriter, loc, 0)); 214 215 // Fields 4: Sizes. 216 for (const auto &en : llvm::enumerate(sizes)) 217 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); 218 219 // Field 5: Strides. 220 for (const auto &en : llvm::enumerate(strides)) 221 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); 222 223 return memRefDescriptor; 224 } 225 226 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( 227 OpBuilder &builder, Location loc, TypeRange origTypes, 228 SmallVectorImpl<Value> &operands, bool toDynamic) const { 229 assert(origTypes.size() == operands.size() && 230 "expected as may original types as operands"); 231 232 // Find operands of unranked memref type and store them. 233 SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs; 234 for (unsigned i = 0, e = operands.size(); i < e; ++i) 235 if (origTypes[i].isa<UnrankedMemRefType>()) 236 unrankedMemrefs.emplace_back(operands[i]); 237 238 if (unrankedMemrefs.empty()) 239 return success(); 240 241 // Compute allocation sizes. 242 SmallVector<Value, 4> sizes; 243 UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), 244 unrankedMemrefs, sizes); 245 246 // Get frequently used types. 247 MLIRContext *context = builder.getContext(); 248 Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 249 auto i1Type = IntegerType::get(context, 1); 250 Type indexType = getTypeConverter()->getIndexType(); 251 252 // Find the malloc and free, or declare them if necessary. 253 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); 254 LLVM::LLVMFuncOp freeFunc, mallocFunc; 255 if (toDynamic) 256 mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); 257 if (!toDynamic) 258 freeFunc = LLVM::lookupOrCreateFreeFn(module); 259 260 // Initialize shared constants. 261 Value zero = 262 builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false)); 263 264 unsigned unrankedMemrefPos = 0; 265 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 266 Type type = origTypes[i]; 267 if (!type.isa<UnrankedMemRefType>()) 268 continue; 269 Value allocationSize = sizes[unrankedMemrefPos++]; 270 UnrankedMemRefDescriptor desc(operands[i]); 271 272 // Allocate memory, copy, and free the source if necessary. 273 Value memory = 274 toDynamic 275 ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize) 276 .getResult(0) 277 : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize, 278 /*alignment=*/0); 279 Value source = desc.memRefDescPtr(builder, loc); 280 builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero); 281 if (!toDynamic) 282 builder.create<LLVM::CallOp>(loc, freeFunc, source); 283 284 // Create a new descriptor. The same descriptor can be returned multiple 285 // times, attempting to modify its pointer can lead to memory leaks 286 // (allocated twice and overwritten) or double frees (the caller does not 287 // know if the descriptor points to the same memory). 288 Type descriptorType = getTypeConverter()->convertType(type); 289 if (!descriptorType) 290 return failure(); 291 auto updatedDesc = 292 UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); 293 Value rank = desc.rank(builder, loc); 294 updatedDesc.setRank(builder, loc, rank); 295 updatedDesc.setMemRefDescPtr(builder, loc, memory); 296 297 operands[i] = updatedDesc; 298 } 299 300 return success(); 301 } 302 303 //===----------------------------------------------------------------------===// 304 // Detail methods 305 //===----------------------------------------------------------------------===// 306 307 /// Replaces the given operation "op" with a new operation of type "targetOp" 308 /// and given operands. 309 LogicalResult LLVM::detail::oneToOneRewrite( 310 Operation *op, StringRef targetOp, ValueRange operands, 311 LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { 312 unsigned numResults = op->getNumResults(); 313 314 Type packedType; 315 if (numResults != 0) { 316 packedType = typeConverter.packFunctionResults(op->getResultTypes()); 317 if (!packedType) 318 return failure(); 319 } 320 321 // Create the operation through state since we don't know its C++ type. 322 OperationState state(op->getLoc(), targetOp); 323 state.addTypes(packedType); 324 state.addOperands(operands); 325 state.addAttributes(op->getAttrs()); 326 Operation *newOp = rewriter.createOperation(state); 327 328 // If the operation produced 0 or 1 result, return them immediately. 329 if (numResults == 0) 330 return rewriter.eraseOp(op), success(); 331 if (numResults == 1) 332 return rewriter.replaceOp(op, newOp->getResult(0)), success(); 333 334 // Otherwise, it had been converted to an operation producing a structure. 335 // Extract individual results from the structure and return them as list. 336 SmallVector<Value, 4> results; 337 results.reserve(numResults); 338 for (unsigned i = 0; i < numResults; ++i) { 339 auto type = typeConverter.convertType(op->getResult(i).getType()); 340 results.push_back(rewriter.create<LLVM::ExtractValueOp>( 341 op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); 342 } 343 rewriter.replaceOp(op, results); 344 return success(); 345 } 346