1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "llvm/ADT/SmallBitVector.h" 24 25 using namespace mlir; 26 27 namespace { 28 29 bool isStaticStrideOrOffset(int64_t strideOrOffset) { 30 return !ShapedType::isDynamicStrideOrOffset(strideOrOffset); 31 } 32 33 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 34 AllocOpLowering(LLVMTypeConverter &converter) 35 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 36 converter) {} 37 38 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 39 Location loc, Value sizeBytes, 40 Operation *op) const override { 41 // Heap allocations. 42 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 43 MemRefType memRefType = allocOp.getType(); 44 45 Value alignment; 46 if (auto alignmentAttr = allocOp.alignment()) { 47 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 48 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 49 // In the case where no alignment is specified, we may want to override 50 // `malloc's` behavior. `malloc` typically aligns at the size of the 51 // biggest scalar on a target HW. For non-scalars, use the natural 52 // alignment of the LLVM type given by the LLVM DataLayout. 53 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 54 } 55 56 if (alignment) { 57 // Adjust the allocation size to consider alignment. 58 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 59 } 60 61 // Allocate the underlying buffer and store a pointer to it in the MemRef 62 // descriptor. 63 Type elementPtrType = this->getElementPtrType(memRefType); 64 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 65 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 66 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 67 getVoidPtrType()); 68 Value allocatedPtr = 69 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 70 71 Value alignedPtr = allocatedPtr; 72 if (alignment) { 73 // Compute the aligned type pointer. 74 Value allocatedInt = 75 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 76 Value alignmentInt = 77 createAligned(rewriter, loc, allocatedInt, alignment); 78 alignedPtr = 79 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 80 } 81 82 return std::make_tuple(allocatedPtr, alignedPtr); 83 } 84 }; 85 86 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 87 AlignedAllocOpLowering(LLVMTypeConverter &converter) 88 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 89 converter) {} 90 91 /// Returns the memref's element size in bytes using the data layout active at 92 /// `op`. 93 // TODO: there are other places where this is used. Expose publicly? 94 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 95 const DataLayout *layout = &defaultLayout; 96 if (const DataLayoutAnalysis *analysis = 97 getTypeConverter()->getDataLayoutAnalysis()) { 98 layout = &analysis->getAbove(op); 99 } 100 Type elementType = memRefType.getElementType(); 101 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 102 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 103 *layout); 104 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 105 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 106 memRefElementType, *layout); 107 return layout->getTypeSize(elementType); 108 } 109 110 /// Returns true if the memref size in bytes is known to be a multiple of 111 /// factor assuming the data layout active at `op`. 112 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 113 Operation *op) const { 114 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 115 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 116 if (ShapedType::isDynamic(type.getDimSize(i))) 117 continue; 118 sizeDivisor = sizeDivisor * type.getDimSize(i); 119 } 120 return sizeDivisor % factor == 0; 121 } 122 123 /// Returns the alignment to be used for the allocation call itself. 124 /// aligned_alloc requires the allocation size to be a power of two, and the 125 /// allocation size to be a multiple of alignment, 126 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 127 if (Optional<uint64_t> alignment = allocOp.alignment()) 128 return *alignment; 129 130 // Whenever we don't have alignment set, we will use an alignment 131 // consistent with the element type; since the allocation size has to be a 132 // power of two, we will bump to the next power of two if it already isn't. 133 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 134 return std::max(kMinAlignedAllocAlignment, 135 llvm::PowerOf2Ceil(eltSizeBytes)); 136 } 137 138 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 139 Location loc, Value sizeBytes, 140 Operation *op) const override { 141 // Heap allocations. 142 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 143 MemRefType memRefType = allocOp.getType(); 144 int64_t alignment = getAllocationAlignment(allocOp); 145 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 146 147 // aligned_alloc requires size to be a multiple of alignment; we will pad 148 // the size to the next multiple if necessary. 149 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 150 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 151 152 Type elementPtrType = this->getElementPtrType(memRefType); 153 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 154 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 155 auto results = 156 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 157 getVoidPtrType()); 158 Value allocatedPtr = 159 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 160 161 return std::make_tuple(allocatedPtr, allocatedPtr); 162 } 163 164 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 165 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 166 167 /// Default layout to use in absence of the corresponding analysis. 168 DataLayout defaultLayout; 169 }; 170 171 // Out of line definition, required till C++17. 172 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 173 174 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 175 AllocaOpLowering(LLVMTypeConverter &converter) 176 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 177 converter) {} 178 179 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 180 /// is set to null for stack allocations. `accessAlignment` is set if 181 /// alignment is needed post allocation (for eg. in conjunction with malloc). 182 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 183 Location loc, Value sizeBytes, 184 Operation *op) const override { 185 186 // With alloca, one gets a pointer to the element type right away. 187 // For stack allocations. 188 auto allocaOp = cast<memref::AllocaOp>(op); 189 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 190 191 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 192 loc, elementPtrType, sizeBytes, 193 allocaOp.alignment() ? *allocaOp.alignment() : 0); 194 195 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 196 } 197 }; 198 199 struct AllocaScopeOpLowering 200 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 201 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 202 203 LogicalResult 204 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 205 ConversionPatternRewriter &rewriter) const override { 206 OpBuilder::InsertionGuard guard(rewriter); 207 Location loc = allocaScopeOp.getLoc(); 208 209 // Split the current block before the AllocaScopeOp to create the inlining 210 // point. 211 auto *currentBlock = rewriter.getInsertionBlock(); 212 auto *remainingOpsBlock = 213 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 214 Block *continueBlock; 215 if (allocaScopeOp.getNumResults() == 0) { 216 continueBlock = remainingOpsBlock; 217 } else { 218 continueBlock = rewriter.createBlock( 219 remainingOpsBlock, allocaScopeOp.getResultTypes(), 220 SmallVector<Location>(allocaScopeOp->getNumResults(), 221 allocaScopeOp.getLoc())); 222 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 223 } 224 225 // Inline body region. 226 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 227 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 228 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 229 230 // Save stack and then branch into the body of the region. 231 rewriter.setInsertionPointToEnd(currentBlock); 232 auto stackSaveOp = 233 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 234 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 235 236 // Replace the alloca_scope return with a branch that jumps out of the body. 237 // Stack restore before leaving the body region. 238 rewriter.setInsertionPointToEnd(afterBody); 239 auto returnOp = 240 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 241 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 242 returnOp, returnOp.results(), continueBlock); 243 244 // Insert stack restore before jumping out the body of the region. 245 rewriter.setInsertionPoint(branchOp); 246 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 247 248 // Replace the op with values return from the body region. 249 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 250 251 return success(); 252 } 253 }; 254 255 struct AssumeAlignmentOpLowering 256 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 257 using ConvertOpToLLVMPattern< 258 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 259 260 LogicalResult 261 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 262 ConversionPatternRewriter &rewriter) const override { 263 Value memref = adaptor.memref(); 264 unsigned alignment = op.alignment(); 265 auto loc = op.getLoc(); 266 267 MemRefDescriptor memRefDescriptor(memref); 268 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 269 270 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 271 // the asserted memref.alignedPtr isn't used anywhere else, as the real 272 // users like load/store/views always re-extract memref.alignedPtr as they 273 // get lowered. 274 // 275 // This relies on LLVM's CSE optimization (potentially after SROA), since 276 // after CSE all memref.alignedPtr instances get de-duplicated into the same 277 // pointer SSA value. 278 auto intPtrType = 279 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 280 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 281 Value mask = 282 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 283 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 284 rewriter.create<LLVM::AssumeOp>( 285 loc, rewriter.create<LLVM::ICmpOp>( 286 loc, LLVM::ICmpPredicate::eq, 287 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 288 289 rewriter.eraseOp(op); 290 return success(); 291 } 292 }; 293 294 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 295 // The memref descriptor being an SSA value, there is no need to clean it up 296 // in any way. 297 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 298 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 299 300 explicit DeallocOpLowering(LLVMTypeConverter &converter) 301 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 302 303 LogicalResult 304 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 305 ConversionPatternRewriter &rewriter) const override { 306 // Insert the `free` declaration if it is not already present. 307 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 308 MemRefDescriptor memref(adaptor.memref()); 309 Value casted = rewriter.create<LLVM::BitcastOp>( 310 op.getLoc(), getVoidPtrType(), 311 memref.allocatedPtr(rewriter, op.getLoc())); 312 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 313 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 314 return success(); 315 } 316 }; 317 318 // A `dim` is converted to a constant for static sizes and to an access to the 319 // size stored in the memref descriptor for dynamic sizes. 320 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 321 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 322 323 LogicalResult 324 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 325 ConversionPatternRewriter &rewriter) const override { 326 Type operandType = dimOp.source().getType(); 327 if (operandType.isa<UnrankedMemRefType>()) { 328 rewriter.replaceOp( 329 dimOp, {extractSizeOfUnrankedMemRef( 330 operandType, dimOp, adaptor.getOperands(), rewriter)}); 331 332 return success(); 333 } 334 if (operandType.isa<MemRefType>()) { 335 rewriter.replaceOp( 336 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 337 adaptor.getOperands(), rewriter)}); 338 return success(); 339 } 340 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 341 } 342 343 private: 344 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 345 OpAdaptor adaptor, 346 ConversionPatternRewriter &rewriter) const { 347 Location loc = dimOp.getLoc(); 348 349 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 350 auto scalarMemRefType = 351 MemRefType::get({}, unrankedMemRefType.getElementType()); 352 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 353 354 // Extract pointer to the underlying ranked descriptor and bitcast it to a 355 // memref<element_type> descriptor pointer to minimize the number of GEP 356 // operations. 357 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 358 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 359 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 360 loc, 361 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 362 addressSpace), 363 underlyingRankedDesc); 364 365 // Get pointer to offset field of memref<element_type> descriptor. 366 Type indexPtrTy = LLVM::LLVMPointerType::get( 367 getTypeConverter()->getIndexType(), addressSpace); 368 Value two = rewriter.create<LLVM::ConstantOp>( 369 loc, typeConverter->convertType(rewriter.getI32Type()), 370 rewriter.getI32IntegerAttr(2)); 371 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 372 loc, indexPtrTy, scalarMemRefDescPtr, 373 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 374 375 // The size value that we have to extract can be obtained using GEPop with 376 // `dimOp.index() + 1` index argument. 377 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 378 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 379 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 380 ValueRange({idxPlusOne})); 381 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 382 } 383 384 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 385 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 386 return idx; 387 388 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 389 return constantOp.getValue() 390 .cast<IntegerAttr>() 391 .getValue() 392 .getSExtValue(); 393 394 return llvm::None; 395 } 396 397 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 398 OpAdaptor adaptor, 399 ConversionPatternRewriter &rewriter) const { 400 Location loc = dimOp.getLoc(); 401 402 // Take advantage if index is constant. 403 MemRefType memRefType = operandType.cast<MemRefType>(); 404 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 405 int64_t i = index.getValue(); 406 if (memRefType.isDynamicDim(i)) { 407 // extract dynamic size from the memref descriptor. 408 MemRefDescriptor descriptor(adaptor.source()); 409 return descriptor.size(rewriter, loc, i); 410 } 411 // Use constant for static size. 412 int64_t dimSize = memRefType.getDimSize(i); 413 return createIndexConstant(rewriter, loc, dimSize); 414 } 415 Value index = adaptor.index(); 416 int64_t rank = memRefType.getRank(); 417 MemRefDescriptor memrefDescriptor(adaptor.source()); 418 return memrefDescriptor.size(rewriter, loc, index, rank); 419 } 420 }; 421 422 /// Common base for load and store operations on MemRefs. Restricts the match 423 /// to supported MemRef types. Provides functionality to emit code accessing a 424 /// specific element of the underlying data buffer. 425 template <typename Derived> 426 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 427 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 428 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 429 using Base = LoadStoreOpLowering<Derived>; 430 431 LogicalResult match(Derived op) const override { 432 MemRefType type = op.getMemRefType(); 433 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 434 } 435 }; 436 437 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 438 /// retried until it succeeds in atomically storing a new value into memory. 439 /// 440 /// +---------------------------------+ 441 /// | <code before the AtomicRMWOp> | 442 /// | <compute initial %loaded> | 443 /// | cf.br loop(%loaded) | 444 /// +---------------------------------+ 445 /// | 446 /// -------| | 447 /// | v v 448 /// | +--------------------------------+ 449 /// | | loop(%loaded): | 450 /// | | <body contents> | 451 /// | | %pair = cmpxchg | 452 /// | | %ok = %pair[0] | 453 /// | | %new = %pair[1] | 454 /// | | cf.cond_br %ok, end, loop(%new) | 455 /// | +--------------------------------+ 456 /// | | | 457 /// |----------- | 458 /// v 459 /// +--------------------------------+ 460 /// | end: | 461 /// | <code after the AtomicRMWOp> | 462 /// +--------------------------------+ 463 /// 464 struct GenericAtomicRMWOpLowering 465 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 466 using Base::Base; 467 468 LogicalResult 469 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 470 ConversionPatternRewriter &rewriter) const override { 471 auto loc = atomicOp.getLoc(); 472 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 473 474 // Split the block into initial, loop, and ending parts. 475 auto *initBlock = rewriter.getInsertionBlock(); 476 auto *loopBlock = rewriter.createBlock( 477 initBlock->getParent(), std::next(Region::iterator(initBlock)), 478 valueType, loc); 479 auto *endBlock = rewriter.createBlock( 480 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 481 482 // Operations range to be moved to `endBlock`. 483 auto opsToMoveStart = atomicOp->getIterator(); 484 auto opsToMoveEnd = initBlock->back().getIterator(); 485 486 // Compute the loaded value and branch to the loop block. 487 rewriter.setInsertionPointToEnd(initBlock); 488 auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); 489 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 490 adaptor.indices(), rewriter); 491 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 492 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 493 494 // Prepare the body of the loop block. 495 rewriter.setInsertionPointToStart(loopBlock); 496 497 // Clone the GenericAtomicRMWOp region and extract the result. 498 auto loopArgument = loopBlock->getArgument(0); 499 BlockAndValueMapping mapping; 500 mapping.map(atomicOp.getCurrentValue(), loopArgument); 501 Block &entryBlock = atomicOp.body().front(); 502 for (auto &nestedOp : entryBlock.without_terminator()) { 503 Operation *clone = rewriter.clone(nestedOp, mapping); 504 mapping.map(nestedOp.getResults(), clone->getResults()); 505 } 506 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 507 508 // Prepare the epilog of the loop block. 509 // Append the cmpxchg op to the end of the loop block. 510 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 511 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 512 auto boolType = IntegerType::get(rewriter.getContext(), 1); 513 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 514 {valueType, boolType}); 515 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 516 loc, pairType, dataPtr, loopArgument, result, successOrdering, 517 failureOrdering); 518 // Extract the %new_loaded and %ok values from the pair. 519 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 520 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 521 Value ok = rewriter.create<LLVM::ExtractValueOp>( 522 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 523 524 // Conditionally branch to the end or back to the loop depending on %ok. 525 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 526 loopBlock, newLoaded); 527 528 rewriter.setInsertionPointToEnd(endBlock); 529 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 530 std::next(opsToMoveEnd), rewriter); 531 532 // The 'result' of the atomic_rmw op is the newly loaded value. 533 rewriter.replaceOp(atomicOp, {newLoaded}); 534 535 return success(); 536 } 537 538 private: 539 // Clones a segment of ops [start, end) and erases the original. 540 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 541 Block::iterator start, Block::iterator end, 542 ConversionPatternRewriter &rewriter) const { 543 BlockAndValueMapping mapping; 544 mapping.map(oldResult, newResult); 545 SmallVector<Operation *, 2> opsToErase; 546 for (auto it = start; it != end; ++it) { 547 rewriter.clone(*it, mapping); 548 opsToErase.push_back(&*it); 549 } 550 for (auto *it : opsToErase) 551 rewriter.eraseOp(it); 552 } 553 }; 554 555 /// Returns the LLVM type of the global variable given the memref type `type`. 556 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 557 LLVMTypeConverter &typeConverter) { 558 // LLVM type for a global memref will be a multi-dimension array. For 559 // declarations or uninitialized global memrefs, we can potentially flatten 560 // this to a 1D array. However, for memref.global's with an initial value, 561 // we do not intend to flatten the ElementsAttribute when going from std -> 562 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 563 Type elementType = typeConverter.convertType(type.getElementType()); 564 Type arrayTy = elementType; 565 // Shape has the outermost dim at index 0, so need to walk it backwards 566 for (int64_t dim : llvm::reverse(type.getShape())) 567 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 568 return arrayTy; 569 } 570 571 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 572 struct GlobalMemrefOpLowering 573 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 574 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 575 576 LogicalResult 577 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 578 ConversionPatternRewriter &rewriter) const override { 579 MemRefType type = global.type(); 580 if (!isConvertibleAndHasIdentityMaps(type)) 581 return failure(); 582 583 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 584 585 LLVM::Linkage linkage = 586 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 587 588 Attribute initialValue = nullptr; 589 if (!global.isExternal() && !global.isUninitialized()) { 590 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 591 initialValue = elementsAttr; 592 593 // For scalar memrefs, the global variable created is of the element type, 594 // so unpack the elements attribute to extract the value. 595 if (type.getRank() == 0) 596 initialValue = elementsAttr.getSplatValue<Attribute>(); 597 } 598 599 uint64_t alignment = global.alignment().value_or(0); 600 601 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 602 global, arrayTy, global.constant(), linkage, global.sym_name(), 603 initialValue, alignment, type.getMemorySpaceAsInt()); 604 if (!global.isExternal() && global.isUninitialized()) { 605 Block *blk = new Block(); 606 newGlobal.getInitializerRegion().push_back(blk); 607 rewriter.setInsertionPointToStart(blk); 608 Value undef[] = { 609 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 610 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 611 } 612 return success(); 613 } 614 }; 615 616 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 617 /// the first element stashed into the descriptor. This reuses 618 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 619 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 620 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 621 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 622 converter) {} 623 624 /// Buffer "allocation" for memref.get_global op is getting the address of 625 /// the global variable referenced. 626 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 627 Location loc, Value sizeBytes, 628 Operation *op) const override { 629 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 630 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 631 unsigned memSpace = type.getMemorySpaceAsInt(); 632 633 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 634 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 635 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 636 637 // Get the address of the first element in the array by creating a GEP with 638 // the address of the GV as the base, and (rank + 1) number of 0 indices. 639 Type elementType = typeConverter->convertType(type.getElementType()); 640 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 641 642 SmallVector<Value> operands; 643 operands.insert(operands.end(), type.getRank() + 1, 644 createIndexConstant(rewriter, loc, 0)); 645 auto gep = 646 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 647 648 // We do not expect the memref obtained using `memref.get_global` to be 649 // ever deallocated. Set the allocated pointer to be known bad value to 650 // help debug if that ever happens. 651 auto intPtrType = getIntPtrType(memSpace); 652 Value deadBeefConst = 653 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 654 auto deadBeefPtr = 655 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 656 657 // Both allocated and aligned pointers are same. We could potentially stash 658 // a nullptr for the allocated pointer since we do not expect any dealloc. 659 return std::make_tuple(deadBeefPtr, gep); 660 } 661 }; 662 663 // Load operation is lowered to obtaining a pointer to the indexed element 664 // and loading it. 665 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 666 using Base::Base; 667 668 LogicalResult 669 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 670 ConversionPatternRewriter &rewriter) const override { 671 auto type = loadOp.getMemRefType(); 672 673 Value dataPtr = getStridedElementPtr( 674 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 675 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 676 return success(); 677 } 678 }; 679 680 // Store operation is lowered to obtaining a pointer to the indexed element, 681 // and storing the given value to it. 682 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 683 using Base::Base; 684 685 LogicalResult 686 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 687 ConversionPatternRewriter &rewriter) const override { 688 auto type = op.getMemRefType(); 689 690 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 691 adaptor.indices(), rewriter); 692 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 693 return success(); 694 } 695 }; 696 697 // The prefetch operation is lowered in a way similar to the load operation 698 // except that the llvm.prefetch operation is used for replacement. 699 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 700 using Base::Base; 701 702 LogicalResult 703 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 704 ConversionPatternRewriter &rewriter) const override { 705 auto type = prefetchOp.getMemRefType(); 706 auto loc = prefetchOp.getLoc(); 707 708 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 709 adaptor.indices(), rewriter); 710 711 // Replace with llvm.prefetch. 712 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 713 auto isWrite = rewriter.create<LLVM::ConstantOp>( 714 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 715 auto localityHint = rewriter.create<LLVM::ConstantOp>( 716 loc, llvmI32Type, 717 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 718 auto isData = rewriter.create<LLVM::ConstantOp>( 719 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 720 721 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 722 localityHint, isData); 723 return success(); 724 } 725 }; 726 727 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 728 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 729 730 LogicalResult 731 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 732 ConversionPatternRewriter &rewriter) const override { 733 Location loc = op.getLoc(); 734 Type operandType = op.memref().getType(); 735 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 736 UnrankedMemRefDescriptor desc(adaptor.memref()); 737 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 738 return success(); 739 } 740 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 741 rewriter.replaceOp( 742 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 743 return success(); 744 } 745 return failure(); 746 } 747 }; 748 749 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 750 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 751 752 LogicalResult match(memref::CastOp memRefCastOp) const override { 753 Type srcType = memRefCastOp.getOperand().getType(); 754 Type dstType = memRefCastOp.getType(); 755 756 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 757 // used for type erasure. For now they must preserve underlying element type 758 // and require source and result type to have the same rank. Therefore, 759 // perform a sanity check that the underlying structs are the same. Once op 760 // semantics are relaxed we can revisit. 761 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 762 return success(typeConverter->convertType(srcType) == 763 typeConverter->convertType(dstType)); 764 765 // At least one of the operands is unranked type 766 assert(srcType.isa<UnrankedMemRefType>() || 767 dstType.isa<UnrankedMemRefType>()); 768 769 // Unranked to unranked cast is disallowed 770 return !(srcType.isa<UnrankedMemRefType>() && 771 dstType.isa<UnrankedMemRefType>()) 772 ? success() 773 : failure(); 774 } 775 776 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 777 ConversionPatternRewriter &rewriter) const override { 778 auto srcType = memRefCastOp.getOperand().getType(); 779 auto dstType = memRefCastOp.getType(); 780 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 781 auto loc = memRefCastOp.getLoc(); 782 783 // For ranked/ranked case, just keep the original descriptor. 784 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 785 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 786 787 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 788 // Casting ranked to unranked memref type 789 // Set the rank in the destination from the memref type 790 // Allocate space on the stack and copy the src memref descriptor 791 // Set the ptr in the destination to the stack space 792 auto srcMemRefType = srcType.cast<MemRefType>(); 793 int64_t rank = srcMemRefType.getRank(); 794 // ptr = AllocaOp sizeof(MemRefDescriptor) 795 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 796 loc, adaptor.source(), rewriter); 797 // voidptr = BitCastOp srcType* to void* 798 auto voidPtr = 799 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 800 .getResult(); 801 // rank = ConstantOp srcRank 802 auto rankVal = rewriter.create<LLVM::ConstantOp>( 803 loc, getIndexType(), rewriter.getIndexAttr(rank)); 804 // undef = UndefOp 805 UnrankedMemRefDescriptor memRefDesc = 806 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 807 // d1 = InsertValueOp undef, rank, 0 808 memRefDesc.setRank(rewriter, loc, rankVal); 809 // d2 = InsertValueOp d1, voidptr, 1 810 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 811 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 812 813 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 814 // Casting from unranked type to ranked. 815 // The operation is assumed to be doing a correct cast. If the destination 816 // type mismatches the unranked the type, it is undefined behavior. 817 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 818 // ptr = ExtractValueOp src, 1 819 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 820 // castPtr = BitCastOp i8* to structTy* 821 auto castPtr = 822 rewriter 823 .create<LLVM::BitcastOp>( 824 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 825 .getResult(); 826 // struct = LoadOp castPtr 827 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 828 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 829 } else { 830 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 831 } 832 } 833 }; 834 835 /// Pattern to lower a `memref.copy` to llvm. 836 /// 837 /// For memrefs with identity layouts, the copy is lowered to the llvm 838 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 839 /// to the generic `MemrefCopyFn`. 840 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 841 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 842 843 LogicalResult 844 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 845 ConversionPatternRewriter &rewriter) const { 846 auto loc = op.getLoc(); 847 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 848 849 MemRefDescriptor srcDesc(adaptor.source()); 850 851 // Compute number of elements. 852 Value numElements = rewriter.create<LLVM::ConstantOp>( 853 loc, getIndexType(), rewriter.getIndexAttr(1)); 854 for (int pos = 0; pos < srcType.getRank(); ++pos) { 855 auto size = srcDesc.size(rewriter, loc, pos); 856 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 857 } 858 859 // Get element size. 860 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 861 // Compute total. 862 Value totalSize = 863 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 864 865 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 866 Value srcOffset = srcDesc.offset(rewriter, loc); 867 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(), 868 srcBasePtr, srcOffset); 869 MemRefDescriptor targetDesc(adaptor.target()); 870 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 871 Value targetOffset = targetDesc.offset(rewriter, loc); 872 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(), 873 targetBasePtr, targetOffset); 874 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 875 loc, typeConverter->convertType(rewriter.getI1Type()), 876 rewriter.getBoolAttr(false)); 877 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 878 isVolatile); 879 rewriter.eraseOp(op); 880 881 return success(); 882 } 883 884 LogicalResult 885 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 886 ConversionPatternRewriter &rewriter) const { 887 auto loc = op.getLoc(); 888 auto srcType = op.source().getType().cast<BaseMemRefType>(); 889 auto targetType = op.target().getType().cast<BaseMemRefType>(); 890 891 // First make sure we have an unranked memref descriptor representation. 892 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 893 auto rank = rewriter.create<LLVM::ConstantOp>( 894 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 895 auto *typeConverter = getTypeConverter(); 896 auto ptr = 897 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 898 auto voidPtr = 899 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 900 .getResult(); 901 auto unrankedType = 902 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 903 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 904 unrankedType, 905 ValueRange{rank, voidPtr}); 906 }; 907 908 Value unrankedSource = srcType.hasRank() 909 ? makeUnranked(adaptor.source(), srcType) 910 : adaptor.source(); 911 Value unrankedTarget = targetType.hasRank() 912 ? makeUnranked(adaptor.target(), targetType) 913 : adaptor.target(); 914 915 // Now promote the unranked descriptors to the stack. 916 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 917 rewriter.getIndexAttr(1)); 918 auto promote = [&](Value desc) { 919 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 920 auto allocated = 921 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 922 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 923 return allocated; 924 }; 925 926 auto sourcePtr = promote(unrankedSource); 927 auto targetPtr = promote(unrankedTarget); 928 929 unsigned typeSize = 930 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 931 auto elemSize = rewriter.create<LLVM::ConstantOp>( 932 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 933 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 934 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 935 rewriter.create<LLVM::CallOp>(loc, copyFn, 936 ValueRange{elemSize, sourcePtr, targetPtr}); 937 rewriter.eraseOp(op); 938 939 return success(); 940 } 941 942 LogicalResult 943 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 944 ConversionPatternRewriter &rewriter) const override { 945 auto srcType = op.source().getType().cast<BaseMemRefType>(); 946 auto targetType = op.target().getType().cast<BaseMemRefType>(); 947 948 auto isContiguousMemrefType = [](BaseMemRefType type) { 949 auto memrefType = type.dyn_cast<mlir::MemRefType>(); 950 // We can use memcpy for memrefs if they have an identity layout or are 951 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 952 // special case handled by memrefCopy. 953 return memrefType && 954 (memrefType.getLayout().isIdentity() || 955 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 956 isStaticShapeAndContiguousRowMajor(memrefType))); 957 }; 958 959 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 960 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 961 962 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 963 } 964 }; 965 966 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 967 /// memref type. In unranked case, the fields are extracted from the underlying 968 /// ranked descriptor. 969 static void extractPointersAndOffset(Location loc, 970 ConversionPatternRewriter &rewriter, 971 LLVMTypeConverter &typeConverter, 972 Value originalOperand, 973 Value convertedOperand, 974 Value *allocatedPtr, Value *alignedPtr, 975 Value *offset = nullptr) { 976 Type operandType = originalOperand.getType(); 977 if (operandType.isa<MemRefType>()) { 978 MemRefDescriptor desc(convertedOperand); 979 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 980 *alignedPtr = desc.alignedPtr(rewriter, loc); 981 if (offset != nullptr) 982 *offset = desc.offset(rewriter, loc); 983 return; 984 } 985 986 unsigned memorySpace = 987 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 988 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 989 Type llvmElementType = typeConverter.convertType(elementType); 990 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 991 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 992 993 // Extract pointer to the underlying ranked memref descriptor and cast it to 994 // ElemType**. 995 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 996 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 997 998 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 999 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 1000 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 1001 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1002 if (offset != nullptr) { 1003 *offset = UnrankedMemRefDescriptor::offset( 1004 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1005 } 1006 } 1007 1008 struct MemRefReinterpretCastOpLowering 1009 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1010 using ConvertOpToLLVMPattern< 1011 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1012 1013 LogicalResult 1014 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1015 ConversionPatternRewriter &rewriter) const override { 1016 Type srcType = castOp.source().getType(); 1017 1018 Value descriptor; 1019 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1020 adaptor, &descriptor))) 1021 return failure(); 1022 rewriter.replaceOp(castOp, {descriptor}); 1023 return success(); 1024 } 1025 1026 private: 1027 LogicalResult convertSourceMemRefToDescriptor( 1028 ConversionPatternRewriter &rewriter, Type srcType, 1029 memref::ReinterpretCastOp castOp, 1030 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1031 MemRefType targetMemRefType = 1032 castOp.getResult().getType().cast<MemRefType>(); 1033 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1034 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1035 if (!llvmTargetDescriptorTy) 1036 return failure(); 1037 1038 // Create descriptor. 1039 Location loc = castOp.getLoc(); 1040 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1041 1042 // Set allocated and aligned pointers. 1043 Value allocatedPtr, alignedPtr; 1044 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1045 castOp.source(), adaptor.source(), &allocatedPtr, 1046 &alignedPtr); 1047 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1048 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1049 1050 // Set offset. 1051 if (castOp.isDynamicOffset(0)) 1052 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 1053 else 1054 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1055 1056 // Set sizes and strides. 1057 unsigned dynSizeId = 0; 1058 unsigned dynStrideId = 0; 1059 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1060 if (castOp.isDynamicSize(i)) 1061 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 1062 else 1063 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1064 1065 if (castOp.isDynamicStride(i)) 1066 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 1067 else 1068 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1069 } 1070 *descriptor = desc; 1071 return success(); 1072 } 1073 }; 1074 1075 struct MemRefReshapeOpLowering 1076 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1077 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1078 1079 LogicalResult 1080 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1081 ConversionPatternRewriter &rewriter) const override { 1082 Type srcType = reshapeOp.source().getType(); 1083 1084 Value descriptor; 1085 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1086 adaptor, &descriptor))) 1087 return failure(); 1088 rewriter.replaceOp(reshapeOp, {descriptor}); 1089 return success(); 1090 } 1091 1092 private: 1093 LogicalResult 1094 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1095 Type srcType, memref::ReshapeOp reshapeOp, 1096 memref::ReshapeOp::Adaptor adaptor, 1097 Value *descriptor) const { 1098 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 1099 if (shapeMemRefType.hasStaticShape()) { 1100 MemRefType targetMemRefType = 1101 reshapeOp.getResult().getType().cast<MemRefType>(); 1102 auto llvmTargetDescriptorTy = 1103 typeConverter->convertType(targetMemRefType) 1104 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1105 if (!llvmTargetDescriptorTy) 1106 return failure(); 1107 1108 // Create descriptor. 1109 Location loc = reshapeOp.getLoc(); 1110 auto desc = 1111 MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1112 1113 // Set allocated and aligned pointers. 1114 Value allocatedPtr, alignedPtr; 1115 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1116 reshapeOp.source(), adaptor.source(), 1117 &allocatedPtr, &alignedPtr); 1118 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1119 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1120 1121 // Extract the offset and strides from the type. 1122 int64_t offset; 1123 SmallVector<int64_t> strides; 1124 if (failed(getStridesAndOffset(targetMemRefType, strides, offset))) 1125 return rewriter.notifyMatchFailure( 1126 reshapeOp, "failed to get stride and offset exprs"); 1127 1128 if (!isStaticStrideOrOffset(offset)) 1129 return rewriter.notifyMatchFailure(reshapeOp, 1130 "dynamic offset is unsupported"); 1131 1132 desc.setConstantOffset(rewriter, loc, offset); 1133 1134 assert(targetMemRefType.getLayout().isIdentity() && 1135 "Identity layout map is a precondition of a valid reshape op"); 1136 1137 Value stride = nullptr; 1138 int64_t targetRank = targetMemRefType.getRank(); 1139 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) { 1140 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1141 // If the stride for this dimension is dynamic, then use the product 1142 // of the sizes of the inner dimensions. 1143 stride = createIndexConstant(rewriter, loc, strides[i]); 1144 } else if (!stride) { 1145 // `stride` is null only in the first iteration of the loop. However, 1146 // since the target memref has an identity layout, we can safely set 1147 // the innermost stride to 1. 1148 stride = createIndexConstant(rewriter, loc, 1); 1149 } 1150 1151 Value dimSize; 1152 int64_t size = targetMemRefType.getDimSize(i); 1153 // If the size of this dimension is dynamic, then load it at runtime 1154 // from the shape operand. 1155 if (!ShapedType::isDynamic(size)) { 1156 dimSize = createIndexConstant(rewriter, loc, size); 1157 } else { 1158 Value shapeOp = reshapeOp.shape(); 1159 Value index = createIndexConstant(rewriter, loc, i); 1160 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index); 1161 } 1162 1163 desc.setSize(rewriter, loc, i, dimSize); 1164 desc.setStride(rewriter, loc, i, stride); 1165 1166 // Prepare the stride value for the next dimension. 1167 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize); 1168 } 1169 1170 *descriptor = desc; 1171 return success(); 1172 } 1173 1174 // The shape is a rank-1 tensor with unknown length. 1175 Location loc = reshapeOp.getLoc(); 1176 MemRefDescriptor shapeDesc(adaptor.shape()); 1177 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1178 1179 // Extract address space and element type. 1180 auto targetType = 1181 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1182 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1183 Type elementType = targetType.getElementType(); 1184 1185 // Create the unranked memref descriptor that holds the ranked one. The 1186 // inner descriptor is allocated on stack. 1187 auto targetDesc = UnrankedMemRefDescriptor::undef( 1188 rewriter, loc, typeConverter->convertType(targetType)); 1189 targetDesc.setRank(rewriter, loc, resultRank); 1190 SmallVector<Value, 4> sizes; 1191 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1192 targetDesc, sizes); 1193 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1194 loc, getVoidPtrType(), sizes.front(), llvm::None); 1195 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1196 1197 // Extract pointers and offset from the source memref. 1198 Value allocatedPtr, alignedPtr, offset; 1199 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1200 reshapeOp.source(), adaptor.source(), 1201 &allocatedPtr, &alignedPtr, &offset); 1202 1203 // Set pointers and offset. 1204 Type llvmElementType = typeConverter->convertType(elementType); 1205 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1206 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1207 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1208 elementPtrPtrType, allocatedPtr); 1209 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1210 underlyingDescPtr, 1211 elementPtrPtrType, alignedPtr); 1212 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1213 underlyingDescPtr, elementPtrPtrType, 1214 offset); 1215 1216 // Use the offset pointer as base for further addressing. Copy over the new 1217 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1218 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1219 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1220 elementPtrPtrType); 1221 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1222 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1223 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1224 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1225 Value resultRankMinusOne = 1226 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1227 1228 Block *initBlock = rewriter.getInsertionBlock(); 1229 Type indexType = getTypeConverter()->getIndexType(); 1230 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1231 1232 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1233 {indexType, indexType}, {loc, loc}); 1234 1235 // Move the remaining initBlock ops to condBlock. 1236 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1237 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1238 1239 rewriter.setInsertionPointToEnd(initBlock); 1240 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1241 condBlock); 1242 rewriter.setInsertionPointToStart(condBlock); 1243 Value indexArg = condBlock->getArgument(0); 1244 Value strideArg = condBlock->getArgument(1); 1245 1246 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1247 Value pred = rewriter.create<LLVM::ICmpOp>( 1248 loc, IntegerType::get(rewriter.getContext(), 1), 1249 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1250 1251 Block *bodyBlock = 1252 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1253 rewriter.setInsertionPointToStart(bodyBlock); 1254 1255 // Copy size from shape to descriptor. 1256 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1257 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1258 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1259 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1260 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1261 targetSizesBase, indexArg, size); 1262 1263 // Write stride value and compute next one. 1264 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1265 targetStridesBase, indexArg, strideArg); 1266 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1267 1268 // Decrement loop counter and branch back. 1269 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1270 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1271 condBlock); 1272 1273 Block *remainder = 1274 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1275 1276 // Hook up the cond exit to the remainder. 1277 rewriter.setInsertionPointToEnd(condBlock); 1278 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1279 llvm::None); 1280 1281 // Reset position to beginning of new remainder block. 1282 rewriter.setInsertionPointToStart(remainder); 1283 1284 *descriptor = targetDesc; 1285 return success(); 1286 } 1287 }; 1288 1289 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1290 /// `Value`s. 1291 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1292 Type &llvmIndexType, 1293 ArrayRef<OpFoldResult> valueOrAttrVec) { 1294 return llvm::to_vector<4>( 1295 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1296 if (auto attr = value.dyn_cast<Attribute>()) 1297 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1298 return value.get<Value>(); 1299 })); 1300 } 1301 1302 /// Compute a map that for a given dimension of the expanded type gives the 1303 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1304 /// the `reassocation` maps. 1305 static DenseMap<int64_t, int64_t> 1306 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1307 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1308 for (auto &en : enumerate(reassociation)) { 1309 for (auto dim : en.value()) 1310 expandedDimToCollapsedDim[dim] = en.index(); 1311 } 1312 return expandedDimToCollapsedDim; 1313 } 1314 1315 static OpFoldResult 1316 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1317 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1318 MemRefDescriptor &inDesc, 1319 ArrayRef<int64_t> inStaticShape, 1320 ArrayRef<ReassociationIndices> reassocation, 1321 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1322 int64_t outDimSize = outStaticShape[outDimIndex]; 1323 if (!ShapedType::isDynamic(outDimSize)) 1324 return b.getIndexAttr(outDimSize); 1325 1326 // Calculate the multiplication of all the out dim sizes except the 1327 // current dim. 1328 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1329 int64_t otherDimSizesMul = 1; 1330 for (auto otherDimIndex : reassocation[inDimIndex]) { 1331 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1332 continue; 1333 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1334 assert(!ShapedType::isDynamic(otherDimSize) && 1335 "single dimension cannot be expanded into multiple dynamic " 1336 "dimensions"); 1337 otherDimSizesMul *= otherDimSize; 1338 } 1339 1340 // outDimSize = inDimSize / otherOutDimSizesMul 1341 int64_t inDimSize = inStaticShape[inDimIndex]; 1342 Value inDimSizeDynamic = 1343 ShapedType::isDynamic(inDimSize) 1344 ? inDesc.size(b, loc, inDimIndex) 1345 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1346 b.getIndexAttr(inDimSize)); 1347 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1348 loc, inDimSizeDynamic, 1349 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1350 b.getIndexAttr(otherDimSizesMul))); 1351 return outDimSizeDynamic; 1352 } 1353 1354 static OpFoldResult getCollapsedOutputDimSize( 1355 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1356 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1357 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1358 if (!ShapedType::isDynamic(outDimSize)) 1359 return b.getIndexAttr(outDimSize); 1360 1361 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1362 Value outDimSizeDynamic = c1; 1363 for (auto inDimIndex : reassocation[outDimIndex]) { 1364 int64_t inDimSize = inStaticShape[inDimIndex]; 1365 Value inDimSizeDynamic = 1366 ShapedType::isDynamic(inDimSize) 1367 ? inDesc.size(b, loc, inDimIndex) 1368 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1369 b.getIndexAttr(inDimSize)); 1370 outDimSizeDynamic = 1371 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1372 } 1373 return outDimSizeDynamic; 1374 } 1375 1376 static SmallVector<OpFoldResult, 4> 1377 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1378 ArrayRef<ReassociationIndices> reassociation, 1379 ArrayRef<int64_t> inStaticShape, 1380 MemRefDescriptor &inDesc, 1381 ArrayRef<int64_t> outStaticShape) { 1382 return llvm::to_vector<4>(llvm::map_range( 1383 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1384 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1385 outStaticShape[outDimIndex], 1386 inStaticShape, inDesc, reassociation); 1387 })); 1388 } 1389 1390 static SmallVector<OpFoldResult, 4> 1391 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1392 ArrayRef<ReassociationIndices> reassociation, 1393 ArrayRef<int64_t> inStaticShape, 1394 MemRefDescriptor &inDesc, 1395 ArrayRef<int64_t> outStaticShape) { 1396 DenseMap<int64_t, int64_t> outDimToInDimMap = 1397 getExpandedDimToCollapsedDimMap(reassociation); 1398 return llvm::to_vector<4>(llvm::map_range( 1399 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1400 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1401 outStaticShape, inDesc, inStaticShape, 1402 reassociation, outDimToInDimMap); 1403 })); 1404 } 1405 1406 static SmallVector<Value> 1407 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1408 ArrayRef<ReassociationIndices> reassociation, 1409 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1410 ArrayRef<int64_t> outStaticShape) { 1411 return outStaticShape.size() < inStaticShape.size() 1412 ? getAsValues(b, loc, llvmIndexType, 1413 getCollapsedOutputShape(b, loc, llvmIndexType, 1414 reassociation, inStaticShape, 1415 inDesc, outStaticShape)) 1416 : getAsValues(b, loc, llvmIndexType, 1417 getExpandedOutputShape(b, loc, llvmIndexType, 1418 reassociation, inStaticShape, 1419 inDesc, outStaticShape)); 1420 } 1421 1422 static void fillInStridesForExpandedMemDescriptor( 1423 OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, 1424 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1425 // See comments for computeExpandedLayoutMap for details on how the strides 1426 // are calculated. 1427 for (auto &en : llvm::enumerate(reassociation)) { 1428 auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); 1429 for (auto dstIndex : llvm::reverse(en.value())) { 1430 dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); 1431 Value size = dstDesc.size(b, loc, dstIndex); 1432 currentStrideToExpand = 1433 b.create<LLVM::MulOp>(loc, size, currentStrideToExpand); 1434 } 1435 } 1436 } 1437 1438 static void fillInStridesForCollapsedMemDescriptor( 1439 ConversionPatternRewriter &rewriter, Location loc, Operation *op, 1440 TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, 1441 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1442 // See comments for computeCollapsedLayoutMap for details on how the strides 1443 // are calculated. 1444 auto srcShape = srcType.getShape(); 1445 for (auto &en : llvm::enumerate(reassociation)) { 1446 rewriter.setInsertionPoint(op); 1447 auto dstIndex = en.index(); 1448 ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value()); 1449 while (srcShape[ref.back()] == 1 && ref.size() > 1) 1450 ref = ref.drop_back(); 1451 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { 1452 dstDesc.setStride(rewriter, loc, dstIndex, 1453 srcDesc.stride(rewriter, loc, ref.back())); 1454 } else { 1455 // Iterate over the source strides in reverse order. Skip over the 1456 // dimensions whose size is 1. 1457 // TODO: we should take the minimum stride in the reassociation group 1458 // instead of just the first where the dimension is not 1. 1459 // 1460 // +------------------------------------------------------+ 1461 // | curEntry: | 1462 // | %srcStride = strides[srcIndex] | 1463 // | %neOne = cmp sizes[srcIndex],1 +--+ 1464 // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | 1465 // +-------------------------+----------------------------+ | 1466 // | | 1467 // v | 1468 // +-----------------------------+ | 1469 // | nextEntry: | | 1470 // | ... +---+ | 1471 // +--------------+--------------+ | | 1472 // | | | 1473 // v | | 1474 // +-----------------------------+ | | 1475 // | nextEntry: | | | 1476 // | ... | | | 1477 // +--------------+--------------+ | +--------+ 1478 // | | | 1479 // v v v 1480 // +--------------------------------------------------+ 1481 // | continue(%newStride): | 1482 // | %newMemRefDes = setStride(%newStride,dstIndex) | 1483 // +--------------------------------------------------+ 1484 OpBuilder::InsertionGuard guard(rewriter); 1485 Block *initBlock = rewriter.getInsertionBlock(); 1486 Block *continueBlock = 1487 rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); 1488 continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); 1489 rewriter.setInsertionPointToStart(continueBlock); 1490 dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); 1491 1492 Block *curEntryBlock = initBlock; 1493 Block *nextEntryBlock; 1494 for (auto srcIndex : llvm::reverse(ref)) { 1495 if (srcShape[srcIndex] == 1 && srcIndex != ref.front()) 1496 continue; 1497 rewriter.setInsertionPointToEnd(curEntryBlock); 1498 Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); 1499 if (srcIndex == ref.front()) { 1500 rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock); 1501 break; 1502 } 1503 Value one = rewriter.create<LLVM::ConstantOp>( 1504 loc, typeConverter->convertType(rewriter.getI64Type()), 1505 rewriter.getI32IntegerAttr(1)); 1506 Value predNeOne = rewriter.create<LLVM::ICmpOp>( 1507 loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), 1508 one); 1509 { 1510 OpBuilder::InsertionGuard guard(rewriter); 1511 nextEntryBlock = rewriter.createBlock( 1512 initBlock->getParent(), Region::iterator(continueBlock), {}); 1513 } 1514 rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock, 1515 srcStride, nextEntryBlock, llvm::None); 1516 curEntryBlock = nextEntryBlock; 1517 } 1518 } 1519 } 1520 } 1521 1522 static void fillInDynamicStridesForMemDescriptor( 1523 ConversionPatternRewriter &b, Location loc, Operation *op, 1524 TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, 1525 MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, 1526 ArrayRef<ReassociationIndices> reassociation) { 1527 if (srcType.getRank() > dstType.getRank()) 1528 fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, 1529 srcDesc, dstDesc, reassociation); 1530 else 1531 fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, 1532 reassociation); 1533 } 1534 1535 // ReshapeOp creates a new view descriptor of the proper rank. 1536 // For now, the only conversion supported is for target MemRef with static sizes 1537 // and strides. 1538 template <typename ReshapeOp> 1539 class ReassociatingReshapeOpConversion 1540 : public ConvertOpToLLVMPattern<ReshapeOp> { 1541 public: 1542 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1543 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1544 1545 LogicalResult 1546 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1547 ConversionPatternRewriter &rewriter) const override { 1548 MemRefType dstType = reshapeOp.getResultType(); 1549 MemRefType srcType = reshapeOp.getSrcType(); 1550 1551 int64_t offset; 1552 SmallVector<int64_t, 4> strides; 1553 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1554 return rewriter.notifyMatchFailure( 1555 reshapeOp, "failed to get stride and offset exprs"); 1556 } 1557 1558 MemRefDescriptor srcDesc(adaptor.src()); 1559 Location loc = reshapeOp->getLoc(); 1560 auto dstDesc = MemRefDescriptor::undef( 1561 rewriter, loc, this->typeConverter->convertType(dstType)); 1562 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1563 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1564 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1565 1566 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1567 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1568 Type llvmIndexType = 1569 this->typeConverter->convertType(rewriter.getIndexType()); 1570 SmallVector<Value> dstShape = getDynamicOutputShape( 1571 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1572 srcStaticShape, srcDesc, dstStaticShape); 1573 for (auto &en : llvm::enumerate(dstShape)) 1574 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1575 1576 if (llvm::all_of(strides, isStaticStrideOrOffset)) { 1577 for (auto &en : llvm::enumerate(strides)) 1578 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1579 } else if (srcType.getLayout().isIdentity() && 1580 dstType.getLayout().isIdentity()) { 1581 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1582 rewriter.getIndexAttr(1)); 1583 Value stride = c1; 1584 for (auto dimIndex : 1585 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1586 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1587 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1588 } 1589 } else { 1590 // There could be mixed static/dynamic strides. For simplicity, we 1591 // recompute all strides if there is at least one dynamic stride. 1592 fillInDynamicStridesForMemDescriptor( 1593 rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, 1594 srcDesc, dstDesc, reshapeOp.getReassociationIndices()); 1595 } 1596 rewriter.replaceOp(reshapeOp, {dstDesc}); 1597 return success(); 1598 } 1599 }; 1600 1601 /// Conversion pattern that transforms a subview op into: 1602 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1603 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1604 /// and stride. 1605 /// The subview op is replaced by the descriptor. 1606 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1607 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1608 1609 LogicalResult 1610 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1611 ConversionPatternRewriter &rewriter) const override { 1612 auto loc = subViewOp.getLoc(); 1613 1614 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1615 auto sourceElementTy = 1616 typeConverter->convertType(sourceMemRefType.getElementType()); 1617 1618 auto viewMemRefType = subViewOp.getType(); 1619 auto inferredType = memref::SubViewOp::inferResultType( 1620 subViewOp.getSourceType(), 1621 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1622 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1623 extractFromI64ArrayAttr(subViewOp.static_strides())) 1624 .cast<MemRefType>(); 1625 auto targetElementTy = 1626 typeConverter->convertType(viewMemRefType.getElementType()); 1627 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1628 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1629 !LLVM::isCompatibleType(sourceElementTy) || 1630 !LLVM::isCompatibleType(targetElementTy) || 1631 !LLVM::isCompatibleType(targetDescTy)) 1632 return failure(); 1633 1634 // Extract the offset and strides from the type. 1635 int64_t offset; 1636 SmallVector<int64_t, 4> strides; 1637 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1638 if (failed(successStrides)) 1639 return failure(); 1640 1641 // Create the descriptor. 1642 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1643 return failure(); 1644 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1645 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1646 1647 // Copy the buffer pointer from the old descriptor to the new one. 1648 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1649 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1650 loc, 1651 LLVM::LLVMPointerType::get(targetElementTy, 1652 viewMemRefType.getMemorySpaceAsInt()), 1653 extracted); 1654 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1655 1656 // Copy the aligned pointer from the old descriptor to the new one. 1657 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1658 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1659 loc, 1660 LLVM::LLVMPointerType::get(targetElementTy, 1661 viewMemRefType.getMemorySpaceAsInt()), 1662 extracted); 1663 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1664 1665 size_t inferredShapeRank = inferredType.getRank(); 1666 size_t resultShapeRank = viewMemRefType.getRank(); 1667 1668 // Extract strides needed to compute offset. 1669 SmallVector<Value, 4> strideValues; 1670 strideValues.reserve(inferredShapeRank); 1671 for (unsigned i = 0; i < inferredShapeRank; ++i) 1672 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1673 1674 // Offset. 1675 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1676 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1677 targetMemRef.setConstantOffset(rewriter, loc, offset); 1678 } else { 1679 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1680 // `inferredShapeRank` may be larger than the number of offset operands 1681 // because of trailing semantics. In this case, the offset is guaranteed 1682 // to be interpreted as 0 and we can just skip the extra dimensions. 1683 for (unsigned i = 0, e = std::min(inferredShapeRank, 1684 subViewOp.getMixedOffsets().size()); 1685 i < e; ++i) { 1686 Value offset = 1687 // TODO: need OpFoldResult ODS adaptor to clean this up. 1688 subViewOp.isDynamicOffset(i) 1689 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1690 : rewriter.create<LLVM::ConstantOp>( 1691 loc, llvmIndexType, 1692 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1693 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1694 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1695 } 1696 targetMemRef.setOffset(rewriter, loc, baseOffset); 1697 } 1698 1699 // Update sizes and strides. 1700 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1701 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1702 assert(mixedSizes.size() == mixedStrides.size() && 1703 "expected sizes and strides of equal length"); 1704 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1705 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1706 i >= 0 && j >= 0; --i) { 1707 if (unusedDims.test(i)) 1708 continue; 1709 1710 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1711 // In this case, the size is guaranteed to be interpreted as Dim and the 1712 // stride as 1. 1713 Value size, stride; 1714 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1715 // If the static size is available, use it directly. This is similar to 1716 // the folding of dim(constant-op) but removes the need for dim to be 1717 // aware of LLVM constants and for this pass to be aware of std 1718 // constants. 1719 int64_t staticSize = 1720 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1721 if (staticSize != ShapedType::kDynamicSize) { 1722 size = rewriter.create<LLVM::ConstantOp>( 1723 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1724 } else { 1725 Value pos = rewriter.create<LLVM::ConstantOp>( 1726 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1727 Value dim = 1728 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1729 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1730 loc, llvmIndexType, dim); 1731 size = cast.getResult(0); 1732 } 1733 stride = rewriter.create<LLVM::ConstantOp>( 1734 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1735 } else { 1736 // TODO: need OpFoldResult ODS adaptor to clean this up. 1737 size = 1738 subViewOp.isDynamicSize(i) 1739 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1740 : rewriter.create<LLVM::ConstantOp>( 1741 loc, llvmIndexType, 1742 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1743 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1744 stride = rewriter.create<LLVM::ConstantOp>( 1745 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1746 } else { 1747 stride = 1748 subViewOp.isDynamicStride(i) 1749 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1750 : rewriter.create<LLVM::ConstantOp>( 1751 loc, llvmIndexType, 1752 rewriter.getI64IntegerAttr( 1753 subViewOp.getStaticStride(i))); 1754 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1755 } 1756 } 1757 targetMemRef.setSize(rewriter, loc, j, size); 1758 targetMemRef.setStride(rewriter, loc, j, stride); 1759 j--; 1760 } 1761 1762 rewriter.replaceOp(subViewOp, {targetMemRef}); 1763 return success(); 1764 } 1765 }; 1766 1767 /// Conversion pattern that transforms a transpose op into: 1768 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1769 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1770 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1771 /// and stride. Size and stride are permutations of the original values. 1772 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1773 /// The transpose op is replaced by the alloca'ed pointer. 1774 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1775 public: 1776 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1777 1778 LogicalResult 1779 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1780 ConversionPatternRewriter &rewriter) const override { 1781 auto loc = transposeOp.getLoc(); 1782 MemRefDescriptor viewMemRef(adaptor.in()); 1783 1784 // No permutation, early exit. 1785 if (transposeOp.permutation().isIdentity()) 1786 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1787 1788 auto targetMemRef = MemRefDescriptor::undef( 1789 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1790 1791 // Copy the base and aligned pointers from the old descriptor to the new 1792 // one. 1793 targetMemRef.setAllocatedPtr(rewriter, loc, 1794 viewMemRef.allocatedPtr(rewriter, loc)); 1795 targetMemRef.setAlignedPtr(rewriter, loc, 1796 viewMemRef.alignedPtr(rewriter, loc)); 1797 1798 // Copy the offset pointer from the old descriptor to the new one. 1799 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1800 1801 // Iterate over the dimensions and apply size/stride permutation. 1802 for (const auto &en : 1803 llvm::enumerate(transposeOp.permutation().getResults())) { 1804 int sourcePos = en.index(); 1805 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1806 targetMemRef.setSize(rewriter, loc, targetPos, 1807 viewMemRef.size(rewriter, loc, sourcePos)); 1808 targetMemRef.setStride(rewriter, loc, targetPos, 1809 viewMemRef.stride(rewriter, loc, sourcePos)); 1810 } 1811 1812 rewriter.replaceOp(transposeOp, {targetMemRef}); 1813 return success(); 1814 } 1815 }; 1816 1817 /// Conversion pattern that transforms an op into: 1818 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1819 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1820 /// and stride. 1821 /// The view op is replaced by the descriptor. 1822 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1823 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1824 1825 // Build and return the value for the idx^th shape dimension, either by 1826 // returning the constant shape dimension or counting the proper dynamic size. 1827 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1828 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1829 unsigned idx) const { 1830 assert(idx < shape.size()); 1831 if (!ShapedType::isDynamic(shape[idx])) 1832 return createIndexConstant(rewriter, loc, shape[idx]); 1833 // Count the number of dynamic dims in range [0, idx] 1834 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1835 return ShapedType::isDynamic(v); 1836 }); 1837 return dynamicSizes[nDynamic]; 1838 } 1839 1840 // Build and return the idx^th stride, either by returning the constant stride 1841 // or by computing the dynamic stride from the current `runningStride` and 1842 // `nextSize`. The caller should keep a running stride and update it with the 1843 // result returned by this function. 1844 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1845 ArrayRef<int64_t> strides, Value nextSize, 1846 Value runningStride, unsigned idx) const { 1847 assert(idx < strides.size()); 1848 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1849 return createIndexConstant(rewriter, loc, strides[idx]); 1850 if (nextSize) 1851 return runningStride 1852 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1853 : nextSize; 1854 assert(!runningStride); 1855 return createIndexConstant(rewriter, loc, 1); 1856 } 1857 1858 LogicalResult 1859 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1860 ConversionPatternRewriter &rewriter) const override { 1861 auto loc = viewOp.getLoc(); 1862 1863 auto viewMemRefType = viewOp.getType(); 1864 auto targetElementTy = 1865 typeConverter->convertType(viewMemRefType.getElementType()); 1866 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1867 if (!targetDescTy || !targetElementTy || 1868 !LLVM::isCompatibleType(targetElementTy) || 1869 !LLVM::isCompatibleType(targetDescTy)) 1870 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1871 failure(); 1872 1873 int64_t offset; 1874 SmallVector<int64_t, 4> strides; 1875 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1876 if (failed(successStrides)) 1877 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1878 assert(offset == 0 && "expected offset to be 0"); 1879 1880 // Target memref must be contiguous in memory (innermost stride is 1), or 1881 // empty (special case when at least one of the memref dimensions is 0). 1882 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) 1883 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1884 failure(); 1885 1886 // Create the descriptor. 1887 MemRefDescriptor sourceMemRef(adaptor.source()); 1888 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1889 1890 // Field 1: Copy the allocated pointer, used for malloc/free. 1891 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1892 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1893 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1894 loc, 1895 LLVM::LLVMPointerType::get(targetElementTy, 1896 srcMemRefType.getMemorySpaceAsInt()), 1897 allocatedPtr); 1898 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1899 1900 // Field 2: Copy the actual aligned pointer to payload. 1901 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1902 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1903 alignedPtr, adaptor.byte_shift()); 1904 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1905 loc, 1906 LLVM::LLVMPointerType::get(targetElementTy, 1907 srcMemRefType.getMemorySpaceAsInt()), 1908 alignedPtr); 1909 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1910 1911 // Field 3: The offset in the resulting type must be 0. This is because of 1912 // the type change: an offset on srcType* may not be expressible as an 1913 // offset on dstType*. 1914 targetMemRef.setOffset(rewriter, loc, 1915 createIndexConstant(rewriter, loc, offset)); 1916 1917 // Early exit for 0-D corner case. 1918 if (viewMemRefType.getRank() == 0) 1919 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1920 1921 // Fields 4 and 5: Update sizes and strides. 1922 Value stride = nullptr, nextSize = nullptr; 1923 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1924 // Update size. 1925 Value size = 1926 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1927 targetMemRef.setSize(rewriter, loc, i, size); 1928 // Update stride. 1929 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1930 targetMemRef.setStride(rewriter, loc, i, stride); 1931 nextSize = size; 1932 } 1933 1934 rewriter.replaceOp(viewOp, {targetMemRef}); 1935 return success(); 1936 } 1937 }; 1938 1939 //===----------------------------------------------------------------------===// 1940 // AtomicRMWOpLowering 1941 //===----------------------------------------------------------------------===// 1942 1943 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1944 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1945 static Optional<LLVM::AtomicBinOp> 1946 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1947 switch (atomicOp.kind()) { 1948 case arith::AtomicRMWKind::addf: 1949 return LLVM::AtomicBinOp::fadd; 1950 case arith::AtomicRMWKind::addi: 1951 return LLVM::AtomicBinOp::add; 1952 case arith::AtomicRMWKind::assign: 1953 return LLVM::AtomicBinOp::xchg; 1954 case arith::AtomicRMWKind::maxs: 1955 return LLVM::AtomicBinOp::max; 1956 case arith::AtomicRMWKind::maxu: 1957 return LLVM::AtomicBinOp::umax; 1958 case arith::AtomicRMWKind::mins: 1959 return LLVM::AtomicBinOp::min; 1960 case arith::AtomicRMWKind::minu: 1961 return LLVM::AtomicBinOp::umin; 1962 case arith::AtomicRMWKind::ori: 1963 return LLVM::AtomicBinOp::_or; 1964 case arith::AtomicRMWKind::andi: 1965 return LLVM::AtomicBinOp::_and; 1966 default: 1967 return llvm::None; 1968 } 1969 llvm_unreachable("Invalid AtomicRMWKind"); 1970 } 1971 1972 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1973 using Base::Base; 1974 1975 LogicalResult 1976 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1977 ConversionPatternRewriter &rewriter) const override { 1978 if (failed(match(atomicOp))) 1979 return failure(); 1980 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1981 if (!maybeKind) 1982 return failure(); 1983 auto resultType = adaptor.value().getType(); 1984 auto memRefType = atomicOp.getMemRefType(); 1985 auto dataPtr = 1986 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1987 adaptor.indices(), rewriter); 1988 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1989 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1990 LLVM::AtomicOrdering::acq_rel); 1991 return success(); 1992 } 1993 }; 1994 1995 } // namespace 1996 1997 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1998 RewritePatternSet &patterns) { 1999 // clang-format off 2000 patterns.add< 2001 AllocaOpLowering, 2002 AllocaScopeOpLowering, 2003 AtomicRMWOpLowering, 2004 AssumeAlignmentOpLowering, 2005 DimOpLowering, 2006 GenericAtomicRMWOpLowering, 2007 GlobalMemrefOpLowering, 2008 GetGlobalMemrefOpLowering, 2009 LoadOpLowering, 2010 MemRefCastOpLowering, 2011 MemRefCopyOpLowering, 2012 MemRefReinterpretCastOpLowering, 2013 MemRefReshapeOpLowering, 2014 PrefetchOpLowering, 2015 RankOpLowering, 2016 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 2017 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 2018 StoreOpLowering, 2019 SubViewOpLowering, 2020 TransposeOpLowering, 2021 ViewOpLowering>(converter); 2022 // clang-format on 2023 auto allocLowering = converter.getOptions().allocLowering; 2024 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 2025 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 2026 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 2027 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 2028 } 2029 2030 namespace { 2031 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 2032 MemRefToLLVMPass() = default; 2033 2034 void runOnOperation() override { 2035 Operation *op = getOperation(); 2036 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 2037 LowerToLLVMOptions options(&getContext(), 2038 dataLayoutAnalysis.getAtOrAbove(op)); 2039 options.allocLowering = 2040 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 2041 : LowerToLLVMOptions::AllocLowering::Malloc); 2042 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 2043 options.overrideIndexBitwidth(indexBitwidth); 2044 2045 LLVMTypeConverter typeConverter(&getContext(), options, 2046 &dataLayoutAnalysis); 2047 RewritePatternSet patterns(&getContext()); 2048 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 2049 LLVMConversionTarget target(getContext()); 2050 target.addLegalOp<func::FuncOp>(); 2051 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 2052 signalPassFailure(); 2053 } 2054 }; 2055 } // namespace 2056 2057 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 2058 return std::make_unique<MemRefToLLVMPass>(); 2059 } 2060