1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 27 AllocOpLowering(LLVMTypeConverter &converter) 28 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 29 converter) {} 30 31 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 32 Location loc, Value sizeBytes, 33 Operation *op) const override { 34 // Heap allocations. 35 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 36 MemRefType memRefType = allocOp.getType(); 37 38 Value alignment; 39 if (auto alignmentAttr = allocOp.alignment()) { 40 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 41 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 42 // In the case where no alignment is specified, we may want to override 43 // `malloc's` behavior. `malloc` typically aligns at the size of the 44 // biggest scalar on a target HW. For non-scalars, use the natural 45 // alignment of the LLVM type given by the LLVM DataLayout. 46 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 47 } 48 49 if (alignment) { 50 // Adjust the allocation size to consider alignment. 51 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 52 } 53 54 // Allocate the underlying buffer and store a pointer to it in the MemRef 55 // descriptor. 56 Type elementPtrType = this->getElementPtrType(memRefType); 57 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 58 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 59 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 60 getVoidPtrType()); 61 Value allocatedPtr = 62 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 63 64 Value alignedPtr = allocatedPtr; 65 if (alignment) { 66 // Compute the aligned type pointer. 67 Value allocatedInt = 68 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 69 Value alignmentInt = 70 createAligned(rewriter, loc, allocatedInt, alignment); 71 alignedPtr = 72 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 73 } 74 75 return std::make_tuple(allocatedPtr, alignedPtr); 76 } 77 }; 78 79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 80 AlignedAllocOpLowering(LLVMTypeConverter &converter) 81 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 82 converter) {} 83 84 /// Returns the memref's element size in bytes using the data layout active at 85 /// `op`. 86 // TODO: there are other places where this is used. Expose publicly? 87 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 88 const DataLayout *layout = &defaultLayout; 89 if (const DataLayoutAnalysis *analysis = 90 getTypeConverter()->getDataLayoutAnalysis()) { 91 layout = &analysis->getAbove(op); 92 } 93 Type elementType = memRefType.getElementType(); 94 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 95 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 96 *layout); 97 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 98 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 99 memRefElementType, *layout); 100 return layout->getTypeSize(elementType); 101 } 102 103 /// Returns true if the memref size in bytes is known to be a multiple of 104 /// factor assuming the data layout active at `op`. 105 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 106 Operation *op) const { 107 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 108 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 109 if (type.isDynamic(type.getDimSize(i))) 110 continue; 111 sizeDivisor = sizeDivisor * type.getDimSize(i); 112 } 113 return sizeDivisor % factor == 0; 114 } 115 116 /// Returns the alignment to be used for the allocation call itself. 117 /// aligned_alloc requires the allocation size to be a power of two, and the 118 /// allocation size to be a multiple of alignment, 119 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 120 if (Optional<uint64_t> alignment = allocOp.alignment()) 121 return *alignment; 122 123 // Whenever we don't have alignment set, we will use an alignment 124 // consistent with the element type; since the allocation size has to be a 125 // power of two, we will bump to the next power of two if it already isn't. 126 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 127 return std::max(kMinAlignedAllocAlignment, 128 llvm::PowerOf2Ceil(eltSizeBytes)); 129 } 130 131 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 132 Location loc, Value sizeBytes, 133 Operation *op) const override { 134 // Heap allocations. 135 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 136 MemRefType memRefType = allocOp.getType(); 137 int64_t alignment = getAllocationAlignment(allocOp); 138 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 139 140 // aligned_alloc requires size to be a multiple of alignment; we will pad 141 // the size to the next multiple if necessary. 142 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 143 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 144 145 Type elementPtrType = this->getElementPtrType(memRefType); 146 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 147 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 148 auto results = 149 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 150 getVoidPtrType()); 151 Value allocatedPtr = 152 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 153 154 return std::make_tuple(allocatedPtr, allocatedPtr); 155 } 156 157 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 158 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 159 160 /// Default layout to use in absence of the corresponding analysis. 161 DataLayout defaultLayout; 162 }; 163 164 // Out of line definition, required till C++17. 165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 166 167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 168 AllocaOpLowering(LLVMTypeConverter &converter) 169 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 170 converter) {} 171 172 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 173 /// is set to null for stack allocations. `accessAlignment` is set if 174 /// alignment is needed post allocation (for eg. in conjunction with malloc). 175 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 176 Location loc, Value sizeBytes, 177 Operation *op) const override { 178 179 // With alloca, one gets a pointer to the element type right away. 180 // For stack allocations. 181 auto allocaOp = cast<memref::AllocaOp>(op); 182 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 183 184 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 185 loc, elementPtrType, sizeBytes, 186 allocaOp.alignment() ? *allocaOp.alignment() : 0); 187 188 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 189 } 190 }; 191 192 struct AllocaScopeOpLowering 193 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 194 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 195 196 LogicalResult 197 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, ArrayRef<Value> operands, 198 ConversionPatternRewriter &rewriter) const override { 199 OpBuilder::InsertionGuard guard(rewriter); 200 Location loc = allocaScopeOp.getLoc(); 201 202 // Split the current block before the AllocaScopeOp to create the inlining 203 // point. 204 auto *currentBlock = rewriter.getInsertionBlock(); 205 auto *remainingOpsBlock = 206 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 207 Block *continueBlock; 208 if (allocaScopeOp.getNumResults() == 0) { 209 continueBlock = remainingOpsBlock; 210 } else { 211 continueBlock = rewriter.createBlock(remainingOpsBlock, 212 allocaScopeOp.getResultTypes()); 213 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 214 } 215 216 // Inline body region. 217 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 218 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 219 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 220 221 // Save stack and then branch into the body of the region. 222 rewriter.setInsertionPointToEnd(currentBlock); 223 auto stackSaveOp = 224 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 225 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 226 227 // Replace the alloca_scope return with a branch that jumps out of the body. 228 // Stack restore before leaving the body region. 229 rewriter.setInsertionPointToEnd(afterBody); 230 auto returnOp = 231 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 232 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 233 returnOp, returnOp.results(), continueBlock); 234 235 // Insert stack restore before jumping out the body of the region. 236 rewriter.setInsertionPoint(branchOp); 237 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 238 239 // Replace the op with values return from the body region. 240 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 241 242 return success(); 243 } 244 }; 245 246 struct AssumeAlignmentOpLowering 247 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 248 using ConvertOpToLLVMPattern< 249 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 250 251 LogicalResult 252 matchAndRewrite(memref::AssumeAlignmentOp op, ArrayRef<Value> operands, 253 ConversionPatternRewriter &rewriter) const override { 254 memref::AssumeAlignmentOp::Adaptor transformed(operands); 255 Value memref = transformed.memref(); 256 unsigned alignment = op.alignment(); 257 auto loc = op.getLoc(); 258 259 MemRefDescriptor memRefDescriptor(memref); 260 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 261 262 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 263 // the asserted memref.alignedPtr isn't used anywhere else, as the real 264 // users like load/store/views always re-extract memref.alignedPtr as they 265 // get lowered. 266 // 267 // This relies on LLVM's CSE optimization (potentially after SROA), since 268 // after CSE all memref.alignedPtr instances get de-duplicated into the same 269 // pointer SSA value. 270 auto intPtrType = 271 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 272 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 273 Value mask = 274 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 275 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 276 rewriter.create<LLVM::AssumeOp>( 277 loc, rewriter.create<LLVM::ICmpOp>( 278 loc, LLVM::ICmpPredicate::eq, 279 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 280 281 rewriter.eraseOp(op); 282 return success(); 283 } 284 }; 285 286 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 287 // The memref descriptor being an SSA value, there is no need to clean it up 288 // in any way. 289 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 290 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 291 292 explicit DeallocOpLowering(LLVMTypeConverter &converter) 293 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 294 295 LogicalResult 296 matchAndRewrite(memref::DeallocOp op, ArrayRef<Value> operands, 297 ConversionPatternRewriter &rewriter) const override { 298 assert(operands.size() == 1 && "dealloc takes one operand"); 299 memref::DeallocOp::Adaptor transformed(operands); 300 301 // Insert the `free` declaration if it is not already present. 302 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 303 MemRefDescriptor memref(transformed.memref()); 304 Value casted = rewriter.create<LLVM::BitcastOp>( 305 op.getLoc(), getVoidPtrType(), 306 memref.allocatedPtr(rewriter, op.getLoc())); 307 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 308 op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted); 309 return success(); 310 } 311 }; 312 313 // A `dim` is converted to a constant for static sizes and to an access to the 314 // size stored in the memref descriptor for dynamic sizes. 315 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 316 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 317 318 LogicalResult 319 matchAndRewrite(memref::DimOp dimOp, ArrayRef<Value> operands, 320 ConversionPatternRewriter &rewriter) const override { 321 Type operandType = dimOp.source().getType(); 322 if (operandType.isa<UnrankedMemRefType>()) { 323 rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef( 324 operandType, dimOp, operands, rewriter)}); 325 326 return success(); 327 } 328 if (operandType.isa<MemRefType>()) { 329 rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef( 330 operandType, dimOp, operands, rewriter)}); 331 return success(); 332 } 333 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 334 } 335 336 private: 337 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 338 ArrayRef<Value> operands, 339 ConversionPatternRewriter &rewriter) const { 340 Location loc = dimOp.getLoc(); 341 memref::DimOp::Adaptor transformed(operands); 342 343 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 344 auto scalarMemRefType = 345 MemRefType::get({}, unrankedMemRefType.getElementType()); 346 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 347 348 // Extract pointer to the underlying ranked descriptor and bitcast it to a 349 // memref<element_type> descriptor pointer to minimize the number of GEP 350 // operations. 351 UnrankedMemRefDescriptor unrankedDesc(transformed.source()); 352 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 353 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 354 loc, 355 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 356 addressSpace), 357 underlyingRankedDesc); 358 359 // Get pointer to offset field of memref<element_type> descriptor. 360 Type indexPtrTy = LLVM::LLVMPointerType::get( 361 getTypeConverter()->getIndexType(), addressSpace); 362 Value two = rewriter.create<LLVM::ConstantOp>( 363 loc, typeConverter->convertType(rewriter.getI32Type()), 364 rewriter.getI32IntegerAttr(2)); 365 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 366 loc, indexPtrTy, scalarMemRefDescPtr, 367 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 368 369 // The size value that we have to extract can be obtained using GEPop with 370 // `dimOp.index() + 1` index argument. 371 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 372 loc, createIndexConstant(rewriter, loc, 1), transformed.index()); 373 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 374 ValueRange({idxPlusOne})); 375 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 376 } 377 378 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 379 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 380 return idx; 381 382 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 383 return constantOp.value().cast<IntegerAttr>().getValue().getSExtValue(); 384 385 return llvm::None; 386 } 387 388 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 389 ArrayRef<Value> operands, 390 ConversionPatternRewriter &rewriter) const { 391 Location loc = dimOp.getLoc(); 392 memref::DimOp::Adaptor transformed(operands); 393 // Take advantage if index is constant. 394 MemRefType memRefType = operandType.cast<MemRefType>(); 395 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 396 int64_t i = index.getValue(); 397 if (memRefType.isDynamicDim(i)) { 398 // extract dynamic size from the memref descriptor. 399 MemRefDescriptor descriptor(transformed.source()); 400 return descriptor.size(rewriter, loc, i); 401 } 402 // Use constant for static size. 403 int64_t dimSize = memRefType.getDimSize(i); 404 return createIndexConstant(rewriter, loc, dimSize); 405 } 406 Value index = transformed.index(); 407 int64_t rank = memRefType.getRank(); 408 MemRefDescriptor memrefDescriptor(transformed.source()); 409 return memrefDescriptor.size(rewriter, loc, index, rank); 410 } 411 }; 412 413 /// Returns the LLVM type of the global variable given the memref type `type`. 414 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 415 LLVMTypeConverter &typeConverter) { 416 // LLVM type for a global memref will be a multi-dimension array. For 417 // declarations or uninitialized global memrefs, we can potentially flatten 418 // this to a 1D array. However, for memref.global's with an initial value, 419 // we do not intend to flatten the ElementsAttribute when going from std -> 420 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 421 Type elementType = typeConverter.convertType(type.getElementType()); 422 Type arrayTy = elementType; 423 // Shape has the outermost dim at index 0, so need to walk it backwards 424 for (int64_t dim : llvm::reverse(type.getShape())) 425 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 426 return arrayTy; 427 } 428 429 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 430 struct GlobalMemrefOpLowering 431 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 432 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 433 434 LogicalResult 435 matchAndRewrite(memref::GlobalOp global, ArrayRef<Value> operands, 436 ConversionPatternRewriter &rewriter) const override { 437 MemRefType type = global.type().cast<MemRefType>(); 438 if (!isConvertibleAndHasIdentityMaps(type)) 439 return failure(); 440 441 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 442 443 LLVM::Linkage linkage = 444 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 445 446 Attribute initialValue = nullptr; 447 if (!global.isExternal() && !global.isUninitialized()) { 448 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 449 initialValue = elementsAttr; 450 451 // For scalar memrefs, the global variable created is of the element type, 452 // so unpack the elements attribute to extract the value. 453 if (type.getRank() == 0) 454 initialValue = elementsAttr.getValue({}); 455 } 456 457 rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 458 global, arrayTy, global.constant(), linkage, global.sym_name(), 459 initialValue, /*alignment=*/0, type.getMemorySpaceAsInt()); 460 return success(); 461 } 462 }; 463 464 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 465 /// the first element stashed into the descriptor. This reuses 466 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 467 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 468 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 469 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 470 converter) {} 471 472 /// Buffer "allocation" for memref.get_global op is getting the address of 473 /// the global variable referenced. 474 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 475 Location loc, Value sizeBytes, 476 Operation *op) const override { 477 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 478 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 479 unsigned memSpace = type.getMemorySpaceAsInt(); 480 481 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 482 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 483 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 484 485 // Get the address of the first element in the array by creating a GEP with 486 // the address of the GV as the base, and (rank + 1) number of 0 indices. 487 Type elementType = typeConverter->convertType(type.getElementType()); 488 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 489 490 SmallVector<Value, 4> operands = {addressOf}; 491 operands.insert(operands.end(), type.getRank() + 1, 492 createIndexConstant(rewriter, loc, 0)); 493 auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands); 494 495 // We do not expect the memref obtained using `memref.get_global` to be 496 // ever deallocated. Set the allocated pointer to be known bad value to 497 // help debug if that ever happens. 498 auto intPtrType = getIntPtrType(memSpace); 499 Value deadBeefConst = 500 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 501 auto deadBeefPtr = 502 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 503 504 // Both allocated and aligned pointers are same. We could potentially stash 505 // a nullptr for the allocated pointer since we do not expect any dealloc. 506 return std::make_tuple(deadBeefPtr, gep); 507 } 508 }; 509 510 // Common base for load and store operations on MemRefs. Restricts the match 511 // to supported MemRef types. Provides functionality to emit code accessing a 512 // specific element of the underlying data buffer. 513 template <typename Derived> 514 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 515 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 516 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 517 using Base = LoadStoreOpLowering<Derived>; 518 519 LogicalResult match(Derived op) const override { 520 MemRefType type = op.getMemRefType(); 521 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 522 } 523 }; 524 525 // Load operation is lowered to obtaining a pointer to the indexed element 526 // and loading it. 527 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 528 using Base::Base; 529 530 LogicalResult 531 matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands, 532 ConversionPatternRewriter &rewriter) const override { 533 memref::LoadOp::Adaptor transformed(operands); 534 auto type = loadOp.getMemRefType(); 535 536 Value dataPtr = 537 getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(), 538 transformed.indices(), rewriter); 539 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 540 return success(); 541 } 542 }; 543 544 // Store operation is lowered to obtaining a pointer to the indexed element, 545 // and storing the given value to it. 546 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 547 using Base::Base; 548 549 LogicalResult 550 matchAndRewrite(memref::StoreOp op, ArrayRef<Value> operands, 551 ConversionPatternRewriter &rewriter) const override { 552 auto type = op.getMemRefType(); 553 memref::StoreOp::Adaptor transformed(operands); 554 555 Value dataPtr = 556 getStridedElementPtr(op.getLoc(), type, transformed.memref(), 557 transformed.indices(), rewriter); 558 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(), 559 dataPtr); 560 return success(); 561 } 562 }; 563 564 // The prefetch operation is lowered in a way similar to the load operation 565 // except that the llvm.prefetch operation is used for replacement. 566 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 567 using Base::Base; 568 569 LogicalResult 570 matchAndRewrite(memref::PrefetchOp prefetchOp, ArrayRef<Value> operands, 571 ConversionPatternRewriter &rewriter) const override { 572 memref::PrefetchOp::Adaptor transformed(operands); 573 auto type = prefetchOp.getMemRefType(); 574 auto loc = prefetchOp.getLoc(); 575 576 Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(), 577 transformed.indices(), rewriter); 578 579 // Replace with llvm.prefetch. 580 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 581 auto isWrite = rewriter.create<LLVM::ConstantOp>( 582 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 583 auto localityHint = rewriter.create<LLVM::ConstantOp>( 584 loc, llvmI32Type, 585 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 586 auto isData = rewriter.create<LLVM::ConstantOp>( 587 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 588 589 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 590 localityHint, isData); 591 return success(); 592 } 593 }; 594 595 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 596 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 597 598 LogicalResult match(memref::CastOp memRefCastOp) const override { 599 Type srcType = memRefCastOp.getOperand().getType(); 600 Type dstType = memRefCastOp.getType(); 601 602 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 603 // used for type erasure. For now they must preserve underlying element type 604 // and require source and result type to have the same rank. Therefore, 605 // perform a sanity check that the underlying structs are the same. Once op 606 // semantics are relaxed we can revisit. 607 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 608 return success(typeConverter->convertType(srcType) == 609 typeConverter->convertType(dstType)); 610 611 // At least one of the operands is unranked type 612 assert(srcType.isa<UnrankedMemRefType>() || 613 dstType.isa<UnrankedMemRefType>()); 614 615 // Unranked to unranked cast is disallowed 616 return !(srcType.isa<UnrankedMemRefType>() && 617 dstType.isa<UnrankedMemRefType>()) 618 ? success() 619 : failure(); 620 } 621 622 void rewrite(memref::CastOp memRefCastOp, ArrayRef<Value> operands, 623 ConversionPatternRewriter &rewriter) const override { 624 memref::CastOp::Adaptor transformed(operands); 625 626 auto srcType = memRefCastOp.getOperand().getType(); 627 auto dstType = memRefCastOp.getType(); 628 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 629 auto loc = memRefCastOp.getLoc(); 630 631 // For ranked/ranked case, just keep the original descriptor. 632 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 633 return rewriter.replaceOp(memRefCastOp, {transformed.source()}); 634 635 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 636 // Casting ranked to unranked memref type 637 // Set the rank in the destination from the memref type 638 // Allocate space on the stack and copy the src memref descriptor 639 // Set the ptr in the destination to the stack space 640 auto srcMemRefType = srcType.cast<MemRefType>(); 641 int64_t rank = srcMemRefType.getRank(); 642 // ptr = AllocaOp sizeof(MemRefDescriptor) 643 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 644 loc, transformed.source(), rewriter); 645 // voidptr = BitCastOp srcType* to void* 646 auto voidPtr = 647 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 648 .getResult(); 649 // rank = ConstantOp srcRank 650 auto rankVal = rewriter.create<LLVM::ConstantOp>( 651 loc, typeConverter->convertType(rewriter.getIntegerType(64)), 652 rewriter.getI64IntegerAttr(rank)); 653 // undef = UndefOp 654 UnrankedMemRefDescriptor memRefDesc = 655 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 656 // d1 = InsertValueOp undef, rank, 0 657 memRefDesc.setRank(rewriter, loc, rankVal); 658 // d2 = InsertValueOp d1, voidptr, 1 659 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 660 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 661 662 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 663 // Casting from unranked type to ranked. 664 // The operation is assumed to be doing a correct cast. If the destination 665 // type mismatches the unranked the type, it is undefined behavior. 666 UnrankedMemRefDescriptor memRefDesc(transformed.source()); 667 // ptr = ExtractValueOp src, 1 668 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 669 // castPtr = BitCastOp i8* to structTy* 670 auto castPtr = 671 rewriter 672 .create<LLVM::BitcastOp>( 673 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 674 .getResult(); 675 // struct = LoadOp castPtr 676 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 677 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 678 } else { 679 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 680 } 681 } 682 }; 683 684 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 685 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 686 687 LogicalResult 688 matchAndRewrite(memref::CopyOp op, ArrayRef<Value> operands, 689 ConversionPatternRewriter &rewriter) const override { 690 auto loc = op.getLoc(); 691 memref::CopyOp::Adaptor adaptor(operands); 692 auto srcType = op.source().getType().cast<BaseMemRefType>(); 693 auto targetType = op.target().getType().cast<BaseMemRefType>(); 694 695 // First make sure we have an unranked memref descriptor representation. 696 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 697 auto rank = rewriter.create<LLVM::ConstantOp>( 698 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 699 auto *typeConverter = getTypeConverter(); 700 auto ptr = 701 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 702 auto voidPtr = 703 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 704 .getResult(); 705 auto unrankedType = 706 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 707 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 708 unrankedType, 709 ValueRange{rank, voidPtr}); 710 }; 711 712 Value unrankedSource = srcType.hasRank() 713 ? makeUnranked(adaptor.source(), srcType) 714 : adaptor.source(); 715 Value unrankedTarget = targetType.hasRank() 716 ? makeUnranked(adaptor.target(), targetType) 717 : adaptor.target(); 718 719 // Now promote the unranked descriptors to the stack. 720 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 721 rewriter.getIndexAttr(1)); 722 auto promote = [&](Value desc) { 723 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 724 auto allocated = 725 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 726 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 727 return allocated; 728 }; 729 730 auto sourcePtr = promote(unrankedSource); 731 auto targetPtr = promote(unrankedTarget); 732 733 auto elemSize = rewriter.create<LLVM::ConstantOp>( 734 loc, getIndexType(), 735 rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8)); 736 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 737 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 738 rewriter.create<LLVM::CallOp>(loc, copyFn, 739 ValueRange{elemSize, sourcePtr, targetPtr}); 740 rewriter.eraseOp(op); 741 742 return success(); 743 } 744 }; 745 746 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 747 /// memref type. In unranked case, the fields are extracted from the underlying 748 /// ranked descriptor. 749 static void extractPointersAndOffset(Location loc, 750 ConversionPatternRewriter &rewriter, 751 LLVMTypeConverter &typeConverter, 752 Value originalOperand, 753 Value convertedOperand, 754 Value *allocatedPtr, Value *alignedPtr, 755 Value *offset = nullptr) { 756 Type operandType = originalOperand.getType(); 757 if (operandType.isa<MemRefType>()) { 758 MemRefDescriptor desc(convertedOperand); 759 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 760 *alignedPtr = desc.alignedPtr(rewriter, loc); 761 if (offset != nullptr) 762 *offset = desc.offset(rewriter, loc); 763 return; 764 } 765 766 unsigned memorySpace = 767 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 768 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 769 Type llvmElementType = typeConverter.convertType(elementType); 770 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 771 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 772 773 // Extract pointer to the underlying ranked memref descriptor and cast it to 774 // ElemType**. 775 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 776 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 777 778 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 779 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 780 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 781 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 782 if (offset != nullptr) { 783 *offset = UnrankedMemRefDescriptor::offset( 784 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 785 } 786 } 787 788 struct MemRefReinterpretCastOpLowering 789 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 790 using ConvertOpToLLVMPattern< 791 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 792 793 LogicalResult 794 matchAndRewrite(memref::ReinterpretCastOp castOp, ArrayRef<Value> operands, 795 ConversionPatternRewriter &rewriter) const override { 796 memref::ReinterpretCastOp::Adaptor adaptor(operands, 797 castOp->getAttrDictionary()); 798 Type srcType = castOp.source().getType(); 799 800 Value descriptor; 801 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 802 adaptor, &descriptor))) 803 return failure(); 804 rewriter.replaceOp(castOp, {descriptor}); 805 return success(); 806 } 807 808 private: 809 LogicalResult convertSourceMemRefToDescriptor( 810 ConversionPatternRewriter &rewriter, Type srcType, 811 memref::ReinterpretCastOp castOp, 812 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 813 MemRefType targetMemRefType = 814 castOp.getResult().getType().cast<MemRefType>(); 815 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 816 .dyn_cast_or_null<LLVM::LLVMStructType>(); 817 if (!llvmTargetDescriptorTy) 818 return failure(); 819 820 // Create descriptor. 821 Location loc = castOp.getLoc(); 822 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 823 824 // Set allocated and aligned pointers. 825 Value allocatedPtr, alignedPtr; 826 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 827 castOp.source(), adaptor.source(), &allocatedPtr, 828 &alignedPtr); 829 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 830 desc.setAlignedPtr(rewriter, loc, alignedPtr); 831 832 // Set offset. 833 if (castOp.isDynamicOffset(0)) 834 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 835 else 836 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 837 838 // Set sizes and strides. 839 unsigned dynSizeId = 0; 840 unsigned dynStrideId = 0; 841 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 842 if (castOp.isDynamicSize(i)) 843 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 844 else 845 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 846 847 if (castOp.isDynamicStride(i)) 848 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 849 else 850 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 851 } 852 *descriptor = desc; 853 return success(); 854 } 855 }; 856 857 struct MemRefReshapeOpLowering 858 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 859 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 860 861 LogicalResult 862 matchAndRewrite(memref::ReshapeOp reshapeOp, ArrayRef<Value> operands, 863 ConversionPatternRewriter &rewriter) const override { 864 auto *op = reshapeOp.getOperation(); 865 memref::ReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary()); 866 Type srcType = reshapeOp.source().getType(); 867 868 Value descriptor; 869 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 870 adaptor, &descriptor))) 871 return failure(); 872 rewriter.replaceOp(op, {descriptor}); 873 return success(); 874 } 875 876 private: 877 LogicalResult 878 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 879 Type srcType, memref::ReshapeOp reshapeOp, 880 memref::ReshapeOp::Adaptor adaptor, 881 Value *descriptor) const { 882 // Conversion for statically-known shape args is performed via 883 // `memref_reinterpret_cast`. 884 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 885 if (shapeMemRefType.hasStaticShape()) 886 return failure(); 887 888 // The shape is a rank-1 tensor with unknown length. 889 Location loc = reshapeOp.getLoc(); 890 MemRefDescriptor shapeDesc(adaptor.shape()); 891 Value resultRank = shapeDesc.size(rewriter, loc, 0); 892 893 // Extract address space and element type. 894 auto targetType = 895 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 896 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 897 Type elementType = targetType.getElementType(); 898 899 // Create the unranked memref descriptor that holds the ranked one. The 900 // inner descriptor is allocated on stack. 901 auto targetDesc = UnrankedMemRefDescriptor::undef( 902 rewriter, loc, typeConverter->convertType(targetType)); 903 targetDesc.setRank(rewriter, loc, resultRank); 904 SmallVector<Value, 4> sizes; 905 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 906 targetDesc, sizes); 907 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 908 loc, getVoidPtrType(), sizes.front(), llvm::None); 909 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 910 911 // Extract pointers and offset from the source memref. 912 Value allocatedPtr, alignedPtr, offset; 913 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 914 reshapeOp.source(), adaptor.source(), 915 &allocatedPtr, &alignedPtr, &offset); 916 917 // Set pointers and offset. 918 Type llvmElementType = typeConverter->convertType(elementType); 919 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 920 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 921 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 922 elementPtrPtrType, allocatedPtr); 923 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 924 underlyingDescPtr, 925 elementPtrPtrType, alignedPtr); 926 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 927 underlyingDescPtr, elementPtrPtrType, 928 offset); 929 930 // Use the offset pointer as base for further addressing. Copy over the new 931 // shape and compute strides. For this, we create a loop from rank-1 to 0. 932 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 933 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 934 elementPtrPtrType); 935 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 936 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 937 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 938 Value oneIndex = createIndexConstant(rewriter, loc, 1); 939 Value resultRankMinusOne = 940 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 941 942 Block *initBlock = rewriter.getInsertionBlock(); 943 Type indexType = getTypeConverter()->getIndexType(); 944 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 945 946 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 947 {indexType, indexType}); 948 949 // Move the remaining initBlock ops to condBlock. 950 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 951 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 952 953 rewriter.setInsertionPointToEnd(initBlock); 954 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 955 condBlock); 956 rewriter.setInsertionPointToStart(condBlock); 957 Value indexArg = condBlock->getArgument(0); 958 Value strideArg = condBlock->getArgument(1); 959 960 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 961 Value pred = rewriter.create<LLVM::ICmpOp>( 962 loc, IntegerType::get(rewriter.getContext(), 1), 963 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 964 965 Block *bodyBlock = 966 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 967 rewriter.setInsertionPointToStart(bodyBlock); 968 969 // Copy size from shape to descriptor. 970 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 971 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 972 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 973 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 974 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 975 targetSizesBase, indexArg, size); 976 977 // Write stride value and compute next one. 978 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 979 targetStridesBase, indexArg, strideArg); 980 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 981 982 // Decrement loop counter and branch back. 983 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 984 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 985 condBlock); 986 987 Block *remainder = 988 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 989 990 // Hook up the cond exit to the remainder. 991 rewriter.setInsertionPointToEnd(condBlock); 992 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 993 llvm::None); 994 995 // Reset position to beginning of new remainder block. 996 rewriter.setInsertionPointToStart(remainder); 997 998 *descriptor = targetDesc; 999 return success(); 1000 } 1001 }; 1002 1003 // ReshapeOp creates a new view descriptor of the proper rank. 1004 // For now, the only conversion supported is for target MemRef with static sizes 1005 // and strides. 1006 template <typename ReshapeOp> 1007 class ReassociatingReshapeOpConversion 1008 : public ConvertOpToLLVMPattern<ReshapeOp> { 1009 public: 1010 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1011 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1012 1013 LogicalResult 1014 matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands, 1015 ConversionPatternRewriter &rewriter) const override { 1016 MemRefType dstType = reshapeOp.getResultType(); 1017 1018 if (!dstType.hasStaticShape()) 1019 return failure(); 1020 1021 int64_t offset; 1022 SmallVector<int64_t, 4> strides; 1023 auto res = getStridesAndOffset(dstType, strides, offset); 1024 if (failed(res) || llvm::any_of(strides, [](int64_t val) { 1025 return ShapedType::isDynamicStrideOrOffset(val); 1026 })) 1027 return failure(); 1028 1029 ReshapeOpAdaptor adaptor(operands); 1030 MemRefDescriptor baseDesc(adaptor.src()); 1031 Location loc = reshapeOp->getLoc(); 1032 auto desc = 1033 MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(), 1034 this->typeConverter->convertType(dstType)); 1035 desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc)); 1036 desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc)); 1037 desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc)); 1038 for (auto en : llvm::enumerate(dstType.getShape())) 1039 desc.setConstantSize(rewriter, loc, en.index(), en.value()); 1040 for (auto en : llvm::enumerate(strides)) 1041 desc.setConstantStride(rewriter, loc, en.index(), en.value()); 1042 rewriter.replaceOp(reshapeOp, {desc}); 1043 return success(); 1044 } 1045 }; 1046 /// Conversion pattern that transforms a subview op into: 1047 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1048 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1049 /// and stride. 1050 /// The subview op is replaced by the descriptor. 1051 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1052 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1053 1054 LogicalResult 1055 matchAndRewrite(memref::SubViewOp subViewOp, ArrayRef<Value> operands, 1056 ConversionPatternRewriter &rewriter) const override { 1057 auto loc = subViewOp.getLoc(); 1058 1059 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1060 auto sourceElementTy = 1061 typeConverter->convertType(sourceMemRefType.getElementType()); 1062 1063 auto viewMemRefType = subViewOp.getType(); 1064 auto inferredType = memref::SubViewOp::inferResultType( 1065 subViewOp.getSourceType(), 1066 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1067 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1068 extractFromI64ArrayAttr(subViewOp.static_strides())) 1069 .cast<MemRefType>(); 1070 auto targetElementTy = 1071 typeConverter->convertType(viewMemRefType.getElementType()); 1072 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1073 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1074 !LLVM::isCompatibleType(sourceElementTy) || 1075 !LLVM::isCompatibleType(targetElementTy) || 1076 !LLVM::isCompatibleType(targetDescTy)) 1077 return failure(); 1078 1079 // Extract the offset and strides from the type. 1080 int64_t offset; 1081 SmallVector<int64_t, 4> strides; 1082 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1083 if (failed(successStrides)) 1084 return failure(); 1085 1086 // Create the descriptor. 1087 if (!LLVM::isCompatibleType(operands.front().getType())) 1088 return failure(); 1089 MemRefDescriptor sourceMemRef(operands.front()); 1090 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1091 1092 // Copy the buffer pointer from the old descriptor to the new one. 1093 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1094 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1095 loc, 1096 LLVM::LLVMPointerType::get(targetElementTy, 1097 viewMemRefType.getMemorySpaceAsInt()), 1098 extracted); 1099 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1100 1101 // Copy the aligned pointer from the old descriptor to the new one. 1102 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1103 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1104 loc, 1105 LLVM::LLVMPointerType::get(targetElementTy, 1106 viewMemRefType.getMemorySpaceAsInt()), 1107 extracted); 1108 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1109 1110 auto shape = viewMemRefType.getShape(); 1111 auto inferredShape = inferredType.getShape(); 1112 size_t inferredShapeRank = inferredShape.size(); 1113 size_t resultShapeRank = shape.size(); 1114 llvm::SmallDenseSet<unsigned> unusedDims = 1115 computeRankReductionMask(inferredShape, shape).getValue(); 1116 1117 // Extract strides needed to compute offset. 1118 SmallVector<Value, 4> strideValues; 1119 strideValues.reserve(inferredShapeRank); 1120 for (unsigned i = 0; i < inferredShapeRank; ++i) 1121 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1122 1123 // Offset. 1124 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1125 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1126 targetMemRef.setConstantOffset(rewriter, loc, offset); 1127 } else { 1128 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1129 // `inferredShapeRank` may be larger than the number of offset operands 1130 // because of trailing semantics. In this case, the offset is guaranteed 1131 // to be interpreted as 0 and we can just skip the extra dimensions. 1132 for (unsigned i = 0, e = std::min(inferredShapeRank, 1133 subViewOp.getMixedOffsets().size()); 1134 i < e; ++i) { 1135 Value offset = 1136 // TODO: need OpFoldResult ODS adaptor to clean this up. 1137 subViewOp.isDynamicOffset(i) 1138 ? operands[subViewOp.getIndexOfDynamicOffset(i)] 1139 : rewriter.create<LLVM::ConstantOp>( 1140 loc, llvmIndexType, 1141 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1142 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1143 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1144 } 1145 targetMemRef.setOffset(rewriter, loc, baseOffset); 1146 } 1147 1148 // Update sizes and strides. 1149 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1150 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1151 assert(mixedSizes.size() == mixedStrides.size() && 1152 "expected sizes and strides of equal length"); 1153 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1154 i >= 0 && j >= 0; --i) { 1155 if (unusedDims.contains(i)) 1156 continue; 1157 1158 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1159 // In this case, the size is guaranteed to be interpreted as Dim and the 1160 // stride as 1. 1161 Value size, stride; 1162 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1163 // If the static size is available, use it directly. This is similar to 1164 // the folding of dim(constant-op) but removes the need for dim to be 1165 // aware of LLVM constants and for this pass to be aware of std 1166 // constants. 1167 int64_t staticSize = 1168 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1169 if (staticSize != ShapedType::kDynamicSize) { 1170 size = rewriter.create<LLVM::ConstantOp>( 1171 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1172 } else { 1173 Value pos = rewriter.create<LLVM::ConstantOp>( 1174 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1175 Value dim = 1176 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1177 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1178 loc, llvmIndexType, dim); 1179 size = cast.getResult(0); 1180 } 1181 stride = rewriter.create<LLVM::ConstantOp>( 1182 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1183 } else { 1184 // TODO: need OpFoldResult ODS adaptor to clean this up. 1185 size = 1186 subViewOp.isDynamicSize(i) 1187 ? operands[subViewOp.getIndexOfDynamicSize(i)] 1188 : rewriter.create<LLVM::ConstantOp>( 1189 loc, llvmIndexType, 1190 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1191 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1192 stride = rewriter.create<LLVM::ConstantOp>( 1193 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1194 } else { 1195 stride = subViewOp.isDynamicStride(i) 1196 ? operands[subViewOp.getIndexOfDynamicStride(i)] 1197 : rewriter.create<LLVM::ConstantOp>( 1198 loc, llvmIndexType, 1199 rewriter.getI64IntegerAttr( 1200 subViewOp.getStaticStride(i))); 1201 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1202 } 1203 } 1204 targetMemRef.setSize(rewriter, loc, j, size); 1205 targetMemRef.setStride(rewriter, loc, j, stride); 1206 j--; 1207 } 1208 1209 rewriter.replaceOp(subViewOp, {targetMemRef}); 1210 return success(); 1211 } 1212 }; 1213 1214 /// Conversion pattern that transforms a transpose op into: 1215 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1216 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1217 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1218 /// and stride. Size and stride are permutations of the original values. 1219 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1220 /// The transpose op is replaced by the alloca'ed pointer. 1221 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1222 public: 1223 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1224 1225 LogicalResult 1226 matchAndRewrite(memref::TransposeOp transposeOp, ArrayRef<Value> operands, 1227 ConversionPatternRewriter &rewriter) const override { 1228 auto loc = transposeOp.getLoc(); 1229 memref::TransposeOpAdaptor adaptor(operands); 1230 MemRefDescriptor viewMemRef(adaptor.in()); 1231 1232 // No permutation, early exit. 1233 if (transposeOp.permutation().isIdentity()) 1234 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1235 1236 auto targetMemRef = MemRefDescriptor::undef( 1237 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1238 1239 // Copy the base and aligned pointers from the old descriptor to the new 1240 // one. 1241 targetMemRef.setAllocatedPtr(rewriter, loc, 1242 viewMemRef.allocatedPtr(rewriter, loc)); 1243 targetMemRef.setAlignedPtr(rewriter, loc, 1244 viewMemRef.alignedPtr(rewriter, loc)); 1245 1246 // Copy the offset pointer from the old descriptor to the new one. 1247 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1248 1249 // Iterate over the dimensions and apply size/stride permutation. 1250 for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { 1251 int sourcePos = en.index(); 1252 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1253 targetMemRef.setSize(rewriter, loc, targetPos, 1254 viewMemRef.size(rewriter, loc, sourcePos)); 1255 targetMemRef.setStride(rewriter, loc, targetPos, 1256 viewMemRef.stride(rewriter, loc, sourcePos)); 1257 } 1258 1259 rewriter.replaceOp(transposeOp, {targetMemRef}); 1260 return success(); 1261 } 1262 }; 1263 1264 /// Conversion pattern that transforms an op into: 1265 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1266 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1267 /// and stride. 1268 /// The view op is replaced by the descriptor. 1269 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1270 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1271 1272 // Build and return the value for the idx^th shape dimension, either by 1273 // returning the constant shape dimension or counting the proper dynamic size. 1274 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1275 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1276 unsigned idx) const { 1277 assert(idx < shape.size()); 1278 if (!ShapedType::isDynamic(shape[idx])) 1279 return createIndexConstant(rewriter, loc, shape[idx]); 1280 // Count the number of dynamic dims in range [0, idx] 1281 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1282 return ShapedType::isDynamic(v); 1283 }); 1284 return dynamicSizes[nDynamic]; 1285 } 1286 1287 // Build and return the idx^th stride, either by returning the constant stride 1288 // or by computing the dynamic stride from the current `runningStride` and 1289 // `nextSize`. The caller should keep a running stride and update it with the 1290 // result returned by this function. 1291 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1292 ArrayRef<int64_t> strides, Value nextSize, 1293 Value runningStride, unsigned idx) const { 1294 assert(idx < strides.size()); 1295 if (!MemRefType::isDynamicStrideOrOffset(strides[idx])) 1296 return createIndexConstant(rewriter, loc, strides[idx]); 1297 if (nextSize) 1298 return runningStride 1299 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1300 : nextSize; 1301 assert(!runningStride); 1302 return createIndexConstant(rewriter, loc, 1); 1303 } 1304 1305 LogicalResult 1306 matchAndRewrite(memref::ViewOp viewOp, ArrayRef<Value> operands, 1307 ConversionPatternRewriter &rewriter) const override { 1308 auto loc = viewOp.getLoc(); 1309 memref::ViewOpAdaptor adaptor(operands); 1310 1311 auto viewMemRefType = viewOp.getType(); 1312 auto targetElementTy = 1313 typeConverter->convertType(viewMemRefType.getElementType()); 1314 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1315 if (!targetDescTy || !targetElementTy || 1316 !LLVM::isCompatibleType(targetElementTy) || 1317 !LLVM::isCompatibleType(targetDescTy)) 1318 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1319 failure(); 1320 1321 int64_t offset; 1322 SmallVector<int64_t, 4> strides; 1323 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1324 if (failed(successStrides)) 1325 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1326 assert(offset == 0 && "expected offset to be 0"); 1327 1328 // Create the descriptor. 1329 MemRefDescriptor sourceMemRef(adaptor.source()); 1330 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1331 1332 // Field 1: Copy the allocated pointer, used for malloc/free. 1333 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1334 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1335 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1336 loc, 1337 LLVM::LLVMPointerType::get(targetElementTy, 1338 srcMemRefType.getMemorySpaceAsInt()), 1339 allocatedPtr); 1340 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1341 1342 // Field 2: Copy the actual aligned pointer to payload. 1343 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1344 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1345 alignedPtr, adaptor.byte_shift()); 1346 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1347 loc, 1348 LLVM::LLVMPointerType::get(targetElementTy, 1349 srcMemRefType.getMemorySpaceAsInt()), 1350 alignedPtr); 1351 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1352 1353 // Field 3: The offset in the resulting type must be 0. This is because of 1354 // the type change: an offset on srcType* may not be expressible as an 1355 // offset on dstType*. 1356 targetMemRef.setOffset(rewriter, loc, 1357 createIndexConstant(rewriter, loc, offset)); 1358 1359 // Early exit for 0-D corner case. 1360 if (viewMemRefType.getRank() == 0) 1361 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1362 1363 // Fields 4 and 5: Update sizes and strides. 1364 if (strides.back() != 1) 1365 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1366 failure(); 1367 Value stride = nullptr, nextSize = nullptr; 1368 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1369 // Update size. 1370 Value size = 1371 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1372 targetMemRef.setSize(rewriter, loc, i, size); 1373 // Update stride. 1374 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1375 targetMemRef.setStride(rewriter, loc, i, stride); 1376 nextSize = size; 1377 } 1378 1379 rewriter.replaceOp(viewOp, {targetMemRef}); 1380 return success(); 1381 } 1382 }; 1383 1384 } // namespace 1385 1386 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1387 RewritePatternSet &patterns) { 1388 // clang-format off 1389 patterns.add< 1390 AllocaOpLowering, 1391 AllocaScopeOpLowering, 1392 AssumeAlignmentOpLowering, 1393 DimOpLowering, 1394 DeallocOpLowering, 1395 GlobalMemrefOpLowering, 1396 GetGlobalMemrefOpLowering, 1397 LoadOpLowering, 1398 MemRefCastOpLowering, 1399 MemRefCopyOpLowering, 1400 MemRefReinterpretCastOpLowering, 1401 MemRefReshapeOpLowering, 1402 PrefetchOpLowering, 1403 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1404 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1405 StoreOpLowering, 1406 SubViewOpLowering, 1407 TransposeOpLowering, 1408 ViewOpLowering>(converter); 1409 // clang-format on 1410 auto allocLowering = converter.getOptions().allocLowering; 1411 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1412 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1413 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1414 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1415 } 1416 1417 namespace { 1418 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1419 MemRefToLLVMPass() = default; 1420 1421 void runOnOperation() override { 1422 Operation *op = getOperation(); 1423 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1424 LowerToLLVMOptions options(&getContext(), 1425 dataLayoutAnalysis.getAtOrAbove(op)); 1426 options.allocLowering = 1427 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1428 : LowerToLLVMOptions::AllocLowering::Malloc); 1429 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1430 options.overrideIndexBitwidth(indexBitwidth); 1431 1432 LLVMTypeConverter typeConverter(&getContext(), options, 1433 &dataLayoutAnalysis); 1434 RewritePatternSet patterns(&getContext()); 1435 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1436 LLVMConversionTarget target(getContext()); 1437 target.addLegalOp<FuncOp>(); 1438 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1439 signalPassFailure(); 1440 } 1441 }; 1442 } // namespace 1443 1444 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1445 return std::make_unique<MemRefToLLVMPass>(); 1446 } 1447