1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "llvm/ADT/SmallBitVector.h" 24 25 using namespace mlir; 26 27 namespace { 28 29 bool isStaticStrideOrOffset(int64_t strideOrOffset) { 30 return !ShapedType::isDynamicStrideOrOffset(strideOrOffset); 31 } 32 33 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 34 AllocOpLowering(LLVMTypeConverter &converter) 35 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 36 converter) {} 37 38 LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const { 39 bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; 40 41 if (useGenericFn) 42 return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType()); 43 44 return LLVM::lookupOrCreateMallocFn(module, getIndexType()); 45 } 46 47 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 48 Location loc, Value sizeBytes, 49 Operation *op) const override { 50 // Heap allocations. 51 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 52 MemRefType memRefType = allocOp.getType(); 53 54 Value alignment; 55 if (auto alignmentAttr = allocOp.getAlignment()) { 56 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 57 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 58 // In the case where no alignment is specified, we may want to override 59 // `malloc's` behavior. `malloc` typically aligns at the size of the 60 // biggest scalar on a target HW. For non-scalars, use the natural 61 // alignment of the LLVM type given by the LLVM DataLayout. 62 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 63 } 64 65 if (alignment) { 66 // Adjust the allocation size to consider alignment. 67 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 68 } 69 70 // Allocate the underlying buffer and store a pointer to it in the MemRef 71 // descriptor. 72 Type elementPtrType = this->getElementPtrType(memRefType); 73 auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>()); 74 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 75 getVoidPtrType()); 76 Value allocatedPtr = 77 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 78 79 Value alignedPtr = allocatedPtr; 80 if (alignment) { 81 // Compute the aligned type pointer. 82 Value allocatedInt = 83 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 84 Value alignmentInt = 85 createAligned(rewriter, loc, allocatedInt, alignment); 86 alignedPtr = 87 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 88 } 89 90 return std::make_tuple(allocatedPtr, alignedPtr); 91 } 92 }; 93 94 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 95 AlignedAllocOpLowering(LLVMTypeConverter &converter) 96 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 97 converter) {} 98 99 /// Returns the memref's element size in bytes using the data layout active at 100 /// `op`. 101 // TODO: there are other places where this is used. Expose publicly? 102 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 103 const DataLayout *layout = &defaultLayout; 104 if (const DataLayoutAnalysis *analysis = 105 getTypeConverter()->getDataLayoutAnalysis()) { 106 layout = &analysis->getAbove(op); 107 } 108 Type elementType = memRefType.getElementType(); 109 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 110 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 111 *layout); 112 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 113 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 114 memRefElementType, *layout); 115 return layout->getTypeSize(elementType); 116 } 117 118 /// Returns true if the memref size in bytes is known to be a multiple of 119 /// factor assuming the data layout active at `op`. 120 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 121 Operation *op) const { 122 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 123 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 124 if (ShapedType::isDynamic(type.getDimSize(i))) 125 continue; 126 sizeDivisor = sizeDivisor * type.getDimSize(i); 127 } 128 return sizeDivisor % factor == 0; 129 } 130 131 /// Returns the alignment to be used for the allocation call itself. 132 /// aligned_alloc requires the allocation size to be a power of two, and the 133 /// allocation size to be a multiple of alignment, 134 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 135 if (Optional<uint64_t> alignment = allocOp.getAlignment()) 136 return *alignment; 137 138 // Whenever we don't have alignment set, we will use an alignment 139 // consistent with the element type; since the allocation size has to be a 140 // power of two, we will bump to the next power of two if it already isn't. 141 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 142 return std::max(kMinAlignedAllocAlignment, 143 llvm::PowerOf2Ceil(eltSizeBytes)); 144 } 145 146 LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const { 147 bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; 148 149 if (useGenericFn) 150 return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType()); 151 152 return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType()); 153 } 154 155 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 156 Location loc, Value sizeBytes, 157 Operation *op) const override { 158 // Heap allocations. 159 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 160 MemRefType memRefType = allocOp.getType(); 161 int64_t alignment = getAllocationAlignment(allocOp); 162 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 163 164 // aligned_alloc requires size to be a multiple of alignment; we will pad 165 // the size to the next multiple if necessary. 166 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 167 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 168 169 Type elementPtrType = this->getElementPtrType(memRefType); 170 auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>()); 171 auto results = 172 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 173 getVoidPtrType()); 174 Value allocatedPtr = 175 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 176 177 return std::make_tuple(allocatedPtr, allocatedPtr); 178 } 179 180 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 181 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 182 183 /// Default layout to use in absence of the corresponding analysis. 184 DataLayout defaultLayout; 185 }; 186 187 // Out of line definition, required till C++17. 188 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 189 190 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 191 AllocaOpLowering(LLVMTypeConverter &converter) 192 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 193 converter) {} 194 195 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 196 /// is set to null for stack allocations. `accessAlignment` is set if 197 /// alignment is needed post allocation (for eg. in conjunction with malloc). 198 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 199 Location loc, Value sizeBytes, 200 Operation *op) const override { 201 202 // With alloca, one gets a pointer to the element type right away. 203 // For stack allocations. 204 auto allocaOp = cast<memref::AllocaOp>(op); 205 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 206 207 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 208 loc, elementPtrType, sizeBytes, 209 allocaOp.getAlignment() ? *allocaOp.getAlignment() : 0); 210 211 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 212 } 213 }; 214 215 struct AllocaScopeOpLowering 216 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 217 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 218 219 LogicalResult 220 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 221 ConversionPatternRewriter &rewriter) const override { 222 OpBuilder::InsertionGuard guard(rewriter); 223 Location loc = allocaScopeOp.getLoc(); 224 225 // Split the current block before the AllocaScopeOp to create the inlining 226 // point. 227 auto *currentBlock = rewriter.getInsertionBlock(); 228 auto *remainingOpsBlock = 229 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 230 Block *continueBlock; 231 if (allocaScopeOp.getNumResults() == 0) { 232 continueBlock = remainingOpsBlock; 233 } else { 234 continueBlock = rewriter.createBlock( 235 remainingOpsBlock, allocaScopeOp.getResultTypes(), 236 SmallVector<Location>(allocaScopeOp->getNumResults(), 237 allocaScopeOp.getLoc())); 238 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 239 } 240 241 // Inline body region. 242 Block *beforeBody = &allocaScopeOp.getBodyRegion().front(); 243 Block *afterBody = &allocaScopeOp.getBodyRegion().back(); 244 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock); 245 246 // Save stack and then branch into the body of the region. 247 rewriter.setInsertionPointToEnd(currentBlock); 248 auto stackSaveOp = 249 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 250 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 251 252 // Replace the alloca_scope return with a branch that jumps out of the body. 253 // Stack restore before leaving the body region. 254 rewriter.setInsertionPointToEnd(afterBody); 255 auto returnOp = 256 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 257 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 258 returnOp, returnOp.getResults(), continueBlock); 259 260 // Insert stack restore before jumping out the body of the region. 261 rewriter.setInsertionPoint(branchOp); 262 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 263 264 // Replace the op with values return from the body region. 265 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 266 267 return success(); 268 } 269 }; 270 271 struct AssumeAlignmentOpLowering 272 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 273 using ConvertOpToLLVMPattern< 274 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 275 276 LogicalResult 277 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 278 ConversionPatternRewriter &rewriter) const override { 279 Value memref = adaptor.getMemref(); 280 unsigned alignment = op.getAlignment(); 281 auto loc = op.getLoc(); 282 283 MemRefDescriptor memRefDescriptor(memref); 284 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 285 286 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 287 // the asserted memref.alignedPtr isn't used anywhere else, as the real 288 // users like load/store/views always re-extract memref.alignedPtr as they 289 // get lowered. 290 // 291 // This relies on LLVM's CSE optimization (potentially after SROA), since 292 // after CSE all memref.alignedPtr instances get de-duplicated into the same 293 // pointer SSA value. 294 auto intPtrType = 295 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 296 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 297 Value mask = 298 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 299 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 300 rewriter.create<LLVM::AssumeOp>( 301 loc, rewriter.create<LLVM::ICmpOp>( 302 loc, LLVM::ICmpPredicate::eq, 303 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 304 305 rewriter.eraseOp(op); 306 return success(); 307 } 308 }; 309 310 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 311 // The memref descriptor being an SSA value, there is no need to clean it up 312 // in any way. 313 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 314 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 315 316 explicit DeallocOpLowering(LLVMTypeConverter &converter) 317 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 318 319 LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const { 320 bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; 321 322 if (useGenericFn) 323 return LLVM::lookupOrCreateGenericFreeFn(module); 324 325 return LLVM::lookupOrCreateFreeFn(module); 326 } 327 328 LogicalResult 329 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 330 ConversionPatternRewriter &rewriter) const override { 331 // Insert the `free` declaration if it is not already present. 332 auto freeFunc = getFreeFn(op->getParentOfType<ModuleOp>()); 333 MemRefDescriptor memref(adaptor.getMemref()); 334 Value casted = rewriter.create<LLVM::BitcastOp>( 335 op.getLoc(), getVoidPtrType(), 336 memref.allocatedPtr(rewriter, op.getLoc())); 337 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 338 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 339 return success(); 340 } 341 }; 342 343 // A `dim` is converted to a constant for static sizes and to an access to the 344 // size stored in the memref descriptor for dynamic sizes. 345 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 346 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 347 348 LogicalResult 349 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 350 ConversionPatternRewriter &rewriter) const override { 351 Type operandType = dimOp.getSource().getType(); 352 if (operandType.isa<UnrankedMemRefType>()) { 353 rewriter.replaceOp( 354 dimOp, {extractSizeOfUnrankedMemRef( 355 operandType, dimOp, adaptor.getOperands(), rewriter)}); 356 357 return success(); 358 } 359 if (operandType.isa<MemRefType>()) { 360 rewriter.replaceOp( 361 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 362 adaptor.getOperands(), rewriter)}); 363 return success(); 364 } 365 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 366 } 367 368 private: 369 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 370 OpAdaptor adaptor, 371 ConversionPatternRewriter &rewriter) const { 372 Location loc = dimOp.getLoc(); 373 374 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 375 auto scalarMemRefType = 376 MemRefType::get({}, unrankedMemRefType.getElementType()); 377 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 378 379 // Extract pointer to the underlying ranked descriptor and bitcast it to a 380 // memref<element_type> descriptor pointer to minimize the number of GEP 381 // operations. 382 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource()); 383 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 384 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 385 loc, 386 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 387 addressSpace), 388 underlyingRankedDesc); 389 390 // Get pointer to offset field of memref<element_type> descriptor. 391 Type indexPtrTy = LLVM::LLVMPointerType::get( 392 getTypeConverter()->getIndexType(), addressSpace); 393 Value two = rewriter.create<LLVM::ConstantOp>( 394 loc, typeConverter->convertType(rewriter.getI32Type()), 395 rewriter.getI32IntegerAttr(2)); 396 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 397 loc, indexPtrTy, scalarMemRefDescPtr, 398 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 399 400 // The size value that we have to extract can be obtained using GEPop with 401 // `dimOp.index() + 1` index argument. 402 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 403 loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex()); 404 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 405 ValueRange({idxPlusOne})); 406 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 407 } 408 409 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 410 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 411 return idx; 412 413 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>()) 414 return constantOp.getValue() 415 .cast<IntegerAttr>() 416 .getValue() 417 .getSExtValue(); 418 419 return llvm::None; 420 } 421 422 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 423 OpAdaptor adaptor, 424 ConversionPatternRewriter &rewriter) const { 425 Location loc = dimOp.getLoc(); 426 427 // Take advantage if index is constant. 428 MemRefType memRefType = operandType.cast<MemRefType>(); 429 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 430 int64_t i = *index; 431 if (memRefType.isDynamicDim(i)) { 432 // extract dynamic size from the memref descriptor. 433 MemRefDescriptor descriptor(adaptor.getSource()); 434 return descriptor.size(rewriter, loc, i); 435 } 436 // Use constant for static size. 437 int64_t dimSize = memRefType.getDimSize(i); 438 return createIndexConstant(rewriter, loc, dimSize); 439 } 440 Value index = adaptor.getIndex(); 441 int64_t rank = memRefType.getRank(); 442 MemRefDescriptor memrefDescriptor(adaptor.getSource()); 443 return memrefDescriptor.size(rewriter, loc, index, rank); 444 } 445 }; 446 447 /// Common base for load and store operations on MemRefs. Restricts the match 448 /// to supported MemRef types. Provides functionality to emit code accessing a 449 /// specific element of the underlying data buffer. 450 template <typename Derived> 451 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 452 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 453 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 454 using Base = LoadStoreOpLowering<Derived>; 455 456 LogicalResult match(Derived op) const override { 457 MemRefType type = op.getMemRefType(); 458 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 459 } 460 }; 461 462 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 463 /// retried until it succeeds in atomically storing a new value into memory. 464 /// 465 /// +---------------------------------+ 466 /// | <code before the AtomicRMWOp> | 467 /// | <compute initial %loaded> | 468 /// | cf.br loop(%loaded) | 469 /// +---------------------------------+ 470 /// | 471 /// -------| | 472 /// | v v 473 /// | +--------------------------------+ 474 /// | | loop(%loaded): | 475 /// | | <body contents> | 476 /// | | %pair = cmpxchg | 477 /// | | %ok = %pair[0] | 478 /// | | %new = %pair[1] | 479 /// | | cf.cond_br %ok, end, loop(%new) | 480 /// | +--------------------------------+ 481 /// | | | 482 /// |----------- | 483 /// v 484 /// +--------------------------------+ 485 /// | end: | 486 /// | <code after the AtomicRMWOp> | 487 /// +--------------------------------+ 488 /// 489 struct GenericAtomicRMWOpLowering 490 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 491 using Base::Base; 492 493 LogicalResult 494 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 495 ConversionPatternRewriter &rewriter) const override { 496 auto loc = atomicOp.getLoc(); 497 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 498 499 // Split the block into initial, loop, and ending parts. 500 auto *initBlock = rewriter.getInsertionBlock(); 501 auto *loopBlock = rewriter.createBlock( 502 initBlock->getParent(), std::next(Region::iterator(initBlock)), 503 valueType, loc); 504 auto *endBlock = rewriter.createBlock( 505 loopBlock->getParent(), std::next(Region::iterator(loopBlock))); 506 507 // Operations range to be moved to `endBlock`. 508 auto opsToMoveStart = atomicOp->getIterator(); 509 auto opsToMoveEnd = initBlock->back().getIterator(); 510 511 // Compute the loaded value and branch to the loop block. 512 rewriter.setInsertionPointToEnd(initBlock); 513 auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>(); 514 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), 515 adaptor.getIndices(), rewriter); 516 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr); 517 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 518 519 // Prepare the body of the loop block. 520 rewriter.setInsertionPointToStart(loopBlock); 521 522 // Clone the GenericAtomicRMWOp region and extract the result. 523 auto loopArgument = loopBlock->getArgument(0); 524 BlockAndValueMapping mapping; 525 mapping.map(atomicOp.getCurrentValue(), loopArgument); 526 Block &entryBlock = atomicOp.body().front(); 527 for (auto &nestedOp : entryBlock.without_terminator()) { 528 Operation *clone = rewriter.clone(nestedOp, mapping); 529 mapping.map(nestedOp.getResults(), clone->getResults()); 530 } 531 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 532 533 // Prepare the epilog of the loop block. 534 // Append the cmpxchg op to the end of the loop block. 535 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 536 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 537 auto boolType = IntegerType::get(rewriter.getContext(), 1); 538 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), 539 {valueType, boolType}); 540 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 541 loc, pairType, dataPtr, loopArgument, result, successOrdering, 542 failureOrdering); 543 // Extract the %new_loaded and %ok values from the pair. 544 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>( 545 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); 546 Value ok = rewriter.create<LLVM::ExtractValueOp>( 547 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); 548 549 // Conditionally branch to the end or back to the loop depending on %ok. 550 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 551 loopBlock, newLoaded); 552 553 rewriter.setInsertionPointToEnd(endBlock); 554 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), 555 std::next(opsToMoveEnd), rewriter); 556 557 // The 'result' of the atomic_rmw op is the newly loaded value. 558 rewriter.replaceOp(atomicOp, {newLoaded}); 559 560 return success(); 561 } 562 563 private: 564 // Clones a segment of ops [start, end) and erases the original. 565 void moveOpsRange(ValueRange oldResult, ValueRange newResult, 566 Block::iterator start, Block::iterator end, 567 ConversionPatternRewriter &rewriter) const { 568 BlockAndValueMapping mapping; 569 mapping.map(oldResult, newResult); 570 SmallVector<Operation *, 2> opsToErase; 571 for (auto it = start; it != end; ++it) { 572 rewriter.clone(*it, mapping); 573 opsToErase.push_back(&*it); 574 } 575 for (auto *it : opsToErase) 576 rewriter.eraseOp(it); 577 } 578 }; 579 580 /// Returns the LLVM type of the global variable given the memref type `type`. 581 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 582 LLVMTypeConverter &typeConverter) { 583 // LLVM type for a global memref will be a multi-dimension array. For 584 // declarations or uninitialized global memrefs, we can potentially flatten 585 // this to a 1D array. However, for memref.global's with an initial value, 586 // we do not intend to flatten the ElementsAttribute when going from std -> 587 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 588 Type elementType = typeConverter.convertType(type.getElementType()); 589 Type arrayTy = elementType; 590 // Shape has the outermost dim at index 0, so need to walk it backwards 591 for (int64_t dim : llvm::reverse(type.getShape())) 592 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 593 return arrayTy; 594 } 595 596 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 597 struct GlobalMemrefOpLowering 598 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 599 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 600 601 LogicalResult 602 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 603 ConversionPatternRewriter &rewriter) const override { 604 MemRefType type = global.getType(); 605 if (!isConvertibleAndHasIdentityMaps(type)) 606 return failure(); 607 608 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 609 610 LLVM::Linkage linkage = 611 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 612 613 Attribute initialValue = nullptr; 614 if (!global.isExternal() && !global.isUninitialized()) { 615 auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>(); 616 initialValue = elementsAttr; 617 618 // For scalar memrefs, the global variable created is of the element type, 619 // so unpack the elements attribute to extract the value. 620 if (type.getRank() == 0) 621 initialValue = elementsAttr.getSplatValue<Attribute>(); 622 } 623 624 uint64_t alignment = global.getAlignment().value_or(0); 625 626 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 627 global, arrayTy, global.getConstant(), linkage, global.getSymName(), 628 initialValue, alignment, type.getMemorySpaceAsInt()); 629 if (!global.isExternal() && global.isUninitialized()) { 630 Block *blk = new Block(); 631 newGlobal.getInitializerRegion().push_back(blk); 632 rewriter.setInsertionPointToStart(blk); 633 Value undef[] = { 634 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 635 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 636 } 637 return success(); 638 } 639 }; 640 641 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 642 /// the first element stashed into the descriptor. This reuses 643 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 644 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 645 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 646 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 647 converter) {} 648 649 /// Buffer "allocation" for memref.get_global op is getting the address of 650 /// the global variable referenced. 651 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 652 Location loc, Value sizeBytes, 653 Operation *op) const override { 654 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 655 MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>(); 656 unsigned memSpace = type.getMemorySpaceAsInt(); 657 658 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 659 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 660 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), 661 getGlobalOp.getName()); 662 663 // Get the address of the first element in the array by creating a GEP with 664 // the address of the GV as the base, and (rank + 1) number of 0 indices. 665 Type elementType = typeConverter->convertType(type.getElementType()); 666 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 667 668 SmallVector<Value> operands; 669 operands.insert(operands.end(), type.getRank() + 1, 670 createIndexConstant(rewriter, loc, 0)); 671 auto gep = 672 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 673 674 // We do not expect the memref obtained using `memref.get_global` to be 675 // ever deallocated. Set the allocated pointer to be known bad value to 676 // help debug if that ever happens. 677 auto intPtrType = getIntPtrType(memSpace); 678 Value deadBeefConst = 679 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 680 auto deadBeefPtr = 681 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 682 683 // Both allocated and aligned pointers are same. We could potentially stash 684 // a nullptr for the allocated pointer since we do not expect any dealloc. 685 return std::make_tuple(deadBeefPtr, gep); 686 } 687 }; 688 689 // Load operation is lowered to obtaining a pointer to the indexed element 690 // and loading it. 691 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 692 using Base::Base; 693 694 LogicalResult 695 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 696 ConversionPatternRewriter &rewriter) const override { 697 auto type = loadOp.getMemRefType(); 698 699 Value dataPtr = 700 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), 701 adaptor.getIndices(), rewriter); 702 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 703 return success(); 704 } 705 }; 706 707 // Store operation is lowered to obtaining a pointer to the indexed element, 708 // and storing the given value to it. 709 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 710 using Base::Base; 711 712 LogicalResult 713 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 714 ConversionPatternRewriter &rewriter) const override { 715 auto type = op.getMemRefType(); 716 717 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), 718 adaptor.getIndices(), rewriter); 719 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr); 720 return success(); 721 } 722 }; 723 724 // The prefetch operation is lowered in a way similar to the load operation 725 // except that the llvm.prefetch operation is used for replacement. 726 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 727 using Base::Base; 728 729 LogicalResult 730 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 731 ConversionPatternRewriter &rewriter) const override { 732 auto type = prefetchOp.getMemRefType(); 733 auto loc = prefetchOp.getLoc(); 734 735 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(), 736 adaptor.getIndices(), rewriter); 737 738 // Replace with llvm.prefetch. 739 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 740 auto isWrite = rewriter.create<LLVM::ConstantOp>( 741 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite())); 742 auto localityHint = rewriter.create<LLVM::ConstantOp>( 743 loc, llvmI32Type, 744 rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint())); 745 auto isData = rewriter.create<LLVM::ConstantOp>( 746 loc, llvmI32Type, 747 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache())); 748 749 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 750 localityHint, isData); 751 return success(); 752 } 753 }; 754 755 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 756 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 757 758 LogicalResult 759 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 760 ConversionPatternRewriter &rewriter) const override { 761 Location loc = op.getLoc(); 762 Type operandType = op.getMemref().getType(); 763 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 764 UnrankedMemRefDescriptor desc(adaptor.getMemref()); 765 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 766 return success(); 767 } 768 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 769 rewriter.replaceOp( 770 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 771 return success(); 772 } 773 return failure(); 774 } 775 }; 776 777 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 778 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 779 780 LogicalResult match(memref::CastOp memRefCastOp) const override { 781 Type srcType = memRefCastOp.getOperand().getType(); 782 Type dstType = memRefCastOp.getType(); 783 784 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 785 // used for type erasure. For now they must preserve underlying element type 786 // and require source and result type to have the same rank. Therefore, 787 // perform a sanity check that the underlying structs are the same. Once op 788 // semantics are relaxed we can revisit. 789 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 790 return success(typeConverter->convertType(srcType) == 791 typeConverter->convertType(dstType)); 792 793 // At least one of the operands is unranked type 794 assert(srcType.isa<UnrankedMemRefType>() || 795 dstType.isa<UnrankedMemRefType>()); 796 797 // Unranked to unranked cast is disallowed 798 return !(srcType.isa<UnrankedMemRefType>() && 799 dstType.isa<UnrankedMemRefType>()) 800 ? success() 801 : failure(); 802 } 803 804 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 805 ConversionPatternRewriter &rewriter) const override { 806 auto srcType = memRefCastOp.getOperand().getType(); 807 auto dstType = memRefCastOp.getType(); 808 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 809 auto loc = memRefCastOp.getLoc(); 810 811 // For ranked/ranked case, just keep the original descriptor. 812 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 813 return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); 814 815 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 816 // Casting ranked to unranked memref type 817 // Set the rank in the destination from the memref type 818 // Allocate space on the stack and copy the src memref descriptor 819 // Set the ptr in the destination to the stack space 820 auto srcMemRefType = srcType.cast<MemRefType>(); 821 int64_t rank = srcMemRefType.getRank(); 822 // ptr = AllocaOp sizeof(MemRefDescriptor) 823 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 824 loc, adaptor.getSource(), rewriter); 825 // voidptr = BitCastOp srcType* to void* 826 auto voidPtr = 827 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 828 .getResult(); 829 // rank = ConstantOp srcRank 830 auto rankVal = rewriter.create<LLVM::ConstantOp>( 831 loc, getIndexType(), rewriter.getIndexAttr(rank)); 832 // undef = UndefOp 833 UnrankedMemRefDescriptor memRefDesc = 834 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 835 // d1 = InsertValueOp undef, rank, 0 836 memRefDesc.setRank(rewriter, loc, rankVal); 837 // d2 = InsertValueOp d1, voidptr, 1 838 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 839 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 840 841 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 842 // Casting from unranked type to ranked. 843 // The operation is assumed to be doing a correct cast. If the destination 844 // type mismatches the unranked the type, it is undefined behavior. 845 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource()); 846 // ptr = ExtractValueOp src, 1 847 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 848 // castPtr = BitCastOp i8* to structTy* 849 auto castPtr = 850 rewriter 851 .create<LLVM::BitcastOp>( 852 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 853 .getResult(); 854 // struct = LoadOp castPtr 855 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 856 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 857 } else { 858 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 859 } 860 } 861 }; 862 863 /// Pattern to lower a `memref.copy` to llvm. 864 /// 865 /// For memrefs with identity layouts, the copy is lowered to the llvm 866 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 867 /// to the generic `MemrefCopyFn`. 868 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 869 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 870 871 LogicalResult 872 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 873 ConversionPatternRewriter &rewriter) const { 874 auto loc = op.getLoc(); 875 auto srcType = op.getSource().getType().dyn_cast<MemRefType>(); 876 877 MemRefDescriptor srcDesc(adaptor.getSource()); 878 879 // Compute number of elements. 880 Value numElements = rewriter.create<LLVM::ConstantOp>( 881 loc, getIndexType(), rewriter.getIndexAttr(1)); 882 for (int pos = 0; pos < srcType.getRank(); ++pos) { 883 auto size = srcDesc.size(rewriter, loc, pos); 884 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 885 } 886 887 // Get element size. 888 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 889 // Compute total. 890 Value totalSize = 891 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 892 893 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 894 Value srcOffset = srcDesc.offset(rewriter, loc); 895 Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(), 896 srcBasePtr, srcOffset); 897 MemRefDescriptor targetDesc(adaptor.getTarget()); 898 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 899 Value targetOffset = targetDesc.offset(rewriter, loc); 900 Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(), 901 targetBasePtr, targetOffset); 902 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 903 loc, typeConverter->convertType(rewriter.getI1Type()), 904 rewriter.getBoolAttr(false)); 905 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 906 isVolatile); 907 rewriter.eraseOp(op); 908 909 return success(); 910 } 911 912 LogicalResult 913 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 914 ConversionPatternRewriter &rewriter) const { 915 auto loc = op.getLoc(); 916 auto srcType = op.getSource().getType().cast<BaseMemRefType>(); 917 auto targetType = op.getTarget().getType().cast<BaseMemRefType>(); 918 919 // First make sure we have an unranked memref descriptor representation. 920 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 921 auto rank = rewriter.create<LLVM::ConstantOp>( 922 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 923 auto *typeConverter = getTypeConverter(); 924 auto ptr = 925 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 926 auto voidPtr = 927 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 928 .getResult(); 929 auto unrankedType = 930 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 931 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 932 unrankedType, 933 ValueRange{rank, voidPtr}); 934 }; 935 936 Value unrankedSource = srcType.hasRank() 937 ? makeUnranked(adaptor.getSource(), srcType) 938 : adaptor.getSource(); 939 Value unrankedTarget = targetType.hasRank() 940 ? makeUnranked(adaptor.getTarget(), targetType) 941 : adaptor.getTarget(); 942 943 // Now promote the unranked descriptors to the stack. 944 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 945 rewriter.getIndexAttr(1)); 946 auto promote = [&](Value desc) { 947 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 948 auto allocated = 949 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 950 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 951 return allocated; 952 }; 953 954 auto sourcePtr = promote(unrankedSource); 955 auto targetPtr = promote(unrankedTarget); 956 957 unsigned typeSize = 958 mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); 959 auto elemSize = rewriter.create<LLVM::ConstantOp>( 960 loc, getIndexType(), rewriter.getIndexAttr(typeSize)); 961 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 962 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 963 rewriter.create<LLVM::CallOp>(loc, copyFn, 964 ValueRange{elemSize, sourcePtr, targetPtr}); 965 rewriter.eraseOp(op); 966 967 return success(); 968 } 969 970 LogicalResult 971 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 972 ConversionPatternRewriter &rewriter) const override { 973 auto srcType = op.getSource().getType().cast<BaseMemRefType>(); 974 auto targetType = op.getTarget().getType().cast<BaseMemRefType>(); 975 976 auto isContiguousMemrefType = [](BaseMemRefType type) { 977 auto memrefType = type.dyn_cast<mlir::MemRefType>(); 978 // We can use memcpy for memrefs if they have an identity layout or are 979 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 980 // special case handled by memrefCopy. 981 return memrefType && 982 (memrefType.getLayout().isIdentity() || 983 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 984 isStaticShapeAndContiguousRowMajor(memrefType))); 985 }; 986 987 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 988 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 989 990 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 991 } 992 }; 993 994 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 995 /// memref type. In unranked case, the fields are extracted from the underlying 996 /// ranked descriptor. 997 static void extractPointersAndOffset(Location loc, 998 ConversionPatternRewriter &rewriter, 999 LLVMTypeConverter &typeConverter, 1000 Value originalOperand, 1001 Value convertedOperand, 1002 Value *allocatedPtr, Value *alignedPtr, 1003 Value *offset = nullptr) { 1004 Type operandType = originalOperand.getType(); 1005 if (operandType.isa<MemRefType>()) { 1006 MemRefDescriptor desc(convertedOperand); 1007 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 1008 *alignedPtr = desc.alignedPtr(rewriter, loc); 1009 if (offset != nullptr) 1010 *offset = desc.offset(rewriter, loc); 1011 return; 1012 } 1013 1014 unsigned memorySpace = 1015 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 1016 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 1017 Type llvmElementType = typeConverter.convertType(elementType); 1018 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 1019 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 1020 1021 // Extract pointer to the underlying ranked memref descriptor and cast it to 1022 // ElemType**. 1023 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 1024 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 1025 1026 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 1027 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 1028 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 1029 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1030 if (offset != nullptr) { 1031 *offset = UnrankedMemRefDescriptor::offset( 1032 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 1033 } 1034 } 1035 1036 struct MemRefReinterpretCastOpLowering 1037 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1038 using ConvertOpToLLVMPattern< 1039 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1040 1041 LogicalResult 1042 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1043 ConversionPatternRewriter &rewriter) const override { 1044 Type srcType = castOp.getSource().getType(); 1045 1046 Value descriptor; 1047 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1048 adaptor, &descriptor))) 1049 return failure(); 1050 rewriter.replaceOp(castOp, {descriptor}); 1051 return success(); 1052 } 1053 1054 private: 1055 LogicalResult convertSourceMemRefToDescriptor( 1056 ConversionPatternRewriter &rewriter, Type srcType, 1057 memref::ReinterpretCastOp castOp, 1058 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1059 MemRefType targetMemRefType = 1060 castOp.getResult().getType().cast<MemRefType>(); 1061 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1062 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1063 if (!llvmTargetDescriptorTy) 1064 return failure(); 1065 1066 // Create descriptor. 1067 Location loc = castOp.getLoc(); 1068 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1069 1070 // Set allocated and aligned pointers. 1071 Value allocatedPtr, alignedPtr; 1072 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1073 castOp.getSource(), adaptor.getSource(), 1074 &allocatedPtr, &alignedPtr); 1075 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1076 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1077 1078 // Set offset. 1079 if (castOp.isDynamicOffset(0)) 1080 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); 1081 else 1082 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1083 1084 // Set sizes and strides. 1085 unsigned dynSizeId = 0; 1086 unsigned dynStrideId = 0; 1087 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1088 if (castOp.isDynamicSize(i)) 1089 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); 1090 else 1091 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1092 1093 if (castOp.isDynamicStride(i)) 1094 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); 1095 else 1096 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1097 } 1098 *descriptor = desc; 1099 return success(); 1100 } 1101 }; 1102 1103 struct MemRefReshapeOpLowering 1104 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1105 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1106 1107 LogicalResult 1108 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1109 ConversionPatternRewriter &rewriter) const override { 1110 Type srcType = reshapeOp.getSource().getType(); 1111 1112 Value descriptor; 1113 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1114 adaptor, &descriptor))) 1115 return failure(); 1116 rewriter.replaceOp(reshapeOp, {descriptor}); 1117 return success(); 1118 } 1119 1120 private: 1121 LogicalResult 1122 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1123 Type srcType, memref::ReshapeOp reshapeOp, 1124 memref::ReshapeOp::Adaptor adaptor, 1125 Value *descriptor) const { 1126 auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>(); 1127 if (shapeMemRefType.hasStaticShape()) { 1128 MemRefType targetMemRefType = 1129 reshapeOp.getResult().getType().cast<MemRefType>(); 1130 auto llvmTargetDescriptorTy = 1131 typeConverter->convertType(targetMemRefType) 1132 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1133 if (!llvmTargetDescriptorTy) 1134 return failure(); 1135 1136 // Create descriptor. 1137 Location loc = reshapeOp.getLoc(); 1138 auto desc = 1139 MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1140 1141 // Set allocated and aligned pointers. 1142 Value allocatedPtr, alignedPtr; 1143 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1144 reshapeOp.getSource(), adaptor.getSource(), 1145 &allocatedPtr, &alignedPtr); 1146 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1147 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1148 1149 // Extract the offset and strides from the type. 1150 int64_t offset; 1151 SmallVector<int64_t> strides; 1152 if (failed(getStridesAndOffset(targetMemRefType, strides, offset))) 1153 return rewriter.notifyMatchFailure( 1154 reshapeOp, "failed to get stride and offset exprs"); 1155 1156 if (!isStaticStrideOrOffset(offset)) 1157 return rewriter.notifyMatchFailure(reshapeOp, 1158 "dynamic offset is unsupported"); 1159 1160 desc.setConstantOffset(rewriter, loc, offset); 1161 1162 assert(targetMemRefType.getLayout().isIdentity() && 1163 "Identity layout map is a precondition of a valid reshape op"); 1164 1165 Value stride = nullptr; 1166 int64_t targetRank = targetMemRefType.getRank(); 1167 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) { 1168 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1169 // If the stride for this dimension is dynamic, then use the product 1170 // of the sizes of the inner dimensions. 1171 stride = createIndexConstant(rewriter, loc, strides[i]); 1172 } else if (!stride) { 1173 // `stride` is null only in the first iteration of the loop. However, 1174 // since the target memref has an identity layout, we can safely set 1175 // the innermost stride to 1. 1176 stride = createIndexConstant(rewriter, loc, 1); 1177 } 1178 1179 Value dimSize; 1180 int64_t size = targetMemRefType.getDimSize(i); 1181 // If the size of this dimension is dynamic, then load it at runtime 1182 // from the shape operand. 1183 if (!ShapedType::isDynamic(size)) { 1184 dimSize = createIndexConstant(rewriter, loc, size); 1185 } else { 1186 Value shapeOp = reshapeOp.getShape(); 1187 Value index = createIndexConstant(rewriter, loc, i); 1188 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index); 1189 Type indexType = getIndexType(); 1190 if (dimSize.getType() != indexType) 1191 dimSize = typeConverter->materializeTargetConversion( 1192 rewriter, loc, indexType, dimSize); 1193 assert(dimSize && "Invalid memref element type"); 1194 } 1195 1196 desc.setSize(rewriter, loc, i, dimSize); 1197 desc.setStride(rewriter, loc, i, stride); 1198 1199 // Prepare the stride value for the next dimension. 1200 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize); 1201 } 1202 1203 *descriptor = desc; 1204 return success(); 1205 } 1206 1207 // The shape is a rank-1 tensor with unknown length. 1208 Location loc = reshapeOp.getLoc(); 1209 MemRefDescriptor shapeDesc(adaptor.getShape()); 1210 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1211 1212 // Extract address space and element type. 1213 auto targetType = 1214 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 1215 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 1216 Type elementType = targetType.getElementType(); 1217 1218 // Create the unranked memref descriptor that holds the ranked one. The 1219 // inner descriptor is allocated on stack. 1220 auto targetDesc = UnrankedMemRefDescriptor::undef( 1221 rewriter, loc, typeConverter->convertType(targetType)); 1222 targetDesc.setRank(rewriter, loc, resultRank); 1223 SmallVector<Value, 4> sizes; 1224 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1225 targetDesc, sizes); 1226 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1227 loc, getVoidPtrType(), sizes.front(), llvm::None); 1228 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1229 1230 // Extract pointers and offset from the source memref. 1231 Value allocatedPtr, alignedPtr, offset; 1232 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1233 reshapeOp.getSource(), adaptor.getSource(), 1234 &allocatedPtr, &alignedPtr, &offset); 1235 1236 // Set pointers and offset. 1237 Type llvmElementType = typeConverter->convertType(elementType); 1238 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 1239 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 1240 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1241 elementPtrPtrType, allocatedPtr); 1242 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1243 underlyingDescPtr, 1244 elementPtrPtrType, alignedPtr); 1245 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1246 underlyingDescPtr, elementPtrPtrType, 1247 offset); 1248 1249 // Use the offset pointer as base for further addressing. Copy over the new 1250 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1251 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1252 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1253 elementPtrPtrType); 1254 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1255 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1256 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1257 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1258 Value resultRankMinusOne = 1259 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1260 1261 Block *initBlock = rewriter.getInsertionBlock(); 1262 Type indexType = getTypeConverter()->getIndexType(); 1263 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1264 1265 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1266 {indexType, indexType}, {loc, loc}); 1267 1268 // Move the remaining initBlock ops to condBlock. 1269 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1270 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1271 1272 rewriter.setInsertionPointToEnd(initBlock); 1273 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1274 condBlock); 1275 rewriter.setInsertionPointToStart(condBlock); 1276 Value indexArg = condBlock->getArgument(0); 1277 Value strideArg = condBlock->getArgument(1); 1278 1279 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1280 Value pred = rewriter.create<LLVM::ICmpOp>( 1281 loc, IntegerType::get(rewriter.getContext(), 1), 1282 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1283 1284 Block *bodyBlock = 1285 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1286 rewriter.setInsertionPointToStart(bodyBlock); 1287 1288 // Copy size from shape to descriptor. 1289 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1290 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1291 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1292 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1293 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1294 targetSizesBase, indexArg, size); 1295 1296 // Write stride value and compute next one. 1297 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1298 targetStridesBase, indexArg, strideArg); 1299 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1300 1301 // Decrement loop counter and branch back. 1302 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1303 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1304 condBlock); 1305 1306 Block *remainder = 1307 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1308 1309 // Hook up the cond exit to the remainder. 1310 rewriter.setInsertionPointToEnd(condBlock); 1311 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1312 llvm::None); 1313 1314 // Reset position to beginning of new remainder block. 1315 rewriter.setInsertionPointToStart(remainder); 1316 1317 *descriptor = targetDesc; 1318 return success(); 1319 } 1320 }; 1321 1322 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1323 /// `Value`s. 1324 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1325 Type &llvmIndexType, 1326 ArrayRef<OpFoldResult> valueOrAttrVec) { 1327 return llvm::to_vector<4>( 1328 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1329 if (auto attr = value.dyn_cast<Attribute>()) 1330 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1331 return value.get<Value>(); 1332 })); 1333 } 1334 1335 /// Compute a map that for a given dimension of the expanded type gives the 1336 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1337 /// the `reassocation` maps. 1338 static DenseMap<int64_t, int64_t> 1339 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1340 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1341 for (auto &en : enumerate(reassociation)) { 1342 for (auto dim : en.value()) 1343 expandedDimToCollapsedDim[dim] = en.index(); 1344 } 1345 return expandedDimToCollapsedDim; 1346 } 1347 1348 static OpFoldResult 1349 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1350 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1351 MemRefDescriptor &inDesc, 1352 ArrayRef<int64_t> inStaticShape, 1353 ArrayRef<ReassociationIndices> reassocation, 1354 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1355 int64_t outDimSize = outStaticShape[outDimIndex]; 1356 if (!ShapedType::isDynamic(outDimSize)) 1357 return b.getIndexAttr(outDimSize); 1358 1359 // Calculate the multiplication of all the out dim sizes except the 1360 // current dim. 1361 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1362 int64_t otherDimSizesMul = 1; 1363 for (auto otherDimIndex : reassocation[inDimIndex]) { 1364 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1365 continue; 1366 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1367 assert(!ShapedType::isDynamic(otherDimSize) && 1368 "single dimension cannot be expanded into multiple dynamic " 1369 "dimensions"); 1370 otherDimSizesMul *= otherDimSize; 1371 } 1372 1373 // outDimSize = inDimSize / otherOutDimSizesMul 1374 int64_t inDimSize = inStaticShape[inDimIndex]; 1375 Value inDimSizeDynamic = 1376 ShapedType::isDynamic(inDimSize) 1377 ? inDesc.size(b, loc, inDimIndex) 1378 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1379 b.getIndexAttr(inDimSize)); 1380 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1381 loc, inDimSizeDynamic, 1382 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1383 b.getIndexAttr(otherDimSizesMul))); 1384 return outDimSizeDynamic; 1385 } 1386 1387 static OpFoldResult getCollapsedOutputDimSize( 1388 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1389 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1390 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1391 if (!ShapedType::isDynamic(outDimSize)) 1392 return b.getIndexAttr(outDimSize); 1393 1394 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1395 Value outDimSizeDynamic = c1; 1396 for (auto inDimIndex : reassocation[outDimIndex]) { 1397 int64_t inDimSize = inStaticShape[inDimIndex]; 1398 Value inDimSizeDynamic = 1399 ShapedType::isDynamic(inDimSize) 1400 ? inDesc.size(b, loc, inDimIndex) 1401 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1402 b.getIndexAttr(inDimSize)); 1403 outDimSizeDynamic = 1404 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1405 } 1406 return outDimSizeDynamic; 1407 } 1408 1409 static SmallVector<OpFoldResult, 4> 1410 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1411 ArrayRef<ReassociationIndices> reassociation, 1412 ArrayRef<int64_t> inStaticShape, 1413 MemRefDescriptor &inDesc, 1414 ArrayRef<int64_t> outStaticShape) { 1415 return llvm::to_vector<4>(llvm::map_range( 1416 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1417 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1418 outStaticShape[outDimIndex], 1419 inStaticShape, inDesc, reassociation); 1420 })); 1421 } 1422 1423 static SmallVector<OpFoldResult, 4> 1424 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1425 ArrayRef<ReassociationIndices> reassociation, 1426 ArrayRef<int64_t> inStaticShape, 1427 MemRefDescriptor &inDesc, 1428 ArrayRef<int64_t> outStaticShape) { 1429 DenseMap<int64_t, int64_t> outDimToInDimMap = 1430 getExpandedDimToCollapsedDimMap(reassociation); 1431 return llvm::to_vector<4>(llvm::map_range( 1432 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1433 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1434 outStaticShape, inDesc, inStaticShape, 1435 reassociation, outDimToInDimMap); 1436 })); 1437 } 1438 1439 static SmallVector<Value> 1440 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1441 ArrayRef<ReassociationIndices> reassociation, 1442 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1443 ArrayRef<int64_t> outStaticShape) { 1444 return outStaticShape.size() < inStaticShape.size() 1445 ? getAsValues(b, loc, llvmIndexType, 1446 getCollapsedOutputShape(b, loc, llvmIndexType, 1447 reassociation, inStaticShape, 1448 inDesc, outStaticShape)) 1449 : getAsValues(b, loc, llvmIndexType, 1450 getExpandedOutputShape(b, loc, llvmIndexType, 1451 reassociation, inStaticShape, 1452 inDesc, outStaticShape)); 1453 } 1454 1455 static void fillInStridesForExpandedMemDescriptor( 1456 OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc, 1457 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1458 // See comments for computeExpandedLayoutMap for details on how the strides 1459 // are calculated. 1460 for (auto &en : llvm::enumerate(reassociation)) { 1461 auto currentStrideToExpand = srcDesc.stride(b, loc, en.index()); 1462 for (auto dstIndex : llvm::reverse(en.value())) { 1463 dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand); 1464 Value size = dstDesc.size(b, loc, dstIndex); 1465 currentStrideToExpand = 1466 b.create<LLVM::MulOp>(loc, size, currentStrideToExpand); 1467 } 1468 } 1469 } 1470 1471 static void fillInStridesForCollapsedMemDescriptor( 1472 ConversionPatternRewriter &rewriter, Location loc, Operation *op, 1473 TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, 1474 MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) { 1475 // See comments for computeCollapsedLayoutMap for details on how the strides 1476 // are calculated. 1477 auto srcShape = srcType.getShape(); 1478 for (auto &en : llvm::enumerate(reassociation)) { 1479 rewriter.setInsertionPoint(op); 1480 auto dstIndex = en.index(); 1481 ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value()); 1482 while (srcShape[ref.back()] == 1 && ref.size() > 1) 1483 ref = ref.drop_back(); 1484 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { 1485 dstDesc.setStride(rewriter, loc, dstIndex, 1486 srcDesc.stride(rewriter, loc, ref.back())); 1487 } else { 1488 // Iterate over the source strides in reverse order. Skip over the 1489 // dimensions whose size is 1. 1490 // TODO: we should take the minimum stride in the reassociation group 1491 // instead of just the first where the dimension is not 1. 1492 // 1493 // +------------------------------------------------------+ 1494 // | curEntry: | 1495 // | %srcStride = strides[srcIndex] | 1496 // | %neOne = cmp sizes[srcIndex],1 +--+ 1497 // | cf.cond_br %neOne, continue(%srcStride), nextEntry | | 1498 // +-------------------------+----------------------------+ | 1499 // | | 1500 // v | 1501 // +-----------------------------+ | 1502 // | nextEntry: | | 1503 // | ... +---+ | 1504 // +--------------+--------------+ | | 1505 // | | | 1506 // v | | 1507 // +-----------------------------+ | | 1508 // | nextEntry: | | | 1509 // | ... | | | 1510 // +--------------+--------------+ | +--------+ 1511 // | | | 1512 // v v v 1513 // +--------------------------------------------------+ 1514 // | continue(%newStride): | 1515 // | %newMemRefDes = setStride(%newStride,dstIndex) | 1516 // +--------------------------------------------------+ 1517 OpBuilder::InsertionGuard guard(rewriter); 1518 Block *initBlock = rewriter.getInsertionBlock(); 1519 Block *continueBlock = 1520 rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); 1521 continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); 1522 rewriter.setInsertionPointToStart(continueBlock); 1523 dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); 1524 1525 Block *curEntryBlock = initBlock; 1526 Block *nextEntryBlock; 1527 for (auto srcIndex : llvm::reverse(ref)) { 1528 if (srcShape[srcIndex] == 1 && srcIndex != ref.front()) 1529 continue; 1530 rewriter.setInsertionPointToEnd(curEntryBlock); 1531 Value srcStride = srcDesc.stride(rewriter, loc, srcIndex); 1532 if (srcIndex == ref.front()) { 1533 rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock); 1534 break; 1535 } 1536 Value one = rewriter.create<LLVM::ConstantOp>( 1537 loc, typeConverter->convertType(rewriter.getI64Type()), 1538 rewriter.getI32IntegerAttr(1)); 1539 Value predNeOne = rewriter.create<LLVM::ICmpOp>( 1540 loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), 1541 one); 1542 { 1543 OpBuilder::InsertionGuard guard(rewriter); 1544 nextEntryBlock = rewriter.createBlock( 1545 initBlock->getParent(), Region::iterator(continueBlock), {}); 1546 } 1547 rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock, 1548 srcStride, nextEntryBlock, llvm::None); 1549 curEntryBlock = nextEntryBlock; 1550 } 1551 } 1552 } 1553 } 1554 1555 static void fillInDynamicStridesForMemDescriptor( 1556 ConversionPatternRewriter &b, Location loc, Operation *op, 1557 TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, 1558 MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, 1559 ArrayRef<ReassociationIndices> reassociation) { 1560 if (srcType.getRank() > dstType.getRank()) 1561 fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, 1562 srcDesc, dstDesc, reassociation); 1563 else 1564 fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, 1565 reassociation); 1566 } 1567 1568 // ReshapeOp creates a new view descriptor of the proper rank. 1569 // For now, the only conversion supported is for target MemRef with static sizes 1570 // and strides. 1571 template <typename ReshapeOp> 1572 class ReassociatingReshapeOpConversion 1573 : public ConvertOpToLLVMPattern<ReshapeOp> { 1574 public: 1575 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1576 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1577 1578 LogicalResult 1579 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1580 ConversionPatternRewriter &rewriter) const override { 1581 MemRefType dstType = reshapeOp.getResultType(); 1582 MemRefType srcType = reshapeOp.getSrcType(); 1583 1584 int64_t offset; 1585 SmallVector<int64_t, 4> strides; 1586 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1587 return rewriter.notifyMatchFailure( 1588 reshapeOp, "failed to get stride and offset exprs"); 1589 } 1590 1591 MemRefDescriptor srcDesc(adaptor.getSrc()); 1592 Location loc = reshapeOp->getLoc(); 1593 auto dstDesc = MemRefDescriptor::undef( 1594 rewriter, loc, this->typeConverter->convertType(dstType)); 1595 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1596 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1597 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1598 1599 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1600 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1601 Type llvmIndexType = 1602 this->typeConverter->convertType(rewriter.getIndexType()); 1603 SmallVector<Value> dstShape = getDynamicOutputShape( 1604 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1605 srcStaticShape, srcDesc, dstStaticShape); 1606 for (auto &en : llvm::enumerate(dstShape)) 1607 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1608 1609 if (llvm::all_of(strides, isStaticStrideOrOffset)) { 1610 for (auto &en : llvm::enumerate(strides)) 1611 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1612 } else if (srcType.getLayout().isIdentity() && 1613 dstType.getLayout().isIdentity()) { 1614 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1615 rewriter.getIndexAttr(1)); 1616 Value stride = c1; 1617 for (auto dimIndex : 1618 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1619 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1620 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1621 } 1622 } else { 1623 // There could be mixed static/dynamic strides. For simplicity, we 1624 // recompute all strides if there is at least one dynamic stride. 1625 fillInDynamicStridesForMemDescriptor( 1626 rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, 1627 srcDesc, dstDesc, reshapeOp.getReassociationIndices()); 1628 } 1629 rewriter.replaceOp(reshapeOp, {dstDesc}); 1630 return success(); 1631 } 1632 }; 1633 1634 /// Conversion pattern that transforms a subview op into: 1635 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1636 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1637 /// and stride. 1638 /// The subview op is replaced by the descriptor. 1639 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1640 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1641 1642 LogicalResult 1643 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1644 ConversionPatternRewriter &rewriter) const override { 1645 auto loc = subViewOp.getLoc(); 1646 1647 auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>(); 1648 auto sourceElementTy = 1649 typeConverter->convertType(sourceMemRefType.getElementType()); 1650 1651 auto viewMemRefType = subViewOp.getType(); 1652 auto inferredType = 1653 memref::SubViewOp::inferResultType( 1654 subViewOp.getSourceType(), 1655 extractFromI64ArrayAttr(subViewOp.getStaticOffsets()), 1656 extractFromI64ArrayAttr(subViewOp.getStaticSizes()), 1657 extractFromI64ArrayAttr(subViewOp.getStaticStrides())) 1658 .cast<MemRefType>(); 1659 auto targetElementTy = 1660 typeConverter->convertType(viewMemRefType.getElementType()); 1661 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1662 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1663 !LLVM::isCompatibleType(sourceElementTy) || 1664 !LLVM::isCompatibleType(targetElementTy) || 1665 !LLVM::isCompatibleType(targetDescTy)) 1666 return failure(); 1667 1668 // Extract the offset and strides from the type. 1669 int64_t offset; 1670 SmallVector<int64_t, 4> strides; 1671 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1672 if (failed(successStrides)) 1673 return failure(); 1674 1675 // Create the descriptor. 1676 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1677 return failure(); 1678 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1679 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1680 1681 // Copy the buffer pointer from the old descriptor to the new one. 1682 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1683 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1684 loc, 1685 LLVM::LLVMPointerType::get(targetElementTy, 1686 viewMemRefType.getMemorySpaceAsInt()), 1687 extracted); 1688 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1689 1690 // Copy the aligned pointer from the old descriptor to the new one. 1691 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1692 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1693 loc, 1694 LLVM::LLVMPointerType::get(targetElementTy, 1695 viewMemRefType.getMemorySpaceAsInt()), 1696 extracted); 1697 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1698 1699 size_t inferredShapeRank = inferredType.getRank(); 1700 size_t resultShapeRank = viewMemRefType.getRank(); 1701 1702 // Extract strides needed to compute offset. 1703 SmallVector<Value, 4> strideValues; 1704 strideValues.reserve(inferredShapeRank); 1705 for (unsigned i = 0; i < inferredShapeRank; ++i) 1706 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1707 1708 // Offset. 1709 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1710 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1711 targetMemRef.setConstantOffset(rewriter, loc, offset); 1712 } else { 1713 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1714 // `inferredShapeRank` may be larger than the number of offset operands 1715 // because of trailing semantics. In this case, the offset is guaranteed 1716 // to be interpreted as 0 and we can just skip the extra dimensions. 1717 for (unsigned i = 0, e = std::min(inferredShapeRank, 1718 subViewOp.getMixedOffsets().size()); 1719 i < e; ++i) { 1720 Value offset = 1721 // TODO: need OpFoldResult ODS adaptor to clean this up. 1722 subViewOp.isDynamicOffset(i) 1723 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1724 : rewriter.create<LLVM::ConstantOp>( 1725 loc, llvmIndexType, 1726 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1727 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1728 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1729 } 1730 targetMemRef.setOffset(rewriter, loc, baseOffset); 1731 } 1732 1733 // Update sizes and strides. 1734 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1735 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1736 assert(mixedSizes.size() == mixedStrides.size() && 1737 "expected sizes and strides of equal length"); 1738 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 1739 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1740 i >= 0 && j >= 0; --i) { 1741 if (unusedDims.test(i)) 1742 continue; 1743 1744 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1745 // In this case, the size is guaranteed to be interpreted as Dim and the 1746 // stride as 1. 1747 Value size, stride; 1748 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1749 // If the static size is available, use it directly. This is similar to 1750 // the folding of dim(constant-op) but removes the need for dim to be 1751 // aware of LLVM constants and for this pass to be aware of std 1752 // constants. 1753 int64_t staticSize = 1754 subViewOp.getSource().getType().cast<MemRefType>().getShape()[i]; 1755 if (staticSize != ShapedType::kDynamicSize) { 1756 size = rewriter.create<LLVM::ConstantOp>( 1757 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1758 } else { 1759 Value pos = rewriter.create<LLVM::ConstantOp>( 1760 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1761 Value dim = 1762 rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos); 1763 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1764 loc, llvmIndexType, dim); 1765 size = cast.getResult(0); 1766 } 1767 stride = rewriter.create<LLVM::ConstantOp>( 1768 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1769 } else { 1770 // TODO: need OpFoldResult ODS adaptor to clean this up. 1771 size = 1772 subViewOp.isDynamicSize(i) 1773 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1774 : rewriter.create<LLVM::ConstantOp>( 1775 loc, llvmIndexType, 1776 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1777 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1778 stride = rewriter.create<LLVM::ConstantOp>( 1779 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1780 } else { 1781 stride = 1782 subViewOp.isDynamicStride(i) 1783 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1784 : rewriter.create<LLVM::ConstantOp>( 1785 loc, llvmIndexType, 1786 rewriter.getI64IntegerAttr( 1787 subViewOp.getStaticStride(i))); 1788 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1789 } 1790 } 1791 targetMemRef.setSize(rewriter, loc, j, size); 1792 targetMemRef.setStride(rewriter, loc, j, stride); 1793 j--; 1794 } 1795 1796 rewriter.replaceOp(subViewOp, {targetMemRef}); 1797 return success(); 1798 } 1799 }; 1800 1801 /// Conversion pattern that transforms a transpose op into: 1802 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1803 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1804 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1805 /// and stride. Size and stride are permutations of the original values. 1806 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1807 /// The transpose op is replaced by the alloca'ed pointer. 1808 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1809 public: 1810 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1811 1812 LogicalResult 1813 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1814 ConversionPatternRewriter &rewriter) const override { 1815 auto loc = transposeOp.getLoc(); 1816 MemRefDescriptor viewMemRef(adaptor.getIn()); 1817 1818 // No permutation, early exit. 1819 if (transposeOp.getPermutation().isIdentity()) 1820 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1821 1822 auto targetMemRef = MemRefDescriptor::undef( 1823 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1824 1825 // Copy the base and aligned pointers from the old descriptor to the new 1826 // one. 1827 targetMemRef.setAllocatedPtr(rewriter, loc, 1828 viewMemRef.allocatedPtr(rewriter, loc)); 1829 targetMemRef.setAlignedPtr(rewriter, loc, 1830 viewMemRef.alignedPtr(rewriter, loc)); 1831 1832 // Copy the offset pointer from the old descriptor to the new one. 1833 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1834 1835 // Iterate over the dimensions and apply size/stride permutation. 1836 for (const auto &en : 1837 llvm::enumerate(transposeOp.getPermutation().getResults())) { 1838 int sourcePos = en.index(); 1839 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1840 targetMemRef.setSize(rewriter, loc, targetPos, 1841 viewMemRef.size(rewriter, loc, sourcePos)); 1842 targetMemRef.setStride(rewriter, loc, targetPos, 1843 viewMemRef.stride(rewriter, loc, sourcePos)); 1844 } 1845 1846 rewriter.replaceOp(transposeOp, {targetMemRef}); 1847 return success(); 1848 } 1849 }; 1850 1851 /// Conversion pattern that transforms an op into: 1852 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1853 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1854 /// and stride. 1855 /// The view op is replaced by the descriptor. 1856 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1857 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1858 1859 // Build and return the value for the idx^th shape dimension, either by 1860 // returning the constant shape dimension or counting the proper dynamic size. 1861 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1862 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1863 unsigned idx) const { 1864 assert(idx < shape.size()); 1865 if (!ShapedType::isDynamic(shape[idx])) 1866 return createIndexConstant(rewriter, loc, shape[idx]); 1867 // Count the number of dynamic dims in range [0, idx] 1868 unsigned nDynamic = 1869 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic); 1870 return dynamicSizes[nDynamic]; 1871 } 1872 1873 // Build and return the idx^th stride, either by returning the constant stride 1874 // or by computing the dynamic stride from the current `runningStride` and 1875 // `nextSize`. The caller should keep a running stride and update it with the 1876 // result returned by this function. 1877 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1878 ArrayRef<int64_t> strides, Value nextSize, 1879 Value runningStride, unsigned idx) const { 1880 assert(idx < strides.size()); 1881 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1882 return createIndexConstant(rewriter, loc, strides[idx]); 1883 if (nextSize) 1884 return runningStride 1885 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1886 : nextSize; 1887 assert(!runningStride); 1888 return createIndexConstant(rewriter, loc, 1); 1889 } 1890 1891 LogicalResult 1892 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1893 ConversionPatternRewriter &rewriter) const override { 1894 auto loc = viewOp.getLoc(); 1895 1896 auto viewMemRefType = viewOp.getType(); 1897 auto targetElementTy = 1898 typeConverter->convertType(viewMemRefType.getElementType()); 1899 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1900 if (!targetDescTy || !targetElementTy || 1901 !LLVM::isCompatibleType(targetElementTy) || 1902 !LLVM::isCompatibleType(targetDescTy)) 1903 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1904 failure(); 1905 1906 int64_t offset; 1907 SmallVector<int64_t, 4> strides; 1908 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1909 if (failed(successStrides)) 1910 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1911 assert(offset == 0 && "expected offset to be 0"); 1912 1913 // Target memref must be contiguous in memory (innermost stride is 1), or 1914 // empty (special case when at least one of the memref dimensions is 0). 1915 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) 1916 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1917 failure(); 1918 1919 // Create the descriptor. 1920 MemRefDescriptor sourceMemRef(adaptor.getSource()); 1921 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1922 1923 // Field 1: Copy the allocated pointer, used for malloc/free. 1924 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1925 auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>(); 1926 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1927 loc, 1928 LLVM::LLVMPointerType::get(targetElementTy, 1929 srcMemRefType.getMemorySpaceAsInt()), 1930 allocatedPtr); 1931 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1932 1933 // Field 2: Copy the actual aligned pointer to payload. 1934 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1935 alignedPtr = rewriter.create<LLVM::GEPOp>( 1936 loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift()); 1937 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1938 loc, 1939 LLVM::LLVMPointerType::get(targetElementTy, 1940 srcMemRefType.getMemorySpaceAsInt()), 1941 alignedPtr); 1942 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1943 1944 // Field 3: The offset in the resulting type must be 0. This is because of 1945 // the type change: an offset on srcType* may not be expressible as an 1946 // offset on dstType*. 1947 targetMemRef.setOffset(rewriter, loc, 1948 createIndexConstant(rewriter, loc, offset)); 1949 1950 // Early exit for 0-D corner case. 1951 if (viewMemRefType.getRank() == 0) 1952 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1953 1954 // Fields 4 and 5: Update sizes and strides. 1955 Value stride = nullptr, nextSize = nullptr; 1956 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1957 // Update size. 1958 Value size = getSize(rewriter, loc, viewMemRefType.getShape(), 1959 adaptor.getSizes(), i); 1960 targetMemRef.setSize(rewriter, loc, i, size); 1961 // Update stride. 1962 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1963 targetMemRef.setStride(rewriter, loc, i, stride); 1964 nextSize = size; 1965 } 1966 1967 rewriter.replaceOp(viewOp, {targetMemRef}); 1968 return success(); 1969 } 1970 }; 1971 1972 //===----------------------------------------------------------------------===// 1973 // AtomicRMWOpLowering 1974 //===----------------------------------------------------------------------===// 1975 1976 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1977 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1978 static Optional<LLVM::AtomicBinOp> 1979 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1980 switch (atomicOp.getKind()) { 1981 case arith::AtomicRMWKind::addf: 1982 return LLVM::AtomicBinOp::fadd; 1983 case arith::AtomicRMWKind::addi: 1984 return LLVM::AtomicBinOp::add; 1985 case arith::AtomicRMWKind::assign: 1986 return LLVM::AtomicBinOp::xchg; 1987 case arith::AtomicRMWKind::maxs: 1988 return LLVM::AtomicBinOp::max; 1989 case arith::AtomicRMWKind::maxu: 1990 return LLVM::AtomicBinOp::umax; 1991 case arith::AtomicRMWKind::mins: 1992 return LLVM::AtomicBinOp::min; 1993 case arith::AtomicRMWKind::minu: 1994 return LLVM::AtomicBinOp::umin; 1995 case arith::AtomicRMWKind::ori: 1996 return LLVM::AtomicBinOp::_or; 1997 case arith::AtomicRMWKind::andi: 1998 return LLVM::AtomicBinOp::_and; 1999 default: 2000 return llvm::None; 2001 } 2002 llvm_unreachable("Invalid AtomicRMWKind"); 2003 } 2004 2005 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 2006 using Base::Base; 2007 2008 LogicalResult 2009 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 2010 ConversionPatternRewriter &rewriter) const override { 2011 if (failed(match(atomicOp))) 2012 return failure(); 2013 auto maybeKind = matchSimpleAtomicOp(atomicOp); 2014 if (!maybeKind) 2015 return failure(); 2016 auto resultType = adaptor.getValue().getType(); 2017 auto memRefType = atomicOp.getMemRefType(); 2018 auto dataPtr = 2019 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), 2020 adaptor.getIndices(), rewriter); 2021 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 2022 atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(), 2023 LLVM::AtomicOrdering::acq_rel); 2024 return success(); 2025 } 2026 }; 2027 2028 } // namespace 2029 2030 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 2031 RewritePatternSet &patterns) { 2032 // clang-format off 2033 patterns.add< 2034 AllocaOpLowering, 2035 AllocaScopeOpLowering, 2036 AtomicRMWOpLowering, 2037 AssumeAlignmentOpLowering, 2038 DimOpLowering, 2039 GenericAtomicRMWOpLowering, 2040 GlobalMemrefOpLowering, 2041 GetGlobalMemrefOpLowering, 2042 LoadOpLowering, 2043 MemRefCastOpLowering, 2044 MemRefCopyOpLowering, 2045 MemRefReinterpretCastOpLowering, 2046 MemRefReshapeOpLowering, 2047 PrefetchOpLowering, 2048 RankOpLowering, 2049 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 2050 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 2051 StoreOpLowering, 2052 SubViewOpLowering, 2053 TransposeOpLowering, 2054 ViewOpLowering>(converter); 2055 // clang-format on 2056 auto allocLowering = converter.getOptions().allocLowering; 2057 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 2058 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 2059 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 2060 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 2061 } 2062 2063 namespace { 2064 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 2065 MemRefToLLVMPass() = default; 2066 2067 void runOnOperation() override { 2068 Operation *op = getOperation(); 2069 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 2070 LowerToLLVMOptions options(&getContext(), 2071 dataLayoutAnalysis.getAtOrAbove(op)); 2072 options.allocLowering = 2073 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 2074 : LowerToLLVMOptions::AllocLowering::Malloc); 2075 2076 options.useGenericFunctions = useGenericFunctions; 2077 2078 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 2079 options.overrideIndexBitwidth(indexBitwidth); 2080 2081 LLVMTypeConverter typeConverter(&getContext(), options, 2082 &dataLayoutAnalysis); 2083 RewritePatternSet patterns(&getContext()); 2084 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 2085 LLVMConversionTarget target(getContext()); 2086 target.addLegalOp<func::FuncOp>(); 2087 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 2088 signalPassFailure(); 2089 } 2090 }; 2091 } // namespace 2092 2093 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 2094 return std::make_unique<MemRefToLLVMPass>(); 2095 } 2096