1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "llvm/ADT/SmallBitVector.h" 24 25 using namespace mlir; 26 27 namespace { 28 29 bool isStaticStrideOrOffset(int64_t strideOrOffset) { 30 return !ShapedType::isDynamicStrideOrOffset(strideOrOffset); 31 } 32 33 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 34 AllocOpLowering(LLVMTypeConverter &converter) 35 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 36 converter) {} 37 38 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 39 Location loc, Value sizeBytes, 40 Operation *op) const override { 41 // Heap allocations. 42 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 43 MemRefType memRefType = allocOp.getType(); 44 45 Value alignment; 46 if (auto alignmentAttr = allocOp.getAlignment()) { 47 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 48 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 49 // In the case where no alignment is specified, we may want to override 50 // `malloc's` behavior. `malloc` typically aligns at the size of the 51 // biggest scalar on a target HW. For non-scalars, use the natural 52 // alignment of the LLVM type given by the LLVM DataLayout. 53 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 54 } 55 56 if (alignment) { 57 // Adjust the allocation size to consider alignment. 58 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 59 } 60 61 // Allocate the underlying buffer and store a pointer to it in the MemRef 62 // descriptor. 63 Type elementPtrType = this->getElementPtrType(memRefType); 64 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 65 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 66 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 67 getVoidPtrType()); 68 Value allocatedPtr = 69 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 70 71 Value alignedPtr = allocatedPtr; 72 if (alignment) { 73 // Compute the aligned type pointer. 74 Value allocatedInt = 75 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 76 Value alignmentInt = 77 createAligned(rewriter, loc, allocatedInt, alignment); 78 alignedPtr = 79 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 80 } 81 82 return std::make_tuple(allocatedPtr, alignedPtr); 83 } 84 }; 85 86 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 87 AlignedAllocOpLowering(LLVMTypeConverter &converter) 88 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 89 converter) {} 90 91 /// Returns the memref's element size in bytes using the data layout active at 92 /// `op`. 93 // TODO: there are other places where this is used. Expose publicly? 94 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 95 const DataLayout *layout = &defaultLayout; 96 if (const DataLayoutAnalysis *analysis = 97 getTypeConverter()->getDataLayoutAnalysis()) { 98 layout = &analysis->getAbove(op); 99 } 100 Type elementType = memRefType.getElementType(); 101 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 102 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 103 *layout); 104 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 105 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 106 memRefElementType, *layout); 107 return layout->getTypeSize(elementType); 108 } 109 110 /// Returns true if the memref size in bytes is known to be a multiple of 111 /// factor assuming the data layout active at `op`. 112 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 113 Operation *op) const { 114 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 115 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 116 if (ShapedType::isDynamic(type.getDimSize(i))) 117 continue; 118 sizeDivisor = sizeDivisor * type.getDimSize(i); 119 } 120 return sizeDivisor % factor == 0; 121 } 122 123 /// Returns the alignment to be used for the allocation call itself. 124 /// aligned_alloc requires the allocation size to be a power of two, and the 125 /// allocation size to be a multiple of alignment, 126 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 127 if (Optional<uint64_t> alignment = allocOp.getAlignment()) 128 return *alignment; 129 130 // Whenever we don't have alignment set, we will use an alignment 131 // consistent with the element type; since the allocation size has to be a 132 // power of two, we will bump to the next power of two if it already isn't. 133 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 134 return std::max(kMinAlignedAllocAlignment, 135 llvm::PowerOf2Ceil(eltSizeBytes)); 136 } 137 138 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 139 Location loc, Value sizeBytes, 140 Operation *op) const override { 141 // Heap allocations. 142 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 143 MemRefType memRefType = allocOp.getType(); 144 int64_t alignment = getAllocationAlignment(allocOp); 145 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 146 147 // aligned_alloc requires size to be a multiple of alignment; we will pad 148 // the size to the next multiple if necessary. 149 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 150 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 151 152 Type elementPtrType = this->getElementPtrType(memRefType); 153 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 154 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 155 auto results = 156 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 157 getVoidPtrType()); 158 Value allocatedPtr = 159 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 160 161 return std::make_tuple(allocatedPtr, allocatedPtr); 162 } 163 164 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 165 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 166 167 /// Default layout to use in absence of the corresponding analysis. 168 DataLayout defaultLayout; 169 }; 170 171 // Out of line definition, required till C++17. 172 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 173 174 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 175 AllocaOpLowering(LLVMTypeConverter &converter) 176 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 177 converter) {} 178 179 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 180 /// is set to null for stack allocations. `accessAlignment` is set if 181 /// alignment is needed post allocation (for eg. in conjunction with malloc). 182 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 183 Location loc, Value sizeBytes, 184 Operation *op) const override { 185 186 // With alloca, one gets a pointer to the element type right away. 187 // For stack allocations. 188 auto allocaOp = cast<memref::AllocaOp>(op); 189 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 190 191 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 192 loc, elementPtrType, sizeBytes, 193 allocaOp.getAlignment() ? *allocaOp.getAlignment() : 0); 194 195 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 196 } 197 }; 198 199 struct AllocaScopeOpLowering 200 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 201 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 202 203 LogicalResult 204 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 205 ConversionPatternRewriter &rewriter) const override { 206 OpBuilder::InsertionGuard guard(rewriter); 207 Location loc = allocaScopeOp.getLoc(); 208 209 // Split the current block before the AllocaScopeOp to create the inlining 210 // point. 211 auto *currentBlock = rewriter.getInsertionBlock(); 212 auto *remainingOpsBlock = 213 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 214 Block *continueBlock; 215 if (allocaScopeOp.getNumResults() == 0) { 216 continueBlock = remainingOpsBlock; 217 } else { 218 continueBlock = rewriter.createBlock( 219 remainingOpsBlock, allocaScopeOp.getResultTypes(), 220 SmallVector<Location>(allocaScopeOp->getNumResults(), 221 allocaScopeOp.getLoc())); 222 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 223 } 224 225 // Inline body region. 226 Block *beforeBody = &allocaScopeOp.getBodyRegion().front(); 227 Block *afterBody = &allocaScopeOp.getBodyRegion().back(); 228 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock); 229 230 // Save stack and then branch into the body of the region. 231 rewriter.setInsertionPointToEnd(currentBlock); 232 auto stackSaveOp = 233 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 234 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 235 236 // Replace the alloca_scope return with a branch that jumps out of the body. 237 // Stack restore before leaving the body region. 238 rewriter.setInsertionPointToEnd(afterBody); 239 auto returnOp = 240 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 241 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 242 returnOp, returnOp.getResults(), continueBlock); 243 244 // Insert stack restore before jumping out the body of the region. 245 rewriter.setInsertionPoint(branchOp); 246 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 247 248 // Replace the op with values return from the body region. 249 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 250 251 return success(); 252 } 253 }; 254 255 struct AssumeAlignmentOpLowering 256 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 257 using ConvertOpToLLVMPattern< 258 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 259 260 LogicalResult 261 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 262 ConversionPatternRewriter &rewriter) const override { 263 Value memref = adaptor.getMemref(); 264 unsigned alignment = op.getAlignment(); 265 auto loc = op.getLoc(); 266 267 MemRefDescriptor memRefDescriptor(memref); 268 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 269 270 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 271 // the asserted memref.alignedPtr isn't used anywhere else, as the real 272 // users like load/store/views always re-extract memref.alignedPtr as they 273 // get lowered. 274 // 275 // This relies on LLVM's CSE optimization (potentially after SROA), since 276 // after CSE all memref.alignedPtr instances get de-duplicated into the same 277 // pointer SSA value. 278 auto intPtrType = 279 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 280 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 281 Value mask = 282 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 283 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 284 rewriter.create<LLVM::AssumeOp>( 285 loc, rewriter.create<LLVM::ICmpOp>( 286 loc, LLVM::ICmpPredicate::eq, 287 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 288 289 rewriter.eraseOp(op); 290 return success(); 291 } 292 }; 293 294 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 295 // The memref descriptor being an SSA value, there is no need to clean it up 296 // in any way. 297 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 298 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 299 300 explicit DeallocOpLowering(LLVMTypeConverter &converter) 301 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 302 303 LogicalResult 304 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 305 ConversionPatternRewriter &rewriter) const override { 306 // Insert the `free` declaration if it is not already present. 307 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 308 MemRefDescriptor memref(adaptor.getMemref()); 309 Value casted = rewriter.create<LLVM::BitcastOp>( 310 op.getLoc(), getVoidPtrType(), 311 memref.allocatedPtr(rewriter, op.getLoc())); 312 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 313 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 314 return success(); 315 } 316 }; 317 318 struct AlignedDeallocOpLowering 319 : public ConvertOpToLLVMPattern<memref::DeallocOp> { 320 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 321 322 explicit AlignedDeallocOpLowering(LLVMTypeConverter &converter) 323 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 324 325 LogicalResult 326 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 327 ConversionPatternRewriter &rewriter) const override { 328 // Insert the `free` declaration if it is not already present. 329 auto freeFunc = 330 LLVM::lookupOrCreateAlignedFreeFn(op->getParentOfType<ModuleOp>()); 331 MemRefDescriptor memref(adaptor.memref()); 332 Value casted = rewriter.create<LLVM::BitcastOp>( 333 op.getLoc(), getVoidPtrType(), 334 memref.allocatedPtr(rewriter, op.getLoc())); 335 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 336 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 337 return success(); 338 } 339 }; 340 341 // A `dim` is converted to a constant for static sizes and to an access to the 342 // size stored in the memref descriptor for dynamic sizes. 343 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 344 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 345 346 LogicalResult 347 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 348 ConversionPatternRewriter &rewriter) const override { 349 Type operandType = dimOp.getSource().getType(); 350 if (operandType.isa<UnrankedMemRefType>()) { 351 rewriter.replaceOp( 352 dimOp, {extractSizeOfUnrankedMemRef( 353 operandType, dimOp, adaptor.getOperands(), rewriter)}); 354 355 return success(); 356 } 357 if (operandType.isa<MemRefType>()) { 358 rewriter.replaceOp( 359 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 360 adaptor.getOperands(), rewriter)}); 361 return success(); 362 } 363 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 364 } 365 366 private: 367 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 368 OpAdaptor adaptor, 369 ConversionPatternRewriter &rewriter) const { 370 Location loc = dimOp.getLoc(); 371 372 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 373 auto scalarMemRefType = 374 MemRefType::get({}, unrankedMemRefType.getElementType()); 375 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 376 377 // Extract pointer to the underlying ranked descriptor and bitcast it to a 378 // memref<element_type> descriptor pointer to minimize the number of GEP 379 // operations. 380 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource()); 381 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 382 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 383 loc, 384 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 385 addressSpace), 386 underlyingRankedDesc); 387 388 // Get pointer to offset field of memref<element_type> descriptor. 389 Type indexPtrTy = LLVM::LLVMPointerType::get( 390 getTypeConverter()->getIndexType(), addressSpace); 391 Value two = rewriter.create<LLVM::ConstantOp>( 392 loc, typeConverter->convertType(rewriter.getI32Type()), 393 rewriter.getI32IntegerAttr(2)); 394 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 395 loc, indexPtrTy, scalarMemRefDescPtr, 396 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 397 398 // The size value that we have to extract can be obtained using GEPop with 399 // `dimOp.index() + 1` index argument. 400 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 401 loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex()); 402 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 403 ValueRange({idxPlusOne})); 404 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 405 } 406 407 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 408 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 409 return idx; 410 411 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>()) 412 return constantOp.getValue() 413 .cast<IntegerAttr>() 414 .getValue() 415 .getSExtValue(); 416 417 return llvm::None; 418 } 419 420 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 421 OpAdaptor adaptor, 422 ConversionPatternRewriter &rewriter) const { 423 Location loc = dimOp.getLoc(); 424 425 // Take advantage if index is constant. 426 MemRefType memRefType = operandType.cast<MemRefType>(); 427 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 428 int64_t i = *index; 429 if (memRefType.isDynamicDim(i)) { 430 // extract dynamic size from the memref descriptor. 431 MemRefDescriptor descriptor(adaptor.getSource()); 432 return descriptor.size(rewriter, loc, i); 433 } 434 // Use constant for static size. 435 int64_t dimSize = memRefType.getDimSize(i); 436 return createIndexConstant(rewriter, loc, dimSize); 437 } 438 Value index = adaptor.getIndex(); 439 int64_t rank = memRefType.getRank(); 440 MemRefDescriptor memrefDescriptor(adaptor.getSource()); 441 return memrefDescriptor.size(rewriter, loc, index, rank); 442 } 443 }; 444 445 /// Common base for load and store operations on MemRefs. Restricts the match 446 /// to supported MemRef types. Provides functionality to emit code accessing a 447 /// specific element of the underlying data buffer. 448 template <typename Derived> 449 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 450 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 451 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 452 using Base = LoadStoreOpLowering<Derived>; 453 454 LogicalResult match(Derived op) const override { 455 MemRefType type = op.getMemRefType(); 456 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 457 } 458 }; 459 460 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 461 /// retried until it succeeds in atomically storing a new value into memory. 462 /// 463 /// +---------------------------------+ 464 /// | <code before the AtomicRMWOp> | 465 /// | <compute initial %loaded> | 466 /// | cf.br loop(%loaded) | 467 /// +---------------------------------+ 468 /// | 469 /// -------| | 470 /// | v v 471 /// | +--------------------------------+ 472 /// | | loop(%loaded): | 473 /// | | <body contents> | 474 /// | | %pair = cmpxchg | 475 /// | | %ok = %pair[0] | 476 /// | | %new = %pair[1] | 477 /// | | cf.cond_br %ok, end, loop(%new) | 478 /// | +--------------------------------+ 479 /// | | | 480 /// |----------- | 481 /// v 482 /// +--------------------------------+ 483 /// | end: | 484 /// | <code after the AtomicRMWOp> | 485 /// +--------------------------------+ 486 /// 487 struct GenericAtomicRMWOpLowering 488 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 489 using Base::Base; 490 491 LogicalResult 492 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 493 ConversionPatternRewriter &rewriter) const override { 494 auto loc = atomicOp.getLoc(); 495 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 496 497 // Split the block into initial, loop, and ending parts. 498 auto *initBlock = rewriter.getInsertionBlock(); 499 auto *loopBlock = rewriter.createBlock( 500 initBlock->getParent(), std::next(Region::iterator(initBlock)), 501 valueType, loc); 502 auto *endBlock = rewriter.createBlock( 503 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 504 505 // Operations range to be moved to `endBlock`. 506 auto opsToMoveStart = atomicOp->getIterator(); 507 auto opsToMoveEnd = initBlock->back().getIterator(); 508 509 // Compute the loaded value and branch to the loop block. 510 rewriter.setInsertionPointToEnd(initBlock); 511 auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>(); 512 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), 513 adaptor.getIndices(), rewriter); 514 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 515 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 516 517 // Prepare the body of the loop block. 518 rewriter.setInsertionPointToStart(loopBlock); 519 520 // Clone the GenericAtomicRMWOp region and extract the result. 521 auto loopArgument = loopBlock->getArgument(0); 522 BlockAndValueMapping mapping; 523 mapping.map(atomicOp.getCurrentValue(), loopArgument); 524 Block &entryBlock = atomicOp.body().front(); 525 for (auto &nestedOp : entryBlock.without_terminator()) { 526 Operation *clone = rewriter.clone(nestedOp, mapping); 527 mapping.map(nestedOp.getResults(), clone->getResults()); 528 } 529 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 530 531 // Prepare the epilog of the loop block. 532 // Append the cmpxchg op to the end of the loop block. 533 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 534 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 535 auto boolType = IntegerType::get(rewriter.getContext(), 1); 536 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 537 {valueType, boolType}); 538 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 539 loc, pairType, dataPtr, loopArgument, result, successOrdering, 540 failureOrdering); 541 // Extract the %new_loaded and %ok values from the pair. 542 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 543 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 544 Value ok = rewriter.create<LLVM::ExtractValueOp>( 545 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 546 547 // Conditionally branch to the end or back to the loop depending on %ok. 548 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 549 loopBlock, newLoaded); 550 551 rewriter.setInsertionPointToEnd(endBlock); 552 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 553 std::next(opsToMoveEnd), rewriter); 554 555 // The 'result' of the atomic_rmw op is the newly loaded value. 556 rewriter.replaceOp(atomicOp, {newLoaded}); 557 558 return success(); 559 } 560 561 private: 562 // Clones a segment of ops [start, end) and erases the original. 563 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 564 Block::iterator start, Block::iterator end, 565 ConversionPatternRewriter &rewriter) const { 566 BlockAndValueMapping mapping; 567 mapping.map(oldResult, newResult); 568 SmallVector<Operation *, 2> opsToErase; 569 for (auto it = start; it != end; ++it) { 570 rewriter.clone(*it, mapping); 571 opsToErase.push_back(&*it); 572 } 573 for (auto *it : opsToErase) 574 rewriter.eraseOp(it); 575 } 576 }; 577 578 /// Returns the LLVM type of the global variable given the memref type `type`. 579 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 580 LLVMTypeConverter &typeConverter) { 581 // LLVM type for a global memref will be a multi-dimension array. For 582 // declarations or uninitialized global memrefs, we can potentially flatten 583 // this to a 1D array. However, for memref.global's with an initial value, 584 // we do not intend to flatten the ElementsAttribute when going from std -> 585 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 586 Type elementType = typeConverter.convertType(type.getElementType()); 587 Type arrayTy = elementType; 588 // Shape has the outermost dim at index 0, so need to walk it backwards 589 for (int64_t dim : llvm::reverse(type.getShape())) 590 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 591 return arrayTy; 592 } 593 594 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 595 struct GlobalMemrefOpLowering 596 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 597 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 598 599 LogicalResult 600 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 601 ConversionPatternRewriter &rewriter) const override { 602 MemRefType type = global.getType(); 603 if (!isConvertibleAndHasIdentityMaps(type)) 604 return failure(); 605 606 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 607 608 LLVM::Linkage linkage = 609 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 610 611 Attribute initialValue = nullptr; 612 if (!global.isExternal() && !global.isUninitialized()) { 613 auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>(); 614 initialValue = elementsAttr; 615 616 // For scalar memrefs, the global variable created is of the element type, 617 // so unpack the elements attribute to extract the value. 618 if (type.getRank() == 0) 619 initialValue = elementsAttr.getSplatValue<Attribute>(); 620 } 621 622 uint64_t alignment = global.getAlignment().value_or(0); 623 624 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 625 global, arrayTy, global.getConstant(), linkage, global.getSymName(), 626 initialValue, alignment, type.getMemorySpaceAsInt()); 627 if (!global.isExternal() && global.isUninitialized()) { 628 Block *blk = new Block(); 629 newGlobal.getInitializerRegion().push_back(blk); 630 rewriter.setInsertionPointToStart(blk); 631 Value undef[] = { 632 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 633 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 634 } 635 return success(); 636 } 637 }; 638 639 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 640 /// the first element stashed into the descriptor. This reuses 641 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 642 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 643 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 644 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 645 converter) {} 646 647 /// Buffer "allocation" for memref.get_global op is getting the address of 648 /// the global variable referenced. 649 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 650 Location loc, Value sizeBytes, 651 Operation *op) const override { 652 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 653 MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>(); 654 unsigned memSpace = type.getMemorySpaceAsInt(); 655 656 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 657 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 658 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), 659 getGlobalOp.getName()); 660 661 // Get the address of the first element in the array by creating a GEP with 662 // the address of the GV as the base, and (rank + 1) number of 0 indices. 663 Type elementType = typeConverter->convertType(type.getElementType()); 664 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 665 666 SmallVector<Value> operands; 667 operands.insert(operands.end(), type.getRank() + 1, 668 createIndexConstant(rewriter, loc, 0)); 669 auto gep = 670 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 671 672 // We do not expect the memref obtained using `memref.get_global` to be 673 // ever deallocated. Set the allocated pointer to be known bad value to 674 // help debug if that ever happens. 675 auto intPtrType = getIntPtrType(memSpace); 676 Value deadBeefConst = 677 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 678 auto deadBeefPtr = 679 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 680 681 // Both allocated and aligned pointers are same. We could potentially stash 682 // a nullptr for the allocated pointer since we do not expect any dealloc. 683 return std::make_tuple(deadBeefPtr, gep); 684 } 685 }; 686 687 // Load operation is lowered to obtaining a pointer to the indexed element 688 // and loading it. 689 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 690 using Base::Base; 691 692 LogicalResult 693 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 694 ConversionPatternRewriter &rewriter) const override { 695 auto type = loadOp.getMemRefType(); 696 697 Value dataPtr = 698 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), 699 adaptor.getIndices(), rewriter); 700 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 701 return success(); 702 } 703 }; 704 705 // Store operation is lowered to obtaining a pointer to the indexed element, 706 // and storing the given value to it. 707 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 708 using Base::Base; 709 710 LogicalResult 711 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 712 ConversionPatternRewriter &rewriter) const override { 713 auto type = op.getMemRefType(); 714 715 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), 716 adaptor.getIndices(), rewriter); 717 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr); 718 return success(); 719 } 720 }; 721 722 // The prefetch operation is lowered in a way similar to the load operation 723 // except that the llvm.prefetch operation is used for replacement. 724 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 725 using Base::Base; 726 727 LogicalResult 728 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 729 ConversionPatternRewriter &rewriter) const override { 730 auto type = prefetchOp.getMemRefType(); 731 auto loc = prefetchOp.getLoc(); 732 733 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(), 734 adaptor.getIndices(), rewriter); 735 736 // Replace with llvm.prefetch. 737 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 738 auto isWrite = rewriter.create<LLVM::ConstantOp>( 739 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite())); 740 auto localityHint = rewriter.create<LLVM::ConstantOp>( 741 loc, llvmI32Type, 742 rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint())); 743 auto isData = rewriter.create<LLVM::ConstantOp>( 744 loc, llvmI32Type, 745 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache())); 746 747 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 748 localityHint, isData); 749 return success(); 750 } 751 }; 752 753 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 754 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 755 756 LogicalResult 757 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 758 ConversionPatternRewriter &rewriter) const override { 759 Location loc = op.getLoc(); 760 Type operandType = op.getMemref().getType(); 761 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 762 UnrankedMemRefDescriptor desc(adaptor.getMemref()); 763 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 764 return success(); 765 } 766 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 767 rewriter.replaceOp( 768 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 769 return success(); 770 } 771 return failure(); 772 } 773 }; 774 775 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 776 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 777 778 LogicalResult match(memref::CastOp memRefCastOp) const override { 779 Type srcType = memRefCastOp.getOperand().getType(); 780 Type dstType = memRefCastOp.getType(); 781 782 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 783 // used for type erasure. For now they must preserve underlying element type 784 // and require source and result type to have the same rank. Therefore, 785 // perform a sanity check that the underlying structs are the same. Once op 786 // semantics are relaxed we can revisit. 787 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 788 return success(typeConverter->convertType(srcType) == 789 typeConverter->convertType(dstType)); 790 791 // At least one of the operands is unranked type 792 assert(srcType.isa<UnrankedMemRefType>() || 793 dstType.isa<UnrankedMemRefType>()); 794 795 // Unranked to unranked cast is disallowed 796 return !(srcType.isa<UnrankedMemRefType>() && 797 dstType.isa<UnrankedMemRefType>()) 798 ? success() 799 : failure(); 800 } 801 802 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 803 ConversionPatternRewriter &rewriter) const override { 804 auto srcType = memRefCastOp.getOperand().getType(); 805 auto dstType = memRefCastOp.getType(); 806 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 807 auto loc = memRefCastOp.getLoc(); 808 809 // For ranked/ranked case, just keep the original descriptor. 810 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 811 return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); 812 813 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 814 // Casting ranked to unranked memref type 815 // Set the rank in the destination from the memref type 816 // Allocate space on the stack and copy the src memref descriptor 817 // Set the ptr in the destination to the stack space 818 auto srcMemRefType = srcType.cast<MemRefType>(); 819 int64_t rank = srcMemRefType.getRank(); 820 // ptr = AllocaOp sizeof(MemRefDescriptor) 821 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 822 loc, adaptor.getSource(), rewriter); 823 // voidptr = BitCastOp srcType* to void* 824 auto voidPtr = 825 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 826 .getResult(); 827 // rank = ConstantOp srcRank 828 auto rankVal = rewriter.create<LLVM::ConstantOp>( 829 loc, getIndexType(), rewriter.getIndexAttr(rank)); 830 // undef = UndefOp 831 UnrankedMemRefDescriptor memRefDesc = 832 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 833 // d1 = InsertValueOp undef, rank, 0 834 memRefDesc.setRank(rewriter, loc, rankVal); 835 // d2 = InsertValueOp d1, voidptr, 1 836 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 837 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 838 839 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 840 // Casting from unranked type to ranked. 841 // The operation is assumed to be doing a correct cast. If the destination 842 // type mismatches the unranked the type, it is undefined behavior. 843 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource()); 844 // ptr = ExtractValueOp src, 1 845 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 846 // castPtr = BitCastOp i8* to structTy* 847 auto castPtr = 848 rewriter 849 .create<LLVM::BitcastOp>( 850 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 851 .getResult(); 852 // struct = LoadOp castPtr 853 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 854 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 855 } else { 856 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 857 } 858 } 859 }; 860 861 /// Pattern to lower a `memref.copy` to llvm. 862 /// 863 /// For memrefs with identity layouts, the copy is lowered to the llvm 864 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 865 /// to the generic `MemrefCopyFn`. 866 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 867 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 868 869 LogicalResult 870 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 871 ConversionPatternRewriter &rewriter) const { 872 auto loc = op.getLoc(); 873 auto srcType = op.getSource().getType().dyn_cast<MemRefType>(); 874 875 MemRefDescriptor srcDesc(adaptor.getSource()); 876 877 // Compute number of elements. 878 Value numElements = rewriter.create<LLVM::ConstantOp>( 879 loc, getIndexType(), rewriter.getIndexAttr(1)); 880 for (int pos = 0; pos < srcType.getRank(); ++pos) { 881 auto size = srcDesc.size(rewriter, loc, pos); 882 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 883 } 884 885 // Get element size. 886 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 887 // Compute total. 888 Value totalSize = 889 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 890 891 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 892 Value srcOffset = srcDesc.offset(rewriter, loc); 893 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(), 894 srcBasePtr, srcOffset); 895 MemRefDescriptor targetDesc(adaptor.getTarget()); 896 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 897 Value targetOffset = targetDesc.offset(rewriter, loc); 898 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(), 899 targetBasePtr, targetOffset); 900 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 901 loc, typeConverter->convertType(rewriter.getI1Type()), 902 rewriter.getBoolAttr(false)); 903 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 904 isVolatile); 905 rewriter.eraseOp(op); 906 907 return success(); 908 } 909 910 LogicalResult 911 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 912 ConversionPatternRewriter &rewriter) const { 913 auto loc = op.getLoc(); 914 auto srcType = op.getSource().getType().cast<BaseMemRefType>(); 915 auto targetType = op.getTarget().getType().cast<BaseMemRefType>(); 916 917 // First make sure we have an unranked memref descriptor representation. 918 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 919 auto rank = rewriter.create<LLVM::ConstantOp>( 920 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 921 auto *typeConverter = getTypeConverter(); 922 auto ptr = 923 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 924 auto voidPtr = 925 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 926 .getResult(); 927 auto unrankedType = 928 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 929 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 930 unrankedType, 931 ValueRange{rank, voidPtr}); 932 }; 933 934 Value unrankedSource = srcType.hasRank() 935 ? makeUnranked(adaptor.getSource(), srcType) 936 : adaptor.getSource(); 937 Value unrankedTarget = targetType.hasRank() 938 ? makeUnranked(adaptor.getTarget(), targetType) 939 : adaptor.getTarget(); 940 941 // Now promote the unranked descriptors to the stack. 942 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 943 rewriter.getIndexAttr(1)); 944 auto promote = [&](Value desc) { 945 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 946 auto allocated = 947 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 948 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 949 return allocated; 950 }; 951 952 auto sourcePtr = promote(unrankedSource); 953 auto targetPtr = promote(unrankedTarget); 954 955 unsigned typeSize = 956 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 957 auto elemSize = rewriter.create<LLVM::ConstantOp>( 958 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 959 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 960 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 961 rewriter.create<LLVM::CallOp>(loc, copyFn, 962 ValueRange{elemSize, sourcePtr, targetPtr}); 963 rewriter.eraseOp(op); 964 965 return success(); 966 } 967 968 LogicalResult 969 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 970 ConversionPatternRewriter &rewriter) const override { 971 auto srcType = op.getSource().getType().cast<BaseMemRefType>(); 972 auto targetType = op.getTarget().getType().cast<BaseMemRefType>(); 973 974 auto isContiguousMemrefType = [](BaseMemRefType type) { 975 auto memrefType = type.dyn_cast<mlir::MemRefType>(); 976 // We can use memcpy for memrefs if they have an identity layout or are 977 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 978 // special case handled by memrefCopy. 979 return memrefType && 980 (memrefType.getLayout().isIdentity() || 981 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 982 isStaticShapeAndContiguousRowMajor(memrefType))); 983 }; 984 985 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 986 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 987 988 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 989 } 990 }; 991 992 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 993 /// memref type. In unranked case, the fields are extracted from the underlying 994 /// ranked descriptor. 995 static void extractPointersAndOffset(Location loc, 996 ConversionPatternRewriter &rewriter, 997 LLVMTypeConverter &typeConverter, 998 Value originalOperand, 999 Value convertedOperand, 1000 Value *allocatedPtr, Value *alignedPtr, 1001 Value *offset = nullptr) { 1002 Type operandType = originalOperand.getType(); 1003 if (operandType.isa<MemRefType>()) { 1004 MemRefDescriptor desc(convertedOperand); 1005 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 1006 *alignedPtr = desc.alignedPtr(rewriter, loc); 1007 if (offset != nullptr) 1008 *offset = desc.offset(rewriter, loc); 1009 return; 1010 } 1011 1012 unsigned memorySpace = 1013 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 1014 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 1015 Type llvmElementType = typeConverter.convertType(elementType); 1016 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 1017 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 1018 1019 // Extract pointer to the underlying ranked memref descriptor and cast it to 1020 // ElemType**. 1021 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 1022 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 1023 1024 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 1025 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 1026 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 1027 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1028 if (offset != nullptr) { 1029 *offset = UnrankedMemRefDescriptor::offset( 1030 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1031 } 1032 } 1033 1034 struct MemRefReinterpretCastOpLowering 1035 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1036 using ConvertOpToLLVMPattern< 1037 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1038 1039 LogicalResult 1040 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1041 ConversionPatternRewriter &rewriter) const override { 1042 Type srcType = castOp.getSource().getType(); 1043 1044 Value descriptor; 1045 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1046 adaptor, &descriptor))) 1047 return failure(); 1048 rewriter.replaceOp(castOp, {descriptor}); 1049 return success(); 1050 } 1051 1052 private: 1053 LogicalResult convertSourceMemRefToDescriptor( 1054 ConversionPatternRewriter &rewriter, Type srcType, 1055 memref::ReinterpretCastOp castOp, 1056 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1057 MemRefType targetMemRefType = 1058 castOp.getResult().getType().cast<MemRefType>(); 1059 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1060 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1061 if (!llvmTargetDescriptorTy) 1062 return failure(); 1063 1064 // Create descriptor. 1065 Location loc = castOp.getLoc(); 1066 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1067 1068 // Set allocated and aligned pointers. 1069 Value allocatedPtr, alignedPtr; 1070 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1071 castOp.getSource(), adaptor.getSource(), 1072 &allocatedPtr, &alignedPtr); 1073 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1074 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1075 1076 // Set offset. 1077 if (castOp.isDynamicOffset(0)) 1078 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); 1079 else 1080 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1081 1082 // Set sizes and strides. 1083 unsigned dynSizeId = 0; 1084 unsigned dynStrideId = 0; 1085 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1086 if (castOp.isDynamicSize(i)) 1087 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); 1088 else 1089 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1090 1091 if (castOp.isDynamicStride(i)) 1092 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); 1093 else 1094 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1095 } 1096 *descriptor = desc; 1097 return success(); 1098 } 1099 }; 1100 1101 struct MemRefReshapeOpLowering 1102 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1103 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1104 1105 LogicalResult 1106 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1107 ConversionPatternRewriter &rewriter) const override { 1108 Type srcType = reshapeOp.getSource().getType(); 1109 1110 Value descriptor; 1111 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1112 adaptor, &descriptor))) 1113 return failure(); 1114 rewriter.replaceOp(reshapeOp, {descriptor}); 1115 return success(); 1116 } 1117 1118 private: 1119 LogicalResult 1120 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1121 Type srcType, memref::ReshapeOp reshapeOp, 1122 memref::ReshapeOp::Adaptor adaptor, 1123 Value *descriptor) const { 1124 auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>(); 1125 if (shapeMemRefType.hasStaticShape()) { 1126 MemRefType targetMemRefType = 1127 reshapeOp.getResult().getType().cast<MemRefType>(); 1128 auto llvmTargetDescriptorTy = 1129 typeConverter->convertType(targetMemRefType) 1130 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1131 if (!llvmTargetDescriptorTy) 1132 return failure(); 1133 1134 // Create descriptor. 1135 Location loc = reshapeOp.getLoc(); 1136 auto desc = 1137 MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1138 1139 // Set allocated and aligned pointers. 1140 Value allocatedPtr, alignedPtr; 1141 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1142 reshapeOp.getSource(), adaptor.getSource(), 1143 &allocatedPtr, &alignedPtr); 1144 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1145 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1146 1147 // Extract the offset and strides from the type. 1148 int64_t offset; 1149 SmallVector<int64_t> strides; 1150 if (failed(getStridesAndOffset(targetMemRefType, strides, offset))) 1151 return rewriter.notifyMatchFailure( 1152 reshapeOp, "failed to get stride and offset exprs"); 1153 1154 if (!isStaticStrideOrOffset(offset)) 1155 return rewriter.notifyMatchFailure(reshapeOp, 1156 "dynamic offset is unsupported"); 1157 1158 desc.setConstantOffset(rewriter, loc, offset); 1159 1160 assert(targetMemRefType.getLayout().isIdentity() && 1161 "Identity layout map is a precondition of a valid reshape op"); 1162 1163 Value stride = nullptr; 1164 int64_t targetRank = targetMemRefType.getRank(); 1165 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) { 1166 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1167 // If the stride for this dimension is dynamic, then use the product 1168 // of the sizes of the inner dimensions. 1169 stride = createIndexConstant(rewriter, loc, strides[i]); 1170 } else if (!stride) { 1171 // `stride` is null only in the first iteration of the loop. However, 1172 // since the target memref has an identity layout, we can safely set 1173 // the innermost stride to 1. 1174 stride = createIndexConstant(rewriter, loc, 1); 1175 } 1176 1177 Value dimSize; 1178 int64_t size = targetMemRefType.getDimSize(i); 1179 // If the size of this dimension is dynamic, then load it at runtime 1180 // from the shape operand. 1181 if (!ShapedType::isDynamic(size)) { 1182 dimSize = createIndexConstant(rewriter, loc, size); 1183 } else { 1184 Value shapeOp = reshapeOp.getShape(); 1185 Value index = createIndexConstant(rewriter, loc, i); 1186 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index); 1187 } 1188 1189 desc.setSize(rewriter, loc, i, dimSize); 1190 desc.setStride(rewriter, loc, i, stride); 1191 1192 // Prepare the stride value for the next dimension. 1193 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize); 1194 } 1195 1196 *descriptor = desc; 1197 return success(); 1198 } 1199 1200 // The shape is a rank-1 tensor with unknown length. 1201 Location loc = reshapeOp.getLoc(); 1202 MemRefDescriptor shapeDesc(adaptor.getShape()); 1203 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1204 1205 // Extract address space and element type. 1206 auto targetType = 1207 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1208 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1209 Type elementType = targetType.getElementType(); 1210 1211 // Create the unranked memref descriptor that holds the ranked one. The 1212 // inner descriptor is allocated on stack. 1213 auto targetDesc = UnrankedMemRefDescriptor::undef( 1214 rewriter, loc, typeConverter->convertType(targetType)); 1215 targetDesc.setRank(rewriter, loc, resultRank); 1216 SmallVector<Value, 4> sizes; 1217 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1218 targetDesc, sizes); 1219 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1220 loc, getVoidPtrType(), sizes.front(), llvm::None); 1221 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1222 1223 // Extract pointers and offset from the source memref. 1224 Value allocatedPtr, alignedPtr, offset; 1225 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1226 reshapeOp.getSource(), adaptor.getSource(), 1227 &allocatedPtr, &alignedPtr, &offset); 1228 1229 // Set pointers and offset. 1230 Type llvmElementType = typeConverter->convertType(elementType); 1231 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1232 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1233 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1234 elementPtrPtrType, allocatedPtr); 1235 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1236 underlyingDescPtr, 1237 elementPtrPtrType, alignedPtr); 1238 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1239 underlyingDescPtr, elementPtrPtrType, 1240 offset); 1241 1242 // Use the offset pointer as base for further addressing. Copy over the new 1243 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1244 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1245 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1246 elementPtrPtrType); 1247 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1248 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1249 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1250 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1251 Value resultRankMinusOne = 1252 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1253 1254 Block *initBlock = rewriter.getInsertionBlock(); 1255 Type indexType = getTypeConverter()->getIndexType(); 1256 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1257 1258 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1259 {indexType, indexType}, {loc, loc}); 1260 1261 // Move the remaining initBlock ops to condBlock. 1262 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1263 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1264 1265 rewriter.setInsertionPointToEnd(initBlock); 1266 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1267 condBlock); 1268 rewriter.setInsertionPointToStart(condBlock); 1269 Value indexArg = condBlock->getArgument(0); 1270 Value strideArg = condBlock->getArgument(1); 1271 1272 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1273 Value pred = rewriter.create<LLVM::ICmpOp>( 1274 loc, IntegerType::get(rewriter.getContext(), 1), 1275 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1276 1277 Block *bodyBlock = 1278 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1279 rewriter.setInsertionPointToStart(bodyBlock); 1280 1281 // Copy size from shape to descriptor. 1282 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1283 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1284 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1285 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1286 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1287 targetSizesBase, indexArg, size); 1288 1289 // Write stride value and compute next one. 1290 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1291 targetStridesBase, indexArg, strideArg); 1292 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1293 1294 // Decrement loop counter and branch back. 1295 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1296 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1297 condBlock); 1298 1299 Block *remainder = 1300 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1301 1302 // Hook up the cond exit to the remainder. 1303 rewriter.setInsertionPointToEnd(condBlock); 1304 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1305 llvm::None); 1306 1307 // Reset position to beginning of new remainder block. 1308 rewriter.setInsertionPointToStart(remainder); 1309 1310 *descriptor = targetDesc; 1311 return success(); 1312 } 1313 }; 1314 1315 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1316 /// `Value`s. 1317 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1318 Type &llvmIndexType, 1319 ArrayRef<OpFoldResult> valueOrAttrVec) { 1320 return llvm::to_vector<4>( 1321 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1322 if (auto attr = value.dyn_cast<Attribute>()) 1323 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1324 return value.get<Value>(); 1325 })); 1326 } 1327 1328 /// Compute a map that for a given dimension of the expanded type gives the 1329 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1330 /// the `reassocation` maps. 1331 static DenseMap<int64_t, int64_t> 1332 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1333 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1334 for (auto &en : enumerate(reassociation)) { 1335 for (auto dim : en.value()) 1336 expandedDimToCollapsedDim[dim] = en.index(); 1337 } 1338 return expandedDimToCollapsedDim; 1339 } 1340 1341 static OpFoldResult 1342 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1343 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1344 MemRefDescriptor &inDesc, 1345 ArrayRef<int64_t> inStaticShape, 1346 ArrayRef<ReassociationIndices> reassocation, 1347 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1348 int64_t outDimSize = outStaticShape[outDimIndex]; 1349 if (!ShapedType::isDynamic(outDimSize)) 1350 return b.getIndexAttr(outDimSize); 1351 1352 // Calculate the multiplication of all the out dim sizes except the 1353 // current dim. 1354 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1355 int64_t otherDimSizesMul = 1; 1356 for (auto otherDimIndex : reassocation[inDimIndex]) { 1357 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1358 continue; 1359 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1360 assert(!ShapedType::isDynamic(otherDimSize) && 1361 "single dimension cannot be expanded into multiple dynamic " 1362 "dimensions"); 1363 otherDimSizesMul *= otherDimSize; 1364 } 1365 1366 // outDimSize = inDimSize / otherOutDimSizesMul 1367 int64_t inDimSize = inStaticShape[inDimIndex]; 1368 Value inDimSizeDynamic = 1369 ShapedType::isDynamic(inDimSize) 1370 ? inDesc.size(b, loc, inDimIndex) 1371 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1372 b.getIndexAttr(inDimSize)); 1373 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1374 loc, inDimSizeDynamic, 1375 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1376 b.getIndexAttr(otherDimSizesMul))); 1377 return outDimSizeDynamic; 1378 } 1379 1380 static OpFoldResult getCollapsedOutputDimSize( 1381 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1382 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1383 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1384 if (!ShapedType::isDynamic(outDimSize)) 1385 return b.getIndexAttr(outDimSize); 1386 1387 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1388 Value outDimSizeDynamic = c1; 1389 for (auto inDimIndex : reassocation[outDimIndex]) { 1390 int64_t inDimSize = inStaticShape[inDimIndex]; 1391 Value inDimSizeDynamic = 1392 ShapedType::isDynamic(inDimSize) 1393 ? inDesc.size(b, loc, inDimIndex) 1394 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1395 b.getIndexAttr(inDimSize)); 1396 outDimSizeDynamic = 1397 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1398 } 1399 return outDimSizeDynamic; 1400 } 1401 1402 static SmallVector<OpFoldResult, 4> 1403 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1404 ArrayRef<ReassociationIndices> reassociation, 1405 ArrayRef<int64_t> inStaticShape, 1406 MemRefDescriptor &inDesc, 1407 ArrayRef<int64_t> outStaticShape) { 1408 return llvm::to_vector<4>(llvm::map_range( 1409 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1410 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1411 outStaticShape[outDimIndex], 1412 inStaticShape, inDesc, reassociation); 1413 })); 1414 } 1415 1416 static SmallVector<OpFoldResult, 4> 1417 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1418 ArrayRef<ReassociationIndices> reassociation, 1419 ArrayRef<int64_t> inStaticShape, 1420 MemRefDescriptor &inDesc, 1421 ArrayRef<int64_t> outStaticShape) { 1422 DenseMap<int64_t, int64_t> outDimToInDimMap = 1423 getExpandedDimToCollapsedDimMap(reassociation); 1424 return llvm::to_vector<4>(llvm::map_range( 1425 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1426 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1427 outStaticShape, inDesc, inStaticShape, 1428 reassociation, outDimToInDimMap); 1429 })); 1430 } 1431 1432 static SmallVector<Value> 1433 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1434 ArrayRef<ReassociationIndices> reassociation, 1435 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1436 ArrayRef<int64_t> outStaticShape) { 1437 return outStaticShape.size() < inStaticShape.size() 1438 ? getAsValues(b, loc, llvmIndexType, 1439 getCollapsedOutputShape(b, loc, llvmIndexType, 1440 reassociation, inStaticShape, 1441 inDesc, outStaticShape)) 1442 : getAsValues(b, loc, llvmIndexType, 1443 getExpandedOutputShape(b, loc, llvmIndexType, 1444 reassociation, inStaticShape, 1445 inDesc, outStaticShape)); 1446 } 1447 1448 static void fillInStridesForExpandedMemDescriptor( 1449 OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, 1450 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1451 // See comments for computeExpandedLayoutMap for details on how the strides 1452 // are calculated. 1453 for (auto &en : llvm::enumerate(reassociation)) { 1454 auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); 1455 for (auto dstIndex : llvm::reverse(en.value())) { 1456 dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); 1457 Value size = dstDesc.size(b, loc, dstIndex); 1458 currentStrideToExpand = 1459 b.create<LLVM::MulOp>(loc, size, currentStrideToExpand); 1460 } 1461 } 1462 } 1463 1464 static void fillInStridesForCollapsedMemDescriptor( 1465 ConversionPatternRewriter &rewriter, Location loc, Operation *op, 1466 TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, 1467 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1468 // See comments for computeCollapsedLayoutMap for details on how the strides 1469 // are calculated. 1470 auto srcShape = srcType.getShape(); 1471 for (auto &en : llvm::enumerate(reassociation)) { 1472 rewriter.setInsertionPoint(op); 1473 auto dstIndex = en.index(); 1474 ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value()); 1475 while (srcShape[ref.back()] == 1 && ref.size() > 1) 1476 ref = ref.drop_back(); 1477 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { 1478 dstDesc.setStride(rewriter, loc, dstIndex, 1479 srcDesc.stride(rewriter, loc, ref.back())); 1480 } else { 1481 // Iterate over the source strides in reverse order. Skip over the 1482 // dimensions whose size is 1. 1483 // TODO: we should take the minimum stride in the reassociation group 1484 // instead of just the first where the dimension is not 1. 1485 // 1486 // +------------------------------------------------------+ 1487 // | curEntry: | 1488 // | %srcStride = strides[srcIndex] | 1489 // | %neOne = cmp sizes[srcIndex],1 +--+ 1490 // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | 1491 // +-------------------------+----------------------------+ | 1492 // | | 1493 // v | 1494 // +-----------------------------+ | 1495 // | nextEntry: | | 1496 // | ... +---+ | 1497 // +--------------+--------------+ | | 1498 // | | | 1499 // v | | 1500 // +-----------------------------+ | | 1501 // | nextEntry: | | | 1502 // | ... | | | 1503 // +--------------+--------------+ | +--------+ 1504 // | | | 1505 // v v v 1506 // +--------------------------------------------------+ 1507 // | continue(%newStride): | 1508 // | %newMemRefDes = setStride(%newStride,dstIndex) | 1509 // +--------------------------------------------------+ 1510 OpBuilder::InsertionGuard guard(rewriter); 1511 Block *initBlock = rewriter.getInsertionBlock(); 1512 Block *continueBlock = 1513 rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); 1514 continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); 1515 rewriter.setInsertionPointToStart(continueBlock); 1516 dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); 1517 1518 Block *curEntryBlock = initBlock; 1519 Block *nextEntryBlock; 1520 for (auto srcIndex : llvm::reverse(ref)) { 1521 if (srcShape[srcIndex] == 1 && srcIndex != ref.front()) 1522 continue; 1523 rewriter.setInsertionPointToEnd(curEntryBlock); 1524 Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); 1525 if (srcIndex == ref.front()) { 1526 rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock); 1527 break; 1528 } 1529 Value one = rewriter.create<LLVM::ConstantOp>( 1530 loc, typeConverter->convertType(rewriter.getI64Type()), 1531 rewriter.getI32IntegerAttr(1)); 1532 Value predNeOne = rewriter.create<LLVM::ICmpOp>( 1533 loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), 1534 one); 1535 { 1536 OpBuilder::InsertionGuard guard(rewriter); 1537 nextEntryBlock = rewriter.createBlock( 1538 initBlock->getParent(), Region::iterator(continueBlock), {}); 1539 } 1540 rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock, 1541 srcStride, nextEntryBlock, llvm::None); 1542 curEntryBlock = nextEntryBlock; 1543 } 1544 } 1545 } 1546 } 1547 1548 static void fillInDynamicStridesForMemDescriptor( 1549 ConversionPatternRewriter &b, Location loc, Operation *op, 1550 TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, 1551 MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, 1552 ArrayRef<ReassociationIndices> reassociation) { 1553 if (srcType.getRank() > dstType.getRank()) 1554 fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, 1555 srcDesc, dstDesc, reassociation); 1556 else 1557 fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, 1558 reassociation); 1559 } 1560 1561 // ReshapeOp creates a new view descriptor of the proper rank. 1562 // For now, the only conversion supported is for target MemRef with static sizes 1563 // and strides. 1564 template <typename ReshapeOp> 1565 class ReassociatingReshapeOpConversion 1566 : public ConvertOpToLLVMPattern<ReshapeOp> { 1567 public: 1568 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1569 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1570 1571 LogicalResult 1572 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1573 ConversionPatternRewriter &rewriter) const override { 1574 MemRefType dstType = reshapeOp.getResultType(); 1575 MemRefType srcType = reshapeOp.getSrcType(); 1576 1577 int64_t offset; 1578 SmallVector<int64_t, 4> strides; 1579 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1580 return rewriter.notifyMatchFailure( 1581 reshapeOp, "failed to get stride and offset exprs"); 1582 } 1583 1584 MemRefDescriptor srcDesc(adaptor.getSrc()); 1585 Location loc = reshapeOp->getLoc(); 1586 auto dstDesc = MemRefDescriptor::undef( 1587 rewriter, loc, this->typeConverter->convertType(dstType)); 1588 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1589 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1590 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1591 1592 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1593 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1594 Type llvmIndexType = 1595 this->typeConverter->convertType(rewriter.getIndexType()); 1596 SmallVector<Value> dstShape = getDynamicOutputShape( 1597 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1598 srcStaticShape, srcDesc, dstStaticShape); 1599 for (auto &en : llvm::enumerate(dstShape)) 1600 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1601 1602 if (llvm::all_of(strides, isStaticStrideOrOffset)) { 1603 for (auto &en : llvm::enumerate(strides)) 1604 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1605 } else if (srcType.getLayout().isIdentity() && 1606 dstType.getLayout().isIdentity()) { 1607 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1608 rewriter.getIndexAttr(1)); 1609 Value stride = c1; 1610 for (auto dimIndex : 1611 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1612 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1613 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1614 } 1615 } else { 1616 // There could be mixed static/dynamic strides. For simplicity, we 1617 // recompute all strides if there is at least one dynamic stride. 1618 fillInDynamicStridesForMemDescriptor( 1619 rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, 1620 srcDesc, dstDesc, reshapeOp.getReassociationIndices()); 1621 } 1622 rewriter.replaceOp(reshapeOp, {dstDesc}); 1623 return success(); 1624 } 1625 }; 1626 1627 /// Conversion pattern that transforms a subview op into: 1628 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1629 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1630 /// and stride. 1631 /// The subview op is replaced by the descriptor. 1632 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1633 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1634 1635 LogicalResult 1636 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1637 ConversionPatternRewriter &rewriter) const override { 1638 auto loc = subViewOp.getLoc(); 1639 1640 auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>(); 1641 auto sourceElementTy = 1642 typeConverter->convertType(sourceMemRefType.getElementType()); 1643 1644 auto viewMemRefType = subViewOp.getType(); 1645 auto inferredType = 1646 memref::SubViewOp::inferResultType( 1647 subViewOp.getSourceType(), 1648 extractFromI64ArrayAttr(subViewOp.getStaticOffsets()), 1649 extractFromI64ArrayAttr(subViewOp.getStaticSizes()), 1650 extractFromI64ArrayAttr(subViewOp.getStaticStrides())) 1651 .cast<MemRefType>(); 1652 auto targetElementTy = 1653 typeConverter->convertType(viewMemRefType.getElementType()); 1654 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1655 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1656 !LLVM::isCompatibleType(sourceElementTy) || 1657 !LLVM::isCompatibleType(targetElementTy) || 1658 !LLVM::isCompatibleType(targetDescTy)) 1659 return failure(); 1660 1661 // Extract the offset and strides from the type. 1662 int64_t offset; 1663 SmallVector<int64_t, 4> strides; 1664 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1665 if (failed(successStrides)) 1666 return failure(); 1667 1668 // Create the descriptor. 1669 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1670 return failure(); 1671 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1672 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1673 1674 // Copy the buffer pointer from the old descriptor to the new one. 1675 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1676 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1677 loc, 1678 LLVM::LLVMPointerType::get(targetElementTy, 1679 viewMemRefType.getMemorySpaceAsInt()), 1680 extracted); 1681 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1682 1683 // Copy the aligned pointer from the old descriptor to the new one. 1684 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1685 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1686 loc, 1687 LLVM::LLVMPointerType::get(targetElementTy, 1688 viewMemRefType.getMemorySpaceAsInt()), 1689 extracted); 1690 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1691 1692 size_t inferredShapeRank = inferredType.getRank(); 1693 size_t resultShapeRank = viewMemRefType.getRank(); 1694 1695 // Extract strides needed to compute offset. 1696 SmallVector<Value, 4> strideValues; 1697 strideValues.reserve(inferredShapeRank); 1698 for (unsigned i = 0; i < inferredShapeRank; ++i) 1699 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1700 1701 // Offset. 1702 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1703 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1704 targetMemRef.setConstantOffset(rewriter, loc, offset); 1705 } else { 1706 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1707 // `inferredShapeRank` may be larger than the number of offset operands 1708 // because of trailing semantics. In this case, the offset is guaranteed 1709 // to be interpreted as 0 and we can just skip the extra dimensions. 1710 for (unsigned i = 0, e = std::min(inferredShapeRank, 1711 subViewOp.getMixedOffsets().size()); 1712 i < e; ++i) { 1713 Value offset = 1714 // TODO: need OpFoldResult ODS adaptor to clean this up. 1715 subViewOp.isDynamicOffset(i) 1716 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1717 : rewriter.create<LLVM::ConstantOp>( 1718 loc, llvmIndexType, 1719 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1720 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1721 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1722 } 1723 targetMemRef.setOffset(rewriter, loc, baseOffset); 1724 } 1725 1726 // Update sizes and strides. 1727 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1728 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1729 assert(mixedSizes.size() == mixedStrides.size() && 1730 "expected sizes and strides of equal length"); 1731 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1732 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1733 i >= 0 && j >= 0; --i) { 1734 if (unusedDims.test(i)) 1735 continue; 1736 1737 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1738 // In this case, the size is guaranteed to be interpreted as Dim and the 1739 // stride as 1. 1740 Value size, stride; 1741 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1742 // If the static size is available, use it directly. This is similar to 1743 // the folding of dim(constant-op) but removes the need for dim to be 1744 // aware of LLVM constants and for this pass to be aware of std 1745 // constants. 1746 int64_t staticSize = 1747 subViewOp.getSource().getType().cast<MemRefType>().getShape()[i]; 1748 if (staticSize != ShapedType::kDynamicSize) { 1749 size = rewriter.create<LLVM::ConstantOp>( 1750 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1751 } else { 1752 Value pos = rewriter.create<LLVM::ConstantOp>( 1753 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1754 Value dim = 1755 rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos); 1756 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1757 loc, llvmIndexType, dim); 1758 size = cast.getResult(0); 1759 } 1760 stride = rewriter.create<LLVM::ConstantOp>( 1761 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1762 } else { 1763 // TODO: need OpFoldResult ODS adaptor to clean this up. 1764 size = 1765 subViewOp.isDynamicSize(i) 1766 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1767 : rewriter.create<LLVM::ConstantOp>( 1768 loc, llvmIndexType, 1769 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1770 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1771 stride = rewriter.create<LLVM::ConstantOp>( 1772 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1773 } else { 1774 stride = 1775 subViewOp.isDynamicStride(i) 1776 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1777 : rewriter.create<LLVM::ConstantOp>( 1778 loc, llvmIndexType, 1779 rewriter.getI64IntegerAttr( 1780 subViewOp.getStaticStride(i))); 1781 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1782 } 1783 } 1784 targetMemRef.setSize(rewriter, loc, j, size); 1785 targetMemRef.setStride(rewriter, loc, j, stride); 1786 j--; 1787 } 1788 1789 rewriter.replaceOp(subViewOp, {targetMemRef}); 1790 return success(); 1791 } 1792 }; 1793 1794 /// Conversion pattern that transforms a transpose op into: 1795 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1796 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1797 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1798 /// and stride. Size and stride are permutations of the original values. 1799 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1800 /// The transpose op is replaced by the alloca'ed pointer. 1801 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1802 public: 1803 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1804 1805 LogicalResult 1806 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1807 ConversionPatternRewriter &rewriter) const override { 1808 auto loc = transposeOp.getLoc(); 1809 MemRefDescriptor viewMemRef(adaptor.getIn()); 1810 1811 // No permutation, early exit. 1812 if (transposeOp.getPermutation().isIdentity()) 1813 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1814 1815 auto targetMemRef = MemRefDescriptor::undef( 1816 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1817 1818 // Copy the base and aligned pointers from the old descriptor to the new 1819 // one. 1820 targetMemRef.setAllocatedPtr(rewriter, loc, 1821 viewMemRef.allocatedPtr(rewriter, loc)); 1822 targetMemRef.setAlignedPtr(rewriter, loc, 1823 viewMemRef.alignedPtr(rewriter, loc)); 1824 1825 // Copy the offset pointer from the old descriptor to the new one. 1826 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1827 1828 // Iterate over the dimensions and apply size/stride permutation. 1829 for (const auto &en : 1830 llvm::enumerate(transposeOp.getPermutation().getResults())) { 1831 int sourcePos = en.index(); 1832 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1833 targetMemRef.setSize(rewriter, loc, targetPos, 1834 viewMemRef.size(rewriter, loc, sourcePos)); 1835 targetMemRef.setStride(rewriter, loc, targetPos, 1836 viewMemRef.stride(rewriter, loc, sourcePos)); 1837 } 1838 1839 rewriter.replaceOp(transposeOp, {targetMemRef}); 1840 return success(); 1841 } 1842 }; 1843 1844 /// Conversion pattern that transforms an op into: 1845 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1846 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1847 /// and stride. 1848 /// The view op is replaced by the descriptor. 1849 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1850 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1851 1852 // Build and return the value for the idx^th shape dimension, either by 1853 // returning the constant shape dimension or counting the proper dynamic size. 1854 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1855 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1856 unsigned idx) const { 1857 assert(idx < shape.size()); 1858 if (!ShapedType::isDynamic(shape[idx])) 1859 return createIndexConstant(rewriter, loc, shape[idx]); 1860 // Count the number of dynamic dims in range [0, idx] 1861 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1862 return ShapedType::isDynamic(v); 1863 }); 1864 return dynamicSizes[nDynamic]; 1865 } 1866 1867 // Build and return the idx^th stride, either by returning the constant stride 1868 // or by computing the dynamic stride from the current `runningStride` and 1869 // `nextSize`. The caller should keep a running stride and update it with the 1870 // result returned by this function. 1871 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1872 ArrayRef<int64_t> strides, Value nextSize, 1873 Value runningStride, unsigned idx) const { 1874 assert(idx < strides.size()); 1875 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1876 return createIndexConstant(rewriter, loc, strides[idx]); 1877 if (nextSize) 1878 return runningStride 1879 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1880 : nextSize; 1881 assert(!runningStride); 1882 return createIndexConstant(rewriter, loc, 1); 1883 } 1884 1885 LogicalResult 1886 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1887 ConversionPatternRewriter &rewriter) const override { 1888 auto loc = viewOp.getLoc(); 1889 1890 auto viewMemRefType = viewOp.getType(); 1891 auto targetElementTy = 1892 typeConverter->convertType(viewMemRefType.getElementType()); 1893 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1894 if (!targetDescTy || !targetElementTy || 1895 !LLVM::isCompatibleType(targetElementTy) || 1896 !LLVM::isCompatibleType(targetDescTy)) 1897 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1898 failure(); 1899 1900 int64_t offset; 1901 SmallVector<int64_t, 4> strides; 1902 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1903 if (failed(successStrides)) 1904 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1905 assert(offset == 0 && "expected offset to be 0"); 1906 1907 // Target memref must be contiguous in memory (innermost stride is 1), or 1908 // empty (special case when at least one of the memref dimensions is 0). 1909 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) 1910 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1911 failure(); 1912 1913 // Create the descriptor. 1914 MemRefDescriptor sourceMemRef(adaptor.getSource()); 1915 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1916 1917 // Field 1: Copy the allocated pointer, used for malloc/free. 1918 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1919 auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>(); 1920 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1921 loc, 1922 LLVM::LLVMPointerType::get(targetElementTy, 1923 srcMemRefType.getMemorySpaceAsInt()), 1924 allocatedPtr); 1925 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1926 1927 // Field 2: Copy the actual aligned pointer to payload. 1928 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1929 alignedPtr = rewriter.create<LLVM::GEPOp>( 1930 loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift()); 1931 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1932 loc, 1933 LLVM::LLVMPointerType::get(targetElementTy, 1934 srcMemRefType.getMemorySpaceAsInt()), 1935 alignedPtr); 1936 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1937 1938 // Field 3: The offset in the resulting type must be 0. This is because of 1939 // the type change: an offset on srcType* may not be expressible as an 1940 // offset on dstType*. 1941 targetMemRef.setOffset(rewriter, loc, 1942 createIndexConstant(rewriter, loc, offset)); 1943 1944 // Early exit for 0-D corner case. 1945 if (viewMemRefType.getRank() == 0) 1946 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1947 1948 // Fields 4 and 5: Update sizes and strides. 1949 Value stride = nullptr, nextSize = nullptr; 1950 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1951 // Update size. 1952 Value size = getSize(rewriter, loc, viewMemRefType.getShape(), 1953 adaptor.getSizes(), i); 1954 targetMemRef.setSize(rewriter, loc, i, size); 1955 // Update stride. 1956 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1957 targetMemRef.setStride(rewriter, loc, i, stride); 1958 nextSize = size; 1959 } 1960 1961 rewriter.replaceOp(viewOp, {targetMemRef}); 1962 return success(); 1963 } 1964 }; 1965 1966 //===----------------------------------------------------------------------===// 1967 // AtomicRMWOpLowering 1968 //===----------------------------------------------------------------------===// 1969 1970 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1971 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1972 static Optional<LLVM::AtomicBinOp> 1973 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1974 switch (atomicOp.getKind()) { 1975 case arith::AtomicRMWKind::addf: 1976 return LLVM::AtomicBinOp::fadd; 1977 case arith::AtomicRMWKind::addi: 1978 return LLVM::AtomicBinOp::add; 1979 case arith::AtomicRMWKind::assign: 1980 return LLVM::AtomicBinOp::xchg; 1981 case arith::AtomicRMWKind::maxs: 1982 return LLVM::AtomicBinOp::max; 1983 case arith::AtomicRMWKind::maxu: 1984 return LLVM::AtomicBinOp::umax; 1985 case arith::AtomicRMWKind::mins: 1986 return LLVM::AtomicBinOp::min; 1987 case arith::AtomicRMWKind::minu: 1988 return LLVM::AtomicBinOp::umin; 1989 case arith::AtomicRMWKind::ori: 1990 return LLVM::AtomicBinOp::_or; 1991 case arith::AtomicRMWKind::andi: 1992 return LLVM::AtomicBinOp::_and; 1993 default: 1994 return llvm::None; 1995 } 1996 llvm_unreachable("Invalid AtomicRMWKind"); 1997 } 1998 1999 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 2000 using Base::Base; 2001 2002 LogicalResult 2003 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 2004 ConversionPatternRewriter &rewriter) const override { 2005 if (failed(match(atomicOp))) 2006 return failure(); 2007 auto maybeKind = matchSimpleAtomicOp(atomicOp); 2008 if (!maybeKind) 2009 return failure(); 2010 auto resultType = adaptor.getValue().getType(); 2011 auto memRefType = atomicOp.getMemRefType(); 2012 auto dataPtr = 2013 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), 2014 adaptor.getIndices(), rewriter); 2015 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 2016 atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(), 2017 LLVM::AtomicOrdering::acq_rel); 2018 return success(); 2019 } 2020 }; 2021 2022 } // namespace 2023 2024 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 2025 RewritePatternSet &patterns) { 2026 // clang-format off 2027 patterns.add< 2028 AllocaOpLowering, 2029 AllocaScopeOpLowering, 2030 AtomicRMWOpLowering, 2031 AssumeAlignmentOpLowering, 2032 DimOpLowering, 2033 GenericAtomicRMWOpLowering, 2034 GlobalMemrefOpLowering, 2035 GetGlobalMemrefOpLowering, 2036 LoadOpLowering, 2037 MemRefCastOpLowering, 2038 MemRefCopyOpLowering, 2039 MemRefReinterpretCastOpLowering, 2040 MemRefReshapeOpLowering, 2041 PrefetchOpLowering, 2042 RankOpLowering, 2043 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 2044 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 2045 StoreOpLowering, 2046 SubViewOpLowering, 2047 TransposeOpLowering, 2048 ViewOpLowering>(converter); 2049 // clang-format on 2050 auto allocLowering = converter.getOptions().allocLowering; 2051 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 2052 patterns.add<AlignedAllocOpLowering, AlignedDeallocOpLowering>(converter); 2053 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 2054 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 2055 } 2056 2057 namespace { 2058 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 2059 MemRefToLLVMPass() = default; 2060 2061 void runOnOperation() override { 2062 Operation *op = getOperation(); 2063 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 2064 LowerToLLVMOptions options(&getContext(), 2065 dataLayoutAnalysis.getAtOrAbove(op)); 2066 options.allocLowering = 2067 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 2068 : LowerToLLVMOptions::AllocLowering::Malloc); 2069 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 2070 options.overrideIndexBitwidth(indexBitwidth); 2071 2072 LLVMTypeConverter typeConverter(&getContext(), options, 2073 &dataLayoutAnalysis); 2074 RewritePatternSet patterns(&getContext()); 2075 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 2076 LLVMConversionTarget target(getContext()); 2077 target.addLegalOp<func::FuncOp>(); 2078 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 2079 signalPassFailure(); 2080 } 2081 }; 2082 } // namespace 2083 2084 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 2085 return std::make_unique<MemRefToLLVMPass>(); 2086 } 2087