1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "llvm/ADT/SmallBitVector.h" 24 25 using namespace mlir; 26 27 namespace { 28 29 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 30 AllocOpLowering(LLVMTypeConverter &converter) 31 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 32 converter) {} 33 34 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 35 Location loc, Value sizeBytes, 36 Operation *op) const override { 37 // Heap allocations. 38 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 39 MemRefType memRefType = allocOp.getType(); 40 41 Value alignment; 42 if (auto alignmentAttr = allocOp.alignment()) { 43 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 44 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 45 // In the case where no alignment is specified, we may want to override 46 // `malloc's` behavior. `malloc` typically aligns at the size of the 47 // biggest scalar on a target HW. For non-scalars, use the natural 48 // alignment of the LLVM type given by the LLVM DataLayout. 49 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 50 } 51 52 if (alignment) { 53 // Adjust the allocation size to consider alignment. 54 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 55 } 56 57 // Allocate the underlying buffer and store a pointer to it in the MemRef 58 // descriptor. 59 Type elementPtrType = this->getElementPtrType(memRefType); 60 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 61 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 62 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 63 getVoidPtrType()); 64 Value allocatedPtr = 65 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 66 67 Value alignedPtr = allocatedPtr; 68 if (alignment) { 69 // Compute the aligned type pointer. 70 Value allocatedInt = 71 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 72 Value alignmentInt = 73 createAligned(rewriter, loc, allocatedInt, alignment); 74 alignedPtr = 75 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 76 } 77 78 return std::make_tuple(allocatedPtr, alignedPtr); 79 } 80 }; 81 82 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 83 AlignedAllocOpLowering(LLVMTypeConverter &converter) 84 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 85 converter) {} 86 87 /// Returns the memref's element size in bytes using the data layout active at 88 /// `op`. 89 // TODO: there are other places where this is used. Expose publicly? 90 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 91 const DataLayout *layout = &defaultLayout; 92 if (const DataLayoutAnalysis *analysis = 93 getTypeConverter()->getDataLayoutAnalysis()) { 94 layout = &analysis->getAbove(op); 95 } 96 Type elementType = memRefType.getElementType(); 97 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 98 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 99 *layout); 100 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 101 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 102 memRefElementType, *layout); 103 return layout->getTypeSize(elementType); 104 } 105 106 /// Returns true if the memref size in bytes is known to be a multiple of 107 /// factor assuming the data layout active at `op`. 108 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 109 Operation *op) const { 110 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 111 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 112 if (ShapedType::isDynamic(type.getDimSize(i))) 113 continue; 114 sizeDivisor = sizeDivisor * type.getDimSize(i); 115 } 116 return sizeDivisor % factor == 0; 117 } 118 119 /// Returns the alignment to be used for the allocation call itself. 120 /// aligned_alloc requires the allocation size to be a power of two, and the 121 /// allocation size to be a multiple of alignment, 122 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 123 if (Optional<uint64_t> alignment = allocOp.alignment()) 124 return *alignment; 125 126 // Whenever we don't have alignment set, we will use an alignment 127 // consistent with the element type; since the allocation size has to be a 128 // power of two, we will bump to the next power of two if it already isn't. 129 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 130 return std::max(kMinAlignedAllocAlignment, 131 llvm::PowerOf2Ceil(eltSizeBytes)); 132 } 133 134 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 135 Location loc, Value sizeBytes, 136 Operation *op) const override { 137 // Heap allocations. 138 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 139 MemRefType memRefType = allocOp.getType(); 140 int64_t alignment = getAllocationAlignment(allocOp); 141 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 142 143 // aligned_alloc requires size to be a multiple of alignment; we will pad 144 // the size to the next multiple if necessary. 145 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 146 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 147 148 Type elementPtrType = this->getElementPtrType(memRefType); 149 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 150 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 151 auto results = 152 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 153 getVoidPtrType()); 154 Value allocatedPtr = 155 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 156 157 return std::make_tuple(allocatedPtr, allocatedPtr); 158 } 159 160 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 161 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 162 163 /// Default layout to use in absence of the corresponding analysis. 164 DataLayout defaultLayout; 165 }; 166 167 // Out of line definition, required till C++17. 168 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 169 170 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 171 AllocaOpLowering(LLVMTypeConverter &converter) 172 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 173 converter) {} 174 175 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 176 /// is set to null for stack allocations. `accessAlignment` is set if 177 /// alignment is needed post allocation (for eg. in conjunction with malloc). 178 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 179 Location loc, Value sizeBytes, 180 Operation *op) const override { 181 182 // With alloca, one gets a pointer to the element type right away. 183 // For stack allocations. 184 auto allocaOp = cast<memref::AllocaOp>(op); 185 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 186 187 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 188 loc, elementPtrType, sizeBytes, 189 allocaOp.alignment() ? *allocaOp.alignment() : 0); 190 191 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 192 } 193 }; 194 195 struct AllocaScopeOpLowering 196 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 197 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 198 199 LogicalResult 200 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 201 ConversionPatternRewriter &rewriter) const override { 202 OpBuilder::InsertionGuard guard(rewriter); 203 Location loc = allocaScopeOp.getLoc(); 204 205 // Split the current block before the AllocaScopeOp to create the inlining 206 // point. 207 auto *currentBlock = rewriter.getInsertionBlock(); 208 auto *remainingOpsBlock = 209 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 210 Block *continueBlock; 211 if (allocaScopeOp.getNumResults() == 0) { 212 continueBlock = remainingOpsBlock; 213 } else { 214 continueBlock = rewriter.createBlock( 215 remainingOpsBlock, allocaScopeOp.getResultTypes(), 216 SmallVector<Location>(allocaScopeOp->getNumResults(), 217 allocaScopeOp.getLoc())); 218 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 219 } 220 221 // Inline body region. 222 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 223 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 224 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 225 226 // Save stack and then branch into the body of the region. 227 rewriter.setInsertionPointToEnd(currentBlock); 228 auto stackSaveOp = 229 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 230 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 231 232 // Replace the alloca_scope return with a branch that jumps out of the body. 233 // Stack restore before leaving the body region. 234 rewriter.setInsertionPointToEnd(afterBody); 235 auto returnOp = 236 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 237 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 238 returnOp, returnOp.results(), continueBlock); 239 240 // Insert stack restore before jumping out the body of the region. 241 rewriter.setInsertionPoint(branchOp); 242 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 243 244 // Replace the op with values return from the body region. 245 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 246 247 return success(); 248 } 249 }; 250 251 struct AssumeAlignmentOpLowering 252 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 253 using ConvertOpToLLVMPattern< 254 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 255 256 LogicalResult 257 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 258 ConversionPatternRewriter &rewriter) const override { 259 Value memref = adaptor.memref(); 260 unsigned alignment = op.alignment(); 261 auto loc = op.getLoc(); 262 263 MemRefDescriptor memRefDescriptor(memref); 264 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 265 266 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 267 // the asserted memref.alignedPtr isn't used anywhere else, as the real 268 // users like load/store/views always re-extract memref.alignedPtr as they 269 // get lowered. 270 // 271 // This relies on LLVM's CSE optimization (potentially after SROA), since 272 // after CSE all memref.alignedPtr instances get de-duplicated into the same 273 // pointer SSA value. 274 auto intPtrType = 275 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 276 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 277 Value mask = 278 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 279 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 280 rewriter.create<LLVM::AssumeOp>( 281 loc, rewriter.create<LLVM::ICmpOp>( 282 loc, LLVM::ICmpPredicate::eq, 283 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 284 285 rewriter.eraseOp(op); 286 return success(); 287 } 288 }; 289 290 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 291 // The memref descriptor being an SSA value, there is no need to clean it up 292 // in any way. 293 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 294 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 295 296 explicit DeallocOpLowering(LLVMTypeConverter &converter) 297 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 298 299 LogicalResult 300 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 301 ConversionPatternRewriter &rewriter) const override { 302 // Insert the `free` declaration if it is not already present. 303 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 304 MemRefDescriptor memref(adaptor.memref()); 305 Value casted = rewriter.create<LLVM::BitcastOp>( 306 op.getLoc(), getVoidPtrType(), 307 memref.allocatedPtr(rewriter, op.getLoc())); 308 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 309 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 310 return success(); 311 } 312 }; 313 314 // A `dim` is converted to a constant for static sizes and to an access to the 315 // size stored in the memref descriptor for dynamic sizes. 316 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 317 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 318 319 LogicalResult 320 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 321 ConversionPatternRewriter &rewriter) const override { 322 Type operandType = dimOp.source().getType(); 323 if (operandType.isa<UnrankedMemRefType>()) { 324 rewriter.replaceOp( 325 dimOp, {extractSizeOfUnrankedMemRef( 326 operandType, dimOp, adaptor.getOperands(), rewriter)}); 327 328 return success(); 329 } 330 if (operandType.isa<MemRefType>()) { 331 rewriter.replaceOp( 332 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 333 adaptor.getOperands(), rewriter)}); 334 return success(); 335 } 336 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 337 } 338 339 private: 340 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 341 OpAdaptor adaptor, 342 ConversionPatternRewriter &rewriter) const { 343 Location loc = dimOp.getLoc(); 344 345 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 346 auto scalarMemRefType = 347 MemRefType::get({}, unrankedMemRefType.getElementType()); 348 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 349 350 // Extract pointer to the underlying ranked descriptor and bitcast it to a 351 // memref<element_type> descriptor pointer to minimize the number of GEP 352 // operations. 353 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 354 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 355 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 356 loc, 357 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 358 addressSpace), 359 underlyingRankedDesc); 360 361 // Get pointer to offset field of memref<element_type> descriptor. 362 Type indexPtrTy = LLVM::LLVMPointerType::get( 363 getTypeConverter()->getIndexType(), addressSpace); 364 Value two = rewriter.create<LLVM::ConstantOp>( 365 loc, typeConverter->convertType(rewriter.getI32Type()), 366 rewriter.getI32IntegerAttr(2)); 367 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 368 loc, indexPtrTy, scalarMemRefDescPtr, 369 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 370 371 // The size value that we have to extract can be obtained using GEPop with 372 // `dimOp.index() + 1` index argument. 373 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 374 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 375 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 376 ValueRange({idxPlusOne})); 377 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 378 } 379 380 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 381 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 382 return idx; 383 384 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 385 return constantOp.getValue() 386 .cast<IntegerAttr>() 387 .getValue() 388 .getSExtValue(); 389 390 return llvm::None; 391 } 392 393 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 394 OpAdaptor adaptor, 395 ConversionPatternRewriter &rewriter) const { 396 Location loc = dimOp.getLoc(); 397 398 // Take advantage if index is constant. 399 MemRefType memRefType = operandType.cast<MemRefType>(); 400 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 401 int64_t i = index.getValue(); 402 if (memRefType.isDynamicDim(i)) { 403 // extract dynamic size from the memref descriptor. 404 MemRefDescriptor descriptor(adaptor.source()); 405 return descriptor.size(rewriter, loc, i); 406 } 407 // Use constant for static size. 408 int64_t dimSize = memRefType.getDimSize(i); 409 return createIndexConstant(rewriter, loc, dimSize); 410 } 411 Value index = adaptor.index(); 412 int64_t rank = memRefType.getRank(); 413 MemRefDescriptor memrefDescriptor(adaptor.source()); 414 return memrefDescriptor.size(rewriter, loc, index, rank); 415 } 416 }; 417 418 /// Common base for load and store operations on MemRefs. Restricts the match 419 /// to supported MemRef types. Provides functionality to emit code accessing a 420 /// specific element of the underlying data buffer. 421 template <typename Derived> 422 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 423 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 424 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 425 using Base = LoadStoreOpLowering<Derived>; 426 427 LogicalResult match(Derived op) const override { 428 MemRefType type = op.getMemRefType(); 429 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 430 } 431 }; 432 433 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 434 /// retried until it succeeds in atomically storing a new value into memory. 435 /// 436 /// +---------------------------------+ 437 /// | <code before the AtomicRMWOp> | 438 /// | <compute initial %loaded> | 439 /// | cf.br loop(%loaded) | 440 /// +---------------------------------+ 441 /// | 442 /// -------| | 443 /// | v v 444 /// | +--------------------------------+ 445 /// | | loop(%loaded): | 446 /// | | <body contents> | 447 /// | | %pair = cmpxchg | 448 /// | | %ok = %pair[0] | 449 /// | | %new = %pair[1] | 450 /// | | cf.cond_br %ok, end, loop(%new) | 451 /// | +--------------------------------+ 452 /// | | | 453 /// |----------- | 454 /// v 455 /// +--------------------------------+ 456 /// | end: | 457 /// | <code after the AtomicRMWOp> | 458 /// +--------------------------------+ 459 /// 460 struct GenericAtomicRMWOpLowering 461 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 462 using Base::Base; 463 464 LogicalResult 465 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 466 ConversionPatternRewriter &rewriter) const override { 467 auto loc = atomicOp.getLoc(); 468 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 469 470 // Split the block into initial, loop, and ending parts. 471 auto *initBlock = rewriter.getInsertionBlock(); 472 auto *loopBlock = rewriter.createBlock( 473 initBlock->getParent(), std::next(Region::iterator(initBlock)), 474 valueType, loc); 475 auto *endBlock = rewriter.createBlock( 476 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 477 478 // Operations range to be moved to `endBlock`. 479 auto opsToMoveStart = atomicOp->getIterator(); 480 auto opsToMoveEnd = initBlock->back().getIterator(); 481 482 // Compute the loaded value and branch to the loop block. 483 rewriter.setInsertionPointToEnd(initBlock); 484 auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); 485 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 486 adaptor.indices(), rewriter); 487 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 488 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 489 490 // Prepare the body of the loop block. 491 rewriter.setInsertionPointToStart(loopBlock); 492 493 // Clone the GenericAtomicRMWOp region and extract the result. 494 auto loopArgument = loopBlock->getArgument(0); 495 BlockAndValueMapping mapping; 496 mapping.map(atomicOp.getCurrentValue(), loopArgument); 497 Block &entryBlock = atomicOp.body().front(); 498 for (auto &nestedOp : entryBlock.without_terminator()) { 499 Operation *clone = rewriter.clone(nestedOp, mapping); 500 mapping.map(nestedOp.getResults(), clone->getResults()); 501 } 502 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 503 504 // Prepare the epilog of the loop block. 505 // Append the cmpxchg op to the end of the loop block. 506 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 507 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 508 auto boolType = IntegerType::get(rewriter.getContext(), 1); 509 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 510 {valueType, boolType}); 511 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 512 loc, pairType, dataPtr, loopArgument, result, successOrdering, 513 failureOrdering); 514 // Extract the %new_loaded and %ok values from the pair. 515 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 516 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 517 Value ok = rewriter.create<LLVM::ExtractValueOp>( 518 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 519 520 // Conditionally branch to the end or back to the loop depending on %ok. 521 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 522 loopBlock, newLoaded); 523 524 rewriter.setInsertionPointToEnd(endBlock); 525 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 526 std::next(opsToMoveEnd), rewriter); 527 528 // The 'result' of the atomic_rmw op is the newly loaded value. 529 rewriter.replaceOp(atomicOp, {newLoaded}); 530 531 return success(); 532 } 533 534 private: 535 // Clones a segment of ops [start, end) and erases the original. 536 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 537 Block::iterator start, Block::iterator end, 538 ConversionPatternRewriter &rewriter) const { 539 BlockAndValueMapping mapping; 540 mapping.map(oldResult, newResult); 541 SmallVector<Operation *, 2> opsToErase; 542 for (auto it = start; it != end; ++it) { 543 rewriter.clone(*it, mapping); 544 opsToErase.push_back(&*it); 545 } 546 for (auto *it : opsToErase) 547 rewriter.eraseOp(it); 548 } 549 }; 550 551 /// Returns the LLVM type of the global variable given the memref type `type`. 552 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 553 LLVMTypeConverter &typeConverter) { 554 // LLVM type for a global memref will be a multi-dimension array. For 555 // declarations or uninitialized global memrefs, we can potentially flatten 556 // this to a 1D array. However, for memref.global's with an initial value, 557 // we do not intend to flatten the ElementsAttribute when going from std -> 558 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 559 Type elementType = typeConverter.convertType(type.getElementType()); 560 Type arrayTy = elementType; 561 // Shape has the outermost dim at index 0, so need to walk it backwards 562 for (int64_t dim : llvm::reverse(type.getShape())) 563 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 564 return arrayTy; 565 } 566 567 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 568 struct GlobalMemrefOpLowering 569 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 570 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 571 572 LogicalResult 573 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 574 ConversionPatternRewriter &rewriter) const override { 575 MemRefType type = global.type(); 576 if (!isConvertibleAndHasIdentityMaps(type)) 577 return failure(); 578 579 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 580 581 LLVM::Linkage linkage = 582 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 583 584 Attribute initialValue = nullptr; 585 if (!global.isExternal() && !global.isUninitialized()) { 586 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 587 initialValue = elementsAttr; 588 589 // For scalar memrefs, the global variable created is of the element type, 590 // so unpack the elements attribute to extract the value. 591 if (type.getRank() == 0) 592 initialValue = elementsAttr.getSplatValue<Attribute>(); 593 } 594 595 uint64_t alignment = global.alignment().getValueOr(0); 596 597 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 598 global, arrayTy, global.constant(), linkage, global.sym_name(), 599 initialValue, alignment, type.getMemorySpaceAsInt()); 600 if (!global.isExternal() && global.isUninitialized()) { 601 Block *blk = new Block(); 602 newGlobal.getInitializerRegion().push_back(blk); 603 rewriter.setInsertionPointToStart(blk); 604 Value undef[] = { 605 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 606 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 607 } 608 return success(); 609 } 610 }; 611 612 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 613 /// the first element stashed into the descriptor. This reuses 614 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 615 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 616 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 617 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 618 converter) {} 619 620 /// Buffer "allocation" for memref.get_global op is getting the address of 621 /// the global variable referenced. 622 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 623 Location loc, Value sizeBytes, 624 Operation *op) const override { 625 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 626 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 627 unsigned memSpace = type.getMemorySpaceAsInt(); 628 629 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 630 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 631 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 632 633 // Get the address of the first element in the array by creating a GEP with 634 // the address of the GV as the base, and (rank + 1) number of 0 indices. 635 Type elementType = typeConverter->convertType(type.getElementType()); 636 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 637 638 SmallVector<Value> operands; 639 operands.insert(operands.end(), type.getRank() + 1, 640 createIndexConstant(rewriter, loc, 0)); 641 auto gep = 642 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 643 644 // We do not expect the memref obtained using `memref.get_global` to be 645 // ever deallocated. Set the allocated pointer to be known bad value to 646 // help debug if that ever happens. 647 auto intPtrType = getIntPtrType(memSpace); 648 Value deadBeefConst = 649 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 650 auto deadBeefPtr = 651 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 652 653 // Both allocated and aligned pointers are same. We could potentially stash 654 // a nullptr for the allocated pointer since we do not expect any dealloc. 655 return std::make_tuple(deadBeefPtr, gep); 656 } 657 }; 658 659 // Load operation is lowered to obtaining a pointer to the indexed element 660 // and loading it. 661 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 662 using Base::Base; 663 664 LogicalResult 665 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 666 ConversionPatternRewriter &rewriter) const override { 667 auto type = loadOp.getMemRefType(); 668 669 Value dataPtr = getStridedElementPtr( 670 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 671 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 672 return success(); 673 } 674 }; 675 676 // Store operation is lowered to obtaining a pointer to the indexed element, 677 // and storing the given value to it. 678 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 679 using Base::Base; 680 681 LogicalResult 682 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 683 ConversionPatternRewriter &rewriter) const override { 684 auto type = op.getMemRefType(); 685 686 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 687 adaptor.indices(), rewriter); 688 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 689 return success(); 690 } 691 }; 692 693 // The prefetch operation is lowered in a way similar to the load operation 694 // except that the llvm.prefetch operation is used for replacement. 695 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 696 using Base::Base; 697 698 LogicalResult 699 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 700 ConversionPatternRewriter &rewriter) const override { 701 auto type = prefetchOp.getMemRefType(); 702 auto loc = prefetchOp.getLoc(); 703 704 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 705 adaptor.indices(), rewriter); 706 707 // Replace with llvm.prefetch. 708 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 709 auto isWrite = rewriter.create<LLVM::ConstantOp>( 710 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 711 auto localityHint = rewriter.create<LLVM::ConstantOp>( 712 loc, llvmI32Type, 713 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 714 auto isData = rewriter.create<LLVM::ConstantOp>( 715 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 716 717 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 718 localityHint, isData); 719 return success(); 720 } 721 }; 722 723 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 724 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 725 726 LogicalResult 727 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 728 ConversionPatternRewriter &rewriter) const override { 729 Location loc = op.getLoc(); 730 Type operandType = op.memref().getType(); 731 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 732 UnrankedMemRefDescriptor desc(adaptor.memref()); 733 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 734 return success(); 735 } 736 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 737 rewriter.replaceOp( 738 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 739 return success(); 740 } 741 return failure(); 742 } 743 }; 744 745 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 746 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 747 748 LogicalResult match(memref::CastOp memRefCastOp) const override { 749 Type srcType = memRefCastOp.getOperand().getType(); 750 Type dstType = memRefCastOp.getType(); 751 752 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 753 // used for type erasure. For now they must preserve underlying element type 754 // and require source and result type to have the same rank. Therefore, 755 // perform a sanity check that the underlying structs are the same. Once op 756 // semantics are relaxed we can revisit. 757 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 758 return success(typeConverter->convertType(srcType) == 759 typeConverter->convertType(dstType)); 760 761 // At least one of the operands is unranked type 762 assert(srcType.isa<UnrankedMemRefType>() || 763 dstType.isa<UnrankedMemRefType>()); 764 765 // Unranked to unranked cast is disallowed 766 return !(srcType.isa<UnrankedMemRefType>() && 767 dstType.isa<UnrankedMemRefType>()) 768 ? success() 769 : failure(); 770 } 771 772 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 773 ConversionPatternRewriter &rewriter) const override { 774 auto srcType = memRefCastOp.getOperand().getType(); 775 auto dstType = memRefCastOp.getType(); 776 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 777 auto loc = memRefCastOp.getLoc(); 778 779 // For ranked/ranked case, just keep the original descriptor. 780 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 781 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 782 783 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 784 // Casting ranked to unranked memref type 785 // Set the rank in the destination from the memref type 786 // Allocate space on the stack and copy the src memref descriptor 787 // Set the ptr in the destination to the stack space 788 auto srcMemRefType = srcType.cast<MemRefType>(); 789 int64_t rank = srcMemRefType.getRank(); 790 // ptr = AllocaOp sizeof(MemRefDescriptor) 791 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 792 loc, adaptor.source(), rewriter); 793 // voidptr = BitCastOp srcType* to void* 794 auto voidPtr = 795 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 796 .getResult(); 797 // rank = ConstantOp srcRank 798 auto rankVal = rewriter.create<LLVM::ConstantOp>( 799 loc, getIndexType(), rewriter.getIndexAttr(rank)); 800 // undef = UndefOp 801 UnrankedMemRefDescriptor memRefDesc = 802 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 803 // d1 = InsertValueOp undef, rank, 0 804 memRefDesc.setRank(rewriter, loc, rankVal); 805 // d2 = InsertValueOp d1, voidptr, 1 806 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 807 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 808 809 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 810 // Casting from unranked type to ranked. 811 // The operation is assumed to be doing a correct cast. If the destination 812 // type mismatches the unranked the type, it is undefined behavior. 813 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 814 // ptr = ExtractValueOp src, 1 815 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 816 // castPtr = BitCastOp i8* to structTy* 817 auto castPtr = 818 rewriter 819 .create<LLVM::BitcastOp>( 820 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 821 .getResult(); 822 // struct = LoadOp castPtr 823 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 824 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 825 } else { 826 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 827 } 828 } 829 }; 830 831 /// Pattern to lower a `memref.copy` to llvm. 832 /// 833 /// For memrefs with identity layouts, the copy is lowered to the llvm 834 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 835 /// to the generic `MemrefCopyFn`. 836 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 837 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 838 839 LogicalResult 840 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 841 ConversionPatternRewriter &rewriter) const { 842 auto loc = op.getLoc(); 843 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 844 845 MemRefDescriptor srcDesc(adaptor.source()); 846 847 // Compute number of elements. 848 Value numElements = rewriter.create<LLVM::ConstantOp>( 849 loc, getIndexType(), rewriter.getIndexAttr(1)); 850 for (int pos = 0; pos < srcType.getRank(); ++pos) { 851 auto size = srcDesc.size(rewriter, loc, pos); 852 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 853 } 854 855 // Get element size. 856 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 857 // Compute total. 858 Value totalSize = 859 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 860 861 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 862 Value srcOffset = srcDesc.offset(rewriter, loc); 863 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(), 864 srcBasePtr, srcOffset); 865 MemRefDescriptor targetDesc(adaptor.target()); 866 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 867 Value targetOffset = targetDesc.offset(rewriter, loc); 868 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(), 869 targetBasePtr, targetOffset); 870 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 871 loc, typeConverter->convertType(rewriter.getI1Type()), 872 rewriter.getBoolAttr(false)); 873 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 874 isVolatile); 875 rewriter.eraseOp(op); 876 877 return success(); 878 } 879 880 LogicalResult 881 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 882 ConversionPatternRewriter &rewriter) const { 883 auto loc = op.getLoc(); 884 auto srcType = op.source().getType().cast<BaseMemRefType>(); 885 auto targetType = op.target().getType().cast<BaseMemRefType>(); 886 887 // First make sure we have an unranked memref descriptor representation. 888 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 889 auto rank = rewriter.create<LLVM::ConstantOp>( 890 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 891 auto *typeConverter = getTypeConverter(); 892 auto ptr = 893 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 894 auto voidPtr = 895 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 896 .getResult(); 897 auto unrankedType = 898 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 899 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 900 unrankedType, 901 ValueRange{rank, voidPtr}); 902 }; 903 904 Value unrankedSource = srcType.hasRank() 905 ? makeUnranked(adaptor.source(), srcType) 906 : adaptor.source(); 907 Value unrankedTarget = targetType.hasRank() 908 ? makeUnranked(adaptor.target(), targetType) 909 : adaptor.target(); 910 911 // Now promote the unranked descriptors to the stack. 912 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 913 rewriter.getIndexAttr(1)); 914 auto promote = [&](Value desc) { 915 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 916 auto allocated = 917 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 918 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 919 return allocated; 920 }; 921 922 auto sourcePtr = promote(unrankedSource); 923 auto targetPtr = promote(unrankedTarget); 924 925 unsigned typeSize = 926 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 927 auto elemSize = rewriter.create<LLVM::ConstantOp>( 928 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 929 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 930 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 931 rewriter.create<LLVM::CallOp>(loc, copyFn, 932 ValueRange{elemSize, sourcePtr, targetPtr}); 933 rewriter.eraseOp(op); 934 935 return success(); 936 } 937 938 LogicalResult 939 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 940 ConversionPatternRewriter &rewriter) const override { 941 auto srcType = op.source().getType().cast<BaseMemRefType>(); 942 auto targetType = op.target().getType().cast<BaseMemRefType>(); 943 944 auto isContiguousMemrefType = [](BaseMemRefType type) { 945 auto memrefType = type.dyn_cast<mlir::MemRefType>(); 946 // We can use memcpy for memrefs if they have an identity layout or are 947 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 948 // special case handled by memrefCopy. 949 return memrefType && 950 (memrefType.getLayout().isIdentity() || 951 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 952 isStaticShapeAndContiguousRowMajor(memrefType))); 953 }; 954 955 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 956 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 957 958 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 959 } 960 }; 961 962 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 963 /// memref type. In unranked case, the fields are extracted from the underlying 964 /// ranked descriptor. 965 static void extractPointersAndOffset(Location loc, 966 ConversionPatternRewriter &rewriter, 967 LLVMTypeConverter &typeConverter, 968 Value originalOperand, 969 Value convertedOperand, 970 Value *allocatedPtr, Value *alignedPtr, 971 Value *offset = nullptr) { 972 Type operandType = originalOperand.getType(); 973 if (operandType.isa<MemRefType>()) { 974 MemRefDescriptor desc(convertedOperand); 975 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 976 *alignedPtr = desc.alignedPtr(rewriter, loc); 977 if (offset != nullptr) 978 *offset = desc.offset(rewriter, loc); 979 return; 980 } 981 982 unsigned memorySpace = 983 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 984 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 985 Type llvmElementType = typeConverter.convertType(elementType); 986 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 987 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 988 989 // Extract pointer to the underlying ranked memref descriptor and cast it to 990 // ElemType**. 991 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 992 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 993 994 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 995 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 996 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 997 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 998 if (offset != nullptr) { 999 *offset = UnrankedMemRefDescriptor::offset( 1000 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1001 } 1002 } 1003 1004 struct MemRefReinterpretCastOpLowering 1005 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1006 using ConvertOpToLLVMPattern< 1007 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1008 1009 LogicalResult 1010 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1011 ConversionPatternRewriter &rewriter) const override { 1012 Type srcType = castOp.source().getType(); 1013 1014 Value descriptor; 1015 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1016 adaptor, &descriptor))) 1017 return failure(); 1018 rewriter.replaceOp(castOp, {descriptor}); 1019 return success(); 1020 } 1021 1022 private: 1023 LogicalResult convertSourceMemRefToDescriptor( 1024 ConversionPatternRewriter &rewriter, Type srcType, 1025 memref::ReinterpretCastOp castOp, 1026 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1027 MemRefType targetMemRefType = 1028 castOp.getResult().getType().cast<MemRefType>(); 1029 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1030 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1031 if (!llvmTargetDescriptorTy) 1032 return failure(); 1033 1034 // Create descriptor. 1035 Location loc = castOp.getLoc(); 1036 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1037 1038 // Set allocated and aligned pointers. 1039 Value allocatedPtr, alignedPtr; 1040 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1041 castOp.source(), adaptor.source(), &allocatedPtr, 1042 &alignedPtr); 1043 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1044 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1045 1046 // Set offset. 1047 if (castOp.isDynamicOffset(0)) 1048 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 1049 else 1050 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1051 1052 // Set sizes and strides. 1053 unsigned dynSizeId = 0; 1054 unsigned dynStrideId = 0; 1055 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1056 if (castOp.isDynamicSize(i)) 1057 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 1058 else 1059 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1060 1061 if (castOp.isDynamicStride(i)) 1062 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 1063 else 1064 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1065 } 1066 *descriptor = desc; 1067 return success(); 1068 } 1069 }; 1070 1071 struct MemRefReshapeOpLowering 1072 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1073 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1074 1075 LogicalResult 1076 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1077 ConversionPatternRewriter &rewriter) const override { 1078 Type srcType = reshapeOp.source().getType(); 1079 1080 Value descriptor; 1081 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1082 adaptor, &descriptor))) 1083 return failure(); 1084 rewriter.replaceOp(reshapeOp, {descriptor}); 1085 return success(); 1086 } 1087 1088 private: 1089 LogicalResult 1090 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1091 Type srcType, memref::ReshapeOp reshapeOp, 1092 memref::ReshapeOp::Adaptor adaptor, 1093 Value *descriptor) const { 1094 // Conversion for statically-known shape args is performed via 1095 // `memref_reinterpret_cast`. 1096 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 1097 if (shapeMemRefType.hasStaticShape()) 1098 return failure(); 1099 1100 // The shape is a rank-1 tensor with unknown length. 1101 Location loc = reshapeOp.getLoc(); 1102 MemRefDescriptor shapeDesc(adaptor.shape()); 1103 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1104 1105 // Extract address space and element type. 1106 auto targetType = 1107 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1108 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1109 Type elementType = targetType.getElementType(); 1110 1111 // Create the unranked memref descriptor that holds the ranked one. The 1112 // inner descriptor is allocated on stack. 1113 auto targetDesc = UnrankedMemRefDescriptor::undef( 1114 rewriter, loc, typeConverter->convertType(targetType)); 1115 targetDesc.setRank(rewriter, loc, resultRank); 1116 SmallVector<Value, 4> sizes; 1117 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1118 targetDesc, sizes); 1119 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1120 loc, getVoidPtrType(), sizes.front(), llvm::None); 1121 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1122 1123 // Extract pointers and offset from the source memref. 1124 Value allocatedPtr, alignedPtr, offset; 1125 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1126 reshapeOp.source(), adaptor.source(), 1127 &allocatedPtr, &alignedPtr, &offset); 1128 1129 // Set pointers and offset. 1130 Type llvmElementType = typeConverter->convertType(elementType); 1131 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1132 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1133 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1134 elementPtrPtrType, allocatedPtr); 1135 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1136 underlyingDescPtr, 1137 elementPtrPtrType, alignedPtr); 1138 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1139 underlyingDescPtr, elementPtrPtrType, 1140 offset); 1141 1142 // Use the offset pointer as base for further addressing. Copy over the new 1143 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1144 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1145 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1146 elementPtrPtrType); 1147 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1148 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1149 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1150 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1151 Value resultRankMinusOne = 1152 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1153 1154 Block *initBlock = rewriter.getInsertionBlock(); 1155 Type indexType = getTypeConverter()->getIndexType(); 1156 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1157 1158 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1159 {indexType, indexType}, {loc, loc}); 1160 1161 // Move the remaining initBlock ops to condBlock. 1162 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1163 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1164 1165 rewriter.setInsertionPointToEnd(initBlock); 1166 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1167 condBlock); 1168 rewriter.setInsertionPointToStart(condBlock); 1169 Value indexArg = condBlock->getArgument(0); 1170 Value strideArg = condBlock->getArgument(1); 1171 1172 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1173 Value pred = rewriter.create<LLVM::ICmpOp>( 1174 loc, IntegerType::get(rewriter.getContext(), 1), 1175 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1176 1177 Block *bodyBlock = 1178 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1179 rewriter.setInsertionPointToStart(bodyBlock); 1180 1181 // Copy size from shape to descriptor. 1182 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1183 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1184 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1185 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1186 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1187 targetSizesBase, indexArg, size); 1188 1189 // Write stride value and compute next one. 1190 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1191 targetStridesBase, indexArg, strideArg); 1192 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1193 1194 // Decrement loop counter and branch back. 1195 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1196 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1197 condBlock); 1198 1199 Block *remainder = 1200 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1201 1202 // Hook up the cond exit to the remainder. 1203 rewriter.setInsertionPointToEnd(condBlock); 1204 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1205 llvm::None); 1206 1207 // Reset position to beginning of new remainder block. 1208 rewriter.setInsertionPointToStart(remainder); 1209 1210 *descriptor = targetDesc; 1211 return success(); 1212 } 1213 }; 1214 1215 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1216 /// `Value`s. 1217 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1218 Type &llvmIndexType, 1219 ArrayRef<OpFoldResult> valueOrAttrVec) { 1220 return llvm::to_vector<4>( 1221 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1222 if (auto attr = value.dyn_cast<Attribute>()) 1223 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1224 return value.get<Value>(); 1225 })); 1226 } 1227 1228 /// Compute a map that for a given dimension of the expanded type gives the 1229 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1230 /// the `reassocation` maps. 1231 static DenseMap<int64_t, int64_t> 1232 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1233 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1234 for (auto &en : enumerate(reassociation)) { 1235 for (auto dim : en.value()) 1236 expandedDimToCollapsedDim[dim] = en.index(); 1237 } 1238 return expandedDimToCollapsedDim; 1239 } 1240 1241 static OpFoldResult 1242 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1243 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1244 MemRefDescriptor &inDesc, 1245 ArrayRef<int64_t> inStaticShape, 1246 ArrayRef<ReassociationIndices> reassocation, 1247 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1248 int64_t outDimSize = outStaticShape[outDimIndex]; 1249 if (!ShapedType::isDynamic(outDimSize)) 1250 return b.getIndexAttr(outDimSize); 1251 1252 // Calculate the multiplication of all the out dim sizes except the 1253 // current dim. 1254 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1255 int64_t otherDimSizesMul = 1; 1256 for (auto otherDimIndex : reassocation[inDimIndex]) { 1257 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1258 continue; 1259 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1260 assert(!ShapedType::isDynamic(otherDimSize) && 1261 "single dimension cannot be expanded into multiple dynamic " 1262 "dimensions"); 1263 otherDimSizesMul *= otherDimSize; 1264 } 1265 1266 // outDimSize = inDimSize / otherOutDimSizesMul 1267 int64_t inDimSize = inStaticShape[inDimIndex]; 1268 Value inDimSizeDynamic = 1269 ShapedType::isDynamic(inDimSize) 1270 ? inDesc.size(b, loc, inDimIndex) 1271 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1272 b.getIndexAttr(inDimSize)); 1273 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1274 loc, inDimSizeDynamic, 1275 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1276 b.getIndexAttr(otherDimSizesMul))); 1277 return outDimSizeDynamic; 1278 } 1279 1280 static OpFoldResult getCollapsedOutputDimSize( 1281 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1282 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1283 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1284 if (!ShapedType::isDynamic(outDimSize)) 1285 return b.getIndexAttr(outDimSize); 1286 1287 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1288 Value outDimSizeDynamic = c1; 1289 for (auto inDimIndex : reassocation[outDimIndex]) { 1290 int64_t inDimSize = inStaticShape[inDimIndex]; 1291 Value inDimSizeDynamic = 1292 ShapedType::isDynamic(inDimSize) 1293 ? inDesc.size(b, loc, inDimIndex) 1294 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1295 b.getIndexAttr(inDimSize)); 1296 outDimSizeDynamic = 1297 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1298 } 1299 return outDimSizeDynamic; 1300 } 1301 1302 static SmallVector<OpFoldResult, 4> 1303 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1304 ArrayRef<ReassociationIndices> reassociation, 1305 ArrayRef<int64_t> inStaticShape, 1306 MemRefDescriptor &inDesc, 1307 ArrayRef<int64_t> outStaticShape) { 1308 return llvm::to_vector<4>(llvm::map_range( 1309 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1310 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1311 outStaticShape[outDimIndex], 1312 inStaticShape, inDesc, reassociation); 1313 })); 1314 } 1315 1316 static SmallVector<OpFoldResult, 4> 1317 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1318 ArrayRef<ReassociationIndices> reassociation, 1319 ArrayRef<int64_t> inStaticShape, 1320 MemRefDescriptor &inDesc, 1321 ArrayRef<int64_t> outStaticShape) { 1322 DenseMap<int64_t, int64_t> outDimToInDimMap = 1323 getExpandedDimToCollapsedDimMap(reassociation); 1324 return llvm::to_vector<4>(llvm::map_range( 1325 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1326 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1327 outStaticShape, inDesc, inStaticShape, 1328 reassociation, outDimToInDimMap); 1329 })); 1330 } 1331 1332 static SmallVector<Value> 1333 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1334 ArrayRef<ReassociationIndices> reassociation, 1335 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1336 ArrayRef<int64_t> outStaticShape) { 1337 return outStaticShape.size() < inStaticShape.size() 1338 ? getAsValues(b, loc, llvmIndexType, 1339 getCollapsedOutputShape(b, loc, llvmIndexType, 1340 reassociation, inStaticShape, 1341 inDesc, outStaticShape)) 1342 : getAsValues(b, loc, llvmIndexType, 1343 getExpandedOutputShape(b, loc, llvmIndexType, 1344 reassociation, inStaticShape, 1345 inDesc, outStaticShape)); 1346 } 1347 1348 static void fillInStridesForExpandedMemDescriptor( 1349 OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, 1350 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1351 // See comments for computeExpandedLayoutMap for details on how the strides 1352 // are calculated. 1353 for (auto &en : llvm::enumerate(reassociation)) { 1354 auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); 1355 for (auto dstIndex : llvm::reverse(en.value())) { 1356 dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); 1357 Value size = dstDesc.size(b, loc, dstIndex); 1358 currentStrideToExpand = 1359 b.create<LLVM::MulOp>(loc, size, currentStrideToExpand); 1360 } 1361 } 1362 } 1363 1364 static void fillInStridesForCollapsedMemDescriptor( 1365 ConversionPatternRewriter &rewriter, Location loc, Operation *op, 1366 TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, 1367 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1368 // See comments for computeCollapsedLayoutMap for details on how the strides 1369 // are calculated. 1370 auto srcShape = srcType.getShape(); 1371 for (auto &en : llvm::enumerate(reassociation)) { 1372 rewriter.setInsertionPoint(op); 1373 auto dstIndex = en.index(); 1374 ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value()); 1375 while (srcShape[ref.back()] == 1 && ref.size() > 1) 1376 ref = ref.drop_back(); 1377 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { 1378 dstDesc.setStride(rewriter, loc, dstIndex, 1379 srcDesc.stride(rewriter, loc, ref.back())); 1380 } else { 1381 // Iterate over the source strides in reverse order. Skip over the 1382 // dimensions whose size is 1. 1383 // TODO: we should take the minimum stride in the reassociation group 1384 // instead of just the first where the dimension is not 1. 1385 // 1386 // +------------------------------------------------------+ 1387 // | curEntry: | 1388 // | %srcStride = strides[srcIndex] | 1389 // | %neOne = cmp sizes[srcIndex],1 +--+ 1390 // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | 1391 // +-------------------------+----------------------------+ | 1392 // | | 1393 // v | 1394 // +-----------------------------+ | 1395 // | nextEntry: | | 1396 // | ... +---+ | 1397 // +--------------+--------------+ | | 1398 // | | | 1399 // v | | 1400 // +-----------------------------+ | | 1401 // | nextEntry: | | | 1402 // | ... | | | 1403 // +--------------+--------------+ | +--------+ 1404 // | | | 1405 // v v v 1406 // +--------------------------------------------------+ 1407 // | continue(%newStride): | 1408 // | %newMemRefDes = setStride(%newStride,dstIndex) | 1409 // +--------------------------------------------------+ 1410 OpBuilder::InsertionGuard guard(rewriter); 1411 Block *initBlock = rewriter.getInsertionBlock(); 1412 Block *continueBlock = 1413 rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); 1414 continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); 1415 rewriter.setInsertionPointToStart(continueBlock); 1416 dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); 1417 1418 Block *curEntryBlock = initBlock; 1419 Block *nextEntryBlock; 1420 for (auto srcIndex : llvm::reverse(ref)) { 1421 if (srcShape[srcIndex] == 1 && srcIndex != ref.front()) 1422 continue; 1423 rewriter.setInsertionPointToEnd(curEntryBlock); 1424 Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); 1425 if (srcIndex == ref.front()) { 1426 rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock); 1427 break; 1428 } 1429 Value one = rewriter.create<LLVM::ConstantOp>( 1430 loc, typeConverter->convertType(rewriter.getI64Type()), 1431 rewriter.getI32IntegerAttr(1)); 1432 Value predNeOne = rewriter.create<LLVM::ICmpOp>( 1433 loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), 1434 one); 1435 { 1436 OpBuilder::InsertionGuard guard(rewriter); 1437 nextEntryBlock = rewriter.createBlock( 1438 initBlock->getParent(), Region::iterator(continueBlock), {}); 1439 } 1440 rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock, 1441 srcStride, nextEntryBlock, llvm::None); 1442 curEntryBlock = nextEntryBlock; 1443 } 1444 } 1445 } 1446 } 1447 1448 static void fillInDynamicStridesForMemDescriptor( 1449 ConversionPatternRewriter &b, Location loc, Operation *op, 1450 TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, 1451 MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, 1452 ArrayRef<ReassociationIndices> reassociation) { 1453 if (srcType.getRank() > dstType.getRank()) 1454 fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, 1455 srcDesc, dstDesc, reassociation); 1456 else 1457 fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, 1458 reassociation); 1459 } 1460 1461 // ReshapeOp creates a new view descriptor of the proper rank. 1462 // For now, the only conversion supported is for target MemRef with static sizes 1463 // and strides. 1464 template <typename ReshapeOp> 1465 class ReassociatingReshapeOpConversion 1466 : public ConvertOpToLLVMPattern<ReshapeOp> { 1467 public: 1468 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1469 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1470 1471 LogicalResult 1472 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1473 ConversionPatternRewriter &rewriter) const override { 1474 MemRefType dstType = reshapeOp.getResultType(); 1475 MemRefType srcType = reshapeOp.getSrcType(); 1476 1477 int64_t offset; 1478 SmallVector<int64_t, 4> strides; 1479 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1480 return rewriter.notifyMatchFailure( 1481 reshapeOp, "failed to get stride and offset exprs"); 1482 } 1483 1484 MemRefDescriptor srcDesc(adaptor.src()); 1485 Location loc = reshapeOp->getLoc(); 1486 auto dstDesc = MemRefDescriptor::undef( 1487 rewriter, loc, this->typeConverter->convertType(dstType)); 1488 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1489 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1490 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1491 1492 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1493 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1494 Type llvmIndexType = 1495 this->typeConverter->convertType(rewriter.getIndexType()); 1496 SmallVector<Value> dstShape = getDynamicOutputShape( 1497 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1498 srcStaticShape, srcDesc, dstStaticShape); 1499 for (auto &en : llvm::enumerate(dstShape)) 1500 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1501 1502 auto isStaticStride = [](int64_t stride) { 1503 return !ShapedType::isDynamicStrideOrOffset(stride); 1504 }; 1505 if (llvm::all_of(strides, isStaticStride)) { 1506 for (auto &en : llvm::enumerate(strides)) 1507 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1508 } else if (srcType.getLayout().isIdentity() && 1509 dstType.getLayout().isIdentity()) { 1510 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1511 rewriter.getIndexAttr(1)); 1512 Value stride = c1; 1513 for (auto dimIndex : 1514 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1515 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1516 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1517 } 1518 } else { 1519 // There could be mixed static/dynamic strides. For simplicity, we 1520 // recompute all strides if there is at least one dynamic stride. 1521 fillInDynamicStridesForMemDescriptor( 1522 rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, 1523 srcDesc, dstDesc, reshapeOp.getReassociationIndices()); 1524 } 1525 rewriter.replaceOp(reshapeOp, {dstDesc}); 1526 return success(); 1527 } 1528 }; 1529 1530 /// Conversion pattern that transforms a subview op into: 1531 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1532 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1533 /// and stride. 1534 /// The subview op is replaced by the descriptor. 1535 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1536 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1537 1538 LogicalResult 1539 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1540 ConversionPatternRewriter &rewriter) const override { 1541 auto loc = subViewOp.getLoc(); 1542 1543 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1544 auto sourceElementTy = 1545 typeConverter->convertType(sourceMemRefType.getElementType()); 1546 1547 auto viewMemRefType = subViewOp.getType(); 1548 auto inferredType = memref::SubViewOp::inferResultType( 1549 subViewOp.getSourceType(), 1550 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1551 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1552 extractFromI64ArrayAttr(subViewOp.static_strides())) 1553 .cast<MemRefType>(); 1554 auto targetElementTy = 1555 typeConverter->convertType(viewMemRefType.getElementType()); 1556 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1557 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1558 !LLVM::isCompatibleType(sourceElementTy) || 1559 !LLVM::isCompatibleType(targetElementTy) || 1560 !LLVM::isCompatibleType(targetDescTy)) 1561 return failure(); 1562 1563 // Extract the offset and strides from the type. 1564 int64_t offset; 1565 SmallVector<int64_t, 4> strides; 1566 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1567 if (failed(successStrides)) 1568 return failure(); 1569 1570 // Create the descriptor. 1571 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1572 return failure(); 1573 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1574 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1575 1576 // Copy the buffer pointer from the old descriptor to the new one. 1577 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1578 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1579 loc, 1580 LLVM::LLVMPointerType::get(targetElementTy, 1581 viewMemRefType.getMemorySpaceAsInt()), 1582 extracted); 1583 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1584 1585 // Copy the aligned pointer from the old descriptor to the new one. 1586 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1587 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1588 loc, 1589 LLVM::LLVMPointerType::get(targetElementTy, 1590 viewMemRefType.getMemorySpaceAsInt()), 1591 extracted); 1592 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1593 1594 size_t inferredShapeRank = inferredType.getRank(); 1595 size_t resultShapeRank = viewMemRefType.getRank(); 1596 1597 // Extract strides needed to compute offset. 1598 SmallVector<Value, 4> strideValues; 1599 strideValues.reserve(inferredShapeRank); 1600 for (unsigned i = 0; i < inferredShapeRank; ++i) 1601 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1602 1603 // Offset. 1604 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1605 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1606 targetMemRef.setConstantOffset(rewriter, loc, offset); 1607 } else { 1608 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1609 // `inferredShapeRank` may be larger than the number of offset operands 1610 // because of trailing semantics. In this case, the offset is guaranteed 1611 // to be interpreted as 0 and we can just skip the extra dimensions. 1612 for (unsigned i = 0, e = std::min(inferredShapeRank, 1613 subViewOp.getMixedOffsets().size()); 1614 i < e; ++i) { 1615 Value offset = 1616 // TODO: need OpFoldResult ODS adaptor to clean this up. 1617 subViewOp.isDynamicOffset(i) 1618 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1619 : rewriter.create<LLVM::ConstantOp>( 1620 loc, llvmIndexType, 1621 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1622 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1623 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1624 } 1625 targetMemRef.setOffset(rewriter, loc, baseOffset); 1626 } 1627 1628 // Update sizes and strides. 1629 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1630 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1631 assert(mixedSizes.size() == mixedStrides.size() && 1632 "expected sizes and strides of equal length"); 1633 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1634 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1635 i >= 0 && j >= 0; --i) { 1636 if (unusedDims.test(i)) 1637 continue; 1638 1639 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1640 // In this case, the size is guaranteed to be interpreted as Dim and the 1641 // stride as 1. 1642 Value size, stride; 1643 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1644 // If the static size is available, use it directly. This is similar to 1645 // the folding of dim(constant-op) but removes the need for dim to be 1646 // aware of LLVM constants and for this pass to be aware of std 1647 // constants. 1648 int64_t staticSize = 1649 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1650 if (staticSize != ShapedType::kDynamicSize) { 1651 size = rewriter.create<LLVM::ConstantOp>( 1652 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1653 } else { 1654 Value pos = rewriter.create<LLVM::ConstantOp>( 1655 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1656 Value dim = 1657 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1658 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1659 loc, llvmIndexType, dim); 1660 size = cast.getResult(0); 1661 } 1662 stride = rewriter.create<LLVM::ConstantOp>( 1663 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1664 } else { 1665 // TODO: need OpFoldResult ODS adaptor to clean this up. 1666 size = 1667 subViewOp.isDynamicSize(i) 1668 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1669 : rewriter.create<LLVM::ConstantOp>( 1670 loc, llvmIndexType, 1671 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1672 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1673 stride = rewriter.create<LLVM::ConstantOp>( 1674 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1675 } else { 1676 stride = 1677 subViewOp.isDynamicStride(i) 1678 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1679 : rewriter.create<LLVM::ConstantOp>( 1680 loc, llvmIndexType, 1681 rewriter.getI64IntegerAttr( 1682 subViewOp.getStaticStride(i))); 1683 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1684 } 1685 } 1686 targetMemRef.setSize(rewriter, loc, j, size); 1687 targetMemRef.setStride(rewriter, loc, j, stride); 1688 j--; 1689 } 1690 1691 rewriter.replaceOp(subViewOp, {targetMemRef}); 1692 return success(); 1693 } 1694 }; 1695 1696 /// Conversion pattern that transforms a transpose op into: 1697 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1698 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1699 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1700 /// and stride. Size and stride are permutations of the original values. 1701 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1702 /// The transpose op is replaced by the alloca'ed pointer. 1703 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1704 public: 1705 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1706 1707 LogicalResult 1708 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1709 ConversionPatternRewriter &rewriter) const override { 1710 auto loc = transposeOp.getLoc(); 1711 MemRefDescriptor viewMemRef(adaptor.in()); 1712 1713 // No permutation, early exit. 1714 if (transposeOp.permutation().isIdentity()) 1715 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1716 1717 auto targetMemRef = MemRefDescriptor::undef( 1718 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1719 1720 // Copy the base and aligned pointers from the old descriptor to the new 1721 // one. 1722 targetMemRef.setAllocatedPtr(rewriter, loc, 1723 viewMemRef.allocatedPtr(rewriter, loc)); 1724 targetMemRef.setAlignedPtr(rewriter, loc, 1725 viewMemRef.alignedPtr(rewriter, loc)); 1726 1727 // Copy the offset pointer from the old descriptor to the new one. 1728 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1729 1730 // Iterate over the dimensions and apply size/stride permutation. 1731 for (const auto &en : 1732 llvm::enumerate(transposeOp.permutation().getResults())) { 1733 int sourcePos = en.index(); 1734 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1735 targetMemRef.setSize(rewriter, loc, targetPos, 1736 viewMemRef.size(rewriter, loc, sourcePos)); 1737 targetMemRef.setStride(rewriter, loc, targetPos, 1738 viewMemRef.stride(rewriter, loc, sourcePos)); 1739 } 1740 1741 rewriter.replaceOp(transposeOp, {targetMemRef}); 1742 return success(); 1743 } 1744 }; 1745 1746 /// Conversion pattern that transforms an op into: 1747 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1748 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1749 /// and stride. 1750 /// The view op is replaced by the descriptor. 1751 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1752 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1753 1754 // Build and return the value for the idx^th shape dimension, either by 1755 // returning the constant shape dimension or counting the proper dynamic size. 1756 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1757 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1758 unsigned idx) const { 1759 assert(idx < shape.size()); 1760 if (!ShapedType::isDynamic(shape[idx])) 1761 return createIndexConstant(rewriter, loc, shape[idx]); 1762 // Count the number of dynamic dims in range [0, idx] 1763 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1764 return ShapedType::isDynamic(v); 1765 }); 1766 return dynamicSizes[nDynamic]; 1767 } 1768 1769 // Build and return the idx^th stride, either by returning the constant stride 1770 // or by computing the dynamic stride from the current `runningStride` and 1771 // `nextSize`. The caller should keep a running stride and update it with the 1772 // result returned by this function. 1773 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1774 ArrayRef<int64_t> strides, Value nextSize, 1775 Value runningStride, unsigned idx) const { 1776 assert(idx < strides.size()); 1777 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1778 return createIndexConstant(rewriter, loc, strides[idx]); 1779 if (nextSize) 1780 return runningStride 1781 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1782 : nextSize; 1783 assert(!runningStride); 1784 return createIndexConstant(rewriter, loc, 1); 1785 } 1786 1787 LogicalResult 1788 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1789 ConversionPatternRewriter &rewriter) const override { 1790 auto loc = viewOp.getLoc(); 1791 1792 auto viewMemRefType = viewOp.getType(); 1793 auto targetElementTy = 1794 typeConverter->convertType(viewMemRefType.getElementType()); 1795 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1796 if (!targetDescTy || !targetElementTy || 1797 !LLVM::isCompatibleType(targetElementTy) || 1798 !LLVM::isCompatibleType(targetDescTy)) 1799 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1800 failure(); 1801 1802 int64_t offset; 1803 SmallVector<int64_t, 4> strides; 1804 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1805 if (failed(successStrides)) 1806 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1807 assert(offset == 0 && "expected offset to be 0"); 1808 1809 // Create the descriptor. 1810 MemRefDescriptor sourceMemRef(adaptor.source()); 1811 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1812 1813 // Field 1: Copy the allocated pointer, used for malloc/free. 1814 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1815 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1816 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1817 loc, 1818 LLVM::LLVMPointerType::get(targetElementTy, 1819 srcMemRefType.getMemorySpaceAsInt()), 1820 allocatedPtr); 1821 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1822 1823 // Field 2: Copy the actual aligned pointer to payload. 1824 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1825 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1826 alignedPtr, adaptor.byte_shift()); 1827 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1828 loc, 1829 LLVM::LLVMPointerType::get(targetElementTy, 1830 srcMemRefType.getMemorySpaceAsInt()), 1831 alignedPtr); 1832 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1833 1834 // Field 3: The offset in the resulting type must be 0. This is because of 1835 // the type change: an offset on srcType* may not be expressible as an 1836 // offset on dstType*. 1837 targetMemRef.setOffset(rewriter, loc, 1838 createIndexConstant(rewriter, loc, offset)); 1839 1840 // Early exit for 0-D corner case. 1841 if (viewMemRefType.getRank() == 0) 1842 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1843 1844 // Fields 4 and 5: Update sizes and strides. 1845 if (strides.back() != 1) 1846 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1847 failure(); 1848 Value stride = nullptr, nextSize = nullptr; 1849 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1850 // Update size. 1851 Value size = 1852 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1853 targetMemRef.setSize(rewriter, loc, i, size); 1854 // Update stride. 1855 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1856 targetMemRef.setStride(rewriter, loc, i, stride); 1857 nextSize = size; 1858 } 1859 1860 rewriter.replaceOp(viewOp, {targetMemRef}); 1861 return success(); 1862 } 1863 }; 1864 1865 //===----------------------------------------------------------------------===// 1866 // AtomicRMWOpLowering 1867 //===----------------------------------------------------------------------===// 1868 1869 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1870 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1871 static Optional<LLVM::AtomicBinOp> 1872 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1873 switch (atomicOp.kind()) { 1874 case arith::AtomicRMWKind::addf: 1875 return LLVM::AtomicBinOp::fadd; 1876 case arith::AtomicRMWKind::addi: 1877 return LLVM::AtomicBinOp::add; 1878 case arith::AtomicRMWKind::assign: 1879 return LLVM::AtomicBinOp::xchg; 1880 case arith::AtomicRMWKind::maxs: 1881 return LLVM::AtomicBinOp::max; 1882 case arith::AtomicRMWKind::maxu: 1883 return LLVM::AtomicBinOp::umax; 1884 case arith::AtomicRMWKind::mins: 1885 return LLVM::AtomicBinOp::min; 1886 case arith::AtomicRMWKind::minu: 1887 return LLVM::AtomicBinOp::umin; 1888 case arith::AtomicRMWKind::ori: 1889 return LLVM::AtomicBinOp::_or; 1890 case arith::AtomicRMWKind::andi: 1891 return LLVM::AtomicBinOp::_and; 1892 default: 1893 return llvm::None; 1894 } 1895 llvm_unreachable("Invalid AtomicRMWKind"); 1896 } 1897 1898 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1899 using Base::Base; 1900 1901 LogicalResult 1902 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1903 ConversionPatternRewriter &rewriter) const override { 1904 if (failed(match(atomicOp))) 1905 return failure(); 1906 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1907 if (!maybeKind) 1908 return failure(); 1909 auto resultType = adaptor.value().getType(); 1910 auto memRefType = atomicOp.getMemRefType(); 1911 auto dataPtr = 1912 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1913 adaptor.indices(), rewriter); 1914 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1915 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1916 LLVM::AtomicOrdering::acq_rel); 1917 return success(); 1918 } 1919 }; 1920 1921 } // namespace 1922 1923 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1924 RewritePatternSet &patterns) { 1925 // clang-format off 1926 patterns.add< 1927 AllocaOpLowering, 1928 AllocaScopeOpLowering, 1929 AtomicRMWOpLowering, 1930 AssumeAlignmentOpLowering, 1931 DimOpLowering, 1932 GenericAtomicRMWOpLowering, 1933 GlobalMemrefOpLowering, 1934 GetGlobalMemrefOpLowering, 1935 LoadOpLowering, 1936 MemRefCastOpLowering, 1937 MemRefCopyOpLowering, 1938 MemRefReinterpretCastOpLowering, 1939 MemRefReshapeOpLowering, 1940 PrefetchOpLowering, 1941 RankOpLowering, 1942 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1943 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1944 StoreOpLowering, 1945 SubViewOpLowering, 1946 TransposeOpLowering, 1947 ViewOpLowering>(converter); 1948 // clang-format on 1949 auto allocLowering = converter.getOptions().allocLowering; 1950 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1951 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1952 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1953 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1954 } 1955 1956 namespace { 1957 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1958 MemRefToLLVMPass() = default; 1959 1960 void runOnOperation() override { 1961 Operation *op = getOperation(); 1962 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1963 LowerToLLVMOptions options(&getContext(), 1964 dataLayoutAnalysis.getAtOrAbove(op)); 1965 options.allocLowering = 1966 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1967 : LowerToLLVMOptions::AllocLowering::Malloc); 1968 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1969 options.overrideIndexBitwidth(indexBitwidth); 1970 1971 LLVMTypeConverter typeConverter(&getContext(), options, 1972 &dataLayoutAnalysis); 1973 RewritePatternSet patterns(&getContext()); 1974 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1975 LLVMConversionTarget target(getContext()); 1976 target.addLegalOp<func::FuncOp>(); 1977 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1978 signalPassFailure(); 1979 } 1980 }; 1981 } // namespace 1982 1983 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1984 return std::make_unique<MemRefToLLVMPass>(); 1985 } 1986