1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "llvm/ADT/SmallBitVector.h" 22 23 using namespace mlir; 24 25 namespace { 26 27 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 28 AllocOpLowering(LLVMTypeConverter &converter) 29 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 30 converter) {} 31 32 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 33 Location loc, Value sizeBytes, 34 Operation *op) const override { 35 // Heap allocations. 36 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 37 MemRefType memRefType = allocOp.getType(); 38 39 Value alignment; 40 if (auto alignmentAttr = allocOp.alignment()) { 41 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 42 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 43 // In the case where no alignment is specified, we may want to override 44 // `malloc's` behavior. `malloc` typically aligns at the size of the 45 // biggest scalar on a target HW. For non-scalars, use the natural 46 // alignment of the LLVM type given by the LLVM DataLayout. 47 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 48 } 49 50 if (alignment) { 51 // Adjust the allocation size to consider alignment. 52 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 53 } 54 55 // Allocate the underlying buffer and store a pointer to it in the MemRef 56 // descriptor. 57 Type elementPtrType = this->getElementPtrType(memRefType); 58 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 59 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 60 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 61 getVoidPtrType()); 62 Value allocatedPtr = 63 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 64 65 Value alignedPtr = allocatedPtr; 66 if (alignment) { 67 // Compute the aligned type pointer. 68 Value allocatedInt = 69 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 70 Value alignmentInt = 71 createAligned(rewriter, loc, allocatedInt, alignment); 72 alignedPtr = 73 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 74 } 75 76 return std::make_tuple(allocatedPtr, alignedPtr); 77 } 78 }; 79 80 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 81 AlignedAllocOpLowering(LLVMTypeConverter &converter) 82 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 83 converter) {} 84 85 /// Returns the memref's element size in bytes using the data layout active at 86 /// `op`. 87 // TODO: there are other places where this is used. Expose publicly? 88 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 89 const DataLayout *layout = &defaultLayout; 90 if (const DataLayoutAnalysis *analysis = 91 getTypeConverter()->getDataLayoutAnalysis()) { 92 layout = &analysis->getAbove(op); 93 } 94 Type elementType = memRefType.getElementType(); 95 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 96 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 97 *layout); 98 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 99 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 100 memRefElementType, *layout); 101 return layout->getTypeSize(elementType); 102 } 103 104 /// Returns true if the memref size in bytes is known to be a multiple of 105 /// factor assuming the data layout active at `op`. 106 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 107 Operation *op) const { 108 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 109 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 110 if (ShapedType::isDynamic(type.getDimSize(i))) 111 continue; 112 sizeDivisor = sizeDivisor * type.getDimSize(i); 113 } 114 return sizeDivisor % factor == 0; 115 } 116 117 /// Returns the alignment to be used for the allocation call itself. 118 /// aligned_alloc requires the allocation size to be a power of two, and the 119 /// allocation size to be a multiple of alignment, 120 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 121 if (Optional<uint64_t> alignment = allocOp.alignment()) 122 return *alignment; 123 124 // Whenever we don't have alignment set, we will use an alignment 125 // consistent with the element type; since the allocation size has to be a 126 // power of two, we will bump to the next power of two if it already isn't. 127 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 128 return std::max(kMinAlignedAllocAlignment, 129 llvm::PowerOf2Ceil(eltSizeBytes)); 130 } 131 132 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 133 Location loc, Value sizeBytes, 134 Operation *op) const override { 135 // Heap allocations. 136 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 137 MemRefType memRefType = allocOp.getType(); 138 int64_t alignment = getAllocationAlignment(allocOp); 139 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 140 141 // aligned_alloc requires size to be a multiple of alignment; we will pad 142 // the size to the next multiple if necessary. 143 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 144 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 145 146 Type elementPtrType = this->getElementPtrType(memRefType); 147 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 148 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 149 auto results = 150 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 151 getVoidPtrType()); 152 Value allocatedPtr = 153 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 154 155 return std::make_tuple(allocatedPtr, allocatedPtr); 156 } 157 158 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 159 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 160 161 /// Default layout to use in absence of the corresponding analysis. 162 DataLayout defaultLayout; 163 }; 164 165 // Out of line definition, required till C++17. 166 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 167 168 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 169 AllocaOpLowering(LLVMTypeConverter &converter) 170 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 171 converter) {} 172 173 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 174 /// is set to null for stack allocations. `accessAlignment` is set if 175 /// alignment is needed post allocation (for eg. in conjunction with malloc). 176 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 177 Location loc, Value sizeBytes, 178 Operation *op) const override { 179 180 // With alloca, one gets a pointer to the element type right away. 181 // For stack allocations. 182 auto allocaOp = cast<memref::AllocaOp>(op); 183 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 184 185 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 186 loc, elementPtrType, sizeBytes, 187 allocaOp.alignment() ? *allocaOp.alignment() : 0); 188 189 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 190 } 191 }; 192 193 struct AllocaScopeOpLowering 194 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 195 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 196 197 LogicalResult 198 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 199 ConversionPatternRewriter &rewriter) const override { 200 OpBuilder::InsertionGuard guard(rewriter); 201 Location loc = allocaScopeOp.getLoc(); 202 203 // Split the current block before the AllocaScopeOp to create the inlining 204 // point. 205 auto *currentBlock = rewriter.getInsertionBlock(); 206 auto *remainingOpsBlock = 207 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 208 Block *continueBlock; 209 if (allocaScopeOp.getNumResults() == 0) { 210 continueBlock = remainingOpsBlock; 211 } else { 212 continueBlock = rewriter.createBlock( 213 remainingOpsBlock, allocaScopeOp.getResultTypes(), 214 SmallVector<Location>(allocaScopeOp->getNumResults(), 215 allocaScopeOp.getLoc())); 216 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 217 } 218 219 // Inline body region. 220 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 221 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 222 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 223 224 // Save stack and then branch into the body of the region. 225 rewriter.setInsertionPointToEnd(currentBlock); 226 auto stackSaveOp = 227 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 228 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 229 230 // Replace the alloca_scope return with a branch that jumps out of the body. 231 // Stack restore before leaving the body region. 232 rewriter.setInsertionPointToEnd(afterBody); 233 auto returnOp = 234 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 235 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 236 returnOp, returnOp.results(), continueBlock); 237 238 // Insert stack restore before jumping out the body of the region. 239 rewriter.setInsertionPoint(branchOp); 240 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 241 242 // Replace the op with values return from the body region. 243 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 244 245 return success(); 246 } 247 }; 248 249 struct AssumeAlignmentOpLowering 250 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 251 using ConvertOpToLLVMPattern< 252 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 253 254 LogicalResult 255 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 256 ConversionPatternRewriter &rewriter) const override { 257 Value memref = adaptor.memref(); 258 unsigned alignment = op.alignment(); 259 auto loc = op.getLoc(); 260 261 MemRefDescriptor memRefDescriptor(memref); 262 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 263 264 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 265 // the asserted memref.alignedPtr isn't used anywhere else, as the real 266 // users like load/store/views always re-extract memref.alignedPtr as they 267 // get lowered. 268 // 269 // This relies on LLVM's CSE optimization (potentially after SROA), since 270 // after CSE all memref.alignedPtr instances get de-duplicated into the same 271 // pointer SSA value. 272 auto intPtrType = 273 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 274 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 275 Value mask = 276 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 277 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 278 rewriter.create<LLVM::AssumeOp>( 279 loc, rewriter.create<LLVM::ICmpOp>( 280 loc, LLVM::ICmpPredicate::eq, 281 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 282 283 rewriter.eraseOp(op); 284 return success(); 285 } 286 }; 287 288 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 289 // The memref descriptor being an SSA value, there is no need to clean it up 290 // in any way. 291 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 292 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 293 294 explicit DeallocOpLowering(LLVMTypeConverter &converter) 295 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 296 297 LogicalResult 298 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 299 ConversionPatternRewriter &rewriter) const override { 300 // Insert the `free` declaration if it is not already present. 301 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 302 MemRefDescriptor memref(adaptor.memref()); 303 Value casted = rewriter.create<LLVM::BitcastOp>( 304 op.getLoc(), getVoidPtrType(), 305 memref.allocatedPtr(rewriter, op.getLoc())); 306 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 307 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 308 return success(); 309 } 310 }; 311 312 // A `dim` is converted to a constant for static sizes and to an access to the 313 // size stored in the memref descriptor for dynamic sizes. 314 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 315 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 316 317 LogicalResult 318 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 319 ConversionPatternRewriter &rewriter) const override { 320 Type operandType = dimOp.source().getType(); 321 if (operandType.isa<UnrankedMemRefType>()) { 322 rewriter.replaceOp( 323 dimOp, {extractSizeOfUnrankedMemRef( 324 operandType, dimOp, adaptor.getOperands(), rewriter)}); 325 326 return success(); 327 } 328 if (operandType.isa<MemRefType>()) { 329 rewriter.replaceOp( 330 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 331 adaptor.getOperands(), rewriter)}); 332 return success(); 333 } 334 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 335 } 336 337 private: 338 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 339 OpAdaptor adaptor, 340 ConversionPatternRewriter &rewriter) const { 341 Location loc = dimOp.getLoc(); 342 343 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 344 auto scalarMemRefType = 345 MemRefType::get({}, unrankedMemRefType.getElementType()); 346 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 347 348 // Extract pointer to the underlying ranked descriptor and bitcast it to a 349 // memref<element_type> descriptor pointer to minimize the number of GEP 350 // operations. 351 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 352 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 353 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 354 loc, 355 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 356 addressSpace), 357 underlyingRankedDesc); 358 359 // Get pointer to offset field of memref<element_type> descriptor. 360 Type indexPtrTy = LLVM::LLVMPointerType::get( 361 getTypeConverter()->getIndexType(), addressSpace); 362 Value two = rewriter.create<LLVM::ConstantOp>( 363 loc, typeConverter->convertType(rewriter.getI32Type()), 364 rewriter.getI32IntegerAttr(2)); 365 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 366 loc, indexPtrTy, scalarMemRefDescPtr, 367 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 368 369 // The size value that we have to extract can be obtained using GEPop with 370 // `dimOp.index() + 1` index argument. 371 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 372 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 373 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 374 ValueRange({idxPlusOne})); 375 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 376 } 377 378 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 379 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 380 return idx; 381 382 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 383 return constantOp.getValue() 384 .cast<IntegerAttr>() 385 .getValue() 386 .getSExtValue(); 387 388 return llvm::None; 389 } 390 391 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 392 OpAdaptor adaptor, 393 ConversionPatternRewriter &rewriter) const { 394 Location loc = dimOp.getLoc(); 395 396 // Take advantage if index is constant. 397 MemRefType memRefType = operandType.cast<MemRefType>(); 398 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 399 int64_t i = index.getValue(); 400 if (memRefType.isDynamicDim(i)) { 401 // extract dynamic size from the memref descriptor. 402 MemRefDescriptor descriptor(adaptor.source()); 403 return descriptor.size(rewriter, loc, i); 404 } 405 // Use constant for static size. 406 int64_t dimSize = memRefType.getDimSize(i); 407 return createIndexConstant(rewriter, loc, dimSize); 408 } 409 Value index = adaptor.index(); 410 int64_t rank = memRefType.getRank(); 411 MemRefDescriptor memrefDescriptor(adaptor.source()); 412 return memrefDescriptor.size(rewriter, loc, index, rank); 413 } 414 }; 415 416 /// Common base for load and store operations on MemRefs. Restricts the match 417 /// to supported MemRef types. Provides functionality to emit code accessing a 418 /// specific element of the underlying data buffer. 419 template <typename Derived> 420 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 421 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 422 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 423 using Base = LoadStoreOpLowering<Derived>; 424 425 LogicalResult match(Derived op) const override { 426 MemRefType type = op.getMemRefType(); 427 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 428 } 429 }; 430 431 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 432 /// retried until it succeeds in atomically storing a new value into memory. 433 /// 434 /// +---------------------------------+ 435 /// | <code before the AtomicRMWOp> | 436 /// | <compute initial %loaded> | 437 /// | cf.br loop(%loaded) | 438 /// +---------------------------------+ 439 /// | 440 /// -------| | 441 /// | v v 442 /// | +--------------------------------+ 443 /// | | loop(%loaded): | 444 /// | | <body contents> | 445 /// | | %pair = cmpxchg | 446 /// | | %ok = %pair[0] | 447 /// | | %new = %pair[1] | 448 /// | | cf.cond_br %ok, end, loop(%new) | 449 /// | +--------------------------------+ 450 /// | | | 451 /// |----------- | 452 /// v 453 /// +--------------------------------+ 454 /// | end: | 455 /// | <code after the AtomicRMWOp> | 456 /// +--------------------------------+ 457 /// 458 struct GenericAtomicRMWOpLowering 459 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 460 using Base::Base; 461 462 LogicalResult 463 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 464 ConversionPatternRewriter &rewriter) const override { 465 auto loc = atomicOp.getLoc(); 466 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 467 468 // Split the block into initial, loop, and ending parts. 469 auto *initBlock = rewriter.getInsertionBlock(); 470 auto *loopBlock = rewriter.createBlock( 471 initBlock->getParent(), std::next(Region::iterator(initBlock)), 472 valueType, loc); 473 auto *endBlock = rewriter.createBlock( 474 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 475 476 // Operations range to be moved to `endBlock`. 477 auto opsToMoveStart = atomicOp->getIterator(); 478 auto opsToMoveEnd = initBlock->back().getIterator(); 479 480 // Compute the loaded value and branch to the loop block. 481 rewriter.setInsertionPointToEnd(initBlock); 482 auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); 483 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 484 adaptor.indices(), rewriter); 485 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 486 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 487 488 // Prepare the body of the loop block. 489 rewriter.setInsertionPointToStart(loopBlock); 490 491 // Clone the GenericAtomicRMWOp region and extract the result. 492 auto loopArgument = loopBlock->getArgument(0); 493 BlockAndValueMapping mapping; 494 mapping.map(atomicOp.getCurrentValue(), loopArgument); 495 Block &entryBlock = atomicOp.body().front(); 496 for (auto &nestedOp : entryBlock.without_terminator()) { 497 Operation *clone = rewriter.clone(nestedOp, mapping); 498 mapping.map(nestedOp.getResults(), clone->getResults()); 499 } 500 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 501 502 // Prepare the epilog of the loop block. 503 // Append the cmpxchg op to the end of the loop block. 504 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 505 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 506 auto boolType = IntegerType::get(rewriter.getContext(), 1); 507 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 508 {valueType, boolType}); 509 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 510 loc, pairType, dataPtr, loopArgument, result, successOrdering, 511 failureOrdering); 512 // Extract the %new_loaded and %ok values from the pair. 513 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 514 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 515 Value ok = rewriter.create<LLVM::ExtractValueOp>( 516 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 517 518 // Conditionally branch to the end or back to the loop depending on %ok. 519 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 520 loopBlock, newLoaded); 521 522 rewriter.setInsertionPointToEnd(endBlock); 523 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 524 std::next(opsToMoveEnd), rewriter); 525 526 // The 'result' of the atomic_rmw op is the newly loaded value. 527 rewriter.replaceOp(atomicOp, {newLoaded}); 528 529 return success(); 530 } 531 532 private: 533 // Clones a segment of ops [start, end) and erases the original. 534 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 535 Block::iterator start, Block::iterator end, 536 ConversionPatternRewriter &rewriter) const { 537 BlockAndValueMapping mapping; 538 mapping.map(oldResult, newResult); 539 SmallVector<Operation *, 2> opsToErase; 540 for (auto it = start; it != end; ++it) { 541 rewriter.clone(*it, mapping); 542 opsToErase.push_back(&*it); 543 } 544 for (auto *it : opsToErase) 545 rewriter.eraseOp(it); 546 } 547 }; 548 549 /// Returns the LLVM type of the global variable given the memref type `type`. 550 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 551 LLVMTypeConverter &typeConverter) { 552 // LLVM type for a global memref will be a multi-dimension array. For 553 // declarations or uninitialized global memrefs, we can potentially flatten 554 // this to a 1D array. However, for memref.global's with an initial value, 555 // we do not intend to flatten the ElementsAttribute when going from std -> 556 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 557 Type elementType = typeConverter.convertType(type.getElementType()); 558 Type arrayTy = elementType; 559 // Shape has the outermost dim at index 0, so need to walk it backwards 560 for (int64_t dim : llvm::reverse(type.getShape())) 561 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 562 return arrayTy; 563 } 564 565 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 566 struct GlobalMemrefOpLowering 567 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 568 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 569 570 LogicalResult 571 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 572 ConversionPatternRewriter &rewriter) const override { 573 MemRefType type = global.type(); 574 if (!isConvertibleAndHasIdentityMaps(type)) 575 return failure(); 576 577 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 578 579 LLVM::Linkage linkage = 580 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 581 582 Attribute initialValue = nullptr; 583 if (!global.isExternal() && !global.isUninitialized()) { 584 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 585 initialValue = elementsAttr; 586 587 // For scalar memrefs, the global variable created is of the element type, 588 // so unpack the elements attribute to extract the value. 589 if (type.getRank() == 0) 590 initialValue = elementsAttr.getSplatValue<Attribute>(); 591 } 592 593 uint64_t alignment = global.alignment().getValueOr(0); 594 595 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 596 global, arrayTy, global.constant(), linkage, global.sym_name(), 597 initialValue, alignment, type.getMemorySpaceAsInt()); 598 if (!global.isExternal() && global.isUninitialized()) { 599 Block *blk = new Block(); 600 newGlobal.getInitializerRegion().push_back(blk); 601 rewriter.setInsertionPointToStart(blk); 602 Value undef[] = { 603 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 604 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 605 } 606 return success(); 607 } 608 }; 609 610 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 611 /// the first element stashed into the descriptor. This reuses 612 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 613 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 614 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 615 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 616 converter) {} 617 618 /// Buffer "allocation" for memref.get_global op is getting the address of 619 /// the global variable referenced. 620 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 621 Location loc, Value sizeBytes, 622 Operation *op) const override { 623 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 624 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 625 unsigned memSpace = type.getMemorySpaceAsInt(); 626 627 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 628 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 629 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 630 631 // Get the address of the first element in the array by creating a GEP with 632 // the address of the GV as the base, and (rank + 1) number of 0 indices. 633 Type elementType = typeConverter->convertType(type.getElementType()); 634 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 635 636 SmallVector<Value> operands; 637 operands.insert(operands.end(), type.getRank() + 1, 638 createIndexConstant(rewriter, loc, 0)); 639 auto gep = 640 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 641 642 // We do not expect the memref obtained using `memref.get_global` to be 643 // ever deallocated. Set the allocated pointer to be known bad value to 644 // help debug if that ever happens. 645 auto intPtrType = getIntPtrType(memSpace); 646 Value deadBeefConst = 647 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 648 auto deadBeefPtr = 649 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 650 651 // Both allocated and aligned pointers are same. We could potentially stash 652 // a nullptr for the allocated pointer since we do not expect any dealloc. 653 return std::make_tuple(deadBeefPtr, gep); 654 } 655 }; 656 657 // Load operation is lowered to obtaining a pointer to the indexed element 658 // and loading it. 659 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 660 using Base::Base; 661 662 LogicalResult 663 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 664 ConversionPatternRewriter &rewriter) const override { 665 auto type = loadOp.getMemRefType(); 666 667 Value dataPtr = getStridedElementPtr( 668 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 669 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 670 return success(); 671 } 672 }; 673 674 // Store operation is lowered to obtaining a pointer to the indexed element, 675 // and storing the given value to it. 676 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 677 using Base::Base; 678 679 LogicalResult 680 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 681 ConversionPatternRewriter &rewriter) const override { 682 auto type = op.getMemRefType(); 683 684 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 685 adaptor.indices(), rewriter); 686 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 687 return success(); 688 } 689 }; 690 691 // The prefetch operation is lowered in a way similar to the load operation 692 // except that the llvm.prefetch operation is used for replacement. 693 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 694 using Base::Base; 695 696 LogicalResult 697 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 698 ConversionPatternRewriter &rewriter) const override { 699 auto type = prefetchOp.getMemRefType(); 700 auto loc = prefetchOp.getLoc(); 701 702 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 703 adaptor.indices(), rewriter); 704 705 // Replace with llvm.prefetch. 706 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 707 auto isWrite = rewriter.create<LLVM::ConstantOp>( 708 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 709 auto localityHint = rewriter.create<LLVM::ConstantOp>( 710 loc, llvmI32Type, 711 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 712 auto isData = rewriter.create<LLVM::ConstantOp>( 713 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 714 715 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 716 localityHint, isData); 717 return success(); 718 } 719 }; 720 721 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 722 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 723 724 LogicalResult 725 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 726 ConversionPatternRewriter &rewriter) const override { 727 Location loc = op.getLoc(); 728 Type operandType = op.memref().getType(); 729 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 730 UnrankedMemRefDescriptor desc(adaptor.memref()); 731 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 732 return success(); 733 } 734 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 735 rewriter.replaceOp( 736 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 737 return success(); 738 } 739 return failure(); 740 } 741 }; 742 743 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 744 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 745 746 LogicalResult match(memref::CastOp memRefCastOp) const override { 747 Type srcType = memRefCastOp.getOperand().getType(); 748 Type dstType = memRefCastOp.getType(); 749 750 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 751 // used for type erasure. For now they must preserve underlying element type 752 // and require source and result type to have the same rank. Therefore, 753 // perform a sanity check that the underlying structs are the same. Once op 754 // semantics are relaxed we can revisit. 755 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 756 return success(typeConverter->convertType(srcType) == 757 typeConverter->convertType(dstType)); 758 759 // At least one of the operands is unranked type 760 assert(srcType.isa<UnrankedMemRefType>() || 761 dstType.isa<UnrankedMemRefType>()); 762 763 // Unranked to unranked cast is disallowed 764 return !(srcType.isa<UnrankedMemRefType>() && 765 dstType.isa<UnrankedMemRefType>()) 766 ? success() 767 : failure(); 768 } 769 770 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 771 ConversionPatternRewriter &rewriter) const override { 772 auto srcType = memRefCastOp.getOperand().getType(); 773 auto dstType = memRefCastOp.getType(); 774 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 775 auto loc = memRefCastOp.getLoc(); 776 777 // For ranked/ranked case, just keep the original descriptor. 778 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 779 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 780 781 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 782 // Casting ranked to unranked memref type 783 // Set the rank in the destination from the memref type 784 // Allocate space on the stack and copy the src memref descriptor 785 // Set the ptr in the destination to the stack space 786 auto srcMemRefType = srcType.cast<MemRefType>(); 787 int64_t rank = srcMemRefType.getRank(); 788 // ptr = AllocaOp sizeof(MemRefDescriptor) 789 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 790 loc, adaptor.source(), rewriter); 791 // voidptr = BitCastOp srcType* to void* 792 auto voidPtr = 793 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 794 .getResult(); 795 // rank = ConstantOp srcRank 796 auto rankVal = rewriter.create<LLVM::ConstantOp>( 797 loc, getIndexType(), rewriter.getIndexAttr(rank)); 798 // undef = UndefOp 799 UnrankedMemRefDescriptor memRefDesc = 800 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 801 // d1 = InsertValueOp undef, rank, 0 802 memRefDesc.setRank(rewriter, loc, rankVal); 803 // d2 = InsertValueOp d1, voidptr, 1 804 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 805 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 806 807 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 808 // Casting from unranked type to ranked. 809 // The operation is assumed to be doing a correct cast. If the destination 810 // type mismatches the unranked the type, it is undefined behavior. 811 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 812 // ptr = ExtractValueOp src, 1 813 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 814 // castPtr = BitCastOp i8* to structTy* 815 auto castPtr = 816 rewriter 817 .create<LLVM::BitcastOp>( 818 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 819 .getResult(); 820 // struct = LoadOp castPtr 821 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 822 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 823 } else { 824 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 825 } 826 } 827 }; 828 829 /// Pattern to lower a `memref.copy` to llvm. 830 /// 831 /// For memrefs with identity layouts, the copy is lowered to the llvm 832 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 833 /// to the generic `MemrefCopyFn`. 834 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 835 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 836 837 LogicalResult 838 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 839 ConversionPatternRewriter &rewriter) const { 840 auto loc = op.getLoc(); 841 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 842 843 MemRefDescriptor srcDesc(adaptor.source()); 844 845 // Compute number of elements. 846 Value numElements = rewriter.create<LLVM::ConstantOp>( 847 loc, getIndexType(), rewriter.getIndexAttr(1)); 848 for (int pos = 0; pos < srcType.getRank(); ++pos) { 849 auto size = srcDesc.size(rewriter, loc, pos); 850 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 851 } 852 853 // Get element size. 854 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 855 // Compute total. 856 Value totalSize = 857 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 858 859 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 860 Value srcOffset = srcDesc.offset(rewriter, loc); 861 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(), 862 srcBasePtr, srcOffset); 863 MemRefDescriptor targetDesc(adaptor.target()); 864 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 865 Value targetOffset = targetDesc.offset(rewriter, loc); 866 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(), 867 targetBasePtr, targetOffset); 868 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 869 loc, typeConverter->convertType(rewriter.getI1Type()), 870 rewriter.getBoolAttr(false)); 871 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 872 isVolatile); 873 rewriter.eraseOp(op); 874 875 return success(); 876 } 877 878 LogicalResult 879 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 880 ConversionPatternRewriter &rewriter) const { 881 auto loc = op.getLoc(); 882 auto srcType = op.source().getType().cast<BaseMemRefType>(); 883 auto targetType = op.target().getType().cast<BaseMemRefType>(); 884 885 // First make sure we have an unranked memref descriptor representation. 886 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 887 auto rank = rewriter.create<LLVM::ConstantOp>( 888 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 889 auto *typeConverter = getTypeConverter(); 890 auto ptr = 891 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 892 auto voidPtr = 893 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 894 .getResult(); 895 auto unrankedType = 896 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 897 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 898 unrankedType, 899 ValueRange{rank, voidPtr}); 900 }; 901 902 Value unrankedSource = srcType.hasRank() 903 ? makeUnranked(adaptor.source(), srcType) 904 : adaptor.source(); 905 Value unrankedTarget = targetType.hasRank() 906 ? makeUnranked(adaptor.target(), targetType) 907 : adaptor.target(); 908 909 // Now promote the unranked descriptors to the stack. 910 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 911 rewriter.getIndexAttr(1)); 912 auto promote = [&](Value desc) { 913 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 914 auto allocated = 915 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 916 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 917 return allocated; 918 }; 919 920 auto sourcePtr = promote(unrankedSource); 921 auto targetPtr = promote(unrankedTarget); 922 923 unsigned typeSize = 924 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 925 auto elemSize = rewriter.create<LLVM::ConstantOp>( 926 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 927 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 928 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 929 rewriter.create<LLVM::CallOp>(loc, copyFn, 930 ValueRange{elemSize, sourcePtr, targetPtr}); 931 rewriter.eraseOp(op); 932 933 return success(); 934 } 935 936 LogicalResult 937 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 938 ConversionPatternRewriter &rewriter) const override { 939 auto srcType = op.source().getType().cast<BaseMemRefType>(); 940 auto targetType = op.target().getType().cast<BaseMemRefType>(); 941 942 auto isContiguousMemrefType = [](BaseMemRefType type) { 943 auto memrefType = type.dyn_cast<mlir::MemRefType>(); 944 // We can use memcpy for memrefs if they have an identity layout or are 945 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 946 // special case handled by memrefCopy. 947 return memrefType && 948 (memrefType.getLayout().isIdentity() || 949 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 950 isStaticShapeAndContiguousRowMajor(memrefType))); 951 }; 952 953 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 954 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 955 956 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 957 } 958 }; 959 960 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 961 /// memref type. In unranked case, the fields are extracted from the underlying 962 /// ranked descriptor. 963 static void extractPointersAndOffset(Location loc, 964 ConversionPatternRewriter &rewriter, 965 LLVMTypeConverter &typeConverter, 966 Value originalOperand, 967 Value convertedOperand, 968 Value *allocatedPtr, Value *alignedPtr, 969 Value *offset = nullptr) { 970 Type operandType = originalOperand.getType(); 971 if (operandType.isa<MemRefType>()) { 972 MemRefDescriptor desc(convertedOperand); 973 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 974 *alignedPtr = desc.alignedPtr(rewriter, loc); 975 if (offset != nullptr) 976 *offset = desc.offset(rewriter, loc); 977 return; 978 } 979 980 unsigned memorySpace = 981 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 982 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 983 Type llvmElementType = typeConverter.convertType(elementType); 984 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 985 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 986 987 // Extract pointer to the underlying ranked memref descriptor and cast it to 988 // ElemType**. 989 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 990 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 991 992 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 993 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 994 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 995 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 996 if (offset != nullptr) { 997 *offset = UnrankedMemRefDescriptor::offset( 998 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 999 } 1000 } 1001 1002 struct MemRefReinterpretCastOpLowering 1003 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1004 using ConvertOpToLLVMPattern< 1005 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1006 1007 LogicalResult 1008 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1009 ConversionPatternRewriter &rewriter) const override { 1010 Type srcType = castOp.source().getType(); 1011 1012 Value descriptor; 1013 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1014 adaptor, &descriptor))) 1015 return failure(); 1016 rewriter.replaceOp(castOp, {descriptor}); 1017 return success(); 1018 } 1019 1020 private: 1021 LogicalResult convertSourceMemRefToDescriptor( 1022 ConversionPatternRewriter &rewriter, Type srcType, 1023 memref::ReinterpretCastOp castOp, 1024 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1025 MemRefType targetMemRefType = 1026 castOp.getResult().getType().cast<MemRefType>(); 1027 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1028 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1029 if (!llvmTargetDescriptorTy) 1030 return failure(); 1031 1032 // Create descriptor. 1033 Location loc = castOp.getLoc(); 1034 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1035 1036 // Set allocated and aligned pointers. 1037 Value allocatedPtr, alignedPtr; 1038 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1039 castOp.source(), adaptor.source(), &allocatedPtr, 1040 &alignedPtr); 1041 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1042 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1043 1044 // Set offset. 1045 if (castOp.isDynamicOffset(0)) 1046 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 1047 else 1048 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1049 1050 // Set sizes and strides. 1051 unsigned dynSizeId = 0; 1052 unsigned dynStrideId = 0; 1053 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1054 if (castOp.isDynamicSize(i)) 1055 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 1056 else 1057 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1058 1059 if (castOp.isDynamicStride(i)) 1060 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 1061 else 1062 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1063 } 1064 *descriptor = desc; 1065 return success(); 1066 } 1067 }; 1068 1069 struct MemRefReshapeOpLowering 1070 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1071 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1072 1073 LogicalResult 1074 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1075 ConversionPatternRewriter &rewriter) const override { 1076 Type srcType = reshapeOp.source().getType(); 1077 1078 Value descriptor; 1079 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1080 adaptor, &descriptor))) 1081 return failure(); 1082 rewriter.replaceOp(reshapeOp, {descriptor}); 1083 return success(); 1084 } 1085 1086 private: 1087 LogicalResult 1088 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1089 Type srcType, memref::ReshapeOp reshapeOp, 1090 memref::ReshapeOp::Adaptor adaptor, 1091 Value *descriptor) const { 1092 // Conversion for statically-known shape args is performed via 1093 // `memref_reinterpret_cast`. 1094 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 1095 if (shapeMemRefType.hasStaticShape()) 1096 return failure(); 1097 1098 // The shape is a rank-1 tensor with unknown length. 1099 Location loc = reshapeOp.getLoc(); 1100 MemRefDescriptor shapeDesc(adaptor.shape()); 1101 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1102 1103 // Extract address space and element type. 1104 auto targetType = 1105 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1106 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1107 Type elementType = targetType.getElementType(); 1108 1109 // Create the unranked memref descriptor that holds the ranked one. The 1110 // inner descriptor is allocated on stack. 1111 auto targetDesc = UnrankedMemRefDescriptor::undef( 1112 rewriter, loc, typeConverter->convertType(targetType)); 1113 targetDesc.setRank(rewriter, loc, resultRank); 1114 SmallVector<Value, 4> sizes; 1115 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1116 targetDesc, sizes); 1117 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1118 loc, getVoidPtrType(), sizes.front(), llvm::None); 1119 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1120 1121 // Extract pointers and offset from the source memref. 1122 Value allocatedPtr, alignedPtr, offset; 1123 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1124 reshapeOp.source(), adaptor.source(), 1125 &allocatedPtr, &alignedPtr, &offset); 1126 1127 // Set pointers and offset. 1128 Type llvmElementType = typeConverter->convertType(elementType); 1129 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1130 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1131 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1132 elementPtrPtrType, allocatedPtr); 1133 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1134 underlyingDescPtr, 1135 elementPtrPtrType, alignedPtr); 1136 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1137 underlyingDescPtr, elementPtrPtrType, 1138 offset); 1139 1140 // Use the offset pointer as base for further addressing. Copy over the new 1141 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1142 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1143 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1144 elementPtrPtrType); 1145 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1146 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1147 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1148 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1149 Value resultRankMinusOne = 1150 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1151 1152 Block *initBlock = rewriter.getInsertionBlock(); 1153 Type indexType = getTypeConverter()->getIndexType(); 1154 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1155 1156 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1157 {indexType, indexType}, {loc, loc}); 1158 1159 // Move the remaining initBlock ops to condBlock. 1160 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1161 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1162 1163 rewriter.setInsertionPointToEnd(initBlock); 1164 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1165 condBlock); 1166 rewriter.setInsertionPointToStart(condBlock); 1167 Value indexArg = condBlock->getArgument(0); 1168 Value strideArg = condBlock->getArgument(1); 1169 1170 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1171 Value pred = rewriter.create<LLVM::ICmpOp>( 1172 loc, IntegerType::get(rewriter.getContext(), 1), 1173 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1174 1175 Block *bodyBlock = 1176 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1177 rewriter.setInsertionPointToStart(bodyBlock); 1178 1179 // Copy size from shape to descriptor. 1180 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1181 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1182 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1183 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1184 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1185 targetSizesBase, indexArg, size); 1186 1187 // Write stride value and compute next one. 1188 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1189 targetStridesBase, indexArg, strideArg); 1190 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1191 1192 // Decrement loop counter and branch back. 1193 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1194 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1195 condBlock); 1196 1197 Block *remainder = 1198 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1199 1200 // Hook up the cond exit to the remainder. 1201 rewriter.setInsertionPointToEnd(condBlock); 1202 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1203 llvm::None); 1204 1205 // Reset position to beginning of new remainder block. 1206 rewriter.setInsertionPointToStart(remainder); 1207 1208 *descriptor = targetDesc; 1209 return success(); 1210 } 1211 }; 1212 1213 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1214 /// `Value`s. 1215 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1216 Type &llvmIndexType, 1217 ArrayRef<OpFoldResult> valueOrAttrVec) { 1218 return llvm::to_vector<4>( 1219 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1220 if (auto attr = value.dyn_cast<Attribute>()) 1221 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1222 return value.get<Value>(); 1223 })); 1224 } 1225 1226 /// Compute a map that for a given dimension of the expanded type gives the 1227 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1228 /// the `reassocation` maps. 1229 static DenseMap<int64_t, int64_t> 1230 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1231 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1232 for (auto &en : enumerate(reassociation)) { 1233 for (auto dim : en.value()) 1234 expandedDimToCollapsedDim[dim] = en.index(); 1235 } 1236 return expandedDimToCollapsedDim; 1237 } 1238 1239 static OpFoldResult 1240 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1241 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1242 MemRefDescriptor &inDesc, 1243 ArrayRef<int64_t> inStaticShape, 1244 ArrayRef<ReassociationIndices> reassocation, 1245 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1246 int64_t outDimSize = outStaticShape[outDimIndex]; 1247 if (!ShapedType::isDynamic(outDimSize)) 1248 return b.getIndexAttr(outDimSize); 1249 1250 // Calculate the multiplication of all the out dim sizes except the 1251 // current dim. 1252 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1253 int64_t otherDimSizesMul = 1; 1254 for (auto otherDimIndex : reassocation[inDimIndex]) { 1255 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1256 continue; 1257 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1258 assert(!ShapedType::isDynamic(otherDimSize) && 1259 "single dimension cannot be expanded into multiple dynamic " 1260 "dimensions"); 1261 otherDimSizesMul *= otherDimSize; 1262 } 1263 1264 // outDimSize = inDimSize / otherOutDimSizesMul 1265 int64_t inDimSize = inStaticShape[inDimIndex]; 1266 Value inDimSizeDynamic = 1267 ShapedType::isDynamic(inDimSize) 1268 ? inDesc.size(b, loc, inDimIndex) 1269 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1270 b.getIndexAttr(inDimSize)); 1271 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1272 loc, inDimSizeDynamic, 1273 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1274 b.getIndexAttr(otherDimSizesMul))); 1275 return outDimSizeDynamic; 1276 } 1277 1278 static OpFoldResult getCollapsedOutputDimSize( 1279 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1280 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1281 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1282 if (!ShapedType::isDynamic(outDimSize)) 1283 return b.getIndexAttr(outDimSize); 1284 1285 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1286 Value outDimSizeDynamic = c1; 1287 for (auto inDimIndex : reassocation[outDimIndex]) { 1288 int64_t inDimSize = inStaticShape[inDimIndex]; 1289 Value inDimSizeDynamic = 1290 ShapedType::isDynamic(inDimSize) 1291 ? inDesc.size(b, loc, inDimIndex) 1292 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1293 b.getIndexAttr(inDimSize)); 1294 outDimSizeDynamic = 1295 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1296 } 1297 return outDimSizeDynamic; 1298 } 1299 1300 static SmallVector<OpFoldResult, 4> 1301 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1302 ArrayRef<ReassociationIndices> reassocation, 1303 ArrayRef<int64_t> inStaticShape, 1304 MemRefDescriptor &inDesc, 1305 ArrayRef<int64_t> outStaticShape) { 1306 return llvm::to_vector<4>(llvm::map_range( 1307 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1308 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1309 outStaticShape[outDimIndex], 1310 inStaticShape, inDesc, reassocation); 1311 })); 1312 } 1313 1314 static SmallVector<OpFoldResult, 4> 1315 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1316 ArrayRef<ReassociationIndices> reassocation, 1317 ArrayRef<int64_t> inStaticShape, 1318 MemRefDescriptor &inDesc, 1319 ArrayRef<int64_t> outStaticShape) { 1320 DenseMap<int64_t, int64_t> outDimToInDimMap = 1321 getExpandedDimToCollapsedDimMap(reassocation); 1322 return llvm::to_vector<4>(llvm::map_range( 1323 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1324 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1325 outStaticShape, inDesc, inStaticShape, 1326 reassocation, outDimToInDimMap); 1327 })); 1328 } 1329 1330 static SmallVector<Value> 1331 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1332 ArrayRef<ReassociationIndices> reassocation, 1333 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1334 ArrayRef<int64_t> outStaticShape) { 1335 return outStaticShape.size() < inStaticShape.size() 1336 ? getAsValues(b, loc, llvmIndexType, 1337 getCollapsedOutputShape(b, loc, llvmIndexType, 1338 reassocation, inStaticShape, 1339 inDesc, outStaticShape)) 1340 : getAsValues(b, loc, llvmIndexType, 1341 getExpandedOutputShape(b, loc, llvmIndexType, 1342 reassocation, inStaticShape, 1343 inDesc, outStaticShape)); 1344 } 1345 1346 // ReshapeOp creates a new view descriptor of the proper rank. 1347 // For now, the only conversion supported is for target MemRef with static sizes 1348 // and strides. 1349 template <typename ReshapeOp> 1350 class ReassociatingReshapeOpConversion 1351 : public ConvertOpToLLVMPattern<ReshapeOp> { 1352 public: 1353 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1354 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1355 1356 LogicalResult 1357 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1358 ConversionPatternRewriter &rewriter) const override { 1359 MemRefType dstType = reshapeOp.getResultType(); 1360 MemRefType srcType = reshapeOp.getSrcType(); 1361 1362 // The condition on the layouts can be ignored when all shapes are static. 1363 if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { 1364 if (!srcType.getLayout().isIdentity() || 1365 !dstType.getLayout().isIdentity()) { 1366 return rewriter.notifyMatchFailure( 1367 reshapeOp, "only empty layout map is supported"); 1368 } 1369 } 1370 1371 int64_t offset; 1372 SmallVector<int64_t, 4> strides; 1373 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1374 return rewriter.notifyMatchFailure( 1375 reshapeOp, "failed to get stride and offset exprs"); 1376 } 1377 1378 MemRefDescriptor srcDesc(adaptor.src()); 1379 Location loc = reshapeOp->getLoc(); 1380 auto dstDesc = MemRefDescriptor::undef( 1381 rewriter, loc, this->typeConverter->convertType(dstType)); 1382 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1383 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1384 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1385 1386 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1387 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1388 Type llvmIndexType = 1389 this->typeConverter->convertType(rewriter.getIndexType()); 1390 SmallVector<Value> dstShape = getDynamicOutputShape( 1391 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1392 srcStaticShape, srcDesc, dstStaticShape); 1393 for (auto &en : llvm::enumerate(dstShape)) 1394 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1395 1396 auto isStaticStride = [](int64_t stride) { 1397 return !ShapedType::isDynamicStrideOrOffset(stride); 1398 }; 1399 if (llvm::all_of(strides, isStaticStride)) { 1400 for (auto &en : llvm::enumerate(strides)) 1401 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1402 } else { 1403 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1404 rewriter.getIndexAttr(1)); 1405 Value stride = c1; 1406 for (auto dimIndex : 1407 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1408 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1409 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1410 } 1411 } 1412 rewriter.replaceOp(reshapeOp, {dstDesc}); 1413 return success(); 1414 } 1415 }; 1416 1417 /// Conversion pattern that transforms a subview op into: 1418 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1419 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1420 /// and stride. 1421 /// The subview op is replaced by the descriptor. 1422 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1423 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1424 1425 LogicalResult 1426 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1427 ConversionPatternRewriter &rewriter) const override { 1428 auto loc = subViewOp.getLoc(); 1429 1430 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1431 auto sourceElementTy = 1432 typeConverter->convertType(sourceMemRefType.getElementType()); 1433 1434 auto viewMemRefType = subViewOp.getType(); 1435 auto inferredType = memref::SubViewOp::inferResultType( 1436 subViewOp.getSourceType(), 1437 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1438 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1439 extractFromI64ArrayAttr(subViewOp.static_strides())) 1440 .cast<MemRefType>(); 1441 auto targetElementTy = 1442 typeConverter->convertType(viewMemRefType.getElementType()); 1443 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1444 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1445 !LLVM::isCompatibleType(sourceElementTy) || 1446 !LLVM::isCompatibleType(targetElementTy) || 1447 !LLVM::isCompatibleType(targetDescTy)) 1448 return failure(); 1449 1450 // Extract the offset and strides from the type. 1451 int64_t offset; 1452 SmallVector<int64_t, 4> strides; 1453 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1454 if (failed(successStrides)) 1455 return failure(); 1456 1457 // Create the descriptor. 1458 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1459 return failure(); 1460 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1461 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1462 1463 // Copy the buffer pointer from the old descriptor to the new one. 1464 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1465 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1466 loc, 1467 LLVM::LLVMPointerType::get(targetElementTy, 1468 viewMemRefType.getMemorySpaceAsInt()), 1469 extracted); 1470 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1471 1472 // Copy the aligned pointer from the old descriptor to the new one. 1473 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1474 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1475 loc, 1476 LLVM::LLVMPointerType::get(targetElementTy, 1477 viewMemRefType.getMemorySpaceAsInt()), 1478 extracted); 1479 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1480 1481 size_t inferredShapeRank = inferredType.getRank(); 1482 size_t resultShapeRank = viewMemRefType.getRank(); 1483 1484 // Extract strides needed to compute offset. 1485 SmallVector<Value, 4> strideValues; 1486 strideValues.reserve(inferredShapeRank); 1487 for (unsigned i = 0; i < inferredShapeRank; ++i) 1488 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1489 1490 // Offset. 1491 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1492 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1493 targetMemRef.setConstantOffset(rewriter, loc, offset); 1494 } else { 1495 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1496 // `inferredShapeRank` may be larger than the number of offset operands 1497 // because of trailing semantics. In this case, the offset is guaranteed 1498 // to be interpreted as 0 and we can just skip the extra dimensions. 1499 for (unsigned i = 0, e = std::min(inferredShapeRank, 1500 subViewOp.getMixedOffsets().size()); 1501 i < e; ++i) { 1502 Value offset = 1503 // TODO: need OpFoldResult ODS adaptor to clean this up. 1504 subViewOp.isDynamicOffset(i) 1505 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1506 : rewriter.create<LLVM::ConstantOp>( 1507 loc, llvmIndexType, 1508 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1509 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1510 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1511 } 1512 targetMemRef.setOffset(rewriter, loc, baseOffset); 1513 } 1514 1515 // Update sizes and strides. 1516 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1517 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1518 assert(mixedSizes.size() == mixedStrides.size() && 1519 "expected sizes and strides of equal length"); 1520 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1521 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1522 i >= 0 && j >= 0; --i) { 1523 if (unusedDims.test(i)) 1524 continue; 1525 1526 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1527 // In this case, the size is guaranteed to be interpreted as Dim and the 1528 // stride as 1. 1529 Value size, stride; 1530 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1531 // If the static size is available, use it directly. This is similar to 1532 // the folding of dim(constant-op) but removes the need for dim to be 1533 // aware of LLVM constants and for this pass to be aware of std 1534 // constants. 1535 int64_t staticSize = 1536 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1537 if (staticSize != ShapedType::kDynamicSize) { 1538 size = rewriter.create<LLVM::ConstantOp>( 1539 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1540 } else { 1541 Value pos = rewriter.create<LLVM::ConstantOp>( 1542 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1543 Value dim = 1544 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1545 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1546 loc, llvmIndexType, dim); 1547 size = cast.getResult(0); 1548 } 1549 stride = rewriter.create<LLVM::ConstantOp>( 1550 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1551 } else { 1552 // TODO: need OpFoldResult ODS adaptor to clean this up. 1553 size = 1554 subViewOp.isDynamicSize(i) 1555 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1556 : rewriter.create<LLVM::ConstantOp>( 1557 loc, llvmIndexType, 1558 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1559 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1560 stride = rewriter.create<LLVM::ConstantOp>( 1561 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1562 } else { 1563 stride = 1564 subViewOp.isDynamicStride(i) 1565 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1566 : rewriter.create<LLVM::ConstantOp>( 1567 loc, llvmIndexType, 1568 rewriter.getI64IntegerAttr( 1569 subViewOp.getStaticStride(i))); 1570 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1571 } 1572 } 1573 targetMemRef.setSize(rewriter, loc, j, size); 1574 targetMemRef.setStride(rewriter, loc, j, stride); 1575 j--; 1576 } 1577 1578 rewriter.replaceOp(subViewOp, {targetMemRef}); 1579 return success(); 1580 } 1581 }; 1582 1583 /// Conversion pattern that transforms a transpose op into: 1584 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1585 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1586 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1587 /// and stride. Size and stride are permutations of the original values. 1588 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1589 /// The transpose op is replaced by the alloca'ed pointer. 1590 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1591 public: 1592 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1593 1594 LogicalResult 1595 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1596 ConversionPatternRewriter &rewriter) const override { 1597 auto loc = transposeOp.getLoc(); 1598 MemRefDescriptor viewMemRef(adaptor.in()); 1599 1600 // No permutation, early exit. 1601 if (transposeOp.permutation().isIdentity()) 1602 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1603 1604 auto targetMemRef = MemRefDescriptor::undef( 1605 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1606 1607 // Copy the base and aligned pointers from the old descriptor to the new 1608 // one. 1609 targetMemRef.setAllocatedPtr(rewriter, loc, 1610 viewMemRef.allocatedPtr(rewriter, loc)); 1611 targetMemRef.setAlignedPtr(rewriter, loc, 1612 viewMemRef.alignedPtr(rewriter, loc)); 1613 1614 // Copy the offset pointer from the old descriptor to the new one. 1615 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1616 1617 // Iterate over the dimensions and apply size/stride permutation. 1618 for (const auto &en : 1619 llvm::enumerate(transposeOp.permutation().getResults())) { 1620 int sourcePos = en.index(); 1621 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1622 targetMemRef.setSize(rewriter, loc, targetPos, 1623 viewMemRef.size(rewriter, loc, sourcePos)); 1624 targetMemRef.setStride(rewriter, loc, targetPos, 1625 viewMemRef.stride(rewriter, loc, sourcePos)); 1626 } 1627 1628 rewriter.replaceOp(transposeOp, {targetMemRef}); 1629 return success(); 1630 } 1631 }; 1632 1633 /// Conversion pattern that transforms an op into: 1634 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1635 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1636 /// and stride. 1637 /// The view op is replaced by the descriptor. 1638 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1639 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1640 1641 // Build and return the value for the idx^th shape dimension, either by 1642 // returning the constant shape dimension or counting the proper dynamic size. 1643 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1644 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1645 unsigned idx) const { 1646 assert(idx < shape.size()); 1647 if (!ShapedType::isDynamic(shape[idx])) 1648 return createIndexConstant(rewriter, loc, shape[idx]); 1649 // Count the number of dynamic dims in range [0, idx] 1650 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1651 return ShapedType::isDynamic(v); 1652 }); 1653 return dynamicSizes[nDynamic]; 1654 } 1655 1656 // Build and return the idx^th stride, either by returning the constant stride 1657 // or by computing the dynamic stride from the current `runningStride` and 1658 // `nextSize`. The caller should keep a running stride and update it with the 1659 // result returned by this function. 1660 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1661 ArrayRef<int64_t> strides, Value nextSize, 1662 Value runningStride, unsigned idx) const { 1663 assert(idx < strides.size()); 1664 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1665 return createIndexConstant(rewriter, loc, strides[idx]); 1666 if (nextSize) 1667 return runningStride 1668 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1669 : nextSize; 1670 assert(!runningStride); 1671 return createIndexConstant(rewriter, loc, 1); 1672 } 1673 1674 LogicalResult 1675 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1676 ConversionPatternRewriter &rewriter) const override { 1677 auto loc = viewOp.getLoc(); 1678 1679 auto viewMemRefType = viewOp.getType(); 1680 auto targetElementTy = 1681 typeConverter->convertType(viewMemRefType.getElementType()); 1682 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1683 if (!targetDescTy || !targetElementTy || 1684 !LLVM::isCompatibleType(targetElementTy) || 1685 !LLVM::isCompatibleType(targetDescTy)) 1686 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1687 failure(); 1688 1689 int64_t offset; 1690 SmallVector<int64_t, 4> strides; 1691 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1692 if (failed(successStrides)) 1693 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1694 assert(offset == 0 && "expected offset to be 0"); 1695 1696 // Create the descriptor. 1697 MemRefDescriptor sourceMemRef(adaptor.source()); 1698 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1699 1700 // Field 1: Copy the allocated pointer, used for malloc/free. 1701 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1702 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1703 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1704 loc, 1705 LLVM::LLVMPointerType::get(targetElementTy, 1706 srcMemRefType.getMemorySpaceAsInt()), 1707 allocatedPtr); 1708 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1709 1710 // Field 2: Copy the actual aligned pointer to payload. 1711 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1712 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1713 alignedPtr, adaptor.byte_shift()); 1714 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1715 loc, 1716 LLVM::LLVMPointerType::get(targetElementTy, 1717 srcMemRefType.getMemorySpaceAsInt()), 1718 alignedPtr); 1719 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1720 1721 // Field 3: The offset in the resulting type must be 0. This is because of 1722 // the type change: an offset on srcType* may not be expressible as an 1723 // offset on dstType*. 1724 targetMemRef.setOffset(rewriter, loc, 1725 createIndexConstant(rewriter, loc, offset)); 1726 1727 // Early exit for 0-D corner case. 1728 if (viewMemRefType.getRank() == 0) 1729 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1730 1731 // Fields 4 and 5: Update sizes and strides. 1732 if (strides.back() != 1) 1733 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1734 failure(); 1735 Value stride = nullptr, nextSize = nullptr; 1736 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1737 // Update size. 1738 Value size = 1739 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1740 targetMemRef.setSize(rewriter, loc, i, size); 1741 // Update stride. 1742 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1743 targetMemRef.setStride(rewriter, loc, i, stride); 1744 nextSize = size; 1745 } 1746 1747 rewriter.replaceOp(viewOp, {targetMemRef}); 1748 return success(); 1749 } 1750 }; 1751 1752 //===----------------------------------------------------------------------===// 1753 // AtomicRMWOpLowering 1754 //===----------------------------------------------------------------------===// 1755 1756 /// Try to match the kind of a std.atomic_rmw to determine whether to use a 1757 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1758 static Optional<LLVM::AtomicBinOp> 1759 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1760 switch (atomicOp.kind()) { 1761 case arith::AtomicRMWKind::addf: 1762 return LLVM::AtomicBinOp::fadd; 1763 case arith::AtomicRMWKind::addi: 1764 return LLVM::AtomicBinOp::add; 1765 case arith::AtomicRMWKind::assign: 1766 return LLVM::AtomicBinOp::xchg; 1767 case arith::AtomicRMWKind::maxs: 1768 return LLVM::AtomicBinOp::max; 1769 case arith::AtomicRMWKind::maxu: 1770 return LLVM::AtomicBinOp::umax; 1771 case arith::AtomicRMWKind::mins: 1772 return LLVM::AtomicBinOp::min; 1773 case arith::AtomicRMWKind::minu: 1774 return LLVM::AtomicBinOp::umin; 1775 case arith::AtomicRMWKind::ori: 1776 return LLVM::AtomicBinOp::_or; 1777 case arith::AtomicRMWKind::andi: 1778 return LLVM::AtomicBinOp::_and; 1779 default: 1780 return llvm::None; 1781 } 1782 llvm_unreachable("Invalid AtomicRMWKind"); 1783 } 1784 1785 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1786 using Base::Base; 1787 1788 LogicalResult 1789 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1790 ConversionPatternRewriter &rewriter) const override { 1791 if (failed(match(atomicOp))) 1792 return failure(); 1793 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1794 if (!maybeKind) 1795 return failure(); 1796 auto resultType = adaptor.value().getType(); 1797 auto memRefType = atomicOp.getMemRefType(); 1798 auto dataPtr = 1799 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1800 adaptor.indices(), rewriter); 1801 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1802 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1803 LLVM::AtomicOrdering::acq_rel); 1804 return success(); 1805 } 1806 }; 1807 1808 } // namespace 1809 1810 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1811 RewritePatternSet &patterns) { 1812 // clang-format off 1813 patterns.add< 1814 AllocaOpLowering, 1815 AllocaScopeOpLowering, 1816 AtomicRMWOpLowering, 1817 AssumeAlignmentOpLowering, 1818 DimOpLowering, 1819 GenericAtomicRMWOpLowering, 1820 GlobalMemrefOpLowering, 1821 GetGlobalMemrefOpLowering, 1822 LoadOpLowering, 1823 MemRefCastOpLowering, 1824 MemRefCopyOpLowering, 1825 MemRefReinterpretCastOpLowering, 1826 MemRefReshapeOpLowering, 1827 PrefetchOpLowering, 1828 RankOpLowering, 1829 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1830 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1831 StoreOpLowering, 1832 SubViewOpLowering, 1833 TransposeOpLowering, 1834 ViewOpLowering>(converter); 1835 // clang-format on 1836 auto allocLowering = converter.getOptions().allocLowering; 1837 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1838 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1839 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1840 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1841 } 1842 1843 namespace { 1844 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1845 MemRefToLLVMPass() = default; 1846 1847 void runOnOperation() override { 1848 Operation *op = getOperation(); 1849 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1850 LowerToLLVMOptions options(&getContext(), 1851 dataLayoutAnalysis.getAtOrAbove(op)); 1852 options.allocLowering = 1853 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1854 : LowerToLLVMOptions::AllocLowering::Malloc); 1855 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1856 options.overrideIndexBitwidth(indexBitwidth); 1857 1858 LLVMTypeConverter typeConverter(&getContext(), options, 1859 &dataLayoutAnalysis); 1860 RewritePatternSet patterns(&getContext()); 1861 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1862 LLVMConversionTarget target(getContext()); 1863 target.addLegalOp<FuncOp>(); 1864 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1865 signalPassFailure(); 1866 } 1867 }; 1868 } // namespace 1869 1870 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1871 return std::make_unique<MemRefToLLVMPass>(); 1872 } 1873