1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 27 AllocOpLowering(LLVMTypeConverter &converter) 28 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 29 converter) {} 30 31 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 32 Location loc, Value sizeBytes, 33 Operation *op) const override { 34 // Heap allocations. 35 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 36 MemRefType memRefType = allocOp.getType(); 37 38 Value alignment; 39 if (auto alignmentAttr = allocOp.alignment()) { 40 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 41 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 42 // In the case where no alignment is specified, we may want to override 43 // `malloc's` behavior. `malloc` typically aligns at the size of the 44 // biggest scalar on a target HW. For non-scalars, use the natural 45 // alignment of the LLVM type given by the LLVM DataLayout. 46 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 47 } 48 49 if (alignment) { 50 // Adjust the allocation size to consider alignment. 51 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 52 } 53 54 // Allocate the underlying buffer and store a pointer to it in the MemRef 55 // descriptor. 56 Type elementPtrType = this->getElementPtrType(memRefType); 57 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 58 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 59 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 60 getVoidPtrType()); 61 Value allocatedPtr = 62 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 63 64 Value alignedPtr = allocatedPtr; 65 if (alignment) { 66 // Compute the aligned type pointer. 67 Value allocatedInt = 68 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 69 Value alignmentInt = 70 createAligned(rewriter, loc, allocatedInt, alignment); 71 alignedPtr = 72 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 73 } 74 75 return std::make_tuple(allocatedPtr, alignedPtr); 76 } 77 }; 78 79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 80 AlignedAllocOpLowering(LLVMTypeConverter &converter) 81 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 82 converter) {} 83 84 /// Returns the memref's element size in bytes using the data layout active at 85 /// `op`. 86 // TODO: there are other places where this is used. Expose publicly? 87 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 88 const DataLayout *layout = &defaultLayout; 89 if (const DataLayoutAnalysis *analysis = 90 getTypeConverter()->getDataLayoutAnalysis()) { 91 layout = &analysis->getAbove(op); 92 } 93 Type elementType = memRefType.getElementType(); 94 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 95 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 96 *layout); 97 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 98 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 99 memRefElementType, *layout); 100 return layout->getTypeSize(elementType); 101 } 102 103 /// Returns true if the memref size in bytes is known to be a multiple of 104 /// factor assuming the data layout active at `op`. 105 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 106 Operation *op) const { 107 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 108 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 109 if (type.isDynamic(type.getDimSize(i))) 110 continue; 111 sizeDivisor = sizeDivisor * type.getDimSize(i); 112 } 113 return sizeDivisor % factor == 0; 114 } 115 116 /// Returns the alignment to be used for the allocation call itself. 117 /// aligned_alloc requires the allocation size to be a power of two, and the 118 /// allocation size to be a multiple of alignment, 119 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 120 if (Optional<uint64_t> alignment = allocOp.alignment()) 121 return *alignment; 122 123 // Whenever we don't have alignment set, we will use an alignment 124 // consistent with the element type; since the allocation size has to be a 125 // power of two, we will bump to the next power of two if it already isn't. 126 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 127 return std::max(kMinAlignedAllocAlignment, 128 llvm::PowerOf2Ceil(eltSizeBytes)); 129 } 130 131 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 132 Location loc, Value sizeBytes, 133 Operation *op) const override { 134 // Heap allocations. 135 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 136 MemRefType memRefType = allocOp.getType(); 137 int64_t alignment = getAllocationAlignment(allocOp); 138 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 139 140 // aligned_alloc requires size to be a multiple of alignment; we will pad 141 // the size to the next multiple if necessary. 142 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 143 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 144 145 Type elementPtrType = this->getElementPtrType(memRefType); 146 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 147 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 148 auto results = 149 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 150 getVoidPtrType()); 151 Value allocatedPtr = 152 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 153 154 return std::make_tuple(allocatedPtr, allocatedPtr); 155 } 156 157 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 158 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 159 160 /// Default layout to use in absence of the corresponding analysis. 161 DataLayout defaultLayout; 162 }; 163 164 // Out of line definition, required till C++17. 165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 166 167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 168 AllocaOpLowering(LLVMTypeConverter &converter) 169 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 170 converter) {} 171 172 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 173 /// is set to null for stack allocations. `accessAlignment` is set if 174 /// alignment is needed post allocation (for eg. in conjunction with malloc). 175 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 176 Location loc, Value sizeBytes, 177 Operation *op) const override { 178 179 // With alloca, one gets a pointer to the element type right away. 180 // For stack allocations. 181 auto allocaOp = cast<memref::AllocaOp>(op); 182 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 183 184 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 185 loc, elementPtrType, sizeBytes, 186 allocaOp.alignment() ? *allocaOp.alignment() : 0); 187 188 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 189 } 190 }; 191 192 struct AllocaScopeOpLowering 193 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 194 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 195 196 LogicalResult 197 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, ArrayRef<Value> operands, 198 ConversionPatternRewriter &rewriter) const override { 199 OpBuilder::InsertionGuard guard(rewriter); 200 Location loc = allocaScopeOp.getLoc(); 201 202 // Split the current block before the AllocaScopeOp to create the inlining 203 // point. 204 auto *currentBlock = rewriter.getInsertionBlock(); 205 auto *remainingOpsBlock = 206 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 207 Block *continueBlock; 208 if (allocaScopeOp.getNumResults() == 0) { 209 continueBlock = remainingOpsBlock; 210 } else { 211 continueBlock = rewriter.createBlock(remainingOpsBlock, 212 allocaScopeOp.getResultTypes()); 213 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 214 } 215 216 // Inline body region. 217 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 218 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 219 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 220 221 // Save stack and then branch into the body of the region. 222 rewriter.setInsertionPointToEnd(currentBlock); 223 auto stackSaveOp = 224 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 225 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 226 227 // Replace the alloca_scope return with a branch that jumps out of the body. 228 // Stack restore before leaving the body region. 229 rewriter.setInsertionPointToEnd(afterBody); 230 auto returnOp = 231 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 232 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 233 returnOp, returnOp.results(), continueBlock); 234 235 // Insert stack restore before jumping out the body of the region. 236 rewriter.setInsertionPoint(branchOp); 237 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 238 239 // Replace the op with values return from the body region. 240 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 241 242 return success(); 243 } 244 }; 245 246 struct AssumeAlignmentOpLowering 247 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 248 using ConvertOpToLLVMPattern< 249 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 250 251 LogicalResult 252 matchAndRewrite(memref::AssumeAlignmentOp op, ArrayRef<Value> operands, 253 ConversionPatternRewriter &rewriter) const override { 254 memref::AssumeAlignmentOp::Adaptor transformed(operands); 255 Value memref = transformed.memref(); 256 unsigned alignment = op.alignment(); 257 auto loc = op.getLoc(); 258 259 MemRefDescriptor memRefDescriptor(memref); 260 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 261 262 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 263 // the asserted memref.alignedPtr isn't used anywhere else, as the real 264 // users like load/store/views always re-extract memref.alignedPtr as they 265 // get lowered. 266 // 267 // This relies on LLVM's CSE optimization (potentially after SROA), since 268 // after CSE all memref.alignedPtr instances get de-duplicated into the same 269 // pointer SSA value. 270 auto intPtrType = 271 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 272 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 273 Value mask = 274 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 275 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 276 rewriter.create<LLVM::AssumeOp>( 277 loc, rewriter.create<LLVM::ICmpOp>( 278 loc, LLVM::ICmpPredicate::eq, 279 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 280 281 rewriter.eraseOp(op); 282 return success(); 283 } 284 }; 285 286 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 287 // The memref descriptor being an SSA value, there is no need to clean it up 288 // in any way. 289 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 290 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 291 292 explicit DeallocOpLowering(LLVMTypeConverter &converter) 293 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 294 295 LogicalResult 296 matchAndRewrite(memref::DeallocOp op, ArrayRef<Value> operands, 297 ConversionPatternRewriter &rewriter) const override { 298 assert(operands.size() == 1 && "dealloc takes one operand"); 299 memref::DeallocOp::Adaptor transformed(operands); 300 301 // Insert the `free` declaration if it is not already present. 302 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 303 MemRefDescriptor memref(transformed.memref()); 304 Value casted = rewriter.create<LLVM::BitcastOp>( 305 op.getLoc(), getVoidPtrType(), 306 memref.allocatedPtr(rewriter, op.getLoc())); 307 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 308 op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted); 309 return success(); 310 } 311 }; 312 313 // A `dim` is converted to a constant for static sizes and to an access to the 314 // size stored in the memref descriptor for dynamic sizes. 315 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 316 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 317 318 LogicalResult 319 matchAndRewrite(memref::DimOp dimOp, ArrayRef<Value> operands, 320 ConversionPatternRewriter &rewriter) const override { 321 Type operandType = dimOp.source().getType(); 322 if (operandType.isa<UnrankedMemRefType>()) { 323 rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef( 324 operandType, dimOp, operands, rewriter)}); 325 326 return success(); 327 } 328 if (operandType.isa<MemRefType>()) { 329 rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef( 330 operandType, dimOp, operands, rewriter)}); 331 return success(); 332 } 333 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 334 } 335 336 private: 337 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 338 ArrayRef<Value> operands, 339 ConversionPatternRewriter &rewriter) const { 340 Location loc = dimOp.getLoc(); 341 memref::DimOp::Adaptor transformed(operands); 342 343 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 344 auto scalarMemRefType = 345 MemRefType::get({}, unrankedMemRefType.getElementType()); 346 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 347 348 // Extract pointer to the underlying ranked descriptor and bitcast it to a 349 // memref<element_type> descriptor pointer to minimize the number of GEP 350 // operations. 351 UnrankedMemRefDescriptor unrankedDesc(transformed.source()); 352 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 353 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 354 loc, 355 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 356 addressSpace), 357 underlyingRankedDesc); 358 359 // Get pointer to offset field of memref<element_type> descriptor. 360 Type indexPtrTy = LLVM::LLVMPointerType::get( 361 getTypeConverter()->getIndexType(), addressSpace); 362 Value two = rewriter.create<LLVM::ConstantOp>( 363 loc, typeConverter->convertType(rewriter.getI32Type()), 364 rewriter.getI32IntegerAttr(2)); 365 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 366 loc, indexPtrTy, scalarMemRefDescPtr, 367 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 368 369 // The size value that we have to extract can be obtained using GEPop with 370 // `dimOp.index() + 1` index argument. 371 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 372 loc, createIndexConstant(rewriter, loc, 1), transformed.index()); 373 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 374 ValueRange({idxPlusOne})); 375 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 376 } 377 378 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 379 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 380 return idx; 381 382 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 383 return constantOp.value().cast<IntegerAttr>().getValue().getSExtValue(); 384 385 return llvm::None; 386 } 387 388 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 389 ArrayRef<Value> operands, 390 ConversionPatternRewriter &rewriter) const { 391 Location loc = dimOp.getLoc(); 392 memref::DimOp::Adaptor transformed(operands); 393 // Take advantage if index is constant. 394 MemRefType memRefType = operandType.cast<MemRefType>(); 395 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 396 int64_t i = index.getValue(); 397 if (memRefType.isDynamicDim(i)) { 398 // extract dynamic size from the memref descriptor. 399 MemRefDescriptor descriptor(transformed.source()); 400 return descriptor.size(rewriter, loc, i); 401 } 402 // Use constant for static size. 403 int64_t dimSize = memRefType.getDimSize(i); 404 return createIndexConstant(rewriter, loc, dimSize); 405 } 406 Value index = transformed.index(); 407 int64_t rank = memRefType.getRank(); 408 MemRefDescriptor memrefDescriptor(transformed.source()); 409 return memrefDescriptor.size(rewriter, loc, index, rank); 410 } 411 }; 412 413 /// Returns the LLVM type of the global variable given the memref type `type`. 414 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 415 LLVMTypeConverter &typeConverter) { 416 // LLVM type for a global memref will be a multi-dimension array. For 417 // declarations or uninitialized global memrefs, we can potentially flatten 418 // this to a 1D array. However, for memref.global's with an initial value, 419 // we do not intend to flatten the ElementsAttribute when going from std -> 420 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 421 Type elementType = typeConverter.convertType(type.getElementType()); 422 Type arrayTy = elementType; 423 // Shape has the outermost dim at index 0, so need to walk it backwards 424 for (int64_t dim : llvm::reverse(type.getShape())) 425 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 426 return arrayTy; 427 } 428 429 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 430 struct GlobalMemrefOpLowering 431 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 432 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 433 434 LogicalResult 435 matchAndRewrite(memref::GlobalOp global, ArrayRef<Value> operands, 436 ConversionPatternRewriter &rewriter) const override { 437 MemRefType type = global.type().cast<MemRefType>(); 438 if (!isConvertibleAndHasIdentityMaps(type)) 439 return failure(); 440 441 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 442 443 LLVM::Linkage linkage = 444 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 445 446 Attribute initialValue = nullptr; 447 if (!global.isExternal() && !global.isUninitialized()) { 448 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 449 initialValue = elementsAttr; 450 451 // For scalar memrefs, the global variable created is of the element type, 452 // so unpack the elements attribute to extract the value. 453 if (type.getRank() == 0) 454 initialValue = elementsAttr.getValue({}); 455 } 456 457 rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 458 global, arrayTy, global.constant(), linkage, global.sym_name(), 459 initialValue, /*alignment=*/0, type.getMemorySpaceAsInt()); 460 return success(); 461 } 462 }; 463 464 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 465 /// the first element stashed into the descriptor. This reuses 466 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 467 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 468 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 469 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 470 converter) {} 471 472 /// Buffer "allocation" for memref.get_global op is getting the address of 473 /// the global variable referenced. 474 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 475 Location loc, Value sizeBytes, 476 Operation *op) const override { 477 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 478 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 479 unsigned memSpace = type.getMemorySpaceAsInt(); 480 481 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 482 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 483 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 484 485 // Get the address of the first element in the array by creating a GEP with 486 // the address of the GV as the base, and (rank + 1) number of 0 indices. 487 Type elementType = typeConverter->convertType(type.getElementType()); 488 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 489 490 SmallVector<Value, 4> operands = {addressOf}; 491 operands.insert(operands.end(), type.getRank() + 1, 492 createIndexConstant(rewriter, loc, 0)); 493 auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands); 494 495 // We do not expect the memref obtained using `memref.get_global` to be 496 // ever deallocated. Set the allocated pointer to be known bad value to 497 // help debug if that ever happens. 498 auto intPtrType = getIntPtrType(memSpace); 499 Value deadBeefConst = 500 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 501 auto deadBeefPtr = 502 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 503 504 // Both allocated and aligned pointers are same. We could potentially stash 505 // a nullptr for the allocated pointer since we do not expect any dealloc. 506 return std::make_tuple(deadBeefPtr, gep); 507 } 508 }; 509 510 // Common base for load and store operations on MemRefs. Restricts the match 511 // to supported MemRef types. Provides functionality to emit code accessing a 512 // specific element of the underlying data buffer. 513 template <typename Derived> 514 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 515 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 516 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 517 using Base = LoadStoreOpLowering<Derived>; 518 519 LogicalResult match(Derived op) const override { 520 MemRefType type = op.getMemRefType(); 521 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 522 } 523 }; 524 525 // Load operation is lowered to obtaining a pointer to the indexed element 526 // and loading it. 527 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 528 using Base::Base; 529 530 LogicalResult 531 matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands, 532 ConversionPatternRewriter &rewriter) const override { 533 memref::LoadOp::Adaptor transformed(operands); 534 auto type = loadOp.getMemRefType(); 535 536 Value dataPtr = 537 getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(), 538 transformed.indices(), rewriter); 539 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 540 return success(); 541 } 542 }; 543 544 // Store operation is lowered to obtaining a pointer to the indexed element, 545 // and storing the given value to it. 546 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 547 using Base::Base; 548 549 LogicalResult 550 matchAndRewrite(memref::StoreOp op, ArrayRef<Value> operands, 551 ConversionPatternRewriter &rewriter) const override { 552 auto type = op.getMemRefType(); 553 memref::StoreOp::Adaptor transformed(operands); 554 555 Value dataPtr = 556 getStridedElementPtr(op.getLoc(), type, transformed.memref(), 557 transformed.indices(), rewriter); 558 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(), 559 dataPtr); 560 return success(); 561 } 562 }; 563 564 // The prefetch operation is lowered in a way similar to the load operation 565 // except that the llvm.prefetch operation is used for replacement. 566 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 567 using Base::Base; 568 569 LogicalResult 570 matchAndRewrite(memref::PrefetchOp prefetchOp, ArrayRef<Value> operands, 571 ConversionPatternRewriter &rewriter) const override { 572 memref::PrefetchOp::Adaptor transformed(operands); 573 auto type = prefetchOp.getMemRefType(); 574 auto loc = prefetchOp.getLoc(); 575 576 Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(), 577 transformed.indices(), rewriter); 578 579 // Replace with llvm.prefetch. 580 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 581 auto isWrite = rewriter.create<LLVM::ConstantOp>( 582 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 583 auto localityHint = rewriter.create<LLVM::ConstantOp>( 584 loc, llvmI32Type, 585 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 586 auto isData = rewriter.create<LLVM::ConstantOp>( 587 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 588 589 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 590 localityHint, isData); 591 return success(); 592 } 593 }; 594 595 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 596 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 597 598 LogicalResult match(memref::CastOp memRefCastOp) const override { 599 Type srcType = memRefCastOp.getOperand().getType(); 600 Type dstType = memRefCastOp.getType(); 601 602 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 603 // used for type erasure. For now they must preserve underlying element type 604 // and require source and result type to have the same rank. Therefore, 605 // perform a sanity check that the underlying structs are the same. Once op 606 // semantics are relaxed we can revisit. 607 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 608 return success(typeConverter->convertType(srcType) == 609 typeConverter->convertType(dstType)); 610 611 // At least one of the operands is unranked type 612 assert(srcType.isa<UnrankedMemRefType>() || 613 dstType.isa<UnrankedMemRefType>()); 614 615 // Unranked to unranked cast is disallowed 616 return !(srcType.isa<UnrankedMemRefType>() && 617 dstType.isa<UnrankedMemRefType>()) 618 ? success() 619 : failure(); 620 } 621 622 void rewrite(memref::CastOp memRefCastOp, ArrayRef<Value> operands, 623 ConversionPatternRewriter &rewriter) const override { 624 memref::CastOp::Adaptor transformed(operands); 625 626 auto srcType = memRefCastOp.getOperand().getType(); 627 auto dstType = memRefCastOp.getType(); 628 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 629 auto loc = memRefCastOp.getLoc(); 630 631 // For ranked/ranked case, just keep the original descriptor. 632 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 633 return rewriter.replaceOp(memRefCastOp, {transformed.source()}); 634 635 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 636 // Casting ranked to unranked memref type 637 // Set the rank in the destination from the memref type 638 // Allocate space on the stack and copy the src memref descriptor 639 // Set the ptr in the destination to the stack space 640 auto srcMemRefType = srcType.cast<MemRefType>(); 641 int64_t rank = srcMemRefType.getRank(); 642 // ptr = AllocaOp sizeof(MemRefDescriptor) 643 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 644 loc, transformed.source(), rewriter); 645 // voidptr = BitCastOp srcType* to void* 646 auto voidPtr = 647 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 648 .getResult(); 649 // rank = ConstantOp srcRank 650 auto rankVal = rewriter.create<LLVM::ConstantOp>( 651 loc, typeConverter->convertType(rewriter.getIntegerType(64)), 652 rewriter.getI64IntegerAttr(rank)); 653 // undef = UndefOp 654 UnrankedMemRefDescriptor memRefDesc = 655 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 656 // d1 = InsertValueOp undef, rank, 0 657 memRefDesc.setRank(rewriter, loc, rankVal); 658 // d2 = InsertValueOp d1, voidptr, 1 659 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 660 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 661 662 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 663 // Casting from unranked type to ranked. 664 // The operation is assumed to be doing a correct cast. If the destination 665 // type mismatches the unranked the type, it is undefined behavior. 666 UnrankedMemRefDescriptor memRefDesc(transformed.source()); 667 // ptr = ExtractValueOp src, 1 668 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 669 // castPtr = BitCastOp i8* to structTy* 670 auto castPtr = 671 rewriter 672 .create<LLVM::BitcastOp>( 673 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 674 .getResult(); 675 // struct = LoadOp castPtr 676 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 677 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 678 } else { 679 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 680 } 681 } 682 }; 683 684 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 685 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 686 687 LogicalResult 688 matchAndRewrite(memref::CopyOp op, ArrayRef<Value> operands, 689 ConversionPatternRewriter &rewriter) const override { 690 auto loc = op.getLoc(); 691 memref::CopyOp::Adaptor adaptor(operands); 692 auto srcType = op.source().getType().cast<BaseMemRefType>(); 693 auto targetType = op.target().getType().cast<BaseMemRefType>(); 694 695 // First make sure we have an unranked memref descriptor representation. 696 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 697 auto rank = rewriter.create<LLVM::ConstantOp>( 698 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 699 auto *typeConverter = getTypeConverter(); 700 auto ptr = 701 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 702 auto voidPtr = 703 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 704 .getResult(); 705 auto unrankedType = 706 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 707 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 708 unrankedType, 709 ValueRange{rank, voidPtr}); 710 }; 711 712 Value unrankedSource = srcType.hasRank() 713 ? makeUnranked(adaptor.source(), srcType) 714 : adaptor.source(); 715 Value unrankedTarget = targetType.hasRank() 716 ? makeUnranked(adaptor.target(), targetType) 717 : adaptor.target(); 718 719 // Now promote the unranked descriptors to the stack. 720 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 721 rewriter.getIndexAttr(1)); 722 auto promote = [&](Value desc) { 723 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 724 auto allocated = 725 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 726 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 727 return allocated; 728 }; 729 730 auto sourcePtr = promote(unrankedSource); 731 auto targetPtr = promote(unrankedTarget); 732 733 auto elemSize = rewriter.create<LLVM::ConstantOp>( 734 loc, getIndexType(), 735 rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8)); 736 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 737 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 738 rewriter.create<LLVM::CallOp>(loc, copyFn, 739 ValueRange{elemSize, sourcePtr, targetPtr}); 740 rewriter.eraseOp(op); 741 742 return success(); 743 } 744 }; 745 746 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 747 /// memref type. In unranked case, the fields are extracted from the underlying 748 /// ranked descriptor. 749 static void extractPointersAndOffset(Location loc, 750 ConversionPatternRewriter &rewriter, 751 LLVMTypeConverter &typeConverter, 752 Value originalOperand, 753 Value convertedOperand, 754 Value *allocatedPtr, Value *alignedPtr, 755 Value *offset = nullptr) { 756 Type operandType = originalOperand.getType(); 757 if (operandType.isa<MemRefType>()) { 758 MemRefDescriptor desc(convertedOperand); 759 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 760 *alignedPtr = desc.alignedPtr(rewriter, loc); 761 if (offset != nullptr) 762 *offset = desc.offset(rewriter, loc); 763 return; 764 } 765 766 unsigned memorySpace = 767 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 768 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 769 Type llvmElementType = typeConverter.convertType(elementType); 770 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 771 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 772 773 // Extract pointer to the underlying ranked memref descriptor and cast it to 774 // ElemType**. 775 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 776 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 777 778 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 779 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 780 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 781 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 782 if (offset != nullptr) { 783 *offset = UnrankedMemRefDescriptor::offset( 784 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 785 } 786 } 787 788 struct MemRefReinterpretCastOpLowering 789 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 790 using ConvertOpToLLVMPattern< 791 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 792 793 LogicalResult 794 matchAndRewrite(memref::ReinterpretCastOp castOp, ArrayRef<Value> operands, 795 ConversionPatternRewriter &rewriter) const override { 796 memref::ReinterpretCastOp::Adaptor adaptor(operands, 797 castOp->getAttrDictionary()); 798 Type srcType = castOp.source().getType(); 799 800 Value descriptor; 801 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 802 adaptor, &descriptor))) 803 return failure(); 804 rewriter.replaceOp(castOp, {descriptor}); 805 return success(); 806 } 807 808 private: 809 LogicalResult convertSourceMemRefToDescriptor( 810 ConversionPatternRewriter &rewriter, Type srcType, 811 memref::ReinterpretCastOp castOp, 812 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 813 MemRefType targetMemRefType = 814 castOp.getResult().getType().cast<MemRefType>(); 815 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 816 .dyn_cast_or_null<LLVM::LLVMStructType>(); 817 if (!llvmTargetDescriptorTy) 818 return failure(); 819 820 // Create descriptor. 821 Location loc = castOp.getLoc(); 822 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 823 824 // Set allocated and aligned pointers. 825 Value allocatedPtr, alignedPtr; 826 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 827 castOp.source(), adaptor.source(), &allocatedPtr, 828 &alignedPtr); 829 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 830 desc.setAlignedPtr(rewriter, loc, alignedPtr); 831 832 // Set offset. 833 if (castOp.isDynamicOffset(0)) 834 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 835 else 836 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 837 838 // Set sizes and strides. 839 unsigned dynSizeId = 0; 840 unsigned dynStrideId = 0; 841 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 842 if (castOp.isDynamicSize(i)) 843 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 844 else 845 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 846 847 if (castOp.isDynamicStride(i)) 848 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 849 else 850 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 851 } 852 *descriptor = desc; 853 return success(); 854 } 855 }; 856 857 struct MemRefReshapeOpLowering 858 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 859 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 860 861 LogicalResult 862 matchAndRewrite(memref::ReshapeOp reshapeOp, ArrayRef<Value> operands, 863 ConversionPatternRewriter &rewriter) const override { 864 auto *op = reshapeOp.getOperation(); 865 memref::ReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary()); 866 Type srcType = reshapeOp.source().getType(); 867 868 Value descriptor; 869 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 870 adaptor, &descriptor))) 871 return failure(); 872 rewriter.replaceOp(op, {descriptor}); 873 return success(); 874 } 875 876 private: 877 LogicalResult 878 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 879 Type srcType, memref::ReshapeOp reshapeOp, 880 memref::ReshapeOp::Adaptor adaptor, 881 Value *descriptor) const { 882 // Conversion for statically-known shape args is performed via 883 // `memref_reinterpret_cast`. 884 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 885 if (shapeMemRefType.hasStaticShape()) 886 return failure(); 887 888 // The shape is a rank-1 tensor with unknown length. 889 Location loc = reshapeOp.getLoc(); 890 MemRefDescriptor shapeDesc(adaptor.shape()); 891 Value resultRank = shapeDesc.size(rewriter, loc, 0); 892 893 // Extract address space and element type. 894 auto targetType = 895 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 896 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 897 Type elementType = targetType.getElementType(); 898 899 // Create the unranked memref descriptor that holds the ranked one. The 900 // inner descriptor is allocated on stack. 901 auto targetDesc = UnrankedMemRefDescriptor::undef( 902 rewriter, loc, typeConverter->convertType(targetType)); 903 targetDesc.setRank(rewriter, loc, resultRank); 904 SmallVector<Value, 4> sizes; 905 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 906 targetDesc, sizes); 907 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 908 loc, getVoidPtrType(), sizes.front(), llvm::None); 909 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 910 911 // Extract pointers and offset from the source memref. 912 Value allocatedPtr, alignedPtr, offset; 913 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 914 reshapeOp.source(), adaptor.source(), 915 &allocatedPtr, &alignedPtr, &offset); 916 917 // Set pointers and offset. 918 Type llvmElementType = typeConverter->convertType(elementType); 919 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 920 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 921 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 922 elementPtrPtrType, allocatedPtr); 923 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 924 underlyingDescPtr, 925 elementPtrPtrType, alignedPtr); 926 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 927 underlyingDescPtr, elementPtrPtrType, 928 offset); 929 930 // Use the offset pointer as base for further addressing. Copy over the new 931 // shape and compute strides. For this, we create a loop from rank-1 to 0. 932 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 933 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 934 elementPtrPtrType); 935 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 936 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 937 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 938 Value oneIndex = createIndexConstant(rewriter, loc, 1); 939 Value resultRankMinusOne = 940 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 941 942 Block *initBlock = rewriter.getInsertionBlock(); 943 Type indexType = getTypeConverter()->getIndexType(); 944 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 945 946 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 947 {indexType, indexType}); 948 949 // Move the remaining initBlock ops to condBlock. 950 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 951 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 952 953 rewriter.setInsertionPointToEnd(initBlock); 954 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 955 condBlock); 956 rewriter.setInsertionPointToStart(condBlock); 957 Value indexArg = condBlock->getArgument(0); 958 Value strideArg = condBlock->getArgument(1); 959 960 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 961 Value pred = rewriter.create<LLVM::ICmpOp>( 962 loc, IntegerType::get(rewriter.getContext(), 1), 963 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 964 965 Block *bodyBlock = 966 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 967 rewriter.setInsertionPointToStart(bodyBlock); 968 969 // Copy size from shape to descriptor. 970 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 971 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 972 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 973 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 974 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 975 targetSizesBase, indexArg, size); 976 977 // Write stride value and compute next one. 978 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 979 targetStridesBase, indexArg, strideArg); 980 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 981 982 // Decrement loop counter and branch back. 983 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 984 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 985 condBlock); 986 987 Block *remainder = 988 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 989 990 // Hook up the cond exit to the remainder. 991 rewriter.setInsertionPointToEnd(condBlock); 992 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 993 llvm::None); 994 995 // Reset position to beginning of new remainder block. 996 rewriter.setInsertionPointToStart(remainder); 997 998 *descriptor = targetDesc; 999 return success(); 1000 } 1001 }; 1002 1003 /// Conversion pattern that transforms a subview op into: 1004 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1005 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1006 /// and stride. 1007 /// The subview op is replaced by the descriptor. 1008 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1009 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1010 1011 LogicalResult 1012 matchAndRewrite(memref::SubViewOp subViewOp, ArrayRef<Value> operands, 1013 ConversionPatternRewriter &rewriter) const override { 1014 auto loc = subViewOp.getLoc(); 1015 1016 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1017 auto sourceElementTy = 1018 typeConverter->convertType(sourceMemRefType.getElementType()); 1019 1020 auto viewMemRefType = subViewOp.getType(); 1021 auto inferredType = memref::SubViewOp::inferResultType( 1022 subViewOp.getSourceType(), 1023 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1024 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1025 extractFromI64ArrayAttr(subViewOp.static_strides())) 1026 .cast<MemRefType>(); 1027 auto targetElementTy = 1028 typeConverter->convertType(viewMemRefType.getElementType()); 1029 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1030 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1031 !LLVM::isCompatibleType(sourceElementTy) || 1032 !LLVM::isCompatibleType(targetElementTy) || 1033 !LLVM::isCompatibleType(targetDescTy)) 1034 return failure(); 1035 1036 // Extract the offset and strides from the type. 1037 int64_t offset; 1038 SmallVector<int64_t, 4> strides; 1039 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1040 if (failed(successStrides)) 1041 return failure(); 1042 1043 // Create the descriptor. 1044 if (!LLVM::isCompatibleType(operands.front().getType())) 1045 return failure(); 1046 MemRefDescriptor sourceMemRef(operands.front()); 1047 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1048 1049 // Copy the buffer pointer from the old descriptor to the new one. 1050 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1051 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1052 loc, 1053 LLVM::LLVMPointerType::get(targetElementTy, 1054 viewMemRefType.getMemorySpaceAsInt()), 1055 extracted); 1056 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1057 1058 // Copy the aligned pointer from the old descriptor to the new one. 1059 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1060 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1061 loc, 1062 LLVM::LLVMPointerType::get(targetElementTy, 1063 viewMemRefType.getMemorySpaceAsInt()), 1064 extracted); 1065 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1066 1067 auto shape = viewMemRefType.getShape(); 1068 auto inferredShape = inferredType.getShape(); 1069 size_t inferredShapeRank = inferredShape.size(); 1070 size_t resultShapeRank = shape.size(); 1071 llvm::SmallDenseSet<unsigned> unusedDims = 1072 computeRankReductionMask(inferredShape, shape).getValue(); 1073 1074 // Extract strides needed to compute offset. 1075 SmallVector<Value, 4> strideValues; 1076 strideValues.reserve(inferredShapeRank); 1077 for (unsigned i = 0; i < inferredShapeRank; ++i) 1078 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1079 1080 // Offset. 1081 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1082 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1083 targetMemRef.setConstantOffset(rewriter, loc, offset); 1084 } else { 1085 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1086 // `inferredShapeRank` may be larger than the number of offset operands 1087 // because of trailing semantics. In this case, the offset is guaranteed 1088 // to be interpreted as 0 and we can just skip the extra dimensions. 1089 for (unsigned i = 0, e = std::min(inferredShapeRank, 1090 subViewOp.getMixedOffsets().size()); 1091 i < e; ++i) { 1092 Value offset = 1093 // TODO: need OpFoldResult ODS adaptor to clean this up. 1094 subViewOp.isDynamicOffset(i) 1095 ? operands[subViewOp.getIndexOfDynamicOffset(i)] 1096 : rewriter.create<LLVM::ConstantOp>( 1097 loc, llvmIndexType, 1098 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1099 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1100 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1101 } 1102 targetMemRef.setOffset(rewriter, loc, baseOffset); 1103 } 1104 1105 // Update sizes and strides. 1106 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1107 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1108 assert(mixedSizes.size() == mixedStrides.size() && 1109 "expected sizes and strides of equal length"); 1110 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1111 i >= 0 && j >= 0; --i) { 1112 if (unusedDims.contains(i)) 1113 continue; 1114 1115 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1116 // In this case, the size is guaranteed to be interpreted as Dim and the 1117 // stride as 1. 1118 Value size, stride; 1119 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1120 // If the static size is available, use it directly. This is similar to 1121 // the folding of dim(constant-op) but removes the need for dim to be 1122 // aware of LLVM constants and for this pass to be aware of std 1123 // constants. 1124 int64_t staticSize = 1125 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1126 if (staticSize != ShapedType::kDynamicSize) { 1127 size = rewriter.create<LLVM::ConstantOp>( 1128 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1129 } else { 1130 Value pos = rewriter.create<LLVM::ConstantOp>( 1131 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1132 size = rewriter.create<LLVM::DialectCastOp>( 1133 loc, llvmIndexType, 1134 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos)); 1135 } 1136 stride = rewriter.create<LLVM::ConstantOp>( 1137 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1138 } else { 1139 // TODO: need OpFoldResult ODS adaptor to clean this up. 1140 size = 1141 subViewOp.isDynamicSize(i) 1142 ? operands[subViewOp.getIndexOfDynamicSize(i)] 1143 : rewriter.create<LLVM::ConstantOp>( 1144 loc, llvmIndexType, 1145 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1146 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1147 stride = rewriter.create<LLVM::ConstantOp>( 1148 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1149 } else { 1150 stride = subViewOp.isDynamicStride(i) 1151 ? operands[subViewOp.getIndexOfDynamicStride(i)] 1152 : rewriter.create<LLVM::ConstantOp>( 1153 loc, llvmIndexType, 1154 rewriter.getI64IntegerAttr( 1155 subViewOp.getStaticStride(i))); 1156 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1157 } 1158 } 1159 targetMemRef.setSize(rewriter, loc, j, size); 1160 targetMemRef.setStride(rewriter, loc, j, stride); 1161 j--; 1162 } 1163 1164 rewriter.replaceOp(subViewOp, {targetMemRef}); 1165 return success(); 1166 } 1167 }; 1168 1169 /// Conversion pattern that transforms a transpose op into: 1170 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1171 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1172 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1173 /// and stride. Size and stride are permutations of the original values. 1174 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1175 /// The transpose op is replaced by the alloca'ed pointer. 1176 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1177 public: 1178 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1179 1180 LogicalResult 1181 matchAndRewrite(memref::TransposeOp transposeOp, ArrayRef<Value> operands, 1182 ConversionPatternRewriter &rewriter) const override { 1183 auto loc = transposeOp.getLoc(); 1184 memref::TransposeOpAdaptor adaptor(operands); 1185 MemRefDescriptor viewMemRef(adaptor.in()); 1186 1187 // No permutation, early exit. 1188 if (transposeOp.permutation().isIdentity()) 1189 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1190 1191 auto targetMemRef = MemRefDescriptor::undef( 1192 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1193 1194 // Copy the base and aligned pointers from the old descriptor to the new 1195 // one. 1196 targetMemRef.setAllocatedPtr(rewriter, loc, 1197 viewMemRef.allocatedPtr(rewriter, loc)); 1198 targetMemRef.setAlignedPtr(rewriter, loc, 1199 viewMemRef.alignedPtr(rewriter, loc)); 1200 1201 // Copy the offset pointer from the old descriptor to the new one. 1202 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1203 1204 // Iterate over the dimensions and apply size/stride permutation. 1205 for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { 1206 int sourcePos = en.index(); 1207 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1208 targetMemRef.setSize(rewriter, loc, targetPos, 1209 viewMemRef.size(rewriter, loc, sourcePos)); 1210 targetMemRef.setStride(rewriter, loc, targetPos, 1211 viewMemRef.stride(rewriter, loc, sourcePos)); 1212 } 1213 1214 rewriter.replaceOp(transposeOp, {targetMemRef}); 1215 return success(); 1216 } 1217 }; 1218 1219 /// Conversion pattern that transforms an op into: 1220 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1221 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1222 /// and stride. 1223 /// The view op is replaced by the descriptor. 1224 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1225 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1226 1227 // Build and return the value for the idx^th shape dimension, either by 1228 // returning the constant shape dimension or counting the proper dynamic size. 1229 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1230 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1231 unsigned idx) const { 1232 assert(idx < shape.size()); 1233 if (!ShapedType::isDynamic(shape[idx])) 1234 return createIndexConstant(rewriter, loc, shape[idx]); 1235 // Count the number of dynamic dims in range [0, idx] 1236 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1237 return ShapedType::isDynamic(v); 1238 }); 1239 return dynamicSizes[nDynamic]; 1240 } 1241 1242 // Build and return the idx^th stride, either by returning the constant stride 1243 // or by computing the dynamic stride from the current `runningStride` and 1244 // `nextSize`. The caller should keep a running stride and update it with the 1245 // result returned by this function. 1246 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1247 ArrayRef<int64_t> strides, Value nextSize, 1248 Value runningStride, unsigned idx) const { 1249 assert(idx < strides.size()); 1250 if (!MemRefType::isDynamicStrideOrOffset(strides[idx])) 1251 return createIndexConstant(rewriter, loc, strides[idx]); 1252 if (nextSize) 1253 return runningStride 1254 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1255 : nextSize; 1256 assert(!runningStride); 1257 return createIndexConstant(rewriter, loc, 1); 1258 } 1259 1260 LogicalResult 1261 matchAndRewrite(memref::ViewOp viewOp, ArrayRef<Value> operands, 1262 ConversionPatternRewriter &rewriter) const override { 1263 auto loc = viewOp.getLoc(); 1264 memref::ViewOpAdaptor adaptor(operands); 1265 1266 auto viewMemRefType = viewOp.getType(); 1267 auto targetElementTy = 1268 typeConverter->convertType(viewMemRefType.getElementType()); 1269 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1270 if (!targetDescTy || !targetElementTy || 1271 !LLVM::isCompatibleType(targetElementTy) || 1272 !LLVM::isCompatibleType(targetDescTy)) 1273 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1274 failure(); 1275 1276 int64_t offset; 1277 SmallVector<int64_t, 4> strides; 1278 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1279 if (failed(successStrides)) 1280 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1281 assert(offset == 0 && "expected offset to be 0"); 1282 1283 // Create the descriptor. 1284 MemRefDescriptor sourceMemRef(adaptor.source()); 1285 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1286 1287 // Field 1: Copy the allocated pointer, used for malloc/free. 1288 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1289 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1290 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1291 loc, 1292 LLVM::LLVMPointerType::get(targetElementTy, 1293 srcMemRefType.getMemorySpaceAsInt()), 1294 allocatedPtr); 1295 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1296 1297 // Field 2: Copy the actual aligned pointer to payload. 1298 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1299 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1300 alignedPtr, adaptor.byte_shift()); 1301 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1302 loc, 1303 LLVM::LLVMPointerType::get(targetElementTy, 1304 srcMemRefType.getMemorySpaceAsInt()), 1305 alignedPtr); 1306 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1307 1308 // Field 3: The offset in the resulting type must be 0. This is because of 1309 // the type change: an offset on srcType* may not be expressible as an 1310 // offset on dstType*. 1311 targetMemRef.setOffset(rewriter, loc, 1312 createIndexConstant(rewriter, loc, offset)); 1313 1314 // Early exit for 0-D corner case. 1315 if (viewMemRefType.getRank() == 0) 1316 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1317 1318 // Fields 4 and 5: Update sizes and strides. 1319 if (strides.back() != 1) 1320 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1321 failure(); 1322 Value stride = nullptr, nextSize = nullptr; 1323 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1324 // Update size. 1325 Value size = 1326 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1327 targetMemRef.setSize(rewriter, loc, i, size); 1328 // Update stride. 1329 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1330 targetMemRef.setStride(rewriter, loc, i, stride); 1331 nextSize = size; 1332 } 1333 1334 rewriter.replaceOp(viewOp, {targetMemRef}); 1335 return success(); 1336 } 1337 }; 1338 1339 } // namespace 1340 1341 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1342 RewritePatternSet &patterns) { 1343 // clang-format off 1344 patterns.add< 1345 AllocaOpLowering, 1346 AllocaScopeOpLowering, 1347 AssumeAlignmentOpLowering, 1348 DimOpLowering, 1349 DeallocOpLowering, 1350 GlobalMemrefOpLowering, 1351 GetGlobalMemrefOpLowering, 1352 LoadOpLowering, 1353 MemRefCastOpLowering, 1354 MemRefCopyOpLowering, 1355 MemRefReinterpretCastOpLowering, 1356 MemRefReshapeOpLowering, 1357 PrefetchOpLowering, 1358 StoreOpLowering, 1359 SubViewOpLowering, 1360 TransposeOpLowering, 1361 ViewOpLowering>(converter); 1362 // clang-format on 1363 auto allocLowering = converter.getOptions().allocLowering; 1364 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1365 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1366 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1367 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1368 } 1369 1370 namespace { 1371 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1372 MemRefToLLVMPass() = default; 1373 1374 void runOnOperation() override { 1375 Operation *op = getOperation(); 1376 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1377 LowerToLLVMOptions options(&getContext(), 1378 dataLayoutAnalysis.getAtOrAbove(op)); 1379 options.allocLowering = 1380 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1381 : LowerToLLVMOptions::AllocLowering::Malloc); 1382 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1383 options.overrideIndexBitwidth(indexBitwidth); 1384 1385 LLVMTypeConverter typeConverter(&getContext(), options, 1386 &dataLayoutAnalysis); 1387 RewritePatternSet patterns(&getContext()); 1388 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1389 LLVMConversionTarget target(getContext()); 1390 target.addLegalOp<LLVM::DialectCastOp, FuncOp>(); 1391 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1392 signalPassFailure(); 1393 } 1394 }; 1395 } // namespace 1396 1397 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1398 return std::make_unique<MemRefToLLVMPass>(); 1399 } 1400