1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/BlockAndValueMapping.h" 22 #include "llvm/ADT/SmallBitVector.h" 23 24 using namespace mlir; 25 26 namespace { 27 28 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 29 AllocOpLowering(LLVMTypeConverter &converter) 30 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 31 converter) {} 32 33 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 34 Location loc, Value sizeBytes, 35 Operation *op) const override { 36 // Heap allocations. 37 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 38 MemRefType memRefType = allocOp.getType(); 39 40 Value alignment; 41 if (auto alignmentAttr = allocOp.alignment()) { 42 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 43 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 44 // In the case where no alignment is specified, we may want to override 45 // `malloc's` behavior. `malloc` typically aligns at the size of the 46 // biggest scalar on a target HW. For non-scalars, use the natural 47 // alignment of the LLVM type given by the LLVM DataLayout. 48 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 49 } 50 51 if (alignment) { 52 // Adjust the allocation size to consider alignment. 53 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 54 } 55 56 // Allocate the underlying buffer and store a pointer to it in the MemRef 57 // descriptor. 58 Type elementPtrType = this->getElementPtrType(memRefType); 59 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 60 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 61 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 62 getVoidPtrType()); 63 Value allocatedPtr = 64 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 65 66 Value alignedPtr = allocatedPtr; 67 if (alignment) { 68 // Compute the aligned type pointer. 69 Value allocatedInt = 70 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 71 Value alignmentInt = 72 createAligned(rewriter, loc, allocatedInt, alignment); 73 alignedPtr = 74 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 75 } 76 77 return std::make_tuple(allocatedPtr, alignedPtr); 78 } 79 }; 80 81 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 82 AlignedAllocOpLowering(LLVMTypeConverter &converter) 83 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 84 converter) {} 85 86 /// Returns the memref's element size in bytes using the data layout active at 87 /// `op`. 88 // TODO: there are other places where this is used. Expose publicly? 89 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 90 const DataLayout *layout = &defaultLayout; 91 if (const DataLayoutAnalysis *analysis = 92 getTypeConverter()->getDataLayoutAnalysis()) { 93 layout = &analysis->getAbove(op); 94 } 95 Type elementType = memRefType.getElementType(); 96 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 97 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 98 *layout); 99 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 100 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 101 memRefElementType, *layout); 102 return layout->getTypeSize(elementType); 103 } 104 105 /// Returns true if the memref size in bytes is known to be a multiple of 106 /// factor assuming the data layout active at `op`. 107 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 108 Operation *op) const { 109 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 110 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 111 if (ShapedType::isDynamic(type.getDimSize(i))) 112 continue; 113 sizeDivisor = sizeDivisor * type.getDimSize(i); 114 } 115 return sizeDivisor % factor == 0; 116 } 117 118 /// Returns the alignment to be used for the allocation call itself. 119 /// aligned_alloc requires the allocation size to be a power of two, and the 120 /// allocation size to be a multiple of alignment, 121 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 122 if (Optional<uint64_t> alignment = allocOp.alignment()) 123 return *alignment; 124 125 // Whenever we don't have alignment set, we will use an alignment 126 // consistent with the element type; since the allocation size has to be a 127 // power of two, we will bump to the next power of two if it already isn't. 128 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 129 return std::max(kMinAlignedAllocAlignment, 130 llvm::PowerOf2Ceil(eltSizeBytes)); 131 } 132 133 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 134 Location loc, Value sizeBytes, 135 Operation *op) const override { 136 // Heap allocations. 137 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 138 MemRefType memRefType = allocOp.getType(); 139 int64_t alignment = getAllocationAlignment(allocOp); 140 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 141 142 // aligned_alloc requires size to be a multiple of alignment; we will pad 143 // the size to the next multiple if necessary. 144 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 145 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 146 147 Type elementPtrType = this->getElementPtrType(memRefType); 148 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 149 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 150 auto results = 151 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 152 getVoidPtrType()); 153 Value allocatedPtr = 154 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 155 156 return std::make_tuple(allocatedPtr, allocatedPtr); 157 } 158 159 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 160 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 161 162 /// Default layout to use in absence of the corresponding analysis. 163 DataLayout defaultLayout; 164 }; 165 166 // Out of line definition, required till C++17. 167 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 168 169 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 170 AllocaOpLowering(LLVMTypeConverter &converter) 171 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 172 converter) {} 173 174 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 175 /// is set to null for stack allocations. `accessAlignment` is set if 176 /// alignment is needed post allocation (for eg. in conjunction with malloc). 177 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 178 Location loc, Value sizeBytes, 179 Operation *op) const override { 180 181 // With alloca, one gets a pointer to the element type right away. 182 // For stack allocations. 183 auto allocaOp = cast<memref::AllocaOp>(op); 184 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 185 186 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 187 loc, elementPtrType, sizeBytes, 188 allocaOp.alignment() ? *allocaOp.alignment() : 0); 189 190 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 191 } 192 }; 193 194 struct AllocaScopeOpLowering 195 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 196 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 197 198 LogicalResult 199 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 200 ConversionPatternRewriter &rewriter) const override { 201 OpBuilder::InsertionGuard guard(rewriter); 202 Location loc = allocaScopeOp.getLoc(); 203 204 // Split the current block before the AllocaScopeOp to create the inlining 205 // point. 206 auto *currentBlock = rewriter.getInsertionBlock(); 207 auto *remainingOpsBlock = 208 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 209 Block *continueBlock; 210 if (allocaScopeOp.getNumResults() == 0) { 211 continueBlock = remainingOpsBlock; 212 } else { 213 continueBlock = rewriter.createBlock( 214 remainingOpsBlock, allocaScopeOp.getResultTypes(), 215 SmallVector<Location>(allocaScopeOp->getNumResults(), 216 allocaScopeOp.getLoc())); 217 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 218 } 219 220 // Inline body region. 221 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 222 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 223 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 224 225 // Save stack and then branch into the body of the region. 226 rewriter.setInsertionPointToEnd(currentBlock); 227 auto stackSaveOp = 228 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 229 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 230 231 // Replace the alloca_scope return with a branch that jumps out of the body. 232 // Stack restore before leaving the body region. 233 rewriter.setInsertionPointToEnd(afterBody); 234 auto returnOp = 235 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 236 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 237 returnOp, returnOp.results(), continueBlock); 238 239 // Insert stack restore before jumping out the body of the region. 240 rewriter.setInsertionPoint(branchOp); 241 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 242 243 // Replace the op with values return from the body region. 244 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 245 246 return success(); 247 } 248 }; 249 250 struct AssumeAlignmentOpLowering 251 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 252 using ConvertOpToLLVMPattern< 253 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 254 255 LogicalResult 256 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 257 ConversionPatternRewriter &rewriter) const override { 258 Value memref = adaptor.memref(); 259 unsigned alignment = op.alignment(); 260 auto loc = op.getLoc(); 261 262 MemRefDescriptor memRefDescriptor(memref); 263 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 264 265 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 266 // the asserted memref.alignedPtr isn't used anywhere else, as the real 267 // users like load/store/views always re-extract memref.alignedPtr as they 268 // get lowered. 269 // 270 // This relies on LLVM's CSE optimization (potentially after SROA), since 271 // after CSE all memref.alignedPtr instances get de-duplicated into the same 272 // pointer SSA value. 273 auto intPtrType = 274 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 275 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 276 Value mask = 277 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 278 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 279 rewriter.create<LLVM::AssumeOp>( 280 loc, rewriter.create<LLVM::ICmpOp>( 281 loc, LLVM::ICmpPredicate::eq, 282 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 283 284 rewriter.eraseOp(op); 285 return success(); 286 } 287 }; 288 289 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 290 // The memref descriptor being an SSA value, there is no need to clean it up 291 // in any way. 292 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 293 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 294 295 explicit DeallocOpLowering(LLVMTypeConverter &converter) 296 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 297 298 LogicalResult 299 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 300 ConversionPatternRewriter &rewriter) const override { 301 // Insert the `free` declaration if it is not already present. 302 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 303 MemRefDescriptor memref(adaptor.memref()); 304 Value casted = rewriter.create<LLVM::BitcastOp>( 305 op.getLoc(), getVoidPtrType(), 306 memref.allocatedPtr(rewriter, op.getLoc())); 307 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 308 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 309 return success(); 310 } 311 }; 312 313 // A `dim` is converted to a constant for static sizes and to an access to the 314 // size stored in the memref descriptor for dynamic sizes. 315 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 316 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 317 318 LogicalResult 319 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 320 ConversionPatternRewriter &rewriter) const override { 321 Type operandType = dimOp.source().getType(); 322 if (operandType.isa<UnrankedMemRefType>()) { 323 rewriter.replaceOp( 324 dimOp, {extractSizeOfUnrankedMemRef( 325 operandType, dimOp, adaptor.getOperands(), rewriter)}); 326 327 return success(); 328 } 329 if (operandType.isa<MemRefType>()) { 330 rewriter.replaceOp( 331 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 332 adaptor.getOperands(), rewriter)}); 333 return success(); 334 } 335 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 336 } 337 338 private: 339 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 340 OpAdaptor adaptor, 341 ConversionPatternRewriter &rewriter) const { 342 Location loc = dimOp.getLoc(); 343 344 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 345 auto scalarMemRefType = 346 MemRefType::get({}, unrankedMemRefType.getElementType()); 347 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 348 349 // Extract pointer to the underlying ranked descriptor and bitcast it to a 350 // memref<element_type> descriptor pointer to minimize the number of GEP 351 // operations. 352 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 353 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 354 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 355 loc, 356 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 357 addressSpace), 358 underlyingRankedDesc); 359 360 // Get pointer to offset field of memref<element_type> descriptor. 361 Type indexPtrTy = LLVM::LLVMPointerType::get( 362 getTypeConverter()->getIndexType(), addressSpace); 363 Value two = rewriter.create<LLVM::ConstantOp>( 364 loc, typeConverter->convertType(rewriter.getI32Type()), 365 rewriter.getI32IntegerAttr(2)); 366 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 367 loc, indexPtrTy, scalarMemRefDescPtr, 368 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 369 370 // The size value that we have to extract can be obtained using GEPop with 371 // `dimOp.index() + 1` index argument. 372 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 373 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 374 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 375 ValueRange({idxPlusOne})); 376 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 377 } 378 379 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 380 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 381 return idx; 382 383 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 384 return constantOp.getValue() 385 .cast<IntegerAttr>() 386 .getValue() 387 .getSExtValue(); 388 389 return llvm::None; 390 } 391 392 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 393 OpAdaptor adaptor, 394 ConversionPatternRewriter &rewriter) const { 395 Location loc = dimOp.getLoc(); 396 397 // Take advantage if index is constant. 398 MemRefType memRefType = operandType.cast<MemRefType>(); 399 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 400 int64_t i = index.getValue(); 401 if (memRefType.isDynamicDim(i)) { 402 // extract dynamic size from the memref descriptor. 403 MemRefDescriptor descriptor(adaptor.source()); 404 return descriptor.size(rewriter, loc, i); 405 } 406 // Use constant for static size. 407 int64_t dimSize = memRefType.getDimSize(i); 408 return createIndexConstant(rewriter, loc, dimSize); 409 } 410 Value index = adaptor.index(); 411 int64_t rank = memRefType.getRank(); 412 MemRefDescriptor memrefDescriptor(adaptor.source()); 413 return memrefDescriptor.size(rewriter, loc, index, rank); 414 } 415 }; 416 417 /// Common base for load and store operations on MemRefs. Restricts the match 418 /// to supported MemRef types. Provides functionality to emit code accessing a 419 /// specific element of the underlying data buffer. 420 template <typename Derived> 421 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 422 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 423 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 424 using Base = LoadStoreOpLowering<Derived>; 425 426 LogicalResult match(Derived op) const override { 427 MemRefType type = op.getMemRefType(); 428 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 429 } 430 }; 431 432 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 433 /// retried until it succeeds in atomically storing a new value into memory. 434 /// 435 /// +---------------------------------+ 436 /// | <code before the AtomicRMWOp> | 437 /// | <compute initial %loaded> | 438 /// | cf.br loop(%loaded) | 439 /// +---------------------------------+ 440 /// | 441 /// -------| | 442 /// | v v 443 /// | +--------------------------------+ 444 /// | | loop(%loaded): | 445 /// | | <body contents> | 446 /// | | %pair = cmpxchg | 447 /// | | %ok = %pair[0] | 448 /// | | %new = %pair[1] | 449 /// | | cf.cond_br %ok, end, loop(%new) | 450 /// | +--------------------------------+ 451 /// | | | 452 /// |----------- | 453 /// v 454 /// +--------------------------------+ 455 /// | end: | 456 /// | <code after the AtomicRMWOp> | 457 /// +--------------------------------+ 458 /// 459 struct GenericAtomicRMWOpLowering 460 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 461 using Base::Base; 462 463 LogicalResult 464 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 465 ConversionPatternRewriter &rewriter) const override { 466 auto loc = atomicOp.getLoc(); 467 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 468 469 // Split the block into initial, loop, and ending parts. 470 auto *initBlock = rewriter.getInsertionBlock(); 471 auto *loopBlock = rewriter.createBlock( 472 initBlock->getParent(), std::next(Region::iterator(initBlock)), 473 valueType, loc); 474 auto *endBlock = rewriter.createBlock( 475 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 476 477 // Operations range to be moved to `endBlock`. 478 auto opsToMoveStart = atomicOp->getIterator(); 479 auto opsToMoveEnd = initBlock->back().getIterator(); 480 481 // Compute the loaded value and branch to the loop block. 482 rewriter.setInsertionPointToEnd(initBlock); 483 auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); 484 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 485 adaptor.indices(), rewriter); 486 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 487 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 488 489 // Prepare the body of the loop block. 490 rewriter.setInsertionPointToStart(loopBlock); 491 492 // Clone the GenericAtomicRMWOp region and extract the result. 493 auto loopArgument = loopBlock->getArgument(0); 494 BlockAndValueMapping mapping; 495 mapping.map(atomicOp.getCurrentValue(), loopArgument); 496 Block &entryBlock = atomicOp.body().front(); 497 for (auto &nestedOp : entryBlock.without_terminator()) { 498 Operation *clone = rewriter.clone(nestedOp, mapping); 499 mapping.map(nestedOp.getResults(), clone->getResults()); 500 } 501 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 502 503 // Prepare the epilog of the loop block. 504 // Append the cmpxchg op to the end of the loop block. 505 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 506 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 507 auto boolType = IntegerType::get(rewriter.getContext(), 1); 508 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 509 {valueType, boolType}); 510 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 511 loc, pairType, dataPtr, loopArgument, result, successOrdering, 512 failureOrdering); 513 // Extract the %new_loaded and %ok values from the pair. 514 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 515 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 516 Value ok = rewriter.create<LLVM::ExtractValueOp>( 517 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 518 519 // Conditionally branch to the end or back to the loop depending on %ok. 520 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 521 loopBlock, newLoaded); 522 523 rewriter.setInsertionPointToEnd(endBlock); 524 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 525 std::next(opsToMoveEnd), rewriter); 526 527 // The 'result' of the atomic_rmw op is the newly loaded value. 528 rewriter.replaceOp(atomicOp, {newLoaded}); 529 530 return success(); 531 } 532 533 private: 534 // Clones a segment of ops [start, end) and erases the original. 535 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 536 Block::iterator start, Block::iterator end, 537 ConversionPatternRewriter &rewriter) const { 538 BlockAndValueMapping mapping; 539 mapping.map(oldResult, newResult); 540 SmallVector<Operation *, 2> opsToErase; 541 for (auto it = start; it != end; ++it) { 542 rewriter.clone(*it, mapping); 543 opsToErase.push_back(&*it); 544 } 545 for (auto *it : opsToErase) 546 rewriter.eraseOp(it); 547 } 548 }; 549 550 /// Returns the LLVM type of the global variable given the memref type `type`. 551 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 552 LLVMTypeConverter &typeConverter) { 553 // LLVM type for a global memref will be a multi-dimension array. For 554 // declarations or uninitialized global memrefs, we can potentially flatten 555 // this to a 1D array. However, for memref.global's with an initial value, 556 // we do not intend to flatten the ElementsAttribute when going from std -> 557 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 558 Type elementType = typeConverter.convertType(type.getElementType()); 559 Type arrayTy = elementType; 560 // Shape has the outermost dim at index 0, so need to walk it backwards 561 for (int64_t dim : llvm::reverse(type.getShape())) 562 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 563 return arrayTy; 564 } 565 566 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 567 struct GlobalMemrefOpLowering 568 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 569 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 570 571 LogicalResult 572 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 573 ConversionPatternRewriter &rewriter) const override { 574 MemRefType type = global.type(); 575 if (!isConvertibleAndHasIdentityMaps(type)) 576 return failure(); 577 578 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 579 580 LLVM::Linkage linkage = 581 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 582 583 Attribute initialValue = nullptr; 584 if (!global.isExternal() && !global.isUninitialized()) { 585 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 586 initialValue = elementsAttr; 587 588 // For scalar memrefs, the global variable created is of the element type, 589 // so unpack the elements attribute to extract the value. 590 if (type.getRank() == 0) 591 initialValue = elementsAttr.getSplatValue<Attribute>(); 592 } 593 594 uint64_t alignment = global.alignment().getValueOr(0); 595 596 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 597 global, arrayTy, global.constant(), linkage, global.sym_name(), 598 initialValue, alignment, type.getMemorySpaceAsInt()); 599 if (!global.isExternal() && global.isUninitialized()) { 600 Block *blk = new Block(); 601 newGlobal.getInitializerRegion().push_back(blk); 602 rewriter.setInsertionPointToStart(blk); 603 Value undef[] = { 604 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 605 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 606 } 607 return success(); 608 } 609 }; 610 611 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 612 /// the first element stashed into the descriptor. This reuses 613 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 614 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 615 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 616 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 617 converter) {} 618 619 /// Buffer "allocation" for memref.get_global op is getting the address of 620 /// the global variable referenced. 621 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 622 Location loc, Value sizeBytes, 623 Operation *op) const override { 624 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 625 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 626 unsigned memSpace = type.getMemorySpaceAsInt(); 627 628 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 629 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 630 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 631 632 // Get the address of the first element in the array by creating a GEP with 633 // the address of the GV as the base, and (rank + 1) number of 0 indices. 634 Type elementType = typeConverter->convertType(type.getElementType()); 635 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 636 637 SmallVector<Value> operands; 638 operands.insert(operands.end(), type.getRank() + 1, 639 createIndexConstant(rewriter, loc, 0)); 640 auto gep = 641 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 642 643 // We do not expect the memref obtained using `memref.get_global` to be 644 // ever deallocated. Set the allocated pointer to be known bad value to 645 // help debug if that ever happens. 646 auto intPtrType = getIntPtrType(memSpace); 647 Value deadBeefConst = 648 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 649 auto deadBeefPtr = 650 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 651 652 // Both allocated and aligned pointers are same. We could potentially stash 653 // a nullptr for the allocated pointer since we do not expect any dealloc. 654 return std::make_tuple(deadBeefPtr, gep); 655 } 656 }; 657 658 // Load operation is lowered to obtaining a pointer to the indexed element 659 // and loading it. 660 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 661 using Base::Base; 662 663 LogicalResult 664 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 665 ConversionPatternRewriter &rewriter) const override { 666 auto type = loadOp.getMemRefType(); 667 668 Value dataPtr = getStridedElementPtr( 669 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 670 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 671 return success(); 672 } 673 }; 674 675 // Store operation is lowered to obtaining a pointer to the indexed element, 676 // and storing the given value to it. 677 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 678 using Base::Base; 679 680 LogicalResult 681 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 682 ConversionPatternRewriter &rewriter) const override { 683 auto type = op.getMemRefType(); 684 685 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 686 adaptor.indices(), rewriter); 687 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 688 return success(); 689 } 690 }; 691 692 // The prefetch operation is lowered in a way similar to the load operation 693 // except that the llvm.prefetch operation is used for replacement. 694 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 695 using Base::Base; 696 697 LogicalResult 698 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 699 ConversionPatternRewriter &rewriter) const override { 700 auto type = prefetchOp.getMemRefType(); 701 auto loc = prefetchOp.getLoc(); 702 703 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 704 adaptor.indices(), rewriter); 705 706 // Replace with llvm.prefetch. 707 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 708 auto isWrite = rewriter.create<LLVM::ConstantOp>( 709 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 710 auto localityHint = rewriter.create<LLVM::ConstantOp>( 711 loc, llvmI32Type, 712 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 713 auto isData = rewriter.create<LLVM::ConstantOp>( 714 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 715 716 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 717 localityHint, isData); 718 return success(); 719 } 720 }; 721 722 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 723 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 724 725 LogicalResult 726 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 727 ConversionPatternRewriter &rewriter) const override { 728 Location loc = op.getLoc(); 729 Type operandType = op.memref().getType(); 730 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 731 UnrankedMemRefDescriptor desc(adaptor.memref()); 732 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 733 return success(); 734 } 735 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 736 rewriter.replaceOp( 737 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 738 return success(); 739 } 740 return failure(); 741 } 742 }; 743 744 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 745 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 746 747 LogicalResult match(memref::CastOp memRefCastOp) const override { 748 Type srcType = memRefCastOp.getOperand().getType(); 749 Type dstType = memRefCastOp.getType(); 750 751 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 752 // used for type erasure. For now they must preserve underlying element type 753 // and require source and result type to have the same rank. Therefore, 754 // perform a sanity check that the underlying structs are the same. Once op 755 // semantics are relaxed we can revisit. 756 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 757 return success(typeConverter->convertType(srcType) == 758 typeConverter->convertType(dstType)); 759 760 // At least one of the operands is unranked type 761 assert(srcType.isa<UnrankedMemRefType>() || 762 dstType.isa<UnrankedMemRefType>()); 763 764 // Unranked to unranked cast is disallowed 765 return !(srcType.isa<UnrankedMemRefType>() && 766 dstType.isa<UnrankedMemRefType>()) 767 ? success() 768 : failure(); 769 } 770 771 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 772 ConversionPatternRewriter &rewriter) const override { 773 auto srcType = memRefCastOp.getOperand().getType(); 774 auto dstType = memRefCastOp.getType(); 775 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 776 auto loc = memRefCastOp.getLoc(); 777 778 // For ranked/ranked case, just keep the original descriptor. 779 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 780 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 781 782 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 783 // Casting ranked to unranked memref type 784 // Set the rank in the destination from the memref type 785 // Allocate space on the stack and copy the src memref descriptor 786 // Set the ptr in the destination to the stack space 787 auto srcMemRefType = srcType.cast<MemRefType>(); 788 int64_t rank = srcMemRefType.getRank(); 789 // ptr = AllocaOp sizeof(MemRefDescriptor) 790 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 791 loc, adaptor.source(), rewriter); 792 // voidptr = BitCastOp srcType* to void* 793 auto voidPtr = 794 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 795 .getResult(); 796 // rank = ConstantOp srcRank 797 auto rankVal = rewriter.create<LLVM::ConstantOp>( 798 loc, getIndexType(), rewriter.getIndexAttr(rank)); 799 // undef = UndefOp 800 UnrankedMemRefDescriptor memRefDesc = 801 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 802 // d1 = InsertValueOp undef, rank, 0 803 memRefDesc.setRank(rewriter, loc, rankVal); 804 // d2 = InsertValueOp d1, voidptr, 1 805 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 806 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 807 808 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 809 // Casting from unranked type to ranked. 810 // The operation is assumed to be doing a correct cast. If the destination 811 // type mismatches the unranked the type, it is undefined behavior. 812 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 813 // ptr = ExtractValueOp src, 1 814 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 815 // castPtr = BitCastOp i8* to structTy* 816 auto castPtr = 817 rewriter 818 .create<LLVM::BitcastOp>( 819 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 820 .getResult(); 821 // struct = LoadOp castPtr 822 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 823 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 824 } else { 825 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 826 } 827 } 828 }; 829 830 /// Pattern to lower a `memref.copy` to llvm. 831 /// 832 /// For memrefs with identity layouts, the copy is lowered to the llvm 833 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 834 /// to the generic `MemrefCopyFn`. 835 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 836 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 837 838 LogicalResult 839 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 840 ConversionPatternRewriter &rewriter) const { 841 auto loc = op.getLoc(); 842 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 843 844 MemRefDescriptor srcDesc(adaptor.source()); 845 846 // Compute number of elements. 847 Value numElements = rewriter.create<LLVM::ConstantOp>( 848 loc, getIndexType(), rewriter.getIndexAttr(1)); 849 for (int pos = 0; pos < srcType.getRank(); ++pos) { 850 auto size = srcDesc.size(rewriter, loc, pos); 851 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 852 } 853 854 // Get element size. 855 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 856 // Compute total. 857 Value totalSize = 858 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 859 860 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 861 Value srcOffset = srcDesc.offset(rewriter, loc); 862 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(), 863 srcBasePtr, srcOffset); 864 MemRefDescriptor targetDesc(adaptor.target()); 865 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 866 Value targetOffset = targetDesc.offset(rewriter, loc); 867 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(), 868 targetBasePtr, targetOffset); 869 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 870 loc, typeConverter->convertType(rewriter.getI1Type()), 871 rewriter.getBoolAttr(false)); 872 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 873 isVolatile); 874 rewriter.eraseOp(op); 875 876 return success(); 877 } 878 879 LogicalResult 880 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 881 ConversionPatternRewriter &rewriter) const { 882 auto loc = op.getLoc(); 883 auto srcType = op.source().getType().cast<BaseMemRefType>(); 884 auto targetType = op.target().getType().cast<BaseMemRefType>(); 885 886 // First make sure we have an unranked memref descriptor representation. 887 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 888 auto rank = rewriter.create<LLVM::ConstantOp>( 889 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 890 auto *typeConverter = getTypeConverter(); 891 auto ptr = 892 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 893 auto voidPtr = 894 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 895 .getResult(); 896 auto unrankedType = 897 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 898 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 899 unrankedType, 900 ValueRange{rank, voidPtr}); 901 }; 902 903 Value unrankedSource = srcType.hasRank() 904 ? makeUnranked(adaptor.source(), srcType) 905 : adaptor.source(); 906 Value unrankedTarget = targetType.hasRank() 907 ? makeUnranked(adaptor.target(), targetType) 908 : adaptor.target(); 909 910 // Now promote the unranked descriptors to the stack. 911 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 912 rewriter.getIndexAttr(1)); 913 auto promote = [&](Value desc) { 914 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 915 auto allocated = 916 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 917 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 918 return allocated; 919 }; 920 921 auto sourcePtr = promote(unrankedSource); 922 auto targetPtr = promote(unrankedTarget); 923 924 unsigned typeSize = 925 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 926 auto elemSize = rewriter.create<LLVM::ConstantOp>( 927 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 928 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 929 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 930 rewriter.create<LLVM::CallOp>(loc, copyFn, 931 ValueRange{elemSize, sourcePtr, targetPtr}); 932 rewriter.eraseOp(op); 933 934 return success(); 935 } 936 937 LogicalResult 938 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 939 ConversionPatternRewriter &rewriter) const override { 940 auto srcType = op.source().getType().cast<BaseMemRefType>(); 941 auto targetType = op.target().getType().cast<BaseMemRefType>(); 942 943 auto isContiguousMemrefType = [](BaseMemRefType type) { 944 auto memrefType = type.dyn_cast<mlir::MemRefType>(); 945 // We can use memcpy for memrefs if they have an identity layout or are 946 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 947 // special case handled by memrefCopy. 948 return memrefType && 949 (memrefType.getLayout().isIdentity() || 950 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 951 isStaticShapeAndContiguousRowMajor(memrefType))); 952 }; 953 954 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 955 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 956 957 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 958 } 959 }; 960 961 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 962 /// memref type. In unranked case, the fields are extracted from the underlying 963 /// ranked descriptor. 964 static void extractPointersAndOffset(Location loc, 965 ConversionPatternRewriter &rewriter, 966 LLVMTypeConverter &typeConverter, 967 Value originalOperand, 968 Value convertedOperand, 969 Value *allocatedPtr, Value *alignedPtr, 970 Value *offset = nullptr) { 971 Type operandType = originalOperand.getType(); 972 if (operandType.isa<MemRefType>()) { 973 MemRefDescriptor desc(convertedOperand); 974 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 975 *alignedPtr = desc.alignedPtr(rewriter, loc); 976 if (offset != nullptr) 977 *offset = desc.offset(rewriter, loc); 978 return; 979 } 980 981 unsigned memorySpace = 982 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 983 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 984 Type llvmElementType = typeConverter.convertType(elementType); 985 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 986 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 987 988 // Extract pointer to the underlying ranked memref descriptor and cast it to 989 // ElemType**. 990 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 991 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 992 993 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 994 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 995 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 996 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 997 if (offset != nullptr) { 998 *offset = UnrankedMemRefDescriptor::offset( 999 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1000 } 1001 } 1002 1003 struct MemRefReinterpretCastOpLowering 1004 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1005 using ConvertOpToLLVMPattern< 1006 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1007 1008 LogicalResult 1009 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1010 ConversionPatternRewriter &rewriter) const override { 1011 Type srcType = castOp.source().getType(); 1012 1013 Value descriptor; 1014 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1015 adaptor, &descriptor))) 1016 return failure(); 1017 rewriter.replaceOp(castOp, {descriptor}); 1018 return success(); 1019 } 1020 1021 private: 1022 LogicalResult convertSourceMemRefToDescriptor( 1023 ConversionPatternRewriter &rewriter, Type srcType, 1024 memref::ReinterpretCastOp castOp, 1025 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1026 MemRefType targetMemRefType = 1027 castOp.getResult().getType().cast<MemRefType>(); 1028 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1029 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1030 if (!llvmTargetDescriptorTy) 1031 return failure(); 1032 1033 // Create descriptor. 1034 Location loc = castOp.getLoc(); 1035 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1036 1037 // Set allocated and aligned pointers. 1038 Value allocatedPtr, alignedPtr; 1039 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1040 castOp.source(), adaptor.source(), &allocatedPtr, 1041 &alignedPtr); 1042 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1043 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1044 1045 // Set offset. 1046 if (castOp.isDynamicOffset(0)) 1047 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 1048 else 1049 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1050 1051 // Set sizes and strides. 1052 unsigned dynSizeId = 0; 1053 unsigned dynStrideId = 0; 1054 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1055 if (castOp.isDynamicSize(i)) 1056 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 1057 else 1058 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1059 1060 if (castOp.isDynamicStride(i)) 1061 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 1062 else 1063 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1064 } 1065 *descriptor = desc; 1066 return success(); 1067 } 1068 }; 1069 1070 struct MemRefReshapeOpLowering 1071 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1072 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1073 1074 LogicalResult 1075 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1076 ConversionPatternRewriter &rewriter) const override { 1077 Type srcType = reshapeOp.source().getType(); 1078 1079 Value descriptor; 1080 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1081 adaptor, &descriptor))) 1082 return failure(); 1083 rewriter.replaceOp(reshapeOp, {descriptor}); 1084 return success(); 1085 } 1086 1087 private: 1088 LogicalResult 1089 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1090 Type srcType, memref::ReshapeOp reshapeOp, 1091 memref::ReshapeOp::Adaptor adaptor, 1092 Value *descriptor) const { 1093 // Conversion for statically-known shape args is performed via 1094 // `memref_reinterpret_cast`. 1095 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 1096 if (shapeMemRefType.hasStaticShape()) 1097 return failure(); 1098 1099 // The shape is a rank-1 tensor with unknown length. 1100 Location loc = reshapeOp.getLoc(); 1101 MemRefDescriptor shapeDesc(adaptor.shape()); 1102 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1103 1104 // Extract address space and element type. 1105 auto targetType = 1106 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1107 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1108 Type elementType = targetType.getElementType(); 1109 1110 // Create the unranked memref descriptor that holds the ranked one. The 1111 // inner descriptor is allocated on stack. 1112 auto targetDesc = UnrankedMemRefDescriptor::undef( 1113 rewriter, loc, typeConverter->convertType(targetType)); 1114 targetDesc.setRank(rewriter, loc, resultRank); 1115 SmallVector<Value, 4> sizes; 1116 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1117 targetDesc, sizes); 1118 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1119 loc, getVoidPtrType(), sizes.front(), llvm::None); 1120 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1121 1122 // Extract pointers and offset from the source memref. 1123 Value allocatedPtr, alignedPtr, offset; 1124 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1125 reshapeOp.source(), adaptor.source(), 1126 &allocatedPtr, &alignedPtr, &offset); 1127 1128 // Set pointers and offset. 1129 Type llvmElementType = typeConverter->convertType(elementType); 1130 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1131 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1132 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1133 elementPtrPtrType, allocatedPtr); 1134 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1135 underlyingDescPtr, 1136 elementPtrPtrType, alignedPtr); 1137 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1138 underlyingDescPtr, elementPtrPtrType, 1139 offset); 1140 1141 // Use the offset pointer as base for further addressing. Copy over the new 1142 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1143 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1144 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1145 elementPtrPtrType); 1146 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1147 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1148 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1149 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1150 Value resultRankMinusOne = 1151 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1152 1153 Block *initBlock = rewriter.getInsertionBlock(); 1154 Type indexType = getTypeConverter()->getIndexType(); 1155 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1156 1157 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1158 {indexType, indexType}, {loc, loc}); 1159 1160 // Move the remaining initBlock ops to condBlock. 1161 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1162 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1163 1164 rewriter.setInsertionPointToEnd(initBlock); 1165 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1166 condBlock); 1167 rewriter.setInsertionPointToStart(condBlock); 1168 Value indexArg = condBlock->getArgument(0); 1169 Value strideArg = condBlock->getArgument(1); 1170 1171 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1172 Value pred = rewriter.create<LLVM::ICmpOp>( 1173 loc, IntegerType::get(rewriter.getContext(), 1), 1174 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1175 1176 Block *bodyBlock = 1177 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1178 rewriter.setInsertionPointToStart(bodyBlock); 1179 1180 // Copy size from shape to descriptor. 1181 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1182 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1183 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1184 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1185 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1186 targetSizesBase, indexArg, size); 1187 1188 // Write stride value and compute next one. 1189 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1190 targetStridesBase, indexArg, strideArg); 1191 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1192 1193 // Decrement loop counter and branch back. 1194 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1195 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1196 condBlock); 1197 1198 Block *remainder = 1199 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1200 1201 // Hook up the cond exit to the remainder. 1202 rewriter.setInsertionPointToEnd(condBlock); 1203 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1204 llvm::None); 1205 1206 // Reset position to beginning of new remainder block. 1207 rewriter.setInsertionPointToStart(remainder); 1208 1209 *descriptor = targetDesc; 1210 return success(); 1211 } 1212 }; 1213 1214 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1215 /// `Value`s. 1216 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1217 Type &llvmIndexType, 1218 ArrayRef<OpFoldResult> valueOrAttrVec) { 1219 return llvm::to_vector<4>( 1220 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1221 if (auto attr = value.dyn_cast<Attribute>()) 1222 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1223 return value.get<Value>(); 1224 })); 1225 } 1226 1227 /// Compute a map that for a given dimension of the expanded type gives the 1228 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1229 /// the `reassocation` maps. 1230 static DenseMap<int64_t, int64_t> 1231 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1232 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1233 for (auto &en : enumerate(reassociation)) { 1234 for (auto dim : en.value()) 1235 expandedDimToCollapsedDim[dim] = en.index(); 1236 } 1237 return expandedDimToCollapsedDim; 1238 } 1239 1240 static OpFoldResult 1241 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1242 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1243 MemRefDescriptor &inDesc, 1244 ArrayRef<int64_t> inStaticShape, 1245 ArrayRef<ReassociationIndices> reassocation, 1246 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1247 int64_t outDimSize = outStaticShape[outDimIndex]; 1248 if (!ShapedType::isDynamic(outDimSize)) 1249 return b.getIndexAttr(outDimSize); 1250 1251 // Calculate the multiplication of all the out dim sizes except the 1252 // current dim. 1253 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1254 int64_t otherDimSizesMul = 1; 1255 for (auto otherDimIndex : reassocation[inDimIndex]) { 1256 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1257 continue; 1258 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1259 assert(!ShapedType::isDynamic(otherDimSize) && 1260 "single dimension cannot be expanded into multiple dynamic " 1261 "dimensions"); 1262 otherDimSizesMul *= otherDimSize; 1263 } 1264 1265 // outDimSize = inDimSize / otherOutDimSizesMul 1266 int64_t inDimSize = inStaticShape[inDimIndex]; 1267 Value inDimSizeDynamic = 1268 ShapedType::isDynamic(inDimSize) 1269 ? inDesc.size(b, loc, inDimIndex) 1270 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1271 b.getIndexAttr(inDimSize)); 1272 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1273 loc, inDimSizeDynamic, 1274 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1275 b.getIndexAttr(otherDimSizesMul))); 1276 return outDimSizeDynamic; 1277 } 1278 1279 static OpFoldResult getCollapsedOutputDimSize( 1280 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1281 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1282 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1283 if (!ShapedType::isDynamic(outDimSize)) 1284 return b.getIndexAttr(outDimSize); 1285 1286 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1287 Value outDimSizeDynamic = c1; 1288 for (auto inDimIndex : reassocation[outDimIndex]) { 1289 int64_t inDimSize = inStaticShape[inDimIndex]; 1290 Value inDimSizeDynamic = 1291 ShapedType::isDynamic(inDimSize) 1292 ? inDesc.size(b, loc, inDimIndex) 1293 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1294 b.getIndexAttr(inDimSize)); 1295 outDimSizeDynamic = 1296 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1297 } 1298 return outDimSizeDynamic; 1299 } 1300 1301 static SmallVector<OpFoldResult, 4> 1302 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1303 ArrayRef<ReassociationIndices> reassocation, 1304 ArrayRef<int64_t> inStaticShape, 1305 MemRefDescriptor &inDesc, 1306 ArrayRef<int64_t> outStaticShape) { 1307 return llvm::to_vector<4>(llvm::map_range( 1308 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1309 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1310 outStaticShape[outDimIndex], 1311 inStaticShape, inDesc, reassocation); 1312 })); 1313 } 1314 1315 static SmallVector<OpFoldResult, 4> 1316 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1317 ArrayRef<ReassociationIndices> reassocation, 1318 ArrayRef<int64_t> inStaticShape, 1319 MemRefDescriptor &inDesc, 1320 ArrayRef<int64_t> outStaticShape) { 1321 DenseMap<int64_t, int64_t> outDimToInDimMap = 1322 getExpandedDimToCollapsedDimMap(reassocation); 1323 return llvm::to_vector<4>(llvm::map_range( 1324 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1325 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1326 outStaticShape, inDesc, inStaticShape, 1327 reassocation, outDimToInDimMap); 1328 })); 1329 } 1330 1331 static SmallVector<Value> 1332 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1333 ArrayRef<ReassociationIndices> reassocation, 1334 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1335 ArrayRef<int64_t> outStaticShape) { 1336 return outStaticShape.size() < inStaticShape.size() 1337 ? getAsValues(b, loc, llvmIndexType, 1338 getCollapsedOutputShape(b, loc, llvmIndexType, 1339 reassocation, inStaticShape, 1340 inDesc, outStaticShape)) 1341 : getAsValues(b, loc, llvmIndexType, 1342 getExpandedOutputShape(b, loc, llvmIndexType, 1343 reassocation, inStaticShape, 1344 inDesc, outStaticShape)); 1345 } 1346 1347 // ReshapeOp creates a new view descriptor of the proper rank. 1348 // For now, the only conversion supported is for target MemRef with static sizes 1349 // and strides. 1350 template <typename ReshapeOp> 1351 class ReassociatingReshapeOpConversion 1352 : public ConvertOpToLLVMPattern<ReshapeOp> { 1353 public: 1354 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1355 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1356 1357 LogicalResult 1358 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1359 ConversionPatternRewriter &rewriter) const override { 1360 MemRefType dstType = reshapeOp.getResultType(); 1361 MemRefType srcType = reshapeOp.getSrcType(); 1362 1363 // The condition on the layouts can be ignored when all shapes are static. 1364 if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { 1365 if (!srcType.getLayout().isIdentity() || 1366 !dstType.getLayout().isIdentity()) { 1367 return rewriter.notifyMatchFailure( 1368 reshapeOp, "only empty layout map is supported"); 1369 } 1370 } 1371 1372 int64_t offset; 1373 SmallVector<int64_t, 4> strides; 1374 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1375 return rewriter.notifyMatchFailure( 1376 reshapeOp, "failed to get stride and offset exprs"); 1377 } 1378 1379 MemRefDescriptor srcDesc(adaptor.src()); 1380 Location loc = reshapeOp->getLoc(); 1381 auto dstDesc = MemRefDescriptor::undef( 1382 rewriter, loc, this->typeConverter->convertType(dstType)); 1383 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1384 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1385 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1386 1387 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1388 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1389 Type llvmIndexType = 1390 this->typeConverter->convertType(rewriter.getIndexType()); 1391 SmallVector<Value> dstShape = getDynamicOutputShape( 1392 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1393 srcStaticShape, srcDesc, dstStaticShape); 1394 for (auto &en : llvm::enumerate(dstShape)) 1395 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1396 1397 auto isStaticStride = [](int64_t stride) { 1398 return !ShapedType::isDynamicStrideOrOffset(stride); 1399 }; 1400 if (llvm::all_of(strides, isStaticStride)) { 1401 for (auto &en : llvm::enumerate(strides)) 1402 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1403 } else { 1404 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1405 rewriter.getIndexAttr(1)); 1406 Value stride = c1; 1407 for (auto dimIndex : 1408 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1409 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1410 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1411 } 1412 } 1413 rewriter.replaceOp(reshapeOp, {dstDesc}); 1414 return success(); 1415 } 1416 }; 1417 1418 /// Conversion pattern that transforms a subview op into: 1419 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1420 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1421 /// and stride. 1422 /// The subview op is replaced by the descriptor. 1423 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1424 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1425 1426 LogicalResult 1427 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1428 ConversionPatternRewriter &rewriter) const override { 1429 auto loc = subViewOp.getLoc(); 1430 1431 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1432 auto sourceElementTy = 1433 typeConverter->convertType(sourceMemRefType.getElementType()); 1434 1435 auto viewMemRefType = subViewOp.getType(); 1436 auto inferredType = memref::SubViewOp::inferResultType( 1437 subViewOp.getSourceType(), 1438 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1439 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1440 extractFromI64ArrayAttr(subViewOp.static_strides())) 1441 .cast<MemRefType>(); 1442 auto targetElementTy = 1443 typeConverter->convertType(viewMemRefType.getElementType()); 1444 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1445 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1446 !LLVM::isCompatibleType(sourceElementTy) || 1447 !LLVM::isCompatibleType(targetElementTy) || 1448 !LLVM::isCompatibleType(targetDescTy)) 1449 return failure(); 1450 1451 // Extract the offset and strides from the type. 1452 int64_t offset; 1453 SmallVector<int64_t, 4> strides; 1454 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1455 if (failed(successStrides)) 1456 return failure(); 1457 1458 // Create the descriptor. 1459 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1460 return failure(); 1461 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1462 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1463 1464 // Copy the buffer pointer from the old descriptor to the new one. 1465 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1466 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1467 loc, 1468 LLVM::LLVMPointerType::get(targetElementTy, 1469 viewMemRefType.getMemorySpaceAsInt()), 1470 extracted); 1471 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1472 1473 // Copy the aligned pointer from the old descriptor to the new one. 1474 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1475 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1476 loc, 1477 LLVM::LLVMPointerType::get(targetElementTy, 1478 viewMemRefType.getMemorySpaceAsInt()), 1479 extracted); 1480 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1481 1482 size_t inferredShapeRank = inferredType.getRank(); 1483 size_t resultShapeRank = viewMemRefType.getRank(); 1484 1485 // Extract strides needed to compute offset. 1486 SmallVector<Value, 4> strideValues; 1487 strideValues.reserve(inferredShapeRank); 1488 for (unsigned i = 0; i < inferredShapeRank; ++i) 1489 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1490 1491 // Offset. 1492 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1493 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1494 targetMemRef.setConstantOffset(rewriter, loc, offset); 1495 } else { 1496 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1497 // `inferredShapeRank` may be larger than the number of offset operands 1498 // because of trailing semantics. In this case, the offset is guaranteed 1499 // to be interpreted as 0 and we can just skip the extra dimensions. 1500 for (unsigned i = 0, e = std::min(inferredShapeRank, 1501 subViewOp.getMixedOffsets().size()); 1502 i < e; ++i) { 1503 Value offset = 1504 // TODO: need OpFoldResult ODS adaptor to clean this up. 1505 subViewOp.isDynamicOffset(i) 1506 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1507 : rewriter.create<LLVM::ConstantOp>( 1508 loc, llvmIndexType, 1509 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1510 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1511 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1512 } 1513 targetMemRef.setOffset(rewriter, loc, baseOffset); 1514 } 1515 1516 // Update sizes and strides. 1517 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1518 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1519 assert(mixedSizes.size() == mixedStrides.size() && 1520 "expected sizes and strides of equal length"); 1521 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1522 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1523 i >= 0 && j >= 0; --i) { 1524 if (unusedDims.test(i)) 1525 continue; 1526 1527 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1528 // In this case, the size is guaranteed to be interpreted as Dim and the 1529 // stride as 1. 1530 Value size, stride; 1531 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1532 // If the static size is available, use it directly. This is similar to 1533 // the folding of dim(constant-op) but removes the need for dim to be 1534 // aware of LLVM constants and for this pass to be aware of std 1535 // constants. 1536 int64_t staticSize = 1537 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1538 if (staticSize != ShapedType::kDynamicSize) { 1539 size = rewriter.create<LLVM::ConstantOp>( 1540 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1541 } else { 1542 Value pos = rewriter.create<LLVM::ConstantOp>( 1543 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1544 Value dim = 1545 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1546 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1547 loc, llvmIndexType, dim); 1548 size = cast.getResult(0); 1549 } 1550 stride = rewriter.create<LLVM::ConstantOp>( 1551 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1552 } else { 1553 // TODO: need OpFoldResult ODS adaptor to clean this up. 1554 size = 1555 subViewOp.isDynamicSize(i) 1556 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1557 : rewriter.create<LLVM::ConstantOp>( 1558 loc, llvmIndexType, 1559 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1560 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1561 stride = rewriter.create<LLVM::ConstantOp>( 1562 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1563 } else { 1564 stride = 1565 subViewOp.isDynamicStride(i) 1566 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1567 : rewriter.create<LLVM::ConstantOp>( 1568 loc, llvmIndexType, 1569 rewriter.getI64IntegerAttr( 1570 subViewOp.getStaticStride(i))); 1571 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1572 } 1573 } 1574 targetMemRef.setSize(rewriter, loc, j, size); 1575 targetMemRef.setStride(rewriter, loc, j, stride); 1576 j--; 1577 } 1578 1579 rewriter.replaceOp(subViewOp, {targetMemRef}); 1580 return success(); 1581 } 1582 }; 1583 1584 /// Conversion pattern that transforms a transpose op into: 1585 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1586 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1587 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1588 /// and stride. Size and stride are permutations of the original values. 1589 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1590 /// The transpose op is replaced by the alloca'ed pointer. 1591 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1592 public: 1593 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1594 1595 LogicalResult 1596 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1597 ConversionPatternRewriter &rewriter) const override { 1598 auto loc = transposeOp.getLoc(); 1599 MemRefDescriptor viewMemRef(adaptor.in()); 1600 1601 // No permutation, early exit. 1602 if (transposeOp.permutation().isIdentity()) 1603 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1604 1605 auto targetMemRef = MemRefDescriptor::undef( 1606 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1607 1608 // Copy the base and aligned pointers from the old descriptor to the new 1609 // one. 1610 targetMemRef.setAllocatedPtr(rewriter, loc, 1611 viewMemRef.allocatedPtr(rewriter, loc)); 1612 targetMemRef.setAlignedPtr(rewriter, loc, 1613 viewMemRef.alignedPtr(rewriter, loc)); 1614 1615 // Copy the offset pointer from the old descriptor to the new one. 1616 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1617 1618 // Iterate over the dimensions and apply size/stride permutation. 1619 for (const auto &en : 1620 llvm::enumerate(transposeOp.permutation().getResults())) { 1621 int sourcePos = en.index(); 1622 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1623 targetMemRef.setSize(rewriter, loc, targetPos, 1624 viewMemRef.size(rewriter, loc, sourcePos)); 1625 targetMemRef.setStride(rewriter, loc, targetPos, 1626 viewMemRef.stride(rewriter, loc, sourcePos)); 1627 } 1628 1629 rewriter.replaceOp(transposeOp, {targetMemRef}); 1630 return success(); 1631 } 1632 }; 1633 1634 /// Conversion pattern that transforms an op into: 1635 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1636 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1637 /// and stride. 1638 /// The view op is replaced by the descriptor. 1639 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1640 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1641 1642 // Build and return the value for the idx^th shape dimension, either by 1643 // returning the constant shape dimension or counting the proper dynamic size. 1644 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1645 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1646 unsigned idx) const { 1647 assert(idx < shape.size()); 1648 if (!ShapedType::isDynamic(shape[idx])) 1649 return createIndexConstant(rewriter, loc, shape[idx]); 1650 // Count the number of dynamic dims in range [0, idx] 1651 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1652 return ShapedType::isDynamic(v); 1653 }); 1654 return dynamicSizes[nDynamic]; 1655 } 1656 1657 // Build and return the idx^th stride, either by returning the constant stride 1658 // or by computing the dynamic stride from the current `runningStride` and 1659 // `nextSize`. The caller should keep a running stride and update it with the 1660 // result returned by this function. 1661 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1662 ArrayRef<int64_t> strides, Value nextSize, 1663 Value runningStride, unsigned idx) const { 1664 assert(idx < strides.size()); 1665 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1666 return createIndexConstant(rewriter, loc, strides[idx]); 1667 if (nextSize) 1668 return runningStride 1669 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1670 : nextSize; 1671 assert(!runningStride); 1672 return createIndexConstant(rewriter, loc, 1); 1673 } 1674 1675 LogicalResult 1676 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1677 ConversionPatternRewriter &rewriter) const override { 1678 auto loc = viewOp.getLoc(); 1679 1680 auto viewMemRefType = viewOp.getType(); 1681 auto targetElementTy = 1682 typeConverter->convertType(viewMemRefType.getElementType()); 1683 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1684 if (!targetDescTy || !targetElementTy || 1685 !LLVM::isCompatibleType(targetElementTy) || 1686 !LLVM::isCompatibleType(targetDescTy)) 1687 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1688 failure(); 1689 1690 int64_t offset; 1691 SmallVector<int64_t, 4> strides; 1692 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1693 if (failed(successStrides)) 1694 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1695 assert(offset == 0 && "expected offset to be 0"); 1696 1697 // Create the descriptor. 1698 MemRefDescriptor sourceMemRef(adaptor.source()); 1699 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1700 1701 // Field 1: Copy the allocated pointer, used for malloc/free. 1702 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1703 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1704 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1705 loc, 1706 LLVM::LLVMPointerType::get(targetElementTy, 1707 srcMemRefType.getMemorySpaceAsInt()), 1708 allocatedPtr); 1709 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1710 1711 // Field 2: Copy the actual aligned pointer to payload. 1712 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1713 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1714 alignedPtr, adaptor.byte_shift()); 1715 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1716 loc, 1717 LLVM::LLVMPointerType::get(targetElementTy, 1718 srcMemRefType.getMemorySpaceAsInt()), 1719 alignedPtr); 1720 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1721 1722 // Field 3: The offset in the resulting type must be 0. This is because of 1723 // the type change: an offset on srcType* may not be expressible as an 1724 // offset on dstType*. 1725 targetMemRef.setOffset(rewriter, loc, 1726 createIndexConstant(rewriter, loc, offset)); 1727 1728 // Early exit for 0-D corner case. 1729 if (viewMemRefType.getRank() == 0) 1730 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1731 1732 // Fields 4 and 5: Update sizes and strides. 1733 if (strides.back() != 1) 1734 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1735 failure(); 1736 Value stride = nullptr, nextSize = nullptr; 1737 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1738 // Update size. 1739 Value size = 1740 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1741 targetMemRef.setSize(rewriter, loc, i, size); 1742 // Update stride. 1743 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1744 targetMemRef.setStride(rewriter, loc, i, stride); 1745 nextSize = size; 1746 } 1747 1748 rewriter.replaceOp(viewOp, {targetMemRef}); 1749 return success(); 1750 } 1751 }; 1752 1753 //===----------------------------------------------------------------------===// 1754 // AtomicRMWOpLowering 1755 //===----------------------------------------------------------------------===// 1756 1757 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1758 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1759 static Optional<LLVM::AtomicBinOp> 1760 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1761 switch (atomicOp.kind()) { 1762 case arith::AtomicRMWKind::addf: 1763 return LLVM::AtomicBinOp::fadd; 1764 case arith::AtomicRMWKind::addi: 1765 return LLVM::AtomicBinOp::add; 1766 case arith::AtomicRMWKind::assign: 1767 return LLVM::AtomicBinOp::xchg; 1768 case arith::AtomicRMWKind::maxs: 1769 return LLVM::AtomicBinOp::max; 1770 case arith::AtomicRMWKind::maxu: 1771 return LLVM::AtomicBinOp::umax; 1772 case arith::AtomicRMWKind::mins: 1773 return LLVM::AtomicBinOp::min; 1774 case arith::AtomicRMWKind::minu: 1775 return LLVM::AtomicBinOp::umin; 1776 case arith::AtomicRMWKind::ori: 1777 return LLVM::AtomicBinOp::_or; 1778 case arith::AtomicRMWKind::andi: 1779 return LLVM::AtomicBinOp::_and; 1780 default: 1781 return llvm::None; 1782 } 1783 llvm_unreachable("Invalid AtomicRMWKind"); 1784 } 1785 1786 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1787 using Base::Base; 1788 1789 LogicalResult 1790 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1791 ConversionPatternRewriter &rewriter) const override { 1792 if (failed(match(atomicOp))) 1793 return failure(); 1794 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1795 if (!maybeKind) 1796 return failure(); 1797 auto resultType = adaptor.value().getType(); 1798 auto memRefType = atomicOp.getMemRefType(); 1799 auto dataPtr = 1800 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1801 adaptor.indices(), rewriter); 1802 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1803 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1804 LLVM::AtomicOrdering::acq_rel); 1805 return success(); 1806 } 1807 }; 1808 1809 } // namespace 1810 1811 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1812 RewritePatternSet &patterns) { 1813 // clang-format off 1814 patterns.add< 1815 AllocaOpLowering, 1816 AllocaScopeOpLowering, 1817 AtomicRMWOpLowering, 1818 AssumeAlignmentOpLowering, 1819 DimOpLowering, 1820 GenericAtomicRMWOpLowering, 1821 GlobalMemrefOpLowering, 1822 GetGlobalMemrefOpLowering, 1823 LoadOpLowering, 1824 MemRefCastOpLowering, 1825 MemRefCopyOpLowering, 1826 MemRefReinterpretCastOpLowering, 1827 MemRefReshapeOpLowering, 1828 PrefetchOpLowering, 1829 RankOpLowering, 1830 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1831 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1832 StoreOpLowering, 1833 SubViewOpLowering, 1834 TransposeOpLowering, 1835 ViewOpLowering>(converter); 1836 // clang-format on 1837 auto allocLowering = converter.getOptions().allocLowering; 1838 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1839 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1840 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1841 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1842 } 1843 1844 namespace { 1845 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1846 MemRefToLLVMPass() = default; 1847 1848 void runOnOperation() override { 1849 Operation *op = getOperation(); 1850 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1851 LowerToLLVMOptions options(&getContext(), 1852 dataLayoutAnalysis.getAtOrAbove(op)); 1853 options.allocLowering = 1854 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1855 : LowerToLLVMOptions::AllocLowering::Malloc); 1856 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1857 options.overrideIndexBitwidth(indexBitwidth); 1858 1859 LLVMTypeConverter typeConverter(&getContext(), options, 1860 &dataLayoutAnalysis); 1861 RewritePatternSet patterns(&getContext()); 1862 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1863 LLVMConversionTarget target(getContext()); 1864 target.addLegalOp<func::FuncOp>(); 1865 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1866 signalPassFailure(); 1867 } 1868 }; 1869 } // namespace 1870 1871 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1872 return std::make_unique<MemRefToLLVMPass>(); 1873 } 1874