1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "llvm/ADT/SmallBitVector.h" 22 23 using namespace mlir; 24 25 namespace { 26 27 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 28 AllocOpLowering(LLVMTypeConverter &converter) 29 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 30 converter) {} 31 32 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 33 Location loc, Value sizeBytes, 34 Operation *op) const override { 35 // Heap allocations. 36 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 37 MemRefType memRefType = allocOp.getType(); 38 39 Value alignment; 40 if (auto alignmentAttr = allocOp.alignment()) { 41 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 42 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 43 // In the case where no alignment is specified, we may want to override 44 // `malloc's` behavior. `malloc` typically aligns at the size of the 45 // biggest scalar on a target HW. For non-scalars, use the natural 46 // alignment of the LLVM type given by the LLVM DataLayout. 47 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 48 } 49 50 if (alignment) { 51 // Adjust the allocation size to consider alignment. 52 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 53 } 54 55 // Allocate the underlying buffer and store a pointer to it in the MemRef 56 // descriptor. 57 Type elementPtrType = this->getElementPtrType(memRefType); 58 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 59 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 60 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 61 getVoidPtrType()); 62 Value allocatedPtr = 63 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 64 65 Value alignedPtr = allocatedPtr; 66 if (alignment) { 67 // Compute the aligned type pointer. 68 Value allocatedInt = 69 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 70 Value alignmentInt = 71 createAligned(rewriter, loc, allocatedInt, alignment); 72 alignedPtr = 73 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 74 } 75 76 return std::make_tuple(allocatedPtr, alignedPtr); 77 } 78 }; 79 80 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 81 AlignedAllocOpLowering(LLVMTypeConverter &converter) 82 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 83 converter) {} 84 85 /// Returns the memref's element size in bytes using the data layout active at 86 /// `op`. 87 // TODO: there are other places where this is used. Expose publicly? 88 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 89 const DataLayout *layout = &defaultLayout; 90 if (const DataLayoutAnalysis *analysis = 91 getTypeConverter()->getDataLayoutAnalysis()) { 92 layout = &analysis->getAbove(op); 93 } 94 Type elementType = memRefType.getElementType(); 95 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 96 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 97 *layout); 98 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 99 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 100 memRefElementType, *layout); 101 return layout->getTypeSize(elementType); 102 } 103 104 /// Returns true if the memref size in bytes is known to be a multiple of 105 /// factor assuming the data layout active at `op`. 106 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 107 Operation *op) const { 108 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 109 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 110 if (ShapedType::isDynamic(type.getDimSize(i))) 111 continue; 112 sizeDivisor = sizeDivisor * type.getDimSize(i); 113 } 114 return sizeDivisor % factor == 0; 115 } 116 117 /// Returns the alignment to be used for the allocation call itself. 118 /// aligned_alloc requires the allocation size to be a power of two, and the 119 /// allocation size to be a multiple of alignment, 120 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 121 if (Optional<uint64_t> alignment = allocOp.alignment()) 122 return *alignment; 123 124 // Whenever we don't have alignment set, we will use an alignment 125 // consistent with the element type; since the allocation size has to be a 126 // power of two, we will bump to the next power of two if it already isn't. 127 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 128 return std::max(kMinAlignedAllocAlignment, 129 llvm::PowerOf2Ceil(eltSizeBytes)); 130 } 131 132 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 133 Location loc, Value sizeBytes, 134 Operation *op) const override { 135 // Heap allocations. 136 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 137 MemRefType memRefType = allocOp.getType(); 138 int64_t alignment = getAllocationAlignment(allocOp); 139 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 140 141 // aligned_alloc requires size to be a multiple of alignment; we will pad 142 // the size to the next multiple if necessary. 143 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 144 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 145 146 Type elementPtrType = this->getElementPtrType(memRefType); 147 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 148 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 149 auto results = 150 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 151 getVoidPtrType()); 152 Value allocatedPtr = 153 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 154 155 return std::make_tuple(allocatedPtr, allocatedPtr); 156 } 157 158 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 159 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 160 161 /// Default layout to use in absence of the corresponding analysis. 162 DataLayout defaultLayout; 163 }; 164 165 // Out of line definition, required till C++17. 166 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 167 168 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 169 AllocaOpLowering(LLVMTypeConverter &converter) 170 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 171 converter) {} 172 173 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 174 /// is set to null for stack allocations. `accessAlignment` is set if 175 /// alignment is needed post allocation (for eg. in conjunction with malloc). 176 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 177 Location loc, Value sizeBytes, 178 Operation *op) const override { 179 180 // With alloca, one gets a pointer to the element type right away. 181 // For stack allocations. 182 auto allocaOp = cast<memref::AllocaOp>(op); 183 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 184 185 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 186 loc, elementPtrType, sizeBytes, 187 allocaOp.alignment() ? *allocaOp.alignment() : 0); 188 189 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 190 } 191 }; 192 193 struct AllocaScopeOpLowering 194 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 195 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 196 197 LogicalResult 198 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 199 ConversionPatternRewriter &rewriter) const override { 200 OpBuilder::InsertionGuard guard(rewriter); 201 Location loc = allocaScopeOp.getLoc(); 202 203 // Split the current block before the AllocaScopeOp to create the inlining 204 // point. 205 auto *currentBlock = rewriter.getInsertionBlock(); 206 auto *remainingOpsBlock = 207 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 208 Block *continueBlock; 209 if (allocaScopeOp.getNumResults() == 0) { 210 continueBlock = remainingOpsBlock; 211 } else { 212 continueBlock = rewriter.createBlock( 213 remainingOpsBlock, allocaScopeOp.getResultTypes(), 214 SmallVector<Location>(allocaScopeOp->getNumResults(), 215 allocaScopeOp.getLoc())); 216 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 217 } 218 219 // Inline body region. 220 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 221 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 222 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 223 224 // Save stack and then branch into the body of the region. 225 rewriter.setInsertionPointToEnd(currentBlock); 226 auto stackSaveOp = 227 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 228 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 229 230 // Replace the alloca_scope return with a branch that jumps out of the body. 231 // Stack restore before leaving the body region. 232 rewriter.setInsertionPointToEnd(afterBody); 233 auto returnOp = 234 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 235 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 236 returnOp, returnOp.results(), continueBlock); 237 238 // Insert stack restore before jumping out the body of the region. 239 rewriter.setInsertionPoint(branchOp); 240 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 241 242 // Replace the op with values return from the body region. 243 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 244 245 return success(); 246 } 247 }; 248 249 struct AssumeAlignmentOpLowering 250 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 251 using ConvertOpToLLVMPattern< 252 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 253 254 LogicalResult 255 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 256 ConversionPatternRewriter &rewriter) const override { 257 Value memref = adaptor.memref(); 258 unsigned alignment = op.alignment(); 259 auto loc = op.getLoc(); 260 261 MemRefDescriptor memRefDescriptor(memref); 262 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 263 264 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 265 // the asserted memref.alignedPtr isn't used anywhere else, as the real 266 // users like load/store/views always re-extract memref.alignedPtr as they 267 // get lowered. 268 // 269 // This relies on LLVM's CSE optimization (potentially after SROA), since 270 // after CSE all memref.alignedPtr instances get de-duplicated into the same 271 // pointer SSA value. 272 auto intPtrType = 273 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 274 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 275 Value mask = 276 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 277 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 278 rewriter.create<LLVM::AssumeOp>( 279 loc, rewriter.create<LLVM::ICmpOp>( 280 loc, LLVM::ICmpPredicate::eq, 281 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 282 283 rewriter.eraseOp(op); 284 return success(); 285 } 286 }; 287 288 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 289 // The memref descriptor being an SSA value, there is no need to clean it up 290 // in any way. 291 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 292 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 293 294 explicit DeallocOpLowering(LLVMTypeConverter &converter) 295 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 296 297 LogicalResult 298 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 299 ConversionPatternRewriter &rewriter) const override { 300 // Insert the `free` declaration if it is not already present. 301 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 302 MemRefDescriptor memref(adaptor.memref()); 303 Value casted = rewriter.create<LLVM::BitcastOp>( 304 op.getLoc(), getVoidPtrType(), 305 memref.allocatedPtr(rewriter, op.getLoc())); 306 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 307 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 308 return success(); 309 } 310 }; 311 312 // A `dim` is converted to a constant for static sizes and to an access to the 313 // size stored in the memref descriptor for dynamic sizes. 314 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 315 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 316 317 LogicalResult 318 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 319 ConversionPatternRewriter &rewriter) const override { 320 Type operandType = dimOp.source().getType(); 321 if (operandType.isa<UnrankedMemRefType>()) { 322 rewriter.replaceOp( 323 dimOp, {extractSizeOfUnrankedMemRef( 324 operandType, dimOp, adaptor.getOperands(), rewriter)}); 325 326 return success(); 327 } 328 if (operandType.isa<MemRefType>()) { 329 rewriter.replaceOp( 330 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 331 adaptor.getOperands(), rewriter)}); 332 return success(); 333 } 334 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 335 } 336 337 private: 338 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 339 OpAdaptor adaptor, 340 ConversionPatternRewriter &rewriter) const { 341 Location loc = dimOp.getLoc(); 342 343 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 344 auto scalarMemRefType = 345 MemRefType::get({}, unrankedMemRefType.getElementType()); 346 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 347 348 // Extract pointer to the underlying ranked descriptor and bitcast it to a 349 // memref<element_type> descriptor pointer to minimize the number of GEP 350 // operations. 351 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 352 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 353 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 354 loc, 355 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 356 addressSpace), 357 underlyingRankedDesc); 358 359 // Get pointer to offset field of memref<element_type> descriptor. 360 Type indexPtrTy = LLVM::LLVMPointerType::get( 361 getTypeConverter()->getIndexType(), addressSpace); 362 Value two = rewriter.create<LLVM::ConstantOp>( 363 loc, typeConverter->convertType(rewriter.getI32Type()), 364 rewriter.getI32IntegerAttr(2)); 365 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 366 loc, indexPtrTy, scalarMemRefDescPtr, 367 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 368 369 // The size value that we have to extract can be obtained using GEPop with 370 // `dimOp.index() + 1` index argument. 371 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 372 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 373 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 374 ValueRange({idxPlusOne})); 375 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 376 } 377 378 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 379 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 380 return idx; 381 382 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 383 return constantOp.getValue() 384 .cast<IntegerAttr>() 385 .getValue() 386 .getSExtValue(); 387 388 return llvm::None; 389 } 390 391 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 392 OpAdaptor adaptor, 393 ConversionPatternRewriter &rewriter) const { 394 Location loc = dimOp.getLoc(); 395 396 // Take advantage if index is constant. 397 MemRefType memRefType = operandType.cast<MemRefType>(); 398 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 399 int64_t i = index.getValue(); 400 if (memRefType.isDynamicDim(i)) { 401 // extract dynamic size from the memref descriptor. 402 MemRefDescriptor descriptor(adaptor.source()); 403 return descriptor.size(rewriter, loc, i); 404 } 405 // Use constant for static size. 406 int64_t dimSize = memRefType.getDimSize(i); 407 return createIndexConstant(rewriter, loc, dimSize); 408 } 409 Value index = adaptor.index(); 410 int64_t rank = memRefType.getRank(); 411 MemRefDescriptor memrefDescriptor(adaptor.source()); 412 return memrefDescriptor.size(rewriter, loc, index, rank); 413 } 414 }; 415 416 /// Common base for load and store operations on MemRefs. Restricts the match 417 /// to supported MemRef types. Provides functionality to emit code accessing a 418 /// specific element of the underlying data buffer. 419 template <typename Derived> 420 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 421 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 422 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 423 using Base = LoadStoreOpLowering<Derived>; 424 425 LogicalResult match(Derived op) const override { 426 MemRefType type = op.getMemRefType(); 427 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 428 } 429 }; 430 431 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 432 /// retried until it succeeds in atomically storing a new value into memory. 433 /// 434 /// +---------------------------------+ 435 /// | <code before the AtomicRMWOp> | 436 /// | <compute initial %loaded> | 437 /// | cf.br loop(%loaded) | 438 /// +---------------------------------+ 439 /// | 440 /// -------| | 441 /// | v v 442 /// | +--------------------------------+ 443 /// | | loop(%loaded): | 444 /// | | <body contents> | 445 /// | | %pair = cmpxchg | 446 /// | | %ok = %pair[0] | 447 /// | | %new = %pair[1] | 448 /// | | cf.cond_br %ok, end, loop(%new) | 449 /// | +--------------------------------+ 450 /// | | | 451 /// |----------- | 452 /// v 453 /// +--------------------------------+ 454 /// | end: | 455 /// | <code after the AtomicRMWOp> | 456 /// +--------------------------------+ 457 /// 458 struct GenericAtomicRMWOpLowering 459 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 460 using Base::Base; 461 462 LogicalResult 463 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 464 ConversionPatternRewriter &rewriter) const override { 465 auto loc = atomicOp.getLoc(); 466 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 467 468 // Split the block into initial, loop, and ending parts. 469 auto *initBlock = rewriter.getInsertionBlock(); 470 auto *loopBlock = rewriter.createBlock( 471 initBlock->getParent(), std::next(Region::iterator(initBlock)), 472 valueType, loc); 473 auto *endBlock = rewriter.createBlock( 474 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 475 476 // Operations range to be moved to `endBlock`. 477 auto opsToMoveStart = atomicOp->getIterator(); 478 auto opsToMoveEnd = initBlock->back().getIterator(); 479 480 // Compute the loaded value and branch to the loop block. 481 rewriter.setInsertionPointToEnd(initBlock); 482 auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); 483 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 484 adaptor.indices(), rewriter); 485 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 486 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 487 488 // Prepare the body of the loop block. 489 rewriter.setInsertionPointToStart(loopBlock); 490 491 // Clone the GenericAtomicRMWOp region and extract the result. 492 auto loopArgument = loopBlock->getArgument(0); 493 BlockAndValueMapping mapping; 494 mapping.map(atomicOp.getCurrentValue(), loopArgument); 495 Block &entryBlock = atomicOp.body().front(); 496 for (auto &nestedOp : entryBlock.without_terminator()) { 497 Operation *clone = rewriter.clone(nestedOp, mapping); 498 mapping.map(nestedOp.getResults(), clone->getResults()); 499 } 500 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 501 502 // Prepare the epilog of the loop block. 503 // Append the cmpxchg op to the end of the loop block. 504 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 505 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 506 auto boolType = IntegerType::get(rewriter.getContext(), 1); 507 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 508 {valueType, boolType}); 509 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 510 loc, pairType, dataPtr, loopArgument, result, successOrdering, 511 failureOrdering); 512 // Extract the %new_loaded and %ok values from the pair. 513 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 514 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 515 Value ok = rewriter.create<LLVM::ExtractValueOp>( 516 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 517 518 // Conditionally branch to the end or back to the loop depending on %ok. 519 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 520 loopBlock, newLoaded); 521 522 rewriter.setInsertionPointToEnd(endBlock); 523 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 524 std::next(opsToMoveEnd), rewriter); 525 526 // The 'result' of the atomic_rmw op is the newly loaded value. 527 rewriter.replaceOp(atomicOp, {newLoaded}); 528 529 return success(); 530 } 531 532 private: 533 // Clones a segment of ops [start, end) and erases the original. 534 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 535 Block::iterator start, Block::iterator end, 536 ConversionPatternRewriter &rewriter) const { 537 BlockAndValueMapping mapping; 538 mapping.map(oldResult, newResult); 539 SmallVector<Operation *, 2> opsToErase; 540 for (auto it = start; it != end; ++it) { 541 rewriter.clone(*it, mapping); 542 opsToErase.push_back(&*it); 543 } 544 for (auto *it : opsToErase) 545 rewriter.eraseOp(it); 546 } 547 }; 548 549 /// Returns the LLVM type of the global variable given the memref type `type`. 550 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 551 LLVMTypeConverter &typeConverter) { 552 // LLVM type for a global memref will be a multi-dimension array. For 553 // declarations or uninitialized global memrefs, we can potentially flatten 554 // this to a 1D array. However, for memref.global's with an initial value, 555 // we do not intend to flatten the ElementsAttribute when going from std -> 556 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 557 Type elementType = typeConverter.convertType(type.getElementType()); 558 Type arrayTy = elementType; 559 // Shape has the outermost dim at index 0, so need to walk it backwards 560 for (int64_t dim : llvm::reverse(type.getShape())) 561 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 562 return arrayTy; 563 } 564 565 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 566 struct GlobalMemrefOpLowering 567 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 568 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 569 570 LogicalResult 571 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 572 ConversionPatternRewriter &rewriter) const override { 573 MemRefType type = global.type(); 574 if (!isConvertibleAndHasIdentityMaps(type)) 575 return failure(); 576 577 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 578 579 LLVM::Linkage linkage = 580 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 581 582 Attribute initialValue = nullptr; 583 if (!global.isExternal() && !global.isUninitialized()) { 584 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 585 initialValue = elementsAttr; 586 587 // For scalar memrefs, the global variable created is of the element type, 588 // so unpack the elements attribute to extract the value. 589 if (type.getRank() == 0) 590 initialValue = elementsAttr.getSplatValue<Attribute>(); 591 } 592 593 uint64_t alignment = global.alignment().getValueOr(0); 594 595 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 596 global, arrayTy, global.constant(), linkage, global.sym_name(), 597 initialValue, alignment, type.getMemorySpaceAsInt()); 598 if (!global.isExternal() && global.isUninitialized()) { 599 Block *blk = new Block(); 600 newGlobal.getInitializerRegion().push_back(blk); 601 rewriter.setInsertionPointToStart(blk); 602 Value undef[] = { 603 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 604 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 605 } 606 return success(); 607 } 608 }; 609 610 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 611 /// the first element stashed into the descriptor. This reuses 612 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 613 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 614 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 615 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 616 converter) {} 617 618 /// Buffer "allocation" for memref.get_global op is getting the address of 619 /// the global variable referenced. 620 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 621 Location loc, Value sizeBytes, 622 Operation *op) const override { 623 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 624 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 625 unsigned memSpace = type.getMemorySpaceAsInt(); 626 627 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 628 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 629 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 630 631 // Get the address of the first element in the array by creating a GEP with 632 // the address of the GV as the base, and (rank + 1) number of 0 indices. 633 Type elementType = typeConverter->convertType(type.getElementType()); 634 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 635 636 SmallVector<Value> operands; 637 operands.insert(operands.end(), type.getRank() + 1, 638 createIndexConstant(rewriter, loc, 0)); 639 auto gep = 640 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 641 642 // We do not expect the memref obtained using `memref.get_global` to be 643 // ever deallocated. Set the allocated pointer to be known bad value to 644 // help debug if that ever happens. 645 auto intPtrType = getIntPtrType(memSpace); 646 Value deadBeefConst = 647 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 648 auto deadBeefPtr = 649 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 650 651 // Both allocated and aligned pointers are same. We could potentially stash 652 // a nullptr for the allocated pointer since we do not expect any dealloc. 653 return std::make_tuple(deadBeefPtr, gep); 654 } 655 }; 656 657 // Load operation is lowered to obtaining a pointer to the indexed element 658 // and loading it. 659 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 660 using Base::Base; 661 662 LogicalResult 663 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 664 ConversionPatternRewriter &rewriter) const override { 665 auto type = loadOp.getMemRefType(); 666 667 Value dataPtr = getStridedElementPtr( 668 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 669 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 670 return success(); 671 } 672 }; 673 674 // Store operation is lowered to obtaining a pointer to the indexed element, 675 // and storing the given value to it. 676 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 677 using Base::Base; 678 679 LogicalResult 680 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 681 ConversionPatternRewriter &rewriter) const override { 682 auto type = op.getMemRefType(); 683 684 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 685 adaptor.indices(), rewriter); 686 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 687 return success(); 688 } 689 }; 690 691 // The prefetch operation is lowered in a way similar to the load operation 692 // except that the llvm.prefetch operation is used for replacement. 693 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 694 using Base::Base; 695 696 LogicalResult 697 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 698 ConversionPatternRewriter &rewriter) const override { 699 auto type = prefetchOp.getMemRefType(); 700 auto loc = prefetchOp.getLoc(); 701 702 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 703 adaptor.indices(), rewriter); 704 705 // Replace with llvm.prefetch. 706 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 707 auto isWrite = rewriter.create<LLVM::ConstantOp>( 708 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 709 auto localityHint = rewriter.create<LLVM::ConstantOp>( 710 loc, llvmI32Type, 711 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 712 auto isData = rewriter.create<LLVM::ConstantOp>( 713 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 714 715 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 716 localityHint, isData); 717 return success(); 718 } 719 }; 720 721 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 722 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 723 724 LogicalResult 725 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 726 ConversionPatternRewriter &rewriter) const override { 727 Location loc = op.getLoc(); 728 Type operandType = op.memref().getType(); 729 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 730 UnrankedMemRefDescriptor desc(adaptor.memref()); 731 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 732 return success(); 733 } 734 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 735 rewriter.replaceOp( 736 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 737 return success(); 738 } 739 return failure(); 740 } 741 }; 742 743 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 744 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 745 746 LogicalResult match(memref::CastOp memRefCastOp) const override { 747 Type srcType = memRefCastOp.getOperand().getType(); 748 Type dstType = memRefCastOp.getType(); 749 750 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 751 // used for type erasure. For now they must preserve underlying element type 752 // and require source and result type to have the same rank. Therefore, 753 // perform a sanity check that the underlying structs are the same. Once op 754 // semantics are relaxed we can revisit. 755 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 756 return success(typeConverter->convertType(srcType) == 757 typeConverter->convertType(dstType)); 758 759 // At least one of the operands is unranked type 760 assert(srcType.isa<UnrankedMemRefType>() || 761 dstType.isa<UnrankedMemRefType>()); 762 763 // Unranked to unranked cast is disallowed 764 return !(srcType.isa<UnrankedMemRefType>() && 765 dstType.isa<UnrankedMemRefType>()) 766 ? success() 767 : failure(); 768 } 769 770 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 771 ConversionPatternRewriter &rewriter) const override { 772 auto srcType = memRefCastOp.getOperand().getType(); 773 auto dstType = memRefCastOp.getType(); 774 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 775 auto loc = memRefCastOp.getLoc(); 776 777 // For ranked/ranked case, just keep the original descriptor. 778 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 779 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 780 781 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 782 // Casting ranked to unranked memref type 783 // Set the rank in the destination from the memref type 784 // Allocate space on the stack and copy the src memref descriptor 785 // Set the ptr in the destination to the stack space 786 auto srcMemRefType = srcType.cast<MemRefType>(); 787 int64_t rank = srcMemRefType.getRank(); 788 // ptr = AllocaOp sizeof(MemRefDescriptor) 789 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 790 loc, adaptor.source(), rewriter); 791 // voidptr = BitCastOp srcType* to void* 792 auto voidPtr = 793 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 794 .getResult(); 795 // rank = ConstantOp srcRank 796 auto rankVal = rewriter.create<LLVM::ConstantOp>( 797 loc, typeConverter->convertType(rewriter.getIntegerType(64)), 798 rewriter.getI64IntegerAttr(rank)); 799 // undef = UndefOp 800 UnrankedMemRefDescriptor memRefDesc = 801 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 802 // d1 = InsertValueOp undef, rank, 0 803 memRefDesc.setRank(rewriter, loc, rankVal); 804 // d2 = InsertValueOp d1, voidptr, 1 805 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 806 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 807 808 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 809 // Casting from unranked type to ranked. 810 // The operation is assumed to be doing a correct cast. If the destination 811 // type mismatches the unranked the type, it is undefined behavior. 812 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 813 // ptr = ExtractValueOp src, 1 814 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 815 // castPtr = BitCastOp i8* to structTy* 816 auto castPtr = 817 rewriter 818 .create<LLVM::BitcastOp>( 819 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 820 .getResult(); 821 // struct = LoadOp castPtr 822 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 823 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 824 } else { 825 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 826 } 827 } 828 }; 829 830 /// Pattern to lower a `memref.copy` to llvm. 831 /// 832 /// For memrefs with identity layouts, the copy is lowered to the llvm 833 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 834 /// to the generic `MemrefCopyFn`. 835 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 836 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 837 838 LogicalResult 839 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 840 ConversionPatternRewriter &rewriter) const { 841 auto loc = op.getLoc(); 842 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 843 844 MemRefDescriptor srcDesc(adaptor.source()); 845 846 // Compute number of elements. 847 Value numElements = rewriter.create<LLVM::ConstantOp>( 848 loc, getIndexType(), rewriter.getIndexAttr(1)); 849 for (int pos = 0; pos < srcType.getRank(); ++pos) { 850 auto size = srcDesc.size(rewriter, loc, pos); 851 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 852 } 853 854 // Get element size. 855 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 856 // Compute total. 857 Value totalSize = 858 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 859 860 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 861 MemRefDescriptor targetDesc(adaptor.target()); 862 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 863 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 864 loc, typeConverter->convertType(rewriter.getI1Type()), 865 rewriter.getBoolAttr(false)); 866 rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize, 867 isVolatile); 868 rewriter.eraseOp(op); 869 870 return success(); 871 } 872 873 LogicalResult 874 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 875 ConversionPatternRewriter &rewriter) const { 876 auto loc = op.getLoc(); 877 auto srcType = op.source().getType().cast<BaseMemRefType>(); 878 auto targetType = op.target().getType().cast<BaseMemRefType>(); 879 880 // First make sure we have an unranked memref descriptor representation. 881 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 882 auto rank = rewriter.create<LLVM::ConstantOp>( 883 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 884 auto *typeConverter = getTypeConverter(); 885 auto ptr = 886 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 887 auto voidPtr = 888 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 889 .getResult(); 890 auto unrankedType = 891 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 892 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 893 unrankedType, 894 ValueRange{rank, voidPtr}); 895 }; 896 897 Value unrankedSource = srcType.hasRank() 898 ? makeUnranked(adaptor.source(), srcType) 899 : adaptor.source(); 900 Value unrankedTarget = targetType.hasRank() 901 ? makeUnranked(adaptor.target(), targetType) 902 : adaptor.target(); 903 904 // Now promote the unranked descriptors to the stack. 905 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 906 rewriter.getIndexAttr(1)); 907 auto promote = [&](Value desc) { 908 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 909 auto allocated = 910 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 911 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 912 return allocated; 913 }; 914 915 auto sourcePtr = promote(unrankedSource); 916 auto targetPtr = promote(unrankedTarget); 917 918 unsigned bitwidth = mlir::DataLayout::closest(op).getTypeSizeInBits( 919 srcType.getElementType()); 920 auto elemSize = rewriter.create<LLVM::ConstantOp>( 921 loc, getIndexType(), rewriter.getIndexAttr(bitwidth / 8)); 922 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 923 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 924 rewriter.create<LLVM::CallOp>(loc, copyFn, 925 ValueRange{elemSize, sourcePtr, targetPtr}); 926 rewriter.eraseOp(op); 927 928 return success(); 929 } 930 931 LogicalResult 932 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 933 ConversionPatternRewriter &rewriter) const override { 934 auto srcType = op.source().getType().cast<BaseMemRefType>(); 935 auto targetType = op.target().getType().cast<BaseMemRefType>(); 936 937 if (srcType.hasRank() && 938 srcType.cast<MemRefType>().getLayout().isIdentity() && 939 targetType.hasRank() && 940 targetType.cast<MemRefType>().getLayout().isIdentity()) 941 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 942 943 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 944 } 945 }; 946 947 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 948 /// memref type. In unranked case, the fields are extracted from the underlying 949 /// ranked descriptor. 950 static void extractPointersAndOffset(Location loc, 951 ConversionPatternRewriter &rewriter, 952 LLVMTypeConverter &typeConverter, 953 Value originalOperand, 954 Value convertedOperand, 955 Value *allocatedPtr, Value *alignedPtr, 956 Value *offset = nullptr) { 957 Type operandType = originalOperand.getType(); 958 if (operandType.isa<MemRefType>()) { 959 MemRefDescriptor desc(convertedOperand); 960 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 961 *alignedPtr = desc.alignedPtr(rewriter, loc); 962 if (offset != nullptr) 963 *offset = desc.offset(rewriter, loc); 964 return; 965 } 966 967 unsigned memorySpace = 968 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 969 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 970 Type llvmElementType = typeConverter.convertType(elementType); 971 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 972 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 973 974 // Extract pointer to the underlying ranked memref descriptor and cast it to 975 // ElemType**. 976 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 977 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 978 979 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 980 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 981 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 982 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 983 if (offset != nullptr) { 984 *offset = UnrankedMemRefDescriptor::offset( 985 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 986 } 987 } 988 989 struct MemRefReinterpretCastOpLowering 990 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 991 using ConvertOpToLLVMPattern< 992 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 993 994 LogicalResult 995 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 996 ConversionPatternRewriter &rewriter) const override { 997 Type srcType = castOp.source().getType(); 998 999 Value descriptor; 1000 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1001 adaptor, &descriptor))) 1002 return failure(); 1003 rewriter.replaceOp(castOp, {descriptor}); 1004 return success(); 1005 } 1006 1007 private: 1008 LogicalResult convertSourceMemRefToDescriptor( 1009 ConversionPatternRewriter &rewriter, Type srcType, 1010 memref::ReinterpretCastOp castOp, 1011 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1012 MemRefType targetMemRefType = 1013 castOp.getResult().getType().cast<MemRefType>(); 1014 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1015 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1016 if (!llvmTargetDescriptorTy) 1017 return failure(); 1018 1019 // Create descriptor. 1020 Location loc = castOp.getLoc(); 1021 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1022 1023 // Set allocated and aligned pointers. 1024 Value allocatedPtr, alignedPtr; 1025 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1026 castOp.source(), adaptor.source(), &allocatedPtr, 1027 &alignedPtr); 1028 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1029 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1030 1031 // Set offset. 1032 if (castOp.isDynamicOffset(0)) 1033 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 1034 else 1035 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1036 1037 // Set sizes and strides. 1038 unsigned dynSizeId = 0; 1039 unsigned dynStrideId = 0; 1040 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1041 if (castOp.isDynamicSize(i)) 1042 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 1043 else 1044 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1045 1046 if (castOp.isDynamicStride(i)) 1047 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 1048 else 1049 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1050 } 1051 *descriptor = desc; 1052 return success(); 1053 } 1054 }; 1055 1056 struct MemRefReshapeOpLowering 1057 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1058 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1059 1060 LogicalResult 1061 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1062 ConversionPatternRewriter &rewriter) const override { 1063 Type srcType = reshapeOp.source().getType(); 1064 1065 Value descriptor; 1066 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1067 adaptor, &descriptor))) 1068 return failure(); 1069 rewriter.replaceOp(reshapeOp, {descriptor}); 1070 return success(); 1071 } 1072 1073 private: 1074 LogicalResult 1075 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1076 Type srcType, memref::ReshapeOp reshapeOp, 1077 memref::ReshapeOp::Adaptor adaptor, 1078 Value *descriptor) const { 1079 // Conversion for statically-known shape args is performed via 1080 // `memref_reinterpret_cast`. 1081 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 1082 if (shapeMemRefType.hasStaticShape()) 1083 return failure(); 1084 1085 // The shape is a rank-1 tensor with unknown length. 1086 Location loc = reshapeOp.getLoc(); 1087 MemRefDescriptor shapeDesc(adaptor.shape()); 1088 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1089 1090 // Extract address space and element type. 1091 auto targetType = 1092 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1093 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1094 Type elementType = targetType.getElementType(); 1095 1096 // Create the unranked memref descriptor that holds the ranked one. The 1097 // inner descriptor is allocated on stack. 1098 auto targetDesc = UnrankedMemRefDescriptor::undef( 1099 rewriter, loc, typeConverter->convertType(targetType)); 1100 targetDesc.setRank(rewriter, loc, resultRank); 1101 SmallVector<Value, 4> sizes; 1102 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1103 targetDesc, sizes); 1104 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1105 loc, getVoidPtrType(), sizes.front(), llvm::None); 1106 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1107 1108 // Extract pointers and offset from the source memref. 1109 Value allocatedPtr, alignedPtr, offset; 1110 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1111 reshapeOp.source(), adaptor.source(), 1112 &allocatedPtr, &alignedPtr, &offset); 1113 1114 // Set pointers and offset. 1115 Type llvmElementType = typeConverter->convertType(elementType); 1116 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1117 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1118 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1119 elementPtrPtrType, allocatedPtr); 1120 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1121 underlyingDescPtr, 1122 elementPtrPtrType, alignedPtr); 1123 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1124 underlyingDescPtr, elementPtrPtrType, 1125 offset); 1126 1127 // Use the offset pointer as base for further addressing. Copy over the new 1128 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1129 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1130 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1131 elementPtrPtrType); 1132 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1133 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1134 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1135 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1136 Value resultRankMinusOne = 1137 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1138 1139 Block *initBlock = rewriter.getInsertionBlock(); 1140 Type indexType = getTypeConverter()->getIndexType(); 1141 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1142 1143 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1144 {indexType, indexType}, {loc, loc}); 1145 1146 // Move the remaining initBlock ops to condBlock. 1147 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1148 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1149 1150 rewriter.setInsertionPointToEnd(initBlock); 1151 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1152 condBlock); 1153 rewriter.setInsertionPointToStart(condBlock); 1154 Value indexArg = condBlock->getArgument(0); 1155 Value strideArg = condBlock->getArgument(1); 1156 1157 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1158 Value pred = rewriter.create<LLVM::ICmpOp>( 1159 loc, IntegerType::get(rewriter.getContext(), 1), 1160 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1161 1162 Block *bodyBlock = 1163 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1164 rewriter.setInsertionPointToStart(bodyBlock); 1165 1166 // Copy size from shape to descriptor. 1167 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1168 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1169 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1170 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1171 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1172 targetSizesBase, indexArg, size); 1173 1174 // Write stride value and compute next one. 1175 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1176 targetStridesBase, indexArg, strideArg); 1177 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1178 1179 // Decrement loop counter and branch back. 1180 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1181 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1182 condBlock); 1183 1184 Block *remainder = 1185 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1186 1187 // Hook up the cond exit to the remainder. 1188 rewriter.setInsertionPointToEnd(condBlock); 1189 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1190 llvm::None); 1191 1192 // Reset position to beginning of new remainder block. 1193 rewriter.setInsertionPointToStart(remainder); 1194 1195 *descriptor = targetDesc; 1196 return success(); 1197 } 1198 }; 1199 1200 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1201 /// `Value`s. 1202 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1203 Type &llvmIndexType, 1204 ArrayRef<OpFoldResult> valueOrAttrVec) { 1205 return llvm::to_vector<4>( 1206 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1207 if (auto attr = value.dyn_cast<Attribute>()) 1208 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1209 return value.get<Value>(); 1210 })); 1211 } 1212 1213 /// Compute a map that for a given dimension of the expanded type gives the 1214 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1215 /// the `reassocation` maps. 1216 static DenseMap<int64_t, int64_t> 1217 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1218 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1219 for (auto &en : enumerate(reassociation)) { 1220 for (auto dim : en.value()) 1221 expandedDimToCollapsedDim[dim] = en.index(); 1222 } 1223 return expandedDimToCollapsedDim; 1224 } 1225 1226 static OpFoldResult 1227 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1228 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1229 MemRefDescriptor &inDesc, 1230 ArrayRef<int64_t> inStaticShape, 1231 ArrayRef<ReassociationIndices> reassocation, 1232 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1233 int64_t outDimSize = outStaticShape[outDimIndex]; 1234 if (!ShapedType::isDynamic(outDimSize)) 1235 return b.getIndexAttr(outDimSize); 1236 1237 // Calculate the multiplication of all the out dim sizes except the 1238 // current dim. 1239 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1240 int64_t otherDimSizesMul = 1; 1241 for (auto otherDimIndex : reassocation[inDimIndex]) { 1242 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1243 continue; 1244 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1245 assert(!ShapedType::isDynamic(otherDimSize) && 1246 "single dimension cannot be expanded into multiple dynamic " 1247 "dimensions"); 1248 otherDimSizesMul *= otherDimSize; 1249 } 1250 1251 // outDimSize = inDimSize / otherOutDimSizesMul 1252 int64_t inDimSize = inStaticShape[inDimIndex]; 1253 Value inDimSizeDynamic = 1254 ShapedType::isDynamic(inDimSize) 1255 ? inDesc.size(b, loc, inDimIndex) 1256 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1257 b.getIndexAttr(inDimSize)); 1258 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1259 loc, inDimSizeDynamic, 1260 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1261 b.getIndexAttr(otherDimSizesMul))); 1262 return outDimSizeDynamic; 1263 } 1264 1265 static OpFoldResult getCollapsedOutputDimSize( 1266 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1267 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1268 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1269 if (!ShapedType::isDynamic(outDimSize)) 1270 return b.getIndexAttr(outDimSize); 1271 1272 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1273 Value outDimSizeDynamic = c1; 1274 for (auto inDimIndex : reassocation[outDimIndex]) { 1275 int64_t inDimSize = inStaticShape[inDimIndex]; 1276 Value inDimSizeDynamic = 1277 ShapedType::isDynamic(inDimSize) 1278 ? inDesc.size(b, loc, inDimIndex) 1279 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1280 b.getIndexAttr(inDimSize)); 1281 outDimSizeDynamic = 1282 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1283 } 1284 return outDimSizeDynamic; 1285 } 1286 1287 static SmallVector<OpFoldResult, 4> 1288 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1289 ArrayRef<ReassociationIndices> reassocation, 1290 ArrayRef<int64_t> inStaticShape, 1291 MemRefDescriptor &inDesc, 1292 ArrayRef<int64_t> outStaticShape) { 1293 return llvm::to_vector<4>(llvm::map_range( 1294 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1295 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1296 outStaticShape[outDimIndex], 1297 inStaticShape, inDesc, reassocation); 1298 })); 1299 } 1300 1301 static SmallVector<OpFoldResult, 4> 1302 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1303 ArrayRef<ReassociationIndices> reassocation, 1304 ArrayRef<int64_t> inStaticShape, 1305 MemRefDescriptor &inDesc, 1306 ArrayRef<int64_t> outStaticShape) { 1307 DenseMap<int64_t, int64_t> outDimToInDimMap = 1308 getExpandedDimToCollapsedDimMap(reassocation); 1309 return llvm::to_vector<4>(llvm::map_range( 1310 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1311 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1312 outStaticShape, inDesc, inStaticShape, 1313 reassocation, outDimToInDimMap); 1314 })); 1315 } 1316 1317 static SmallVector<Value> 1318 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1319 ArrayRef<ReassociationIndices> reassocation, 1320 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1321 ArrayRef<int64_t> outStaticShape) { 1322 return outStaticShape.size() < inStaticShape.size() 1323 ? getAsValues(b, loc, llvmIndexType, 1324 getCollapsedOutputShape(b, loc, llvmIndexType, 1325 reassocation, inStaticShape, 1326 inDesc, outStaticShape)) 1327 : getAsValues(b, loc, llvmIndexType, 1328 getExpandedOutputShape(b, loc, llvmIndexType, 1329 reassocation, inStaticShape, 1330 inDesc, outStaticShape)); 1331 } 1332 1333 // ReshapeOp creates a new view descriptor of the proper rank. 1334 // For now, the only conversion supported is for target MemRef with static sizes 1335 // and strides. 1336 template <typename ReshapeOp> 1337 class ReassociatingReshapeOpConversion 1338 : public ConvertOpToLLVMPattern<ReshapeOp> { 1339 public: 1340 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1341 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1342 1343 LogicalResult 1344 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1345 ConversionPatternRewriter &rewriter) const override { 1346 MemRefType dstType = reshapeOp.getResultType(); 1347 MemRefType srcType = reshapeOp.getSrcType(); 1348 1349 // The condition on the layouts can be ignored when all shapes are static. 1350 if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { 1351 if (!srcType.getLayout().isIdentity() || 1352 !dstType.getLayout().isIdentity()) { 1353 return rewriter.notifyMatchFailure( 1354 reshapeOp, "only empty layout map is supported"); 1355 } 1356 } 1357 1358 int64_t offset; 1359 SmallVector<int64_t, 4> strides; 1360 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1361 return rewriter.notifyMatchFailure( 1362 reshapeOp, "failed to get stride and offset exprs"); 1363 } 1364 1365 MemRefDescriptor srcDesc(adaptor.src()); 1366 Location loc = reshapeOp->getLoc(); 1367 auto dstDesc = MemRefDescriptor::undef( 1368 rewriter, loc, this->typeConverter->convertType(dstType)); 1369 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1370 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1371 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1372 1373 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1374 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1375 Type llvmIndexType = 1376 this->typeConverter->convertType(rewriter.getIndexType()); 1377 SmallVector<Value> dstShape = getDynamicOutputShape( 1378 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1379 srcStaticShape, srcDesc, dstStaticShape); 1380 for (auto &en : llvm::enumerate(dstShape)) 1381 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1382 1383 auto isStaticStride = [](int64_t stride) { 1384 return !ShapedType::isDynamicStrideOrOffset(stride); 1385 }; 1386 if (llvm::all_of(strides, isStaticStride)) { 1387 for (auto &en : llvm::enumerate(strides)) 1388 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1389 } else { 1390 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1391 rewriter.getIndexAttr(1)); 1392 Value stride = c1; 1393 for (auto dimIndex : 1394 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1395 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1396 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1397 } 1398 } 1399 rewriter.replaceOp(reshapeOp, {dstDesc}); 1400 return success(); 1401 } 1402 }; 1403 1404 /// Conversion pattern that transforms a subview op into: 1405 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1406 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1407 /// and stride. 1408 /// The subview op is replaced by the descriptor. 1409 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1410 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1411 1412 LogicalResult 1413 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1414 ConversionPatternRewriter &rewriter) const override { 1415 auto loc = subViewOp.getLoc(); 1416 1417 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1418 auto sourceElementTy = 1419 typeConverter->convertType(sourceMemRefType.getElementType()); 1420 1421 auto viewMemRefType = subViewOp.getType(); 1422 auto inferredType = memref::SubViewOp::inferResultType( 1423 subViewOp.getSourceType(), 1424 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1425 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1426 extractFromI64ArrayAttr(subViewOp.static_strides())) 1427 .cast<MemRefType>(); 1428 auto targetElementTy = 1429 typeConverter->convertType(viewMemRefType.getElementType()); 1430 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1431 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1432 !LLVM::isCompatibleType(sourceElementTy) || 1433 !LLVM::isCompatibleType(targetElementTy) || 1434 !LLVM::isCompatibleType(targetDescTy)) 1435 return failure(); 1436 1437 // Extract the offset and strides from the type. 1438 int64_t offset; 1439 SmallVector<int64_t, 4> strides; 1440 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1441 if (failed(successStrides)) 1442 return failure(); 1443 1444 // Create the descriptor. 1445 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1446 return failure(); 1447 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1448 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1449 1450 // Copy the buffer pointer from the old descriptor to the new one. 1451 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1452 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1453 loc, 1454 LLVM::LLVMPointerType::get(targetElementTy, 1455 viewMemRefType.getMemorySpaceAsInt()), 1456 extracted); 1457 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1458 1459 // Copy the aligned pointer from the old descriptor to the new one. 1460 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1461 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1462 loc, 1463 LLVM::LLVMPointerType::get(targetElementTy, 1464 viewMemRefType.getMemorySpaceAsInt()), 1465 extracted); 1466 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1467 1468 size_t inferredShapeRank = inferredType.getRank(); 1469 size_t resultShapeRank = viewMemRefType.getRank(); 1470 1471 // Extract strides needed to compute offset. 1472 SmallVector<Value, 4> strideValues; 1473 strideValues.reserve(inferredShapeRank); 1474 for (unsigned i = 0; i < inferredShapeRank; ++i) 1475 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1476 1477 // Offset. 1478 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1479 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1480 targetMemRef.setConstantOffset(rewriter, loc, offset); 1481 } else { 1482 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1483 // `inferredShapeRank` may be larger than the number of offset operands 1484 // because of trailing semantics. In this case, the offset is guaranteed 1485 // to be interpreted as 0 and we can just skip the extra dimensions. 1486 for (unsigned i = 0, e = std::min(inferredShapeRank, 1487 subViewOp.getMixedOffsets().size()); 1488 i < e; ++i) { 1489 Value offset = 1490 // TODO: need OpFoldResult ODS adaptor to clean this up. 1491 subViewOp.isDynamicOffset(i) 1492 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1493 : rewriter.create<LLVM::ConstantOp>( 1494 loc, llvmIndexType, 1495 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1496 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1497 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1498 } 1499 targetMemRef.setOffset(rewriter, loc, baseOffset); 1500 } 1501 1502 // Update sizes and strides. 1503 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1504 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1505 assert(mixedSizes.size() == mixedStrides.size() && 1506 "expected sizes and strides of equal length"); 1507 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1508 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1509 i >= 0 && j >= 0; --i) { 1510 if (unusedDims.test(i)) 1511 continue; 1512 1513 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1514 // In this case, the size is guaranteed to be interpreted as Dim and the 1515 // stride as 1. 1516 Value size, stride; 1517 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1518 // If the static size is available, use it directly. This is similar to 1519 // the folding of dim(constant-op) but removes the need for dim to be 1520 // aware of LLVM constants and for this pass to be aware of std 1521 // constants. 1522 int64_t staticSize = 1523 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1524 if (staticSize != ShapedType::kDynamicSize) { 1525 size = rewriter.create<LLVM::ConstantOp>( 1526 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1527 } else { 1528 Value pos = rewriter.create<LLVM::ConstantOp>( 1529 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1530 Value dim = 1531 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1532 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1533 loc, llvmIndexType, dim); 1534 size = cast.getResult(0); 1535 } 1536 stride = rewriter.create<LLVM::ConstantOp>( 1537 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1538 } else { 1539 // TODO: need OpFoldResult ODS adaptor to clean this up. 1540 size = 1541 subViewOp.isDynamicSize(i) 1542 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1543 : rewriter.create<LLVM::ConstantOp>( 1544 loc, llvmIndexType, 1545 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1546 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1547 stride = rewriter.create<LLVM::ConstantOp>( 1548 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1549 } else { 1550 stride = 1551 subViewOp.isDynamicStride(i) 1552 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1553 : rewriter.create<LLVM::ConstantOp>( 1554 loc, llvmIndexType, 1555 rewriter.getI64IntegerAttr( 1556 subViewOp.getStaticStride(i))); 1557 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1558 } 1559 } 1560 targetMemRef.setSize(rewriter, loc, j, size); 1561 targetMemRef.setStride(rewriter, loc, j, stride); 1562 j--; 1563 } 1564 1565 rewriter.replaceOp(subViewOp, {targetMemRef}); 1566 return success(); 1567 } 1568 }; 1569 1570 /// Conversion pattern that transforms a transpose op into: 1571 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1572 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1573 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1574 /// and stride. Size and stride are permutations of the original values. 1575 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1576 /// The transpose op is replaced by the alloca'ed pointer. 1577 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1578 public: 1579 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1580 1581 LogicalResult 1582 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1583 ConversionPatternRewriter &rewriter) const override { 1584 auto loc = transposeOp.getLoc(); 1585 MemRefDescriptor viewMemRef(adaptor.in()); 1586 1587 // No permutation, early exit. 1588 if (transposeOp.permutation().isIdentity()) 1589 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1590 1591 auto targetMemRef = MemRefDescriptor::undef( 1592 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1593 1594 // Copy the base and aligned pointers from the old descriptor to the new 1595 // one. 1596 targetMemRef.setAllocatedPtr(rewriter, loc, 1597 viewMemRef.allocatedPtr(rewriter, loc)); 1598 targetMemRef.setAlignedPtr(rewriter, loc, 1599 viewMemRef.alignedPtr(rewriter, loc)); 1600 1601 // Copy the offset pointer from the old descriptor to the new one. 1602 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1603 1604 // Iterate over the dimensions and apply size/stride permutation. 1605 for (const auto &en : 1606 llvm::enumerate(transposeOp.permutation().getResults())) { 1607 int sourcePos = en.index(); 1608 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1609 targetMemRef.setSize(rewriter, loc, targetPos, 1610 viewMemRef.size(rewriter, loc, sourcePos)); 1611 targetMemRef.setStride(rewriter, loc, targetPos, 1612 viewMemRef.stride(rewriter, loc, sourcePos)); 1613 } 1614 1615 rewriter.replaceOp(transposeOp, {targetMemRef}); 1616 return success(); 1617 } 1618 }; 1619 1620 /// Conversion pattern that transforms an op into: 1621 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1622 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1623 /// and stride. 1624 /// The view op is replaced by the descriptor. 1625 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1626 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1627 1628 // Build and return the value for the idx^th shape dimension, either by 1629 // returning the constant shape dimension or counting the proper dynamic size. 1630 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1631 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1632 unsigned idx) const { 1633 assert(idx < shape.size()); 1634 if (!ShapedType::isDynamic(shape[idx])) 1635 return createIndexConstant(rewriter, loc, shape[idx]); 1636 // Count the number of dynamic dims in range [0, idx] 1637 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1638 return ShapedType::isDynamic(v); 1639 }); 1640 return dynamicSizes[nDynamic]; 1641 } 1642 1643 // Build and return the idx^th stride, either by returning the constant stride 1644 // or by computing the dynamic stride from the current `runningStride` and 1645 // `nextSize`. The caller should keep a running stride and update it with the 1646 // result returned by this function. 1647 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1648 ArrayRef<int64_t> strides, Value nextSize, 1649 Value runningStride, unsigned idx) const { 1650 assert(idx < strides.size()); 1651 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1652 return createIndexConstant(rewriter, loc, strides[idx]); 1653 if (nextSize) 1654 return runningStride 1655 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1656 : nextSize; 1657 assert(!runningStride); 1658 return createIndexConstant(rewriter, loc, 1); 1659 } 1660 1661 LogicalResult 1662 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1663 ConversionPatternRewriter &rewriter) const override { 1664 auto loc = viewOp.getLoc(); 1665 1666 auto viewMemRefType = viewOp.getType(); 1667 auto targetElementTy = 1668 typeConverter->convertType(viewMemRefType.getElementType()); 1669 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1670 if (!targetDescTy || !targetElementTy || 1671 !LLVM::isCompatibleType(targetElementTy) || 1672 !LLVM::isCompatibleType(targetDescTy)) 1673 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1674 failure(); 1675 1676 int64_t offset; 1677 SmallVector<int64_t, 4> strides; 1678 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1679 if (failed(successStrides)) 1680 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1681 assert(offset == 0 && "expected offset to be 0"); 1682 1683 // Create the descriptor. 1684 MemRefDescriptor sourceMemRef(adaptor.source()); 1685 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1686 1687 // Field 1: Copy the allocated pointer, used for malloc/free. 1688 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1689 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1690 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1691 loc, 1692 LLVM::LLVMPointerType::get(targetElementTy, 1693 srcMemRefType.getMemorySpaceAsInt()), 1694 allocatedPtr); 1695 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1696 1697 // Field 2: Copy the actual aligned pointer to payload. 1698 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1699 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1700 alignedPtr, adaptor.byte_shift()); 1701 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1702 loc, 1703 LLVM::LLVMPointerType::get(targetElementTy, 1704 srcMemRefType.getMemorySpaceAsInt()), 1705 alignedPtr); 1706 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1707 1708 // Field 3: The offset in the resulting type must be 0. This is because of 1709 // the type change: an offset on srcType* may not be expressible as an 1710 // offset on dstType*. 1711 targetMemRef.setOffset(rewriter, loc, 1712 createIndexConstant(rewriter, loc, offset)); 1713 1714 // Early exit for 0-D corner case. 1715 if (viewMemRefType.getRank() == 0) 1716 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1717 1718 // Fields 4 and 5: Update sizes and strides. 1719 if (strides.back() != 1) 1720 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1721 failure(); 1722 Value stride = nullptr, nextSize = nullptr; 1723 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1724 // Update size. 1725 Value size = 1726 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1727 targetMemRef.setSize(rewriter, loc, i, size); 1728 // Update stride. 1729 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1730 targetMemRef.setStride(rewriter, loc, i, stride); 1731 nextSize = size; 1732 } 1733 1734 rewriter.replaceOp(viewOp, {targetMemRef}); 1735 return success(); 1736 } 1737 }; 1738 1739 //===----------------------------------------------------------------------===// 1740 // AtomicRMWOpLowering 1741 //===----------------------------------------------------------------------===// 1742 1743 /// Try to match the kind of a std.atomic_rmw to determine whether to use a 1744 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1745 static Optional<LLVM::AtomicBinOp> 1746 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1747 switch (atomicOp.kind()) { 1748 case arith::AtomicRMWKind::addf: 1749 return LLVM::AtomicBinOp::fadd; 1750 case arith::AtomicRMWKind::addi: 1751 return LLVM::AtomicBinOp::add; 1752 case arith::AtomicRMWKind::assign: 1753 return LLVM::AtomicBinOp::xchg; 1754 case arith::AtomicRMWKind::maxs: 1755 return LLVM::AtomicBinOp::max; 1756 case arith::AtomicRMWKind::maxu: 1757 return LLVM::AtomicBinOp::umax; 1758 case arith::AtomicRMWKind::mins: 1759 return LLVM::AtomicBinOp::min; 1760 case arith::AtomicRMWKind::minu: 1761 return LLVM::AtomicBinOp::umin; 1762 case arith::AtomicRMWKind::ori: 1763 return LLVM::AtomicBinOp::_or; 1764 case arith::AtomicRMWKind::andi: 1765 return LLVM::AtomicBinOp::_and; 1766 default: 1767 return llvm::None; 1768 } 1769 llvm_unreachable("Invalid AtomicRMWKind"); 1770 } 1771 1772 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1773 using Base::Base; 1774 1775 LogicalResult 1776 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1777 ConversionPatternRewriter &rewriter) const override { 1778 if (failed(match(atomicOp))) 1779 return failure(); 1780 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1781 if (!maybeKind) 1782 return failure(); 1783 auto resultType = adaptor.value().getType(); 1784 auto memRefType = atomicOp.getMemRefType(); 1785 auto dataPtr = 1786 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1787 adaptor.indices(), rewriter); 1788 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1789 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1790 LLVM::AtomicOrdering::acq_rel); 1791 return success(); 1792 } 1793 }; 1794 1795 } // namespace 1796 1797 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1798 RewritePatternSet &patterns) { 1799 // clang-format off 1800 patterns.add< 1801 AllocaOpLowering, 1802 AllocaScopeOpLowering, 1803 AtomicRMWOpLowering, 1804 AssumeAlignmentOpLowering, 1805 DimOpLowering, 1806 GenericAtomicRMWOpLowering, 1807 GlobalMemrefOpLowering, 1808 GetGlobalMemrefOpLowering, 1809 LoadOpLowering, 1810 MemRefCastOpLowering, 1811 MemRefCopyOpLowering, 1812 MemRefReinterpretCastOpLowering, 1813 MemRefReshapeOpLowering, 1814 PrefetchOpLowering, 1815 RankOpLowering, 1816 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1817 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1818 StoreOpLowering, 1819 SubViewOpLowering, 1820 TransposeOpLowering, 1821 ViewOpLowering>(converter); 1822 // clang-format on 1823 auto allocLowering = converter.getOptions().allocLowering; 1824 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1825 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1826 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1827 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1828 } 1829 1830 namespace { 1831 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1832 MemRefToLLVMPass() = default; 1833 1834 void runOnOperation() override { 1835 Operation *op = getOperation(); 1836 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1837 LowerToLLVMOptions options(&getContext(), 1838 dataLayoutAnalysis.getAtOrAbove(op)); 1839 options.allocLowering = 1840 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1841 : LowerToLLVMOptions::AllocLowering::Malloc); 1842 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1843 options.overrideIndexBitwidth(indexBitwidth); 1844 1845 LLVMTypeConverter typeConverter(&getContext(), options, 1846 &dataLayoutAnalysis); 1847 RewritePatternSet patterns(&getContext()); 1848 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1849 LLVMConversionTarget target(getContext()); 1850 target.addLegalOp<FuncOp>(); 1851 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1852 signalPassFailure(); 1853 } 1854 }; 1855 } // namespace 1856 1857 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1858 return std::make_unique<MemRefToLLVMPass>(); 1859 } 1860