1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/Pattern.h" 10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 13 #include "mlir/IR/AffineMap.h" 14 15 using namespace mlir; 16 17 //===----------------------------------------------------------------------===// 18 // ConvertToLLVMPattern 19 //===----------------------------------------------------------------------===// 20 21 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, 22 MLIRContext *context, 23 LLVMTypeConverter &typeConverter, 24 PatternBenefit benefit) 25 : ConversionPattern(typeConverter, rootOpName, benefit, context) {} 26 27 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { 28 return static_cast<LLVMTypeConverter *>( 29 ConversionPattern::getTypeConverter()); 30 } 31 32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { 33 return *getTypeConverter()->getDialect(); 34 } 35 36 Type ConvertToLLVMPattern::getIndexType() const { 37 return getTypeConverter()->getIndexType(); 38 } 39 40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { 41 return IntegerType::get(&getTypeConverter()->getContext(), 42 getTypeConverter()->getPointerBitwidth(addressSpace)); 43 } 44 45 Type ConvertToLLVMPattern::getVoidType() const { 46 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); 47 } 48 49 Type ConvertToLLVMPattern::getVoidPtrType() const { 50 return LLVM::LLVMPointerType::get( 51 IntegerType::get(&getTypeConverter()->getContext(), 8)); 52 } 53 54 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, 55 Location loc, 56 Type resultType, 57 int64_t value) { 58 return builder.create<LLVM::ConstantOp>( 59 loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); 60 } 61 62 Value ConvertToLLVMPattern::createIndexConstant( 63 ConversionPatternRewriter &builder, Location loc, uint64_t value) const { 64 return createIndexAttrConstant(builder, loc, getIndexType(), value); 65 } 66 67 Value ConvertToLLVMPattern::getStridedElementPtr( 68 Location loc, MemRefType type, Value memRefDesc, ValueRange indices, 69 ConversionPatternRewriter &rewriter) const { 70 71 int64_t offset; 72 SmallVector<int64_t, 4> strides; 73 auto successStrides = getStridesAndOffset(type, strides, offset); 74 assert(succeeded(successStrides) && "unexpected non-strided memref"); 75 (void)successStrides; 76 77 MemRefDescriptor memRefDescriptor(memRefDesc); 78 Value base = memRefDescriptor.alignedPtr(rewriter, loc); 79 80 Value index; 81 if (offset != 0) // Skip if offset is zero. 82 index = MemRefType::isDynamicStrideOrOffset(offset) 83 ? memRefDescriptor.offset(rewriter, loc) 84 : createIndexConstant(rewriter, loc, offset); 85 86 for (int i = 0, e = indices.size(); i < e; ++i) { 87 Value increment = indices[i]; 88 if (strides[i] != 1) { // Skip if stride is 1. 89 Value stride = MemRefType::isDynamicStrideOrOffset(strides[i]) 90 ? memRefDescriptor.stride(rewriter, loc, i) 91 : createIndexConstant(rewriter, loc, strides[i]); 92 increment = rewriter.create<LLVM::MulOp>(loc, increment, stride); 93 } 94 index = 95 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; 96 } 97 98 Type elementPtrType = memRefDescriptor.getElementPtrType(); 99 return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index) 100 : base; 101 } 102 103 // Check if the MemRefType `type` is supported by the lowering. We currently 104 // only support memrefs with identity maps. 105 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( 106 MemRefType type) const { 107 if (!typeConverter->convertType(type.getElementType())) 108 return false; 109 return type.getAffineMaps().empty() || 110 llvm::all_of(type.getAffineMaps(), 111 [](AffineMap map) { return map.isIdentity(); }); 112 } 113 114 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { 115 auto elementType = type.getElementType(); 116 auto structElementType = typeConverter->convertType(elementType); 117 return LLVM::LLVMPointerType::get(structElementType, 118 type.getMemorySpaceAsInt()); 119 } 120 121 void ConvertToLLVMPattern::getMemRefDescriptorSizes( 122 Location loc, MemRefType memRefType, ValueRange dynamicSizes, 123 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, 124 SmallVectorImpl<Value> &strides, Value &sizeBytes) const { 125 assert(isConvertibleAndHasIdentityMaps(memRefType) && 126 "layout maps must have been normalized away"); 127 assert(count(memRefType.getShape(), ShapedType::kDynamicSize) == 128 static_cast<ssize_t>(dynamicSizes.size()) && 129 "dynamicSizes size doesn't match dynamic sizes count in memref shape"); 130 131 sizes.reserve(memRefType.getRank()); 132 unsigned dynamicIndex = 0; 133 for (int64_t size : memRefType.getShape()) { 134 sizes.push_back(size == ShapedType::kDynamicSize 135 ? dynamicSizes[dynamicIndex++] 136 : createIndexConstant(rewriter, loc, size)); 137 } 138 139 // Strides: iterate sizes in reverse order and multiply. 140 int64_t stride = 1; 141 Value runningStride = createIndexConstant(rewriter, loc, 1); 142 strides.resize(memRefType.getRank()); 143 for (auto i = memRefType.getRank(); i-- > 0;) { 144 strides[i] = runningStride; 145 146 int64_t size = memRefType.getShape()[i]; 147 if (size == 0) 148 continue; 149 bool useSizeAsStride = stride == 1; 150 if (size == ShapedType::kDynamicSize) 151 stride = ShapedType::kDynamicSize; 152 if (stride != ShapedType::kDynamicSize) 153 stride *= size; 154 155 if (useSizeAsStride) 156 runningStride = sizes[i]; 157 else if (stride == ShapedType::kDynamicSize) 158 runningStride = 159 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); 160 else 161 runningStride = createIndexConstant(rewriter, loc, stride); 162 } 163 164 // Buffer size in bytes. 165 Type elementPtrType = getElementPtrType(memRefType); 166 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType); 167 Value gepPtr = rewriter.create<LLVM::GEPOp>( 168 loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride}); 169 sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 170 } 171 172 Value ConvertToLLVMPattern::getSizeInBytes( 173 Location loc, Type type, ConversionPatternRewriter &rewriter) const { 174 // Compute the size of an individual element. This emits the MLIR equivalent 175 // of the following sizeof(...) implementation in LLVM IR: 176 // %0 = getelementptr %elementType* null, %indexType 1 177 // %1 = ptrtoint %elementType* %0 to %indexType 178 // which is a common pattern of getting the size of a type in bytes. 179 auto convertedPtrType = 180 LLVM::LLVMPointerType::get(typeConverter->convertType(type)); 181 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType); 182 auto gep = rewriter.create<LLVM::GEPOp>( 183 loc, convertedPtrType, 184 ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)}); 185 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); 186 } 187 188 Value ConvertToLLVMPattern::getNumElements( 189 Location loc, ArrayRef<Value> shape, 190 ConversionPatternRewriter &rewriter) const { 191 // Compute the total number of memref elements. 192 Value numElements = 193 shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); 194 for (unsigned i = 1, e = shape.size(); i < e; ++i) 195 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]); 196 return numElements; 197 } 198 199 /// Creates and populates the memref descriptor struct given all its fields. 200 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( 201 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, 202 ArrayRef<Value> sizes, ArrayRef<Value> strides, 203 ConversionPatternRewriter &rewriter) const { 204 auto structType = typeConverter->convertType(memRefType); 205 auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); 206 207 // Field 1: Allocated pointer, used for malloc/free. 208 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); 209 210 // Field 2: Actual aligned pointer to payload. 211 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); 212 213 // Field 3: Offset in aligned pointer. 214 memRefDescriptor.setOffset(rewriter, loc, 215 createIndexConstant(rewriter, loc, 0)); 216 217 // Fields 4: Sizes. 218 for (auto en : llvm::enumerate(sizes)) 219 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); 220 221 // Field 5: Strides. 222 for (auto en : llvm::enumerate(strides)) 223 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); 224 225 return memRefDescriptor; 226 } 227 228 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( 229 OpBuilder &builder, Location loc, TypeRange origTypes, 230 SmallVectorImpl<Value> &operands, bool toDynamic) const { 231 assert(origTypes.size() == operands.size() && 232 "expected as may original types as operands"); 233 234 // Find operands of unranked memref type and store them. 235 SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs; 236 for (unsigned i = 0, e = operands.size(); i < e; ++i) 237 if (origTypes[i].isa<UnrankedMemRefType>()) 238 unrankedMemrefs.emplace_back(operands[i]); 239 240 if (unrankedMemrefs.empty()) 241 return success(); 242 243 // Compute allocation sizes. 244 SmallVector<Value, 4> sizes; 245 UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), 246 unrankedMemrefs, sizes); 247 248 // Get frequently used types. 249 MLIRContext *context = builder.getContext(); 250 Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 251 auto i1Type = IntegerType::get(context, 1); 252 Type indexType = getTypeConverter()->getIndexType(); 253 254 // Find the malloc and free, or declare them if necessary. 255 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); 256 LLVM::LLVMFuncOp freeFunc, mallocFunc; 257 if (toDynamic) 258 mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); 259 if (!toDynamic) 260 freeFunc = LLVM::lookupOrCreateFreeFn(module); 261 262 // Initialize shared constants. 263 Value zero = 264 builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false)); 265 266 unsigned unrankedMemrefPos = 0; 267 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 268 Type type = origTypes[i]; 269 if (!type.isa<UnrankedMemRefType>()) 270 continue; 271 Value allocationSize = sizes[unrankedMemrefPos++]; 272 UnrankedMemRefDescriptor desc(operands[i]); 273 274 // Allocate memory, copy, and free the source if necessary. 275 Value memory = 276 toDynamic 277 ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize) 278 .getResult(0) 279 : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize, 280 /*alignment=*/0); 281 Value source = desc.memRefDescPtr(builder, loc); 282 builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero); 283 if (!toDynamic) 284 builder.create<LLVM::CallOp>(loc, freeFunc, source); 285 286 // Create a new descriptor. The same descriptor can be returned multiple 287 // times, attempting to modify its pointer can lead to memory leaks 288 // (allocated twice and overwritten) or double frees (the caller does not 289 // know if the descriptor points to the same memory). 290 Type descriptorType = getTypeConverter()->convertType(type); 291 if (!descriptorType) 292 return failure(); 293 auto updatedDesc = 294 UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); 295 Value rank = desc.rank(builder, loc); 296 updatedDesc.setRank(builder, loc, rank); 297 updatedDesc.setMemRefDescPtr(builder, loc, memory); 298 299 operands[i] = updatedDesc; 300 } 301 302 return success(); 303 } 304 305 //===----------------------------------------------------------------------===// 306 // Detail methods 307 //===----------------------------------------------------------------------===// 308 309 /// Replaces the given operation "op" with a new operation of type "targetOp" 310 /// and given operands. 311 LogicalResult LLVM::detail::oneToOneRewrite( 312 Operation *op, StringRef targetOp, ValueRange operands, 313 LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { 314 unsigned numResults = op->getNumResults(); 315 316 Type packedType; 317 if (numResults != 0) { 318 packedType = typeConverter.packFunctionResults(op->getResultTypes()); 319 if (!packedType) 320 return failure(); 321 } 322 323 // Create the operation through state since we don't know its C++ type. 324 OperationState state(op->getLoc(), targetOp); 325 state.addTypes(packedType); 326 state.addOperands(operands); 327 state.addAttributes(op->getAttrs()); 328 Operation *newOp = rewriter.createOperation(state); 329 330 // If the operation produced 0 or 1 result, return them immediately. 331 if (numResults == 0) 332 return rewriter.eraseOp(op), success(); 333 if (numResults == 1) 334 return rewriter.replaceOp(op, newOp->getResult(0)), success(); 335 336 // Otherwise, it had been converted to an operation producing a structure. 337 // Extract individual results from the structure and return them as list. 338 SmallVector<Value, 4> results; 339 results.reserve(numResults); 340 for (unsigned i = 0; i < numResults; ++i) { 341 auto type = typeConverter.convertType(op->getResult(i).getType()); 342 results.push_back(rewriter.create<LLVM::ExtractValueOp>( 343 op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); 344 } 345 rewriter.replaceOp(op, results); 346 return success(); 347 } 348