1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 27 AllocOpLowering(LLVMTypeConverter &converter) 28 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 29 converter) {} 30 31 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 32 Location loc, Value sizeBytes, 33 Operation *op) const override { 34 // Heap allocations. 35 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 36 MemRefType memRefType = allocOp.getType(); 37 38 Value alignment; 39 if (auto alignmentAttr = allocOp.alignment()) { 40 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 41 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 42 // In the case where no alignment is specified, we may want to override 43 // `malloc's` behavior. `malloc` typically aligns at the size of the 44 // biggest scalar on a target HW. For non-scalars, use the natural 45 // alignment of the LLVM type given by the LLVM DataLayout. 46 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 47 } 48 49 if (alignment) { 50 // Adjust the allocation size to consider alignment. 51 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 52 } 53 54 // Allocate the underlying buffer and store a pointer to it in the MemRef 55 // descriptor. 56 Type elementPtrType = this->getElementPtrType(memRefType); 57 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 58 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 59 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 60 getVoidPtrType()); 61 Value allocatedPtr = 62 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 63 64 Value alignedPtr = allocatedPtr; 65 if (alignment) { 66 // Compute the aligned type pointer. 67 Value allocatedInt = 68 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 69 Value alignmentInt = 70 createAligned(rewriter, loc, allocatedInt, alignment); 71 alignedPtr = 72 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 73 } 74 75 return std::make_tuple(allocatedPtr, alignedPtr); 76 } 77 }; 78 79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 80 AlignedAllocOpLowering(LLVMTypeConverter &converter) 81 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 82 converter) {} 83 84 /// Returns the memref's element size in bytes using the data layout active at 85 /// `op`. 86 // TODO: there are other places where this is used. Expose publicly? 87 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 88 const DataLayout *layout = &defaultLayout; 89 if (const DataLayoutAnalysis *analysis = 90 getTypeConverter()->getDataLayoutAnalysis()) { 91 layout = &analysis->getAbove(op); 92 } 93 Type elementType = memRefType.getElementType(); 94 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 95 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 96 *layout); 97 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 98 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 99 memRefElementType, *layout); 100 return layout->getTypeSize(elementType); 101 } 102 103 /// Returns true if the memref size in bytes is known to be a multiple of 104 /// factor assuming the data layout active at `op`. 105 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 106 Operation *op) const { 107 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 108 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 109 if (ShapedType::isDynamic(type.getDimSize(i))) 110 continue; 111 sizeDivisor = sizeDivisor * type.getDimSize(i); 112 } 113 return sizeDivisor % factor == 0; 114 } 115 116 /// Returns the alignment to be used for the allocation call itself. 117 /// aligned_alloc requires the allocation size to be a power of two, and the 118 /// allocation size to be a multiple of alignment, 119 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 120 if (Optional<uint64_t> alignment = allocOp.alignment()) 121 return *alignment; 122 123 // Whenever we don't have alignment set, we will use an alignment 124 // consistent with the element type; since the allocation size has to be a 125 // power of two, we will bump to the next power of two if it already isn't. 126 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 127 return std::max(kMinAlignedAllocAlignment, 128 llvm::PowerOf2Ceil(eltSizeBytes)); 129 } 130 131 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 132 Location loc, Value sizeBytes, 133 Operation *op) const override { 134 // Heap allocations. 135 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 136 MemRefType memRefType = allocOp.getType(); 137 int64_t alignment = getAllocationAlignment(allocOp); 138 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 139 140 // aligned_alloc requires size to be a multiple of alignment; we will pad 141 // the size to the next multiple if necessary. 142 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 143 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 144 145 Type elementPtrType = this->getElementPtrType(memRefType); 146 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 147 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 148 auto results = 149 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 150 getVoidPtrType()); 151 Value allocatedPtr = 152 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 153 154 return std::make_tuple(allocatedPtr, allocatedPtr); 155 } 156 157 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 158 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 159 160 /// Default layout to use in absence of the corresponding analysis. 161 DataLayout defaultLayout; 162 }; 163 164 // Out of line definition, required till C++17. 165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 166 167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 168 AllocaOpLowering(LLVMTypeConverter &converter) 169 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 170 converter) {} 171 172 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 173 /// is set to null for stack allocations. `accessAlignment` is set if 174 /// alignment is needed post allocation (for eg. in conjunction with malloc). 175 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 176 Location loc, Value sizeBytes, 177 Operation *op) const override { 178 179 // With alloca, one gets a pointer to the element type right away. 180 // For stack allocations. 181 auto allocaOp = cast<memref::AllocaOp>(op); 182 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 183 184 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 185 loc, elementPtrType, sizeBytes, 186 allocaOp.alignment() ? *allocaOp.alignment() : 0); 187 188 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 189 } 190 }; 191 192 struct AllocaScopeOpLowering 193 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 194 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 195 196 LogicalResult 197 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 198 ConversionPatternRewriter &rewriter) const override { 199 OpBuilder::InsertionGuard guard(rewriter); 200 Location loc = allocaScopeOp.getLoc(); 201 202 // Split the current block before the AllocaScopeOp to create the inlining 203 // point. 204 auto *currentBlock = rewriter.getInsertionBlock(); 205 auto *remainingOpsBlock = 206 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 207 Block *continueBlock; 208 if (allocaScopeOp.getNumResults() == 0) { 209 continueBlock = remainingOpsBlock; 210 } else { 211 continueBlock = rewriter.createBlock( 212 remainingOpsBlock, allocaScopeOp.getResultTypes(), 213 SmallVector<Location>(allocaScopeOp->getNumResults(), 214 allocaScopeOp.getLoc())); 215 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 216 } 217 218 // Inline body region. 219 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 220 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 221 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 222 223 // Save stack and then branch into the body of the region. 224 rewriter.setInsertionPointToEnd(currentBlock); 225 auto stackSaveOp = 226 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 227 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 228 229 // Replace the alloca_scope return with a branch that jumps out of the body. 230 // Stack restore before leaving the body region. 231 rewriter.setInsertionPointToEnd(afterBody); 232 auto returnOp = 233 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 234 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 235 returnOp, returnOp.results(), continueBlock); 236 237 // Insert stack restore before jumping out the body of the region. 238 rewriter.setInsertionPoint(branchOp); 239 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 240 241 // Replace the op with values return from the body region. 242 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 243 244 return success(); 245 } 246 }; 247 248 struct AssumeAlignmentOpLowering 249 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 250 using ConvertOpToLLVMPattern< 251 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 252 253 LogicalResult 254 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 255 ConversionPatternRewriter &rewriter) const override { 256 Value memref = adaptor.memref(); 257 unsigned alignment = op.alignment(); 258 auto loc = op.getLoc(); 259 260 MemRefDescriptor memRefDescriptor(memref); 261 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 262 263 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 264 // the asserted memref.alignedPtr isn't used anywhere else, as the real 265 // users like load/store/views always re-extract memref.alignedPtr as they 266 // get lowered. 267 // 268 // This relies on LLVM's CSE optimization (potentially after SROA), since 269 // after CSE all memref.alignedPtr instances get de-duplicated into the same 270 // pointer SSA value. 271 auto intPtrType = 272 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 273 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 274 Value mask = 275 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 276 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 277 rewriter.create<LLVM::AssumeOp>( 278 loc, rewriter.create<LLVM::ICmpOp>( 279 loc, LLVM::ICmpPredicate::eq, 280 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 281 282 rewriter.eraseOp(op); 283 return success(); 284 } 285 }; 286 287 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 288 // The memref descriptor being an SSA value, there is no need to clean it up 289 // in any way. 290 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 291 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 292 293 explicit DeallocOpLowering(LLVMTypeConverter &converter) 294 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 295 296 LogicalResult 297 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 298 ConversionPatternRewriter &rewriter) const override { 299 // Insert the `free` declaration if it is not already present. 300 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 301 MemRefDescriptor memref(adaptor.memref()); 302 Value casted = rewriter.create<LLVM::BitcastOp>( 303 op.getLoc(), getVoidPtrType(), 304 memref.allocatedPtr(rewriter, op.getLoc())); 305 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 306 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 307 return success(); 308 } 309 }; 310 311 // A `dim` is converted to a constant for static sizes and to an access to the 312 // size stored in the memref descriptor for dynamic sizes. 313 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 314 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 315 316 LogicalResult 317 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 318 ConversionPatternRewriter &rewriter) const override { 319 Type operandType = dimOp.source().getType(); 320 if (operandType.isa<UnrankedMemRefType>()) { 321 rewriter.replaceOp( 322 dimOp, {extractSizeOfUnrankedMemRef( 323 operandType, dimOp, adaptor.getOperands(), rewriter)}); 324 325 return success(); 326 } 327 if (operandType.isa<MemRefType>()) { 328 rewriter.replaceOp( 329 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 330 adaptor.getOperands(), rewriter)}); 331 return success(); 332 } 333 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 334 } 335 336 private: 337 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 338 OpAdaptor adaptor, 339 ConversionPatternRewriter &rewriter) const { 340 Location loc = dimOp.getLoc(); 341 342 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 343 auto scalarMemRefType = 344 MemRefType::get({}, unrankedMemRefType.getElementType()); 345 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 346 347 // Extract pointer to the underlying ranked descriptor and bitcast it to a 348 // memref<element_type> descriptor pointer to minimize the number of GEP 349 // operations. 350 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 351 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 352 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 353 loc, 354 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 355 addressSpace), 356 underlyingRankedDesc); 357 358 // Get pointer to offset field of memref<element_type> descriptor. 359 Type indexPtrTy = LLVM::LLVMPointerType::get( 360 getTypeConverter()->getIndexType(), addressSpace); 361 Value two = rewriter.create<LLVM::ConstantOp>( 362 loc, typeConverter->convertType(rewriter.getI32Type()), 363 rewriter.getI32IntegerAttr(2)); 364 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 365 loc, indexPtrTy, scalarMemRefDescPtr, 366 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 367 368 // The size value that we have to extract can be obtained using GEPop with 369 // `dimOp.index() + 1` index argument. 370 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 371 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 372 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 373 ValueRange({idxPlusOne})); 374 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 375 } 376 377 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 378 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 379 return idx; 380 381 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 382 return constantOp.getValue() 383 .cast<IntegerAttr>() 384 .getValue() 385 .getSExtValue(); 386 387 return llvm::None; 388 } 389 390 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 391 OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const { 393 Location loc = dimOp.getLoc(); 394 395 // Take advantage if index is constant. 396 MemRefType memRefType = operandType.cast<MemRefType>(); 397 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 398 int64_t i = index.getValue(); 399 if (memRefType.isDynamicDim(i)) { 400 // extract dynamic size from the memref descriptor. 401 MemRefDescriptor descriptor(adaptor.source()); 402 return descriptor.size(rewriter, loc, i); 403 } 404 // Use constant for static size. 405 int64_t dimSize = memRefType.getDimSize(i); 406 return createIndexConstant(rewriter, loc, dimSize); 407 } 408 Value index = adaptor.index(); 409 int64_t rank = memRefType.getRank(); 410 MemRefDescriptor memrefDescriptor(adaptor.source()); 411 return memrefDescriptor.size(rewriter, loc, index, rank); 412 } 413 }; 414 415 /// Common base for load and store operations on MemRefs. Restricts the match 416 /// to supported MemRef types. Provides functionality to emit code accessing a 417 /// specific element of the underlying data buffer. 418 template <typename Derived> 419 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 420 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 421 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 422 using Base = LoadStoreOpLowering<Derived>; 423 424 LogicalResult match(Derived op) const override { 425 MemRefType type = op.getMemRefType(); 426 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 427 } 428 }; 429 430 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 431 /// retried until it succeeds in atomically storing a new value into memory. 432 /// 433 /// +---------------------------------+ 434 /// | <code before the AtomicRMWOp> | 435 /// | <compute initial %loaded> | 436 /// | br loop(%loaded) | 437 /// +---------------------------------+ 438 /// | 439 /// -------| | 440 /// | v v 441 /// | +--------------------------------+ 442 /// | | loop(%loaded): | 443 /// | | <body contents> | 444 /// | | %pair = cmpxchg | 445 /// | | %ok = %pair[0] | 446 /// | | %new = %pair[1] | 447 /// | | cond_br %ok, end, loop(%new) | 448 /// | +--------------------------------+ 449 /// | | | 450 /// |----------- | 451 /// v 452 /// +--------------------------------+ 453 /// | end: | 454 /// | <code after the AtomicRMWOp> | 455 /// +--------------------------------+ 456 /// 457 struct GenericAtomicRMWOpLowering 458 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 459 using Base::Base; 460 461 LogicalResult 462 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 463 ConversionPatternRewriter &rewriter) const override { 464 auto loc = atomicOp.getLoc(); 465 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 466 467 // Split the block into initial, loop, and ending parts. 468 auto *initBlock = rewriter.getInsertionBlock(); 469 auto *loopBlock = rewriter.createBlock( 470 initBlock->getParent(), std::next(Region::iterator(initBlock)), 471 valueType, loc); 472 auto *endBlock = rewriter.createBlock( 473 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 474 475 // Operations range to be moved to `endBlock`. 476 auto opsToMoveStart = atomicOp->getIterator(); 477 auto opsToMoveEnd = initBlock->back().getIterator(); 478 479 // Compute the loaded value and branch to the loop block. 480 rewriter.setInsertionPointToEnd(initBlock); 481 auto memRefType = atomicOp.memref().getType().cast<MemRefType>(); 482 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 483 adaptor.indices(), rewriter); 484 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 485 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 486 487 // Prepare the body of the loop block. 488 rewriter.setInsertionPointToStart(loopBlock); 489 490 // Clone the GenericAtomicRMWOp region and extract the result. 491 auto loopArgument = loopBlock->getArgument(0); 492 BlockAndValueMapping mapping; 493 mapping.map(atomicOp.getCurrentValue(), loopArgument); 494 Block &entryBlock = atomicOp.body().front(); 495 for (auto &nestedOp : entryBlock.without_terminator()) { 496 Operation *clone = rewriter.clone(nestedOp, mapping); 497 mapping.map(nestedOp.getResults(), clone->getResults()); 498 } 499 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 500 501 // Prepare the epilog of the loop block. 502 // Append the cmpxchg op to the end of the loop block. 503 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 504 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 505 auto boolType = IntegerType::get(rewriter.getContext(), 1); 506 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 507 {valueType, boolType}); 508 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 509 loc, pairType, dataPtr, loopArgument, result, successOrdering, 510 failureOrdering); 511 // Extract the %new_loaded and %ok values from the pair. 512 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 513 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 514 Value ok = rewriter.create<LLVM::ExtractValueOp>( 515 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 516 517 // Conditionally branch to the end or back to the loop depending on %ok. 518 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 519 loopBlock, newLoaded); 520 521 rewriter.setInsertionPointToEnd(endBlock); 522 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 523 std::next(opsToMoveEnd), rewriter); 524 525 // The 'result' of the atomic_rmw op is the newly loaded value. 526 rewriter.replaceOp(atomicOp, {newLoaded}); 527 528 return success(); 529 } 530 531 private: 532 // Clones a segment of ops [start, end) and erases the original. 533 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 534 Block::iterator start, Block::iterator end, 535 ConversionPatternRewriter &rewriter) const { 536 BlockAndValueMapping mapping; 537 mapping.map(oldResult, newResult); 538 SmallVector<Operation *, 2> opsToErase; 539 for (auto it = start; it != end; ++it) { 540 rewriter.clone(*it, mapping); 541 opsToErase.push_back(&*it); 542 } 543 for (auto *it : opsToErase) 544 rewriter.eraseOp(it); 545 } 546 }; 547 548 /// Returns the LLVM type of the global variable given the memref type `type`. 549 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 550 LLVMTypeConverter &typeConverter) { 551 // LLVM type for a global memref will be a multi-dimension array. For 552 // declarations or uninitialized global memrefs, we can potentially flatten 553 // this to a 1D array. However, for memref.global's with an initial value, 554 // we do not intend to flatten the ElementsAttribute when going from std -> 555 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 556 Type elementType = typeConverter.convertType(type.getElementType()); 557 Type arrayTy = elementType; 558 // Shape has the outermost dim at index 0, so need to walk it backwards 559 for (int64_t dim : llvm::reverse(type.getShape())) 560 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 561 return arrayTy; 562 } 563 564 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 565 struct GlobalMemrefOpLowering 566 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 567 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 568 569 LogicalResult 570 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 571 ConversionPatternRewriter &rewriter) const override { 572 MemRefType type = global.type(); 573 if (!isConvertibleAndHasIdentityMaps(type)) 574 return failure(); 575 576 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 577 578 LLVM::Linkage linkage = 579 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 580 581 Attribute initialValue = nullptr; 582 if (!global.isExternal() && !global.isUninitialized()) { 583 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 584 initialValue = elementsAttr; 585 586 // For scalar memrefs, the global variable created is of the element type, 587 // so unpack the elements attribute to extract the value. 588 if (type.getRank() == 0) 589 initialValue = elementsAttr.getSplatValue<Attribute>(); 590 } 591 592 uint64_t alignment = global.alignment().getValueOr(0); 593 594 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 595 global, arrayTy, global.constant(), linkage, global.sym_name(), 596 initialValue, alignment, type.getMemorySpaceAsInt()); 597 if (!global.isExternal() && global.isUninitialized()) { 598 Block *blk = new Block(); 599 newGlobal.getInitializerRegion().push_back(blk); 600 rewriter.setInsertionPointToStart(blk); 601 Value undef[] = { 602 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 603 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 604 } 605 return success(); 606 } 607 }; 608 609 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 610 /// the first element stashed into the descriptor. This reuses 611 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 612 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 613 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 614 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 615 converter) {} 616 617 /// Buffer "allocation" for memref.get_global op is getting the address of 618 /// the global variable referenced. 619 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 620 Location loc, Value sizeBytes, 621 Operation *op) const override { 622 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 623 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 624 unsigned memSpace = type.getMemorySpaceAsInt(); 625 626 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 627 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 628 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 629 630 // Get the address of the first element in the array by creating a GEP with 631 // the address of the GV as the base, and (rank + 1) number of 0 indices. 632 Type elementType = typeConverter->convertType(type.getElementType()); 633 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 634 635 SmallVector<Value> operands; 636 operands.insert(operands.end(), type.getRank() + 1, 637 createIndexConstant(rewriter, loc, 0)); 638 auto gep = 639 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 640 641 // We do not expect the memref obtained using `memref.get_global` to be 642 // ever deallocated. Set the allocated pointer to be known bad value to 643 // help debug if that ever happens. 644 auto intPtrType = getIntPtrType(memSpace); 645 Value deadBeefConst = 646 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 647 auto deadBeefPtr = 648 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 649 650 // Both allocated and aligned pointers are same. We could potentially stash 651 // a nullptr for the allocated pointer since we do not expect any dealloc. 652 return std::make_tuple(deadBeefPtr, gep); 653 } 654 }; 655 656 // Load operation is lowered to obtaining a pointer to the indexed element 657 // and loading it. 658 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 659 using Base::Base; 660 661 LogicalResult 662 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 663 ConversionPatternRewriter &rewriter) const override { 664 auto type = loadOp.getMemRefType(); 665 666 Value dataPtr = getStridedElementPtr( 667 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 668 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 669 return success(); 670 } 671 }; 672 673 // Store operation is lowered to obtaining a pointer to the indexed element, 674 // and storing the given value to it. 675 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 676 using Base::Base; 677 678 LogicalResult 679 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 680 ConversionPatternRewriter &rewriter) const override { 681 auto type = op.getMemRefType(); 682 683 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 684 adaptor.indices(), rewriter); 685 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 686 return success(); 687 } 688 }; 689 690 // The prefetch operation is lowered in a way similar to the load operation 691 // except that the llvm.prefetch operation is used for replacement. 692 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 693 using Base::Base; 694 695 LogicalResult 696 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 697 ConversionPatternRewriter &rewriter) const override { 698 auto type = prefetchOp.getMemRefType(); 699 auto loc = prefetchOp.getLoc(); 700 701 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 702 adaptor.indices(), rewriter); 703 704 // Replace with llvm.prefetch. 705 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 706 auto isWrite = rewriter.create<LLVM::ConstantOp>( 707 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 708 auto localityHint = rewriter.create<LLVM::ConstantOp>( 709 loc, llvmI32Type, 710 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 711 auto isData = rewriter.create<LLVM::ConstantOp>( 712 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 713 714 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 715 localityHint, isData); 716 return success(); 717 } 718 }; 719 720 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 721 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 722 723 LogicalResult 724 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 725 ConversionPatternRewriter &rewriter) const override { 726 Location loc = op.getLoc(); 727 Type operandType = op.memref().getType(); 728 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 729 UnrankedMemRefDescriptor desc(adaptor.memref()); 730 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 731 return success(); 732 } 733 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 734 rewriter.replaceOp( 735 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 736 return success(); 737 } 738 return failure(); 739 } 740 }; 741 742 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 743 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 744 745 LogicalResult match(memref::CastOp memRefCastOp) const override { 746 Type srcType = memRefCastOp.getOperand().getType(); 747 Type dstType = memRefCastOp.getType(); 748 749 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 750 // used for type erasure. For now they must preserve underlying element type 751 // and require source and result type to have the same rank. Therefore, 752 // perform a sanity check that the underlying structs are the same. Once op 753 // semantics are relaxed we can revisit. 754 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 755 return success(typeConverter->convertType(srcType) == 756 typeConverter->convertType(dstType)); 757 758 // At least one of the operands is unranked type 759 assert(srcType.isa<UnrankedMemRefType>() || 760 dstType.isa<UnrankedMemRefType>()); 761 762 // Unranked to unranked cast is disallowed 763 return !(srcType.isa<UnrankedMemRefType>() && 764 dstType.isa<UnrankedMemRefType>()) 765 ? success() 766 : failure(); 767 } 768 769 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 770 ConversionPatternRewriter &rewriter) const override { 771 auto srcType = memRefCastOp.getOperand().getType(); 772 auto dstType = memRefCastOp.getType(); 773 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 774 auto loc = memRefCastOp.getLoc(); 775 776 // For ranked/ranked case, just keep the original descriptor. 777 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 778 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 779 780 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 781 // Casting ranked to unranked memref type 782 // Set the rank in the destination from the memref type 783 // Allocate space on the stack and copy the src memref descriptor 784 // Set the ptr in the destination to the stack space 785 auto srcMemRefType = srcType.cast<MemRefType>(); 786 int64_t rank = srcMemRefType.getRank(); 787 // ptr = AllocaOp sizeof(MemRefDescriptor) 788 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 789 loc, adaptor.source(), rewriter); 790 // voidptr = BitCastOp srcType* to void* 791 auto voidPtr = 792 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 793 .getResult(); 794 // rank = ConstantOp srcRank 795 auto rankVal = rewriter.create<LLVM::ConstantOp>( 796 loc, typeConverter->convertType(rewriter.getIntegerType(64)), 797 rewriter.getI64IntegerAttr(rank)); 798 // undef = UndefOp 799 UnrankedMemRefDescriptor memRefDesc = 800 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 801 // d1 = InsertValueOp undef, rank, 0 802 memRefDesc.setRank(rewriter, loc, rankVal); 803 // d2 = InsertValueOp d1, voidptr, 1 804 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 805 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 806 807 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 808 // Casting from unranked type to ranked. 809 // The operation is assumed to be doing a correct cast. If the destination 810 // type mismatches the unranked the type, it is undefined behavior. 811 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 812 // ptr = ExtractValueOp src, 1 813 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 814 // castPtr = BitCastOp i8* to structTy* 815 auto castPtr = 816 rewriter 817 .create<LLVM::BitcastOp>( 818 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 819 .getResult(); 820 // struct = LoadOp castPtr 821 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 822 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 823 } else { 824 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 825 } 826 } 827 }; 828 829 /// Pattern to lower a `memref.copy` to llvm. 830 /// 831 /// For memrefs with identity layouts, the copy is lowered to the llvm 832 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 833 /// to the generic `MemrefCopyFn`. 834 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 835 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 836 837 LogicalResult 838 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 839 ConversionPatternRewriter &rewriter) const { 840 auto loc = op.getLoc(); 841 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 842 843 MemRefDescriptor srcDesc(adaptor.source()); 844 845 // Compute number of elements. 846 Value numElements = rewriter.create<LLVM::ConstantOp>( 847 loc, getIndexType(), rewriter.getIndexAttr(1)); 848 for (int pos = 0; pos < srcType.getRank(); ++pos) { 849 auto size = srcDesc.size(rewriter, loc, pos); 850 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 851 } 852 853 // Get element size. 854 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 855 // Compute total. 856 Value totalSize = 857 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 858 859 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 860 MemRefDescriptor targetDesc(adaptor.target()); 861 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 862 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 863 loc, typeConverter->convertType(rewriter.getI1Type()), 864 rewriter.getBoolAttr(false)); 865 rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize, 866 isVolatile); 867 rewriter.eraseOp(op); 868 869 return success(); 870 } 871 872 LogicalResult 873 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 874 ConversionPatternRewriter &rewriter) const { 875 auto loc = op.getLoc(); 876 auto srcType = op.source().getType().cast<BaseMemRefType>(); 877 auto targetType = op.target().getType().cast<BaseMemRefType>(); 878 879 // First make sure we have an unranked memref descriptor representation. 880 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 881 auto rank = rewriter.create<LLVM::ConstantOp>( 882 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 883 auto *typeConverter = getTypeConverter(); 884 auto ptr = 885 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 886 auto voidPtr = 887 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 888 .getResult(); 889 auto unrankedType = 890 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 891 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 892 unrankedType, 893 ValueRange{rank, voidPtr}); 894 }; 895 896 Value unrankedSource = srcType.hasRank() 897 ? makeUnranked(adaptor.source(), srcType) 898 : adaptor.source(); 899 Value unrankedTarget = targetType.hasRank() 900 ? makeUnranked(adaptor.target(), targetType) 901 : adaptor.target(); 902 903 // Now promote the unranked descriptors to the stack. 904 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 905 rewriter.getIndexAttr(1)); 906 auto promote = [&](Value desc) { 907 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 908 auto allocated = 909 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 910 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 911 return allocated; 912 }; 913 914 auto sourcePtr = promote(unrankedSource); 915 auto targetPtr = promote(unrankedTarget); 916 917 unsigned bitwidth = mlir::DataLayout::closest(op).getTypeSizeInBits( 918 srcType.getElementType()); 919 auto elemSize = rewriter.create<LLVM::ConstantOp>( 920 loc, getIndexType(), rewriter.getIndexAttr(bitwidth / 8)); 921 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 922 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 923 rewriter.create<LLVM::CallOp>(loc, copyFn, 924 ValueRange{elemSize, sourcePtr, targetPtr}); 925 rewriter.eraseOp(op); 926 927 return success(); 928 } 929 930 LogicalResult 931 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 932 ConversionPatternRewriter &rewriter) const override { 933 auto srcType = op.source().getType().cast<BaseMemRefType>(); 934 auto targetType = op.target().getType().cast<BaseMemRefType>(); 935 936 if (srcType.hasRank() && 937 srcType.cast<MemRefType>().getLayout().isIdentity() && 938 targetType.hasRank() && 939 targetType.cast<MemRefType>().getLayout().isIdentity()) 940 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 941 942 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 943 } 944 }; 945 946 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 947 /// memref type. In unranked case, the fields are extracted from the underlying 948 /// ranked descriptor. 949 static void extractPointersAndOffset(Location loc, 950 ConversionPatternRewriter &rewriter, 951 LLVMTypeConverter &typeConverter, 952 Value originalOperand, 953 Value convertedOperand, 954 Value *allocatedPtr, Value *alignedPtr, 955 Value *offset = nullptr) { 956 Type operandType = originalOperand.getType(); 957 if (operandType.isa<MemRefType>()) { 958 MemRefDescriptor desc(convertedOperand); 959 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 960 *alignedPtr = desc.alignedPtr(rewriter, loc); 961 if (offset != nullptr) 962 *offset = desc.offset(rewriter, loc); 963 return; 964 } 965 966 unsigned memorySpace = 967 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 968 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 969 Type llvmElementType = typeConverter.convertType(elementType); 970 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 971 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 972 973 // Extract pointer to the underlying ranked memref descriptor and cast it to 974 // ElemType**. 975 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 976 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 977 978 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 979 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 980 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 981 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 982 if (offset != nullptr) { 983 *offset = UnrankedMemRefDescriptor::offset( 984 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 985 } 986 } 987 988 struct MemRefReinterpretCastOpLowering 989 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 990 using ConvertOpToLLVMPattern< 991 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 992 993 LogicalResult 994 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 995 ConversionPatternRewriter &rewriter) const override { 996 Type srcType = castOp.source().getType(); 997 998 Value descriptor; 999 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1000 adaptor, &descriptor))) 1001 return failure(); 1002 rewriter.replaceOp(castOp, {descriptor}); 1003 return success(); 1004 } 1005 1006 private: 1007 LogicalResult convertSourceMemRefToDescriptor( 1008 ConversionPatternRewriter &rewriter, Type srcType, 1009 memref::ReinterpretCastOp castOp, 1010 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1011 MemRefType targetMemRefType = 1012 castOp.getResult().getType().cast<MemRefType>(); 1013 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1014 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1015 if (!llvmTargetDescriptorTy) 1016 return failure(); 1017 1018 // Create descriptor. 1019 Location loc = castOp.getLoc(); 1020 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1021 1022 // Set allocated and aligned pointers. 1023 Value allocatedPtr, alignedPtr; 1024 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1025 castOp.source(), adaptor.source(), &allocatedPtr, 1026 &alignedPtr); 1027 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1028 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1029 1030 // Set offset. 1031 if (castOp.isDynamicOffset(0)) 1032 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 1033 else 1034 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1035 1036 // Set sizes and strides. 1037 unsigned dynSizeId = 0; 1038 unsigned dynStrideId = 0; 1039 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1040 if (castOp.isDynamicSize(i)) 1041 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 1042 else 1043 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1044 1045 if (castOp.isDynamicStride(i)) 1046 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 1047 else 1048 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1049 } 1050 *descriptor = desc; 1051 return success(); 1052 } 1053 }; 1054 1055 struct MemRefReshapeOpLowering 1056 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1057 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1058 1059 LogicalResult 1060 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1061 ConversionPatternRewriter &rewriter) const override { 1062 Type srcType = reshapeOp.source().getType(); 1063 1064 Value descriptor; 1065 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1066 adaptor, &descriptor))) 1067 return failure(); 1068 rewriter.replaceOp(reshapeOp, {descriptor}); 1069 return success(); 1070 } 1071 1072 private: 1073 LogicalResult 1074 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1075 Type srcType, memref::ReshapeOp reshapeOp, 1076 memref::ReshapeOp::Adaptor adaptor, 1077 Value *descriptor) const { 1078 // Conversion for statically-known shape args is performed via 1079 // `memref_reinterpret_cast`. 1080 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 1081 if (shapeMemRefType.hasStaticShape()) 1082 return failure(); 1083 1084 // The shape is a rank-1 tensor with unknown length. 1085 Location loc = reshapeOp.getLoc(); 1086 MemRefDescriptor shapeDesc(adaptor.shape()); 1087 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1088 1089 // Extract address space and element type. 1090 auto targetType = 1091 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1092 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1093 Type elementType = targetType.getElementType(); 1094 1095 // Create the unranked memref descriptor that holds the ranked one. The 1096 // inner descriptor is allocated on stack. 1097 auto targetDesc = UnrankedMemRefDescriptor::undef( 1098 rewriter, loc, typeConverter->convertType(targetType)); 1099 targetDesc.setRank(rewriter, loc, resultRank); 1100 SmallVector<Value, 4> sizes; 1101 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1102 targetDesc, sizes); 1103 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1104 loc, getVoidPtrType(), sizes.front(), llvm::None); 1105 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1106 1107 // Extract pointers and offset from the source memref. 1108 Value allocatedPtr, alignedPtr, offset; 1109 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1110 reshapeOp.source(), adaptor.source(), 1111 &allocatedPtr, &alignedPtr, &offset); 1112 1113 // Set pointers and offset. 1114 Type llvmElementType = typeConverter->convertType(elementType); 1115 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1116 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1117 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1118 elementPtrPtrType, allocatedPtr); 1119 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1120 underlyingDescPtr, 1121 elementPtrPtrType, alignedPtr); 1122 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1123 underlyingDescPtr, elementPtrPtrType, 1124 offset); 1125 1126 // Use the offset pointer as base for further addressing. Copy over the new 1127 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1128 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1129 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1130 elementPtrPtrType); 1131 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1132 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1133 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1134 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1135 Value resultRankMinusOne = 1136 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1137 1138 Block *initBlock = rewriter.getInsertionBlock(); 1139 Type indexType = getTypeConverter()->getIndexType(); 1140 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1141 1142 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1143 {indexType, indexType}, {loc, loc}); 1144 1145 // Move the remaining initBlock ops to condBlock. 1146 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1147 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1148 1149 rewriter.setInsertionPointToEnd(initBlock); 1150 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1151 condBlock); 1152 rewriter.setInsertionPointToStart(condBlock); 1153 Value indexArg = condBlock->getArgument(0); 1154 Value strideArg = condBlock->getArgument(1); 1155 1156 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1157 Value pred = rewriter.create<LLVM::ICmpOp>( 1158 loc, IntegerType::get(rewriter.getContext(), 1), 1159 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1160 1161 Block *bodyBlock = 1162 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1163 rewriter.setInsertionPointToStart(bodyBlock); 1164 1165 // Copy size from shape to descriptor. 1166 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1167 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1168 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1169 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1170 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1171 targetSizesBase, indexArg, size); 1172 1173 // Write stride value and compute next one. 1174 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1175 targetStridesBase, indexArg, strideArg); 1176 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1177 1178 // Decrement loop counter and branch back. 1179 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1180 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1181 condBlock); 1182 1183 Block *remainder = 1184 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1185 1186 // Hook up the cond exit to the remainder. 1187 rewriter.setInsertionPointToEnd(condBlock); 1188 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1189 llvm::None); 1190 1191 // Reset position to beginning of new remainder block. 1192 rewriter.setInsertionPointToStart(remainder); 1193 1194 *descriptor = targetDesc; 1195 return success(); 1196 } 1197 }; 1198 1199 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1200 /// `Value`s. 1201 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1202 Type &llvmIndexType, 1203 ArrayRef<OpFoldResult> valueOrAttrVec) { 1204 return llvm::to_vector<4>( 1205 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1206 if (auto attr = value.dyn_cast<Attribute>()) 1207 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1208 return value.get<Value>(); 1209 })); 1210 } 1211 1212 /// Compute a map that for a given dimension of the expanded type gives the 1213 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1214 /// the `reassocation` maps. 1215 static DenseMap<int64_t, int64_t> 1216 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1217 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1218 for (auto &en : enumerate(reassociation)) { 1219 for (auto dim : en.value()) 1220 expandedDimToCollapsedDim[dim] = en.index(); 1221 } 1222 return expandedDimToCollapsedDim; 1223 } 1224 1225 static OpFoldResult 1226 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1227 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1228 MemRefDescriptor &inDesc, 1229 ArrayRef<int64_t> inStaticShape, 1230 ArrayRef<ReassociationIndices> reassocation, 1231 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1232 int64_t outDimSize = outStaticShape[outDimIndex]; 1233 if (!ShapedType::isDynamic(outDimSize)) 1234 return b.getIndexAttr(outDimSize); 1235 1236 // Calculate the multiplication of all the out dim sizes except the 1237 // current dim. 1238 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1239 int64_t otherDimSizesMul = 1; 1240 for (auto otherDimIndex : reassocation[inDimIndex]) { 1241 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1242 continue; 1243 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1244 assert(!ShapedType::isDynamic(otherDimSize) && 1245 "single dimension cannot be expanded into multiple dynamic " 1246 "dimensions"); 1247 otherDimSizesMul *= otherDimSize; 1248 } 1249 1250 // outDimSize = inDimSize / otherOutDimSizesMul 1251 int64_t inDimSize = inStaticShape[inDimIndex]; 1252 Value inDimSizeDynamic = 1253 ShapedType::isDynamic(inDimSize) 1254 ? inDesc.size(b, loc, inDimIndex) 1255 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1256 b.getIndexAttr(inDimSize)); 1257 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1258 loc, inDimSizeDynamic, 1259 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1260 b.getIndexAttr(otherDimSizesMul))); 1261 return outDimSizeDynamic; 1262 } 1263 1264 static OpFoldResult getCollapsedOutputDimSize( 1265 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1266 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1267 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1268 if (!ShapedType::isDynamic(outDimSize)) 1269 return b.getIndexAttr(outDimSize); 1270 1271 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1272 Value outDimSizeDynamic = c1; 1273 for (auto inDimIndex : reassocation[outDimIndex]) { 1274 int64_t inDimSize = inStaticShape[inDimIndex]; 1275 Value inDimSizeDynamic = 1276 ShapedType::isDynamic(inDimSize) 1277 ? inDesc.size(b, loc, inDimIndex) 1278 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1279 b.getIndexAttr(inDimSize)); 1280 outDimSizeDynamic = 1281 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1282 } 1283 return outDimSizeDynamic; 1284 } 1285 1286 static SmallVector<OpFoldResult, 4> 1287 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1288 ArrayRef<ReassociationIndices> reassocation, 1289 ArrayRef<int64_t> inStaticShape, 1290 MemRefDescriptor &inDesc, 1291 ArrayRef<int64_t> outStaticShape) { 1292 return llvm::to_vector<4>(llvm::map_range( 1293 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1294 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1295 outStaticShape[outDimIndex], 1296 inStaticShape, inDesc, reassocation); 1297 })); 1298 } 1299 1300 static SmallVector<OpFoldResult, 4> 1301 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1302 ArrayRef<ReassociationIndices> reassocation, 1303 ArrayRef<int64_t> inStaticShape, 1304 MemRefDescriptor &inDesc, 1305 ArrayRef<int64_t> outStaticShape) { 1306 DenseMap<int64_t, int64_t> outDimToInDimMap = 1307 getExpandedDimToCollapsedDimMap(reassocation); 1308 return llvm::to_vector<4>(llvm::map_range( 1309 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1310 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1311 outStaticShape, inDesc, inStaticShape, 1312 reassocation, outDimToInDimMap); 1313 })); 1314 } 1315 1316 static SmallVector<Value> 1317 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1318 ArrayRef<ReassociationIndices> reassocation, 1319 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1320 ArrayRef<int64_t> outStaticShape) { 1321 return outStaticShape.size() < inStaticShape.size() 1322 ? getAsValues(b, loc, llvmIndexType, 1323 getCollapsedOutputShape(b, loc, llvmIndexType, 1324 reassocation, inStaticShape, 1325 inDesc, outStaticShape)) 1326 : getAsValues(b, loc, llvmIndexType, 1327 getExpandedOutputShape(b, loc, llvmIndexType, 1328 reassocation, inStaticShape, 1329 inDesc, outStaticShape)); 1330 } 1331 1332 // ReshapeOp creates a new view descriptor of the proper rank. 1333 // For now, the only conversion supported is for target MemRef with static sizes 1334 // and strides. 1335 template <typename ReshapeOp> 1336 class ReassociatingReshapeOpConversion 1337 : public ConvertOpToLLVMPattern<ReshapeOp> { 1338 public: 1339 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1340 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1341 1342 LogicalResult 1343 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1344 ConversionPatternRewriter &rewriter) const override { 1345 MemRefType dstType = reshapeOp.getResultType(); 1346 MemRefType srcType = reshapeOp.getSrcType(); 1347 1348 // The condition on the layouts can be ignored when all shapes are static. 1349 if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { 1350 if (!srcType.getLayout().isIdentity() || 1351 !dstType.getLayout().isIdentity()) { 1352 return rewriter.notifyMatchFailure( 1353 reshapeOp, "only empty layout map is supported"); 1354 } 1355 } 1356 1357 int64_t offset; 1358 SmallVector<int64_t, 4> strides; 1359 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1360 return rewriter.notifyMatchFailure( 1361 reshapeOp, "failed to get stride and offset exprs"); 1362 } 1363 1364 MemRefDescriptor srcDesc(adaptor.src()); 1365 Location loc = reshapeOp->getLoc(); 1366 auto dstDesc = MemRefDescriptor::undef( 1367 rewriter, loc, this->typeConverter->convertType(dstType)); 1368 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1369 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1370 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1371 1372 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1373 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1374 Type llvmIndexType = 1375 this->typeConverter->convertType(rewriter.getIndexType()); 1376 SmallVector<Value> dstShape = getDynamicOutputShape( 1377 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1378 srcStaticShape, srcDesc, dstStaticShape); 1379 for (auto &en : llvm::enumerate(dstShape)) 1380 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1381 1382 auto isStaticStride = [](int64_t stride) { 1383 return !ShapedType::isDynamicStrideOrOffset(stride); 1384 }; 1385 if (llvm::all_of(strides, isStaticStride)) { 1386 for (auto &en : llvm::enumerate(strides)) 1387 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1388 } else { 1389 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1390 rewriter.getIndexAttr(1)); 1391 Value stride = c1; 1392 for (auto dimIndex : 1393 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1394 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1395 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1396 } 1397 } 1398 rewriter.replaceOp(reshapeOp, {dstDesc}); 1399 return success(); 1400 } 1401 }; 1402 1403 /// Conversion pattern that transforms a subview op into: 1404 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1405 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1406 /// and stride. 1407 /// The subview op is replaced by the descriptor. 1408 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1409 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1410 1411 LogicalResult 1412 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1413 ConversionPatternRewriter &rewriter) const override { 1414 auto loc = subViewOp.getLoc(); 1415 1416 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1417 auto sourceElementTy = 1418 typeConverter->convertType(sourceMemRefType.getElementType()); 1419 1420 auto viewMemRefType = subViewOp.getType(); 1421 auto inferredType = memref::SubViewOp::inferResultType( 1422 subViewOp.getSourceType(), 1423 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1424 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1425 extractFromI64ArrayAttr(subViewOp.static_strides())) 1426 .cast<MemRefType>(); 1427 auto targetElementTy = 1428 typeConverter->convertType(viewMemRefType.getElementType()); 1429 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1430 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1431 !LLVM::isCompatibleType(sourceElementTy) || 1432 !LLVM::isCompatibleType(targetElementTy) || 1433 !LLVM::isCompatibleType(targetDescTy)) 1434 return failure(); 1435 1436 // Extract the offset and strides from the type. 1437 int64_t offset; 1438 SmallVector<int64_t, 4> strides; 1439 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1440 if (failed(successStrides)) 1441 return failure(); 1442 1443 // Create the descriptor. 1444 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1445 return failure(); 1446 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1447 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1448 1449 // Copy the buffer pointer from the old descriptor to the new one. 1450 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1451 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1452 loc, 1453 LLVM::LLVMPointerType::get(targetElementTy, 1454 viewMemRefType.getMemorySpaceAsInt()), 1455 extracted); 1456 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1457 1458 // Copy the aligned pointer from the old descriptor to the new one. 1459 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1460 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1461 loc, 1462 LLVM::LLVMPointerType::get(targetElementTy, 1463 viewMemRefType.getMemorySpaceAsInt()), 1464 extracted); 1465 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1466 1467 size_t inferredShapeRank = inferredType.getRank(); 1468 size_t resultShapeRank = viewMemRefType.getRank(); 1469 1470 // Extract strides needed to compute offset. 1471 SmallVector<Value, 4> strideValues; 1472 strideValues.reserve(inferredShapeRank); 1473 for (unsigned i = 0; i < inferredShapeRank; ++i) 1474 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1475 1476 // Offset. 1477 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1478 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1479 targetMemRef.setConstantOffset(rewriter, loc, offset); 1480 } else { 1481 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1482 // `inferredShapeRank` may be larger than the number of offset operands 1483 // because of trailing semantics. In this case, the offset is guaranteed 1484 // to be interpreted as 0 and we can just skip the extra dimensions. 1485 for (unsigned i = 0, e = std::min(inferredShapeRank, 1486 subViewOp.getMixedOffsets().size()); 1487 i < e; ++i) { 1488 Value offset = 1489 // TODO: need OpFoldResult ODS adaptor to clean this up. 1490 subViewOp.isDynamicOffset(i) 1491 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1492 : rewriter.create<LLVM::ConstantOp>( 1493 loc, llvmIndexType, 1494 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1495 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1496 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1497 } 1498 targetMemRef.setOffset(rewriter, loc, baseOffset); 1499 } 1500 1501 // Update sizes and strides. 1502 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1503 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1504 assert(mixedSizes.size() == mixedStrides.size() && 1505 "expected sizes and strides of equal length"); 1506 llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims(); 1507 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1508 i >= 0 && j >= 0; --i) { 1509 if (unusedDims.contains(i)) 1510 continue; 1511 1512 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1513 // In this case, the size is guaranteed to be interpreted as Dim and the 1514 // stride as 1. 1515 Value size, stride; 1516 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1517 // If the static size is available, use it directly. This is similar to 1518 // the folding of dim(constant-op) but removes the need for dim to be 1519 // aware of LLVM constants and for this pass to be aware of std 1520 // constants. 1521 int64_t staticSize = 1522 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1523 if (staticSize != ShapedType::kDynamicSize) { 1524 size = rewriter.create<LLVM::ConstantOp>( 1525 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1526 } else { 1527 Value pos = rewriter.create<LLVM::ConstantOp>( 1528 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1529 Value dim = 1530 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1531 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1532 loc, llvmIndexType, dim); 1533 size = cast.getResult(0); 1534 } 1535 stride = rewriter.create<LLVM::ConstantOp>( 1536 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1537 } else { 1538 // TODO: need OpFoldResult ODS adaptor to clean this up. 1539 size = 1540 subViewOp.isDynamicSize(i) 1541 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1542 : rewriter.create<LLVM::ConstantOp>( 1543 loc, llvmIndexType, 1544 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1545 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1546 stride = rewriter.create<LLVM::ConstantOp>( 1547 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1548 } else { 1549 stride = 1550 subViewOp.isDynamicStride(i) 1551 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1552 : rewriter.create<LLVM::ConstantOp>( 1553 loc, llvmIndexType, 1554 rewriter.getI64IntegerAttr( 1555 subViewOp.getStaticStride(i))); 1556 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1557 } 1558 } 1559 targetMemRef.setSize(rewriter, loc, j, size); 1560 targetMemRef.setStride(rewriter, loc, j, stride); 1561 j--; 1562 } 1563 1564 rewriter.replaceOp(subViewOp, {targetMemRef}); 1565 return success(); 1566 } 1567 }; 1568 1569 /// Conversion pattern that transforms a transpose op into: 1570 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1571 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1572 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1573 /// and stride. Size and stride are permutations of the original values. 1574 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1575 /// The transpose op is replaced by the alloca'ed pointer. 1576 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1577 public: 1578 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1579 1580 LogicalResult 1581 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1582 ConversionPatternRewriter &rewriter) const override { 1583 auto loc = transposeOp.getLoc(); 1584 MemRefDescriptor viewMemRef(adaptor.in()); 1585 1586 // No permutation, early exit. 1587 if (transposeOp.permutation().isIdentity()) 1588 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1589 1590 auto targetMemRef = MemRefDescriptor::undef( 1591 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1592 1593 // Copy the base and aligned pointers from the old descriptor to the new 1594 // one. 1595 targetMemRef.setAllocatedPtr(rewriter, loc, 1596 viewMemRef.allocatedPtr(rewriter, loc)); 1597 targetMemRef.setAlignedPtr(rewriter, loc, 1598 viewMemRef.alignedPtr(rewriter, loc)); 1599 1600 // Copy the offset pointer from the old descriptor to the new one. 1601 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1602 1603 // Iterate over the dimensions and apply size/stride permutation. 1604 for (const auto &en : 1605 llvm::enumerate(transposeOp.permutation().getResults())) { 1606 int sourcePos = en.index(); 1607 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1608 targetMemRef.setSize(rewriter, loc, targetPos, 1609 viewMemRef.size(rewriter, loc, sourcePos)); 1610 targetMemRef.setStride(rewriter, loc, targetPos, 1611 viewMemRef.stride(rewriter, loc, sourcePos)); 1612 } 1613 1614 rewriter.replaceOp(transposeOp, {targetMemRef}); 1615 return success(); 1616 } 1617 }; 1618 1619 /// Conversion pattern that transforms an op into: 1620 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1621 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1622 /// and stride. 1623 /// The view op is replaced by the descriptor. 1624 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1625 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1626 1627 // Build and return the value for the idx^th shape dimension, either by 1628 // returning the constant shape dimension or counting the proper dynamic size. 1629 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1630 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1631 unsigned idx) const { 1632 assert(idx < shape.size()); 1633 if (!ShapedType::isDynamic(shape[idx])) 1634 return createIndexConstant(rewriter, loc, shape[idx]); 1635 // Count the number of dynamic dims in range [0, idx] 1636 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1637 return ShapedType::isDynamic(v); 1638 }); 1639 return dynamicSizes[nDynamic]; 1640 } 1641 1642 // Build and return the idx^th stride, either by returning the constant stride 1643 // or by computing the dynamic stride from the current `runningStride` and 1644 // `nextSize`. The caller should keep a running stride and update it with the 1645 // result returned by this function. 1646 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1647 ArrayRef<int64_t> strides, Value nextSize, 1648 Value runningStride, unsigned idx) const { 1649 assert(idx < strides.size()); 1650 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1651 return createIndexConstant(rewriter, loc, strides[idx]); 1652 if (nextSize) 1653 return runningStride 1654 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1655 : nextSize; 1656 assert(!runningStride); 1657 return createIndexConstant(rewriter, loc, 1); 1658 } 1659 1660 LogicalResult 1661 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1662 ConversionPatternRewriter &rewriter) const override { 1663 auto loc = viewOp.getLoc(); 1664 1665 auto viewMemRefType = viewOp.getType(); 1666 auto targetElementTy = 1667 typeConverter->convertType(viewMemRefType.getElementType()); 1668 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1669 if (!targetDescTy || !targetElementTy || 1670 !LLVM::isCompatibleType(targetElementTy) || 1671 !LLVM::isCompatibleType(targetDescTy)) 1672 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1673 failure(); 1674 1675 int64_t offset; 1676 SmallVector<int64_t, 4> strides; 1677 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1678 if (failed(successStrides)) 1679 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1680 assert(offset == 0 && "expected offset to be 0"); 1681 1682 // Create the descriptor. 1683 MemRefDescriptor sourceMemRef(adaptor.source()); 1684 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1685 1686 // Field 1: Copy the allocated pointer, used for malloc/free. 1687 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1688 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1689 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1690 loc, 1691 LLVM::LLVMPointerType::get(targetElementTy, 1692 srcMemRefType.getMemorySpaceAsInt()), 1693 allocatedPtr); 1694 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1695 1696 // Field 2: Copy the actual aligned pointer to payload. 1697 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1698 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1699 alignedPtr, adaptor.byte_shift()); 1700 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1701 loc, 1702 LLVM::LLVMPointerType::get(targetElementTy, 1703 srcMemRefType.getMemorySpaceAsInt()), 1704 alignedPtr); 1705 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1706 1707 // Field 3: The offset in the resulting type must be 0. This is because of 1708 // the type change: an offset on srcType* may not be expressible as an 1709 // offset on dstType*. 1710 targetMemRef.setOffset(rewriter, loc, 1711 createIndexConstant(rewriter, loc, offset)); 1712 1713 // Early exit for 0-D corner case. 1714 if (viewMemRefType.getRank() == 0) 1715 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1716 1717 // Fields 4 and 5: Update sizes and strides. 1718 if (strides.back() != 1) 1719 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1720 failure(); 1721 Value stride = nullptr, nextSize = nullptr; 1722 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1723 // Update size. 1724 Value size = 1725 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1726 targetMemRef.setSize(rewriter, loc, i, size); 1727 // Update stride. 1728 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1729 targetMemRef.setStride(rewriter, loc, i, stride); 1730 nextSize = size; 1731 } 1732 1733 rewriter.replaceOp(viewOp, {targetMemRef}); 1734 return success(); 1735 } 1736 }; 1737 1738 //===----------------------------------------------------------------------===// 1739 // AtomicRMWOpLowering 1740 //===----------------------------------------------------------------------===// 1741 1742 /// Try to match the kind of a std.atomic_rmw to determine whether to use a 1743 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1744 static Optional<LLVM::AtomicBinOp> 1745 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1746 switch (atomicOp.kind()) { 1747 case arith::AtomicRMWKind::addf: 1748 return LLVM::AtomicBinOp::fadd; 1749 case arith::AtomicRMWKind::addi: 1750 return LLVM::AtomicBinOp::add; 1751 case arith::AtomicRMWKind::assign: 1752 return LLVM::AtomicBinOp::xchg; 1753 case arith::AtomicRMWKind::maxs: 1754 return LLVM::AtomicBinOp::max; 1755 case arith::AtomicRMWKind::maxu: 1756 return LLVM::AtomicBinOp::umax; 1757 case arith::AtomicRMWKind::mins: 1758 return LLVM::AtomicBinOp::min; 1759 case arith::AtomicRMWKind::minu: 1760 return LLVM::AtomicBinOp::umin; 1761 case arith::AtomicRMWKind::ori: 1762 return LLVM::AtomicBinOp::_or; 1763 case arith::AtomicRMWKind::andi: 1764 return LLVM::AtomicBinOp::_and; 1765 default: 1766 return llvm::None; 1767 } 1768 llvm_unreachable("Invalid AtomicRMWKind"); 1769 } 1770 1771 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1772 using Base::Base; 1773 1774 LogicalResult 1775 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1776 ConversionPatternRewriter &rewriter) const override { 1777 if (failed(match(atomicOp))) 1778 return failure(); 1779 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1780 if (!maybeKind) 1781 return failure(); 1782 auto resultType = adaptor.value().getType(); 1783 auto memRefType = atomicOp.getMemRefType(); 1784 auto dataPtr = 1785 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1786 adaptor.indices(), rewriter); 1787 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1788 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1789 LLVM::AtomicOrdering::acq_rel); 1790 return success(); 1791 } 1792 }; 1793 1794 } // namespace 1795 1796 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1797 RewritePatternSet &patterns) { 1798 // clang-format off 1799 patterns.add< 1800 AllocaOpLowering, 1801 AllocaScopeOpLowering, 1802 AtomicRMWOpLowering, 1803 AssumeAlignmentOpLowering, 1804 DimOpLowering, 1805 GenericAtomicRMWOpLowering, 1806 GlobalMemrefOpLowering, 1807 GetGlobalMemrefOpLowering, 1808 LoadOpLowering, 1809 MemRefCastOpLowering, 1810 MemRefCopyOpLowering, 1811 MemRefReinterpretCastOpLowering, 1812 MemRefReshapeOpLowering, 1813 PrefetchOpLowering, 1814 RankOpLowering, 1815 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1816 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1817 StoreOpLowering, 1818 SubViewOpLowering, 1819 TransposeOpLowering, 1820 ViewOpLowering>(converter); 1821 // clang-format on 1822 auto allocLowering = converter.getOptions().allocLowering; 1823 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1824 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1825 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1826 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1827 } 1828 1829 namespace { 1830 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1831 MemRefToLLVMPass() = default; 1832 1833 void runOnOperation() override { 1834 Operation *op = getOperation(); 1835 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1836 LowerToLLVMOptions options(&getContext(), 1837 dataLayoutAnalysis.getAtOrAbove(op)); 1838 options.allocLowering = 1839 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1840 : LowerToLLVMOptions::AllocLowering::Malloc); 1841 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1842 options.overrideIndexBitwidth(indexBitwidth); 1843 1844 LLVMTypeConverter typeConverter(&getContext(), options, 1845 &dataLayoutAnalysis); 1846 RewritePatternSet patterns(&getContext()); 1847 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1848 LLVMConversionTarget target(getContext()); 1849 target.addLegalOp<FuncOp>(); 1850 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1851 signalPassFailure(); 1852 } 1853 }; 1854 } // namespace 1855 1856 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1857 return std::make_unique<MemRefToLLVMPass>(); 1858 } 1859