1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "llvm/ADT/SmallBitVector.h" 22 23 using namespace mlir; 24 25 namespace { 26 27 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 28 AllocOpLowering(LLVMTypeConverter &converter) 29 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 30 converter) {} 31 32 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 33 Location loc, Value sizeBytes, 34 Operation *op) const override { 35 // Heap allocations. 36 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 37 MemRefType memRefType = allocOp.getType(); 38 39 Value alignment; 40 if (auto alignmentAttr = allocOp.alignment()) { 41 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 42 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 43 // In the case where no alignment is specified, we may want to override 44 // `malloc's` behavior. `malloc` typically aligns at the size of the 45 // biggest scalar on a target HW. For non-scalars, use the natural 46 // alignment of the LLVM type given by the LLVM DataLayout. 47 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 48 } 49 50 if (alignment) { 51 // Adjust the allocation size to consider alignment. 52 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 53 } 54 55 // Allocate the underlying buffer and store a pointer to it in the MemRef 56 // descriptor. 57 Type elementPtrType = this->getElementPtrType(memRefType); 58 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 59 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 60 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 61 getVoidPtrType()); 62 Value allocatedPtr = 63 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 64 65 Value alignedPtr = allocatedPtr; 66 if (alignment) { 67 // Compute the aligned type pointer. 68 Value allocatedInt = 69 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 70 Value alignmentInt = 71 createAligned(rewriter, loc, allocatedInt, alignment); 72 alignedPtr = 73 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 74 } 75 76 return std::make_tuple(allocatedPtr, alignedPtr); 77 } 78 }; 79 80 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 81 AlignedAllocOpLowering(LLVMTypeConverter &converter) 82 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 83 converter) {} 84 85 /// Returns the memref's element size in bytes using the data layout active at 86 /// `op`. 87 // TODO: there are other places where this is used. Expose publicly? 88 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 89 const DataLayout *layout = &defaultLayout; 90 if (const DataLayoutAnalysis *analysis = 91 getTypeConverter()->getDataLayoutAnalysis()) { 92 layout = &analysis->getAbove(op); 93 } 94 Type elementType = memRefType.getElementType(); 95 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 96 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 97 *layout); 98 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 99 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 100 memRefElementType, *layout); 101 return layout->getTypeSize(elementType); 102 } 103 104 /// Returns true if the memref size in bytes is known to be a multiple of 105 /// factor assuming the data layout active at `op`. 106 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 107 Operation *op) const { 108 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 109 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 110 if (ShapedType::isDynamic(type.getDimSize(i))) 111 continue; 112 sizeDivisor = sizeDivisor * type.getDimSize(i); 113 } 114 return sizeDivisor % factor == 0; 115 } 116 117 /// Returns the alignment to be used for the allocation call itself. 118 /// aligned_alloc requires the allocation size to be a power of two, and the 119 /// allocation size to be a multiple of alignment, 120 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 121 if (Optional<uint64_t> alignment = allocOp.alignment()) 122 return *alignment; 123 124 // Whenever we don't have alignment set, we will use an alignment 125 // consistent with the element type; since the allocation size has to be a 126 // power of two, we will bump to the next power of two if it already isn't. 127 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 128 return std::max(kMinAlignedAllocAlignment, 129 llvm::PowerOf2Ceil(eltSizeBytes)); 130 } 131 132 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 133 Location loc, Value sizeBytes, 134 Operation *op) const override { 135 // Heap allocations. 136 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 137 MemRefType memRefType = allocOp.getType(); 138 int64_t alignment = getAllocationAlignment(allocOp); 139 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 140 141 // aligned_alloc requires size to be a multiple of alignment; we will pad 142 // the size to the next multiple if necessary. 143 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 144 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 145 146 Type elementPtrType = this->getElementPtrType(memRefType); 147 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 148 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 149 auto results = 150 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 151 getVoidPtrType()); 152 Value allocatedPtr = 153 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 154 155 return std::make_tuple(allocatedPtr, allocatedPtr); 156 } 157 158 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 159 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 160 161 /// Default layout to use in absence of the corresponding analysis. 162 DataLayout defaultLayout; 163 }; 164 165 // Out of line definition, required till C++17. 166 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 167 168 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 169 AllocaOpLowering(LLVMTypeConverter &converter) 170 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 171 converter) {} 172 173 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 174 /// is set to null for stack allocations. `accessAlignment` is set if 175 /// alignment is needed post allocation (for eg. in conjunction with malloc). 176 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 177 Location loc, Value sizeBytes, 178 Operation *op) const override { 179 180 // With alloca, one gets a pointer to the element type right away. 181 // For stack allocations. 182 auto allocaOp = cast<memref::AllocaOp>(op); 183 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 184 185 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 186 loc, elementPtrType, sizeBytes, 187 allocaOp.alignment() ? *allocaOp.alignment() : 0); 188 189 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 190 } 191 }; 192 193 struct AllocaScopeOpLowering 194 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 195 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 196 197 LogicalResult 198 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 199 ConversionPatternRewriter &rewriter) const override { 200 OpBuilder::InsertionGuard guard(rewriter); 201 Location loc = allocaScopeOp.getLoc(); 202 203 // Split the current block before the AllocaScopeOp to create the inlining 204 // point. 205 auto *currentBlock = rewriter.getInsertionBlock(); 206 auto *remainingOpsBlock = 207 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 208 Block *continueBlock; 209 if (allocaScopeOp.getNumResults() == 0) { 210 continueBlock = remainingOpsBlock; 211 } else { 212 continueBlock = rewriter.createBlock( 213 remainingOpsBlock, allocaScopeOp.getResultTypes(), 214 SmallVector<Location>(allocaScopeOp->getNumResults(), 215 allocaScopeOp.getLoc())); 216 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 217 } 218 219 // Inline body region. 220 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 221 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 222 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 223 224 // Save stack and then branch into the body of the region. 225 rewriter.setInsertionPointToEnd(currentBlock); 226 auto stackSaveOp = 227 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 228 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 229 230 // Replace the alloca_scope return with a branch that jumps out of the body. 231 // Stack restore before leaving the body region. 232 rewriter.setInsertionPointToEnd(afterBody); 233 auto returnOp = 234 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 235 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 236 returnOp, returnOp.results(), continueBlock); 237 238 // Insert stack restore before jumping out the body of the region. 239 rewriter.setInsertionPoint(branchOp); 240 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 241 242 // Replace the op with values return from the body region. 243 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 244 245 return success(); 246 } 247 }; 248 249 struct AssumeAlignmentOpLowering 250 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 251 using ConvertOpToLLVMPattern< 252 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 253 254 LogicalResult 255 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 256 ConversionPatternRewriter &rewriter) const override { 257 Value memref = adaptor.memref(); 258 unsigned alignment = op.alignment(); 259 auto loc = op.getLoc(); 260 261 MemRefDescriptor memRefDescriptor(memref); 262 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 263 264 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 265 // the asserted memref.alignedPtr isn't used anywhere else, as the real 266 // users like load/store/views always re-extract memref.alignedPtr as they 267 // get lowered. 268 // 269 // This relies on LLVM's CSE optimization (potentially after SROA), since 270 // after CSE all memref.alignedPtr instances get de-duplicated into the same 271 // pointer SSA value. 272 auto intPtrType = 273 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 274 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 275 Value mask = 276 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 277 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 278 rewriter.create<LLVM::AssumeOp>( 279 loc, rewriter.create<LLVM::ICmpOp>( 280 loc, LLVM::ICmpPredicate::eq, 281 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 282 283 rewriter.eraseOp(op); 284 return success(); 285 } 286 }; 287 288 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 289 // The memref descriptor being an SSA value, there is no need to clean it up 290 // in any way. 291 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 292 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 293 294 explicit DeallocOpLowering(LLVMTypeConverter &converter) 295 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 296 297 LogicalResult 298 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 299 ConversionPatternRewriter &rewriter) const override { 300 // Insert the `free` declaration if it is not already present. 301 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 302 MemRefDescriptor memref(adaptor.memref()); 303 Value casted = rewriter.create<LLVM::BitcastOp>( 304 op.getLoc(), getVoidPtrType(), 305 memref.allocatedPtr(rewriter, op.getLoc())); 306 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 307 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 308 return success(); 309 } 310 }; 311 312 // A `dim` is converted to a constant for static sizes and to an access to the 313 // size stored in the memref descriptor for dynamic sizes. 314 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 315 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 316 317 LogicalResult 318 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 319 ConversionPatternRewriter &rewriter) const override { 320 Type operandType = dimOp.source().getType(); 321 if (operandType.isa<UnrankedMemRefType>()) { 322 rewriter.replaceOp( 323 dimOp, {extractSizeOfUnrankedMemRef( 324 operandType, dimOp, adaptor.getOperands(), rewriter)}); 325 326 return success(); 327 } 328 if (operandType.isa<MemRefType>()) { 329 rewriter.replaceOp( 330 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 331 adaptor.getOperands(), rewriter)}); 332 return success(); 333 } 334 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 335 } 336 337 private: 338 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 339 OpAdaptor adaptor, 340 ConversionPatternRewriter &rewriter) const { 341 Location loc = dimOp.getLoc(); 342 343 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 344 auto scalarMemRefType = 345 MemRefType::get({}, unrankedMemRefType.getElementType()); 346 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 347 348 // Extract pointer to the underlying ranked descriptor and bitcast it to a 349 // memref<element_type> descriptor pointer to minimize the number of GEP 350 // operations. 351 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 352 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 353 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 354 loc, 355 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 356 addressSpace), 357 underlyingRankedDesc); 358 359 // Get pointer to offset field of memref<element_type> descriptor. 360 Type indexPtrTy = LLVM::LLVMPointerType::get( 361 getTypeConverter()->getIndexType(), addressSpace); 362 Value two = rewriter.create<LLVM::ConstantOp>( 363 loc, typeConverter->convertType(rewriter.getI32Type()), 364 rewriter.getI32IntegerAttr(2)); 365 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 366 loc, indexPtrTy, scalarMemRefDescPtr, 367 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 368 369 // The size value that we have to extract can be obtained using GEPop with 370 // `dimOp.index() + 1` index argument. 371 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 372 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 373 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 374 ValueRange({idxPlusOne})); 375 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 376 } 377 378 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 379 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 380 return idx; 381 382 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 383 return constantOp.getValue() 384 .cast<IntegerAttr>() 385 .getValue() 386 .getSExtValue(); 387 388 return llvm::None; 389 } 390 391 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 392 OpAdaptor adaptor, 393 ConversionPatternRewriter &rewriter) const { 394 Location loc = dimOp.getLoc(); 395 396 // Take advantage if index is constant. 397 MemRefType memRefType = operandType.cast<MemRefType>(); 398 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 399 int64_t i = index.getValue(); 400 if (memRefType.isDynamicDim(i)) { 401 // extract dynamic size from the memref descriptor. 402 MemRefDescriptor descriptor(adaptor.source()); 403 return descriptor.size(rewriter, loc, i); 404 } 405 // Use constant for static size. 406 int64_t dimSize = memRefType.getDimSize(i); 407 return createIndexConstant(rewriter, loc, dimSize); 408 } 409 Value index = adaptor.index(); 410 int64_t rank = memRefType.getRank(); 411 MemRefDescriptor memrefDescriptor(adaptor.source()); 412 return memrefDescriptor.size(rewriter, loc, index, rank); 413 } 414 }; 415 416 /// Common base for load and store operations on MemRefs. Restricts the match 417 /// to supported MemRef types. Provides functionality to emit code accessing a 418 /// specific element of the underlying data buffer. 419 template <typename Derived> 420 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 421 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 422 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 423 using Base = LoadStoreOpLowering<Derived>; 424 425 LogicalResult match(Derived op) const override { 426 MemRefType type = op.getMemRefType(); 427 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 428 } 429 }; 430 431 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 432 /// retried until it succeeds in atomically storing a new value into memory. 433 /// 434 /// +---------------------------------+ 435 /// | <code before the AtomicRMWOp> | 436 /// | <compute initial %loaded> | 437 /// | cf.br loop(%loaded) | 438 /// +---------------------------------+ 439 /// | 440 /// -------| | 441 /// | v v 442 /// | +--------------------------------+ 443 /// | | loop(%loaded): | 444 /// | | <body contents> | 445 /// | | %pair = cmpxchg | 446 /// | | %ok = %pair[0] | 447 /// | | %new = %pair[1] | 448 /// | | cf.cond_br %ok, end, loop(%new) | 449 /// | +--------------------------------+ 450 /// | | | 451 /// |----------- | 452 /// v 453 /// +--------------------------------+ 454 /// | end: | 455 /// | <code after the AtomicRMWOp> | 456 /// +--------------------------------+ 457 /// 458 struct GenericAtomicRMWOpLowering 459 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 460 using Base::Base; 461 462 LogicalResult 463 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 464 ConversionPatternRewriter &rewriter) const override { 465 auto loc = atomicOp.getLoc(); 466 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 467 468 // Split the block into initial, loop, and ending parts. 469 auto *initBlock = rewriter.getInsertionBlock(); 470 auto *loopBlock = rewriter.createBlock( 471 initBlock->getParent(), std::next(Region::iterator(initBlock)), 472 valueType, loc); 473 auto *endBlock = rewriter.createBlock( 474 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 475 476 // Operations range to be moved to `endBlock`. 477 auto opsToMoveStart = atomicOp->getIterator(); 478 auto opsToMoveEnd = initBlock->back().getIterator(); 479 480 // Compute the loaded value and branch to the loop block. 481 rewriter.setInsertionPointToEnd(initBlock); 482 auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); 483 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 484 adaptor.indices(), rewriter); 485 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 486 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 487 488 // Prepare the body of the loop block. 489 rewriter.setInsertionPointToStart(loopBlock); 490 491 // Clone the GenericAtomicRMWOp region and extract the result. 492 auto loopArgument = loopBlock->getArgument(0); 493 BlockAndValueMapping mapping; 494 mapping.map(atomicOp.getCurrentValue(), loopArgument); 495 Block &entryBlock = atomicOp.body().front(); 496 for (auto &nestedOp : entryBlock.without_terminator()) { 497 Operation *clone = rewriter.clone(nestedOp, mapping); 498 mapping.map(nestedOp.getResults(), clone->getResults()); 499 } 500 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 501 502 // Prepare the epilog of the loop block. 503 // Append the cmpxchg op to the end of the loop block. 504 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 505 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 506 auto boolType = IntegerType::get(rewriter.getContext(), 1); 507 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 508 {valueType, boolType}); 509 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 510 loc, pairType, dataPtr, loopArgument, result, successOrdering, 511 failureOrdering); 512 // Extract the %new_loaded and %ok values from the pair. 513 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 514 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 515 Value ok = rewriter.create<LLVM::ExtractValueOp>( 516 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 517 518 // Conditionally branch to the end or back to the loop depending on %ok. 519 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 520 loopBlock, newLoaded); 521 522 rewriter.setInsertionPointToEnd(endBlock); 523 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 524 std::next(opsToMoveEnd), rewriter); 525 526 // The 'result' of the atomic_rmw op is the newly loaded value. 527 rewriter.replaceOp(atomicOp, {newLoaded}); 528 529 return success(); 530 } 531 532 private: 533 // Clones a segment of ops [start, end) and erases the original. 534 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 535 Block::iterator start, Block::iterator end, 536 ConversionPatternRewriter &rewriter) const { 537 BlockAndValueMapping mapping; 538 mapping.map(oldResult, newResult); 539 SmallVector<Operation *, 2> opsToErase; 540 for (auto it = start; it != end; ++it) { 541 rewriter.clone(*it, mapping); 542 opsToErase.push_back(&*it); 543 } 544 for (auto *it : opsToErase) 545 rewriter.eraseOp(it); 546 } 547 }; 548 549 /// Returns the LLVM type of the global variable given the memref type `type`. 550 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 551 LLVMTypeConverter &typeConverter) { 552 // LLVM type for a global memref will be a multi-dimension array. For 553 // declarations or uninitialized global memrefs, we can potentially flatten 554 // this to a 1D array. However, for memref.global's with an initial value, 555 // we do not intend to flatten the ElementsAttribute when going from std -> 556 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 557 Type elementType = typeConverter.convertType(type.getElementType()); 558 Type arrayTy = elementType; 559 // Shape has the outermost dim at index 0, so need to walk it backwards 560 for (int64_t dim : llvm::reverse(type.getShape())) 561 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 562 return arrayTy; 563 } 564 565 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 566 struct GlobalMemrefOpLowering 567 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 568 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 569 570 LogicalResult 571 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 572 ConversionPatternRewriter &rewriter) const override { 573 MemRefType type = global.type(); 574 if (!isConvertibleAndHasIdentityMaps(type)) 575 return failure(); 576 577 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 578 579 LLVM::Linkage linkage = 580 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 581 582 Attribute initialValue = nullptr; 583 if (!global.isExternal() && !global.isUninitialized()) { 584 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 585 initialValue = elementsAttr; 586 587 // For scalar memrefs, the global variable created is of the element type, 588 // so unpack the elements attribute to extract the value. 589 if (type.getRank() == 0) 590 initialValue = elementsAttr.getSplatValue<Attribute>(); 591 } 592 593 uint64_t alignment = global.alignment().getValueOr(0); 594 595 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 596 global, arrayTy, global.constant(), linkage, global.sym_name(), 597 initialValue, alignment, type.getMemorySpaceAsInt()); 598 if (!global.isExternal() && global.isUninitialized()) { 599 Block *blk = new Block(); 600 newGlobal.getInitializerRegion().push_back(blk); 601 rewriter.setInsertionPointToStart(blk); 602 Value undef[] = { 603 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 604 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 605 } 606 return success(); 607 } 608 }; 609 610 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 611 /// the first element stashed into the descriptor. This reuses 612 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 613 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 614 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 615 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 616 converter) {} 617 618 /// Buffer "allocation" for memref.get_global op is getting the address of 619 /// the global variable referenced. 620 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 621 Location loc, Value sizeBytes, 622 Operation *op) const override { 623 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 624 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 625 unsigned memSpace = type.getMemorySpaceAsInt(); 626 627 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 628 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 629 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 630 631 // Get the address of the first element in the array by creating a GEP with 632 // the address of the GV as the base, and (rank + 1) number of 0 indices. 633 Type elementType = typeConverter->convertType(type.getElementType()); 634 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 635 636 SmallVector<Value> operands; 637 operands.insert(operands.end(), type.getRank() + 1, 638 createIndexConstant(rewriter, loc, 0)); 639 auto gep = 640 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 641 642 // We do not expect the memref obtained using `memref.get_global` to be 643 // ever deallocated. Set the allocated pointer to be known bad value to 644 // help debug if that ever happens. 645 auto intPtrType = getIntPtrType(memSpace); 646 Value deadBeefConst = 647 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 648 auto deadBeefPtr = 649 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 650 651 // Both allocated and aligned pointers are same. We could potentially stash 652 // a nullptr for the allocated pointer since we do not expect any dealloc. 653 return std::make_tuple(deadBeefPtr, gep); 654 } 655 }; 656 657 // Load operation is lowered to obtaining a pointer to the indexed element 658 // and loading it. 659 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 660 using Base::Base; 661 662 LogicalResult 663 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 664 ConversionPatternRewriter &rewriter) const override { 665 auto type = loadOp.getMemRefType(); 666 667 Value dataPtr = getStridedElementPtr( 668 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 669 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 670 return success(); 671 } 672 }; 673 674 // Store operation is lowered to obtaining a pointer to the indexed element, 675 // and storing the given value to it. 676 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 677 using Base::Base; 678 679 LogicalResult 680 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 681 ConversionPatternRewriter &rewriter) const override { 682 auto type = op.getMemRefType(); 683 684 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 685 adaptor.indices(), rewriter); 686 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 687 return success(); 688 } 689 }; 690 691 // The prefetch operation is lowered in a way similar to the load operation 692 // except that the llvm.prefetch operation is used for replacement. 693 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 694 using Base::Base; 695 696 LogicalResult 697 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 698 ConversionPatternRewriter &rewriter) const override { 699 auto type = prefetchOp.getMemRefType(); 700 auto loc = prefetchOp.getLoc(); 701 702 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 703 adaptor.indices(), rewriter); 704 705 // Replace with llvm.prefetch. 706 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 707 auto isWrite = rewriter.create<LLVM::ConstantOp>( 708 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 709 auto localityHint = rewriter.create<LLVM::ConstantOp>( 710 loc, llvmI32Type, 711 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 712 auto isData = rewriter.create<LLVM::ConstantOp>( 713 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 714 715 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 716 localityHint, isData); 717 return success(); 718 } 719 }; 720 721 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 722 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 723 724 LogicalResult 725 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 726 ConversionPatternRewriter &rewriter) const override { 727 Location loc = op.getLoc(); 728 Type operandType = op.memref().getType(); 729 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 730 UnrankedMemRefDescriptor desc(adaptor.memref()); 731 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 732 return success(); 733 } 734 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 735 rewriter.replaceOp( 736 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 737 return success(); 738 } 739 return failure(); 740 } 741 }; 742 743 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 744 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 745 746 LogicalResult match(memref::CastOp memRefCastOp) const override { 747 Type srcType = memRefCastOp.getOperand().getType(); 748 Type dstType = memRefCastOp.getType(); 749 750 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 751 // used for type erasure. For now they must preserve underlying element type 752 // and require source and result type to have the same rank. Therefore, 753 // perform a sanity check that the underlying structs are the same. Once op 754 // semantics are relaxed we can revisit. 755 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 756 return success(typeConverter->convertType(srcType) == 757 typeConverter->convertType(dstType)); 758 759 // At least one of the operands is unranked type 760 assert(srcType.isa<UnrankedMemRefType>() || 761 dstType.isa<UnrankedMemRefType>()); 762 763 // Unranked to unranked cast is disallowed 764 return !(srcType.isa<UnrankedMemRefType>() && 765 dstType.isa<UnrankedMemRefType>()) 766 ? success() 767 : failure(); 768 } 769 770 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 771 ConversionPatternRewriter &rewriter) const override { 772 auto srcType = memRefCastOp.getOperand().getType(); 773 auto dstType = memRefCastOp.getType(); 774 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 775 auto loc = memRefCastOp.getLoc(); 776 777 // For ranked/ranked case, just keep the original descriptor. 778 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 779 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 780 781 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 782 // Casting ranked to unranked memref type 783 // Set the rank in the destination from the memref type 784 // Allocate space on the stack and copy the src memref descriptor 785 // Set the ptr in the destination to the stack space 786 auto srcMemRefType = srcType.cast<MemRefType>(); 787 int64_t rank = srcMemRefType.getRank(); 788 // ptr = AllocaOp sizeof(MemRefDescriptor) 789 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 790 loc, adaptor.source(), rewriter); 791 // voidptr = BitCastOp srcType* to void* 792 auto voidPtr = 793 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 794 .getResult(); 795 // rank = ConstantOp srcRank 796 auto rankVal = rewriter.create<LLVM::ConstantOp>( 797 loc, getIndexType(), rewriter.getIndexAttr(rank)); 798 // undef = UndefOp 799 UnrankedMemRefDescriptor memRefDesc = 800 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 801 // d1 = InsertValueOp undef, rank, 0 802 memRefDesc.setRank(rewriter, loc, rankVal); 803 // d2 = InsertValueOp d1, voidptr, 1 804 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 805 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 806 807 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 808 // Casting from unranked type to ranked. 809 // The operation is assumed to be doing a correct cast. If the destination 810 // type mismatches the unranked the type, it is undefined behavior. 811 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 812 // ptr = ExtractValueOp src, 1 813 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 814 // castPtr = BitCastOp i8* to structTy* 815 auto castPtr = 816 rewriter 817 .create<LLVM::BitcastOp>( 818 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 819 .getResult(); 820 // struct = LoadOp castPtr 821 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 822 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 823 } else { 824 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 825 } 826 } 827 }; 828 829 /// Pattern to lower a `memref.copy` to llvm. 830 /// 831 /// For memrefs with identity layouts, the copy is lowered to the llvm 832 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 833 /// to the generic `MemrefCopyFn`. 834 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 835 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 836 837 LogicalResult 838 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 839 ConversionPatternRewriter &rewriter) const { 840 auto loc = op.getLoc(); 841 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 842 843 MemRefDescriptor srcDesc(adaptor.source()); 844 845 // Compute number of elements. 846 Value numElements = rewriter.create<LLVM::ConstantOp>( 847 loc, getIndexType(), rewriter.getIndexAttr(1)); 848 for (int pos = 0; pos < srcType.getRank(); ++pos) { 849 auto size = srcDesc.size(rewriter, loc, pos); 850 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 851 } 852 853 // Get element size. 854 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 855 // Compute total. 856 Value totalSize = 857 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 858 859 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 860 MemRefDescriptor targetDesc(adaptor.target()); 861 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 862 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 863 loc, typeConverter->convertType(rewriter.getI1Type()), 864 rewriter.getBoolAttr(false)); 865 rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize, 866 isVolatile); 867 rewriter.eraseOp(op); 868 869 return success(); 870 } 871 872 LogicalResult 873 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 874 ConversionPatternRewriter &rewriter) const { 875 auto loc = op.getLoc(); 876 auto srcType = op.source().getType().cast<BaseMemRefType>(); 877 auto targetType = op.target().getType().cast<BaseMemRefType>(); 878 879 // First make sure we have an unranked memref descriptor representation. 880 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 881 auto rank = rewriter.create<LLVM::ConstantOp>( 882 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 883 auto *typeConverter = getTypeConverter(); 884 auto ptr = 885 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 886 auto voidPtr = 887 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 888 .getResult(); 889 auto unrankedType = 890 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 891 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 892 unrankedType, 893 ValueRange{rank, voidPtr}); 894 }; 895 896 Value unrankedSource = srcType.hasRank() 897 ? makeUnranked(adaptor.source(), srcType) 898 : adaptor.source(); 899 Value unrankedTarget = targetType.hasRank() 900 ? makeUnranked(adaptor.target(), targetType) 901 : adaptor.target(); 902 903 // Now promote the unranked descriptors to the stack. 904 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 905 rewriter.getIndexAttr(1)); 906 auto promote = [&](Value desc) { 907 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 908 auto allocated = 909 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 910 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 911 return allocated; 912 }; 913 914 auto sourcePtr = promote(unrankedSource); 915 auto targetPtr = promote(unrankedTarget); 916 917 unsigned typeSize = 918 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 919 auto elemSize = rewriter.create<LLVM::ConstantOp>( 920 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 921 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 922 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 923 rewriter.create<LLVM::CallOp>(loc, copyFn, 924 ValueRange{elemSize, sourcePtr, targetPtr}); 925 rewriter.eraseOp(op); 926 927 return success(); 928 } 929 930 LogicalResult 931 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 932 ConversionPatternRewriter &rewriter) const override { 933 auto srcType = op.source().getType().cast<BaseMemRefType>(); 934 auto targetType = op.target().getType().cast<BaseMemRefType>(); 935 936 if (srcType.hasRank() && 937 srcType.cast<MemRefType>().getLayout().isIdentity() && 938 targetType.hasRank() && 939 targetType.cast<MemRefType>().getLayout().isIdentity()) 940 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 941 942 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 943 } 944 }; 945 946 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 947 /// memref type. In unranked case, the fields are extracted from the underlying 948 /// ranked descriptor. 949 static void extractPointersAndOffset(Location loc, 950 ConversionPatternRewriter &rewriter, 951 LLVMTypeConverter &typeConverter, 952 Value originalOperand, 953 Value convertedOperand, 954 Value *allocatedPtr, Value *alignedPtr, 955 Value *offset = nullptr) { 956 Type operandType = originalOperand.getType(); 957 if (operandType.isa<MemRefType>()) { 958 MemRefDescriptor desc(convertedOperand); 959 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 960 *alignedPtr = desc.alignedPtr(rewriter, loc); 961 if (offset != nullptr) 962 *offset = desc.offset(rewriter, loc); 963 return; 964 } 965 966 unsigned memorySpace = 967 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 968 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 969 Type llvmElementType = typeConverter.convertType(elementType); 970 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 971 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 972 973 // Extract pointer to the underlying ranked memref descriptor and cast it to 974 // ElemType**. 975 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 976 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 977 978 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 979 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 980 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 981 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 982 if (offset != nullptr) { 983 *offset = UnrankedMemRefDescriptor::offset( 984 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 985 } 986 } 987 988 struct MemRefReinterpretCastOpLowering 989 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 990 using ConvertOpToLLVMPattern< 991 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 992 993 LogicalResult 994 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 995 ConversionPatternRewriter &rewriter) const override { 996 Type srcType = castOp.source().getType(); 997 998 Value descriptor; 999 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1000 adaptor, &descriptor))) 1001 return failure(); 1002 rewriter.replaceOp(castOp, {descriptor}); 1003 return success(); 1004 } 1005 1006 private: 1007 LogicalResult convertSourceMemRefToDescriptor( 1008 ConversionPatternRewriter &rewriter, Type srcType, 1009 memref::ReinterpretCastOp castOp, 1010 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1011 MemRefType targetMemRefType = 1012 castOp.getResult().getType().cast<MemRefType>(); 1013 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1014 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1015 if (!llvmTargetDescriptorTy) 1016 return failure(); 1017 1018 // Create descriptor. 1019 Location loc = castOp.getLoc(); 1020 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1021 1022 // Set allocated and aligned pointers. 1023 Value allocatedPtr, alignedPtr; 1024 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1025 castOp.source(), adaptor.source(), &allocatedPtr, 1026 &alignedPtr); 1027 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1028 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1029 1030 // Set offset. 1031 if (castOp.isDynamicOffset(0)) 1032 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 1033 else 1034 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1035 1036 // Set sizes and strides. 1037 unsigned dynSizeId = 0; 1038 unsigned dynStrideId = 0; 1039 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1040 if (castOp.isDynamicSize(i)) 1041 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 1042 else 1043 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1044 1045 if (castOp.isDynamicStride(i)) 1046 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 1047 else 1048 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1049 } 1050 *descriptor = desc; 1051 return success(); 1052 } 1053 }; 1054 1055 struct MemRefReshapeOpLowering 1056 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1057 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1058 1059 LogicalResult 1060 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1061 ConversionPatternRewriter &rewriter) const override { 1062 Type srcType = reshapeOp.source().getType(); 1063 1064 Value descriptor; 1065 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1066 adaptor, &descriptor))) 1067 return failure(); 1068 rewriter.replaceOp(reshapeOp, {descriptor}); 1069 return success(); 1070 } 1071 1072 private: 1073 LogicalResult 1074 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1075 Type srcType, memref::ReshapeOp reshapeOp, 1076 memref::ReshapeOp::Adaptor adaptor, 1077 Value *descriptor) const { 1078 // Conversion for statically-known shape args is performed via 1079 // `memref_reinterpret_cast`. 1080 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 1081 if (shapeMemRefType.hasStaticShape()) 1082 return failure(); 1083 1084 // The shape is a rank-1 tensor with unknown length. 1085 Location loc = reshapeOp.getLoc(); 1086 MemRefDescriptor shapeDesc(adaptor.shape()); 1087 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1088 1089 // Extract address space and element type. 1090 auto targetType = 1091 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1092 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1093 Type elementType = targetType.getElementType(); 1094 1095 // Create the unranked memref descriptor that holds the ranked one. The 1096 // inner descriptor is allocated on stack. 1097 auto targetDesc = UnrankedMemRefDescriptor::undef( 1098 rewriter, loc, typeConverter->convertType(targetType)); 1099 targetDesc.setRank(rewriter, loc, resultRank); 1100 SmallVector<Value, 4> sizes; 1101 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1102 targetDesc, sizes); 1103 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1104 loc, getVoidPtrType(), sizes.front(), llvm::None); 1105 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1106 1107 // Extract pointers and offset from the source memref. 1108 Value allocatedPtr, alignedPtr, offset; 1109 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1110 reshapeOp.source(), adaptor.source(), 1111 &allocatedPtr, &alignedPtr, &offset); 1112 1113 // Set pointers and offset. 1114 Type llvmElementType = typeConverter->convertType(elementType); 1115 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1116 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1117 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1118 elementPtrPtrType, allocatedPtr); 1119 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1120 underlyingDescPtr, 1121 elementPtrPtrType, alignedPtr); 1122 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1123 underlyingDescPtr, elementPtrPtrType, 1124 offset); 1125 1126 // Use the offset pointer as base for further addressing. Copy over the new 1127 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1128 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1129 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1130 elementPtrPtrType); 1131 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1132 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1133 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1134 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1135 Value resultRankMinusOne = 1136 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1137 1138 Block *initBlock = rewriter.getInsertionBlock(); 1139 Type indexType = getTypeConverter()->getIndexType(); 1140 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1141 1142 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1143 {indexType, indexType}, {loc, loc}); 1144 1145 // Move the remaining initBlock ops to condBlock. 1146 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1147 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1148 1149 rewriter.setInsertionPointToEnd(initBlock); 1150 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1151 condBlock); 1152 rewriter.setInsertionPointToStart(condBlock); 1153 Value indexArg = condBlock->getArgument(0); 1154 Value strideArg = condBlock->getArgument(1); 1155 1156 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1157 Value pred = rewriter.create<LLVM::ICmpOp>( 1158 loc, IntegerType::get(rewriter.getContext(), 1), 1159 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1160 1161 Block *bodyBlock = 1162 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1163 rewriter.setInsertionPointToStart(bodyBlock); 1164 1165 // Copy size from shape to descriptor. 1166 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1167 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1168 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1169 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1170 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1171 targetSizesBase, indexArg, size); 1172 1173 // Write stride value and compute next one. 1174 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1175 targetStridesBase, indexArg, strideArg); 1176 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1177 1178 // Decrement loop counter and branch back. 1179 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1180 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1181 condBlock); 1182 1183 Block *remainder = 1184 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1185 1186 // Hook up the cond exit to the remainder. 1187 rewriter.setInsertionPointToEnd(condBlock); 1188 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1189 llvm::None); 1190 1191 // Reset position to beginning of new remainder block. 1192 rewriter.setInsertionPointToStart(remainder); 1193 1194 *descriptor = targetDesc; 1195 return success(); 1196 } 1197 }; 1198 1199 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1200 /// `Value`s. 1201 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1202 Type &llvmIndexType, 1203 ArrayRef<OpFoldResult> valueOrAttrVec) { 1204 return llvm::to_vector<4>( 1205 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1206 if (auto attr = value.dyn_cast<Attribute>()) 1207 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1208 return value.get<Value>(); 1209 })); 1210 } 1211 1212 /// Compute a map that for a given dimension of the expanded type gives the 1213 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1214 /// the `reassocation` maps. 1215 static DenseMap<int64_t, int64_t> 1216 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1217 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1218 for (auto &en : enumerate(reassociation)) { 1219 for (auto dim : en.value()) 1220 expandedDimToCollapsedDim[dim] = en.index(); 1221 } 1222 return expandedDimToCollapsedDim; 1223 } 1224 1225 static OpFoldResult 1226 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1227 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1228 MemRefDescriptor &inDesc, 1229 ArrayRef<int64_t> inStaticShape, 1230 ArrayRef<ReassociationIndices> reassocation, 1231 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1232 int64_t outDimSize = outStaticShape[outDimIndex]; 1233 if (!ShapedType::isDynamic(outDimSize)) 1234 return b.getIndexAttr(outDimSize); 1235 1236 // Calculate the multiplication of all the out dim sizes except the 1237 // current dim. 1238 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1239 int64_t otherDimSizesMul = 1; 1240 for (auto otherDimIndex : reassocation[inDimIndex]) { 1241 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1242 continue; 1243 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1244 assert(!ShapedType::isDynamic(otherDimSize) && 1245 "single dimension cannot be expanded into multiple dynamic " 1246 "dimensions"); 1247 otherDimSizesMul *= otherDimSize; 1248 } 1249 1250 // outDimSize = inDimSize / otherOutDimSizesMul 1251 int64_t inDimSize = inStaticShape[inDimIndex]; 1252 Value inDimSizeDynamic = 1253 ShapedType::isDynamic(inDimSize) 1254 ? inDesc.size(b, loc, inDimIndex) 1255 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1256 b.getIndexAttr(inDimSize)); 1257 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1258 loc, inDimSizeDynamic, 1259 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1260 b.getIndexAttr(otherDimSizesMul))); 1261 return outDimSizeDynamic; 1262 } 1263 1264 static OpFoldResult getCollapsedOutputDimSize( 1265 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1266 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1267 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1268 if (!ShapedType::isDynamic(outDimSize)) 1269 return b.getIndexAttr(outDimSize); 1270 1271 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1272 Value outDimSizeDynamic = c1; 1273 for (auto inDimIndex : reassocation[outDimIndex]) { 1274 int64_t inDimSize = inStaticShape[inDimIndex]; 1275 Value inDimSizeDynamic = 1276 ShapedType::isDynamic(inDimSize) 1277 ? inDesc.size(b, loc, inDimIndex) 1278 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1279 b.getIndexAttr(inDimSize)); 1280 outDimSizeDynamic = 1281 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1282 } 1283 return outDimSizeDynamic; 1284 } 1285 1286 static SmallVector<OpFoldResult, 4> 1287 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1288 ArrayRef<ReassociationIndices> reassocation, 1289 ArrayRef<int64_t> inStaticShape, 1290 MemRefDescriptor &inDesc, 1291 ArrayRef<int64_t> outStaticShape) { 1292 return llvm::to_vector<4>(llvm::map_range( 1293 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1294 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1295 outStaticShape[outDimIndex], 1296 inStaticShape, inDesc, reassocation); 1297 })); 1298 } 1299 1300 static SmallVector<OpFoldResult, 4> 1301 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1302 ArrayRef<ReassociationIndices> reassocation, 1303 ArrayRef<int64_t> inStaticShape, 1304 MemRefDescriptor &inDesc, 1305 ArrayRef<int64_t> outStaticShape) { 1306 DenseMap<int64_t, int64_t> outDimToInDimMap = 1307 getExpandedDimToCollapsedDimMap(reassocation); 1308 return llvm::to_vector<4>(llvm::map_range( 1309 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1310 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1311 outStaticShape, inDesc, inStaticShape, 1312 reassocation, outDimToInDimMap); 1313 })); 1314 } 1315 1316 static SmallVector<Value> 1317 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1318 ArrayRef<ReassociationIndices> reassocation, 1319 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1320 ArrayRef<int64_t> outStaticShape) { 1321 return outStaticShape.size() < inStaticShape.size() 1322 ? getAsValues(b, loc, llvmIndexType, 1323 getCollapsedOutputShape(b, loc, llvmIndexType, 1324 reassocation, inStaticShape, 1325 inDesc, outStaticShape)) 1326 : getAsValues(b, loc, llvmIndexType, 1327 getExpandedOutputShape(b, loc, llvmIndexType, 1328 reassocation, inStaticShape, 1329 inDesc, outStaticShape)); 1330 } 1331 1332 // ReshapeOp creates a new view descriptor of the proper rank. 1333 // For now, the only conversion supported is for target MemRef with static sizes 1334 // and strides. 1335 template <typename ReshapeOp> 1336 class ReassociatingReshapeOpConversion 1337 : public ConvertOpToLLVMPattern<ReshapeOp> { 1338 public: 1339 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1340 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1341 1342 LogicalResult 1343 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1344 ConversionPatternRewriter &rewriter) const override { 1345 MemRefType dstType = reshapeOp.getResultType(); 1346 MemRefType srcType = reshapeOp.getSrcType(); 1347 1348 // The condition on the layouts can be ignored when all shapes are static. 1349 if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { 1350 if (!srcType.getLayout().isIdentity() || 1351 !dstType.getLayout().isIdentity()) { 1352 return rewriter.notifyMatchFailure( 1353 reshapeOp, "only empty layout map is supported"); 1354 } 1355 } 1356 1357 int64_t offset; 1358 SmallVector<int64_t, 4> strides; 1359 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1360 return rewriter.notifyMatchFailure( 1361 reshapeOp, "failed to get stride and offset exprs"); 1362 } 1363 1364 MemRefDescriptor srcDesc(adaptor.src()); 1365 Location loc = reshapeOp->getLoc(); 1366 auto dstDesc = MemRefDescriptor::undef( 1367 rewriter, loc, this->typeConverter->convertType(dstType)); 1368 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1369 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1370 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1371 1372 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1373 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1374 Type llvmIndexType = 1375 this->typeConverter->convertType(rewriter.getIndexType()); 1376 SmallVector<Value> dstShape = getDynamicOutputShape( 1377 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1378 srcStaticShape, srcDesc, dstStaticShape); 1379 for (auto &en : llvm::enumerate(dstShape)) 1380 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1381 1382 auto isStaticStride = [](int64_t stride) { 1383 return !ShapedType::isDynamicStrideOrOffset(stride); 1384 }; 1385 if (llvm::all_of(strides, isStaticStride)) { 1386 for (auto &en : llvm::enumerate(strides)) 1387 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1388 } else { 1389 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1390 rewriter.getIndexAttr(1)); 1391 Value stride = c1; 1392 for (auto dimIndex : 1393 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1394 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1395 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1396 } 1397 } 1398 rewriter.replaceOp(reshapeOp, {dstDesc}); 1399 return success(); 1400 } 1401 }; 1402 1403 /// Conversion pattern that transforms a subview op into: 1404 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1405 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1406 /// and stride. 1407 /// The subview op is replaced by the descriptor. 1408 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1409 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1410 1411 LogicalResult 1412 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1413 ConversionPatternRewriter &rewriter) const override { 1414 auto loc = subViewOp.getLoc(); 1415 1416 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1417 auto sourceElementTy = 1418 typeConverter->convertType(sourceMemRefType.getElementType()); 1419 1420 auto viewMemRefType = subViewOp.getType(); 1421 auto inferredType = memref::SubViewOp::inferResultType( 1422 subViewOp.getSourceType(), 1423 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1424 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1425 extractFromI64ArrayAttr(subViewOp.static_strides())) 1426 .cast<MemRefType>(); 1427 auto targetElementTy = 1428 typeConverter->convertType(viewMemRefType.getElementType()); 1429 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1430 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1431 !LLVM::isCompatibleType(sourceElementTy) || 1432 !LLVM::isCompatibleType(targetElementTy) || 1433 !LLVM::isCompatibleType(targetDescTy)) 1434 return failure(); 1435 1436 // Extract the offset and strides from the type. 1437 int64_t offset; 1438 SmallVector<int64_t, 4> strides; 1439 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1440 if (failed(successStrides)) 1441 return failure(); 1442 1443 // Create the descriptor. 1444 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1445 return failure(); 1446 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1447 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1448 1449 // Copy the buffer pointer from the old descriptor to the new one. 1450 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1451 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1452 loc, 1453 LLVM::LLVMPointerType::get(targetElementTy, 1454 viewMemRefType.getMemorySpaceAsInt()), 1455 extracted); 1456 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1457 1458 // Copy the aligned pointer from the old descriptor to the new one. 1459 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1460 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1461 loc, 1462 LLVM::LLVMPointerType::get(targetElementTy, 1463 viewMemRefType.getMemorySpaceAsInt()), 1464 extracted); 1465 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1466 1467 size_t inferredShapeRank = inferredType.getRank(); 1468 size_t resultShapeRank = viewMemRefType.getRank(); 1469 1470 // Extract strides needed to compute offset. 1471 SmallVector<Value, 4> strideValues; 1472 strideValues.reserve(inferredShapeRank); 1473 for (unsigned i = 0; i < inferredShapeRank; ++i) 1474 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1475 1476 // Offset. 1477 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1478 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1479 targetMemRef.setConstantOffset(rewriter, loc, offset); 1480 } else { 1481 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1482 // `inferredShapeRank` may be larger than the number of offset operands 1483 // because of trailing semantics. In this case, the offset is guaranteed 1484 // to be interpreted as 0 and we can just skip the extra dimensions. 1485 for (unsigned i = 0, e = std::min(inferredShapeRank, 1486 subViewOp.getMixedOffsets().size()); 1487 i < e; ++i) { 1488 Value offset = 1489 // TODO: need OpFoldResult ODS adaptor to clean this up. 1490 subViewOp.isDynamicOffset(i) 1491 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1492 : rewriter.create<LLVM::ConstantOp>( 1493 loc, llvmIndexType, 1494 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1495 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1496 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1497 } 1498 targetMemRef.setOffset(rewriter, loc, baseOffset); 1499 } 1500 1501 // Update sizes and strides. 1502 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1503 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1504 assert(mixedSizes.size() == mixedStrides.size() && 1505 "expected sizes and strides of equal length"); 1506 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1507 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1508 i >= 0 && j >= 0; --i) { 1509 if (unusedDims.test(i)) 1510 continue; 1511 1512 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1513 // In this case, the size is guaranteed to be interpreted as Dim and the 1514 // stride as 1. 1515 Value size, stride; 1516 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1517 // If the static size is available, use it directly. This is similar to 1518 // the folding of dim(constant-op) but removes the need for dim to be 1519 // aware of LLVM constants and for this pass to be aware of std 1520 // constants. 1521 int64_t staticSize = 1522 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1523 if (staticSize != ShapedType::kDynamicSize) { 1524 size = rewriter.create<LLVM::ConstantOp>( 1525 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1526 } else { 1527 Value pos = rewriter.create<LLVM::ConstantOp>( 1528 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1529 Value dim = 1530 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1531 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1532 loc, llvmIndexType, dim); 1533 size = cast.getResult(0); 1534 } 1535 stride = rewriter.create<LLVM::ConstantOp>( 1536 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1537 } else { 1538 // TODO: need OpFoldResult ODS adaptor to clean this up. 1539 size = 1540 subViewOp.isDynamicSize(i) 1541 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1542 : rewriter.create<LLVM::ConstantOp>( 1543 loc, llvmIndexType, 1544 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1545 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1546 stride = rewriter.create<LLVM::ConstantOp>( 1547 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1548 } else { 1549 stride = 1550 subViewOp.isDynamicStride(i) 1551 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1552 : rewriter.create<LLVM::ConstantOp>( 1553 loc, llvmIndexType, 1554 rewriter.getI64IntegerAttr( 1555 subViewOp.getStaticStride(i))); 1556 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1557 } 1558 } 1559 targetMemRef.setSize(rewriter, loc, j, size); 1560 targetMemRef.setStride(rewriter, loc, j, stride); 1561 j--; 1562 } 1563 1564 rewriter.replaceOp(subViewOp, {targetMemRef}); 1565 return success(); 1566 } 1567 }; 1568 1569 /// Conversion pattern that transforms a transpose op into: 1570 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1571 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1572 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1573 /// and stride. Size and stride are permutations of the original values. 1574 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1575 /// The transpose op is replaced by the alloca'ed pointer. 1576 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1577 public: 1578 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1579 1580 LogicalResult 1581 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1582 ConversionPatternRewriter &rewriter) const override { 1583 auto loc = transposeOp.getLoc(); 1584 MemRefDescriptor viewMemRef(adaptor.in()); 1585 1586 // No permutation, early exit. 1587 if (transposeOp.permutation().isIdentity()) 1588 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1589 1590 auto targetMemRef = MemRefDescriptor::undef( 1591 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1592 1593 // Copy the base and aligned pointers from the old descriptor to the new 1594 // one. 1595 targetMemRef.setAllocatedPtr(rewriter, loc, 1596 viewMemRef.allocatedPtr(rewriter, loc)); 1597 targetMemRef.setAlignedPtr(rewriter, loc, 1598 viewMemRef.alignedPtr(rewriter, loc)); 1599 1600 // Copy the offset pointer from the old descriptor to the new one. 1601 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1602 1603 // Iterate over the dimensions and apply size/stride permutation. 1604 for (const auto &en : 1605 llvm::enumerate(transposeOp.permutation().getResults())) { 1606 int sourcePos = en.index(); 1607 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1608 targetMemRef.setSize(rewriter, loc, targetPos, 1609 viewMemRef.size(rewriter, loc, sourcePos)); 1610 targetMemRef.setStride(rewriter, loc, targetPos, 1611 viewMemRef.stride(rewriter, loc, sourcePos)); 1612 } 1613 1614 rewriter.replaceOp(transposeOp, {targetMemRef}); 1615 return success(); 1616 } 1617 }; 1618 1619 /// Conversion pattern that transforms an op into: 1620 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1621 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1622 /// and stride. 1623 /// The view op is replaced by the descriptor. 1624 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1625 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1626 1627 // Build and return the value for the idx^th shape dimension, either by 1628 // returning the constant shape dimension or counting the proper dynamic size. 1629 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1630 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1631 unsigned idx) const { 1632 assert(idx < shape.size()); 1633 if (!ShapedType::isDynamic(shape[idx])) 1634 return createIndexConstant(rewriter, loc, shape[idx]); 1635 // Count the number of dynamic dims in range [0, idx] 1636 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1637 return ShapedType::isDynamic(v); 1638 }); 1639 return dynamicSizes[nDynamic]; 1640 } 1641 1642 // Build and return the idx^th stride, either by returning the constant stride 1643 // or by computing the dynamic stride from the current `runningStride` and 1644 // `nextSize`. The caller should keep a running stride and update it with the 1645 // result returned by this function. 1646 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1647 ArrayRef<int64_t> strides, Value nextSize, 1648 Value runningStride, unsigned idx) const { 1649 assert(idx < strides.size()); 1650 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1651 return createIndexConstant(rewriter, loc, strides[idx]); 1652 if (nextSize) 1653 return runningStride 1654 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1655 : nextSize; 1656 assert(!runningStride); 1657 return createIndexConstant(rewriter, loc, 1); 1658 } 1659 1660 LogicalResult 1661 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1662 ConversionPatternRewriter &rewriter) const override { 1663 auto loc = viewOp.getLoc(); 1664 1665 auto viewMemRefType = viewOp.getType(); 1666 auto targetElementTy = 1667 typeConverter->convertType(viewMemRefType.getElementType()); 1668 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1669 if (!targetDescTy || !targetElementTy || 1670 !LLVM::isCompatibleType(targetElementTy) || 1671 !LLVM::isCompatibleType(targetDescTy)) 1672 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1673 failure(); 1674 1675 int64_t offset; 1676 SmallVector<int64_t, 4> strides; 1677 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1678 if (failed(successStrides)) 1679 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1680 assert(offset == 0 && "expected offset to be 0"); 1681 1682 // Create the descriptor. 1683 MemRefDescriptor sourceMemRef(adaptor.source()); 1684 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1685 1686 // Field 1: Copy the allocated pointer, used for malloc/free. 1687 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1688 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1689 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1690 loc, 1691 LLVM::LLVMPointerType::get(targetElementTy, 1692 srcMemRefType.getMemorySpaceAsInt()), 1693 allocatedPtr); 1694 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1695 1696 // Field 2: Copy the actual aligned pointer to payload. 1697 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1698 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1699 alignedPtr, adaptor.byte_shift()); 1700 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1701 loc, 1702 LLVM::LLVMPointerType::get(targetElementTy, 1703 srcMemRefType.getMemorySpaceAsInt()), 1704 alignedPtr); 1705 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1706 1707 // Field 3: The offset in the resulting type must be 0. This is because of 1708 // the type change: an offset on srcType* may not be expressible as an 1709 // offset on dstType*. 1710 targetMemRef.setOffset(rewriter, loc, 1711 createIndexConstant(rewriter, loc, offset)); 1712 1713 // Early exit for 0-D corner case. 1714 if (viewMemRefType.getRank() == 0) 1715 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1716 1717 // Fields 4 and 5: Update sizes and strides. 1718 if (strides.back() != 1) 1719 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1720 failure(); 1721 Value stride = nullptr, nextSize = nullptr; 1722 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1723 // Update size. 1724 Value size = 1725 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1726 targetMemRef.setSize(rewriter, loc, i, size); 1727 // Update stride. 1728 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1729 targetMemRef.setStride(rewriter, loc, i, stride); 1730 nextSize = size; 1731 } 1732 1733 rewriter.replaceOp(viewOp, {targetMemRef}); 1734 return success(); 1735 } 1736 }; 1737 1738 //===----------------------------------------------------------------------===// 1739 // AtomicRMWOpLowering 1740 //===----------------------------------------------------------------------===// 1741 1742 /// Try to match the kind of a std.atomic_rmw to determine whether to use a 1743 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1744 static Optional<LLVM::AtomicBinOp> 1745 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1746 switch (atomicOp.kind()) { 1747 case arith::AtomicRMWKind::addf: 1748 return LLVM::AtomicBinOp::fadd; 1749 case arith::AtomicRMWKind::addi: 1750 return LLVM::AtomicBinOp::add; 1751 case arith::AtomicRMWKind::assign: 1752 return LLVM::AtomicBinOp::xchg; 1753 case arith::AtomicRMWKind::maxs: 1754 return LLVM::AtomicBinOp::max; 1755 case arith::AtomicRMWKind::maxu: 1756 return LLVM::AtomicBinOp::umax; 1757 case arith::AtomicRMWKind::mins: 1758 return LLVM::AtomicBinOp::min; 1759 case arith::AtomicRMWKind::minu: 1760 return LLVM::AtomicBinOp::umin; 1761 case arith::AtomicRMWKind::ori: 1762 return LLVM::AtomicBinOp::_or; 1763 case arith::AtomicRMWKind::andi: 1764 return LLVM::AtomicBinOp::_and; 1765 default: 1766 return llvm::None; 1767 } 1768 llvm_unreachable("Invalid AtomicRMWKind"); 1769 } 1770 1771 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1772 using Base::Base; 1773 1774 LogicalResult 1775 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1776 ConversionPatternRewriter &rewriter) const override { 1777 if (failed(match(atomicOp))) 1778 return failure(); 1779 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1780 if (!maybeKind) 1781 return failure(); 1782 auto resultType = adaptor.value().getType(); 1783 auto memRefType = atomicOp.getMemRefType(); 1784 auto dataPtr = 1785 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1786 adaptor.indices(), rewriter); 1787 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1788 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1789 LLVM::AtomicOrdering::acq_rel); 1790 return success(); 1791 } 1792 }; 1793 1794 } // namespace 1795 1796 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1797 RewritePatternSet &patterns) { 1798 // clang-format off 1799 patterns.add< 1800 AllocaOpLowering, 1801 AllocaScopeOpLowering, 1802 AtomicRMWOpLowering, 1803 AssumeAlignmentOpLowering, 1804 DimOpLowering, 1805 GenericAtomicRMWOpLowering, 1806 GlobalMemrefOpLowering, 1807 GetGlobalMemrefOpLowering, 1808 LoadOpLowering, 1809 MemRefCastOpLowering, 1810 MemRefCopyOpLowering, 1811 MemRefReinterpretCastOpLowering, 1812 MemRefReshapeOpLowering, 1813 PrefetchOpLowering, 1814 RankOpLowering, 1815 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1816 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1817 StoreOpLowering, 1818 SubViewOpLowering, 1819 TransposeOpLowering, 1820 ViewOpLowering>(converter); 1821 // clang-format on 1822 auto allocLowering = converter.getOptions().allocLowering; 1823 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1824 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1825 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1826 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1827 } 1828 1829 namespace { 1830 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1831 MemRefToLLVMPass() = default; 1832 1833 void runOnOperation() override { 1834 Operation *op = getOperation(); 1835 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1836 LowerToLLVMOptions options(&getContext(), 1837 dataLayoutAnalysis.getAtOrAbove(op)); 1838 options.allocLowering = 1839 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1840 : LowerToLLVMOptions::AllocLowering::Malloc); 1841 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1842 options.overrideIndexBitwidth(indexBitwidth); 1843 1844 LLVMTypeConverter typeConverter(&getContext(), options, 1845 &dataLayoutAnalysis); 1846 RewritePatternSet patterns(&getContext()); 1847 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1848 LLVMConversionTarget target(getContext()); 1849 target.addLegalOp<FuncOp>(); 1850 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1851 signalPassFailure(); 1852 } 1853 }; 1854 } // namespace 1855 1856 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1857 return std::make_unique<MemRefToLLVMPass>(); 1858 } 1859