1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 27 AllocOpLowering(LLVMTypeConverter &converter) 28 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 29 converter) {} 30 31 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 32 Location loc, Value sizeBytes, 33 Operation *op) const override { 34 // Heap allocations. 35 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 36 MemRefType memRefType = allocOp.getType(); 37 38 Value alignment; 39 if (auto alignmentAttr = allocOp.alignment()) { 40 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 41 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 42 // In the case where no alignment is specified, we may want to override 43 // `malloc's` behavior. `malloc` typically aligns at the size of the 44 // biggest scalar on a target HW. For non-scalars, use the natural 45 // alignment of the LLVM type given by the LLVM DataLayout. 46 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 47 } 48 49 if (alignment) { 50 // Adjust the allocation size to consider alignment. 51 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 52 } 53 54 // Allocate the underlying buffer and store a pointer to it in the MemRef 55 // descriptor. 56 Type elementPtrType = this->getElementPtrType(memRefType); 57 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 58 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 59 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 60 getVoidPtrType()); 61 Value allocatedPtr = 62 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 63 64 Value alignedPtr = allocatedPtr; 65 if (alignment) { 66 // Compute the aligned type pointer. 67 Value allocatedInt = 68 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 69 Value alignmentInt = 70 createAligned(rewriter, loc, allocatedInt, alignment); 71 alignedPtr = 72 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 73 } 74 75 return std::make_tuple(allocatedPtr, alignedPtr); 76 } 77 }; 78 79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 80 AlignedAllocOpLowering(LLVMTypeConverter &converter) 81 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 82 converter) {} 83 84 /// Returns the memref's element size in bytes using the data layout active at 85 /// `op`. 86 // TODO: there are other places where this is used. Expose publicly? 87 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 88 const DataLayout *layout = &defaultLayout; 89 if (const DataLayoutAnalysis *analysis = 90 getTypeConverter()->getDataLayoutAnalysis()) { 91 layout = &analysis->getAbove(op); 92 } 93 Type elementType = memRefType.getElementType(); 94 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 95 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 96 *layout); 97 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 98 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 99 memRefElementType, *layout); 100 return layout->getTypeSize(elementType); 101 } 102 103 /// Returns true if the memref size in bytes is known to be a multiple of 104 /// factor assuming the data layout active at `op`. 105 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 106 Operation *op) const { 107 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 108 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 109 if (ShapedType::isDynamic(type.getDimSize(i))) 110 continue; 111 sizeDivisor = sizeDivisor * type.getDimSize(i); 112 } 113 return sizeDivisor % factor == 0; 114 } 115 116 /// Returns the alignment to be used for the allocation call itself. 117 /// aligned_alloc requires the allocation size to be a power of two, and the 118 /// allocation size to be a multiple of alignment, 119 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 120 if (Optional<uint64_t> alignment = allocOp.alignment()) 121 return *alignment; 122 123 // Whenever we don't have alignment set, we will use an alignment 124 // consistent with the element type; since the allocation size has to be a 125 // power of two, we will bump to the next power of two if it already isn't. 126 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 127 return std::max(kMinAlignedAllocAlignment, 128 llvm::PowerOf2Ceil(eltSizeBytes)); 129 } 130 131 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 132 Location loc, Value sizeBytes, 133 Operation *op) const override { 134 // Heap allocations. 135 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 136 MemRefType memRefType = allocOp.getType(); 137 int64_t alignment = getAllocationAlignment(allocOp); 138 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 139 140 // aligned_alloc requires size to be a multiple of alignment; we will pad 141 // the size to the next multiple if necessary. 142 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 143 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 144 145 Type elementPtrType = this->getElementPtrType(memRefType); 146 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 147 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 148 auto results = 149 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 150 getVoidPtrType()); 151 Value allocatedPtr = 152 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 153 154 return std::make_tuple(allocatedPtr, allocatedPtr); 155 } 156 157 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 158 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 159 160 /// Default layout to use in absence of the corresponding analysis. 161 DataLayout defaultLayout; 162 }; 163 164 // Out of line definition, required till C++17. 165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 166 167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 168 AllocaOpLowering(LLVMTypeConverter &converter) 169 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 170 converter) {} 171 172 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 173 /// is set to null for stack allocations. `accessAlignment` is set if 174 /// alignment is needed post allocation (for eg. in conjunction with malloc). 175 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 176 Location loc, Value sizeBytes, 177 Operation *op) const override { 178 179 // With alloca, one gets a pointer to the element type right away. 180 // For stack allocations. 181 auto allocaOp = cast<memref::AllocaOp>(op); 182 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 183 184 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 185 loc, elementPtrType, sizeBytes, 186 allocaOp.alignment() ? *allocaOp.alignment() : 0); 187 188 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 189 } 190 }; 191 192 struct AllocaScopeOpLowering 193 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 194 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 195 196 LogicalResult 197 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 198 ConversionPatternRewriter &rewriter) const override { 199 OpBuilder::InsertionGuard guard(rewriter); 200 Location loc = allocaScopeOp.getLoc(); 201 202 // Split the current block before the AllocaScopeOp to create the inlining 203 // point. 204 auto *currentBlock = rewriter.getInsertionBlock(); 205 auto *remainingOpsBlock = 206 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 207 Block *continueBlock; 208 if (allocaScopeOp.getNumResults() == 0) { 209 continueBlock = remainingOpsBlock; 210 } else { 211 continueBlock = rewriter.createBlock( 212 remainingOpsBlock, allocaScopeOp.getResultTypes(), 213 SmallVector<Location>(allocaScopeOp->getNumResults(), 214 allocaScopeOp.getLoc())); 215 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 216 } 217 218 // Inline body region. 219 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 220 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 221 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 222 223 // Save stack and then branch into the body of the region. 224 rewriter.setInsertionPointToEnd(currentBlock); 225 auto stackSaveOp = 226 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 227 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 228 229 // Replace the alloca_scope return with a branch that jumps out of the body. 230 // Stack restore before leaving the body region. 231 rewriter.setInsertionPointToEnd(afterBody); 232 auto returnOp = 233 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 234 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 235 returnOp, returnOp.results(), continueBlock); 236 237 // Insert stack restore before jumping out the body of the region. 238 rewriter.setInsertionPoint(branchOp); 239 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 240 241 // Replace the op with values return from the body region. 242 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 243 244 return success(); 245 } 246 }; 247 248 struct AssumeAlignmentOpLowering 249 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 250 using ConvertOpToLLVMPattern< 251 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 252 253 LogicalResult 254 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 255 ConversionPatternRewriter &rewriter) const override { 256 Value memref = adaptor.memref(); 257 unsigned alignment = op.alignment(); 258 auto loc = op.getLoc(); 259 260 MemRefDescriptor memRefDescriptor(memref); 261 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 262 263 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 264 // the asserted memref.alignedPtr isn't used anywhere else, as the real 265 // users like load/store/views always re-extract memref.alignedPtr as they 266 // get lowered. 267 // 268 // This relies on LLVM's CSE optimization (potentially after SROA), since 269 // after CSE all memref.alignedPtr instances get de-duplicated into the same 270 // pointer SSA value. 271 auto intPtrType = 272 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 273 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 274 Value mask = 275 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 276 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 277 rewriter.create<LLVM::AssumeOp>( 278 loc, rewriter.create<LLVM::ICmpOp>( 279 loc, LLVM::ICmpPredicate::eq, 280 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 281 282 rewriter.eraseOp(op); 283 return success(); 284 } 285 }; 286 287 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 288 // The memref descriptor being an SSA value, there is no need to clean it up 289 // in any way. 290 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 291 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 292 293 explicit DeallocOpLowering(LLVMTypeConverter &converter) 294 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 295 296 LogicalResult 297 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 298 ConversionPatternRewriter &rewriter) const override { 299 // Insert the `free` declaration if it is not already present. 300 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 301 MemRefDescriptor memref(adaptor.memref()); 302 Value casted = rewriter.create<LLVM::BitcastOp>( 303 op.getLoc(), getVoidPtrType(), 304 memref.allocatedPtr(rewriter, op.getLoc())); 305 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 306 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 307 return success(); 308 } 309 }; 310 311 // A `dim` is converted to a constant for static sizes and to an access to the 312 // size stored in the memref descriptor for dynamic sizes. 313 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 314 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 315 316 LogicalResult 317 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 318 ConversionPatternRewriter &rewriter) const override { 319 Type operandType = dimOp.source().getType(); 320 if (operandType.isa<UnrankedMemRefType>()) { 321 rewriter.replaceOp( 322 dimOp, {extractSizeOfUnrankedMemRef( 323 operandType, dimOp, adaptor.getOperands(), rewriter)}); 324 325 return success(); 326 } 327 if (operandType.isa<MemRefType>()) { 328 rewriter.replaceOp( 329 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 330 adaptor.getOperands(), rewriter)}); 331 return success(); 332 } 333 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 334 } 335 336 private: 337 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 338 OpAdaptor adaptor, 339 ConversionPatternRewriter &rewriter) const { 340 Location loc = dimOp.getLoc(); 341 342 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 343 auto scalarMemRefType = 344 MemRefType::get({}, unrankedMemRefType.getElementType()); 345 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 346 347 // Extract pointer to the underlying ranked descriptor and bitcast it to a 348 // memref<element_type> descriptor pointer to minimize the number of GEP 349 // operations. 350 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 351 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 352 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 353 loc, 354 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 355 addressSpace), 356 underlyingRankedDesc); 357 358 // Get pointer to offset field of memref<element_type> descriptor. 359 Type indexPtrTy = LLVM::LLVMPointerType::get( 360 getTypeConverter()->getIndexType(), addressSpace); 361 Value two = rewriter.create<LLVM::ConstantOp>( 362 loc, typeConverter->convertType(rewriter.getI32Type()), 363 rewriter.getI32IntegerAttr(2)); 364 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 365 loc, indexPtrTy, scalarMemRefDescPtr, 366 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 367 368 // The size value that we have to extract can be obtained using GEPop with 369 // `dimOp.index() + 1` index argument. 370 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 371 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 372 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 373 ValueRange({idxPlusOne})); 374 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 375 } 376 377 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 378 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 379 return idx; 380 381 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 382 return constantOp.getValue() 383 .cast<IntegerAttr>() 384 .getValue() 385 .getSExtValue(); 386 387 return llvm::None; 388 } 389 390 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 391 OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const { 393 Location loc = dimOp.getLoc(); 394 395 // Take advantage if index is constant. 396 MemRefType memRefType = operandType.cast<MemRefType>(); 397 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 398 int64_t i = index.getValue(); 399 if (memRefType.isDynamicDim(i)) { 400 // extract dynamic size from the memref descriptor. 401 MemRefDescriptor descriptor(adaptor.source()); 402 return descriptor.size(rewriter, loc, i); 403 } 404 // Use constant for static size. 405 int64_t dimSize = memRefType.getDimSize(i); 406 return createIndexConstant(rewriter, loc, dimSize); 407 } 408 Value index = adaptor.index(); 409 int64_t rank = memRefType.getRank(); 410 MemRefDescriptor memrefDescriptor(adaptor.source()); 411 return memrefDescriptor.size(rewriter, loc, index, rank); 412 } 413 }; 414 415 /// Common base for load and store operations on MemRefs. Restricts the match 416 /// to supported MemRef types. Provides functionality to emit code accessing a 417 /// specific element of the underlying data buffer. 418 template <typename Derived> 419 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 420 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 421 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 422 using Base = LoadStoreOpLowering<Derived>; 423 424 LogicalResult match(Derived op) const override { 425 MemRefType type = op.getMemRefType(); 426 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 427 } 428 }; 429 430 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 431 /// retried until it succeeds in atomically storing a new value into memory. 432 /// 433 /// +---------------------------------+ 434 /// | <code before the AtomicRMWOp> | 435 /// | <compute initial %loaded> | 436 /// | br loop(%loaded) | 437 /// +---------------------------------+ 438 /// | 439 /// -------| | 440 /// | v v 441 /// | +--------------------------------+ 442 /// | | loop(%loaded): | 443 /// | | <body contents> | 444 /// | | %pair = cmpxchg | 445 /// | | %ok = %pair[0] | 446 /// | | %new = %pair[1] | 447 /// | | cond_br %ok, end, loop(%new) | 448 /// | +--------------------------------+ 449 /// | | | 450 /// |----------- | 451 /// v 452 /// +--------------------------------+ 453 /// | end: | 454 /// | <code after the AtomicRMWOp> | 455 /// +--------------------------------+ 456 /// 457 struct GenericAtomicRMWOpLowering 458 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 459 using Base::Base; 460 461 LogicalResult 462 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 463 ConversionPatternRewriter &rewriter) const override { 464 auto loc = atomicOp.getLoc(); 465 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 466 467 // Split the block into initial, loop, and ending parts. 468 auto *initBlock = rewriter.getInsertionBlock(); 469 auto *loopBlock = rewriter.createBlock( 470 initBlock->getParent(), std::next(Region::iterator(initBlock)), 471 valueType, loc); 472 auto *endBlock = rewriter.createBlock( 473 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 474 475 // Operations range to be moved to `endBlock`. 476 auto opsToMoveStart = atomicOp->getIterator(); 477 auto opsToMoveEnd = initBlock->back().getIterator(); 478 479 // Compute the loaded value and branch to the loop block. 480 rewriter.setInsertionPointToEnd(initBlock); 481 auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); 482 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 483 adaptor.indices(), rewriter); 484 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 485 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 486 487 // Prepare the body of the loop block. 488 rewriter.setInsertionPointToStart(loopBlock); 489 490 // Clone the GenericAtomicRMWOp region and extract the result. 491 auto loopArgument = loopBlock->getArgument(0); 492 BlockAndValueMapping mapping; 493 mapping.map(atomicOp.getCurrentValue(), loopArgument); 494 Block &entryBlock = atomicOp.body().front(); 495 for (auto &nestedOp : entryBlock.without_terminator()) { 496 Operation *clone = rewriter.clone(nestedOp, mapping); 497 mapping.map(nestedOp.getResults(), clone->getResults()); 498 } 499 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 500 501 // Prepare the epilog of the loop block. 502 // Append the cmpxchg op to the end of the loop block. 503 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 504 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 505 auto boolType = IntegerType::get(rewriter.getContext(), 1); 506 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 507 {valueType, boolType}); 508 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 509 loc, pairType, dataPtr, loopArgument, result, successOrdering, 510 failureOrdering); 511 // Extract the %new_loaded and %ok values from the pair. 512 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 513 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 514 Value ok = rewriter.create<LLVM::ExtractValueOp>( 515 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 516 517 // Conditionally branch to the end or back to the loop depending on %ok. 518 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 519 loopBlock, newLoaded); 520 521 rewriter.setInsertionPointToEnd(endBlock); 522 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 523 std::next(opsToMoveEnd), rewriter); 524 525 // The 'result' of the atomic_rmw op is the newly loaded value. 526 rewriter.replaceOp(atomicOp, {newLoaded}); 527 528 return success(); 529 } 530 531 private: 532 // Clones a segment of ops [start, end) and erases the original. 533 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 534 Block::iterator start, Block::iterator end, 535 ConversionPatternRewriter &rewriter) const { 536 BlockAndValueMapping mapping; 537 mapping.map(oldResult, newResult); 538 SmallVector<Operation *, 2> opsToErase; 539 for (auto it = start; it != end; ++it) { 540 rewriter.clone(*it, mapping); 541 opsToErase.push_back(&*it); 542 } 543 for (auto *it : opsToErase) 544 rewriter.eraseOp(it); 545 } 546 }; 547 548 /// Returns the LLVM type of the global variable given the memref type `type`. 549 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 550 LLVMTypeConverter &typeConverter) { 551 // LLVM type for a global memref will be a multi-dimension array. For 552 // declarations or uninitialized global memrefs, we can potentially flatten 553 // this to a 1D array. However, for memref.global's with an initial value, 554 // we do not intend to flatten the ElementsAttribute when going from std -> 555 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 556 Type elementType = typeConverter.convertType(type.getElementType()); 557 Type arrayTy = elementType; 558 // Shape has the outermost dim at index 0, so need to walk it backwards 559 for (int64_t dim : llvm::reverse(type.getShape())) 560 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 561 return arrayTy; 562 } 563 564 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 565 struct GlobalMemrefOpLowering 566 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 567 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 568 569 LogicalResult 570 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 571 ConversionPatternRewriter &rewriter) const override { 572 MemRefType type = global.type(); 573 if (!isConvertibleAndHasIdentityMaps(type)) 574 return failure(); 575 576 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 577 578 LLVM::Linkage linkage = 579 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 580 581 Attribute initialValue = nullptr; 582 if (!global.isExternal() && !global.isUninitialized()) { 583 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 584 initialValue = elementsAttr; 585 586 // For scalar memrefs, the global variable created is of the element type, 587 // so unpack the elements attribute to extract the value. 588 if (type.getRank() == 0) 589 initialValue = elementsAttr.getSplatValue<Attribute>(); 590 } 591 592 uint64_t alignment = global.alignment().getValueOr(0); 593 594 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 595 global, arrayTy, global.constant(), linkage, global.sym_name(), 596 initialValue, alignment, type.getMemorySpaceAsInt()); 597 if (!global.isExternal() && global.isUninitialized()) { 598 Block *blk = new Block(); 599 newGlobal.getInitializerRegion().push_back(blk); 600 rewriter.setInsertionPointToStart(blk); 601 Value undef[] = { 602 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 603 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 604 } 605 return success(); 606 } 607 }; 608 609 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 610 /// the first element stashed into the descriptor. This reuses 611 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 612 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 613 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 614 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 615 converter) {} 616 617 /// Buffer "allocation" for memref.get_global op is getting the address of 618 /// the global variable referenced. 619 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 620 Location loc, Value sizeBytes, 621 Operation *op) const override { 622 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 623 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 624 unsigned memSpace = type.getMemorySpaceAsInt(); 625 626 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 627 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 628 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 629 630 // Get the address of the first element in the array by creating a GEP with 631 // the address of the GV as the base, and (rank + 1) number of 0 indices. 632 Type elementType = typeConverter->convertType(type.getElementType()); 633 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 634 635 SmallVector<Value> operands; 636 operands.insert(operands.end(), type.getRank() + 1, 637 createIndexConstant(rewriter, loc, 0)); 638 auto gep = 639 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 640 641 // We do not expect the memref obtained using `memref.get_global` to be 642 // ever deallocated. Set the allocated pointer to be known bad value to 643 // help debug if that ever happens. 644 auto intPtrType = getIntPtrType(memSpace); 645 Value deadBeefConst = 646 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 647 auto deadBeefPtr = 648 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 649 650 // Both allocated and aligned pointers are same. We could potentially stash 651 // a nullptr for the allocated pointer since we do not expect any dealloc. 652 return std::make_tuple(deadBeefPtr, gep); 653 } 654 }; 655 656 // Load operation is lowered to obtaining a pointer to the indexed element 657 // and loading it. 658 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 659 using Base::Base; 660 661 LogicalResult 662 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 663 ConversionPatternRewriter &rewriter) const override { 664 auto type = loadOp.getMemRefType(); 665 666 Value dataPtr = getStridedElementPtr( 667 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 668 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 669 return success(); 670 } 671 }; 672 673 // Store operation is lowered to obtaining a pointer to the indexed element, 674 // and storing the given value to it. 675 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 676 using Base::Base; 677 678 LogicalResult 679 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 680 ConversionPatternRewriter &rewriter) const override { 681 auto type = op.getMemRefType(); 682 683 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 684 adaptor.indices(), rewriter); 685 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 686 return success(); 687 } 688 }; 689 690 // The prefetch operation is lowered in a way similar to the load operation 691 // except that the llvm.prefetch operation is used for replacement. 692 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 693 using Base::Base; 694 695 LogicalResult 696 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 697 ConversionPatternRewriter &rewriter) const override { 698 auto type = prefetchOp.getMemRefType(); 699 auto loc = prefetchOp.getLoc(); 700 701 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 702 adaptor.indices(), rewriter); 703 704 // Replace with llvm.prefetch. 705 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 706 auto isWrite = rewriter.create<LLVM::ConstantOp>( 707 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 708 auto localityHint = rewriter.create<LLVM::ConstantOp>( 709 loc, llvmI32Type, 710 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 711 auto isData = rewriter.create<LLVM::ConstantOp>( 712 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 713 714 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 715 localityHint, isData); 716 return success(); 717 } 718 }; 719 720 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 721 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 722 723 LogicalResult 724 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 725 ConversionPatternRewriter &rewriter) const override { 726 Location loc = op.getLoc(); 727 Type operandType = op.memref().getType(); 728 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 729 UnrankedMemRefDescriptor desc(adaptor.memref()); 730 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 731 return success(); 732 } 733 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 734 rewriter.replaceOp( 735 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 736 return success(); 737 } 738 return failure(); 739 } 740 }; 741 742 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 743 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 744 745 LogicalResult match(memref::CastOp memRefCastOp) const override { 746 Type srcType = memRefCastOp.getOperand().getType(); 747 Type dstType = memRefCastOp.getType(); 748 749 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 750 // used for type erasure. For now they must preserve underlying element type 751 // and require source and result type to have the same rank. Therefore, 752 // perform a sanity check that the underlying structs are the same. Once op 753 // semantics are relaxed we can revisit. 754 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 755 return success(typeConverter->convertType(srcType) == 756 typeConverter->convertType(dstType)); 757 758 // At least one of the operands is unranked type 759 assert(srcType.isa<UnrankedMemRefType>() || 760 dstType.isa<UnrankedMemRefType>()); 761 762 // Unranked to unranked cast is disallowed 763 return !(srcType.isa<UnrankedMemRefType>() && 764 dstType.isa<UnrankedMemRefType>()) 765 ? success() 766 : failure(); 767 } 768 769 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 770 ConversionPatternRewriter &rewriter) const override { 771 auto srcType = memRefCastOp.getOperand().getType(); 772 auto dstType = memRefCastOp.getType(); 773 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 774 auto loc = memRefCastOp.getLoc(); 775 776 // For ranked/ranked case, just keep the original descriptor. 777 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 778 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 779 780 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 781 // Casting ranked to unranked memref type 782 // Set the rank in the destination from the memref type 783 // Allocate space on the stack and copy the src memref descriptor 784 // Set the ptr in the destination to the stack space 785 auto srcMemRefType = srcType.cast<MemRefType>(); 786 int64_t rank = srcMemRefType.getRank(); 787 // ptr = AllocaOp sizeof(MemRefDescriptor) 788 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 789 loc, adaptor.source(), rewriter); 790 // voidptr = BitCastOp srcType* to void* 791 auto voidPtr = 792 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 793 .getResult(); 794 // rank = ConstantOp srcRank 795 auto rankVal = rewriter.create<LLVM::ConstantOp>( 796 loc, typeConverter->convertType(rewriter.getIntegerType(64)), 797 rewriter.getI64IntegerAttr(rank)); 798 // undef = UndefOp 799 UnrankedMemRefDescriptor memRefDesc = 800 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 801 // d1 = InsertValueOp undef, rank, 0 802 memRefDesc.setRank(rewriter, loc, rankVal); 803 // d2 = InsertValueOp d1, voidptr, 1 804 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 805 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 806 807 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 808 // Casting from unranked type to ranked. 809 // The operation is assumed to be doing a correct cast. If the destination 810 // type mismatches the unranked the type, it is undefined behavior. 811 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 812 // ptr = ExtractValueOp src, 1 813 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 814 // castPtr = BitCastOp i8* to structTy* 815 auto castPtr = 816 rewriter 817 .create<LLVM::BitcastOp>( 818 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 819 .getResult(); 820 // struct = LoadOp castPtr 821 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 822 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 823 } else { 824 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 825 } 826 } 827 }; 828 829 /// Pattern to lower a `memref.copy` to llvm. 830 /// 831 /// For memrefs with identity layouts, the copy is lowered to the llvm 832 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 833 /// to the generic `MemrefCopyFn`. 834 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 835 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 836 837 LogicalResult 838 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 839 ConversionPatternRewriter &rewriter) const { 840 auto loc = op.getLoc(); 841 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 842 843 MemRefDescriptor srcDesc(adaptor.source()); 844 845 // Compute number of elements. 846 Value numElements = rewriter.create<LLVM::ConstantOp>( 847 loc, getIndexType(), rewriter.getIndexAttr(1)); 848 for (int pos = 0; pos < srcType.getRank(); ++pos) { 849 auto size = srcDesc.size(rewriter, loc, pos); 850 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 851 } 852 853 // Get element size. 854 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 855 // Compute total. 856 Value totalSize = 857 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 858 859 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 860 MemRefDescriptor targetDesc(adaptor.target()); 861 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 862 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 863 loc, typeConverter->convertType(rewriter.getI1Type()), 864 rewriter.getBoolAttr(false)); 865 rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize, 866 isVolatile); 867 rewriter.eraseOp(op); 868 869 return success(); 870 } 871 872 LogicalResult 873 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 874 ConversionPatternRewriter &rewriter) const { 875 auto loc = op.getLoc(); 876 auto srcType = op.source().getType().cast<BaseMemRefType>(); 877 auto targetType = op.target().getType().cast<BaseMemRefType>(); 878 879 // First make sure we have an unranked memref descriptor representation. 880 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 881 auto rank = rewriter.create<LLVM::ConstantOp>( 882 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 883 auto *typeConverter = getTypeConverter(); 884 auto ptr = 885 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 886 auto voidPtr = 887 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 888 .getResult(); 889 auto unrankedType = 890 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 891 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 892 unrankedType, 893 ValueRange{rank, voidPtr}); 894 }; 895 896 Value unrankedSource = srcType.hasRank() 897 ? makeUnranked(adaptor.source(), srcType) 898 : adaptor.source(); 899 Value unrankedTarget = targetType.hasRank() 900 ? makeUnranked(adaptor.target(), targetType) 901 : adaptor.target(); 902 903 // Now promote the unranked descriptors to the stack. 904 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 905 rewriter.getIndexAttr(1)); 906 auto promote = [&](Value desc) { 907 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 908 auto allocated = 909 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 910 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 911 return allocated; 912 }; 913 914 auto sourcePtr = promote(unrankedSource); 915 auto targetPtr = promote(unrankedTarget); 916 917 auto elemSize = rewriter.create<LLVM::ConstantOp>( 918 loc, getIndexType(), 919 rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8)); 920 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 921 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 922 rewriter.create<LLVM::CallOp>(loc, copyFn, 923 ValueRange{elemSize, sourcePtr, targetPtr}); 924 rewriter.eraseOp(op); 925 926 return success(); 927 } 928 929 LogicalResult 930 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 931 ConversionPatternRewriter &rewriter) const override { 932 auto srcType = op.source().getType().cast<BaseMemRefType>(); 933 auto targetType = op.target().getType().cast<BaseMemRefType>(); 934 935 if (srcType.hasRank() && 936 srcType.cast<MemRefType>().getLayout().isIdentity() && 937 targetType.hasRank() && 938 targetType.cast<MemRefType>().getLayout().isIdentity()) 939 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 940 941 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 942 } 943 }; 944 945 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 946 /// memref type. In unranked case, the fields are extracted from the underlying 947 /// ranked descriptor. 948 static void extractPointersAndOffset(Location loc, 949 ConversionPatternRewriter &rewriter, 950 LLVMTypeConverter &typeConverter, 951 Value originalOperand, 952 Value convertedOperand, 953 Value *allocatedPtr, Value *alignedPtr, 954 Value *offset = nullptr) { 955 Type operandType = originalOperand.getType(); 956 if (operandType.isa<MemRefType>()) { 957 MemRefDescriptor desc(convertedOperand); 958 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 959 *alignedPtr = desc.alignedPtr(rewriter, loc); 960 if (offset != nullptr) 961 *offset = desc.offset(rewriter, loc); 962 return; 963 } 964 965 unsigned memorySpace = 966 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 967 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 968 Type llvmElementType = typeConverter.convertType(elementType); 969 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 970 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 971 972 // Extract pointer to the underlying ranked memref descriptor and cast it to 973 // ElemType**. 974 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 975 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 976 977 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 978 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 979 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 980 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 981 if (offset != nullptr) { 982 *offset = UnrankedMemRefDescriptor::offset( 983 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 984 } 985 } 986 987 struct MemRefReinterpretCastOpLowering 988 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 989 using ConvertOpToLLVMPattern< 990 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 991 992 LogicalResult 993 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 994 ConversionPatternRewriter &rewriter) const override { 995 Type srcType = castOp.source().getType(); 996 997 Value descriptor; 998 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 999 adaptor, &descriptor))) 1000 return failure(); 1001 rewriter.replaceOp(castOp, {descriptor}); 1002 return success(); 1003 } 1004 1005 private: 1006 LogicalResult convertSourceMemRefToDescriptor( 1007 ConversionPatternRewriter &rewriter, Type srcType, 1008 memref::ReinterpretCastOp castOp, 1009 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1010 MemRefType targetMemRefType = 1011 castOp.getResult().getType().cast<MemRefType>(); 1012 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1013 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1014 if (!llvmTargetDescriptorTy) 1015 return failure(); 1016 1017 // Create descriptor. 1018 Location loc = castOp.getLoc(); 1019 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1020 1021 // Set allocated and aligned pointers. 1022 Value allocatedPtr, alignedPtr; 1023 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1024 castOp.source(), adaptor.source(), &allocatedPtr, 1025 &alignedPtr); 1026 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1027 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1028 1029 // Set offset. 1030 if (castOp.isDynamicOffset(0)) 1031 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 1032 else 1033 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1034 1035 // Set sizes and strides. 1036 unsigned dynSizeId = 0; 1037 unsigned dynStrideId = 0; 1038 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1039 if (castOp.isDynamicSize(i)) 1040 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 1041 else 1042 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1043 1044 if (castOp.isDynamicStride(i)) 1045 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 1046 else 1047 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1048 } 1049 *descriptor = desc; 1050 return success(); 1051 } 1052 }; 1053 1054 struct MemRefReshapeOpLowering 1055 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1056 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1057 1058 LogicalResult 1059 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1060 ConversionPatternRewriter &rewriter) const override { 1061 Type srcType = reshapeOp.source().getType(); 1062 1063 Value descriptor; 1064 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1065 adaptor, &descriptor))) 1066 return failure(); 1067 rewriter.replaceOp(reshapeOp, {descriptor}); 1068 return success(); 1069 } 1070 1071 private: 1072 LogicalResult 1073 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1074 Type srcType, memref::ReshapeOp reshapeOp, 1075 memref::ReshapeOp::Adaptor adaptor, 1076 Value *descriptor) const { 1077 // Conversion for statically-known shape args is performed via 1078 // `memref_reinterpret_cast`. 1079 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 1080 if (shapeMemRefType.hasStaticShape()) 1081 return failure(); 1082 1083 // The shape is a rank-1 tensor with unknown length. 1084 Location loc = reshapeOp.getLoc(); 1085 MemRefDescriptor shapeDesc(adaptor.shape()); 1086 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1087 1088 // Extract address space and element type. 1089 auto targetType = 1090 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1091 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1092 Type elementType = targetType.getElementType(); 1093 1094 // Create the unranked memref descriptor that holds the ranked one. The 1095 // inner descriptor is allocated on stack. 1096 auto targetDesc = UnrankedMemRefDescriptor::undef( 1097 rewriter, loc, typeConverter->convertType(targetType)); 1098 targetDesc.setRank(rewriter, loc, resultRank); 1099 SmallVector<Value, 4> sizes; 1100 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1101 targetDesc, sizes); 1102 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1103 loc, getVoidPtrType(), sizes.front(), llvm::None); 1104 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1105 1106 // Extract pointers and offset from the source memref. 1107 Value allocatedPtr, alignedPtr, offset; 1108 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1109 reshapeOp.source(), adaptor.source(), 1110 &allocatedPtr, &alignedPtr, &offset); 1111 1112 // Set pointers and offset. 1113 Type llvmElementType = typeConverter->convertType(elementType); 1114 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1115 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1116 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1117 elementPtrPtrType, allocatedPtr); 1118 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1119 underlyingDescPtr, 1120 elementPtrPtrType, alignedPtr); 1121 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1122 underlyingDescPtr, elementPtrPtrType, 1123 offset); 1124 1125 // Use the offset pointer as base for further addressing. Copy over the new 1126 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1127 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1128 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1129 elementPtrPtrType); 1130 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1131 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1132 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1133 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1134 Value resultRankMinusOne = 1135 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1136 1137 Block *initBlock = rewriter.getInsertionBlock(); 1138 Type indexType = getTypeConverter()->getIndexType(); 1139 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1140 1141 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1142 {indexType, indexType}, {loc, loc}); 1143 1144 // Move the remaining initBlock ops to condBlock. 1145 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1146 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1147 1148 rewriter.setInsertionPointToEnd(initBlock); 1149 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1150 condBlock); 1151 rewriter.setInsertionPointToStart(condBlock); 1152 Value indexArg = condBlock->getArgument(0); 1153 Value strideArg = condBlock->getArgument(1); 1154 1155 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1156 Value pred = rewriter.create<LLVM::ICmpOp>( 1157 loc, IntegerType::get(rewriter.getContext(), 1), 1158 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1159 1160 Block *bodyBlock = 1161 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1162 rewriter.setInsertionPointToStart(bodyBlock); 1163 1164 // Copy size from shape to descriptor. 1165 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1166 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1167 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1168 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1169 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1170 targetSizesBase, indexArg, size); 1171 1172 // Write stride value and compute next one. 1173 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1174 targetStridesBase, indexArg, strideArg); 1175 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1176 1177 // Decrement loop counter and branch back. 1178 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1179 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1180 condBlock); 1181 1182 Block *remainder = 1183 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1184 1185 // Hook up the cond exit to the remainder. 1186 rewriter.setInsertionPointToEnd(condBlock); 1187 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1188 llvm::None); 1189 1190 // Reset position to beginning of new remainder block. 1191 rewriter.setInsertionPointToStart(remainder); 1192 1193 *descriptor = targetDesc; 1194 return success(); 1195 } 1196 }; 1197 1198 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1199 /// `Value`s. 1200 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1201 Type &llvmIndexType, 1202 ArrayRef<OpFoldResult> valueOrAttrVec) { 1203 return llvm::to_vector<4>( 1204 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1205 if (auto attr = value.dyn_cast<Attribute>()) 1206 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1207 return value.get<Value>(); 1208 })); 1209 } 1210 1211 /// Compute a map that for a given dimension of the expanded type gives the 1212 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1213 /// the `reassocation` maps. 1214 static DenseMap<int64_t, int64_t> 1215 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1216 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1217 for (auto &en : enumerate(reassociation)) { 1218 for (auto dim : en.value()) 1219 expandedDimToCollapsedDim[dim] = en.index(); 1220 } 1221 return expandedDimToCollapsedDim; 1222 } 1223 1224 static OpFoldResult 1225 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1226 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1227 MemRefDescriptor &inDesc, 1228 ArrayRef<int64_t> inStaticShape, 1229 ArrayRef<ReassociationIndices> reassocation, 1230 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1231 int64_t outDimSize = outStaticShape[outDimIndex]; 1232 if (!ShapedType::isDynamic(outDimSize)) 1233 return b.getIndexAttr(outDimSize); 1234 1235 // Calculate the multiplication of all the out dim sizes except the 1236 // current dim. 1237 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1238 int64_t otherDimSizesMul = 1; 1239 for (auto otherDimIndex : reassocation[inDimIndex]) { 1240 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1241 continue; 1242 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1243 assert(!ShapedType::isDynamic(otherDimSize) && 1244 "single dimension cannot be expanded into multiple dynamic " 1245 "dimensions"); 1246 otherDimSizesMul *= otherDimSize; 1247 } 1248 1249 // outDimSize = inDimSize / otherOutDimSizesMul 1250 int64_t inDimSize = inStaticShape[inDimIndex]; 1251 Value inDimSizeDynamic = 1252 ShapedType::isDynamic(inDimSize) 1253 ? inDesc.size(b, loc, inDimIndex) 1254 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1255 b.getIndexAttr(inDimSize)); 1256 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1257 loc, inDimSizeDynamic, 1258 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1259 b.getIndexAttr(otherDimSizesMul))); 1260 return outDimSizeDynamic; 1261 } 1262 1263 static OpFoldResult getCollapsedOutputDimSize( 1264 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1265 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1266 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1267 if (!ShapedType::isDynamic(outDimSize)) 1268 return b.getIndexAttr(outDimSize); 1269 1270 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1271 Value outDimSizeDynamic = c1; 1272 for (auto inDimIndex : reassocation[outDimIndex]) { 1273 int64_t inDimSize = inStaticShape[inDimIndex]; 1274 Value inDimSizeDynamic = 1275 ShapedType::isDynamic(inDimSize) 1276 ? inDesc.size(b, loc, inDimIndex) 1277 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1278 b.getIndexAttr(inDimSize)); 1279 outDimSizeDynamic = 1280 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1281 } 1282 return outDimSizeDynamic; 1283 } 1284 1285 static SmallVector<OpFoldResult, 4> 1286 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1287 ArrayRef<ReassociationIndices> reassocation, 1288 ArrayRef<int64_t> inStaticShape, 1289 MemRefDescriptor &inDesc, 1290 ArrayRef<int64_t> outStaticShape) { 1291 return llvm::to_vector<4>(llvm::map_range( 1292 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1293 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1294 outStaticShape[outDimIndex], 1295 inStaticShape, inDesc, reassocation); 1296 })); 1297 } 1298 1299 static SmallVector<OpFoldResult, 4> 1300 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1301 ArrayRef<ReassociationIndices> reassocation, 1302 ArrayRef<int64_t> inStaticShape, 1303 MemRefDescriptor &inDesc, 1304 ArrayRef<int64_t> outStaticShape) { 1305 DenseMap<int64_t, int64_t> outDimToInDimMap = 1306 getExpandedDimToCollapsedDimMap(reassocation); 1307 return llvm::to_vector<4>(llvm::map_range( 1308 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1309 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1310 outStaticShape, inDesc, inStaticShape, 1311 reassocation, outDimToInDimMap); 1312 })); 1313 } 1314 1315 static SmallVector<Value> 1316 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1317 ArrayRef<ReassociationIndices> reassocation, 1318 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1319 ArrayRef<int64_t> outStaticShape) { 1320 return outStaticShape.size() < inStaticShape.size() 1321 ? getAsValues(b, loc, llvmIndexType, 1322 getCollapsedOutputShape(b, loc, llvmIndexType, 1323 reassocation, inStaticShape, 1324 inDesc, outStaticShape)) 1325 : getAsValues(b, loc, llvmIndexType, 1326 getExpandedOutputShape(b, loc, llvmIndexType, 1327 reassocation, inStaticShape, 1328 inDesc, outStaticShape)); 1329 } 1330 1331 // ReshapeOp creates a new view descriptor of the proper rank. 1332 // For now, the only conversion supported is for target MemRef with static sizes 1333 // and strides. 1334 template <typename ReshapeOp> 1335 class ReassociatingReshapeOpConversion 1336 : public ConvertOpToLLVMPattern<ReshapeOp> { 1337 public: 1338 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1339 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1340 1341 LogicalResult 1342 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1343 ConversionPatternRewriter &rewriter) const override { 1344 MemRefType dstType = reshapeOp.getResultType(); 1345 MemRefType srcType = reshapeOp.getSrcType(); 1346 1347 // The condition on the layouts can be ignored when all shapes are static. 1348 if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { 1349 if (!srcType.getLayout().isIdentity() || 1350 !dstType.getLayout().isIdentity()) { 1351 return rewriter.notifyMatchFailure( 1352 reshapeOp, "only empty layout map is supported"); 1353 } 1354 } 1355 1356 int64_t offset; 1357 SmallVector<int64_t, 4> strides; 1358 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1359 return rewriter.notifyMatchFailure( 1360 reshapeOp, "failed to get stride and offset exprs"); 1361 } 1362 1363 MemRefDescriptor srcDesc(adaptor.src()); 1364 Location loc = reshapeOp->getLoc(); 1365 auto dstDesc = MemRefDescriptor::undef( 1366 rewriter, loc, this->typeConverter->convertType(dstType)); 1367 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1368 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1369 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1370 1371 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1372 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1373 Type llvmIndexType = 1374 this->typeConverter->convertType(rewriter.getIndexType()); 1375 SmallVector<Value> dstShape = getDynamicOutputShape( 1376 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1377 srcStaticShape, srcDesc, dstStaticShape); 1378 for (auto &en : llvm::enumerate(dstShape)) 1379 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1380 1381 auto isStaticStride = [](int64_t stride) { 1382 return !ShapedType::isDynamicStrideOrOffset(stride); 1383 }; 1384 if (llvm::all_of(strides, isStaticStride)) { 1385 for (auto &en : llvm::enumerate(strides)) 1386 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1387 } else { 1388 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1389 rewriter.getIndexAttr(1)); 1390 Value stride = c1; 1391 for (auto dimIndex : 1392 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1393 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1394 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1395 } 1396 } 1397 rewriter.replaceOp(reshapeOp, {dstDesc}); 1398 return success(); 1399 } 1400 }; 1401 1402 /// Conversion pattern that transforms a subview op into: 1403 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1404 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1405 /// and stride. 1406 /// The subview op is replaced by the descriptor. 1407 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1408 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1409 1410 LogicalResult 1411 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1412 ConversionPatternRewriter &rewriter) const override { 1413 auto loc = subViewOp.getLoc(); 1414 1415 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1416 auto sourceElementTy = 1417 typeConverter->convertType(sourceMemRefType.getElementType()); 1418 1419 auto viewMemRefType = subViewOp.getType(); 1420 auto inferredType = memref::SubViewOp::inferResultType( 1421 subViewOp.getSourceType(), 1422 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1423 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1424 extractFromI64ArrayAttr(subViewOp.static_strides())) 1425 .cast<MemRefType>(); 1426 auto targetElementTy = 1427 typeConverter->convertType(viewMemRefType.getElementType()); 1428 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1429 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1430 !LLVM::isCompatibleType(sourceElementTy) || 1431 !LLVM::isCompatibleType(targetElementTy) || 1432 !LLVM::isCompatibleType(targetDescTy)) 1433 return failure(); 1434 1435 // Extract the offset and strides from the type. 1436 int64_t offset; 1437 SmallVector<int64_t, 4> strides; 1438 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1439 if (failed(successStrides)) 1440 return failure(); 1441 1442 // Create the descriptor. 1443 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1444 return failure(); 1445 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1446 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1447 1448 // Copy the buffer pointer from the old descriptor to the new one. 1449 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1450 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1451 loc, 1452 LLVM::LLVMPointerType::get(targetElementTy, 1453 viewMemRefType.getMemorySpaceAsInt()), 1454 extracted); 1455 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1456 1457 // Copy the aligned pointer from the old descriptor to the new one. 1458 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1459 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1460 loc, 1461 LLVM::LLVMPointerType::get(targetElementTy, 1462 viewMemRefType.getMemorySpaceAsInt()), 1463 extracted); 1464 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1465 1466 size_t inferredShapeRank = inferredType.getRank(); 1467 size_t resultShapeRank = viewMemRefType.getRank(); 1468 1469 // Extract strides needed to compute offset. 1470 SmallVector<Value, 4> strideValues; 1471 strideValues.reserve(inferredShapeRank); 1472 for (unsigned i = 0; i < inferredShapeRank; ++i) 1473 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1474 1475 // Offset. 1476 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1477 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1478 targetMemRef.setConstantOffset(rewriter, loc, offset); 1479 } else { 1480 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1481 // `inferredShapeRank` may be larger than the number of offset operands 1482 // because of trailing semantics. In this case, the offset is guaranteed 1483 // to be interpreted as 0 and we can just skip the extra dimensions. 1484 for (unsigned i = 0, e = std::min(inferredShapeRank, 1485 subViewOp.getMixedOffsets().size()); 1486 i < e; ++i) { 1487 Value offset = 1488 // TODO: need OpFoldResult ODS adaptor to clean this up. 1489 subViewOp.isDynamicOffset(i) 1490 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1491 : rewriter.create<LLVM::ConstantOp>( 1492 loc, llvmIndexType, 1493 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1494 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1495 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1496 } 1497 targetMemRef.setOffset(rewriter, loc, baseOffset); 1498 } 1499 1500 // Update sizes and strides. 1501 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1502 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1503 assert(mixedSizes.size() == mixedStrides.size() && 1504 "expected sizes and strides of equal length"); 1505 llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims(); 1506 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1507 i >= 0 && j >= 0; --i) { 1508 if (unusedDims.contains(i)) 1509 continue; 1510 1511 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1512 // In this case, the size is guaranteed to be interpreted as Dim and the 1513 // stride as 1. 1514 Value size, stride; 1515 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1516 // If the static size is available, use it directly. This is similar to 1517 // the folding of dim(constant-op) but removes the need for dim to be 1518 // aware of LLVM constants and for this pass to be aware of std 1519 // constants. 1520 int64_t staticSize = 1521 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1522 if (staticSize != ShapedType::kDynamicSize) { 1523 size = rewriter.create<LLVM::ConstantOp>( 1524 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1525 } else { 1526 Value pos = rewriter.create<LLVM::ConstantOp>( 1527 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1528 Value dim = 1529 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1530 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1531 loc, llvmIndexType, dim); 1532 size = cast.getResult(0); 1533 } 1534 stride = rewriter.create<LLVM::ConstantOp>( 1535 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1536 } else { 1537 // TODO: need OpFoldResult ODS adaptor to clean this up. 1538 size = 1539 subViewOp.isDynamicSize(i) 1540 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1541 : rewriter.create<LLVM::ConstantOp>( 1542 loc, llvmIndexType, 1543 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1544 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1545 stride = rewriter.create<LLVM::ConstantOp>( 1546 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1547 } else { 1548 stride = 1549 subViewOp.isDynamicStride(i) 1550 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1551 : rewriter.create<LLVM::ConstantOp>( 1552 loc, llvmIndexType, 1553 rewriter.getI64IntegerAttr( 1554 subViewOp.getStaticStride(i))); 1555 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1556 } 1557 } 1558 targetMemRef.setSize(rewriter, loc, j, size); 1559 targetMemRef.setStride(rewriter, loc, j, stride); 1560 j--; 1561 } 1562 1563 rewriter.replaceOp(subViewOp, {targetMemRef}); 1564 return success(); 1565 } 1566 }; 1567 1568 /// Conversion pattern that transforms a transpose op into: 1569 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1570 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1571 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1572 /// and stride. Size and stride are permutations of the original values. 1573 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1574 /// The transpose op is replaced by the alloca'ed pointer. 1575 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1576 public: 1577 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1578 1579 LogicalResult 1580 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1581 ConversionPatternRewriter &rewriter) const override { 1582 auto loc = transposeOp.getLoc(); 1583 MemRefDescriptor viewMemRef(adaptor.in()); 1584 1585 // No permutation, early exit. 1586 if (transposeOp.permutation().isIdentity()) 1587 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1588 1589 auto targetMemRef = MemRefDescriptor::undef( 1590 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1591 1592 // Copy the base and aligned pointers from the old descriptor to the new 1593 // one. 1594 targetMemRef.setAllocatedPtr(rewriter, loc, 1595 viewMemRef.allocatedPtr(rewriter, loc)); 1596 targetMemRef.setAlignedPtr(rewriter, loc, 1597 viewMemRef.alignedPtr(rewriter, loc)); 1598 1599 // Copy the offset pointer from the old descriptor to the new one. 1600 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1601 1602 // Iterate over the dimensions and apply size/stride permutation. 1603 for (const auto &en : 1604 llvm::enumerate(transposeOp.permutation().getResults())) { 1605 int sourcePos = en.index(); 1606 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1607 targetMemRef.setSize(rewriter, loc, targetPos, 1608 viewMemRef.size(rewriter, loc, sourcePos)); 1609 targetMemRef.setStride(rewriter, loc, targetPos, 1610 viewMemRef.stride(rewriter, loc, sourcePos)); 1611 } 1612 1613 rewriter.replaceOp(transposeOp, {targetMemRef}); 1614 return success(); 1615 } 1616 }; 1617 1618 /// Conversion pattern that transforms an op into: 1619 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1620 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1621 /// and stride. 1622 /// The view op is replaced by the descriptor. 1623 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1624 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1625 1626 // Build and return the value for the idx^th shape dimension, either by 1627 // returning the constant shape dimension or counting the proper dynamic size. 1628 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1629 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1630 unsigned idx) const { 1631 assert(idx < shape.size()); 1632 if (!ShapedType::isDynamic(shape[idx])) 1633 return createIndexConstant(rewriter, loc, shape[idx]); 1634 // Count the number of dynamic dims in range [0, idx] 1635 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1636 return ShapedType::isDynamic(v); 1637 }); 1638 return dynamicSizes[nDynamic]; 1639 } 1640 1641 // Build and return the idx^th stride, either by returning the constant stride 1642 // or by computing the dynamic stride from the current `runningStride` and 1643 // `nextSize`. The caller should keep a running stride and update it with the 1644 // result returned by this function. 1645 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1646 ArrayRef<int64_t> strides, Value nextSize, 1647 Value runningStride, unsigned idx) const { 1648 assert(idx < strides.size()); 1649 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1650 return createIndexConstant(rewriter, loc, strides[idx]); 1651 if (nextSize) 1652 return runningStride 1653 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1654 : nextSize; 1655 assert(!runningStride); 1656 return createIndexConstant(rewriter, loc, 1); 1657 } 1658 1659 LogicalResult 1660 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1661 ConversionPatternRewriter &rewriter) const override { 1662 auto loc = viewOp.getLoc(); 1663 1664 auto viewMemRefType = viewOp.getType(); 1665 auto targetElementTy = 1666 typeConverter->convertType(viewMemRefType.getElementType()); 1667 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1668 if (!targetDescTy || !targetElementTy || 1669 !LLVM::isCompatibleType(targetElementTy) || 1670 !LLVM::isCompatibleType(targetDescTy)) 1671 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1672 failure(); 1673 1674 int64_t offset; 1675 SmallVector<int64_t, 4> strides; 1676 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1677 if (failed(successStrides)) 1678 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1679 assert(offset == 0 && "expected offset to be 0"); 1680 1681 // Create the descriptor. 1682 MemRefDescriptor sourceMemRef(adaptor.source()); 1683 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1684 1685 // Field 1: Copy the allocated pointer, used for malloc/free. 1686 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1687 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1688 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1689 loc, 1690 LLVM::LLVMPointerType::get(targetElementTy, 1691 srcMemRefType.getMemorySpaceAsInt()), 1692 allocatedPtr); 1693 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1694 1695 // Field 2: Copy the actual aligned pointer to payload. 1696 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1697 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1698 alignedPtr, adaptor.byte_shift()); 1699 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1700 loc, 1701 LLVM::LLVMPointerType::get(targetElementTy, 1702 srcMemRefType.getMemorySpaceAsInt()), 1703 alignedPtr); 1704 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1705 1706 // Field 3: The offset in the resulting type must be 0. This is because of 1707 // the type change: an offset on srcType* may not be expressible as an 1708 // offset on dstType*. 1709 targetMemRef.setOffset(rewriter, loc, 1710 createIndexConstant(rewriter, loc, offset)); 1711 1712 // Early exit for 0-D corner case. 1713 if (viewMemRefType.getRank() == 0) 1714 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1715 1716 // Fields 4 and 5: Update sizes and strides. 1717 if (strides.back() != 1) 1718 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1719 failure(); 1720 Value stride = nullptr, nextSize = nullptr; 1721 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1722 // Update size. 1723 Value size = 1724 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1725 targetMemRef.setSize(rewriter, loc, i, size); 1726 // Update stride. 1727 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1728 targetMemRef.setStride(rewriter, loc, i, stride); 1729 nextSize = size; 1730 } 1731 1732 rewriter.replaceOp(viewOp, {targetMemRef}); 1733 return success(); 1734 } 1735 }; 1736 1737 //===----------------------------------------------------------------------===// 1738 // AtomicRMWOpLowering 1739 //===----------------------------------------------------------------------===// 1740 1741 /// Try to match the kind of a std.atomic_rmw to determine whether to use a 1742 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1743 static Optional<LLVM::AtomicBinOp> 1744 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1745 switch (atomicOp.kind()) { 1746 case arith::AtomicRMWKind::addf: 1747 return LLVM::AtomicBinOp::fadd; 1748 case arith::AtomicRMWKind::addi: 1749 return LLVM::AtomicBinOp::add; 1750 case arith::AtomicRMWKind::assign: 1751 return LLVM::AtomicBinOp::xchg; 1752 case arith::AtomicRMWKind::maxs: 1753 return LLVM::AtomicBinOp::max; 1754 case arith::AtomicRMWKind::maxu: 1755 return LLVM::AtomicBinOp::umax; 1756 case arith::AtomicRMWKind::mins: 1757 return LLVM::AtomicBinOp::min; 1758 case arith::AtomicRMWKind::minu: 1759 return LLVM::AtomicBinOp::umin; 1760 case arith::AtomicRMWKind::ori: 1761 return LLVM::AtomicBinOp::_or; 1762 case arith::AtomicRMWKind::andi: 1763 return LLVM::AtomicBinOp::_and; 1764 default: 1765 return llvm::None; 1766 } 1767 llvm_unreachable("Invalid AtomicRMWKind"); 1768 } 1769 1770 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1771 using Base::Base; 1772 1773 LogicalResult 1774 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1775 ConversionPatternRewriter &rewriter) const override { 1776 if (failed(match(atomicOp))) 1777 return failure(); 1778 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1779 if (!maybeKind) 1780 return failure(); 1781 auto resultType = adaptor.value().getType(); 1782 auto memRefType = atomicOp.getMemRefType(); 1783 auto dataPtr = 1784 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1785 adaptor.indices(), rewriter); 1786 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1787 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1788 LLVM::AtomicOrdering::acq_rel); 1789 return success(); 1790 } 1791 }; 1792 1793 } // namespace 1794 1795 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1796 RewritePatternSet &patterns) { 1797 // clang-format off 1798 patterns.add< 1799 AllocaOpLowering, 1800 AllocaScopeOpLowering, 1801 AtomicRMWOpLowering, 1802 AssumeAlignmentOpLowering, 1803 DimOpLowering, 1804 GenericAtomicRMWOpLowering, 1805 GlobalMemrefOpLowering, 1806 GetGlobalMemrefOpLowering, 1807 LoadOpLowering, 1808 MemRefCastOpLowering, 1809 MemRefCopyOpLowering, 1810 MemRefReinterpretCastOpLowering, 1811 MemRefReshapeOpLowering, 1812 PrefetchOpLowering, 1813 RankOpLowering, 1814 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1815 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1816 StoreOpLowering, 1817 SubViewOpLowering, 1818 TransposeOpLowering, 1819 ViewOpLowering>(converter); 1820 // clang-format on 1821 auto allocLowering = converter.getOptions().allocLowering; 1822 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1823 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1824 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1825 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1826 } 1827 1828 namespace { 1829 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1830 MemRefToLLVMPass() = default; 1831 1832 void runOnOperation() override { 1833 Operation *op = getOperation(); 1834 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1835 LowerToLLVMOptions options(&getContext(), 1836 dataLayoutAnalysis.getAtOrAbove(op)); 1837 options.allocLowering = 1838 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1839 : LowerToLLVMOptions::AllocLowering::Malloc); 1840 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1841 options.overrideIndexBitwidth(indexBitwidth); 1842 1843 LLVMTypeConverter typeConverter(&getContext(), options, 1844 &dataLayoutAnalysis); 1845 RewritePatternSet patterns(&getContext()); 1846 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1847 LLVMConversionTarget target(getContext()); 1848 target.addLegalOp<FuncOp>(); 1849 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1850 signalPassFailure(); 1851 } 1852 }; 1853 } // namespace 1854 1855 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1856 return std::make_unique<MemRefToLLVMPass>(); 1857 } 1858