1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "llvm/ADT/SmallBitVector.h" 24 25 using namespace mlir; 26 27 namespace { 28 29 bool isStaticStrideOrOffset(int64_t strideOrOffset) { 30 return !ShapedType::isDynamicStrideOrOffset(strideOrOffset); 31 } 32 33 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 34 AllocOpLowering(LLVMTypeConverter &converter) 35 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 36 converter) {} 37 38 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 39 Location loc, Value sizeBytes, 40 Operation *op) const override { 41 // Heap allocations. 42 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 43 MemRefType memRefType = allocOp.getType(); 44 45 Value alignment; 46 if (auto alignmentAttr = allocOp.alignment()) { 47 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 48 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 49 // In the case where no alignment is specified, we may want to override 50 // `malloc's` behavior. `malloc` typically aligns at the size of the 51 // biggest scalar on a target HW. For non-scalars, use the natural 52 // alignment of the LLVM type given by the LLVM DataLayout. 53 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 54 } 55 56 if (alignment) { 57 // Adjust the allocation size to consider alignment. 58 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 59 } 60 61 // Allocate the underlying buffer and store a pointer to it in the MemRef 62 // descriptor. 63 Type elementPtrType = this->getElementPtrType(memRefType); 64 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 65 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 66 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 67 getVoidPtrType()); 68 Value allocatedPtr = 69 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 70 71 Value alignedPtr = allocatedPtr; 72 if (alignment) { 73 // Compute the aligned type pointer. 74 Value allocatedInt = 75 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 76 Value alignmentInt = 77 createAligned(rewriter, loc, allocatedInt, alignment); 78 alignedPtr = 79 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 80 } 81 82 return std::make_tuple(allocatedPtr, alignedPtr); 83 } 84 }; 85 86 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 87 AlignedAllocOpLowering(LLVMTypeConverter &converter) 88 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 89 converter) {} 90 91 /// Returns the memref's element size in bytes using the data layout active at 92 /// `op`. 93 // TODO: there are other places where this is used. Expose publicly? 94 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 95 const DataLayout *layout = &defaultLayout; 96 if (const DataLayoutAnalysis *analysis = 97 getTypeConverter()->getDataLayoutAnalysis()) { 98 layout = &analysis->getAbove(op); 99 } 100 Type elementType = memRefType.getElementType(); 101 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 102 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 103 *layout); 104 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 105 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 106 memRefElementType, *layout); 107 return layout->getTypeSize(elementType); 108 } 109 110 /// Returns true if the memref size in bytes is known to be a multiple of 111 /// factor assuming the data layout active at `op`. 112 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 113 Operation *op) const { 114 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 115 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 116 if (ShapedType::isDynamic(type.getDimSize(i))) 117 continue; 118 sizeDivisor = sizeDivisor * type.getDimSize(i); 119 } 120 return sizeDivisor % factor == 0; 121 } 122 123 /// Returns the alignment to be used for the allocation call itself. 124 /// aligned_alloc requires the allocation size to be a power of two, and the 125 /// allocation size to be a multiple of alignment, 126 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 127 if (Optional<uint64_t> alignment = allocOp.alignment()) 128 return *alignment; 129 130 // Whenever we don't have alignment set, we will use an alignment 131 // consistent with the element type; since the allocation size has to be a 132 // power of two, we will bump to the next power of two if it already isn't. 133 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 134 return std::max(kMinAlignedAllocAlignment, 135 llvm::PowerOf2Ceil(eltSizeBytes)); 136 } 137 138 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 139 Location loc, Value sizeBytes, 140 Operation *op) const override { 141 // Heap allocations. 142 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 143 MemRefType memRefType = allocOp.getType(); 144 int64_t alignment = getAllocationAlignment(allocOp); 145 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 146 147 // aligned_alloc requires size to be a multiple of alignment; we will pad 148 // the size to the next multiple if necessary. 149 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 150 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 151 152 Type elementPtrType = this->getElementPtrType(memRefType); 153 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 154 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 155 auto results = 156 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 157 getVoidPtrType()); 158 Value allocatedPtr = 159 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 160 161 return std::make_tuple(allocatedPtr, allocatedPtr); 162 } 163 164 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 165 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 166 167 /// Default layout to use in absence of the corresponding analysis. 168 DataLayout defaultLayout; 169 }; 170 171 // Out of line definition, required till C++17. 172 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 173 174 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 175 AllocaOpLowering(LLVMTypeConverter &converter) 176 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 177 converter) {} 178 179 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 180 /// is set to null for stack allocations. `accessAlignment` is set if 181 /// alignment is needed post allocation (for eg. in conjunction with malloc). 182 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 183 Location loc, Value sizeBytes, 184 Operation *op) const override { 185 186 // With alloca, one gets a pointer to the element type right away. 187 // For stack allocations. 188 auto allocaOp = cast<memref::AllocaOp>(op); 189 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 190 191 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 192 loc, elementPtrType, sizeBytes, 193 allocaOp.alignment() ? *allocaOp.alignment() : 0); 194 195 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 196 } 197 }; 198 199 struct AllocaScopeOpLowering 200 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 201 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 202 203 LogicalResult 204 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 205 ConversionPatternRewriter &rewriter) const override { 206 OpBuilder::InsertionGuard guard(rewriter); 207 Location loc = allocaScopeOp.getLoc(); 208 209 // Split the current block before the AllocaScopeOp to create the inlining 210 // point. 211 auto *currentBlock = rewriter.getInsertionBlock(); 212 auto *remainingOpsBlock = 213 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 214 Block *continueBlock; 215 if (allocaScopeOp.getNumResults() == 0) { 216 continueBlock = remainingOpsBlock; 217 } else { 218 continueBlock = rewriter.createBlock( 219 remainingOpsBlock, allocaScopeOp.getResultTypes(), 220 SmallVector<Location>(allocaScopeOp->getNumResults(), 221 allocaScopeOp.getLoc())); 222 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 223 } 224 225 // Inline body region. 226 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 227 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 228 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 229 230 // Save stack and then branch into the body of the region. 231 rewriter.setInsertionPointToEnd(currentBlock); 232 auto stackSaveOp = 233 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 234 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 235 236 // Replace the alloca_scope return with a branch that jumps out of the body. 237 // Stack restore before leaving the body region. 238 rewriter.setInsertionPointToEnd(afterBody); 239 auto returnOp = 240 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 241 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 242 returnOp, returnOp.results(), continueBlock); 243 244 // Insert stack restore before jumping out the body of the region. 245 rewriter.setInsertionPoint(branchOp); 246 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 247 248 // Replace the op with values return from the body region. 249 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 250 251 return success(); 252 } 253 }; 254 255 struct AssumeAlignmentOpLowering 256 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 257 using ConvertOpToLLVMPattern< 258 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 259 260 LogicalResult 261 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 262 ConversionPatternRewriter &rewriter) const override { 263 Value memref = adaptor.memref(); 264 unsigned alignment = op.alignment(); 265 auto loc = op.getLoc(); 266 267 MemRefDescriptor memRefDescriptor(memref); 268 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 269 270 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 271 // the asserted memref.alignedPtr isn't used anywhere else, as the real 272 // users like load/store/views always re-extract memref.alignedPtr as they 273 // get lowered. 274 // 275 // This relies on LLVM's CSE optimization (potentially after SROA), since 276 // after CSE all memref.alignedPtr instances get de-duplicated into the same 277 // pointer SSA value. 278 auto intPtrType = 279 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 280 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 281 Value mask = 282 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 283 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 284 rewriter.create<LLVM::AssumeOp>( 285 loc, rewriter.create<LLVM::ICmpOp>( 286 loc, LLVM::ICmpPredicate::eq, 287 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 288 289 rewriter.eraseOp(op); 290 return success(); 291 } 292 }; 293 294 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 295 // The memref descriptor being an SSA value, there is no need to clean it up 296 // in any way. 297 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 298 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 299 300 explicit DeallocOpLowering(LLVMTypeConverter &converter) 301 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 302 303 LogicalResult 304 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 305 ConversionPatternRewriter &rewriter) const override { 306 // Insert the `free` declaration if it is not already present. 307 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 308 MemRefDescriptor memref(adaptor.memref()); 309 Value casted = rewriter.create<LLVM::BitcastOp>( 310 op.getLoc(), getVoidPtrType(), 311 memref.allocatedPtr(rewriter, op.getLoc())); 312 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 313 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 314 return success(); 315 } 316 }; 317 318 // A `dim` is converted to a constant for static sizes and to an access to the 319 // size stored in the memref descriptor for dynamic sizes. 320 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 321 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 322 323 LogicalResult 324 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 325 ConversionPatternRewriter &rewriter) const override { 326 Type operandType = dimOp.source().getType(); 327 if (operandType.isa<UnrankedMemRefType>()) { 328 rewriter.replaceOp( 329 dimOp, {extractSizeOfUnrankedMemRef( 330 operandType, dimOp, adaptor.getOperands(), rewriter)}); 331 332 return success(); 333 } 334 if (operandType.isa<MemRefType>()) { 335 rewriter.replaceOp( 336 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 337 adaptor.getOperands(), rewriter)}); 338 return success(); 339 } 340 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 341 } 342 343 private: 344 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 345 OpAdaptor adaptor, 346 ConversionPatternRewriter &rewriter) const { 347 Location loc = dimOp.getLoc(); 348 349 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 350 auto scalarMemRefType = 351 MemRefType::get({}, unrankedMemRefType.getElementType()); 352 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 353 354 // Extract pointer to the underlying ranked descriptor and bitcast it to a 355 // memref<element_type> descriptor pointer to minimize the number of GEP 356 // operations. 357 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 358 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 359 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 360 loc, 361 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 362 addressSpace), 363 underlyingRankedDesc); 364 365 // Get pointer to offset field of memref<element_type> descriptor. 366 Type indexPtrTy = LLVM::LLVMPointerType::get( 367 getTypeConverter()->getIndexType(), addressSpace); 368 Value two = rewriter.create<LLVM::ConstantOp>( 369 loc, typeConverter->convertType(rewriter.getI32Type()), 370 rewriter.getI32IntegerAttr(2)); 371 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 372 loc, indexPtrTy, scalarMemRefDescPtr, 373 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 374 375 // The size value that we have to extract can be obtained using GEPop with 376 // `dimOp.index() + 1` index argument. 377 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 378 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 379 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 380 ValueRange({idxPlusOne})); 381 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 382 } 383 384 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 385 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 386 return idx; 387 388 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 389 return constantOp.getValue() 390 .cast<IntegerAttr>() 391 .getValue() 392 .getSExtValue(); 393 394 return llvm::None; 395 } 396 397 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 398 OpAdaptor adaptor, 399 ConversionPatternRewriter &rewriter) const { 400 Location loc = dimOp.getLoc(); 401 402 // Take advantage if index is constant. 403 MemRefType memRefType = operandType.cast<MemRefType>(); 404 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 405 int64_t i = index.getValue(); 406 if (memRefType.isDynamicDim(i)) { 407 // extract dynamic size from the memref descriptor. 408 MemRefDescriptor descriptor(adaptor.source()); 409 return descriptor.size(rewriter, loc, i); 410 } 411 // Use constant for static size. 412 int64_t dimSize = memRefType.getDimSize(i); 413 return createIndexConstant(rewriter, loc, dimSize); 414 } 415 Value index = adaptor.index(); 416 int64_t rank = memRefType.getRank(); 417 MemRefDescriptor memrefDescriptor(adaptor.source()); 418 return memrefDescriptor.size(rewriter, loc, index, rank); 419 } 420 }; 421 422 /// Common base for load and store operations on MemRefs. Restricts the match 423 /// to supported MemRef types. Provides functionality to emit code accessing a 424 /// specific element of the underlying data buffer. 425 template <typename Derived> 426 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 427 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 428 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 429 using Base = LoadStoreOpLowering<Derived>; 430 431 LogicalResult match(Derived op) const override { 432 MemRefType type = op.getMemRefType(); 433 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 434 } 435 }; 436 437 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 438 /// retried until it succeeds in atomically storing a new value into memory. 439 /// 440 /// +---------------------------------+ 441 /// | <code before the AtomicRMWOp> | 442 /// | <compute initial %loaded> | 443 /// | cf.br loop(%loaded) | 444 /// +---------------------------------+ 445 /// | 446 /// -------| | 447 /// | v v 448 /// | +--------------------------------+ 449 /// | | loop(%loaded): | 450 /// | | <body contents> | 451 /// | | %pair = cmpxchg | 452 /// | | %ok = %pair[0] | 453 /// | | %new = %pair[1] | 454 /// | | cf.cond_br %ok, end, loop(%new) | 455 /// | +--------------------------------+ 456 /// | | | 457 /// |----------- | 458 /// v 459 /// +--------------------------------+ 460 /// | end: | 461 /// | <code after the AtomicRMWOp> | 462 /// +--------------------------------+ 463 /// 464 struct GenericAtomicRMWOpLowering 465 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 466 using Base::Base; 467 468 LogicalResult 469 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 470 ConversionPatternRewriter &rewriter) const override { 471 auto loc = atomicOp.getLoc(); 472 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 473 474 // Split the block into initial, loop, and ending parts. 475 auto *initBlock = rewriter.getInsertionBlock(); 476 auto *loopBlock = rewriter.createBlock( 477 initBlock->getParent(), std::next(Region::iterator(initBlock)), 478 valueType, loc); 479 auto *endBlock = rewriter.createBlock( 480 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 481 482 // Operations range to be moved to `endBlock`. 483 auto opsToMoveStart = atomicOp->getIterator(); 484 auto opsToMoveEnd = initBlock->back().getIterator(); 485 486 // Compute the loaded value and branch to the loop block. 487 rewriter.setInsertionPointToEnd(initBlock); 488 auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); 489 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 490 adaptor.indices(), rewriter); 491 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 492 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 493 494 // Prepare the body of the loop block. 495 rewriter.setInsertionPointToStart(loopBlock); 496 497 // Clone the GenericAtomicRMWOp region and extract the result. 498 auto loopArgument = loopBlock->getArgument(0); 499 BlockAndValueMapping mapping; 500 mapping.map(atomicOp.getCurrentValue(), loopArgument); 501 Block &entryBlock = atomicOp.body().front(); 502 for (auto &nestedOp : entryBlock.without_terminator()) { 503 Operation *clone = rewriter.clone(nestedOp, mapping); 504 mapping.map(nestedOp.getResults(), clone->getResults()); 505 } 506 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 507 508 // Prepare the epilog of the loop block. 509 // Append the cmpxchg op to the end of the loop block. 510 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 511 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 512 auto boolType = IntegerType::get(rewriter.getContext(), 1); 513 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 514 {valueType, boolType}); 515 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 516 loc, pairType, dataPtr, loopArgument, result, successOrdering, 517 failureOrdering); 518 // Extract the %new_loaded and %ok values from the pair. 519 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 520 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 521 Value ok = rewriter.create<LLVM::ExtractValueOp>( 522 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 523 524 // Conditionally branch to the end or back to the loop depending on %ok. 525 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 526 loopBlock, newLoaded); 527 528 rewriter.setInsertionPointToEnd(endBlock); 529 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 530 std::next(opsToMoveEnd), rewriter); 531 532 // The 'result' of the atomic_rmw op is the newly loaded value. 533 rewriter.replaceOp(atomicOp, {newLoaded}); 534 535 return success(); 536 } 537 538 private: 539 // Clones a segment of ops [start, end) and erases the original. 540 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 541 Block::iterator start, Block::iterator end, 542 ConversionPatternRewriter &rewriter) const { 543 BlockAndValueMapping mapping; 544 mapping.map(oldResult, newResult); 545 SmallVector<Operation *, 2> opsToErase; 546 for (auto it = start; it != end; ++it) { 547 rewriter.clone(*it, mapping); 548 opsToErase.push_back(&*it); 549 } 550 for (auto *it : opsToErase) 551 rewriter.eraseOp(it); 552 } 553 }; 554 555 /// Returns the LLVM type of the global variable given the memref type `type`. 556 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 557 LLVMTypeConverter &typeConverter) { 558 // LLVM type for a global memref will be a multi-dimension array. For 559 // declarations or uninitialized global memrefs, we can potentially flatten 560 // this to a 1D array. However, for memref.global's with an initial value, 561 // we do not intend to flatten the ElementsAttribute when going from std -> 562 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 563 Type elementType = typeConverter.convertType(type.getElementType()); 564 Type arrayTy = elementType; 565 // Shape has the outermost dim at index 0, so need to walk it backwards 566 for (int64_t dim : llvm::reverse(type.getShape())) 567 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 568 return arrayTy; 569 } 570 571 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 572 struct GlobalMemrefOpLowering 573 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 574 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 575 576 LogicalResult 577 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 578 ConversionPatternRewriter &rewriter) const override { 579 MemRefType type = global.type(); 580 if (!isConvertibleAndHasIdentityMaps(type)) 581 return failure(); 582 583 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 584 585 LLVM::Linkage linkage = 586 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 587 588 Attribute initialValue = nullptr; 589 if (!global.isExternal() && !global.isUninitialized()) { 590 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 591 initialValue = elementsAttr; 592 593 // For scalar memrefs, the global variable created is of the element type, 594 // so unpack the elements attribute to extract the value. 595 if (type.getRank() == 0) 596 initialValue = elementsAttr.getSplatValue<Attribute>(); 597 } 598 599 uint64_t alignment = global.alignment().getValueOr(0); 600 601 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 602 global, arrayTy, global.constant(), linkage, global.sym_name(), 603 initialValue, alignment, type.getMemorySpaceAsInt()); 604 if (!global.isExternal() && global.isUninitialized()) { 605 Block *blk = new Block(); 606 newGlobal.getInitializerRegion().push_back(blk); 607 rewriter.setInsertionPointToStart(blk); 608 Value undef[] = { 609 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 610 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 611 } 612 return success(); 613 } 614 }; 615 616 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 617 /// the first element stashed into the descriptor. This reuses 618 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 619 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 620 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 621 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 622 converter) {} 623 624 /// Buffer "allocation" for memref.get_global op is getting the address of 625 /// the global variable referenced. 626 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 627 Location loc, Value sizeBytes, 628 Operation *op) const override { 629 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 630 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 631 unsigned memSpace = type.getMemorySpaceAsInt(); 632 633 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 634 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 635 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 636 637 // Get the address of the first element in the array by creating a GEP with 638 // the address of the GV as the base, and (rank + 1) number of 0 indices. 639 Type elementType = typeConverter->convertType(type.getElementType()); 640 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 641 642 SmallVector<Value> operands; 643 operands.insert(operands.end(), type.getRank() + 1, 644 createIndexConstant(rewriter, loc, 0)); 645 auto gep = 646 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 647 648 // We do not expect the memref obtained using `memref.get_global` to be 649 // ever deallocated. Set the allocated pointer to be known bad value to 650 // help debug if that ever happens. 651 auto intPtrType = getIntPtrType(memSpace); 652 Value deadBeefConst = 653 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 654 auto deadBeefPtr = 655 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 656 657 // Both allocated and aligned pointers are same. We could potentially stash 658 // a nullptr for the allocated pointer since we do not expect any dealloc. 659 return std::make_tuple(deadBeefPtr, gep); 660 } 661 }; 662 663 // Load operation is lowered to obtaining a pointer to the indexed element 664 // and loading it. 665 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 666 using Base::Base; 667 668 LogicalResult 669 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 670 ConversionPatternRewriter &rewriter) const override { 671 auto type = loadOp.getMemRefType(); 672 673 Value dataPtr = getStridedElementPtr( 674 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 675 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 676 return success(); 677 } 678 }; 679 680 // Store operation is lowered to obtaining a pointer to the indexed element, 681 // and storing the given value to it. 682 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 683 using Base::Base; 684 685 LogicalResult 686 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 687 ConversionPatternRewriter &rewriter) const override { 688 auto type = op.getMemRefType(); 689 690 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 691 adaptor.indices(), rewriter); 692 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 693 return success(); 694 } 695 }; 696 697 // The prefetch operation is lowered in a way similar to the load operation 698 // except that the llvm.prefetch operation is used for replacement. 699 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 700 using Base::Base; 701 702 LogicalResult 703 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 704 ConversionPatternRewriter &rewriter) const override { 705 auto type = prefetchOp.getMemRefType(); 706 auto loc = prefetchOp.getLoc(); 707 708 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 709 adaptor.indices(), rewriter); 710 711 // Replace with llvm.prefetch. 712 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 713 auto isWrite = rewriter.create<LLVM::ConstantOp>( 714 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 715 auto localityHint = rewriter.create<LLVM::ConstantOp>( 716 loc, llvmI32Type, 717 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 718 auto isData = rewriter.create<LLVM::ConstantOp>( 719 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 720 721 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 722 localityHint, isData); 723 return success(); 724 } 725 }; 726 727 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 728 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 729 730 LogicalResult 731 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 732 ConversionPatternRewriter &rewriter) const override { 733 Location loc = op.getLoc(); 734 Type operandType = op.memref().getType(); 735 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 736 UnrankedMemRefDescriptor desc(adaptor.memref()); 737 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 738 return success(); 739 } 740 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 741 rewriter.replaceOp( 742 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 743 return success(); 744 } 745 return failure(); 746 } 747 }; 748 749 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 750 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 751 752 LogicalResult match(memref::CastOp memRefCastOp) const override { 753 Type srcType = memRefCastOp.getOperand().getType(); 754 Type dstType = memRefCastOp.getType(); 755 756 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 757 // used for type erasure. For now they must preserve underlying element type 758 // and require source and result type to have the same rank. Therefore, 759 // perform a sanity check that the underlying structs are the same. Once op 760 // semantics are relaxed we can revisit. 761 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 762 return success(typeConverter->convertType(srcType) == 763 typeConverter->convertType(dstType)); 764 765 // At least one of the operands is unranked type 766 assert(srcType.isa<UnrankedMemRefType>() || 767 dstType.isa<UnrankedMemRefType>()); 768 769 // Unranked to unranked cast is disallowed 770 return !(srcType.isa<UnrankedMemRefType>() && 771 dstType.isa<UnrankedMemRefType>()) 772 ? success() 773 : failure(); 774 } 775 776 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 777 ConversionPatternRewriter &rewriter) const override { 778 auto srcType = memRefCastOp.getOperand().getType(); 779 auto dstType = memRefCastOp.getType(); 780 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 781 auto loc = memRefCastOp.getLoc(); 782 783 // For ranked/ranked case, just keep the original descriptor. 784 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 785 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 786 787 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 788 // Casting ranked to unranked memref type 789 // Set the rank in the destination from the memref type 790 // Allocate space on the stack and copy the src memref descriptor 791 // Set the ptr in the destination to the stack space 792 auto srcMemRefType = srcType.cast<MemRefType>(); 793 int64_t rank = srcMemRefType.getRank(); 794 // ptr = AllocaOp sizeof(MemRefDescriptor) 795 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 796 loc, adaptor.source(), rewriter); 797 // voidptr = BitCastOp srcType* to void* 798 auto voidPtr = 799 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 800 .getResult(); 801 // rank = ConstantOp srcRank 802 auto rankVal = rewriter.create<LLVM::ConstantOp>( 803 loc, getIndexType(), rewriter.getIndexAttr(rank)); 804 // undef = UndefOp 805 UnrankedMemRefDescriptor memRefDesc = 806 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 807 // d1 = InsertValueOp undef, rank, 0 808 memRefDesc.setRank(rewriter, loc, rankVal); 809 // d2 = InsertValueOp d1, voidptr, 1 810 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 811 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 812 813 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 814 // Casting from unranked type to ranked. 815 // The operation is assumed to be doing a correct cast. If the destination 816 // type mismatches the unranked the type, it is undefined behavior. 817 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 818 // ptr = ExtractValueOp src, 1 819 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 820 // castPtr = BitCastOp i8* to structTy* 821 auto castPtr = 822 rewriter 823 .create<LLVM::BitcastOp>( 824 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 825 .getResult(); 826 // struct = LoadOp castPtr 827 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 828 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 829 } else { 830 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 831 } 832 } 833 }; 834 835 /// Pattern to lower a `memref.copy` to llvm. 836 /// 837 /// For memrefs with identity layouts, the copy is lowered to the llvm 838 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 839 /// to the generic `MemrefCopyFn`. 840 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 841 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 842 843 LogicalResult 844 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 845 ConversionPatternRewriter &rewriter) const { 846 auto loc = op.getLoc(); 847 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 848 849 MemRefDescriptor srcDesc(adaptor.source()); 850 851 // Compute number of elements. 852 Value numElements = rewriter.create<LLVM::ConstantOp>( 853 loc, getIndexType(), rewriter.getIndexAttr(1)); 854 for (int pos = 0; pos < srcType.getRank(); ++pos) { 855 auto size = srcDesc.size(rewriter, loc, pos); 856 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 857 } 858 859 // Get element size. 860 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 861 // Compute total. 862 Value totalSize = 863 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 864 865 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 866 Value srcOffset = srcDesc.offset(rewriter, loc); 867 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(), 868 srcBasePtr, srcOffset); 869 MemRefDescriptor targetDesc(adaptor.target()); 870 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 871 Value targetOffset = targetDesc.offset(rewriter, loc); 872 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(), 873 targetBasePtr, targetOffset); 874 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 875 loc, typeConverter->convertType(rewriter.getI1Type()), 876 rewriter.getBoolAttr(false)); 877 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 878 isVolatile); 879 rewriter.eraseOp(op); 880 881 return success(); 882 } 883 884 LogicalResult 885 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 886 ConversionPatternRewriter &rewriter) const { 887 auto loc = op.getLoc(); 888 auto srcType = op.source().getType().cast<BaseMemRefType>(); 889 auto targetType = op.target().getType().cast<BaseMemRefType>(); 890 891 // First make sure we have an unranked memref descriptor representation. 892 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 893 auto rank = rewriter.create<LLVM::ConstantOp>( 894 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 895 auto *typeConverter = getTypeConverter(); 896 auto ptr = 897 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 898 auto voidPtr = 899 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 900 .getResult(); 901 auto unrankedType = 902 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 903 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 904 unrankedType, 905 ValueRange{rank, voidPtr}); 906 }; 907 908 Value unrankedSource = srcType.hasRank() 909 ? makeUnranked(adaptor.source(), srcType) 910 : adaptor.source(); 911 Value unrankedTarget = targetType.hasRank() 912 ? makeUnranked(adaptor.target(), targetType) 913 : adaptor.target(); 914 915 // Now promote the unranked descriptors to the stack. 916 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 917 rewriter.getIndexAttr(1)); 918 auto promote = [&](Value desc) { 919 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 920 auto allocated = 921 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 922 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 923 return allocated; 924 }; 925 926 auto sourcePtr = promote(unrankedSource); 927 auto targetPtr = promote(unrankedTarget); 928 929 unsigned typeSize = 930 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 931 auto elemSize = rewriter.create<LLVM::ConstantOp>( 932 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 933 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 934 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 935 rewriter.create<LLVM::CallOp>(loc, copyFn, 936 ValueRange{elemSize, sourcePtr, targetPtr}); 937 rewriter.eraseOp(op); 938 939 return success(); 940 } 941 942 LogicalResult 943 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 944 ConversionPatternRewriter &rewriter) const override { 945 auto srcType = op.source().getType().cast<BaseMemRefType>(); 946 auto targetType = op.target().getType().cast<BaseMemRefType>(); 947 948 auto isContiguousMemrefType = [](BaseMemRefType type) { 949 auto memrefType = type.dyn_cast<mlir::MemRefType>(); 950 // We can use memcpy for memrefs if they have an identity layout or are 951 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 952 // special case handled by memrefCopy. 953 return memrefType && 954 (memrefType.getLayout().isIdentity() || 955 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 956 isStaticShapeAndContiguousRowMajor(memrefType))); 957 }; 958 959 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 960 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 961 962 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 963 } 964 }; 965 966 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 967 /// memref type. In unranked case, the fields are extracted from the underlying 968 /// ranked descriptor. 969 static void extractPointersAndOffset(Location loc, 970 ConversionPatternRewriter &rewriter, 971 LLVMTypeConverter &typeConverter, 972 Value originalOperand, 973 Value convertedOperand, 974 Value *allocatedPtr, Value *alignedPtr, 975 Value *offset = nullptr) { 976 Type operandType = originalOperand.getType(); 977 if (operandType.isa<MemRefType>()) { 978 MemRefDescriptor desc(convertedOperand); 979 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 980 *alignedPtr = desc.alignedPtr(rewriter, loc); 981 if (offset != nullptr) 982 *offset = desc.offset(rewriter, loc); 983 return; 984 } 985 986 unsigned memorySpace = 987 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 988 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 989 Type llvmElementType = typeConverter.convertType(elementType); 990 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 991 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 992 993 // Extract pointer to the underlying ranked memref descriptor and cast it to 994 // ElemType**. 995 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 996 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 997 998 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 999 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 1000 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 1001 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1002 if (offset != nullptr) { 1003 *offset = UnrankedMemRefDescriptor::offset( 1004 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1005 } 1006 } 1007 1008 struct MemRefReinterpretCastOpLowering 1009 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1010 using ConvertOpToLLVMPattern< 1011 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1012 1013 LogicalResult 1014 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1015 ConversionPatternRewriter &rewriter) const override { 1016 Type srcType = castOp.source().getType(); 1017 1018 Value descriptor; 1019 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1020 adaptor, &descriptor))) 1021 return failure(); 1022 rewriter.replaceOp(castOp, {descriptor}); 1023 return success(); 1024 } 1025 1026 private: 1027 LogicalResult convertSourceMemRefToDescriptor( 1028 ConversionPatternRewriter &rewriter, Type srcType, 1029 memref::ReinterpretCastOp castOp, 1030 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1031 MemRefType targetMemRefType = 1032 castOp.getResult().getType().cast<MemRefType>(); 1033 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1034 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1035 if (!llvmTargetDescriptorTy) 1036 return failure(); 1037 1038 // Create descriptor. 1039 Location loc = castOp.getLoc(); 1040 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1041 1042 // Set allocated and aligned pointers. 1043 Value allocatedPtr, alignedPtr; 1044 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1045 castOp.source(), adaptor.source(), &allocatedPtr, 1046 &alignedPtr); 1047 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1048 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1049 1050 // Set offset. 1051 if (castOp.isDynamicOffset(0)) 1052 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 1053 else 1054 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1055 1056 // Set sizes and strides. 1057 unsigned dynSizeId = 0; 1058 unsigned dynStrideId = 0; 1059 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1060 if (castOp.isDynamicSize(i)) 1061 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 1062 else 1063 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1064 1065 if (castOp.isDynamicStride(i)) 1066 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 1067 else 1068 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1069 } 1070 *descriptor = desc; 1071 return success(); 1072 } 1073 }; 1074 1075 struct MemRefReshapeOpLowering 1076 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1077 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1078 1079 LogicalResult 1080 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1081 ConversionPatternRewriter &rewriter) const override { 1082 Type srcType = reshapeOp.source().getType(); 1083 1084 Value descriptor; 1085 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1086 adaptor, &descriptor))) 1087 return failure(); 1088 rewriter.replaceOp(reshapeOp, {descriptor}); 1089 return success(); 1090 } 1091 1092 private: 1093 LogicalResult 1094 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1095 Type srcType, memref::ReshapeOp reshapeOp, 1096 memref::ReshapeOp::Adaptor adaptor, 1097 Value *descriptor) const { 1098 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 1099 if (shapeMemRefType.hasStaticShape()) { 1100 MemRefType targetMemRefType = 1101 reshapeOp.getResult().getType().cast<MemRefType>(); 1102 auto llvmTargetDescriptorTy = 1103 typeConverter->convertType(targetMemRefType) 1104 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1105 if (!llvmTargetDescriptorTy) 1106 return failure(); 1107 1108 // Create descriptor. 1109 Location loc = reshapeOp.getLoc(); 1110 auto desc = 1111 MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1112 1113 // Set allocated and aligned pointers. 1114 Value allocatedPtr, alignedPtr; 1115 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1116 reshapeOp.source(), adaptor.source(), 1117 &allocatedPtr, &alignedPtr); 1118 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1119 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1120 1121 // Extract the offset and strides from the type. 1122 int64_t offset; 1123 SmallVector<int64_t> strides; 1124 if (failed(getStridesAndOffset(targetMemRefType, strides, offset))) 1125 return rewriter.notifyMatchFailure( 1126 reshapeOp, "failed to get stride and offset exprs"); 1127 1128 if (!isStaticStrideOrOffset(offset)) 1129 return rewriter.notifyMatchFailure(reshapeOp, 1130 "dynamic offset is unsupported"); 1131 if (!llvm::all_of(strides, isStaticStrideOrOffset)) 1132 return rewriter.notifyMatchFailure(reshapeOp, 1133 "dynamic strides are unsupported"); 1134 1135 desc.setConstantOffset(rewriter, loc, offset); 1136 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1137 desc.setConstantSize(rewriter, loc, i, targetMemRefType.getDimSize(i)); 1138 desc.setConstantStride(rewriter, loc, i, strides[i]); 1139 } 1140 1141 *descriptor = desc; 1142 return success(); 1143 } 1144 1145 // The shape is a rank-1 tensor with unknown length. 1146 Location loc = reshapeOp.getLoc(); 1147 MemRefDescriptor shapeDesc(adaptor.shape()); 1148 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1149 1150 // Extract address space and element type. 1151 auto targetType = 1152 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1153 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1154 Type elementType = targetType.getElementType(); 1155 1156 // Create the unranked memref descriptor that holds the ranked one. The 1157 // inner descriptor is allocated on stack. 1158 auto targetDesc = UnrankedMemRefDescriptor::undef( 1159 rewriter, loc, typeConverter->convertType(targetType)); 1160 targetDesc.setRank(rewriter, loc, resultRank); 1161 SmallVector<Value, 4> sizes; 1162 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1163 targetDesc, sizes); 1164 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1165 loc, getVoidPtrType(), sizes.front(), llvm::None); 1166 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1167 1168 // Extract pointers and offset from the source memref. 1169 Value allocatedPtr, alignedPtr, offset; 1170 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1171 reshapeOp.source(), adaptor.source(), 1172 &allocatedPtr, &alignedPtr, &offset); 1173 1174 // Set pointers and offset. 1175 Type llvmElementType = typeConverter->convertType(elementType); 1176 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1177 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1178 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1179 elementPtrPtrType, allocatedPtr); 1180 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1181 underlyingDescPtr, 1182 elementPtrPtrType, alignedPtr); 1183 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1184 underlyingDescPtr, elementPtrPtrType, 1185 offset); 1186 1187 // Use the offset pointer as base for further addressing. Copy over the new 1188 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1189 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1190 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1191 elementPtrPtrType); 1192 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1193 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1194 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1195 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1196 Value resultRankMinusOne = 1197 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1198 1199 Block *initBlock = rewriter.getInsertionBlock(); 1200 Type indexType = getTypeConverter()->getIndexType(); 1201 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1202 1203 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1204 {indexType, indexType}, {loc, loc}); 1205 1206 // Move the remaining initBlock ops to condBlock. 1207 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1208 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1209 1210 rewriter.setInsertionPointToEnd(initBlock); 1211 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1212 condBlock); 1213 rewriter.setInsertionPointToStart(condBlock); 1214 Value indexArg = condBlock->getArgument(0); 1215 Value strideArg = condBlock->getArgument(1); 1216 1217 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1218 Value pred = rewriter.create<LLVM::ICmpOp>( 1219 loc, IntegerType::get(rewriter.getContext(), 1), 1220 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1221 1222 Block *bodyBlock = 1223 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1224 rewriter.setInsertionPointToStart(bodyBlock); 1225 1226 // Copy size from shape to descriptor. 1227 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1228 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1229 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1230 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1231 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1232 targetSizesBase, indexArg, size); 1233 1234 // Write stride value and compute next one. 1235 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1236 targetStridesBase, indexArg, strideArg); 1237 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1238 1239 // Decrement loop counter and branch back. 1240 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1241 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1242 condBlock); 1243 1244 Block *remainder = 1245 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1246 1247 // Hook up the cond exit to the remainder. 1248 rewriter.setInsertionPointToEnd(condBlock); 1249 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1250 llvm::None); 1251 1252 // Reset position to beginning of new remainder block. 1253 rewriter.setInsertionPointToStart(remainder); 1254 1255 *descriptor = targetDesc; 1256 return success(); 1257 } 1258 }; 1259 1260 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1261 /// `Value`s. 1262 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1263 Type &llvmIndexType, 1264 ArrayRef<OpFoldResult> valueOrAttrVec) { 1265 return llvm::to_vector<4>( 1266 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1267 if (auto attr = value.dyn_cast<Attribute>()) 1268 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1269 return value.get<Value>(); 1270 })); 1271 } 1272 1273 /// Compute a map that for a given dimension of the expanded type gives the 1274 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1275 /// the `reassocation` maps. 1276 static DenseMap<int64_t, int64_t> 1277 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1278 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1279 for (auto &en : enumerate(reassociation)) { 1280 for (auto dim : en.value()) 1281 expandedDimToCollapsedDim[dim] = en.index(); 1282 } 1283 return expandedDimToCollapsedDim; 1284 } 1285 1286 static OpFoldResult 1287 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1288 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1289 MemRefDescriptor &inDesc, 1290 ArrayRef<int64_t> inStaticShape, 1291 ArrayRef<ReassociationIndices> reassocation, 1292 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1293 int64_t outDimSize = outStaticShape[outDimIndex]; 1294 if (!ShapedType::isDynamic(outDimSize)) 1295 return b.getIndexAttr(outDimSize); 1296 1297 // Calculate the multiplication of all the out dim sizes except the 1298 // current dim. 1299 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1300 int64_t otherDimSizesMul = 1; 1301 for (auto otherDimIndex : reassocation[inDimIndex]) { 1302 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1303 continue; 1304 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1305 assert(!ShapedType::isDynamic(otherDimSize) && 1306 "single dimension cannot be expanded into multiple dynamic " 1307 "dimensions"); 1308 otherDimSizesMul *= otherDimSize; 1309 } 1310 1311 // outDimSize = inDimSize / otherOutDimSizesMul 1312 int64_t inDimSize = inStaticShape[inDimIndex]; 1313 Value inDimSizeDynamic = 1314 ShapedType::isDynamic(inDimSize) 1315 ? inDesc.size(b, loc, inDimIndex) 1316 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1317 b.getIndexAttr(inDimSize)); 1318 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1319 loc, inDimSizeDynamic, 1320 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1321 b.getIndexAttr(otherDimSizesMul))); 1322 return outDimSizeDynamic; 1323 } 1324 1325 static OpFoldResult getCollapsedOutputDimSize( 1326 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1327 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1328 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1329 if (!ShapedType::isDynamic(outDimSize)) 1330 return b.getIndexAttr(outDimSize); 1331 1332 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1333 Value outDimSizeDynamic = c1; 1334 for (auto inDimIndex : reassocation[outDimIndex]) { 1335 int64_t inDimSize = inStaticShape[inDimIndex]; 1336 Value inDimSizeDynamic = 1337 ShapedType::isDynamic(inDimSize) 1338 ? inDesc.size(b, loc, inDimIndex) 1339 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1340 b.getIndexAttr(inDimSize)); 1341 outDimSizeDynamic = 1342 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1343 } 1344 return outDimSizeDynamic; 1345 } 1346 1347 static SmallVector<OpFoldResult, 4> 1348 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1349 ArrayRef<ReassociationIndices> reassociation, 1350 ArrayRef<int64_t> inStaticShape, 1351 MemRefDescriptor &inDesc, 1352 ArrayRef<int64_t> outStaticShape) { 1353 return llvm::to_vector<4>(llvm::map_range( 1354 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1355 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1356 outStaticShape[outDimIndex], 1357 inStaticShape, inDesc, reassociation); 1358 })); 1359 } 1360 1361 static SmallVector<OpFoldResult, 4> 1362 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1363 ArrayRef<ReassociationIndices> reassociation, 1364 ArrayRef<int64_t> inStaticShape, 1365 MemRefDescriptor &inDesc, 1366 ArrayRef<int64_t> outStaticShape) { 1367 DenseMap<int64_t, int64_t> outDimToInDimMap = 1368 getExpandedDimToCollapsedDimMap(reassociation); 1369 return llvm::to_vector<4>(llvm::map_range( 1370 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1371 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1372 outStaticShape, inDesc, inStaticShape, 1373 reassociation, outDimToInDimMap); 1374 })); 1375 } 1376 1377 static SmallVector<Value> 1378 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1379 ArrayRef<ReassociationIndices> reassociation, 1380 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1381 ArrayRef<int64_t> outStaticShape) { 1382 return outStaticShape.size() < inStaticShape.size() 1383 ? getAsValues(b, loc, llvmIndexType, 1384 getCollapsedOutputShape(b, loc, llvmIndexType, 1385 reassociation, inStaticShape, 1386 inDesc, outStaticShape)) 1387 : getAsValues(b, loc, llvmIndexType, 1388 getExpandedOutputShape(b, loc, llvmIndexType, 1389 reassociation, inStaticShape, 1390 inDesc, outStaticShape)); 1391 } 1392 1393 static void fillInStridesForExpandedMemDescriptor( 1394 OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, 1395 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1396 // See comments for computeExpandedLayoutMap for details on how the strides 1397 // are calculated. 1398 for (auto &en : llvm::enumerate(reassociation)) { 1399 auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); 1400 for (auto dstIndex : llvm::reverse(en.value())) { 1401 dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); 1402 Value size = dstDesc.size(b, loc, dstIndex); 1403 currentStrideToExpand = 1404 b.create<LLVM::MulOp>(loc, size, currentStrideToExpand); 1405 } 1406 } 1407 } 1408 1409 static void fillInStridesForCollapsedMemDescriptor( 1410 ConversionPatternRewriter &rewriter, Location loc, Operation *op, 1411 TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, 1412 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1413 // See comments for computeCollapsedLayoutMap for details on how the strides 1414 // are calculated. 1415 auto srcShape = srcType.getShape(); 1416 for (auto &en : llvm::enumerate(reassociation)) { 1417 rewriter.setInsertionPoint(op); 1418 auto dstIndex = en.index(); 1419 ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value()); 1420 while (srcShape[ref.back()] == 1 && ref.size() > 1) 1421 ref = ref.drop_back(); 1422 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { 1423 dstDesc.setStride(rewriter, loc, dstIndex, 1424 srcDesc.stride(rewriter, loc, ref.back())); 1425 } else { 1426 // Iterate over the source strides in reverse order. Skip over the 1427 // dimensions whose size is 1. 1428 // TODO: we should take the minimum stride in the reassociation group 1429 // instead of just the first where the dimension is not 1. 1430 // 1431 // +------------------------------------------------------+ 1432 // | curEntry: | 1433 // | %srcStride = strides[srcIndex] | 1434 // | %neOne = cmp sizes[srcIndex],1 +--+ 1435 // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | 1436 // +-------------------------+----------------------------+ | 1437 // | | 1438 // v | 1439 // +-----------------------------+ | 1440 // | nextEntry: | | 1441 // | ... +---+ | 1442 // +--------------+--------------+ | | 1443 // | | | 1444 // v | | 1445 // +-----------------------------+ | | 1446 // | nextEntry: | | | 1447 // | ... | | | 1448 // +--------------+--------------+ | +--------+ 1449 // | | | 1450 // v v v 1451 // +--------------------------------------------------+ 1452 // | continue(%newStride): | 1453 // | %newMemRefDes = setStride(%newStride,dstIndex) | 1454 // +--------------------------------------------------+ 1455 OpBuilder::InsertionGuard guard(rewriter); 1456 Block *initBlock = rewriter.getInsertionBlock(); 1457 Block *continueBlock = 1458 rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); 1459 continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); 1460 rewriter.setInsertionPointToStart(continueBlock); 1461 dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); 1462 1463 Block *curEntryBlock = initBlock; 1464 Block *nextEntryBlock; 1465 for (auto srcIndex : llvm::reverse(ref)) { 1466 if (srcShape[srcIndex] == 1 && srcIndex != ref.front()) 1467 continue; 1468 rewriter.setInsertionPointToEnd(curEntryBlock); 1469 Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); 1470 if (srcIndex == ref.front()) { 1471 rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock); 1472 break; 1473 } 1474 Value one = rewriter.create<LLVM::ConstantOp>( 1475 loc, typeConverter->convertType(rewriter.getI64Type()), 1476 rewriter.getI32IntegerAttr(1)); 1477 Value predNeOne = rewriter.create<LLVM::ICmpOp>( 1478 loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), 1479 one); 1480 { 1481 OpBuilder::InsertionGuard guard(rewriter); 1482 nextEntryBlock = rewriter.createBlock( 1483 initBlock->getParent(), Region::iterator(continueBlock), {}); 1484 } 1485 rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock, 1486 srcStride, nextEntryBlock, llvm::None); 1487 curEntryBlock = nextEntryBlock; 1488 } 1489 } 1490 } 1491 } 1492 1493 static void fillInDynamicStridesForMemDescriptor( 1494 ConversionPatternRewriter &b, Location loc, Operation *op, 1495 TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, 1496 MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, 1497 ArrayRef<ReassociationIndices> reassociation) { 1498 if (srcType.getRank() > dstType.getRank()) 1499 fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, 1500 srcDesc, dstDesc, reassociation); 1501 else 1502 fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, 1503 reassociation); 1504 } 1505 1506 // ReshapeOp creates a new view descriptor of the proper rank. 1507 // For now, the only conversion supported is for target MemRef with static sizes 1508 // and strides. 1509 template <typename ReshapeOp> 1510 class ReassociatingReshapeOpConversion 1511 : public ConvertOpToLLVMPattern<ReshapeOp> { 1512 public: 1513 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1514 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1515 1516 LogicalResult 1517 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1518 ConversionPatternRewriter &rewriter) const override { 1519 MemRefType dstType = reshapeOp.getResultType(); 1520 MemRefType srcType = reshapeOp.getSrcType(); 1521 1522 int64_t offset; 1523 SmallVector<int64_t, 4> strides; 1524 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1525 return rewriter.notifyMatchFailure( 1526 reshapeOp, "failed to get stride and offset exprs"); 1527 } 1528 1529 MemRefDescriptor srcDesc(adaptor.src()); 1530 Location loc = reshapeOp->getLoc(); 1531 auto dstDesc = MemRefDescriptor::undef( 1532 rewriter, loc, this->typeConverter->convertType(dstType)); 1533 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1534 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1535 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1536 1537 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1538 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1539 Type llvmIndexType = 1540 this->typeConverter->convertType(rewriter.getIndexType()); 1541 SmallVector<Value> dstShape = getDynamicOutputShape( 1542 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1543 srcStaticShape, srcDesc, dstStaticShape); 1544 for (auto &en : llvm::enumerate(dstShape)) 1545 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1546 1547 if (llvm::all_of(strides, isStaticStrideOrOffset)) { 1548 for (auto &en : llvm::enumerate(strides)) 1549 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1550 } else if (srcType.getLayout().isIdentity() && 1551 dstType.getLayout().isIdentity()) { 1552 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1553 rewriter.getIndexAttr(1)); 1554 Value stride = c1; 1555 for (auto dimIndex : 1556 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1557 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1558 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1559 } 1560 } else { 1561 // There could be mixed static/dynamic strides. For simplicity, we 1562 // recompute all strides if there is at least one dynamic stride. 1563 fillInDynamicStridesForMemDescriptor( 1564 rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, 1565 srcDesc, dstDesc, reshapeOp.getReassociationIndices()); 1566 } 1567 rewriter.replaceOp(reshapeOp, {dstDesc}); 1568 return success(); 1569 } 1570 }; 1571 1572 /// Conversion pattern that transforms a subview op into: 1573 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1574 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1575 /// and stride. 1576 /// The subview op is replaced by the descriptor. 1577 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1578 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1579 1580 LogicalResult 1581 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1582 ConversionPatternRewriter &rewriter) const override { 1583 auto loc = subViewOp.getLoc(); 1584 1585 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1586 auto sourceElementTy = 1587 typeConverter->convertType(sourceMemRefType.getElementType()); 1588 1589 auto viewMemRefType = subViewOp.getType(); 1590 auto inferredType = memref::SubViewOp::inferResultType( 1591 subViewOp.getSourceType(), 1592 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1593 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1594 extractFromI64ArrayAttr(subViewOp.static_strides())) 1595 .cast<MemRefType>(); 1596 auto targetElementTy = 1597 typeConverter->convertType(viewMemRefType.getElementType()); 1598 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1599 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1600 !LLVM::isCompatibleType(sourceElementTy) || 1601 !LLVM::isCompatibleType(targetElementTy) || 1602 !LLVM::isCompatibleType(targetDescTy)) 1603 return failure(); 1604 1605 // Extract the offset and strides from the type. 1606 int64_t offset; 1607 SmallVector<int64_t, 4> strides; 1608 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1609 if (failed(successStrides)) 1610 return failure(); 1611 1612 // Create the descriptor. 1613 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1614 return failure(); 1615 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1616 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1617 1618 // Copy the buffer pointer from the old descriptor to the new one. 1619 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1620 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1621 loc, 1622 LLVM::LLVMPointerType::get(targetElementTy, 1623 viewMemRefType.getMemorySpaceAsInt()), 1624 extracted); 1625 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1626 1627 // Copy the aligned pointer from the old descriptor to the new one. 1628 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1629 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1630 loc, 1631 LLVM::LLVMPointerType::get(targetElementTy, 1632 viewMemRefType.getMemorySpaceAsInt()), 1633 extracted); 1634 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1635 1636 size_t inferredShapeRank = inferredType.getRank(); 1637 size_t resultShapeRank = viewMemRefType.getRank(); 1638 1639 // Extract strides needed to compute offset. 1640 SmallVector<Value, 4> strideValues; 1641 strideValues.reserve(inferredShapeRank); 1642 for (unsigned i = 0; i < inferredShapeRank; ++i) 1643 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1644 1645 // Offset. 1646 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1647 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1648 targetMemRef.setConstantOffset(rewriter, loc, offset); 1649 } else { 1650 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1651 // `inferredShapeRank` may be larger than the number of offset operands 1652 // because of trailing semantics. In this case, the offset is guaranteed 1653 // to be interpreted as 0 and we can just skip the extra dimensions. 1654 for (unsigned i = 0, e = std::min(inferredShapeRank, 1655 subViewOp.getMixedOffsets().size()); 1656 i < e; ++i) { 1657 Value offset = 1658 // TODO: need OpFoldResult ODS adaptor to clean this up. 1659 subViewOp.isDynamicOffset(i) 1660 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1661 : rewriter.create<LLVM::ConstantOp>( 1662 loc, llvmIndexType, 1663 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1664 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1665 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1666 } 1667 targetMemRef.setOffset(rewriter, loc, baseOffset); 1668 } 1669 1670 // Update sizes and strides. 1671 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1672 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1673 assert(mixedSizes.size() == mixedStrides.size() && 1674 "expected sizes and strides of equal length"); 1675 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1676 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1677 i >= 0 && j >= 0; --i) { 1678 if (unusedDims.test(i)) 1679 continue; 1680 1681 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1682 // In this case, the size is guaranteed to be interpreted as Dim and the 1683 // stride as 1. 1684 Value size, stride; 1685 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1686 // If the static size is available, use it directly. This is similar to 1687 // the folding of dim(constant-op) but removes the need for dim to be 1688 // aware of LLVM constants and for this pass to be aware of std 1689 // constants. 1690 int64_t staticSize = 1691 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1692 if (staticSize != ShapedType::kDynamicSize) { 1693 size = rewriter.create<LLVM::ConstantOp>( 1694 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1695 } else { 1696 Value pos = rewriter.create<LLVM::ConstantOp>( 1697 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1698 Value dim = 1699 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1700 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1701 loc, llvmIndexType, dim); 1702 size = cast.getResult(0); 1703 } 1704 stride = rewriter.create<LLVM::ConstantOp>( 1705 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1706 } else { 1707 // TODO: need OpFoldResult ODS adaptor to clean this up. 1708 size = 1709 subViewOp.isDynamicSize(i) 1710 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1711 : rewriter.create<LLVM::ConstantOp>( 1712 loc, llvmIndexType, 1713 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1714 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1715 stride = rewriter.create<LLVM::ConstantOp>( 1716 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1717 } else { 1718 stride = 1719 subViewOp.isDynamicStride(i) 1720 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1721 : rewriter.create<LLVM::ConstantOp>( 1722 loc, llvmIndexType, 1723 rewriter.getI64IntegerAttr( 1724 subViewOp.getStaticStride(i))); 1725 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1726 } 1727 } 1728 targetMemRef.setSize(rewriter, loc, j, size); 1729 targetMemRef.setStride(rewriter, loc, j, stride); 1730 j--; 1731 } 1732 1733 rewriter.replaceOp(subViewOp, {targetMemRef}); 1734 return success(); 1735 } 1736 }; 1737 1738 /// Conversion pattern that transforms a transpose op into: 1739 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1740 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1741 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1742 /// and stride. Size and stride are permutations of the original values. 1743 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1744 /// The transpose op is replaced by the alloca'ed pointer. 1745 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1746 public: 1747 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1748 1749 LogicalResult 1750 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1751 ConversionPatternRewriter &rewriter) const override { 1752 auto loc = transposeOp.getLoc(); 1753 MemRefDescriptor viewMemRef(adaptor.in()); 1754 1755 // No permutation, early exit. 1756 if (transposeOp.permutation().isIdentity()) 1757 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1758 1759 auto targetMemRef = MemRefDescriptor::undef( 1760 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1761 1762 // Copy the base and aligned pointers from the old descriptor to the new 1763 // one. 1764 targetMemRef.setAllocatedPtr(rewriter, loc, 1765 viewMemRef.allocatedPtr(rewriter, loc)); 1766 targetMemRef.setAlignedPtr(rewriter, loc, 1767 viewMemRef.alignedPtr(rewriter, loc)); 1768 1769 // Copy the offset pointer from the old descriptor to the new one. 1770 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1771 1772 // Iterate over the dimensions and apply size/stride permutation. 1773 for (const auto &en : 1774 llvm::enumerate(transposeOp.permutation().getResults())) { 1775 int sourcePos = en.index(); 1776 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1777 targetMemRef.setSize(rewriter, loc, targetPos, 1778 viewMemRef.size(rewriter, loc, sourcePos)); 1779 targetMemRef.setStride(rewriter, loc, targetPos, 1780 viewMemRef.stride(rewriter, loc, sourcePos)); 1781 } 1782 1783 rewriter.replaceOp(transposeOp, {targetMemRef}); 1784 return success(); 1785 } 1786 }; 1787 1788 /// Conversion pattern that transforms an op into: 1789 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1790 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1791 /// and stride. 1792 /// The view op is replaced by the descriptor. 1793 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1794 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1795 1796 // Build and return the value for the idx^th shape dimension, either by 1797 // returning the constant shape dimension or counting the proper dynamic size. 1798 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1799 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1800 unsigned idx) const { 1801 assert(idx < shape.size()); 1802 if (!ShapedType::isDynamic(shape[idx])) 1803 return createIndexConstant(rewriter, loc, shape[idx]); 1804 // Count the number of dynamic dims in range [0, idx] 1805 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1806 return ShapedType::isDynamic(v); 1807 }); 1808 return dynamicSizes[nDynamic]; 1809 } 1810 1811 // Build and return the idx^th stride, either by returning the constant stride 1812 // or by computing the dynamic stride from the current `runningStride` and 1813 // `nextSize`. The caller should keep a running stride and update it with the 1814 // result returned by this function. 1815 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1816 ArrayRef<int64_t> strides, Value nextSize, 1817 Value runningStride, unsigned idx) const { 1818 assert(idx < strides.size()); 1819 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1820 return createIndexConstant(rewriter, loc, strides[idx]); 1821 if (nextSize) 1822 return runningStride 1823 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1824 : nextSize; 1825 assert(!runningStride); 1826 return createIndexConstant(rewriter, loc, 1); 1827 } 1828 1829 LogicalResult 1830 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1831 ConversionPatternRewriter &rewriter) const override { 1832 auto loc = viewOp.getLoc(); 1833 1834 auto viewMemRefType = viewOp.getType(); 1835 auto targetElementTy = 1836 typeConverter->convertType(viewMemRefType.getElementType()); 1837 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1838 if (!targetDescTy || !targetElementTy || 1839 !LLVM::isCompatibleType(targetElementTy) || 1840 !LLVM::isCompatibleType(targetDescTy)) 1841 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1842 failure(); 1843 1844 int64_t offset; 1845 SmallVector<int64_t, 4> strides; 1846 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1847 if (failed(successStrides)) 1848 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1849 assert(offset == 0 && "expected offset to be 0"); 1850 1851 // Target memref must be contiguous in memory (innermost stride is 1), or 1852 // empty (special case when at least one of the memref dimensions is 0). 1853 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) 1854 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1855 failure(); 1856 1857 // Create the descriptor. 1858 MemRefDescriptor sourceMemRef(adaptor.source()); 1859 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1860 1861 // Field 1: Copy the allocated pointer, used for malloc/free. 1862 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1863 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1864 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1865 loc, 1866 LLVM::LLVMPointerType::get(targetElementTy, 1867 srcMemRefType.getMemorySpaceAsInt()), 1868 allocatedPtr); 1869 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1870 1871 // Field 2: Copy the actual aligned pointer to payload. 1872 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1873 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1874 alignedPtr, adaptor.byte_shift()); 1875 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1876 loc, 1877 LLVM::LLVMPointerType::get(targetElementTy, 1878 srcMemRefType.getMemorySpaceAsInt()), 1879 alignedPtr); 1880 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1881 1882 // Field 3: The offset in the resulting type must be 0. This is because of 1883 // the type change: an offset on srcType* may not be expressible as an 1884 // offset on dstType*. 1885 targetMemRef.setOffset(rewriter, loc, 1886 createIndexConstant(rewriter, loc, offset)); 1887 1888 // Early exit for 0-D corner case. 1889 if (viewMemRefType.getRank() == 0) 1890 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1891 1892 // Fields 4 and 5: Update sizes and strides. 1893 Value stride = nullptr, nextSize = nullptr; 1894 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1895 // Update size. 1896 Value size = 1897 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1898 targetMemRef.setSize(rewriter, loc, i, size); 1899 // Update stride. 1900 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1901 targetMemRef.setStride(rewriter, loc, i, stride); 1902 nextSize = size; 1903 } 1904 1905 rewriter.replaceOp(viewOp, {targetMemRef}); 1906 return success(); 1907 } 1908 }; 1909 1910 //===----------------------------------------------------------------------===// 1911 // AtomicRMWOpLowering 1912 //===----------------------------------------------------------------------===// 1913 1914 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1915 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1916 static Optional<LLVM::AtomicBinOp> 1917 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1918 switch (atomicOp.kind()) { 1919 case arith::AtomicRMWKind::addf: 1920 return LLVM::AtomicBinOp::fadd; 1921 case arith::AtomicRMWKind::addi: 1922 return LLVM::AtomicBinOp::add; 1923 case arith::AtomicRMWKind::assign: 1924 return LLVM::AtomicBinOp::xchg; 1925 case arith::AtomicRMWKind::maxs: 1926 return LLVM::AtomicBinOp::max; 1927 case arith::AtomicRMWKind::maxu: 1928 return LLVM::AtomicBinOp::umax; 1929 case arith::AtomicRMWKind::mins: 1930 return LLVM::AtomicBinOp::min; 1931 case arith::AtomicRMWKind::minu: 1932 return LLVM::AtomicBinOp::umin; 1933 case arith::AtomicRMWKind::ori: 1934 return LLVM::AtomicBinOp::_or; 1935 case arith::AtomicRMWKind::andi: 1936 return LLVM::AtomicBinOp::_and; 1937 default: 1938 return llvm::None; 1939 } 1940 llvm_unreachable("Invalid AtomicRMWKind"); 1941 } 1942 1943 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1944 using Base::Base; 1945 1946 LogicalResult 1947 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1948 ConversionPatternRewriter &rewriter) const override { 1949 if (failed(match(atomicOp))) 1950 return failure(); 1951 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1952 if (!maybeKind) 1953 return failure(); 1954 auto resultType = adaptor.value().getType(); 1955 auto memRefType = atomicOp.getMemRefType(); 1956 auto dataPtr = 1957 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1958 adaptor.indices(), rewriter); 1959 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1960 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1961 LLVM::AtomicOrdering::acq_rel); 1962 return success(); 1963 } 1964 }; 1965 1966 } // namespace 1967 1968 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1969 RewritePatternSet &patterns) { 1970 // clang-format off 1971 patterns.add< 1972 AllocaOpLowering, 1973 AllocaScopeOpLowering, 1974 AtomicRMWOpLowering, 1975 AssumeAlignmentOpLowering, 1976 DimOpLowering, 1977 GenericAtomicRMWOpLowering, 1978 GlobalMemrefOpLowering, 1979 GetGlobalMemrefOpLowering, 1980 LoadOpLowering, 1981 MemRefCastOpLowering, 1982 MemRefCopyOpLowering, 1983 MemRefReinterpretCastOpLowering, 1984 MemRefReshapeOpLowering, 1985 PrefetchOpLowering, 1986 RankOpLowering, 1987 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1988 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1989 StoreOpLowering, 1990 SubViewOpLowering, 1991 TransposeOpLowering, 1992 ViewOpLowering>(converter); 1993 // clang-format on 1994 auto allocLowering = converter.getOptions().allocLowering; 1995 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1996 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1997 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1998 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1999 } 2000 2001 namespace { 2002 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 2003 MemRefToLLVMPass() = default; 2004 2005 void runOnOperation() override { 2006 Operation *op = getOperation(); 2007 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 2008 LowerToLLVMOptions options(&getContext(), 2009 dataLayoutAnalysis.getAtOrAbove(op)); 2010 options.allocLowering = 2011 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 2012 : LowerToLLVMOptions::AllocLowering::Malloc); 2013 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 2014 options.overrideIndexBitwidth(indexBitwidth); 2015 2016 LLVMTypeConverter typeConverter(&getContext(), options, 2017 &dataLayoutAnalysis); 2018 RewritePatternSet patterns(&getContext()); 2019 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 2020 LLVMConversionTarget target(getContext()); 2021 target.addLegalOp<func::FuncOp>(); 2022 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 2023 signalPassFailure(); 2024 } 2025 }; 2026 } // namespace 2027 2028 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 2029 return std::make_unique<MemRefToLLVMPass>(); 2030 } 2031