1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "llvm/ADT/SmallBitVector.h" 24 25 using namespace mlir; 26 27 namespace { 28 29 bool isStaticStrideOrOffset(int64_t strideOrOffset) { 30 return !ShapedType::isDynamicStrideOrOffset(strideOrOffset); 31 } 32 33 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 34 AllocOpLowering(LLVMTypeConverter &converter) 35 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 36 converter) {} 37 38 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 39 Location loc, Value sizeBytes, 40 Operation *op) const override { 41 // Heap allocations. 42 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 43 MemRefType memRefType = allocOp.getType(); 44 45 Value alignment; 46 if (auto alignmentAttr = allocOp.getAlignment()) { 47 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 48 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 49 // In the case where no alignment is specified, we may want to override 50 // `malloc's` behavior. `malloc` typically aligns at the size of the 51 // biggest scalar on a target HW. For non-scalars, use the natural 52 // alignment of the LLVM type given by the LLVM DataLayout. 53 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 54 } 55 56 if (alignment) { 57 // Adjust the allocation size to consider alignment. 58 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 59 } 60 61 // Allocate the underlying buffer and store a pointer to it in the MemRef 62 // descriptor. 63 Type elementPtrType = this->getElementPtrType(memRefType); 64 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 65 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 66 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 67 getVoidPtrType()); 68 Value allocatedPtr = 69 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 70 71 Value alignedPtr = allocatedPtr; 72 if (alignment) { 73 // Compute the aligned type pointer. 74 Value allocatedInt = 75 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 76 Value alignmentInt = 77 createAligned(rewriter, loc, allocatedInt, alignment); 78 alignedPtr = 79 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 80 } 81 82 return std::make_tuple(allocatedPtr, alignedPtr); 83 } 84 }; 85 86 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 87 AlignedAllocOpLowering(LLVMTypeConverter &converter) 88 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 89 converter) {} 90 91 /// Returns the memref's element size in bytes using the data layout active at 92 /// `op`. 93 // TODO: there are other places where this is used. Expose publicly? 94 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 95 const DataLayout *layout = &defaultLayout; 96 if (const DataLayoutAnalysis *analysis = 97 getTypeConverter()->getDataLayoutAnalysis()) { 98 layout = &analysis->getAbove(op); 99 } 100 Type elementType = memRefType.getElementType(); 101 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 102 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 103 *layout); 104 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 105 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 106 memRefElementType, *layout); 107 return layout->getTypeSize(elementType); 108 } 109 110 /// Returns true if the memref size in bytes is known to be a multiple of 111 /// factor assuming the data layout active at `op`. 112 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 113 Operation *op) const { 114 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 115 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 116 if (ShapedType::isDynamic(type.getDimSize(i))) 117 continue; 118 sizeDivisor = sizeDivisor * type.getDimSize(i); 119 } 120 return sizeDivisor % factor == 0; 121 } 122 123 /// Returns the alignment to be used for the allocation call itself. 124 /// aligned_alloc requires the allocation size to be a power of two, and the 125 /// allocation size to be a multiple of alignment, 126 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 127 if (Optional<uint64_t> alignment = allocOp.getAlignment()) 128 return *alignment; 129 130 // Whenever we don't have alignment set, we will use an alignment 131 // consistent with the element type; since the allocation size has to be a 132 // power of two, we will bump to the next power of two if it already isn't. 133 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 134 return std::max(kMinAlignedAllocAlignment, 135 llvm::PowerOf2Ceil(eltSizeBytes)); 136 } 137 138 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 139 Location loc, Value sizeBytes, 140 Operation *op) const override { 141 // Heap allocations. 142 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 143 MemRefType memRefType = allocOp.getType(); 144 int64_t alignment = getAllocationAlignment(allocOp); 145 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 146 147 // aligned_alloc requires size to be a multiple of alignment; we will pad 148 // the size to the next multiple if necessary. 149 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 150 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 151 152 Type elementPtrType = this->getElementPtrType(memRefType); 153 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 154 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 155 auto results = 156 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 157 getVoidPtrType()); 158 Value allocatedPtr = 159 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 160 161 return std::make_tuple(allocatedPtr, allocatedPtr); 162 } 163 164 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 165 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 166 167 /// Default layout to use in absence of the corresponding analysis. 168 DataLayout defaultLayout; 169 }; 170 171 // Out of line definition, required till C++17. 172 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 173 174 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 175 AllocaOpLowering(LLVMTypeConverter &converter) 176 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 177 converter) {} 178 179 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 180 /// is set to null for stack allocations. `accessAlignment` is set if 181 /// alignment is needed post allocation (for eg. in conjunction with malloc). 182 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 183 Location loc, Value sizeBytes, 184 Operation *op) const override { 185 186 // With alloca, one gets a pointer to the element type right away. 187 // For stack allocations. 188 auto allocaOp = cast<memref::AllocaOp>(op); 189 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 190 191 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 192 loc, elementPtrType, sizeBytes, 193 allocaOp.getAlignment() ? *allocaOp.getAlignment() : 0); 194 195 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 196 } 197 }; 198 199 struct AllocaScopeOpLowering 200 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 201 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 202 203 LogicalResult 204 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 205 ConversionPatternRewriter &rewriter) const override { 206 OpBuilder::InsertionGuard guard(rewriter); 207 Location loc = allocaScopeOp.getLoc(); 208 209 // Split the current block before the AllocaScopeOp to create the inlining 210 // point. 211 auto *currentBlock = rewriter.getInsertionBlock(); 212 auto *remainingOpsBlock = 213 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 214 Block *continueBlock; 215 if (allocaScopeOp.getNumResults() == 0) { 216 continueBlock = remainingOpsBlock; 217 } else { 218 continueBlock = rewriter.createBlock( 219 remainingOpsBlock, allocaScopeOp.getResultTypes(), 220 SmallVector<Location>(allocaScopeOp->getNumResults(), 221 allocaScopeOp.getLoc())); 222 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 223 } 224 225 // Inline body region. 226 Block *beforeBody = &allocaScopeOp.getBodyRegion().front(); 227 Block *afterBody = &allocaScopeOp.getBodyRegion().back(); 228 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock); 229 230 // Save stack and then branch into the body of the region. 231 rewriter.setInsertionPointToEnd(currentBlock); 232 auto stackSaveOp = 233 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 234 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 235 236 // Replace the alloca_scope return with a branch that jumps out of the body. 237 // Stack restore before leaving the body region. 238 rewriter.setInsertionPointToEnd(afterBody); 239 auto returnOp = 240 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 241 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 242 returnOp, returnOp.getResults(), continueBlock); 243 244 // Insert stack restore before jumping out the body of the region. 245 rewriter.setInsertionPoint(branchOp); 246 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 247 248 // Replace the op with values return from the body region. 249 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 250 251 return success(); 252 } 253 }; 254 255 struct AssumeAlignmentOpLowering 256 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 257 using ConvertOpToLLVMPattern< 258 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 259 260 LogicalResult 261 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 262 ConversionPatternRewriter &rewriter) const override { 263 Value memref = adaptor.getMemref(); 264 unsigned alignment = op.getAlignment(); 265 auto loc = op.getLoc(); 266 267 MemRefDescriptor memRefDescriptor(memref); 268 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 269 270 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 271 // the asserted memref.alignedPtr isn't used anywhere else, as the real 272 // users like load/store/views always re-extract memref.alignedPtr as they 273 // get lowered. 274 // 275 // This relies on LLVM's CSE optimization (potentially after SROA), since 276 // after CSE all memref.alignedPtr instances get de-duplicated into the same 277 // pointer SSA value. 278 auto intPtrType = 279 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 280 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 281 Value mask = 282 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 283 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 284 rewriter.create<LLVM::AssumeOp>( 285 loc, rewriter.create<LLVM::ICmpOp>( 286 loc, LLVM::ICmpPredicate::eq, 287 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 288 289 rewriter.eraseOp(op); 290 return success(); 291 } 292 }; 293 294 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 295 // The memref descriptor being an SSA value, there is no need to clean it up 296 // in any way. 297 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 298 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 299 300 explicit DeallocOpLowering(LLVMTypeConverter &converter) 301 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 302 303 LogicalResult 304 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 305 ConversionPatternRewriter &rewriter) const override { 306 // Insert the `free` declaration if it is not already present. 307 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 308 MemRefDescriptor memref(adaptor.getMemref()); 309 Value casted = rewriter.create<LLVM::BitcastOp>( 310 op.getLoc(), getVoidPtrType(), 311 memref.allocatedPtr(rewriter, op.getLoc())); 312 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 313 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 314 return success(); 315 } 316 }; 317 318 // A `dim` is converted to a constant for static sizes and to an access to the 319 // size stored in the memref descriptor for dynamic sizes. 320 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 321 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 322 323 LogicalResult 324 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 325 ConversionPatternRewriter &rewriter) const override { 326 Type operandType = dimOp.getSource().getType(); 327 if (operandType.isa<UnrankedMemRefType>()) { 328 rewriter.replaceOp( 329 dimOp, {extractSizeOfUnrankedMemRef( 330 operandType, dimOp, adaptor.getOperands(), rewriter)}); 331 332 return success(); 333 } 334 if (operandType.isa<MemRefType>()) { 335 rewriter.replaceOp( 336 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 337 adaptor.getOperands(), rewriter)}); 338 return success(); 339 } 340 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 341 } 342 343 private: 344 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 345 OpAdaptor adaptor, 346 ConversionPatternRewriter &rewriter) const { 347 Location loc = dimOp.getLoc(); 348 349 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 350 auto scalarMemRefType = 351 MemRefType::get({}, unrankedMemRefType.getElementType()); 352 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 353 354 // Extract pointer to the underlying ranked descriptor and bitcast it to a 355 // memref<element_type> descriptor pointer to minimize the number of GEP 356 // operations. 357 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource()); 358 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 359 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 360 loc, 361 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 362 addressSpace), 363 underlyingRankedDesc); 364 365 // Get pointer to offset field of memref<element_type> descriptor. 366 Type indexPtrTy = LLVM::LLVMPointerType::get( 367 getTypeConverter()->getIndexType(), addressSpace); 368 Value two = rewriter.create<LLVM::ConstantOp>( 369 loc, typeConverter->convertType(rewriter.getI32Type()), 370 rewriter.getI32IntegerAttr(2)); 371 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 372 loc, indexPtrTy, scalarMemRefDescPtr, 373 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 374 375 // The size value that we have to extract can be obtained using GEPop with 376 // `dimOp.index() + 1` index argument. 377 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 378 loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex()); 379 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 380 ValueRange({idxPlusOne})); 381 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 382 } 383 384 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 385 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 386 return idx; 387 388 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>()) 389 return constantOp.getValue() 390 .cast<IntegerAttr>() 391 .getValue() 392 .getSExtValue(); 393 394 return llvm::None; 395 } 396 397 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 398 OpAdaptor adaptor, 399 ConversionPatternRewriter &rewriter) const { 400 Location loc = dimOp.getLoc(); 401 402 // Take advantage if index is constant. 403 MemRefType memRefType = operandType.cast<MemRefType>(); 404 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 405 int64_t i = *index; 406 if (memRefType.isDynamicDim(i)) { 407 // extract dynamic size from the memref descriptor. 408 MemRefDescriptor descriptor(adaptor.getSource()); 409 return descriptor.size(rewriter, loc, i); 410 } 411 // Use constant for static size. 412 int64_t dimSize = memRefType.getDimSize(i); 413 return createIndexConstant(rewriter, loc, dimSize); 414 } 415 Value index = adaptor.getIndex(); 416 int64_t rank = memRefType.getRank(); 417 MemRefDescriptor memrefDescriptor(adaptor.getSource()); 418 return memrefDescriptor.size(rewriter, loc, index, rank); 419 } 420 }; 421 422 /// Common base for load and store operations on MemRefs. Restricts the match 423 /// to supported MemRef types. Provides functionality to emit code accessing a 424 /// specific element of the underlying data buffer. 425 template <typename Derived> 426 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 427 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 428 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 429 using Base = LoadStoreOpLowering<Derived>; 430 431 LogicalResult match(Derived op) const override { 432 MemRefType type = op.getMemRefType(); 433 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 434 } 435 }; 436 437 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 438 /// retried until it succeeds in atomically storing a new value into memory. 439 /// 440 /// +---------------------------------+ 441 /// | <code before the AtomicRMWOp> | 442 /// | <compute initial %loaded> | 443 /// | cf.br loop(%loaded) | 444 /// +---------------------------------+ 445 /// | 446 /// -------| | 447 /// | v v 448 /// | +--------------------------------+ 449 /// | | loop(%loaded): | 450 /// | | <body contents> | 451 /// | | %pair = cmpxchg | 452 /// | | %ok = %pair[0] | 453 /// | | %new = %pair[1] | 454 /// | | cf.cond_br %ok, end, loop(%new) | 455 /// | +--------------------------------+ 456 /// | | | 457 /// |----------- | 458 /// v 459 /// +--------------------------------+ 460 /// | end: | 461 /// | <code after the AtomicRMWOp> | 462 /// +--------------------------------+ 463 /// 464 struct GenericAtomicRMWOpLowering 465 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 466 using Base::Base; 467 468 LogicalResult 469 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 470 ConversionPatternRewriter &rewriter) const override { 471 auto loc = atomicOp.getLoc(); 472 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 473 474 // Split the block into initial, loop, and ending parts. 475 auto *initBlock = rewriter.getInsertionBlock(); 476 auto *loopBlock = rewriter.createBlock( 477 initBlock->getParent(), std::next(Region::iterator(initBlock)), 478 valueType, loc); 479 auto *endBlock = rewriter.createBlock( 480 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 481 482 // Operations range to be moved to `endBlock`. 483 auto opsToMoveStart = atomicOp->getIterator(); 484 auto opsToMoveEnd = initBlock->back().getIterator(); 485 486 // Compute the loaded value and branch to the loop block. 487 rewriter.setInsertionPointToEnd(initBlock); 488 auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>(); 489 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), 490 adaptor.getIndices(), rewriter); 491 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 492 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 493 494 // Prepare the body of the loop block. 495 rewriter.setInsertionPointToStart(loopBlock); 496 497 // Clone the GenericAtomicRMWOp region and extract the result. 498 auto loopArgument = loopBlock->getArgument(0); 499 BlockAndValueMapping mapping; 500 mapping.map(atomicOp.getCurrentValue(), loopArgument); 501 Block &entryBlock = atomicOp.body().front(); 502 for (auto &nestedOp : entryBlock.without_terminator()) { 503 Operation *clone = rewriter.clone(nestedOp, mapping); 504 mapping.map(nestedOp.getResults(), clone->getResults()); 505 } 506 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 507 508 // Prepare the epilog of the loop block. 509 // Append the cmpxchg op to the end of the loop block. 510 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 511 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 512 auto boolType = IntegerType::get(rewriter.getContext(), 1); 513 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 514 {valueType, boolType}); 515 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 516 loc, pairType, dataPtr, loopArgument, result, successOrdering, 517 failureOrdering); 518 // Extract the %new_loaded and %ok values from the pair. 519 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 520 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 521 Value ok = rewriter.create<LLVM::ExtractValueOp>( 522 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 523 524 // Conditionally branch to the end or back to the loop depending on %ok. 525 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 526 loopBlock, newLoaded); 527 528 rewriter.setInsertionPointToEnd(endBlock); 529 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 530 std::next(opsToMoveEnd), rewriter); 531 532 // The 'result' of the atomic_rmw op is the newly loaded value. 533 rewriter.replaceOp(atomicOp, {newLoaded}); 534 535 return success(); 536 } 537 538 private: 539 // Clones a segment of ops [start, end) and erases the original. 540 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 541 Block::iterator start, Block::iterator end, 542 ConversionPatternRewriter &rewriter) const { 543 BlockAndValueMapping mapping; 544 mapping.map(oldResult, newResult); 545 SmallVector<Operation *, 2> opsToErase; 546 for (auto it = start; it != end; ++it) { 547 rewriter.clone(*it, mapping); 548 opsToErase.push_back(&*it); 549 } 550 for (auto *it : opsToErase) 551 rewriter.eraseOp(it); 552 } 553 }; 554 555 /// Returns the LLVM type of the global variable given the memref type `type`. 556 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 557 LLVMTypeConverter &typeConverter) { 558 // LLVM type for a global memref will be a multi-dimension array. For 559 // declarations or uninitialized global memrefs, we can potentially flatten 560 // this to a 1D array. However, for memref.global's with an initial value, 561 // we do not intend to flatten the ElementsAttribute when going from std -> 562 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 563 Type elementType = typeConverter.convertType(type.getElementType()); 564 Type arrayTy = elementType; 565 // Shape has the outermost dim at index 0, so need to walk it backwards 566 for (int64_t dim : llvm::reverse(type.getShape())) 567 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 568 return arrayTy; 569 } 570 571 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 572 struct GlobalMemrefOpLowering 573 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 574 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 575 576 LogicalResult 577 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 578 ConversionPatternRewriter &rewriter) const override { 579 MemRefType type = global.getType(); 580 if (!isConvertibleAndHasIdentityMaps(type)) 581 return failure(); 582 583 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 584 585 LLVM::Linkage linkage = 586 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 587 588 Attribute initialValue = nullptr; 589 if (!global.isExternal() && !global.isUninitialized()) { 590 auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>(); 591 initialValue = elementsAttr; 592 593 // For scalar memrefs, the global variable created is of the element type, 594 // so unpack the elements attribute to extract the value. 595 if (type.getRank() == 0) 596 initialValue = elementsAttr.getSplatValue<Attribute>(); 597 } 598 599 uint64_t alignment = global.getAlignment().value_or(0); 600 601 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 602 global, arrayTy, global.getConstant(), linkage, global.getSymName(), 603 initialValue, alignment, type.getMemorySpaceAsInt()); 604 if (!global.isExternal() && global.isUninitialized()) { 605 Block *blk = new Block(); 606 newGlobal.getInitializerRegion().push_back(blk); 607 rewriter.setInsertionPointToStart(blk); 608 Value undef[] = { 609 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 610 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 611 } 612 return success(); 613 } 614 }; 615 616 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 617 /// the first element stashed into the descriptor. This reuses 618 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 619 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 620 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 621 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 622 converter) {} 623 624 /// Buffer "allocation" for memref.get_global op is getting the address of 625 /// the global variable referenced. 626 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 627 Location loc, Value sizeBytes, 628 Operation *op) const override { 629 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 630 MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>(); 631 unsigned memSpace = type.getMemorySpaceAsInt(); 632 633 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 634 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 635 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), 636 getGlobalOp.getName()); 637 638 // Get the address of the first element in the array by creating a GEP with 639 // the address of the GV as the base, and (rank + 1) number of 0 indices. 640 Type elementType = typeConverter->convertType(type.getElementType()); 641 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 642 643 SmallVector<Value> operands; 644 operands.insert(operands.end(), type.getRank() + 1, 645 createIndexConstant(rewriter, loc, 0)); 646 auto gep = 647 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 648 649 // We do not expect the memref obtained using `memref.get_global` to be 650 // ever deallocated. Set the allocated pointer to be known bad value to 651 // help debug if that ever happens. 652 auto intPtrType = getIntPtrType(memSpace); 653 Value deadBeefConst = 654 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 655 auto deadBeefPtr = 656 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 657 658 // Both allocated and aligned pointers are same. We could potentially stash 659 // a nullptr for the allocated pointer since we do not expect any dealloc. 660 return std::make_tuple(deadBeefPtr, gep); 661 } 662 }; 663 664 // Load operation is lowered to obtaining a pointer to the indexed element 665 // and loading it. 666 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 667 using Base::Base; 668 669 LogicalResult 670 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 671 ConversionPatternRewriter &rewriter) const override { 672 auto type = loadOp.getMemRefType(); 673 674 Value dataPtr = 675 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), 676 adaptor.getIndices(), rewriter); 677 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 678 return success(); 679 } 680 }; 681 682 // Store operation is lowered to obtaining a pointer to the indexed element, 683 // and storing the given value to it. 684 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 685 using Base::Base; 686 687 LogicalResult 688 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 689 ConversionPatternRewriter &rewriter) const override { 690 auto type = op.getMemRefType(); 691 692 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), 693 adaptor.getIndices(), rewriter); 694 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr); 695 return success(); 696 } 697 }; 698 699 // The prefetch operation is lowered in a way similar to the load operation 700 // except that the llvm.prefetch operation is used for replacement. 701 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 702 using Base::Base; 703 704 LogicalResult 705 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 706 ConversionPatternRewriter &rewriter) const override { 707 auto type = prefetchOp.getMemRefType(); 708 auto loc = prefetchOp.getLoc(); 709 710 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(), 711 adaptor.getIndices(), rewriter); 712 713 // Replace with llvm.prefetch. 714 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 715 auto isWrite = rewriter.create<LLVM::ConstantOp>( 716 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite())); 717 auto localityHint = rewriter.create<LLVM::ConstantOp>( 718 loc, llvmI32Type, 719 rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint())); 720 auto isData = rewriter.create<LLVM::ConstantOp>( 721 loc, llvmI32Type, 722 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache())); 723 724 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 725 localityHint, isData); 726 return success(); 727 } 728 }; 729 730 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 731 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 732 733 LogicalResult 734 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 735 ConversionPatternRewriter &rewriter) const override { 736 Location loc = op.getLoc(); 737 Type operandType = op.getMemref().getType(); 738 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 739 UnrankedMemRefDescriptor desc(adaptor.getMemref()); 740 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 741 return success(); 742 } 743 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 744 rewriter.replaceOp( 745 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 746 return success(); 747 } 748 return failure(); 749 } 750 }; 751 752 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 753 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 754 755 LogicalResult match(memref::CastOp memRefCastOp) const override { 756 Type srcType = memRefCastOp.getOperand().getType(); 757 Type dstType = memRefCastOp.getType(); 758 759 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 760 // used for type erasure. For now they must preserve underlying element type 761 // and require source and result type to have the same rank. Therefore, 762 // perform a sanity check that the underlying structs are the same. Once op 763 // semantics are relaxed we can revisit. 764 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 765 return success(typeConverter->convertType(srcType) == 766 typeConverter->convertType(dstType)); 767 768 // At least one of the operands is unranked type 769 assert(srcType.isa<UnrankedMemRefType>() || 770 dstType.isa<UnrankedMemRefType>()); 771 772 // Unranked to unranked cast is disallowed 773 return !(srcType.isa<UnrankedMemRefType>() && 774 dstType.isa<UnrankedMemRefType>()) 775 ? success() 776 : failure(); 777 } 778 779 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 780 ConversionPatternRewriter &rewriter) const override { 781 auto srcType = memRefCastOp.getOperand().getType(); 782 auto dstType = memRefCastOp.getType(); 783 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 784 auto loc = memRefCastOp.getLoc(); 785 786 // For ranked/ranked case, just keep the original descriptor. 787 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 788 return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); 789 790 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 791 // Casting ranked to unranked memref type 792 // Set the rank in the destination from the memref type 793 // Allocate space on the stack and copy the src memref descriptor 794 // Set the ptr in the destination to the stack space 795 auto srcMemRefType = srcType.cast<MemRefType>(); 796 int64_t rank = srcMemRefType.getRank(); 797 // ptr = AllocaOp sizeof(MemRefDescriptor) 798 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 799 loc, adaptor.getSource(), rewriter); 800 // voidptr = BitCastOp srcType* to void* 801 auto voidPtr = 802 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 803 .getResult(); 804 // rank = ConstantOp srcRank 805 auto rankVal = rewriter.create<LLVM::ConstantOp>( 806 loc, getIndexType(), rewriter.getIndexAttr(rank)); 807 // undef = UndefOp 808 UnrankedMemRefDescriptor memRefDesc = 809 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 810 // d1 = InsertValueOp undef, rank, 0 811 memRefDesc.setRank(rewriter, loc, rankVal); 812 // d2 = InsertValueOp d1, voidptr, 1 813 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 814 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 815 816 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 817 // Casting from unranked type to ranked. 818 // The operation is assumed to be doing a correct cast. If the destination 819 // type mismatches the unranked the type, it is undefined behavior. 820 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource()); 821 // ptr = ExtractValueOp src, 1 822 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 823 // castPtr = BitCastOp i8* to structTy* 824 auto castPtr = 825 rewriter 826 .create<LLVM::BitcastOp>( 827 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 828 .getResult(); 829 // struct = LoadOp castPtr 830 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 831 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 832 } else { 833 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 834 } 835 } 836 }; 837 838 /// Pattern to lower a `memref.copy` to llvm. 839 /// 840 /// For memrefs with identity layouts, the copy is lowered to the llvm 841 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 842 /// to the generic `MemrefCopyFn`. 843 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 844 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 845 846 LogicalResult 847 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 848 ConversionPatternRewriter &rewriter) const { 849 auto loc = op.getLoc(); 850 auto srcType = op.getSource().getType().dyn_cast<MemRefType>(); 851 852 MemRefDescriptor srcDesc(adaptor.getSource()); 853 854 // Compute number of elements. 855 Value numElements = rewriter.create<LLVM::ConstantOp>( 856 loc, getIndexType(), rewriter.getIndexAttr(1)); 857 for (int pos = 0; pos < srcType.getRank(); ++pos) { 858 auto size = srcDesc.size(rewriter, loc, pos); 859 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 860 } 861 862 // Get element size. 863 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 864 // Compute total. 865 Value totalSize = 866 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 867 868 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 869 Value srcOffset = srcDesc.offset(rewriter, loc); 870 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(), 871 srcBasePtr, srcOffset); 872 MemRefDescriptor targetDesc(adaptor.getTarget()); 873 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 874 Value targetOffset = targetDesc.offset(rewriter, loc); 875 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(), 876 targetBasePtr, targetOffset); 877 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 878 loc, typeConverter->convertType(rewriter.getI1Type()), 879 rewriter.getBoolAttr(false)); 880 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 881 isVolatile); 882 rewriter.eraseOp(op); 883 884 return success(); 885 } 886 887 LogicalResult 888 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 889 ConversionPatternRewriter &rewriter) const { 890 auto loc = op.getLoc(); 891 auto srcType = op.getSource().getType().cast<BaseMemRefType>(); 892 auto targetType = op.getTarget().getType().cast<BaseMemRefType>(); 893 894 // First make sure we have an unranked memref descriptor representation. 895 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 896 auto rank = rewriter.create<LLVM::ConstantOp>( 897 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 898 auto *typeConverter = getTypeConverter(); 899 auto ptr = 900 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 901 auto voidPtr = 902 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 903 .getResult(); 904 auto unrankedType = 905 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 906 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 907 unrankedType, 908 ValueRange{rank, voidPtr}); 909 }; 910 911 Value unrankedSource = srcType.hasRank() 912 ? makeUnranked(adaptor.getSource(), srcType) 913 : adaptor.getSource(); 914 Value unrankedTarget = targetType.hasRank() 915 ? makeUnranked(adaptor.getTarget(), targetType) 916 : adaptor.getTarget(); 917 918 // Now promote the unranked descriptors to the stack. 919 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 920 rewriter.getIndexAttr(1)); 921 auto promote = [&](Value desc) { 922 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 923 auto allocated = 924 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 925 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 926 return allocated; 927 }; 928 929 auto sourcePtr = promote(unrankedSource); 930 auto targetPtr = promote(unrankedTarget); 931 932 unsigned typeSize = 933 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 934 auto elemSize = rewriter.create<LLVM::ConstantOp>( 935 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 936 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 937 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 938 rewriter.create<LLVM::CallOp>(loc, copyFn, 939 ValueRange{elemSize, sourcePtr, targetPtr}); 940 rewriter.eraseOp(op); 941 942 return success(); 943 } 944 945 LogicalResult 946 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 947 ConversionPatternRewriter &rewriter) const override { 948 auto srcType = op.getSource().getType().cast<BaseMemRefType>(); 949 auto targetType = op.getTarget().getType().cast<BaseMemRefType>(); 950 951 auto isContiguousMemrefType = [](BaseMemRefType type) { 952 auto memrefType = type.dyn_cast<mlir::MemRefType>(); 953 // We can use memcpy for memrefs if they have an identity layout or are 954 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 955 // special case handled by memrefCopy. 956 return memrefType && 957 (memrefType.getLayout().isIdentity() || 958 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 959 isStaticShapeAndContiguousRowMajor(memrefType))); 960 }; 961 962 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 963 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 964 965 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 966 } 967 }; 968 969 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 970 /// memref type. In unranked case, the fields are extracted from the underlying 971 /// ranked descriptor. 972 static void extractPointersAndOffset(Location loc, 973 ConversionPatternRewriter &rewriter, 974 LLVMTypeConverter &typeConverter, 975 Value originalOperand, 976 Value convertedOperand, 977 Value *allocatedPtr, Value *alignedPtr, 978 Value *offset = nullptr) { 979 Type operandType = originalOperand.getType(); 980 if (operandType.isa<MemRefType>()) { 981 MemRefDescriptor desc(convertedOperand); 982 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 983 *alignedPtr = desc.alignedPtr(rewriter, loc); 984 if (offset != nullptr) 985 *offset = desc.offset(rewriter, loc); 986 return; 987 } 988 989 unsigned memorySpace = 990 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 991 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 992 Type llvmElementType = typeConverter.convertType(elementType); 993 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 994 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 995 996 // Extract pointer to the underlying ranked memref descriptor and cast it to 997 // ElemType**. 998 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 999 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 1000 1001 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 1002 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 1003 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 1004 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1005 if (offset != nullptr) { 1006 *offset = UnrankedMemRefDescriptor::offset( 1007 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1008 } 1009 } 1010 1011 struct MemRefReinterpretCastOpLowering 1012 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1013 using ConvertOpToLLVMPattern< 1014 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1015 1016 LogicalResult 1017 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1018 ConversionPatternRewriter &rewriter) const override { 1019 Type srcType = castOp.getSource().getType(); 1020 1021 Value descriptor; 1022 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1023 adaptor, &descriptor))) 1024 return failure(); 1025 rewriter.replaceOp(castOp, {descriptor}); 1026 return success(); 1027 } 1028 1029 private: 1030 LogicalResult convertSourceMemRefToDescriptor( 1031 ConversionPatternRewriter &rewriter, Type srcType, 1032 memref::ReinterpretCastOp castOp, 1033 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1034 MemRefType targetMemRefType = 1035 castOp.getResult().getType().cast<MemRefType>(); 1036 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1037 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1038 if (!llvmTargetDescriptorTy) 1039 return failure(); 1040 1041 // Create descriptor. 1042 Location loc = castOp.getLoc(); 1043 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1044 1045 // Set allocated and aligned pointers. 1046 Value allocatedPtr, alignedPtr; 1047 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1048 castOp.getSource(), adaptor.getSource(), 1049 &allocatedPtr, &alignedPtr); 1050 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1051 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1052 1053 // Set offset. 1054 if (castOp.isDynamicOffset(0)) 1055 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); 1056 else 1057 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1058 1059 // Set sizes and strides. 1060 unsigned dynSizeId = 0; 1061 unsigned dynStrideId = 0; 1062 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1063 if (castOp.isDynamicSize(i)) 1064 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); 1065 else 1066 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1067 1068 if (castOp.isDynamicStride(i)) 1069 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); 1070 else 1071 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1072 } 1073 *descriptor = desc; 1074 return success(); 1075 } 1076 }; 1077 1078 struct MemRefReshapeOpLowering 1079 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1080 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1081 1082 LogicalResult 1083 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1084 ConversionPatternRewriter &rewriter) const override { 1085 Type srcType = reshapeOp.getSource().getType(); 1086 1087 Value descriptor; 1088 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1089 adaptor, &descriptor))) 1090 return failure(); 1091 rewriter.replaceOp(reshapeOp, {descriptor}); 1092 return success(); 1093 } 1094 1095 private: 1096 LogicalResult 1097 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1098 Type srcType, memref::ReshapeOp reshapeOp, 1099 memref::ReshapeOp::Adaptor adaptor, 1100 Value *descriptor) const { 1101 auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>(); 1102 if (shapeMemRefType.hasStaticShape()) { 1103 MemRefType targetMemRefType = 1104 reshapeOp.getResult().getType().cast<MemRefType>(); 1105 auto llvmTargetDescriptorTy = 1106 typeConverter->convertType(targetMemRefType) 1107 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1108 if (!llvmTargetDescriptorTy) 1109 return failure(); 1110 1111 // Create descriptor. 1112 Location loc = reshapeOp.getLoc(); 1113 auto desc = 1114 MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1115 1116 // Set allocated and aligned pointers. 1117 Value allocatedPtr, alignedPtr; 1118 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1119 reshapeOp.getSource(), adaptor.getSource(), 1120 &allocatedPtr, &alignedPtr); 1121 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1122 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1123 1124 // Extract the offset and strides from the type. 1125 int64_t offset; 1126 SmallVector<int64_t> strides; 1127 if (failed(getStridesAndOffset(targetMemRefType, strides, offset))) 1128 return rewriter.notifyMatchFailure( 1129 reshapeOp, "failed to get stride and offset exprs"); 1130 1131 if (!isStaticStrideOrOffset(offset)) 1132 return rewriter.notifyMatchFailure(reshapeOp, 1133 "dynamic offset is unsupported"); 1134 1135 desc.setConstantOffset(rewriter, loc, offset); 1136 1137 assert(targetMemRefType.getLayout().isIdentity() && 1138 "Identity layout map is a precondition of a valid reshape op"); 1139 1140 Value stride = nullptr; 1141 int64_t targetRank = targetMemRefType.getRank(); 1142 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) { 1143 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1144 // If the stride for this dimension is dynamic, then use the product 1145 // of the sizes of the inner dimensions. 1146 stride = createIndexConstant(rewriter, loc, strides[i]); 1147 } else if (!stride) { 1148 // `stride` is null only in the first iteration of the loop. However, 1149 // since the target memref has an identity layout, we can safely set 1150 // the innermost stride to 1. 1151 stride = createIndexConstant(rewriter, loc, 1); 1152 } 1153 1154 Value dimSize; 1155 int64_t size = targetMemRefType.getDimSize(i); 1156 // If the size of this dimension is dynamic, then load it at runtime 1157 // from the shape operand. 1158 if (!ShapedType::isDynamic(size)) { 1159 dimSize = createIndexConstant(rewriter, loc, size); 1160 } else { 1161 Value shapeOp = reshapeOp.getShape(); 1162 Value index = createIndexConstant(rewriter, loc, i); 1163 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index); 1164 } 1165 1166 desc.setSize(rewriter, loc, i, dimSize); 1167 desc.setStride(rewriter, loc, i, stride); 1168 1169 // Prepare the stride value for the next dimension. 1170 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize); 1171 } 1172 1173 *descriptor = desc; 1174 return success(); 1175 } 1176 1177 // The shape is a rank-1 tensor with unknown length. 1178 Location loc = reshapeOp.getLoc(); 1179 MemRefDescriptor shapeDesc(adaptor.getShape()); 1180 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1181 1182 // Extract address space and element type. 1183 auto targetType = 1184 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1185 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1186 Type elementType = targetType.getElementType(); 1187 1188 // Create the unranked memref descriptor that holds the ranked one. The 1189 // inner descriptor is allocated on stack. 1190 auto targetDesc = UnrankedMemRefDescriptor::undef( 1191 rewriter, loc, typeConverter->convertType(targetType)); 1192 targetDesc.setRank(rewriter, loc, resultRank); 1193 SmallVector<Value, 4> sizes; 1194 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1195 targetDesc, sizes); 1196 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1197 loc, getVoidPtrType(), sizes.front(), llvm::None); 1198 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1199 1200 // Extract pointers and offset from the source memref. 1201 Value allocatedPtr, alignedPtr, offset; 1202 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1203 reshapeOp.getSource(), adaptor.getSource(), 1204 &allocatedPtr, &alignedPtr, &offset); 1205 1206 // Set pointers and offset. 1207 Type llvmElementType = typeConverter->convertType(elementType); 1208 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1209 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1210 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1211 elementPtrPtrType, allocatedPtr); 1212 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1213 underlyingDescPtr, 1214 elementPtrPtrType, alignedPtr); 1215 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1216 underlyingDescPtr, elementPtrPtrType, 1217 offset); 1218 1219 // Use the offset pointer as base for further addressing. Copy over the new 1220 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1221 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1222 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1223 elementPtrPtrType); 1224 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1225 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1226 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1227 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1228 Value resultRankMinusOne = 1229 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1230 1231 Block *initBlock = rewriter.getInsertionBlock(); 1232 Type indexType = getTypeConverter()->getIndexType(); 1233 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1234 1235 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1236 {indexType, indexType}, {loc, loc}); 1237 1238 // Move the remaining initBlock ops to condBlock. 1239 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1240 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1241 1242 rewriter.setInsertionPointToEnd(initBlock); 1243 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1244 condBlock); 1245 rewriter.setInsertionPointToStart(condBlock); 1246 Value indexArg = condBlock->getArgument(0); 1247 Value strideArg = condBlock->getArgument(1); 1248 1249 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1250 Value pred = rewriter.create<LLVM::ICmpOp>( 1251 loc, IntegerType::get(rewriter.getContext(), 1), 1252 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1253 1254 Block *bodyBlock = 1255 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1256 rewriter.setInsertionPointToStart(bodyBlock); 1257 1258 // Copy size from shape to descriptor. 1259 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1260 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1261 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1262 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1263 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1264 targetSizesBase, indexArg, size); 1265 1266 // Write stride value and compute next one. 1267 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1268 targetStridesBase, indexArg, strideArg); 1269 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1270 1271 // Decrement loop counter and branch back. 1272 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1273 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1274 condBlock); 1275 1276 Block *remainder = 1277 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1278 1279 // Hook up the cond exit to the remainder. 1280 rewriter.setInsertionPointToEnd(condBlock); 1281 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1282 llvm::None); 1283 1284 // Reset position to beginning of new remainder block. 1285 rewriter.setInsertionPointToStart(remainder); 1286 1287 *descriptor = targetDesc; 1288 return success(); 1289 } 1290 }; 1291 1292 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1293 /// `Value`s. 1294 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1295 Type &llvmIndexType, 1296 ArrayRef<OpFoldResult> valueOrAttrVec) { 1297 return llvm::to_vector<4>( 1298 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1299 if (auto attr = value.dyn_cast<Attribute>()) 1300 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1301 return value.get<Value>(); 1302 })); 1303 } 1304 1305 /// Compute a map that for a given dimension of the expanded type gives the 1306 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1307 /// the `reassocation` maps. 1308 static DenseMap<int64_t, int64_t> 1309 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1310 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1311 for (auto &en : enumerate(reassociation)) { 1312 for (auto dim : en.value()) 1313 expandedDimToCollapsedDim[dim] = en.index(); 1314 } 1315 return expandedDimToCollapsedDim; 1316 } 1317 1318 static OpFoldResult 1319 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1320 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1321 MemRefDescriptor &inDesc, 1322 ArrayRef<int64_t> inStaticShape, 1323 ArrayRef<ReassociationIndices> reassocation, 1324 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1325 int64_t outDimSize = outStaticShape[outDimIndex]; 1326 if (!ShapedType::isDynamic(outDimSize)) 1327 return b.getIndexAttr(outDimSize); 1328 1329 // Calculate the multiplication of all the out dim sizes except the 1330 // current dim. 1331 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1332 int64_t otherDimSizesMul = 1; 1333 for (auto otherDimIndex : reassocation[inDimIndex]) { 1334 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1335 continue; 1336 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1337 assert(!ShapedType::isDynamic(otherDimSize) && 1338 "single dimension cannot be expanded into multiple dynamic " 1339 "dimensions"); 1340 otherDimSizesMul *= otherDimSize; 1341 } 1342 1343 // outDimSize = inDimSize / otherOutDimSizesMul 1344 int64_t inDimSize = inStaticShape[inDimIndex]; 1345 Value inDimSizeDynamic = 1346 ShapedType::isDynamic(inDimSize) 1347 ? inDesc.size(b, loc, inDimIndex) 1348 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1349 b.getIndexAttr(inDimSize)); 1350 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1351 loc, inDimSizeDynamic, 1352 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1353 b.getIndexAttr(otherDimSizesMul))); 1354 return outDimSizeDynamic; 1355 } 1356 1357 static OpFoldResult getCollapsedOutputDimSize( 1358 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1359 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1360 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1361 if (!ShapedType::isDynamic(outDimSize)) 1362 return b.getIndexAttr(outDimSize); 1363 1364 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1365 Value outDimSizeDynamic = c1; 1366 for (auto inDimIndex : reassocation[outDimIndex]) { 1367 int64_t inDimSize = inStaticShape[inDimIndex]; 1368 Value inDimSizeDynamic = 1369 ShapedType::isDynamic(inDimSize) 1370 ? inDesc.size(b, loc, inDimIndex) 1371 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1372 b.getIndexAttr(inDimSize)); 1373 outDimSizeDynamic = 1374 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1375 } 1376 return outDimSizeDynamic; 1377 } 1378 1379 static SmallVector<OpFoldResult, 4> 1380 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1381 ArrayRef<ReassociationIndices> reassociation, 1382 ArrayRef<int64_t> inStaticShape, 1383 MemRefDescriptor &inDesc, 1384 ArrayRef<int64_t> outStaticShape) { 1385 return llvm::to_vector<4>(llvm::map_range( 1386 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1387 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1388 outStaticShape[outDimIndex], 1389 inStaticShape, inDesc, reassociation); 1390 })); 1391 } 1392 1393 static SmallVector<OpFoldResult, 4> 1394 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1395 ArrayRef<ReassociationIndices> reassociation, 1396 ArrayRef<int64_t> inStaticShape, 1397 MemRefDescriptor &inDesc, 1398 ArrayRef<int64_t> outStaticShape) { 1399 DenseMap<int64_t, int64_t> outDimToInDimMap = 1400 getExpandedDimToCollapsedDimMap(reassociation); 1401 return llvm::to_vector<4>(llvm::map_range( 1402 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1403 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1404 outStaticShape, inDesc, inStaticShape, 1405 reassociation, outDimToInDimMap); 1406 })); 1407 } 1408 1409 static SmallVector<Value> 1410 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1411 ArrayRef<ReassociationIndices> reassociation, 1412 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1413 ArrayRef<int64_t> outStaticShape) { 1414 return outStaticShape.size() < inStaticShape.size() 1415 ? getAsValues(b, loc, llvmIndexType, 1416 getCollapsedOutputShape(b, loc, llvmIndexType, 1417 reassociation, inStaticShape, 1418 inDesc, outStaticShape)) 1419 : getAsValues(b, loc, llvmIndexType, 1420 getExpandedOutputShape(b, loc, llvmIndexType, 1421 reassociation, inStaticShape, 1422 inDesc, outStaticShape)); 1423 } 1424 1425 static void fillInStridesForExpandedMemDescriptor( 1426 OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, 1427 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1428 // See comments for computeExpandedLayoutMap for details on how the strides 1429 // are calculated. 1430 for (auto &en : llvm::enumerate(reassociation)) { 1431 auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); 1432 for (auto dstIndex : llvm::reverse(en.value())) { 1433 dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); 1434 Value size = dstDesc.size(b, loc, dstIndex); 1435 currentStrideToExpand = 1436 b.create<LLVM::MulOp>(loc, size, currentStrideToExpand); 1437 } 1438 } 1439 } 1440 1441 static void fillInStridesForCollapsedMemDescriptor( 1442 ConversionPatternRewriter &rewriter, Location loc, Operation *op, 1443 TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, 1444 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1445 // See comments for computeCollapsedLayoutMap for details on how the strides 1446 // are calculated. 1447 auto srcShape = srcType.getShape(); 1448 for (auto &en : llvm::enumerate(reassociation)) { 1449 rewriter.setInsertionPoint(op); 1450 auto dstIndex = en.index(); 1451 ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value()); 1452 while (srcShape[ref.back()] == 1 && ref.size() > 1) 1453 ref = ref.drop_back(); 1454 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { 1455 dstDesc.setStride(rewriter, loc, dstIndex, 1456 srcDesc.stride(rewriter, loc, ref.back())); 1457 } else { 1458 // Iterate over the source strides in reverse order. Skip over the 1459 // dimensions whose size is 1. 1460 // TODO: we should take the minimum stride in the reassociation group 1461 // instead of just the first where the dimension is not 1. 1462 // 1463 // +------------------------------------------------------+ 1464 // | curEntry: | 1465 // | %srcStride = strides[srcIndex] | 1466 // | %neOne = cmp sizes[srcIndex],1 +--+ 1467 // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | 1468 // +-------------------------+----------------------------+ | 1469 // | | 1470 // v | 1471 // +-----------------------------+ | 1472 // | nextEntry: | | 1473 // | ... +---+ | 1474 // +--------------+--------------+ | | 1475 // | | | 1476 // v | | 1477 // +-----------------------------+ | | 1478 // | nextEntry: | | | 1479 // | ... | | | 1480 // +--------------+--------------+ | +--------+ 1481 // | | | 1482 // v v v 1483 // +--------------------------------------------------+ 1484 // | continue(%newStride): | 1485 // | %newMemRefDes = setStride(%newStride,dstIndex) | 1486 // +--------------------------------------------------+ 1487 OpBuilder::InsertionGuard guard(rewriter); 1488 Block *initBlock = rewriter.getInsertionBlock(); 1489 Block *continueBlock = 1490 rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); 1491 continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); 1492 rewriter.setInsertionPointToStart(continueBlock); 1493 dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); 1494 1495 Block *curEntryBlock = initBlock; 1496 Block *nextEntryBlock; 1497 for (auto srcIndex : llvm::reverse(ref)) { 1498 if (srcShape[srcIndex] == 1 && srcIndex != ref.front()) 1499 continue; 1500 rewriter.setInsertionPointToEnd(curEntryBlock); 1501 Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); 1502 if (srcIndex == ref.front()) { 1503 rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock); 1504 break; 1505 } 1506 Value one = rewriter.create<LLVM::ConstantOp>( 1507 loc, typeConverter->convertType(rewriter.getI64Type()), 1508 rewriter.getI32IntegerAttr(1)); 1509 Value predNeOne = rewriter.create<LLVM::ICmpOp>( 1510 loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), 1511 one); 1512 { 1513 OpBuilder::InsertionGuard guard(rewriter); 1514 nextEntryBlock = rewriter.createBlock( 1515 initBlock->getParent(), Region::iterator(continueBlock), {}); 1516 } 1517 rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock, 1518 srcStride, nextEntryBlock, llvm::None); 1519 curEntryBlock = nextEntryBlock; 1520 } 1521 } 1522 } 1523 } 1524 1525 static void fillInDynamicStridesForMemDescriptor( 1526 ConversionPatternRewriter &b, Location loc, Operation *op, 1527 TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, 1528 MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, 1529 ArrayRef<ReassociationIndices> reassociation) { 1530 if (srcType.getRank() > dstType.getRank()) 1531 fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, 1532 srcDesc, dstDesc, reassociation); 1533 else 1534 fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, 1535 reassociation); 1536 } 1537 1538 // ReshapeOp creates a new view descriptor of the proper rank. 1539 // For now, the only conversion supported is for target MemRef with static sizes 1540 // and strides. 1541 template <typename ReshapeOp> 1542 class ReassociatingReshapeOpConversion 1543 : public ConvertOpToLLVMPattern<ReshapeOp> { 1544 public: 1545 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1546 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1547 1548 LogicalResult 1549 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1550 ConversionPatternRewriter &rewriter) const override { 1551 MemRefType dstType = reshapeOp.getResultType(); 1552 MemRefType srcType = reshapeOp.getSrcType(); 1553 1554 int64_t offset; 1555 SmallVector<int64_t, 4> strides; 1556 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1557 return rewriter.notifyMatchFailure( 1558 reshapeOp, "failed to get stride and offset exprs"); 1559 } 1560 1561 MemRefDescriptor srcDesc(adaptor.getSrc()); 1562 Location loc = reshapeOp->getLoc(); 1563 auto dstDesc = MemRefDescriptor::undef( 1564 rewriter, loc, this->typeConverter->convertType(dstType)); 1565 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1566 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1567 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1568 1569 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1570 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1571 Type llvmIndexType = 1572 this->typeConverter->convertType(rewriter.getIndexType()); 1573 SmallVector<Value> dstShape = getDynamicOutputShape( 1574 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1575 srcStaticShape, srcDesc, dstStaticShape); 1576 for (auto &en : llvm::enumerate(dstShape)) 1577 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1578 1579 if (llvm::all_of(strides, isStaticStrideOrOffset)) { 1580 for (auto &en : llvm::enumerate(strides)) 1581 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1582 } else if (srcType.getLayout().isIdentity() && 1583 dstType.getLayout().isIdentity()) { 1584 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1585 rewriter.getIndexAttr(1)); 1586 Value stride = c1; 1587 for (auto dimIndex : 1588 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1589 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1590 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1591 } 1592 } else { 1593 // There could be mixed static/dynamic strides. For simplicity, we 1594 // recompute all strides if there is at least one dynamic stride. 1595 fillInDynamicStridesForMemDescriptor( 1596 rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, 1597 srcDesc, dstDesc, reshapeOp.getReassociationIndices()); 1598 } 1599 rewriter.replaceOp(reshapeOp, {dstDesc}); 1600 return success(); 1601 } 1602 }; 1603 1604 /// Conversion pattern that transforms a subview op into: 1605 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1606 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1607 /// and stride. 1608 /// The subview op is replaced by the descriptor. 1609 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1610 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1611 1612 LogicalResult 1613 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1614 ConversionPatternRewriter &rewriter) const override { 1615 auto loc = subViewOp.getLoc(); 1616 1617 auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>(); 1618 auto sourceElementTy = 1619 typeConverter->convertType(sourceMemRefType.getElementType()); 1620 1621 auto viewMemRefType = subViewOp.getType(); 1622 auto inferredType = 1623 memref::SubViewOp::inferResultType( 1624 subViewOp.getSourceType(), 1625 extractFromI64ArrayAttr(subViewOp.getStaticOffsets()), 1626 extractFromI64ArrayAttr(subViewOp.getStaticSizes()), 1627 extractFromI64ArrayAttr(subViewOp.getStaticStrides())) 1628 .cast<MemRefType>(); 1629 auto targetElementTy = 1630 typeConverter->convertType(viewMemRefType.getElementType()); 1631 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1632 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1633 !LLVM::isCompatibleType(sourceElementTy) || 1634 !LLVM::isCompatibleType(targetElementTy) || 1635 !LLVM::isCompatibleType(targetDescTy)) 1636 return failure(); 1637 1638 // Extract the offset and strides from the type. 1639 int64_t offset; 1640 SmallVector<int64_t, 4> strides; 1641 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1642 if (failed(successStrides)) 1643 return failure(); 1644 1645 // Create the descriptor. 1646 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1647 return failure(); 1648 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1649 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1650 1651 // Copy the buffer pointer from the old descriptor to the new one. 1652 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1653 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1654 loc, 1655 LLVM::LLVMPointerType::get(targetElementTy, 1656 viewMemRefType.getMemorySpaceAsInt()), 1657 extracted); 1658 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1659 1660 // Copy the aligned pointer from the old descriptor to the new one. 1661 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1662 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1663 loc, 1664 LLVM::LLVMPointerType::get(targetElementTy, 1665 viewMemRefType.getMemorySpaceAsInt()), 1666 extracted); 1667 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1668 1669 size_t inferredShapeRank = inferredType.getRank(); 1670 size_t resultShapeRank = viewMemRefType.getRank(); 1671 1672 // Extract strides needed to compute offset. 1673 SmallVector<Value, 4> strideValues; 1674 strideValues.reserve(inferredShapeRank); 1675 for (unsigned i = 0; i < inferredShapeRank; ++i) 1676 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1677 1678 // Offset. 1679 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1680 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1681 targetMemRef.setConstantOffset(rewriter, loc, offset); 1682 } else { 1683 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1684 // `inferredShapeRank` may be larger than the number of offset operands 1685 // because of trailing semantics. In this case, the offset is guaranteed 1686 // to be interpreted as 0 and we can just skip the extra dimensions. 1687 for (unsigned i = 0, e = std::min(inferredShapeRank, 1688 subViewOp.getMixedOffsets().size()); 1689 i < e; ++i) { 1690 Value offset = 1691 // TODO: need OpFoldResult ODS adaptor to clean this up. 1692 subViewOp.isDynamicOffset(i) 1693 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1694 : rewriter.create<LLVM::ConstantOp>( 1695 loc, llvmIndexType, 1696 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1697 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1698 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1699 } 1700 targetMemRef.setOffset(rewriter, loc, baseOffset); 1701 } 1702 1703 // Update sizes and strides. 1704 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1705 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1706 assert(mixedSizes.size() == mixedStrides.size() && 1707 "expected sizes and strides of equal length"); 1708 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1709 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1710 i >= 0 && j >= 0; --i) { 1711 if (unusedDims.test(i)) 1712 continue; 1713 1714 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1715 // In this case, the size is guaranteed to be interpreted as Dim and the 1716 // stride as 1. 1717 Value size, stride; 1718 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1719 // If the static size is available, use it directly. This is similar to 1720 // the folding of dim(constant-op) but removes the need for dim to be 1721 // aware of LLVM constants and for this pass to be aware of std 1722 // constants. 1723 int64_t staticSize = 1724 subViewOp.getSource().getType().cast<MemRefType>().getShape()[i]; 1725 if (staticSize != ShapedType::kDynamicSize) { 1726 size = rewriter.create<LLVM::ConstantOp>( 1727 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1728 } else { 1729 Value pos = rewriter.create<LLVM::ConstantOp>( 1730 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1731 Value dim = 1732 rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos); 1733 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1734 loc, llvmIndexType, dim); 1735 size = cast.getResult(0); 1736 } 1737 stride = rewriter.create<LLVM::ConstantOp>( 1738 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1739 } else { 1740 // TODO: need OpFoldResult ODS adaptor to clean this up. 1741 size = 1742 subViewOp.isDynamicSize(i) 1743 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1744 : rewriter.create<LLVM::ConstantOp>( 1745 loc, llvmIndexType, 1746 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1747 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1748 stride = rewriter.create<LLVM::ConstantOp>( 1749 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1750 } else { 1751 stride = 1752 subViewOp.isDynamicStride(i) 1753 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1754 : rewriter.create<LLVM::ConstantOp>( 1755 loc, llvmIndexType, 1756 rewriter.getI64IntegerAttr( 1757 subViewOp.getStaticStride(i))); 1758 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1759 } 1760 } 1761 targetMemRef.setSize(rewriter, loc, j, size); 1762 targetMemRef.setStride(rewriter, loc, j, stride); 1763 j--; 1764 } 1765 1766 rewriter.replaceOp(subViewOp, {targetMemRef}); 1767 return success(); 1768 } 1769 }; 1770 1771 /// Conversion pattern that transforms a transpose op into: 1772 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1773 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1774 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1775 /// and stride. Size and stride are permutations of the original values. 1776 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1777 /// The transpose op is replaced by the alloca'ed pointer. 1778 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1779 public: 1780 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1781 1782 LogicalResult 1783 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1784 ConversionPatternRewriter &rewriter) const override { 1785 auto loc = transposeOp.getLoc(); 1786 MemRefDescriptor viewMemRef(adaptor.getIn()); 1787 1788 // No permutation, early exit. 1789 if (transposeOp.getPermutation().isIdentity()) 1790 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1791 1792 auto targetMemRef = MemRefDescriptor::undef( 1793 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1794 1795 // Copy the base and aligned pointers from the old descriptor to the new 1796 // one. 1797 targetMemRef.setAllocatedPtr(rewriter, loc, 1798 viewMemRef.allocatedPtr(rewriter, loc)); 1799 targetMemRef.setAlignedPtr(rewriter, loc, 1800 viewMemRef.alignedPtr(rewriter, loc)); 1801 1802 // Copy the offset pointer from the old descriptor to the new one. 1803 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1804 1805 // Iterate over the dimensions and apply size/stride permutation. 1806 for (const auto &en : 1807 llvm::enumerate(transposeOp.getPermutation().getResults())) { 1808 int sourcePos = en.index(); 1809 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1810 targetMemRef.setSize(rewriter, loc, targetPos, 1811 viewMemRef.size(rewriter, loc, sourcePos)); 1812 targetMemRef.setStride(rewriter, loc, targetPos, 1813 viewMemRef.stride(rewriter, loc, sourcePos)); 1814 } 1815 1816 rewriter.replaceOp(transposeOp, {targetMemRef}); 1817 return success(); 1818 } 1819 }; 1820 1821 /// Conversion pattern that transforms an op into: 1822 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1823 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1824 /// and stride. 1825 /// The view op is replaced by the descriptor. 1826 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1827 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1828 1829 // Build and return the value for the idx^th shape dimension, either by 1830 // returning the constant shape dimension or counting the proper dynamic size. 1831 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1832 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1833 unsigned idx) const { 1834 assert(idx < shape.size()); 1835 if (!ShapedType::isDynamic(shape[idx])) 1836 return createIndexConstant(rewriter, loc, shape[idx]); 1837 // Count the number of dynamic dims in range [0, idx] 1838 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1839 return ShapedType::isDynamic(v); 1840 }); 1841 return dynamicSizes[nDynamic]; 1842 } 1843 1844 // Build and return the idx^th stride, either by returning the constant stride 1845 // or by computing the dynamic stride from the current `runningStride` and 1846 // `nextSize`. The caller should keep a running stride and update it with the 1847 // result returned by this function. 1848 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1849 ArrayRef<int64_t> strides, Value nextSize, 1850 Value runningStride, unsigned idx) const { 1851 assert(idx < strides.size()); 1852 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1853 return createIndexConstant(rewriter, loc, strides[idx]); 1854 if (nextSize) 1855 return runningStride 1856 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1857 : nextSize; 1858 assert(!runningStride); 1859 return createIndexConstant(rewriter, loc, 1); 1860 } 1861 1862 LogicalResult 1863 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1864 ConversionPatternRewriter &rewriter) const override { 1865 auto loc = viewOp.getLoc(); 1866 1867 auto viewMemRefType = viewOp.getType(); 1868 auto targetElementTy = 1869 typeConverter->convertType(viewMemRefType.getElementType()); 1870 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1871 if (!targetDescTy || !targetElementTy || 1872 !LLVM::isCompatibleType(targetElementTy) || 1873 !LLVM::isCompatibleType(targetDescTy)) 1874 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1875 failure(); 1876 1877 int64_t offset; 1878 SmallVector<int64_t, 4> strides; 1879 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1880 if (failed(successStrides)) 1881 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1882 assert(offset == 0 && "expected offset to be 0"); 1883 1884 // Target memref must be contiguous in memory (innermost stride is 1), or 1885 // empty (special case when at least one of the memref dimensions is 0). 1886 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) 1887 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1888 failure(); 1889 1890 // Create the descriptor. 1891 MemRefDescriptor sourceMemRef(adaptor.getSource()); 1892 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1893 1894 // Field 1: Copy the allocated pointer, used for malloc/free. 1895 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1896 auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>(); 1897 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1898 loc, 1899 LLVM::LLVMPointerType::get(targetElementTy, 1900 srcMemRefType.getMemorySpaceAsInt()), 1901 allocatedPtr); 1902 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1903 1904 // Field 2: Copy the actual aligned pointer to payload. 1905 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1906 alignedPtr = rewriter.create<LLVM::GEPOp>( 1907 loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift()); 1908 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1909 loc, 1910 LLVM::LLVMPointerType::get(targetElementTy, 1911 srcMemRefType.getMemorySpaceAsInt()), 1912 alignedPtr); 1913 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1914 1915 // Field 3: The offset in the resulting type must be 0. This is because of 1916 // the type change: an offset on srcType* may not be expressible as an 1917 // offset on dstType*. 1918 targetMemRef.setOffset(rewriter, loc, 1919 createIndexConstant(rewriter, loc, offset)); 1920 1921 // Early exit for 0-D corner case. 1922 if (viewMemRefType.getRank() == 0) 1923 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1924 1925 // Fields 4 and 5: Update sizes and strides. 1926 Value stride = nullptr, nextSize = nullptr; 1927 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1928 // Update size. 1929 Value size = getSize(rewriter, loc, viewMemRefType.getShape(), 1930 adaptor.getSizes(), i); 1931 targetMemRef.setSize(rewriter, loc, i, size); 1932 // Update stride. 1933 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1934 targetMemRef.setStride(rewriter, loc, i, stride); 1935 nextSize = size; 1936 } 1937 1938 rewriter.replaceOp(viewOp, {targetMemRef}); 1939 return success(); 1940 } 1941 }; 1942 1943 //===----------------------------------------------------------------------===// 1944 // AtomicRMWOpLowering 1945 //===----------------------------------------------------------------------===// 1946 1947 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1948 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1949 static Optional<LLVM::AtomicBinOp> 1950 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1951 switch (atomicOp.getKind()) { 1952 case arith::AtomicRMWKind::addf: 1953 return LLVM::AtomicBinOp::fadd; 1954 case arith::AtomicRMWKind::addi: 1955 return LLVM::AtomicBinOp::add; 1956 case arith::AtomicRMWKind::assign: 1957 return LLVM::AtomicBinOp::xchg; 1958 case arith::AtomicRMWKind::maxs: 1959 return LLVM::AtomicBinOp::max; 1960 case arith::AtomicRMWKind::maxu: 1961 return LLVM::AtomicBinOp::umax; 1962 case arith::AtomicRMWKind::mins: 1963 return LLVM::AtomicBinOp::min; 1964 case arith::AtomicRMWKind::minu: 1965 return LLVM::AtomicBinOp::umin; 1966 case arith::AtomicRMWKind::ori: 1967 return LLVM::AtomicBinOp::_or; 1968 case arith::AtomicRMWKind::andi: 1969 return LLVM::AtomicBinOp::_and; 1970 default: 1971 return llvm::None; 1972 } 1973 llvm_unreachable("Invalid AtomicRMWKind"); 1974 } 1975 1976 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1977 using Base::Base; 1978 1979 LogicalResult 1980 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1981 ConversionPatternRewriter &rewriter) const override { 1982 if (failed(match(atomicOp))) 1983 return failure(); 1984 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1985 if (!maybeKind) 1986 return failure(); 1987 auto resultType = adaptor.getValue().getType(); 1988 auto memRefType = atomicOp.getMemRefType(); 1989 auto dataPtr = 1990 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), 1991 adaptor.getIndices(), rewriter); 1992 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1993 atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(), 1994 LLVM::AtomicOrdering::acq_rel); 1995 return success(); 1996 } 1997 }; 1998 1999 } // namespace 2000 2001 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 2002 RewritePatternSet &patterns) { 2003 // clang-format off 2004 patterns.add< 2005 AllocaOpLowering, 2006 AllocaScopeOpLowering, 2007 AtomicRMWOpLowering, 2008 AssumeAlignmentOpLowering, 2009 DimOpLowering, 2010 GenericAtomicRMWOpLowering, 2011 GlobalMemrefOpLowering, 2012 GetGlobalMemrefOpLowering, 2013 LoadOpLowering, 2014 MemRefCastOpLowering, 2015 MemRefCopyOpLowering, 2016 MemRefReinterpretCastOpLowering, 2017 MemRefReshapeOpLowering, 2018 PrefetchOpLowering, 2019 RankOpLowering, 2020 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 2021 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 2022 StoreOpLowering, 2023 SubViewOpLowering, 2024 TransposeOpLowering, 2025 ViewOpLowering>(converter); 2026 // clang-format on 2027 auto allocLowering = converter.getOptions().allocLowering; 2028 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 2029 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 2030 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 2031 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 2032 } 2033 2034 namespace { 2035 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 2036 MemRefToLLVMPass() = default; 2037 2038 void runOnOperation() override { 2039 Operation *op = getOperation(); 2040 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 2041 LowerToLLVMOptions options(&getContext(), 2042 dataLayoutAnalysis.getAtOrAbove(op)); 2043 options.allocLowering = 2044 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 2045 : LowerToLLVMOptions::AllocLowering::Malloc); 2046 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 2047 options.overrideIndexBitwidth(indexBitwidth); 2048 2049 LLVMTypeConverter typeConverter(&getContext(), options, 2050 &dataLayoutAnalysis); 2051 RewritePatternSet patterns(&getContext()); 2052 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 2053 LLVMConversionTarget target(getContext()); 2054 target.addLegalOp<func::FuncOp>(); 2055 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 2056 signalPassFailure(); 2057 } 2058 }; 2059 } // namespace 2060 2061 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 2062 return std::make_unique<MemRefToLLVMPass>(); 2063 } 2064