1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "llvm/ADT/SmallBitVector.h" 24 25 using namespace mlir; 26 27 namespace { 28 29 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 30 AllocOpLowering(LLVMTypeConverter &converter) 31 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 32 converter) {} 33 34 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 35 Location loc, Value sizeBytes, 36 Operation *op) const override { 37 // Heap allocations. 38 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 39 MemRefType memRefType = allocOp.getType(); 40 41 Value alignment; 42 if (auto alignmentAttr = allocOp.alignment()) { 43 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 44 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 45 // In the case where no alignment is specified, we may want to override 46 // `malloc's` behavior. `malloc` typically aligns at the size of the 47 // biggest scalar on a target HW. For non-scalars, use the natural 48 // alignment of the LLVM type given by the LLVM DataLayout. 49 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 50 } 51 52 if (alignment) { 53 // Adjust the allocation size to consider alignment. 54 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 55 } 56 57 // Allocate the underlying buffer and store a pointer to it in the MemRef 58 // descriptor. 59 Type elementPtrType = this->getElementPtrType(memRefType); 60 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 61 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 62 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 63 getVoidPtrType()); 64 Value allocatedPtr = 65 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 66 67 Value alignedPtr = allocatedPtr; 68 if (alignment) { 69 // Compute the aligned type pointer. 70 Value allocatedInt = 71 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 72 Value alignmentInt = 73 createAligned(rewriter, loc, allocatedInt, alignment); 74 alignedPtr = 75 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 76 } 77 78 return std::make_tuple(allocatedPtr, alignedPtr); 79 } 80 }; 81 82 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 83 AlignedAllocOpLowering(LLVMTypeConverter &converter) 84 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 85 converter) {} 86 87 /// Returns the memref's element size in bytes using the data layout active at 88 /// `op`. 89 // TODO: there are other places where this is used. Expose publicly? 90 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 91 const DataLayout *layout = &defaultLayout; 92 if (const DataLayoutAnalysis *analysis = 93 getTypeConverter()->getDataLayoutAnalysis()) { 94 layout = &analysis->getAbove(op); 95 } 96 Type elementType = memRefType.getElementType(); 97 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 98 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 99 *layout); 100 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 101 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 102 memRefElementType, *layout); 103 return layout->getTypeSize(elementType); 104 } 105 106 /// Returns true if the memref size in bytes is known to be a multiple of 107 /// factor assuming the data layout active at `op`. 108 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 109 Operation *op) const { 110 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 111 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 112 if (ShapedType::isDynamic(type.getDimSize(i))) 113 continue; 114 sizeDivisor = sizeDivisor * type.getDimSize(i); 115 } 116 return sizeDivisor % factor == 0; 117 } 118 119 /// Returns the alignment to be used for the allocation call itself. 120 /// aligned_alloc requires the allocation size to be a power of two, and the 121 /// allocation size to be a multiple of alignment, 122 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 123 if (Optional<uint64_t> alignment = allocOp.alignment()) 124 return *alignment; 125 126 // Whenever we don't have alignment set, we will use an alignment 127 // consistent with the element type; since the allocation size has to be a 128 // power of two, we will bump to the next power of two if it already isn't. 129 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 130 return std::max(kMinAlignedAllocAlignment, 131 llvm::PowerOf2Ceil(eltSizeBytes)); 132 } 133 134 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 135 Location loc, Value sizeBytes, 136 Operation *op) const override { 137 // Heap allocations. 138 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 139 MemRefType memRefType = allocOp.getType(); 140 int64_t alignment = getAllocationAlignment(allocOp); 141 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 142 143 // aligned_alloc requires size to be a multiple of alignment; we will pad 144 // the size to the next multiple if necessary. 145 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 146 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 147 148 Type elementPtrType = this->getElementPtrType(memRefType); 149 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 150 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 151 auto results = 152 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 153 getVoidPtrType()); 154 Value allocatedPtr = 155 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 156 157 return std::make_tuple(allocatedPtr, allocatedPtr); 158 } 159 160 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 161 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 162 163 /// Default layout to use in absence of the corresponding analysis. 164 DataLayout defaultLayout; 165 }; 166 167 // Out of line definition, required till C++17. 168 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 169 170 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 171 AllocaOpLowering(LLVMTypeConverter &converter) 172 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 173 converter) {} 174 175 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 176 /// is set to null for stack allocations. `accessAlignment` is set if 177 /// alignment is needed post allocation (for eg. in conjunction with malloc). 178 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 179 Location loc, Value sizeBytes, 180 Operation *op) const override { 181 182 // With alloca, one gets a pointer to the element type right away. 183 // For stack allocations. 184 auto allocaOp = cast<memref::AllocaOp>(op); 185 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 186 187 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 188 loc, elementPtrType, sizeBytes, 189 allocaOp.alignment() ? *allocaOp.alignment() : 0); 190 191 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 192 } 193 }; 194 195 struct AllocaScopeOpLowering 196 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 197 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 198 199 LogicalResult 200 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 201 ConversionPatternRewriter &rewriter) const override { 202 OpBuilder::InsertionGuard guard(rewriter); 203 Location loc = allocaScopeOp.getLoc(); 204 205 // Split the current block before the AllocaScopeOp to create the inlining 206 // point. 207 auto *currentBlock = rewriter.getInsertionBlock(); 208 auto *remainingOpsBlock = 209 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 210 Block *continueBlock; 211 if (allocaScopeOp.getNumResults() == 0) { 212 continueBlock = remainingOpsBlock; 213 } else { 214 continueBlock = rewriter.createBlock( 215 remainingOpsBlock, allocaScopeOp.getResultTypes(), 216 SmallVector<Location>(allocaScopeOp->getNumResults(), 217 allocaScopeOp.getLoc())); 218 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 219 } 220 221 // Inline body region. 222 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 223 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 224 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 225 226 // Save stack and then branch into the body of the region. 227 rewriter.setInsertionPointToEnd(currentBlock); 228 auto stackSaveOp = 229 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 230 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 231 232 // Replace the alloca_scope return with a branch that jumps out of the body. 233 // Stack restore before leaving the body region. 234 rewriter.setInsertionPointToEnd(afterBody); 235 auto returnOp = 236 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 237 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 238 returnOp, returnOp.results(), continueBlock); 239 240 // Insert stack restore before jumping out the body of the region. 241 rewriter.setInsertionPoint(branchOp); 242 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 243 244 // Replace the op with values return from the body region. 245 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 246 247 return success(); 248 } 249 }; 250 251 struct AssumeAlignmentOpLowering 252 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 253 using ConvertOpToLLVMPattern< 254 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 255 256 LogicalResult 257 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 258 ConversionPatternRewriter &rewriter) const override { 259 Value memref = adaptor.memref(); 260 unsigned alignment = op.alignment(); 261 auto loc = op.getLoc(); 262 263 MemRefDescriptor memRefDescriptor(memref); 264 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 265 266 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 267 // the asserted memref.alignedPtr isn't used anywhere else, as the real 268 // users like load/store/views always re-extract memref.alignedPtr as they 269 // get lowered. 270 // 271 // This relies on LLVM's CSE optimization (potentially after SROA), since 272 // after CSE all memref.alignedPtr instances get de-duplicated into the same 273 // pointer SSA value. 274 auto intPtrType = 275 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 276 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 277 Value mask = 278 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 279 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 280 rewriter.create<LLVM::AssumeOp>( 281 loc, rewriter.create<LLVM::ICmpOp>( 282 loc, LLVM::ICmpPredicate::eq, 283 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 284 285 rewriter.eraseOp(op); 286 return success(); 287 } 288 }; 289 290 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 291 // The memref descriptor being an SSA value, there is no need to clean it up 292 // in any way. 293 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 294 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 295 296 explicit DeallocOpLowering(LLVMTypeConverter &converter) 297 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 298 299 LogicalResult 300 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 301 ConversionPatternRewriter &rewriter) const override { 302 // Insert the `free` declaration if it is not already present. 303 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 304 MemRefDescriptor memref(adaptor.memref()); 305 Value casted = rewriter.create<LLVM::BitcastOp>( 306 op.getLoc(), getVoidPtrType(), 307 memref.allocatedPtr(rewriter, op.getLoc())); 308 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 309 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 310 return success(); 311 } 312 }; 313 314 // A `dim` is converted to a constant for static sizes and to an access to the 315 // size stored in the memref descriptor for dynamic sizes. 316 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 317 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 318 319 LogicalResult 320 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 321 ConversionPatternRewriter &rewriter) const override { 322 Type operandType = dimOp.source().getType(); 323 if (operandType.isa<UnrankedMemRefType>()) { 324 rewriter.replaceOp( 325 dimOp, {extractSizeOfUnrankedMemRef( 326 operandType, dimOp, adaptor.getOperands(), rewriter)}); 327 328 return success(); 329 } 330 if (operandType.isa<MemRefType>()) { 331 rewriter.replaceOp( 332 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 333 adaptor.getOperands(), rewriter)}); 334 return success(); 335 } 336 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 337 } 338 339 private: 340 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 341 OpAdaptor adaptor, 342 ConversionPatternRewriter &rewriter) const { 343 Location loc = dimOp.getLoc(); 344 345 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 346 auto scalarMemRefType = 347 MemRefType::get({}, unrankedMemRefType.getElementType()); 348 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 349 350 // Extract pointer to the underlying ranked descriptor and bitcast it to a 351 // memref<element_type> descriptor pointer to minimize the number of GEP 352 // operations. 353 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 354 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 355 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 356 loc, 357 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 358 addressSpace), 359 underlyingRankedDesc); 360 361 // Get pointer to offset field of memref<element_type> descriptor. 362 Type indexPtrTy = LLVM::LLVMPointerType::get( 363 getTypeConverter()->getIndexType(), addressSpace); 364 Value two = rewriter.create<LLVM::ConstantOp>( 365 loc, typeConverter->convertType(rewriter.getI32Type()), 366 rewriter.getI32IntegerAttr(2)); 367 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 368 loc, indexPtrTy, scalarMemRefDescPtr, 369 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 370 371 // The size value that we have to extract can be obtained using GEPop with 372 // `dimOp.index() + 1` index argument. 373 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 374 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 375 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 376 ValueRange({idxPlusOne})); 377 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 378 } 379 380 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 381 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 382 return idx; 383 384 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 385 return constantOp.getValue() 386 .cast<IntegerAttr>() 387 .getValue() 388 .getSExtValue(); 389 390 return llvm::None; 391 } 392 393 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 394 OpAdaptor adaptor, 395 ConversionPatternRewriter &rewriter) const { 396 Location loc = dimOp.getLoc(); 397 398 // Take advantage if index is constant. 399 MemRefType memRefType = operandType.cast<MemRefType>(); 400 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 401 int64_t i = index.getValue(); 402 if (memRefType.isDynamicDim(i)) { 403 // extract dynamic size from the memref descriptor. 404 MemRefDescriptor descriptor(adaptor.source()); 405 return descriptor.size(rewriter, loc, i); 406 } 407 // Use constant for static size. 408 int64_t dimSize = memRefType.getDimSize(i); 409 return createIndexConstant(rewriter, loc, dimSize); 410 } 411 Value index = adaptor.index(); 412 int64_t rank = memRefType.getRank(); 413 MemRefDescriptor memrefDescriptor(adaptor.source()); 414 return memrefDescriptor.size(rewriter, loc, index, rank); 415 } 416 }; 417 418 /// Common base for load and store operations on MemRefs. Restricts the match 419 /// to supported MemRef types. Provides functionality to emit code accessing a 420 /// specific element of the underlying data buffer. 421 template <typename Derived> 422 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 423 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 424 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 425 using Base = LoadStoreOpLowering<Derived>; 426 427 LogicalResult match(Derived op) const override { 428 MemRefType type = op.getMemRefType(); 429 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 430 } 431 }; 432 433 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 434 /// retried until it succeeds in atomically storing a new value into memory. 435 /// 436 /// +---------------------------------+ 437 /// | <code before the AtomicRMWOp> | 438 /// | <compute initial %loaded> | 439 /// | cf.br loop(%loaded) | 440 /// +---------------------------------+ 441 /// | 442 /// -------| | 443 /// | v v 444 /// | +--------------------------------+ 445 /// | | loop(%loaded): | 446 /// | | <body contents> | 447 /// | | %pair = cmpxchg | 448 /// | | %ok = %pair[0] | 449 /// | | %new = %pair[1] | 450 /// | | cf.cond_br %ok, end, loop(%new) | 451 /// | +--------------------------------+ 452 /// | | | 453 /// |----------- | 454 /// v 455 /// +--------------------------------+ 456 /// | end: | 457 /// | <code after the AtomicRMWOp> | 458 /// +--------------------------------+ 459 /// 460 struct GenericAtomicRMWOpLowering 461 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 462 using Base::Base; 463 464 LogicalResult 465 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 466 ConversionPatternRewriter &rewriter) const override { 467 auto loc = atomicOp.getLoc(); 468 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 469 470 // Split the block into initial, loop, and ending parts. 471 auto *initBlock = rewriter.getInsertionBlock(); 472 auto *loopBlock = rewriter.createBlock( 473 initBlock->getParent(), std::next(Region::iterator(initBlock)), 474 valueType, loc); 475 auto *endBlock = rewriter.createBlock( 476 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 477 478 // Operations range to be moved to `endBlock`. 479 auto opsToMoveStart = atomicOp->getIterator(); 480 auto opsToMoveEnd = initBlock->back().getIterator(); 481 482 // Compute the loaded value and branch to the loop block. 483 rewriter.setInsertionPointToEnd(initBlock); 484 auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); 485 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 486 adaptor.indices(), rewriter); 487 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 488 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 489 490 // Prepare the body of the loop block. 491 rewriter.setInsertionPointToStart(loopBlock); 492 493 // Clone the GenericAtomicRMWOp region and extract the result. 494 auto loopArgument = loopBlock->getArgument(0); 495 BlockAndValueMapping mapping; 496 mapping.map(atomicOp.getCurrentValue(), loopArgument); 497 Block &entryBlock = atomicOp.body().front(); 498 for (auto &nestedOp : entryBlock.without_terminator()) { 499 Operation *clone = rewriter.clone(nestedOp, mapping); 500 mapping.map(nestedOp.getResults(), clone->getResults()); 501 } 502 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 503 504 // Prepare the epilog of the loop block. 505 // Append the cmpxchg op to the end of the loop block. 506 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 507 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 508 auto boolType = IntegerType::get(rewriter.getContext(), 1); 509 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 510 {valueType, boolType}); 511 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 512 loc, pairType, dataPtr, loopArgument, result, successOrdering, 513 failureOrdering); 514 // Extract the %new_loaded and %ok values from the pair. 515 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 516 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 517 Value ok = rewriter.create<LLVM::ExtractValueOp>( 518 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 519 520 // Conditionally branch to the end or back to the loop depending on %ok. 521 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 522 loopBlock, newLoaded); 523 524 rewriter.setInsertionPointToEnd(endBlock); 525 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 526 std::next(opsToMoveEnd), rewriter); 527 528 // The 'result' of the atomic_rmw op is the newly loaded value. 529 rewriter.replaceOp(atomicOp, {newLoaded}); 530 531 return success(); 532 } 533 534 private: 535 // Clones a segment of ops [start, end) and erases the original. 536 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 537 Block::iterator start, Block::iterator end, 538 ConversionPatternRewriter &rewriter) const { 539 BlockAndValueMapping mapping; 540 mapping.map(oldResult, newResult); 541 SmallVector<Operation *, 2> opsToErase; 542 for (auto it = start; it != end; ++it) { 543 rewriter.clone(*it, mapping); 544 opsToErase.push_back(&*it); 545 } 546 for (auto *it : opsToErase) 547 rewriter.eraseOp(it); 548 } 549 }; 550 551 /// Returns the LLVM type of the global variable given the memref type `type`. 552 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 553 LLVMTypeConverter &typeConverter) { 554 // LLVM type for a global memref will be a multi-dimension array. For 555 // declarations or uninitialized global memrefs, we can potentially flatten 556 // this to a 1D array. However, for memref.global's with an initial value, 557 // we do not intend to flatten the ElementsAttribute when going from std -> 558 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 559 Type elementType = typeConverter.convertType(type.getElementType()); 560 Type arrayTy = elementType; 561 // Shape has the outermost dim at index 0, so need to walk it backwards 562 for (int64_t dim : llvm::reverse(type.getShape())) 563 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 564 return arrayTy; 565 } 566 567 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 568 struct GlobalMemrefOpLowering 569 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 570 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 571 572 LogicalResult 573 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 574 ConversionPatternRewriter &rewriter) const override { 575 MemRefType type = global.type(); 576 if (!isConvertibleAndHasIdentityMaps(type)) 577 return failure(); 578 579 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 580 581 LLVM::Linkage linkage = 582 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 583 584 Attribute initialValue = nullptr; 585 if (!global.isExternal() && !global.isUninitialized()) { 586 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 587 initialValue = elementsAttr; 588 589 // For scalar memrefs, the global variable created is of the element type, 590 // so unpack the elements attribute to extract the value. 591 if (type.getRank() == 0) 592 initialValue = elementsAttr.getSplatValue<Attribute>(); 593 } 594 595 uint64_t alignment = global.alignment().getValueOr(0); 596 597 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 598 global, arrayTy, global.constant(), linkage, global.sym_name(), 599 initialValue, alignment, type.getMemorySpaceAsInt()); 600 if (!global.isExternal() && global.isUninitialized()) { 601 Block *blk = new Block(); 602 newGlobal.getInitializerRegion().push_back(blk); 603 rewriter.setInsertionPointToStart(blk); 604 Value undef[] = { 605 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 606 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 607 } 608 return success(); 609 } 610 }; 611 612 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 613 /// the first element stashed into the descriptor. This reuses 614 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 615 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 616 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 617 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 618 converter) {} 619 620 /// Buffer "allocation" for memref.get_global op is getting the address of 621 /// the global variable referenced. 622 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 623 Location loc, Value sizeBytes, 624 Operation *op) const override { 625 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 626 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 627 unsigned memSpace = type.getMemorySpaceAsInt(); 628 629 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 630 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 631 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 632 633 // Get the address of the first element in the array by creating a GEP with 634 // the address of the GV as the base, and (rank + 1) number of 0 indices. 635 Type elementType = typeConverter->convertType(type.getElementType()); 636 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 637 638 SmallVector<Value> operands; 639 operands.insert(operands.end(), type.getRank() + 1, 640 createIndexConstant(rewriter, loc, 0)); 641 auto gep = 642 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 643 644 // We do not expect the memref obtained using `memref.get_global` to be 645 // ever deallocated. Set the allocated pointer to be known bad value to 646 // help debug if that ever happens. 647 auto intPtrType = getIntPtrType(memSpace); 648 Value deadBeefConst = 649 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 650 auto deadBeefPtr = 651 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 652 653 // Both allocated and aligned pointers are same. We could potentially stash 654 // a nullptr for the allocated pointer since we do not expect any dealloc. 655 return std::make_tuple(deadBeefPtr, gep); 656 } 657 }; 658 659 // Load operation is lowered to obtaining a pointer to the indexed element 660 // and loading it. 661 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 662 using Base::Base; 663 664 LogicalResult 665 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 666 ConversionPatternRewriter &rewriter) const override { 667 auto type = loadOp.getMemRefType(); 668 669 Value dataPtr = getStridedElementPtr( 670 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 671 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 672 return success(); 673 } 674 }; 675 676 // Store operation is lowered to obtaining a pointer to the indexed element, 677 // and storing the given value to it. 678 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 679 using Base::Base; 680 681 LogicalResult 682 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 683 ConversionPatternRewriter &rewriter) const override { 684 auto type = op.getMemRefType(); 685 686 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 687 adaptor.indices(), rewriter); 688 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 689 return success(); 690 } 691 }; 692 693 // The prefetch operation is lowered in a way similar to the load operation 694 // except that the llvm.prefetch operation is used for replacement. 695 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 696 using Base::Base; 697 698 LogicalResult 699 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 700 ConversionPatternRewriter &rewriter) const override { 701 auto type = prefetchOp.getMemRefType(); 702 auto loc = prefetchOp.getLoc(); 703 704 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 705 adaptor.indices(), rewriter); 706 707 // Replace with llvm.prefetch. 708 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 709 auto isWrite = rewriter.create<LLVM::ConstantOp>( 710 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 711 auto localityHint = rewriter.create<LLVM::ConstantOp>( 712 loc, llvmI32Type, 713 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 714 auto isData = rewriter.create<LLVM::ConstantOp>( 715 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 716 717 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 718 localityHint, isData); 719 return success(); 720 } 721 }; 722 723 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 724 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 725 726 LogicalResult 727 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 728 ConversionPatternRewriter &rewriter) const override { 729 Location loc = op.getLoc(); 730 Type operandType = op.memref().getType(); 731 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 732 UnrankedMemRefDescriptor desc(adaptor.memref()); 733 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 734 return success(); 735 } 736 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 737 rewriter.replaceOp( 738 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 739 return success(); 740 } 741 return failure(); 742 } 743 }; 744 745 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 746 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 747 748 LogicalResult match(memref::CastOp memRefCastOp) const override { 749 Type srcType = memRefCastOp.getOperand().getType(); 750 Type dstType = memRefCastOp.getType(); 751 752 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 753 // used for type erasure. For now they must preserve underlying element type 754 // and require source and result type to have the same rank. Therefore, 755 // perform a sanity check that the underlying structs are the same. Once op 756 // semantics are relaxed we can revisit. 757 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 758 return success(typeConverter->convertType(srcType) == 759 typeConverter->convertType(dstType)); 760 761 // At least one of the operands is unranked type 762 assert(srcType.isa<UnrankedMemRefType>() || 763 dstType.isa<UnrankedMemRefType>()); 764 765 // Unranked to unranked cast is disallowed 766 return !(srcType.isa<UnrankedMemRefType>() && 767 dstType.isa<UnrankedMemRefType>()) 768 ? success() 769 : failure(); 770 } 771 772 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 773 ConversionPatternRewriter &rewriter) const override { 774 auto srcType = memRefCastOp.getOperand().getType(); 775 auto dstType = memRefCastOp.getType(); 776 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 777 auto loc = memRefCastOp.getLoc(); 778 779 // For ranked/ranked case, just keep the original descriptor. 780 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 781 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 782 783 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 784 // Casting ranked to unranked memref type 785 // Set the rank in the destination from the memref type 786 // Allocate space on the stack and copy the src memref descriptor 787 // Set the ptr in the destination to the stack space 788 auto srcMemRefType = srcType.cast<MemRefType>(); 789 int64_t rank = srcMemRefType.getRank(); 790 // ptr = AllocaOp sizeof(MemRefDescriptor) 791 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 792 loc, adaptor.source(), rewriter); 793 // voidptr = BitCastOp srcType* to void* 794 auto voidPtr = 795 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 796 .getResult(); 797 // rank = ConstantOp srcRank 798 auto rankVal = rewriter.create<LLVM::ConstantOp>( 799 loc, getIndexType(), rewriter.getIndexAttr(rank)); 800 // undef = UndefOp 801 UnrankedMemRefDescriptor memRefDesc = 802 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 803 // d1 = InsertValueOp undef, rank, 0 804 memRefDesc.setRank(rewriter, loc, rankVal); 805 // d2 = InsertValueOp d1, voidptr, 1 806 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 807 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 808 809 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 810 // Casting from unranked type to ranked. 811 // The operation is assumed to be doing a correct cast. If the destination 812 // type mismatches the unranked the type, it is undefined behavior. 813 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 814 // ptr = ExtractValueOp src, 1 815 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 816 // castPtr = BitCastOp i8* to structTy* 817 auto castPtr = 818 rewriter 819 .create<LLVM::BitcastOp>( 820 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 821 .getResult(); 822 // struct = LoadOp castPtr 823 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 824 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 825 } else { 826 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 827 } 828 } 829 }; 830 831 /// Pattern to lower a `memref.copy` to llvm. 832 /// 833 /// For memrefs with identity layouts, the copy is lowered to the llvm 834 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 835 /// to the generic `MemrefCopyFn`. 836 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 837 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 838 839 LogicalResult 840 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 841 ConversionPatternRewriter &rewriter) const { 842 auto loc = op.getLoc(); 843 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 844 845 MemRefDescriptor srcDesc(adaptor.source()); 846 847 // Compute number of elements. 848 Value numElements = rewriter.create<LLVM::ConstantOp>( 849 loc, getIndexType(), rewriter.getIndexAttr(1)); 850 for (int pos = 0; pos < srcType.getRank(); ++pos) { 851 auto size = srcDesc.size(rewriter, loc, pos); 852 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 853 } 854 855 // Get element size. 856 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 857 // Compute total. 858 Value totalSize = 859 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 860 861 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 862 Value srcOffset = srcDesc.offset(rewriter, loc); 863 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(), 864 srcBasePtr, srcOffset); 865 MemRefDescriptor targetDesc(adaptor.target()); 866 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 867 Value targetOffset = targetDesc.offset(rewriter, loc); 868 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(), 869 targetBasePtr, targetOffset); 870 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 871 loc, typeConverter->convertType(rewriter.getI1Type()), 872 rewriter.getBoolAttr(false)); 873 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 874 isVolatile); 875 rewriter.eraseOp(op); 876 877 return success(); 878 } 879 880 LogicalResult 881 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 882 ConversionPatternRewriter &rewriter) const { 883 auto loc = op.getLoc(); 884 auto srcType = op.source().getType().cast<BaseMemRefType>(); 885 auto targetType = op.target().getType().cast<BaseMemRefType>(); 886 887 // First make sure we have an unranked memref descriptor representation. 888 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 889 auto rank = rewriter.create<LLVM::ConstantOp>( 890 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 891 auto *typeConverter = getTypeConverter(); 892 auto ptr = 893 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 894 auto voidPtr = 895 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 896 .getResult(); 897 auto unrankedType = 898 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 899 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 900 unrankedType, 901 ValueRange{rank, voidPtr}); 902 }; 903 904 Value unrankedSource = srcType.hasRank() 905 ? makeUnranked(adaptor.source(), srcType) 906 : adaptor.source(); 907 Value unrankedTarget = targetType.hasRank() 908 ? makeUnranked(adaptor.target(), targetType) 909 : adaptor.target(); 910 911 // Now promote the unranked descriptors to the stack. 912 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 913 rewriter.getIndexAttr(1)); 914 auto promote = [&](Value desc) { 915 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 916 auto allocated = 917 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 918 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 919 return allocated; 920 }; 921 922 auto sourcePtr = promote(unrankedSource); 923 auto targetPtr = promote(unrankedTarget); 924 925 unsigned typeSize = 926 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 927 auto elemSize = rewriter.create<LLVM::ConstantOp>( 928 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 929 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 930 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 931 rewriter.create<LLVM::CallOp>(loc, copyFn, 932 ValueRange{elemSize, sourcePtr, targetPtr}); 933 rewriter.eraseOp(op); 934 935 return success(); 936 } 937 938 LogicalResult 939 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 940 ConversionPatternRewriter &rewriter) const override { 941 auto srcType = op.source().getType().cast<BaseMemRefType>(); 942 auto targetType = op.target().getType().cast<BaseMemRefType>(); 943 944 auto isContiguousMemrefType = [](BaseMemRefType type) { 945 auto memrefType = type.dyn_cast<mlir::MemRefType>(); 946 // We can use memcpy for memrefs if they have an identity layout or are 947 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 948 // special case handled by memrefCopy. 949 return memrefType && 950 (memrefType.getLayout().isIdentity() || 951 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 952 isStaticShapeAndContiguousRowMajor(memrefType))); 953 }; 954 955 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 956 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 957 958 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 959 } 960 }; 961 962 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 963 /// memref type. In unranked case, the fields are extracted from the underlying 964 /// ranked descriptor. 965 static void extractPointersAndOffset(Location loc, 966 ConversionPatternRewriter &rewriter, 967 LLVMTypeConverter &typeConverter, 968 Value originalOperand, 969 Value convertedOperand, 970 Value *allocatedPtr, Value *alignedPtr, 971 Value *offset = nullptr) { 972 Type operandType = originalOperand.getType(); 973 if (operandType.isa<MemRefType>()) { 974 MemRefDescriptor desc(convertedOperand); 975 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 976 *alignedPtr = desc.alignedPtr(rewriter, loc); 977 if (offset != nullptr) 978 *offset = desc.offset(rewriter, loc); 979 return; 980 } 981 982 unsigned memorySpace = 983 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 984 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 985 Type llvmElementType = typeConverter.convertType(elementType); 986 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 987 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 988 989 // Extract pointer to the underlying ranked memref descriptor and cast it to 990 // ElemType**. 991 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 992 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 993 994 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 995 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 996 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 997 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 998 if (offset != nullptr) { 999 *offset = UnrankedMemRefDescriptor::offset( 1000 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1001 } 1002 } 1003 1004 struct MemRefReinterpretCastOpLowering 1005 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1006 using ConvertOpToLLVMPattern< 1007 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1008 1009 LogicalResult 1010 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1011 ConversionPatternRewriter &rewriter) const override { 1012 Type srcType = castOp.source().getType(); 1013 1014 Value descriptor; 1015 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1016 adaptor, &descriptor))) 1017 return failure(); 1018 rewriter.replaceOp(castOp, {descriptor}); 1019 return success(); 1020 } 1021 1022 private: 1023 LogicalResult convertSourceMemRefToDescriptor( 1024 ConversionPatternRewriter &rewriter, Type srcType, 1025 memref::ReinterpretCastOp castOp, 1026 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1027 MemRefType targetMemRefType = 1028 castOp.getResult().getType().cast<MemRefType>(); 1029 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1030 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1031 if (!llvmTargetDescriptorTy) 1032 return failure(); 1033 1034 // Create descriptor. 1035 Location loc = castOp.getLoc(); 1036 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1037 1038 // Set allocated and aligned pointers. 1039 Value allocatedPtr, alignedPtr; 1040 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1041 castOp.source(), adaptor.source(), &allocatedPtr, 1042 &alignedPtr); 1043 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1044 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1045 1046 // Set offset. 1047 if (castOp.isDynamicOffset(0)) 1048 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 1049 else 1050 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1051 1052 // Set sizes and strides. 1053 unsigned dynSizeId = 0; 1054 unsigned dynStrideId = 0; 1055 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1056 if (castOp.isDynamicSize(i)) 1057 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 1058 else 1059 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1060 1061 if (castOp.isDynamicStride(i)) 1062 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 1063 else 1064 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1065 } 1066 *descriptor = desc; 1067 return success(); 1068 } 1069 }; 1070 1071 struct MemRefReshapeOpLowering 1072 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1073 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1074 1075 LogicalResult 1076 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1077 ConversionPatternRewriter &rewriter) const override { 1078 Type srcType = reshapeOp.source().getType(); 1079 1080 Value descriptor; 1081 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1082 adaptor, &descriptor))) 1083 return failure(); 1084 rewriter.replaceOp(reshapeOp, {descriptor}); 1085 return success(); 1086 } 1087 1088 private: 1089 LogicalResult 1090 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1091 Type srcType, memref::ReshapeOp reshapeOp, 1092 memref::ReshapeOp::Adaptor adaptor, 1093 Value *descriptor) const { 1094 // Conversion for statically-known shape args is performed via 1095 // `memref_reinterpret_cast`. 1096 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 1097 if (shapeMemRefType.hasStaticShape()) 1098 return failure(); 1099 1100 // The shape is a rank-1 tensor with unknown length. 1101 Location loc = reshapeOp.getLoc(); 1102 MemRefDescriptor shapeDesc(adaptor.shape()); 1103 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1104 1105 // Extract address space and element type. 1106 auto targetType = 1107 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1108 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1109 Type elementType = targetType.getElementType(); 1110 1111 // Create the unranked memref descriptor that holds the ranked one. The 1112 // inner descriptor is allocated on stack. 1113 auto targetDesc = UnrankedMemRefDescriptor::undef( 1114 rewriter, loc, typeConverter->convertType(targetType)); 1115 targetDesc.setRank(rewriter, loc, resultRank); 1116 SmallVector<Value, 4> sizes; 1117 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1118 targetDesc, sizes); 1119 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1120 loc, getVoidPtrType(), sizes.front(), llvm::None); 1121 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1122 1123 // Extract pointers and offset from the source memref. 1124 Value allocatedPtr, alignedPtr, offset; 1125 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1126 reshapeOp.source(), adaptor.source(), 1127 &allocatedPtr, &alignedPtr, &offset); 1128 1129 // Set pointers and offset. 1130 Type llvmElementType = typeConverter->convertType(elementType); 1131 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1132 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1133 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1134 elementPtrPtrType, allocatedPtr); 1135 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1136 underlyingDescPtr, 1137 elementPtrPtrType, alignedPtr); 1138 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1139 underlyingDescPtr, elementPtrPtrType, 1140 offset); 1141 1142 // Use the offset pointer as base for further addressing. Copy over the new 1143 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1144 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1145 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1146 elementPtrPtrType); 1147 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1148 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1149 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1150 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1151 Value resultRankMinusOne = 1152 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1153 1154 Block *initBlock = rewriter.getInsertionBlock(); 1155 Type indexType = getTypeConverter()->getIndexType(); 1156 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1157 1158 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1159 {indexType, indexType}, {loc, loc}); 1160 1161 // Move the remaining initBlock ops to condBlock. 1162 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1163 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1164 1165 rewriter.setInsertionPointToEnd(initBlock); 1166 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1167 condBlock); 1168 rewriter.setInsertionPointToStart(condBlock); 1169 Value indexArg = condBlock->getArgument(0); 1170 Value strideArg = condBlock->getArgument(1); 1171 1172 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1173 Value pred = rewriter.create<LLVM::ICmpOp>( 1174 loc, IntegerType::get(rewriter.getContext(), 1), 1175 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1176 1177 Block *bodyBlock = 1178 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1179 rewriter.setInsertionPointToStart(bodyBlock); 1180 1181 // Copy size from shape to descriptor. 1182 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1183 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1184 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1185 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1186 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1187 targetSizesBase, indexArg, size); 1188 1189 // Write stride value and compute next one. 1190 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1191 targetStridesBase, indexArg, strideArg); 1192 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1193 1194 // Decrement loop counter and branch back. 1195 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1196 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1197 condBlock); 1198 1199 Block *remainder = 1200 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1201 1202 // Hook up the cond exit to the remainder. 1203 rewriter.setInsertionPointToEnd(condBlock); 1204 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1205 llvm::None); 1206 1207 // Reset position to beginning of new remainder block. 1208 rewriter.setInsertionPointToStart(remainder); 1209 1210 *descriptor = targetDesc; 1211 return success(); 1212 } 1213 }; 1214 1215 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1216 /// `Value`s. 1217 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1218 Type &llvmIndexType, 1219 ArrayRef<OpFoldResult> valueOrAttrVec) { 1220 return llvm::to_vector<4>( 1221 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1222 if (auto attr = value.dyn_cast<Attribute>()) 1223 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1224 return value.get<Value>(); 1225 })); 1226 } 1227 1228 /// Compute a map that for a given dimension of the expanded type gives the 1229 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1230 /// the `reassocation` maps. 1231 static DenseMap<int64_t, int64_t> 1232 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1233 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1234 for (auto &en : enumerate(reassociation)) { 1235 for (auto dim : en.value()) 1236 expandedDimToCollapsedDim[dim] = en.index(); 1237 } 1238 return expandedDimToCollapsedDim; 1239 } 1240 1241 static OpFoldResult 1242 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1243 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1244 MemRefDescriptor &inDesc, 1245 ArrayRef<int64_t> inStaticShape, 1246 ArrayRef<ReassociationIndices> reassocation, 1247 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1248 int64_t outDimSize = outStaticShape[outDimIndex]; 1249 if (!ShapedType::isDynamic(outDimSize)) 1250 return b.getIndexAttr(outDimSize); 1251 1252 // Calculate the multiplication of all the out dim sizes except the 1253 // current dim. 1254 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1255 int64_t otherDimSizesMul = 1; 1256 for (auto otherDimIndex : reassocation[inDimIndex]) { 1257 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1258 continue; 1259 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1260 assert(!ShapedType::isDynamic(otherDimSize) && 1261 "single dimension cannot be expanded into multiple dynamic " 1262 "dimensions"); 1263 otherDimSizesMul *= otherDimSize; 1264 } 1265 1266 // outDimSize = inDimSize / otherOutDimSizesMul 1267 int64_t inDimSize = inStaticShape[inDimIndex]; 1268 Value inDimSizeDynamic = 1269 ShapedType::isDynamic(inDimSize) 1270 ? inDesc.size(b, loc, inDimIndex) 1271 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1272 b.getIndexAttr(inDimSize)); 1273 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1274 loc, inDimSizeDynamic, 1275 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1276 b.getIndexAttr(otherDimSizesMul))); 1277 return outDimSizeDynamic; 1278 } 1279 1280 static OpFoldResult getCollapsedOutputDimSize( 1281 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1282 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1283 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1284 if (!ShapedType::isDynamic(outDimSize)) 1285 return b.getIndexAttr(outDimSize); 1286 1287 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1288 Value outDimSizeDynamic = c1; 1289 for (auto inDimIndex : reassocation[outDimIndex]) { 1290 int64_t inDimSize = inStaticShape[inDimIndex]; 1291 Value inDimSizeDynamic = 1292 ShapedType::isDynamic(inDimSize) 1293 ? inDesc.size(b, loc, inDimIndex) 1294 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1295 b.getIndexAttr(inDimSize)); 1296 outDimSizeDynamic = 1297 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1298 } 1299 return outDimSizeDynamic; 1300 } 1301 1302 static SmallVector<OpFoldResult, 4> 1303 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1304 ArrayRef<ReassociationIndices> reassocation, 1305 ArrayRef<int64_t> inStaticShape, 1306 MemRefDescriptor &inDesc, 1307 ArrayRef<int64_t> outStaticShape) { 1308 return llvm::to_vector<4>(llvm::map_range( 1309 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1310 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1311 outStaticShape[outDimIndex], 1312 inStaticShape, inDesc, reassocation); 1313 })); 1314 } 1315 1316 static SmallVector<OpFoldResult, 4> 1317 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1318 ArrayRef<ReassociationIndices> reassocation, 1319 ArrayRef<int64_t> inStaticShape, 1320 MemRefDescriptor &inDesc, 1321 ArrayRef<int64_t> outStaticShape) { 1322 DenseMap<int64_t, int64_t> outDimToInDimMap = 1323 getExpandedDimToCollapsedDimMap(reassocation); 1324 return llvm::to_vector<4>(llvm::map_range( 1325 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1326 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1327 outStaticShape, inDesc, inStaticShape, 1328 reassocation, outDimToInDimMap); 1329 })); 1330 } 1331 1332 static SmallVector<Value> 1333 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1334 ArrayRef<ReassociationIndices> reassocation, 1335 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1336 ArrayRef<int64_t> outStaticShape) { 1337 return outStaticShape.size() < inStaticShape.size() 1338 ? getAsValues(b, loc, llvmIndexType, 1339 getCollapsedOutputShape(b, loc, llvmIndexType, 1340 reassocation, inStaticShape, 1341 inDesc, outStaticShape)) 1342 : getAsValues(b, loc, llvmIndexType, 1343 getExpandedOutputShape(b, loc, llvmIndexType, 1344 reassocation, inStaticShape, 1345 inDesc, outStaticShape)); 1346 } 1347 1348 // ReshapeOp creates a new view descriptor of the proper rank. 1349 // For now, the only conversion supported is for target MemRef with static sizes 1350 // and strides. 1351 template <typename ReshapeOp> 1352 class ReassociatingReshapeOpConversion 1353 : public ConvertOpToLLVMPattern<ReshapeOp> { 1354 public: 1355 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1356 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1357 1358 LogicalResult 1359 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1360 ConversionPatternRewriter &rewriter) const override { 1361 MemRefType dstType = reshapeOp.getResultType(); 1362 MemRefType srcType = reshapeOp.getSrcType(); 1363 1364 // The condition on the layouts can be ignored when all shapes are static. 1365 if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { 1366 if (!srcType.getLayout().isIdentity() || 1367 !dstType.getLayout().isIdentity()) { 1368 return rewriter.notifyMatchFailure( 1369 reshapeOp, "only empty layout map is supported"); 1370 } 1371 } 1372 1373 int64_t offset; 1374 SmallVector<int64_t, 4> strides; 1375 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1376 return rewriter.notifyMatchFailure( 1377 reshapeOp, "failed to get stride and offset exprs"); 1378 } 1379 1380 MemRefDescriptor srcDesc(adaptor.src()); 1381 Location loc = reshapeOp->getLoc(); 1382 auto dstDesc = MemRefDescriptor::undef( 1383 rewriter, loc, this->typeConverter->convertType(dstType)); 1384 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1385 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1386 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1387 1388 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1389 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1390 Type llvmIndexType = 1391 this->typeConverter->convertType(rewriter.getIndexType()); 1392 SmallVector<Value> dstShape = getDynamicOutputShape( 1393 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1394 srcStaticShape, srcDesc, dstStaticShape); 1395 for (auto &en : llvm::enumerate(dstShape)) 1396 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1397 1398 auto isStaticStride = [](int64_t stride) { 1399 return !ShapedType::isDynamicStrideOrOffset(stride); 1400 }; 1401 if (llvm::all_of(strides, isStaticStride)) { 1402 for (auto &en : llvm::enumerate(strides)) 1403 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1404 } else { 1405 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1406 rewriter.getIndexAttr(1)); 1407 Value stride = c1; 1408 for (auto dimIndex : 1409 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1410 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1411 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1412 } 1413 } 1414 rewriter.replaceOp(reshapeOp, {dstDesc}); 1415 return success(); 1416 } 1417 }; 1418 1419 /// Conversion pattern that transforms a subview op into: 1420 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1421 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1422 /// and stride. 1423 /// The subview op is replaced by the descriptor. 1424 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1425 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1426 1427 LogicalResult 1428 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1429 ConversionPatternRewriter &rewriter) const override { 1430 auto loc = subViewOp.getLoc(); 1431 1432 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1433 auto sourceElementTy = 1434 typeConverter->convertType(sourceMemRefType.getElementType()); 1435 1436 auto viewMemRefType = subViewOp.getType(); 1437 auto inferredType = memref::SubViewOp::inferResultType( 1438 subViewOp.getSourceType(), 1439 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1440 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1441 extractFromI64ArrayAttr(subViewOp.static_strides())) 1442 .cast<MemRefType>(); 1443 auto targetElementTy = 1444 typeConverter->convertType(viewMemRefType.getElementType()); 1445 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1446 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1447 !LLVM::isCompatibleType(sourceElementTy) || 1448 !LLVM::isCompatibleType(targetElementTy) || 1449 !LLVM::isCompatibleType(targetDescTy)) 1450 return failure(); 1451 1452 // Extract the offset and strides from the type. 1453 int64_t offset; 1454 SmallVector<int64_t, 4> strides; 1455 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1456 if (failed(successStrides)) 1457 return failure(); 1458 1459 // Create the descriptor. 1460 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1461 return failure(); 1462 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1463 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1464 1465 // Copy the buffer pointer from the old descriptor to the new one. 1466 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1467 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1468 loc, 1469 LLVM::LLVMPointerType::get(targetElementTy, 1470 viewMemRefType.getMemorySpaceAsInt()), 1471 extracted); 1472 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1473 1474 // Copy the aligned pointer from the old descriptor to the new one. 1475 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1476 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1477 loc, 1478 LLVM::LLVMPointerType::get(targetElementTy, 1479 viewMemRefType.getMemorySpaceAsInt()), 1480 extracted); 1481 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1482 1483 size_t inferredShapeRank = inferredType.getRank(); 1484 size_t resultShapeRank = viewMemRefType.getRank(); 1485 1486 // Extract strides needed to compute offset. 1487 SmallVector<Value, 4> strideValues; 1488 strideValues.reserve(inferredShapeRank); 1489 for (unsigned i = 0; i < inferredShapeRank; ++i) 1490 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1491 1492 // Offset. 1493 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1494 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1495 targetMemRef.setConstantOffset(rewriter, loc, offset); 1496 } else { 1497 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1498 // `inferredShapeRank` may be larger than the number of offset operands 1499 // because of trailing semantics. In this case, the offset is guaranteed 1500 // to be interpreted as 0 and we can just skip the extra dimensions. 1501 for (unsigned i = 0, e = std::min(inferredShapeRank, 1502 subViewOp.getMixedOffsets().size()); 1503 i < e; ++i) { 1504 Value offset = 1505 // TODO: need OpFoldResult ODS adaptor to clean this up. 1506 subViewOp.isDynamicOffset(i) 1507 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1508 : rewriter.create<LLVM::ConstantOp>( 1509 loc, llvmIndexType, 1510 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1511 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1512 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1513 } 1514 targetMemRef.setOffset(rewriter, loc, baseOffset); 1515 } 1516 1517 // Update sizes and strides. 1518 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1519 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1520 assert(mixedSizes.size() == mixedStrides.size() && 1521 "expected sizes and strides of equal length"); 1522 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1523 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1524 i >= 0 && j >= 0; --i) { 1525 if (unusedDims.test(i)) 1526 continue; 1527 1528 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1529 // In this case, the size is guaranteed to be interpreted as Dim and the 1530 // stride as 1. 1531 Value size, stride; 1532 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1533 // If the static size is available, use it directly. This is similar to 1534 // the folding of dim(constant-op) but removes the need for dim to be 1535 // aware of LLVM constants and for this pass to be aware of std 1536 // constants. 1537 int64_t staticSize = 1538 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1539 if (staticSize != ShapedType::kDynamicSize) { 1540 size = rewriter.create<LLVM::ConstantOp>( 1541 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1542 } else { 1543 Value pos = rewriter.create<LLVM::ConstantOp>( 1544 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1545 Value dim = 1546 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1547 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1548 loc, llvmIndexType, dim); 1549 size = cast.getResult(0); 1550 } 1551 stride = rewriter.create<LLVM::ConstantOp>( 1552 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1553 } else { 1554 // TODO: need OpFoldResult ODS adaptor to clean this up. 1555 size = 1556 subViewOp.isDynamicSize(i) 1557 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1558 : rewriter.create<LLVM::ConstantOp>( 1559 loc, llvmIndexType, 1560 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1561 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1562 stride = rewriter.create<LLVM::ConstantOp>( 1563 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1564 } else { 1565 stride = 1566 subViewOp.isDynamicStride(i) 1567 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1568 : rewriter.create<LLVM::ConstantOp>( 1569 loc, llvmIndexType, 1570 rewriter.getI64IntegerAttr( 1571 subViewOp.getStaticStride(i))); 1572 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1573 } 1574 } 1575 targetMemRef.setSize(rewriter, loc, j, size); 1576 targetMemRef.setStride(rewriter, loc, j, stride); 1577 j--; 1578 } 1579 1580 rewriter.replaceOp(subViewOp, {targetMemRef}); 1581 return success(); 1582 } 1583 }; 1584 1585 /// Conversion pattern that transforms a transpose op into: 1586 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1587 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1588 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1589 /// and stride. Size and stride are permutations of the original values. 1590 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1591 /// The transpose op is replaced by the alloca'ed pointer. 1592 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1593 public: 1594 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1595 1596 LogicalResult 1597 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1598 ConversionPatternRewriter &rewriter) const override { 1599 auto loc = transposeOp.getLoc(); 1600 MemRefDescriptor viewMemRef(adaptor.in()); 1601 1602 // No permutation, early exit. 1603 if (transposeOp.permutation().isIdentity()) 1604 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1605 1606 auto targetMemRef = MemRefDescriptor::undef( 1607 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1608 1609 // Copy the base and aligned pointers from the old descriptor to the new 1610 // one. 1611 targetMemRef.setAllocatedPtr(rewriter, loc, 1612 viewMemRef.allocatedPtr(rewriter, loc)); 1613 targetMemRef.setAlignedPtr(rewriter, loc, 1614 viewMemRef.alignedPtr(rewriter, loc)); 1615 1616 // Copy the offset pointer from the old descriptor to the new one. 1617 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1618 1619 // Iterate over the dimensions and apply size/stride permutation. 1620 for (const auto &en : 1621 llvm::enumerate(transposeOp.permutation().getResults())) { 1622 int sourcePos = en.index(); 1623 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1624 targetMemRef.setSize(rewriter, loc, targetPos, 1625 viewMemRef.size(rewriter, loc, sourcePos)); 1626 targetMemRef.setStride(rewriter, loc, targetPos, 1627 viewMemRef.stride(rewriter, loc, sourcePos)); 1628 } 1629 1630 rewriter.replaceOp(transposeOp, {targetMemRef}); 1631 return success(); 1632 } 1633 }; 1634 1635 /// Conversion pattern that transforms an op into: 1636 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1637 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1638 /// and stride. 1639 /// The view op is replaced by the descriptor. 1640 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1641 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1642 1643 // Build and return the value for the idx^th shape dimension, either by 1644 // returning the constant shape dimension or counting the proper dynamic size. 1645 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1646 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1647 unsigned idx) const { 1648 assert(idx < shape.size()); 1649 if (!ShapedType::isDynamic(shape[idx])) 1650 return createIndexConstant(rewriter, loc, shape[idx]); 1651 // Count the number of dynamic dims in range [0, idx] 1652 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1653 return ShapedType::isDynamic(v); 1654 }); 1655 return dynamicSizes[nDynamic]; 1656 } 1657 1658 // Build and return the idx^th stride, either by returning the constant stride 1659 // or by computing the dynamic stride from the current `runningStride` and 1660 // `nextSize`. The caller should keep a running stride and update it with the 1661 // result returned by this function. 1662 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1663 ArrayRef<int64_t> strides, Value nextSize, 1664 Value runningStride, unsigned idx) const { 1665 assert(idx < strides.size()); 1666 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1667 return createIndexConstant(rewriter, loc, strides[idx]); 1668 if (nextSize) 1669 return runningStride 1670 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1671 : nextSize; 1672 assert(!runningStride); 1673 return createIndexConstant(rewriter, loc, 1); 1674 } 1675 1676 LogicalResult 1677 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1678 ConversionPatternRewriter &rewriter) const override { 1679 auto loc = viewOp.getLoc(); 1680 1681 auto viewMemRefType = viewOp.getType(); 1682 auto targetElementTy = 1683 typeConverter->convertType(viewMemRefType.getElementType()); 1684 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1685 if (!targetDescTy || !targetElementTy || 1686 !LLVM::isCompatibleType(targetElementTy) || 1687 !LLVM::isCompatibleType(targetDescTy)) 1688 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1689 failure(); 1690 1691 int64_t offset; 1692 SmallVector<int64_t, 4> strides; 1693 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1694 if (failed(successStrides)) 1695 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1696 assert(offset == 0 && "expected offset to be 0"); 1697 1698 // Create the descriptor. 1699 MemRefDescriptor sourceMemRef(adaptor.source()); 1700 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1701 1702 // Field 1: Copy the allocated pointer, used for malloc/free. 1703 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1704 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1705 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1706 loc, 1707 LLVM::LLVMPointerType::get(targetElementTy, 1708 srcMemRefType.getMemorySpaceAsInt()), 1709 allocatedPtr); 1710 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1711 1712 // Field 2: Copy the actual aligned pointer to payload. 1713 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1714 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1715 alignedPtr, adaptor.byte_shift()); 1716 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1717 loc, 1718 LLVM::LLVMPointerType::get(targetElementTy, 1719 srcMemRefType.getMemorySpaceAsInt()), 1720 alignedPtr); 1721 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1722 1723 // Field 3: The offset in the resulting type must be 0. This is because of 1724 // the type change: an offset on srcType* may not be expressible as an 1725 // offset on dstType*. 1726 targetMemRef.setOffset(rewriter, loc, 1727 createIndexConstant(rewriter, loc, offset)); 1728 1729 // Early exit for 0-D corner case. 1730 if (viewMemRefType.getRank() == 0) 1731 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1732 1733 // Fields 4 and 5: Update sizes and strides. 1734 if (strides.back() != 1) 1735 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1736 failure(); 1737 Value stride = nullptr, nextSize = nullptr; 1738 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1739 // Update size. 1740 Value size = 1741 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1742 targetMemRef.setSize(rewriter, loc, i, size); 1743 // Update stride. 1744 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1745 targetMemRef.setStride(rewriter, loc, i, stride); 1746 nextSize = size; 1747 } 1748 1749 rewriter.replaceOp(viewOp, {targetMemRef}); 1750 return success(); 1751 } 1752 }; 1753 1754 //===----------------------------------------------------------------------===// 1755 // AtomicRMWOpLowering 1756 //===----------------------------------------------------------------------===// 1757 1758 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1759 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1760 static Optional<LLVM::AtomicBinOp> 1761 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1762 switch (atomicOp.kind()) { 1763 case arith::AtomicRMWKind::addf: 1764 return LLVM::AtomicBinOp::fadd; 1765 case arith::AtomicRMWKind::addi: 1766 return LLVM::AtomicBinOp::add; 1767 case arith::AtomicRMWKind::assign: 1768 return LLVM::AtomicBinOp::xchg; 1769 case arith::AtomicRMWKind::maxs: 1770 return LLVM::AtomicBinOp::max; 1771 case arith::AtomicRMWKind::maxu: 1772 return LLVM::AtomicBinOp::umax; 1773 case arith::AtomicRMWKind::mins: 1774 return LLVM::AtomicBinOp::min; 1775 case arith::AtomicRMWKind::minu: 1776 return LLVM::AtomicBinOp::umin; 1777 case arith::AtomicRMWKind::ori: 1778 return LLVM::AtomicBinOp::_or; 1779 case arith::AtomicRMWKind::andi: 1780 return LLVM::AtomicBinOp::_and; 1781 default: 1782 return llvm::None; 1783 } 1784 llvm_unreachable("Invalid AtomicRMWKind"); 1785 } 1786 1787 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1788 using Base::Base; 1789 1790 LogicalResult 1791 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1792 ConversionPatternRewriter &rewriter) const override { 1793 if (failed(match(atomicOp))) 1794 return failure(); 1795 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1796 if (!maybeKind) 1797 return failure(); 1798 auto resultType = adaptor.value().getType(); 1799 auto memRefType = atomicOp.getMemRefType(); 1800 auto dataPtr = 1801 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1802 adaptor.indices(), rewriter); 1803 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1804 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1805 LLVM::AtomicOrdering::acq_rel); 1806 return success(); 1807 } 1808 }; 1809 1810 } // namespace 1811 1812 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1813 RewritePatternSet &patterns) { 1814 // clang-format off 1815 patterns.add< 1816 AllocaOpLowering, 1817 AllocaScopeOpLowering, 1818 AtomicRMWOpLowering, 1819 AssumeAlignmentOpLowering, 1820 DimOpLowering, 1821 GenericAtomicRMWOpLowering, 1822 GlobalMemrefOpLowering, 1823 GetGlobalMemrefOpLowering, 1824 LoadOpLowering, 1825 MemRefCastOpLowering, 1826 MemRefCopyOpLowering, 1827 MemRefReinterpretCastOpLowering, 1828 MemRefReshapeOpLowering, 1829 PrefetchOpLowering, 1830 RankOpLowering, 1831 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1832 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1833 StoreOpLowering, 1834 SubViewOpLowering, 1835 TransposeOpLowering, 1836 ViewOpLowering>(converter); 1837 // clang-format on 1838 auto allocLowering = converter.getOptions().allocLowering; 1839 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1840 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1841 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1842 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1843 } 1844 1845 namespace { 1846 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1847 MemRefToLLVMPass() = default; 1848 1849 void runOnOperation() override { 1850 Operation *op = getOperation(); 1851 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1852 LowerToLLVMOptions options(&getContext(), 1853 dataLayoutAnalysis.getAtOrAbove(op)); 1854 options.allocLowering = 1855 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1856 : LowerToLLVMOptions::AllocLowering::Malloc); 1857 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1858 options.overrideIndexBitwidth(indexBitwidth); 1859 1860 LLVMTypeConverter typeConverter(&getContext(), options, 1861 &dataLayoutAnalysis); 1862 RewritePatternSet patterns(&getContext()); 1863 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1864 LLVMConversionTarget target(getContext()); 1865 target.addLegalOp<func::FuncOp>(); 1866 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1867 signalPassFailure(); 1868 } 1869 }; 1870 } // namespace 1871 1872 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1873 return std::make_unique<MemRefToLLVMPass>(); 1874 } 1875