1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "llvm/ADT/SmallBitVector.h" 24 25 using namespace mlir; 26 27 namespace { 28 29 bool isStaticStrideOrOffset(int64_t strideOrOffset) { 30 return !ShapedType::isDynamicStrideOrOffset(strideOrOffset); 31 } 32 33 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 34 AllocOpLowering(LLVMTypeConverter &converter) 35 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 36 converter) {} 37 38 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 39 Location loc, Value sizeBytes, 40 Operation *op) const override { 41 // Heap allocations. 42 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 43 MemRefType memRefType = allocOp.getType(); 44 45 Value alignment; 46 if (auto alignmentAttr = allocOp.getAlignment()) { 47 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 48 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 49 // In the case where no alignment is specified, we may want to override 50 // `malloc's` behavior. `malloc` typically aligns at the size of the 51 // biggest scalar on a target HW. For non-scalars, use the natural 52 // alignment of the LLVM type given by the LLVM DataLayout. 53 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 54 } 55 56 if (alignment) { 57 // Adjust the allocation size to consider alignment. 58 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 59 } 60 61 // Allocate the underlying buffer and store a pointer to it in the MemRef 62 // descriptor. 63 Type elementPtrType = this->getElementPtrType(memRefType); 64 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 65 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 66 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 67 getVoidPtrType()); 68 Value allocatedPtr = 69 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 70 71 Value alignedPtr = allocatedPtr; 72 if (alignment) { 73 // Compute the aligned type pointer. 74 Value allocatedInt = 75 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 76 Value alignmentInt = 77 createAligned(rewriter, loc, allocatedInt, alignment); 78 alignedPtr = 79 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 80 } 81 82 return std::make_tuple(allocatedPtr, alignedPtr); 83 } 84 }; 85 86 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 87 AlignedAllocOpLowering(LLVMTypeConverter &converter) 88 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 89 converter) {} 90 91 /// Returns the memref's element size in bytes using the data layout active at 92 /// `op`. 93 // TODO: there are other places where this is used. Expose publicly? 94 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 95 const DataLayout *layout = &defaultLayout; 96 if (const DataLayoutAnalysis *analysis = 97 getTypeConverter()->getDataLayoutAnalysis()) { 98 layout = &analysis->getAbove(op); 99 } 100 Type elementType = memRefType.getElementType(); 101 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 102 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 103 *layout); 104 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 105 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 106 memRefElementType, *layout); 107 return layout->getTypeSize(elementType); 108 } 109 110 /// Returns true if the memref size in bytes is known to be a multiple of 111 /// factor assuming the data layout active at `op`. 112 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 113 Operation *op) const { 114 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 115 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 116 if (ShapedType::isDynamic(type.getDimSize(i))) 117 continue; 118 sizeDivisor = sizeDivisor * type.getDimSize(i); 119 } 120 return sizeDivisor % factor == 0; 121 } 122 123 /// Returns the alignment to be used for the allocation call itself. 124 /// aligned_alloc requires the allocation size to be a power of two, and the 125 /// allocation size to be a multiple of alignment, 126 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 127 if (Optional<uint64_t> alignment = allocOp.getAlignment()) 128 return *alignment; 129 130 // Whenever we don't have alignment set, we will use an alignment 131 // consistent with the element type; since the allocation size has to be a 132 // power of two, we will bump to the next power of two if it already isn't. 133 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 134 return std::max(kMinAlignedAllocAlignment, 135 llvm::PowerOf2Ceil(eltSizeBytes)); 136 } 137 138 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 139 Location loc, Value sizeBytes, 140 Operation *op) const override { 141 // Heap allocations. 142 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 143 MemRefType memRefType = allocOp.getType(); 144 int64_t alignment = getAllocationAlignment(allocOp); 145 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 146 147 // aligned_alloc requires size to be a multiple of alignment; we will pad 148 // the size to the next multiple if necessary. 149 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 150 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 151 152 Type elementPtrType = this->getElementPtrType(memRefType); 153 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 154 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 155 auto results = 156 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 157 getVoidPtrType()); 158 Value allocatedPtr = 159 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 160 161 return std::make_tuple(allocatedPtr, allocatedPtr); 162 } 163 164 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 165 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 166 167 /// Default layout to use in absence of the corresponding analysis. 168 DataLayout defaultLayout; 169 }; 170 171 // Out of line definition, required till C++17. 172 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 173 174 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 175 AllocaOpLowering(LLVMTypeConverter &converter) 176 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 177 converter) {} 178 179 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 180 /// is set to null for stack allocations. `accessAlignment` is set if 181 /// alignment is needed post allocation (for eg. in conjunction with malloc). 182 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 183 Location loc, Value sizeBytes, 184 Operation *op) const override { 185 186 // With alloca, one gets a pointer to the element type right away. 187 // For stack allocations. 188 auto allocaOp = cast<memref::AllocaOp>(op); 189 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 190 191 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 192 loc, elementPtrType, sizeBytes, 193 allocaOp.getAlignment() ? *allocaOp.getAlignment() : 0); 194 195 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 196 } 197 }; 198 199 struct AllocaScopeOpLowering 200 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 201 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 202 203 LogicalResult 204 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 205 ConversionPatternRewriter &rewriter) const override { 206 OpBuilder::InsertionGuard guard(rewriter); 207 Location loc = allocaScopeOp.getLoc(); 208 209 // Split the current block before the AllocaScopeOp to create the inlining 210 // point. 211 auto *currentBlock = rewriter.getInsertionBlock(); 212 auto *remainingOpsBlock = 213 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 214 Block *continueBlock; 215 if (allocaScopeOp.getNumResults() == 0) { 216 continueBlock = remainingOpsBlock; 217 } else { 218 continueBlock = rewriter.createBlock( 219 remainingOpsBlock, allocaScopeOp.getResultTypes(), 220 SmallVector<Location>(allocaScopeOp->getNumResults(), 221 allocaScopeOp.getLoc())); 222 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 223 } 224 225 // Inline body region. 226 Block *beforeBody = &allocaScopeOp.getBodyRegion().front(); 227 Block *afterBody = &allocaScopeOp.getBodyRegion().back(); 228 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock); 229 230 // Save stack and then branch into the body of the region. 231 rewriter.setInsertionPointToEnd(currentBlock); 232 auto stackSaveOp = 233 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 234 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 235 236 // Replace the alloca_scope return with a branch that jumps out of the body. 237 // Stack restore before leaving the body region. 238 rewriter.setInsertionPointToEnd(afterBody); 239 auto returnOp = 240 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 241 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 242 returnOp, returnOp.getResults(), continueBlock); 243 244 // Insert stack restore before jumping out the body of the region. 245 rewriter.setInsertionPoint(branchOp); 246 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 247 248 // Replace the op with values return from the body region. 249 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 250 251 return success(); 252 } 253 }; 254 255 struct AssumeAlignmentOpLowering 256 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 257 using ConvertOpToLLVMPattern< 258 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 259 260 LogicalResult 261 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 262 ConversionPatternRewriter &rewriter) const override { 263 Value memref = adaptor.getMemref(); 264 unsigned alignment = op.getAlignment(); 265 auto loc = op.getLoc(); 266 267 MemRefDescriptor memRefDescriptor(memref); 268 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 269 270 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 271 // the asserted memref.alignedPtr isn't used anywhere else, as the real 272 // users like load/store/views always re-extract memref.alignedPtr as they 273 // get lowered. 274 // 275 // This relies on LLVM's CSE optimization (potentially after SROA), since 276 // after CSE all memref.alignedPtr instances get de-duplicated into the same 277 // pointer SSA value. 278 auto intPtrType = 279 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 280 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 281 Value mask = 282 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 283 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 284 rewriter.create<LLVM::AssumeOp>( 285 loc, rewriter.create<LLVM::ICmpOp>( 286 loc, LLVM::ICmpPredicate::eq, 287 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 288 289 rewriter.eraseOp(op); 290 return success(); 291 } 292 }; 293 294 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 295 // The memref descriptor being an SSA value, there is no need to clean it up 296 // in any way. 297 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 298 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 299 300 explicit DeallocOpLowering(LLVMTypeConverter &converter) 301 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 302 303 LogicalResult 304 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 305 ConversionPatternRewriter &rewriter) const override { 306 // Insert the `free` declaration if it is not already present. 307 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 308 MemRefDescriptor memref(adaptor.getMemref()); 309 Value casted = rewriter.create<LLVM::BitcastOp>( 310 op.getLoc(), getVoidPtrType(), 311 memref.allocatedPtr(rewriter, op.getLoc())); 312 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 313 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 314 return success(); 315 } 316 }; 317 318 // A `dim` is converted to a constant for static sizes and to an access to the 319 // size stored in the memref descriptor for dynamic sizes. 320 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 321 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 322 323 LogicalResult 324 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 325 ConversionPatternRewriter &rewriter) const override { 326 Type operandType = dimOp.getSource().getType(); 327 if (operandType.isa<UnrankedMemRefType>()) { 328 rewriter.replaceOp( 329 dimOp, {extractSizeOfUnrankedMemRef( 330 operandType, dimOp, adaptor.getOperands(), rewriter)}); 331 332 return success(); 333 } 334 if (operandType.isa<MemRefType>()) { 335 rewriter.replaceOp( 336 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 337 adaptor.getOperands(), rewriter)}); 338 return success(); 339 } 340 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 341 } 342 343 private: 344 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 345 OpAdaptor adaptor, 346 ConversionPatternRewriter &rewriter) const { 347 Location loc = dimOp.getLoc(); 348 349 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 350 auto scalarMemRefType = 351 MemRefType::get({}, unrankedMemRefType.getElementType()); 352 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 353 354 // Extract pointer to the underlying ranked descriptor and bitcast it to a 355 // memref<element_type> descriptor pointer to minimize the number of GEP 356 // operations. 357 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource()); 358 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 359 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 360 loc, 361 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 362 addressSpace), 363 underlyingRankedDesc); 364 365 // Get pointer to offset field of memref<element_type> descriptor. 366 Type indexPtrTy = LLVM::LLVMPointerType::get( 367 getTypeConverter()->getIndexType(), addressSpace); 368 Value two = rewriter.create<LLVM::ConstantOp>( 369 loc, typeConverter->convertType(rewriter.getI32Type()), 370 rewriter.getI32IntegerAttr(2)); 371 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 372 loc, indexPtrTy, scalarMemRefDescPtr, 373 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 374 375 // The size value that we have to extract can be obtained using GEPop with 376 // `dimOp.index() + 1` index argument. 377 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 378 loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex()); 379 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 380 ValueRange({idxPlusOne})); 381 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 382 } 383 384 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 385 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 386 return idx; 387 388 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>()) 389 return constantOp.getValue() 390 .cast<IntegerAttr>() 391 .getValue() 392 .getSExtValue(); 393 394 return llvm::None; 395 } 396 397 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 398 OpAdaptor adaptor, 399 ConversionPatternRewriter &rewriter) const { 400 Location loc = dimOp.getLoc(); 401 402 // Take advantage if index is constant. 403 MemRefType memRefType = operandType.cast<MemRefType>(); 404 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 405 int64_t i = *index; 406 if (memRefType.isDynamicDim(i)) { 407 // extract dynamic size from the memref descriptor. 408 MemRefDescriptor descriptor(adaptor.getSource()); 409 return descriptor.size(rewriter, loc, i); 410 } 411 // Use constant for static size. 412 int64_t dimSize = memRefType.getDimSize(i); 413 return createIndexConstant(rewriter, loc, dimSize); 414 } 415 Value index = adaptor.getIndex(); 416 int64_t rank = memRefType.getRank(); 417 MemRefDescriptor memrefDescriptor(adaptor.getSource()); 418 return memrefDescriptor.size(rewriter, loc, index, rank); 419 } 420 }; 421 422 /// Common base for load and store operations on MemRefs. Restricts the match 423 /// to supported MemRef types. Provides functionality to emit code accessing a 424 /// specific element of the underlying data buffer. 425 template <typename Derived> 426 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 427 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 428 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 429 using Base = LoadStoreOpLowering<Derived>; 430 431 LogicalResult match(Derived op) const override { 432 MemRefType type = op.getMemRefType(); 433 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 434 } 435 }; 436 437 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 438 /// retried until it succeeds in atomically storing a new value into memory. 439 /// 440 /// +---------------------------------+ 441 /// | <code before the AtomicRMWOp> | 442 /// | <compute initial %loaded> | 443 /// | cf.br loop(%loaded) | 444 /// +---------------------------------+ 445 /// | 446 /// -------| | 447 /// | v v 448 /// | +--------------------------------+ 449 /// | | loop(%loaded): | 450 /// | | <body contents> | 451 /// | | %pair = cmpxchg | 452 /// | | %ok = %pair[0] | 453 /// | | %new = %pair[1] | 454 /// | | cf.cond_br %ok, end, loop(%new) | 455 /// | +--------------------------------+ 456 /// | | | 457 /// |----------- | 458 /// v 459 /// +--------------------------------+ 460 /// | end: | 461 /// | <code after the AtomicRMWOp> | 462 /// +--------------------------------+ 463 /// 464 struct GenericAtomicRMWOpLowering 465 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 466 using Base::Base; 467 468 LogicalResult 469 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 470 ConversionPatternRewriter &rewriter) const override { 471 auto loc = atomicOp.getLoc(); 472 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 473 474 // Split the block into initial, loop, and ending parts. 475 auto *initBlock = rewriter.getInsertionBlock(); 476 auto *loopBlock = rewriter.createBlock( 477 initBlock->getParent(), std::next(Region::iterator(initBlock)), 478 valueType, loc); 479 auto *endBlock = rewriter.createBlock( 480 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 481 482 // Operations range to be moved to `endBlock`. 483 auto opsToMoveStart = atomicOp->getIterator(); 484 auto opsToMoveEnd = initBlock->back().getIterator(); 485 486 // Compute the loaded value and branch to the loop block. 487 rewriter.setInsertionPointToEnd(initBlock); 488 auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>(); 489 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), 490 adaptor.getIndices(), rewriter); 491 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 492 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 493 494 // Prepare the body of the loop block. 495 rewriter.setInsertionPointToStart(loopBlock); 496 497 // Clone the GenericAtomicRMWOp region and extract the result. 498 auto loopArgument = loopBlock->getArgument(0); 499 BlockAndValueMapping mapping; 500 mapping.map(atomicOp.getCurrentValue(), loopArgument); 501 Block &entryBlock = atomicOp.body().front(); 502 for (auto &nestedOp : entryBlock.without_terminator()) { 503 Operation *clone = rewriter.clone(nestedOp, mapping); 504 mapping.map(nestedOp.getResults(), clone->getResults()); 505 } 506 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 507 508 // Prepare the epilog of the loop block. 509 // Append the cmpxchg op to the end of the loop block. 510 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 511 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 512 auto boolType = IntegerType::get(rewriter.getContext(), 1); 513 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 514 {valueType, boolType}); 515 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 516 loc, pairType, dataPtr, loopArgument, result, successOrdering, 517 failureOrdering); 518 // Extract the %new_loaded and %ok values from the pair. 519 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 520 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 521 Value ok = rewriter.create<LLVM::ExtractValueOp>( 522 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 523 524 // Conditionally branch to the end or back to the loop depending on %ok. 525 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 526 loopBlock, newLoaded); 527 528 rewriter.setInsertionPointToEnd(endBlock); 529 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 530 std::next(opsToMoveEnd), rewriter); 531 532 // The 'result' of the atomic_rmw op is the newly loaded value. 533 rewriter.replaceOp(atomicOp, {newLoaded}); 534 535 return success(); 536 } 537 538 private: 539 // Clones a segment of ops [start, end) and erases the original. 540 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 541 Block::iterator start, Block::iterator end, 542 ConversionPatternRewriter &rewriter) const { 543 BlockAndValueMapping mapping; 544 mapping.map(oldResult, newResult); 545 SmallVector<Operation *, 2> opsToErase; 546 for (auto it = start; it != end; ++it) { 547 rewriter.clone(*it, mapping); 548 opsToErase.push_back(&*it); 549 } 550 for (auto *it : opsToErase) 551 rewriter.eraseOp(it); 552 } 553 }; 554 555 /// Returns the LLVM type of the global variable given the memref type `type`. 556 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 557 LLVMTypeConverter &typeConverter) { 558 // LLVM type for a global memref will be a multi-dimension array. For 559 // declarations or uninitialized global memrefs, we can potentially flatten 560 // this to a 1D array. However, for memref.global's with an initial value, 561 // we do not intend to flatten the ElementsAttribute when going from std -> 562 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 563 Type elementType = typeConverter.convertType(type.getElementType()); 564 Type arrayTy = elementType; 565 // Shape has the outermost dim at index 0, so need to walk it backwards 566 for (int64_t dim : llvm::reverse(type.getShape())) 567 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 568 return arrayTy; 569 } 570 571 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 572 struct GlobalMemrefOpLowering 573 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 574 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 575 576 LogicalResult 577 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 578 ConversionPatternRewriter &rewriter) const override { 579 MemRefType type = global.getType(); 580 if (!isConvertibleAndHasIdentityMaps(type)) 581 return failure(); 582 583 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 584 585 LLVM::Linkage linkage = 586 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 587 588 Attribute initialValue = nullptr; 589 if (!global.isExternal() && !global.isUninitialized()) { 590 auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>(); 591 initialValue = elementsAttr; 592 593 // For scalar memrefs, the global variable created is of the element type, 594 // so unpack the elements attribute to extract the value. 595 if (type.getRank() == 0) 596 initialValue = elementsAttr.getSplatValue<Attribute>(); 597 } 598 599 uint64_t alignment = global.getAlignment().value_or(0); 600 601 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 602 global, arrayTy, global.getConstant(), linkage, global.getSymName(), 603 initialValue, alignment, type.getMemorySpaceAsInt()); 604 if (!global.isExternal() && global.isUninitialized()) { 605 Block *blk = new Block(); 606 newGlobal.getInitializerRegion().push_back(blk); 607 rewriter.setInsertionPointToStart(blk); 608 Value undef[] = { 609 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 610 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 611 } 612 return success(); 613 } 614 }; 615 616 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 617 /// the first element stashed into the descriptor. This reuses 618 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 619 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 620 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 621 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 622 converter) {} 623 624 /// Buffer "allocation" for memref.get_global op is getting the address of 625 /// the global variable referenced. 626 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 627 Location loc, Value sizeBytes, 628 Operation *op) const override { 629 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 630 MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>(); 631 unsigned memSpace = type.getMemorySpaceAsInt(); 632 633 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 634 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 635 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), 636 getGlobalOp.getName()); 637 638 // Get the address of the first element in the array by creating a GEP with 639 // the address of the GV as the base, and (rank + 1) number of 0 indices. 640 Type elementType = typeConverter->convertType(type.getElementType()); 641 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 642 643 SmallVector<Value> operands; 644 operands.insert(operands.end(), type.getRank() + 1, 645 createIndexConstant(rewriter, loc, 0)); 646 auto gep = 647 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 648 649 // We do not expect the memref obtained using `memref.get_global` to be 650 // ever deallocated. Set the allocated pointer to be known bad value to 651 // help debug if that ever happens. 652 auto intPtrType = getIntPtrType(memSpace); 653 Value deadBeefConst = 654 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 655 auto deadBeefPtr = 656 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 657 658 // Both allocated and aligned pointers are same. We could potentially stash 659 // a nullptr for the allocated pointer since we do not expect any dealloc. 660 return std::make_tuple(deadBeefPtr, gep); 661 } 662 }; 663 664 // Load operation is lowered to obtaining a pointer to the indexed element 665 // and loading it. 666 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 667 using Base::Base; 668 669 LogicalResult 670 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 671 ConversionPatternRewriter &rewriter) const override { 672 auto type = loadOp.getMemRefType(); 673 674 Value dataPtr = 675 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), 676 adaptor.getIndices(), rewriter); 677 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 678 return success(); 679 } 680 }; 681 682 // Store operation is lowered to obtaining a pointer to the indexed element, 683 // and storing the given value to it. 684 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 685 using Base::Base; 686 687 LogicalResult 688 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 689 ConversionPatternRewriter &rewriter) const override { 690 auto type = op.getMemRefType(); 691 692 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), 693 adaptor.getIndices(), rewriter); 694 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr); 695 return success(); 696 } 697 }; 698 699 // The prefetch operation is lowered in a way similar to the load operation 700 // except that the llvm.prefetch operation is used for replacement. 701 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 702 using Base::Base; 703 704 LogicalResult 705 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 706 ConversionPatternRewriter &rewriter) const override { 707 auto type = prefetchOp.getMemRefType(); 708 auto loc = prefetchOp.getLoc(); 709 710 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(), 711 adaptor.getIndices(), rewriter); 712 713 // Replace with llvm.prefetch. 714 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 715 auto isWrite = rewriter.create<LLVM::ConstantOp>( 716 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite())); 717 auto localityHint = rewriter.create<LLVM::ConstantOp>( 718 loc, llvmI32Type, 719 rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint())); 720 auto isData = rewriter.create<LLVM::ConstantOp>( 721 loc, llvmI32Type, 722 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache())); 723 724 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 725 localityHint, isData); 726 return success(); 727 } 728 }; 729 730 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 731 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 732 733 LogicalResult 734 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 735 ConversionPatternRewriter &rewriter) const override { 736 Location loc = op.getLoc(); 737 Type operandType = op.getMemref().getType(); 738 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 739 UnrankedMemRefDescriptor desc(adaptor.getMemref()); 740 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 741 return success(); 742 } 743 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 744 rewriter.replaceOp( 745 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 746 return success(); 747 } 748 return failure(); 749 } 750 }; 751 752 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 753 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 754 755 LogicalResult match(memref::CastOp memRefCastOp) const override { 756 Type srcType = memRefCastOp.getOperand().getType(); 757 Type dstType = memRefCastOp.getType(); 758 759 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 760 // used for type erasure. For now they must preserve underlying element type 761 // and require source and result type to have the same rank. Therefore, 762 // perform a sanity check that the underlying structs are the same. Once op 763 // semantics are relaxed we can revisit. 764 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 765 return success(typeConverter->convertType(srcType) == 766 typeConverter->convertType(dstType)); 767 768 // At least one of the operands is unranked type 769 assert(srcType.isa<UnrankedMemRefType>() || 770 dstType.isa<UnrankedMemRefType>()); 771 772 // Unranked to unranked cast is disallowed 773 return !(srcType.isa<UnrankedMemRefType>() && 774 dstType.isa<UnrankedMemRefType>()) 775 ? success() 776 : failure(); 777 } 778 779 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 780 ConversionPatternRewriter &rewriter) const override { 781 auto srcType = memRefCastOp.getOperand().getType(); 782 auto dstType = memRefCastOp.getType(); 783 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 784 auto loc = memRefCastOp.getLoc(); 785 786 // For ranked/ranked case, just keep the original descriptor. 787 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 788 return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); 789 790 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 791 // Casting ranked to unranked memref type 792 // Set the rank in the destination from the memref type 793 // Allocate space on the stack and copy the src memref descriptor 794 // Set the ptr in the destination to the stack space 795 auto srcMemRefType = srcType.cast<MemRefType>(); 796 int64_t rank = srcMemRefType.getRank(); 797 // ptr = AllocaOp sizeof(MemRefDescriptor) 798 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 799 loc, adaptor.getSource(), rewriter); 800 // voidptr = BitCastOp srcType* to void* 801 auto voidPtr = 802 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 803 .getResult(); 804 // rank = ConstantOp srcRank 805 auto rankVal = rewriter.create<LLVM::ConstantOp>( 806 loc, getIndexType(), rewriter.getIndexAttr(rank)); 807 // undef = UndefOp 808 UnrankedMemRefDescriptor memRefDesc = 809 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 810 // d1 = InsertValueOp undef, rank, 0 811 memRefDesc.setRank(rewriter, loc, rankVal); 812 // d2 = InsertValueOp d1, voidptr, 1 813 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 814 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 815 816 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 817 // Casting from unranked type to ranked. 818 // The operation is assumed to be doing a correct cast. If the destination 819 // type mismatches the unranked the type, it is undefined behavior. 820 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource()); 821 // ptr = ExtractValueOp src, 1 822 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 823 // castPtr = BitCastOp i8* to structTy* 824 auto castPtr = 825 rewriter 826 .create<LLVM::BitcastOp>( 827 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 828 .getResult(); 829 // struct = LoadOp castPtr 830 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 831 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 832 } else { 833 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 834 } 835 } 836 }; 837 838 /// Pattern to lower a `memref.copy` to llvm. 839 /// 840 /// For memrefs with identity layouts, the copy is lowered to the llvm 841 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 842 /// to the generic `MemrefCopyFn`. 843 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 844 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 845 846 LogicalResult 847 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 848 ConversionPatternRewriter &rewriter) const { 849 auto loc = op.getLoc(); 850 auto srcType = op.getSource().getType().dyn_cast<MemRefType>(); 851 852 MemRefDescriptor srcDesc(adaptor.getSource()); 853 854 // Compute number of elements. 855 Value numElements = rewriter.create<LLVM::ConstantOp>( 856 loc, getIndexType(), rewriter.getIndexAttr(1)); 857 for (int pos = 0; pos < srcType.getRank(); ++pos) { 858 auto size = srcDesc.size(rewriter, loc, pos); 859 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 860 } 861 862 // Get element size. 863 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 864 // Compute total. 865 Value totalSize = 866 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 867 868 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 869 Value srcOffset = srcDesc.offset(rewriter, loc); 870 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(), 871 srcBasePtr, srcOffset); 872 MemRefDescriptor targetDesc(adaptor.getTarget()); 873 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 874 Value targetOffset = targetDesc.offset(rewriter, loc); 875 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(), 876 targetBasePtr, targetOffset); 877 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 878 loc, typeConverter->convertType(rewriter.getI1Type()), 879 rewriter.getBoolAttr(false)); 880 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 881 isVolatile); 882 rewriter.eraseOp(op); 883 884 return success(); 885 } 886 887 LogicalResult 888 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 889 ConversionPatternRewriter &rewriter) const { 890 auto loc = op.getLoc(); 891 auto srcType = op.getSource().getType().cast<BaseMemRefType>(); 892 auto targetType = op.getTarget().getType().cast<BaseMemRefType>(); 893 894 // First make sure we have an unranked memref descriptor representation. 895 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 896 auto rank = rewriter.create<LLVM::ConstantOp>( 897 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 898 auto *typeConverter = getTypeConverter(); 899 auto ptr = 900 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 901 auto voidPtr = 902 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 903 .getResult(); 904 auto unrankedType = 905 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 906 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 907 unrankedType, 908 ValueRange{rank, voidPtr}); 909 }; 910 911 Value unrankedSource = srcType.hasRank() 912 ? makeUnranked(adaptor.getSource(), srcType) 913 : adaptor.getSource(); 914 Value unrankedTarget = targetType.hasRank() 915 ? makeUnranked(adaptor.getTarget(), targetType) 916 : adaptor.getTarget(); 917 918 // Now promote the unranked descriptors to the stack. 919 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 920 rewriter.getIndexAttr(1)); 921 auto promote = [&](Value desc) { 922 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 923 auto allocated = 924 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 925 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 926 return allocated; 927 }; 928 929 auto sourcePtr = promote(unrankedSource); 930 auto targetPtr = promote(unrankedTarget); 931 932 unsigned typeSize = 933 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 934 auto elemSize = rewriter.create<LLVM::ConstantOp>( 935 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 936 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 937 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 938 rewriter.create<LLVM::CallOp>(loc, copyFn, 939 ValueRange{elemSize, sourcePtr, targetPtr}); 940 rewriter.eraseOp(op); 941 942 return success(); 943 } 944 945 LogicalResult 946 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 947 ConversionPatternRewriter &rewriter) const override { 948 auto srcType = op.getSource().getType().cast<BaseMemRefType>(); 949 auto targetType = op.getTarget().getType().cast<BaseMemRefType>(); 950 951 auto isContiguousMemrefType = [](BaseMemRefType type) { 952 auto memrefType = type.dyn_cast<mlir::MemRefType>(); 953 // We can use memcpy for memrefs if they have an identity layout or are 954 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 955 // special case handled by memrefCopy. 956 return memrefType && 957 (memrefType.getLayout().isIdentity() || 958 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 959 isStaticShapeAndContiguousRowMajor(memrefType))); 960 }; 961 962 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 963 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 964 965 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 966 } 967 }; 968 969 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 970 /// memref type. In unranked case, the fields are extracted from the underlying 971 /// ranked descriptor. 972 static void extractPointersAndOffset(Location loc, 973 ConversionPatternRewriter &rewriter, 974 LLVMTypeConverter &typeConverter, 975 Value originalOperand, 976 Value convertedOperand, 977 Value *allocatedPtr, Value *alignedPtr, 978 Value *offset = nullptr) { 979 Type operandType = originalOperand.getType(); 980 if (operandType.isa<MemRefType>()) { 981 MemRefDescriptor desc(convertedOperand); 982 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 983 *alignedPtr = desc.alignedPtr(rewriter, loc); 984 if (offset != nullptr) 985 *offset = desc.offset(rewriter, loc); 986 return; 987 } 988 989 unsigned memorySpace = 990 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 991 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 992 Type llvmElementType = typeConverter.convertType(elementType); 993 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 994 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 995 996 // Extract pointer to the underlying ranked memref descriptor and cast it to 997 // ElemType**. 998 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 999 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 1000 1001 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 1002 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 1003 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 1004 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1005 if (offset != nullptr) { 1006 *offset = UnrankedMemRefDescriptor::offset( 1007 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1008 } 1009 } 1010 1011 struct MemRefReinterpretCastOpLowering 1012 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1013 using ConvertOpToLLVMPattern< 1014 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1015 1016 LogicalResult 1017 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1018 ConversionPatternRewriter &rewriter) const override { 1019 Type srcType = castOp.getSource().getType(); 1020 1021 Value descriptor; 1022 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1023 adaptor, &descriptor))) 1024 return failure(); 1025 rewriter.replaceOp(castOp, {descriptor}); 1026 return success(); 1027 } 1028 1029 private: 1030 LogicalResult convertSourceMemRefToDescriptor( 1031 ConversionPatternRewriter &rewriter, Type srcType, 1032 memref::ReinterpretCastOp castOp, 1033 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1034 MemRefType targetMemRefType = 1035 castOp.getResult().getType().cast<MemRefType>(); 1036 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1037 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1038 if (!llvmTargetDescriptorTy) 1039 return failure(); 1040 1041 // Create descriptor. 1042 Location loc = castOp.getLoc(); 1043 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1044 1045 // Set allocated and aligned pointers. 1046 Value allocatedPtr, alignedPtr; 1047 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1048 castOp.getSource(), adaptor.getSource(), 1049 &allocatedPtr, &alignedPtr); 1050 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1051 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1052 1053 // Set offset. 1054 if (castOp.isDynamicOffset(0)) 1055 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); 1056 else 1057 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1058 1059 // Set sizes and strides. 1060 unsigned dynSizeId = 0; 1061 unsigned dynStrideId = 0; 1062 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1063 if (castOp.isDynamicSize(i)) 1064 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); 1065 else 1066 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1067 1068 if (castOp.isDynamicStride(i)) 1069 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); 1070 else 1071 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1072 } 1073 *descriptor = desc; 1074 return success(); 1075 } 1076 }; 1077 1078 struct MemRefReshapeOpLowering 1079 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1080 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1081 1082 LogicalResult 1083 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1084 ConversionPatternRewriter &rewriter) const override { 1085 Type srcType = reshapeOp.getSource().getType(); 1086 1087 Value descriptor; 1088 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1089 adaptor, &descriptor))) 1090 return failure(); 1091 rewriter.replaceOp(reshapeOp, {descriptor}); 1092 return success(); 1093 } 1094 1095 private: 1096 LogicalResult 1097 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1098 Type srcType, memref::ReshapeOp reshapeOp, 1099 memref::ReshapeOp::Adaptor adaptor, 1100 Value *descriptor) const { 1101 auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>(); 1102 if (shapeMemRefType.hasStaticShape()) { 1103 MemRefType targetMemRefType = 1104 reshapeOp.getResult().getType().cast<MemRefType>(); 1105 auto llvmTargetDescriptorTy = 1106 typeConverter->convertType(targetMemRefType) 1107 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1108 if (!llvmTargetDescriptorTy) 1109 return failure(); 1110 1111 // Create descriptor. 1112 Location loc = reshapeOp.getLoc(); 1113 auto desc = 1114 MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1115 1116 // Set allocated and aligned pointers. 1117 Value allocatedPtr, alignedPtr; 1118 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1119 reshapeOp.getSource(), adaptor.getSource(), 1120 &allocatedPtr, &alignedPtr); 1121 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1122 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1123 1124 // Extract the offset and strides from the type. 1125 int64_t offset; 1126 SmallVector<int64_t> strides; 1127 if (failed(getStridesAndOffset(targetMemRefType, strides, offset))) 1128 return rewriter.notifyMatchFailure( 1129 reshapeOp, "failed to get stride and offset exprs"); 1130 1131 if (!isStaticStrideOrOffset(offset)) 1132 return rewriter.notifyMatchFailure(reshapeOp, 1133 "dynamic offset is unsupported"); 1134 1135 desc.setConstantOffset(rewriter, loc, offset); 1136 1137 assert(targetMemRefType.getLayout().isIdentity() && 1138 "Identity layout map is a precondition of a valid reshape op"); 1139 1140 Value stride = nullptr; 1141 int64_t targetRank = targetMemRefType.getRank(); 1142 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) { 1143 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1144 // If the stride for this dimension is dynamic, then use the product 1145 // of the sizes of the inner dimensions. 1146 stride = createIndexConstant(rewriter, loc, strides[i]); 1147 } else if (!stride) { 1148 // `stride` is null only in the first iteration of the loop. However, 1149 // since the target memref has an identity layout, we can safely set 1150 // the innermost stride to 1. 1151 stride = createIndexConstant(rewriter, loc, 1); 1152 } 1153 1154 Value dimSize; 1155 int64_t size = targetMemRefType.getDimSize(i); 1156 // If the size of this dimension is dynamic, then load it at runtime 1157 // from the shape operand. 1158 if (!ShapedType::isDynamic(size)) { 1159 dimSize = createIndexConstant(rewriter, loc, size); 1160 } else { 1161 Value shapeOp = reshapeOp.getShape(); 1162 Value index = createIndexConstant(rewriter, loc, i); 1163 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index); 1164 Type indexType = getIndexType(); 1165 if (dimSize.getType() != indexType) 1166 dimSize = typeConverter->materializeTargetConversion( 1167 rewriter, loc, indexType, dimSize); 1168 assert(dimSize && "Invalid memref element type"); 1169 } 1170 1171 desc.setSize(rewriter, loc, i, dimSize); 1172 desc.setStride(rewriter, loc, i, stride); 1173 1174 // Prepare the stride value for the next dimension. 1175 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize); 1176 } 1177 1178 *descriptor = desc; 1179 return success(); 1180 } 1181 1182 // The shape is a rank-1 tensor with unknown length. 1183 Location loc = reshapeOp.getLoc(); 1184 MemRefDescriptor shapeDesc(adaptor.getShape()); 1185 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1186 1187 // Extract address space and element type. 1188 auto targetType = 1189 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1190 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1191 Type elementType = targetType.getElementType(); 1192 1193 // Create the unranked memref descriptor that holds the ranked one. The 1194 // inner descriptor is allocated on stack. 1195 auto targetDesc = UnrankedMemRefDescriptor::undef( 1196 rewriter, loc, typeConverter->convertType(targetType)); 1197 targetDesc.setRank(rewriter, loc, resultRank); 1198 SmallVector<Value, 4> sizes; 1199 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1200 targetDesc, sizes); 1201 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1202 loc, getVoidPtrType(), sizes.front(), llvm::None); 1203 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1204 1205 // Extract pointers and offset from the source memref. 1206 Value allocatedPtr, alignedPtr, offset; 1207 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1208 reshapeOp.getSource(), adaptor.getSource(), 1209 &allocatedPtr, &alignedPtr, &offset); 1210 1211 // Set pointers and offset. 1212 Type llvmElementType = typeConverter->convertType(elementType); 1213 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1214 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1215 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1216 elementPtrPtrType, allocatedPtr); 1217 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1218 underlyingDescPtr, 1219 elementPtrPtrType, alignedPtr); 1220 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1221 underlyingDescPtr, elementPtrPtrType, 1222 offset); 1223 1224 // Use the offset pointer as base for further addressing. Copy over the new 1225 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1226 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1227 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1228 elementPtrPtrType); 1229 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1230 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1231 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1232 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1233 Value resultRankMinusOne = 1234 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1235 1236 Block *initBlock = rewriter.getInsertionBlock(); 1237 Type indexType = getTypeConverter()->getIndexType(); 1238 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1239 1240 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1241 {indexType, indexType}, {loc, loc}); 1242 1243 // Move the remaining initBlock ops to condBlock. 1244 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1245 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1246 1247 rewriter.setInsertionPointToEnd(initBlock); 1248 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1249 condBlock); 1250 rewriter.setInsertionPointToStart(condBlock); 1251 Value indexArg = condBlock->getArgument(0); 1252 Value strideArg = condBlock->getArgument(1); 1253 1254 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1255 Value pred = rewriter.create<LLVM::ICmpOp>( 1256 loc, IntegerType::get(rewriter.getContext(), 1), 1257 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1258 1259 Block *bodyBlock = 1260 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1261 rewriter.setInsertionPointToStart(bodyBlock); 1262 1263 // Copy size from shape to descriptor. 1264 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1265 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1266 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1267 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1268 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1269 targetSizesBase, indexArg, size); 1270 1271 // Write stride value and compute next one. 1272 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1273 targetStridesBase, indexArg, strideArg); 1274 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1275 1276 // Decrement loop counter and branch back. 1277 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1278 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1279 condBlock); 1280 1281 Block *remainder = 1282 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1283 1284 // Hook up the cond exit to the remainder. 1285 rewriter.setInsertionPointToEnd(condBlock); 1286 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1287 llvm::None); 1288 1289 // Reset position to beginning of new remainder block. 1290 rewriter.setInsertionPointToStart(remainder); 1291 1292 *descriptor = targetDesc; 1293 return success(); 1294 } 1295 }; 1296 1297 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1298 /// `Value`s. 1299 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1300 Type &llvmIndexType, 1301 ArrayRef<OpFoldResult> valueOrAttrVec) { 1302 return llvm::to_vector<4>( 1303 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1304 if (auto attr = value.dyn_cast<Attribute>()) 1305 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1306 return value.get<Value>(); 1307 })); 1308 } 1309 1310 /// Compute a map that for a given dimension of the expanded type gives the 1311 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1312 /// the `reassocation` maps. 1313 static DenseMap<int64_t, int64_t> 1314 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1315 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1316 for (auto &en : enumerate(reassociation)) { 1317 for (auto dim : en.value()) 1318 expandedDimToCollapsedDim[dim] = en.index(); 1319 } 1320 return expandedDimToCollapsedDim; 1321 } 1322 1323 static OpFoldResult 1324 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1325 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1326 MemRefDescriptor &inDesc, 1327 ArrayRef<int64_t> inStaticShape, 1328 ArrayRef<ReassociationIndices> reassocation, 1329 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1330 int64_t outDimSize = outStaticShape[outDimIndex]; 1331 if (!ShapedType::isDynamic(outDimSize)) 1332 return b.getIndexAttr(outDimSize); 1333 1334 // Calculate the multiplication of all the out dim sizes except the 1335 // current dim. 1336 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1337 int64_t otherDimSizesMul = 1; 1338 for (auto otherDimIndex : reassocation[inDimIndex]) { 1339 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1340 continue; 1341 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1342 assert(!ShapedType::isDynamic(otherDimSize) && 1343 "single dimension cannot be expanded into multiple dynamic " 1344 "dimensions"); 1345 otherDimSizesMul *= otherDimSize; 1346 } 1347 1348 // outDimSize = inDimSize / otherOutDimSizesMul 1349 int64_t inDimSize = inStaticShape[inDimIndex]; 1350 Value inDimSizeDynamic = 1351 ShapedType::isDynamic(inDimSize) 1352 ? inDesc.size(b, loc, inDimIndex) 1353 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1354 b.getIndexAttr(inDimSize)); 1355 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1356 loc, inDimSizeDynamic, 1357 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1358 b.getIndexAttr(otherDimSizesMul))); 1359 return outDimSizeDynamic; 1360 } 1361 1362 static OpFoldResult getCollapsedOutputDimSize( 1363 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1364 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1365 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1366 if (!ShapedType::isDynamic(outDimSize)) 1367 return b.getIndexAttr(outDimSize); 1368 1369 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1370 Value outDimSizeDynamic = c1; 1371 for (auto inDimIndex : reassocation[outDimIndex]) { 1372 int64_t inDimSize = inStaticShape[inDimIndex]; 1373 Value inDimSizeDynamic = 1374 ShapedType::isDynamic(inDimSize) 1375 ? inDesc.size(b, loc, inDimIndex) 1376 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1377 b.getIndexAttr(inDimSize)); 1378 outDimSizeDynamic = 1379 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1380 } 1381 return outDimSizeDynamic; 1382 } 1383 1384 static SmallVector<OpFoldResult, 4> 1385 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1386 ArrayRef<ReassociationIndices> reassociation, 1387 ArrayRef<int64_t> inStaticShape, 1388 MemRefDescriptor &inDesc, 1389 ArrayRef<int64_t> outStaticShape) { 1390 return llvm::to_vector<4>(llvm::map_range( 1391 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1392 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1393 outStaticShape[outDimIndex], 1394 inStaticShape, inDesc, reassociation); 1395 })); 1396 } 1397 1398 static SmallVector<OpFoldResult, 4> 1399 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1400 ArrayRef<ReassociationIndices> reassociation, 1401 ArrayRef<int64_t> inStaticShape, 1402 MemRefDescriptor &inDesc, 1403 ArrayRef<int64_t> outStaticShape) { 1404 DenseMap<int64_t, int64_t> outDimToInDimMap = 1405 getExpandedDimToCollapsedDimMap(reassociation); 1406 return llvm::to_vector<4>(llvm::map_range( 1407 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1408 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1409 outStaticShape, inDesc, inStaticShape, 1410 reassociation, outDimToInDimMap); 1411 })); 1412 } 1413 1414 static SmallVector<Value> 1415 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1416 ArrayRef<ReassociationIndices> reassociation, 1417 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1418 ArrayRef<int64_t> outStaticShape) { 1419 return outStaticShape.size() < inStaticShape.size() 1420 ? getAsValues(b, loc, llvmIndexType, 1421 getCollapsedOutputShape(b, loc, llvmIndexType, 1422 reassociation, inStaticShape, 1423 inDesc, outStaticShape)) 1424 : getAsValues(b, loc, llvmIndexType, 1425 getExpandedOutputShape(b, loc, llvmIndexType, 1426 reassociation, inStaticShape, 1427 inDesc, outStaticShape)); 1428 } 1429 1430 static void fillInStridesForExpandedMemDescriptor( 1431 OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, 1432 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1433 // See comments for computeExpandedLayoutMap for details on how the strides 1434 // are calculated. 1435 for (auto &en : llvm::enumerate(reassociation)) { 1436 auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); 1437 for (auto dstIndex : llvm::reverse(en.value())) { 1438 dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); 1439 Value size = dstDesc.size(b, loc, dstIndex); 1440 currentStrideToExpand = 1441 b.create<LLVM::MulOp>(loc, size, currentStrideToExpand); 1442 } 1443 } 1444 } 1445 1446 static void fillInStridesForCollapsedMemDescriptor( 1447 ConversionPatternRewriter &rewriter, Location loc, Operation *op, 1448 TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, 1449 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1450 // See comments for computeCollapsedLayoutMap for details on how the strides 1451 // are calculated. 1452 auto srcShape = srcType.getShape(); 1453 for (auto &en : llvm::enumerate(reassociation)) { 1454 rewriter.setInsertionPoint(op); 1455 auto dstIndex = en.index(); 1456 ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value()); 1457 while (srcShape[ref.back()] == 1 && ref.size() > 1) 1458 ref = ref.drop_back(); 1459 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { 1460 dstDesc.setStride(rewriter, loc, dstIndex, 1461 srcDesc.stride(rewriter, loc, ref.back())); 1462 } else { 1463 // Iterate over the source strides in reverse order. Skip over the 1464 // dimensions whose size is 1. 1465 // TODO: we should take the minimum stride in the reassociation group 1466 // instead of just the first where the dimension is not 1. 1467 // 1468 // +------------------------------------------------------+ 1469 // | curEntry: | 1470 // | %srcStride = strides[srcIndex] | 1471 // | %neOne = cmp sizes[srcIndex],1 +--+ 1472 // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | 1473 // +-------------------------+----------------------------+ | 1474 // | | 1475 // v | 1476 // +-----------------------------+ | 1477 // | nextEntry: | | 1478 // | ... +---+ | 1479 // +--------------+--------------+ | | 1480 // | | | 1481 // v | | 1482 // +-----------------------------+ | | 1483 // | nextEntry: | | | 1484 // | ... | | | 1485 // +--------------+--------------+ | +--------+ 1486 // | | | 1487 // v v v 1488 // +--------------------------------------------------+ 1489 // | continue(%newStride): | 1490 // | %newMemRefDes = setStride(%newStride,dstIndex) | 1491 // +--------------------------------------------------+ 1492 OpBuilder::InsertionGuard guard(rewriter); 1493 Block *initBlock = rewriter.getInsertionBlock(); 1494 Block *continueBlock = 1495 rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); 1496 continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); 1497 rewriter.setInsertionPointToStart(continueBlock); 1498 dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); 1499 1500 Block *curEntryBlock = initBlock; 1501 Block *nextEntryBlock; 1502 for (auto srcIndex : llvm::reverse(ref)) { 1503 if (srcShape[srcIndex] == 1 && srcIndex != ref.front()) 1504 continue; 1505 rewriter.setInsertionPointToEnd(curEntryBlock); 1506 Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); 1507 if (srcIndex == ref.front()) { 1508 rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock); 1509 break; 1510 } 1511 Value one = rewriter.create<LLVM::ConstantOp>( 1512 loc, typeConverter->convertType(rewriter.getI64Type()), 1513 rewriter.getI32IntegerAttr(1)); 1514 Value predNeOne = rewriter.create<LLVM::ICmpOp>( 1515 loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), 1516 one); 1517 { 1518 OpBuilder::InsertionGuard guard(rewriter); 1519 nextEntryBlock = rewriter.createBlock( 1520 initBlock->getParent(), Region::iterator(continueBlock), {}); 1521 } 1522 rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock, 1523 srcStride, nextEntryBlock, llvm::None); 1524 curEntryBlock = nextEntryBlock; 1525 } 1526 } 1527 } 1528 } 1529 1530 static void fillInDynamicStridesForMemDescriptor( 1531 ConversionPatternRewriter &b, Location loc, Operation *op, 1532 TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, 1533 MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, 1534 ArrayRef<ReassociationIndices> reassociation) { 1535 if (srcType.getRank() > dstType.getRank()) 1536 fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, 1537 srcDesc, dstDesc, reassociation); 1538 else 1539 fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, 1540 reassociation); 1541 } 1542 1543 // ReshapeOp creates a new view descriptor of the proper rank. 1544 // For now, the only conversion supported is for target MemRef with static sizes 1545 // and strides. 1546 template <typename ReshapeOp> 1547 class ReassociatingReshapeOpConversion 1548 : public ConvertOpToLLVMPattern<ReshapeOp> { 1549 public: 1550 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1551 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1552 1553 LogicalResult 1554 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1555 ConversionPatternRewriter &rewriter) const override { 1556 MemRefType dstType = reshapeOp.getResultType(); 1557 MemRefType srcType = reshapeOp.getSrcType(); 1558 1559 int64_t offset; 1560 SmallVector<int64_t, 4> strides; 1561 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1562 return rewriter.notifyMatchFailure( 1563 reshapeOp, "failed to get stride and offset exprs"); 1564 } 1565 1566 MemRefDescriptor srcDesc(adaptor.getSrc()); 1567 Location loc = reshapeOp->getLoc(); 1568 auto dstDesc = MemRefDescriptor::undef( 1569 rewriter, loc, this->typeConverter->convertType(dstType)); 1570 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1571 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1572 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1573 1574 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1575 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1576 Type llvmIndexType = 1577 this->typeConverter->convertType(rewriter.getIndexType()); 1578 SmallVector<Value> dstShape = getDynamicOutputShape( 1579 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1580 srcStaticShape, srcDesc, dstStaticShape); 1581 for (auto &en : llvm::enumerate(dstShape)) 1582 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1583 1584 if (llvm::all_of(strides, isStaticStrideOrOffset)) { 1585 for (auto &en : llvm::enumerate(strides)) 1586 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1587 } else if (srcType.getLayout().isIdentity() && 1588 dstType.getLayout().isIdentity()) { 1589 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1590 rewriter.getIndexAttr(1)); 1591 Value stride = c1; 1592 for (auto dimIndex : 1593 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1594 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1595 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1596 } 1597 } else { 1598 // There could be mixed static/dynamic strides. For simplicity, we 1599 // recompute all strides if there is at least one dynamic stride. 1600 fillInDynamicStridesForMemDescriptor( 1601 rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, 1602 srcDesc, dstDesc, reshapeOp.getReassociationIndices()); 1603 } 1604 rewriter.replaceOp(reshapeOp, {dstDesc}); 1605 return success(); 1606 } 1607 }; 1608 1609 /// Conversion pattern that transforms a subview op into: 1610 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1611 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1612 /// and stride. 1613 /// The subview op is replaced by the descriptor. 1614 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1615 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1616 1617 LogicalResult 1618 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1619 ConversionPatternRewriter &rewriter) const override { 1620 auto loc = subViewOp.getLoc(); 1621 1622 auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>(); 1623 auto sourceElementTy = 1624 typeConverter->convertType(sourceMemRefType.getElementType()); 1625 1626 auto viewMemRefType = subViewOp.getType(); 1627 auto inferredType = 1628 memref::SubViewOp::inferResultType( 1629 subViewOp.getSourceType(), 1630 extractFromI64ArrayAttr(subViewOp.getStaticOffsets()), 1631 extractFromI64ArrayAttr(subViewOp.getStaticSizes()), 1632 extractFromI64ArrayAttr(subViewOp.getStaticStrides())) 1633 .cast<MemRefType>(); 1634 auto targetElementTy = 1635 typeConverter->convertType(viewMemRefType.getElementType()); 1636 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1637 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1638 !LLVM::isCompatibleType(sourceElementTy) || 1639 !LLVM::isCompatibleType(targetElementTy) || 1640 !LLVM::isCompatibleType(targetDescTy)) 1641 return failure(); 1642 1643 // Extract the offset and strides from the type. 1644 int64_t offset; 1645 SmallVector<int64_t, 4> strides; 1646 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1647 if (failed(successStrides)) 1648 return failure(); 1649 1650 // Create the descriptor. 1651 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1652 return failure(); 1653 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1654 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1655 1656 // Copy the buffer pointer from the old descriptor to the new one. 1657 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1658 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1659 loc, 1660 LLVM::LLVMPointerType::get(targetElementTy, 1661 viewMemRefType.getMemorySpaceAsInt()), 1662 extracted); 1663 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1664 1665 // Copy the aligned pointer from the old descriptor to the new one. 1666 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1667 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1668 loc, 1669 LLVM::LLVMPointerType::get(targetElementTy, 1670 viewMemRefType.getMemorySpaceAsInt()), 1671 extracted); 1672 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1673 1674 size_t inferredShapeRank = inferredType.getRank(); 1675 size_t resultShapeRank = viewMemRefType.getRank(); 1676 1677 // Extract strides needed to compute offset. 1678 SmallVector<Value, 4> strideValues; 1679 strideValues.reserve(inferredShapeRank); 1680 for (unsigned i = 0; i < inferredShapeRank; ++i) 1681 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1682 1683 // Offset. 1684 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1685 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1686 targetMemRef.setConstantOffset(rewriter, loc, offset); 1687 } else { 1688 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1689 // `inferredShapeRank` may be larger than the number of offset operands 1690 // because of trailing semantics. In this case, the offset is guaranteed 1691 // to be interpreted as 0 and we can just skip the extra dimensions. 1692 for (unsigned i = 0, e = std::min(inferredShapeRank, 1693 subViewOp.getMixedOffsets().size()); 1694 i < e; ++i) { 1695 Value offset = 1696 // TODO: need OpFoldResult ODS adaptor to clean this up. 1697 subViewOp.isDynamicOffset(i) 1698 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1699 : rewriter.create<LLVM::ConstantOp>( 1700 loc, llvmIndexType, 1701 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1702 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1703 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1704 } 1705 targetMemRef.setOffset(rewriter, loc, baseOffset); 1706 } 1707 1708 // Update sizes and strides. 1709 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1710 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1711 assert(mixedSizes.size() == mixedStrides.size() && 1712 "expected sizes and strides of equal length"); 1713 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1714 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1715 i >= 0 && j >= 0; --i) { 1716 if (unusedDims.test(i)) 1717 continue; 1718 1719 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1720 // In this case, the size is guaranteed to be interpreted as Dim and the 1721 // stride as 1. 1722 Value size, stride; 1723 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1724 // If the static size is available, use it directly. This is similar to 1725 // the folding of dim(constant-op) but removes the need for dim to be 1726 // aware of LLVM constants and for this pass to be aware of std 1727 // constants. 1728 int64_t staticSize = 1729 subViewOp.getSource().getType().cast<MemRefType>().getShape()[i]; 1730 if (staticSize != ShapedType::kDynamicSize) { 1731 size = rewriter.create<LLVM::ConstantOp>( 1732 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1733 } else { 1734 Value pos = rewriter.create<LLVM::ConstantOp>( 1735 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1736 Value dim = 1737 rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos); 1738 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1739 loc, llvmIndexType, dim); 1740 size = cast.getResult(0); 1741 } 1742 stride = rewriter.create<LLVM::ConstantOp>( 1743 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1744 } else { 1745 // TODO: need OpFoldResult ODS adaptor to clean this up. 1746 size = 1747 subViewOp.isDynamicSize(i) 1748 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1749 : rewriter.create<LLVM::ConstantOp>( 1750 loc, llvmIndexType, 1751 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1752 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1753 stride = rewriter.create<LLVM::ConstantOp>( 1754 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1755 } else { 1756 stride = 1757 subViewOp.isDynamicStride(i) 1758 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1759 : rewriter.create<LLVM::ConstantOp>( 1760 loc, llvmIndexType, 1761 rewriter.getI64IntegerAttr( 1762 subViewOp.getStaticStride(i))); 1763 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1764 } 1765 } 1766 targetMemRef.setSize(rewriter, loc, j, size); 1767 targetMemRef.setStride(rewriter, loc, j, stride); 1768 j--; 1769 } 1770 1771 rewriter.replaceOp(subViewOp, {targetMemRef}); 1772 return success(); 1773 } 1774 }; 1775 1776 /// Conversion pattern that transforms a transpose op into: 1777 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1778 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1779 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1780 /// and stride. Size and stride are permutations of the original values. 1781 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1782 /// The transpose op is replaced by the alloca'ed pointer. 1783 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1784 public: 1785 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1786 1787 LogicalResult 1788 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1789 ConversionPatternRewriter &rewriter) const override { 1790 auto loc = transposeOp.getLoc(); 1791 MemRefDescriptor viewMemRef(adaptor.getIn()); 1792 1793 // No permutation, early exit. 1794 if (transposeOp.getPermutation().isIdentity()) 1795 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1796 1797 auto targetMemRef = MemRefDescriptor::undef( 1798 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1799 1800 // Copy the base and aligned pointers from the old descriptor to the new 1801 // one. 1802 targetMemRef.setAllocatedPtr(rewriter, loc, 1803 viewMemRef.allocatedPtr(rewriter, loc)); 1804 targetMemRef.setAlignedPtr(rewriter, loc, 1805 viewMemRef.alignedPtr(rewriter, loc)); 1806 1807 // Copy the offset pointer from the old descriptor to the new one. 1808 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1809 1810 // Iterate over the dimensions and apply size/stride permutation. 1811 for (const auto &en : 1812 llvm::enumerate(transposeOp.getPermutation().getResults())) { 1813 int sourcePos = en.index(); 1814 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1815 targetMemRef.setSize(rewriter, loc, targetPos, 1816 viewMemRef.size(rewriter, loc, sourcePos)); 1817 targetMemRef.setStride(rewriter, loc, targetPos, 1818 viewMemRef.stride(rewriter, loc, sourcePos)); 1819 } 1820 1821 rewriter.replaceOp(transposeOp, {targetMemRef}); 1822 return success(); 1823 } 1824 }; 1825 1826 /// Conversion pattern that transforms an op into: 1827 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1828 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1829 /// and stride. 1830 /// The view op is replaced by the descriptor. 1831 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1832 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1833 1834 // Build and return the value for the idx^th shape dimension, either by 1835 // returning the constant shape dimension or counting the proper dynamic size. 1836 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1837 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1838 unsigned idx) const { 1839 assert(idx < shape.size()); 1840 if (!ShapedType::isDynamic(shape[idx])) 1841 return createIndexConstant(rewriter, loc, shape[idx]); 1842 // Count the number of dynamic dims in range [0, idx] 1843 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1844 return ShapedType::isDynamic(v); 1845 }); 1846 return dynamicSizes[nDynamic]; 1847 } 1848 1849 // Build and return the idx^th stride, either by returning the constant stride 1850 // or by computing the dynamic stride from the current `runningStride` and 1851 // `nextSize`. The caller should keep a running stride and update it with the 1852 // result returned by this function. 1853 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1854 ArrayRef<int64_t> strides, Value nextSize, 1855 Value runningStride, unsigned idx) const { 1856 assert(idx < strides.size()); 1857 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1858 return createIndexConstant(rewriter, loc, strides[idx]); 1859 if (nextSize) 1860 return runningStride 1861 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1862 : nextSize; 1863 assert(!runningStride); 1864 return createIndexConstant(rewriter, loc, 1); 1865 } 1866 1867 LogicalResult 1868 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1869 ConversionPatternRewriter &rewriter) const override { 1870 auto loc = viewOp.getLoc(); 1871 1872 auto viewMemRefType = viewOp.getType(); 1873 auto targetElementTy = 1874 typeConverter->convertType(viewMemRefType.getElementType()); 1875 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1876 if (!targetDescTy || !targetElementTy || 1877 !LLVM::isCompatibleType(targetElementTy) || 1878 !LLVM::isCompatibleType(targetDescTy)) 1879 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1880 failure(); 1881 1882 int64_t offset; 1883 SmallVector<int64_t, 4> strides; 1884 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1885 if (failed(successStrides)) 1886 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1887 assert(offset == 0 && "expected offset to be 0"); 1888 1889 // Target memref must be contiguous in memory (innermost stride is 1), or 1890 // empty (special case when at least one of the memref dimensions is 0). 1891 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) 1892 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1893 failure(); 1894 1895 // Create the descriptor. 1896 MemRefDescriptor sourceMemRef(adaptor.getSource()); 1897 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1898 1899 // Field 1: Copy the allocated pointer, used for malloc/free. 1900 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1901 auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>(); 1902 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1903 loc, 1904 LLVM::LLVMPointerType::get(targetElementTy, 1905 srcMemRefType.getMemorySpaceAsInt()), 1906 allocatedPtr); 1907 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1908 1909 // Field 2: Copy the actual aligned pointer to payload. 1910 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1911 alignedPtr = rewriter.create<LLVM::GEPOp>( 1912 loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift()); 1913 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1914 loc, 1915 LLVM::LLVMPointerType::get(targetElementTy, 1916 srcMemRefType.getMemorySpaceAsInt()), 1917 alignedPtr); 1918 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1919 1920 // Field 3: The offset in the resulting type must be 0. This is because of 1921 // the type change: an offset on srcType* may not be expressible as an 1922 // offset on dstType*. 1923 targetMemRef.setOffset(rewriter, loc, 1924 createIndexConstant(rewriter, loc, offset)); 1925 1926 // Early exit for 0-D corner case. 1927 if (viewMemRefType.getRank() == 0) 1928 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1929 1930 // Fields 4 and 5: Update sizes and strides. 1931 Value stride = nullptr, nextSize = nullptr; 1932 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1933 // Update size. 1934 Value size = getSize(rewriter, loc, viewMemRefType.getShape(), 1935 adaptor.getSizes(), i); 1936 targetMemRef.setSize(rewriter, loc, i, size); 1937 // Update stride. 1938 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1939 targetMemRef.setStride(rewriter, loc, i, stride); 1940 nextSize = size; 1941 } 1942 1943 rewriter.replaceOp(viewOp, {targetMemRef}); 1944 return success(); 1945 } 1946 }; 1947 1948 //===----------------------------------------------------------------------===// 1949 // AtomicRMWOpLowering 1950 //===----------------------------------------------------------------------===// 1951 1952 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1953 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1954 static Optional<LLVM::AtomicBinOp> 1955 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1956 switch (atomicOp.getKind()) { 1957 case arith::AtomicRMWKind::addf: 1958 return LLVM::AtomicBinOp::fadd; 1959 case arith::AtomicRMWKind::addi: 1960 return LLVM::AtomicBinOp::add; 1961 case arith::AtomicRMWKind::assign: 1962 return LLVM::AtomicBinOp::xchg; 1963 case arith::AtomicRMWKind::maxs: 1964 return LLVM::AtomicBinOp::max; 1965 case arith::AtomicRMWKind::maxu: 1966 return LLVM::AtomicBinOp::umax; 1967 case arith::AtomicRMWKind::mins: 1968 return LLVM::AtomicBinOp::min; 1969 case arith::AtomicRMWKind::minu: 1970 return LLVM::AtomicBinOp::umin; 1971 case arith::AtomicRMWKind::ori: 1972 return LLVM::AtomicBinOp::_or; 1973 case arith::AtomicRMWKind::andi: 1974 return LLVM::AtomicBinOp::_and; 1975 default: 1976 return llvm::None; 1977 } 1978 llvm_unreachable("Invalid AtomicRMWKind"); 1979 } 1980 1981 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1982 using Base::Base; 1983 1984 LogicalResult 1985 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1986 ConversionPatternRewriter &rewriter) const override { 1987 if (failed(match(atomicOp))) 1988 return failure(); 1989 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1990 if (!maybeKind) 1991 return failure(); 1992 auto resultType = adaptor.getValue().getType(); 1993 auto memRefType = atomicOp.getMemRefType(); 1994 auto dataPtr = 1995 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), 1996 adaptor.getIndices(), rewriter); 1997 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1998 atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(), 1999 LLVM::AtomicOrdering::acq_rel); 2000 return success(); 2001 } 2002 }; 2003 2004 } // namespace 2005 2006 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 2007 RewritePatternSet &patterns) { 2008 // clang-format off 2009 patterns.add< 2010 AllocaOpLowering, 2011 AllocaScopeOpLowering, 2012 AtomicRMWOpLowering, 2013 AssumeAlignmentOpLowering, 2014 DimOpLowering, 2015 GenericAtomicRMWOpLowering, 2016 GlobalMemrefOpLowering, 2017 GetGlobalMemrefOpLowering, 2018 LoadOpLowering, 2019 MemRefCastOpLowering, 2020 MemRefCopyOpLowering, 2021 MemRefReinterpretCastOpLowering, 2022 MemRefReshapeOpLowering, 2023 PrefetchOpLowering, 2024 RankOpLowering, 2025 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 2026 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 2027 StoreOpLowering, 2028 SubViewOpLowering, 2029 TransposeOpLowering, 2030 ViewOpLowering>(converter); 2031 // clang-format on 2032 auto allocLowering = converter.getOptions().allocLowering; 2033 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 2034 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 2035 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 2036 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 2037 } 2038 2039 namespace { 2040 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 2041 MemRefToLLVMPass() = default; 2042 2043 void runOnOperation() override { 2044 Operation *op = getOperation(); 2045 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 2046 LowerToLLVMOptions options(&getContext(), 2047 dataLayoutAnalysis.getAtOrAbove(op)); 2048 options.allocLowering = 2049 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 2050 : LowerToLLVMOptions::AllocLowering::Malloc); 2051 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 2052 options.overrideIndexBitwidth(indexBitwidth); 2053 2054 LLVMTypeConverter typeConverter(&getContext(), options, 2055 &dataLayoutAnalysis); 2056 RewritePatternSet patterns(&getContext()); 2057 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 2058 LLVMConversionTarget target(getContext()); 2059 target.addLegalOp<func::FuncOp>(); 2060 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 2061 signalPassFailure(); 2062 } 2063 }; 2064 } // namespace 2065 2066 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 2067 return std::make_unique<MemRefToLLVMPass>(); 2068 } 2069