1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 27 AllocOpLowering(LLVMTypeConverter &converter) 28 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 29 converter) {} 30 31 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 32 Location loc, Value sizeBytes, 33 Operation *op) const override { 34 // Heap allocations. 35 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 36 MemRefType memRefType = allocOp.getType(); 37 38 Value alignment; 39 if (auto alignmentAttr = allocOp.alignment()) { 40 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 41 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 42 // In the case where no alignment is specified, we may want to override 43 // `malloc's` behavior. `malloc` typically aligns at the size of the 44 // biggest scalar on a target HW. For non-scalars, use the natural 45 // alignment of the LLVM type given by the LLVM DataLayout. 46 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 47 } 48 49 if (alignment) { 50 // Adjust the allocation size to consider alignment. 51 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 52 } 53 54 // Allocate the underlying buffer and store a pointer to it in the MemRef 55 // descriptor. 56 Type elementPtrType = this->getElementPtrType(memRefType); 57 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 58 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 59 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 60 getVoidPtrType()); 61 Value allocatedPtr = 62 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 63 64 Value alignedPtr = allocatedPtr; 65 if (alignment) { 66 // Compute the aligned type pointer. 67 Value allocatedInt = 68 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 69 Value alignmentInt = 70 createAligned(rewriter, loc, allocatedInt, alignment); 71 alignedPtr = 72 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 73 } 74 75 return std::make_tuple(allocatedPtr, alignedPtr); 76 } 77 }; 78 79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 80 AlignedAllocOpLowering(LLVMTypeConverter &converter) 81 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 82 converter) {} 83 84 /// Returns the memref's element size in bytes using the data layout active at 85 /// `op`. 86 // TODO: there are other places where this is used. Expose publicly? 87 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 88 const DataLayout *layout = &defaultLayout; 89 if (const DataLayoutAnalysis *analysis = 90 getTypeConverter()->getDataLayoutAnalysis()) { 91 layout = &analysis->getAbove(op); 92 } 93 Type elementType = memRefType.getElementType(); 94 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 95 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 96 *layout); 97 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 98 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 99 memRefElementType, *layout); 100 return layout->getTypeSize(elementType); 101 } 102 103 /// Returns true if the memref size in bytes is known to be a multiple of 104 /// factor assuming the data layout active at `op`. 105 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 106 Operation *op) const { 107 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 108 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 109 if (ShapedType::isDynamic(type.getDimSize(i))) 110 continue; 111 sizeDivisor = sizeDivisor * type.getDimSize(i); 112 } 113 return sizeDivisor % factor == 0; 114 } 115 116 /// Returns the alignment to be used for the allocation call itself. 117 /// aligned_alloc requires the allocation size to be a power of two, and the 118 /// allocation size to be a multiple of alignment, 119 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 120 if (Optional<uint64_t> alignment = allocOp.alignment()) 121 return *alignment; 122 123 // Whenever we don't have alignment set, we will use an alignment 124 // consistent with the element type; since the allocation size has to be a 125 // power of two, we will bump to the next power of two if it already isn't. 126 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 127 return std::max(kMinAlignedAllocAlignment, 128 llvm::PowerOf2Ceil(eltSizeBytes)); 129 } 130 131 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 132 Location loc, Value sizeBytes, 133 Operation *op) const override { 134 // Heap allocations. 135 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 136 MemRefType memRefType = allocOp.getType(); 137 int64_t alignment = getAllocationAlignment(allocOp); 138 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 139 140 // aligned_alloc requires size to be a multiple of alignment; we will pad 141 // the size to the next multiple if necessary. 142 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 143 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 144 145 Type elementPtrType = this->getElementPtrType(memRefType); 146 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 147 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 148 auto results = 149 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 150 getVoidPtrType()); 151 Value allocatedPtr = 152 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 153 154 return std::make_tuple(allocatedPtr, allocatedPtr); 155 } 156 157 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 158 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 159 160 /// Default layout to use in absence of the corresponding analysis. 161 DataLayout defaultLayout; 162 }; 163 164 // Out of line definition, required till C++17. 165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 166 167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 168 AllocaOpLowering(LLVMTypeConverter &converter) 169 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 170 converter) {} 171 172 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 173 /// is set to null for stack allocations. `accessAlignment` is set if 174 /// alignment is needed post allocation (for eg. in conjunction with malloc). 175 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 176 Location loc, Value sizeBytes, 177 Operation *op) const override { 178 179 // With alloca, one gets a pointer to the element type right away. 180 // For stack allocations. 181 auto allocaOp = cast<memref::AllocaOp>(op); 182 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 183 184 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 185 loc, elementPtrType, sizeBytes, 186 allocaOp.alignment() ? *allocaOp.alignment() : 0); 187 188 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 189 } 190 }; 191 192 struct AllocaScopeOpLowering 193 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 194 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 195 196 LogicalResult 197 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 198 ConversionPatternRewriter &rewriter) const override { 199 OpBuilder::InsertionGuard guard(rewriter); 200 Location loc = allocaScopeOp.getLoc(); 201 202 // Split the current block before the AllocaScopeOp to create the inlining 203 // point. 204 auto *currentBlock = rewriter.getInsertionBlock(); 205 auto *remainingOpsBlock = 206 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 207 Block *continueBlock; 208 if (allocaScopeOp.getNumResults() == 0) { 209 continueBlock = remainingOpsBlock; 210 } else { 211 continueBlock = rewriter.createBlock(remainingOpsBlock, 212 allocaScopeOp.getResultTypes()); 213 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 214 } 215 216 // Inline body region. 217 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 218 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 219 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 220 221 // Save stack and then branch into the body of the region. 222 rewriter.setInsertionPointToEnd(currentBlock); 223 auto stackSaveOp = 224 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 225 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 226 227 // Replace the alloca_scope return with a branch that jumps out of the body. 228 // Stack restore before leaving the body region. 229 rewriter.setInsertionPointToEnd(afterBody); 230 auto returnOp = 231 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 232 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 233 returnOp, returnOp.results(), continueBlock); 234 235 // Insert stack restore before jumping out the body of the region. 236 rewriter.setInsertionPoint(branchOp); 237 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 238 239 // Replace the op with values return from the body region. 240 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 241 242 return success(); 243 } 244 }; 245 246 struct AssumeAlignmentOpLowering 247 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 248 using ConvertOpToLLVMPattern< 249 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 250 251 LogicalResult 252 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 253 ConversionPatternRewriter &rewriter) const override { 254 Value memref = adaptor.memref(); 255 unsigned alignment = op.alignment(); 256 auto loc = op.getLoc(); 257 258 MemRefDescriptor memRefDescriptor(memref); 259 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 260 261 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 262 // the asserted memref.alignedPtr isn't used anywhere else, as the real 263 // users like load/store/views always re-extract memref.alignedPtr as they 264 // get lowered. 265 // 266 // This relies on LLVM's CSE optimization (potentially after SROA), since 267 // after CSE all memref.alignedPtr instances get de-duplicated into the same 268 // pointer SSA value. 269 auto intPtrType = 270 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 271 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 272 Value mask = 273 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 274 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 275 rewriter.create<LLVM::AssumeOp>( 276 loc, rewriter.create<LLVM::ICmpOp>( 277 loc, LLVM::ICmpPredicate::eq, 278 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 279 280 rewriter.eraseOp(op); 281 return success(); 282 } 283 }; 284 285 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 286 // The memref descriptor being an SSA value, there is no need to clean it up 287 // in any way. 288 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 289 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 290 291 explicit DeallocOpLowering(LLVMTypeConverter &converter) 292 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 293 294 LogicalResult 295 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 296 ConversionPatternRewriter &rewriter) const override { 297 // Insert the `free` declaration if it is not already present. 298 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 299 MemRefDescriptor memref(adaptor.memref()); 300 Value casted = rewriter.create<LLVM::BitcastOp>( 301 op.getLoc(), getVoidPtrType(), 302 memref.allocatedPtr(rewriter, op.getLoc())); 303 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 304 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 305 return success(); 306 } 307 }; 308 309 // A `dim` is converted to a constant for static sizes and to an access to the 310 // size stored in the memref descriptor for dynamic sizes. 311 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 312 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 313 314 LogicalResult 315 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 316 ConversionPatternRewriter &rewriter) const override { 317 Type operandType = dimOp.source().getType(); 318 if (operandType.isa<UnrankedMemRefType>()) { 319 rewriter.replaceOp( 320 dimOp, {extractSizeOfUnrankedMemRef( 321 operandType, dimOp, adaptor.getOperands(), rewriter)}); 322 323 return success(); 324 } 325 if (operandType.isa<MemRefType>()) { 326 rewriter.replaceOp( 327 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 328 adaptor.getOperands(), rewriter)}); 329 return success(); 330 } 331 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 332 } 333 334 private: 335 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 336 OpAdaptor adaptor, 337 ConversionPatternRewriter &rewriter) const { 338 Location loc = dimOp.getLoc(); 339 340 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 341 auto scalarMemRefType = 342 MemRefType::get({}, unrankedMemRefType.getElementType()); 343 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 344 345 // Extract pointer to the underlying ranked descriptor and bitcast it to a 346 // memref<element_type> descriptor pointer to minimize the number of GEP 347 // operations. 348 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 349 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 350 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 351 loc, 352 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 353 addressSpace), 354 underlyingRankedDesc); 355 356 // Get pointer to offset field of memref<element_type> descriptor. 357 Type indexPtrTy = LLVM::LLVMPointerType::get( 358 getTypeConverter()->getIndexType(), addressSpace); 359 Value two = rewriter.create<LLVM::ConstantOp>( 360 loc, typeConverter->convertType(rewriter.getI32Type()), 361 rewriter.getI32IntegerAttr(2)); 362 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 363 loc, indexPtrTy, scalarMemRefDescPtr, 364 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 365 366 // The size value that we have to extract can be obtained using GEPop with 367 // `dimOp.index() + 1` index argument. 368 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 369 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 370 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 371 ValueRange({idxPlusOne})); 372 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 373 } 374 375 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 376 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 377 return idx; 378 379 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 380 return constantOp.getValue() 381 .cast<IntegerAttr>() 382 .getValue() 383 .getSExtValue(); 384 385 return llvm::None; 386 } 387 388 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 389 OpAdaptor adaptor, 390 ConversionPatternRewriter &rewriter) const { 391 Location loc = dimOp.getLoc(); 392 393 // Take advantage if index is constant. 394 MemRefType memRefType = operandType.cast<MemRefType>(); 395 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 396 int64_t i = index.getValue(); 397 if (memRefType.isDynamicDim(i)) { 398 // extract dynamic size from the memref descriptor. 399 MemRefDescriptor descriptor(adaptor.source()); 400 return descriptor.size(rewriter, loc, i); 401 } 402 // Use constant for static size. 403 int64_t dimSize = memRefType.getDimSize(i); 404 return createIndexConstant(rewriter, loc, dimSize); 405 } 406 Value index = adaptor.index(); 407 int64_t rank = memRefType.getRank(); 408 MemRefDescriptor memrefDescriptor(adaptor.source()); 409 return memrefDescriptor.size(rewriter, loc, index, rank); 410 } 411 }; 412 413 /// Returns the LLVM type of the global variable given the memref type `type`. 414 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 415 LLVMTypeConverter &typeConverter) { 416 // LLVM type for a global memref will be a multi-dimension array. For 417 // declarations or uninitialized global memrefs, we can potentially flatten 418 // this to a 1D array. However, for memref.global's with an initial value, 419 // we do not intend to flatten the ElementsAttribute when going from std -> 420 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 421 Type elementType = typeConverter.convertType(type.getElementType()); 422 Type arrayTy = elementType; 423 // Shape has the outermost dim at index 0, so need to walk it backwards 424 for (int64_t dim : llvm::reverse(type.getShape())) 425 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 426 return arrayTy; 427 } 428 429 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 430 struct GlobalMemrefOpLowering 431 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 432 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 433 434 LogicalResult 435 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 436 ConversionPatternRewriter &rewriter) const override { 437 MemRefType type = global.type(); 438 if (!isConvertibleAndHasIdentityMaps(type)) 439 return failure(); 440 441 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 442 443 LLVM::Linkage linkage = 444 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 445 446 Attribute initialValue = nullptr; 447 if (!global.isExternal() && !global.isUninitialized()) { 448 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 449 initialValue = elementsAttr; 450 451 // For scalar memrefs, the global variable created is of the element type, 452 // so unpack the elements attribute to extract the value. 453 if (type.getRank() == 0) 454 initialValue = elementsAttr.getSplatValue<Attribute>(); 455 } 456 457 uint64_t alignment = global.alignment().getValueOr(0); 458 459 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 460 global, arrayTy, global.constant(), linkage, global.sym_name(), 461 initialValue, alignment, type.getMemorySpaceAsInt()); 462 if (!global.isExternal() && global.isUninitialized()) { 463 Block *blk = new Block(); 464 newGlobal.getInitializerRegion().push_back(blk); 465 rewriter.setInsertionPointToStart(blk); 466 Value undef[] = { 467 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 468 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 469 } 470 return success(); 471 } 472 }; 473 474 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 475 /// the first element stashed into the descriptor. This reuses 476 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 477 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 478 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 479 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 480 converter) {} 481 482 /// Buffer "allocation" for memref.get_global op is getting the address of 483 /// the global variable referenced. 484 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 485 Location loc, Value sizeBytes, 486 Operation *op) const override { 487 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 488 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 489 unsigned memSpace = type.getMemorySpaceAsInt(); 490 491 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 492 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 493 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 494 495 // Get the address of the first element in the array by creating a GEP with 496 // the address of the GV as the base, and (rank + 1) number of 0 indices. 497 Type elementType = typeConverter->convertType(type.getElementType()); 498 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 499 500 SmallVector<Value> operands; 501 operands.insert(operands.end(), type.getRank() + 1, 502 createIndexConstant(rewriter, loc, 0)); 503 auto gep = 504 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 505 506 // We do not expect the memref obtained using `memref.get_global` to be 507 // ever deallocated. Set the allocated pointer to be known bad value to 508 // help debug if that ever happens. 509 auto intPtrType = getIntPtrType(memSpace); 510 Value deadBeefConst = 511 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 512 auto deadBeefPtr = 513 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 514 515 // Both allocated and aligned pointers are same. We could potentially stash 516 // a nullptr for the allocated pointer since we do not expect any dealloc. 517 return std::make_tuple(deadBeefPtr, gep); 518 } 519 }; 520 521 // Common base for load and store operations on MemRefs. Restricts the match 522 // to supported MemRef types. Provides functionality to emit code accessing a 523 // specific element of the underlying data buffer. 524 template <typename Derived> 525 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 526 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 527 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 528 using Base = LoadStoreOpLowering<Derived>; 529 530 LogicalResult match(Derived op) const override { 531 MemRefType type = op.getMemRefType(); 532 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 533 } 534 }; 535 536 // Load operation is lowered to obtaining a pointer to the indexed element 537 // and loading it. 538 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 539 using Base::Base; 540 541 LogicalResult 542 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 543 ConversionPatternRewriter &rewriter) const override { 544 auto type = loadOp.getMemRefType(); 545 546 Value dataPtr = getStridedElementPtr( 547 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 548 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 549 return success(); 550 } 551 }; 552 553 // Store operation is lowered to obtaining a pointer to the indexed element, 554 // and storing the given value to it. 555 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 556 using Base::Base; 557 558 LogicalResult 559 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 560 ConversionPatternRewriter &rewriter) const override { 561 auto type = op.getMemRefType(); 562 563 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 564 adaptor.indices(), rewriter); 565 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 566 return success(); 567 } 568 }; 569 570 // The prefetch operation is lowered in a way similar to the load operation 571 // except that the llvm.prefetch operation is used for replacement. 572 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 573 using Base::Base; 574 575 LogicalResult 576 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 577 ConversionPatternRewriter &rewriter) const override { 578 auto type = prefetchOp.getMemRefType(); 579 auto loc = prefetchOp.getLoc(); 580 581 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 582 adaptor.indices(), rewriter); 583 584 // Replace with llvm.prefetch. 585 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 586 auto isWrite = rewriter.create<LLVM::ConstantOp>( 587 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 588 auto localityHint = rewriter.create<LLVM::ConstantOp>( 589 loc, llvmI32Type, 590 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 591 auto isData = rewriter.create<LLVM::ConstantOp>( 592 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 593 594 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 595 localityHint, isData); 596 return success(); 597 } 598 }; 599 600 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 601 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 602 603 LogicalResult 604 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 605 ConversionPatternRewriter &rewriter) const override { 606 Location loc = op.getLoc(); 607 Type operandType = op.memref().getType(); 608 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 609 UnrankedMemRefDescriptor desc(adaptor.memref()); 610 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 611 return success(); 612 } 613 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 614 rewriter.replaceOp( 615 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 616 return success(); 617 } 618 return failure(); 619 } 620 }; 621 622 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 623 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 624 625 LogicalResult match(memref::CastOp memRefCastOp) const override { 626 Type srcType = memRefCastOp.getOperand().getType(); 627 Type dstType = memRefCastOp.getType(); 628 629 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 630 // used for type erasure. For now they must preserve underlying element type 631 // and require source and result type to have the same rank. Therefore, 632 // perform a sanity check that the underlying structs are the same. Once op 633 // semantics are relaxed we can revisit. 634 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 635 return success(typeConverter->convertType(srcType) == 636 typeConverter->convertType(dstType)); 637 638 // At least one of the operands is unranked type 639 assert(srcType.isa<UnrankedMemRefType>() || 640 dstType.isa<UnrankedMemRefType>()); 641 642 // Unranked to unranked cast is disallowed 643 return !(srcType.isa<UnrankedMemRefType>() && 644 dstType.isa<UnrankedMemRefType>()) 645 ? success() 646 : failure(); 647 } 648 649 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 650 ConversionPatternRewriter &rewriter) const override { 651 auto srcType = memRefCastOp.getOperand().getType(); 652 auto dstType = memRefCastOp.getType(); 653 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 654 auto loc = memRefCastOp.getLoc(); 655 656 // For ranked/ranked case, just keep the original descriptor. 657 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 658 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 659 660 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 661 // Casting ranked to unranked memref type 662 // Set the rank in the destination from the memref type 663 // Allocate space on the stack and copy the src memref descriptor 664 // Set the ptr in the destination to the stack space 665 auto srcMemRefType = srcType.cast<MemRefType>(); 666 int64_t rank = srcMemRefType.getRank(); 667 // ptr = AllocaOp sizeof(MemRefDescriptor) 668 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 669 loc, adaptor.source(), rewriter); 670 // voidptr = BitCastOp srcType* to void* 671 auto voidPtr = 672 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 673 .getResult(); 674 // rank = ConstantOp srcRank 675 auto rankVal = rewriter.create<LLVM::ConstantOp>( 676 loc, typeConverter->convertType(rewriter.getIntegerType(64)), 677 rewriter.getI64IntegerAttr(rank)); 678 // undef = UndefOp 679 UnrankedMemRefDescriptor memRefDesc = 680 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 681 // d1 = InsertValueOp undef, rank, 0 682 memRefDesc.setRank(rewriter, loc, rankVal); 683 // d2 = InsertValueOp d1, voidptr, 1 684 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 685 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 686 687 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 688 // Casting from unranked type to ranked. 689 // The operation is assumed to be doing a correct cast. If the destination 690 // type mismatches the unranked the type, it is undefined behavior. 691 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 692 // ptr = ExtractValueOp src, 1 693 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 694 // castPtr = BitCastOp i8* to structTy* 695 auto castPtr = 696 rewriter 697 .create<LLVM::BitcastOp>( 698 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 699 .getResult(); 700 // struct = LoadOp castPtr 701 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 702 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 703 } else { 704 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 705 } 706 } 707 }; 708 709 /// Pattern to lower a `memref.copy` to llvm. 710 /// 711 /// For memrefs with identity layouts, the copy is lowered to the llvm 712 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 713 /// to the generic `MemrefCopyFn`. 714 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 715 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 716 717 LogicalResult 718 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 719 ConversionPatternRewriter &rewriter) const { 720 auto loc = op.getLoc(); 721 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 722 723 MemRefDescriptor srcDesc(adaptor.source()); 724 725 // Compute number of elements. 726 Value numElements = rewriter.create<LLVM::ConstantOp>( 727 loc, getIndexType(), rewriter.getIndexAttr(1)); 728 for (int pos = 0; pos < srcType.getRank(); ++pos) { 729 auto size = srcDesc.size(rewriter, loc, pos); 730 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 731 } 732 733 // Get element size. 734 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 735 // Compute total. 736 Value totalSize = 737 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 738 739 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 740 MemRefDescriptor targetDesc(adaptor.target()); 741 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 742 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 743 loc, typeConverter->convertType(rewriter.getI1Type()), 744 rewriter.getBoolAttr(false)); 745 rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize, 746 isVolatile); 747 rewriter.eraseOp(op); 748 749 return success(); 750 } 751 752 LogicalResult 753 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 754 ConversionPatternRewriter &rewriter) const { 755 auto loc = op.getLoc(); 756 auto srcType = op.source().getType().cast<BaseMemRefType>(); 757 auto targetType = op.target().getType().cast<BaseMemRefType>(); 758 759 // First make sure we have an unranked memref descriptor representation. 760 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 761 auto rank = rewriter.create<LLVM::ConstantOp>( 762 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 763 auto *typeConverter = getTypeConverter(); 764 auto ptr = 765 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 766 auto voidPtr = 767 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 768 .getResult(); 769 auto unrankedType = 770 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 771 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 772 unrankedType, 773 ValueRange{rank, voidPtr}); 774 }; 775 776 Value unrankedSource = srcType.hasRank() 777 ? makeUnranked(adaptor.source(), srcType) 778 : adaptor.source(); 779 Value unrankedTarget = targetType.hasRank() 780 ? makeUnranked(adaptor.target(), targetType) 781 : adaptor.target(); 782 783 // Now promote the unranked descriptors to the stack. 784 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 785 rewriter.getIndexAttr(1)); 786 auto promote = [&](Value desc) { 787 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 788 auto allocated = 789 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 790 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 791 return allocated; 792 }; 793 794 auto sourcePtr = promote(unrankedSource); 795 auto targetPtr = promote(unrankedTarget); 796 797 auto elemSize = rewriter.create<LLVM::ConstantOp>( 798 loc, getIndexType(), 799 rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8)); 800 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 801 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 802 rewriter.create<LLVM::CallOp>(loc, copyFn, 803 ValueRange{elemSize, sourcePtr, targetPtr}); 804 rewriter.eraseOp(op); 805 806 return success(); 807 } 808 809 LogicalResult 810 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 811 ConversionPatternRewriter &rewriter) const override { 812 auto srcType = op.source().getType().cast<BaseMemRefType>(); 813 auto targetType = op.target().getType().cast<BaseMemRefType>(); 814 815 if (srcType.hasRank() && 816 srcType.cast<MemRefType>().getLayout().isIdentity() && 817 targetType.hasRank() && 818 targetType.cast<MemRefType>().getLayout().isIdentity()) 819 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 820 821 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 822 } 823 }; 824 825 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 826 /// memref type. In unranked case, the fields are extracted from the underlying 827 /// ranked descriptor. 828 static void extractPointersAndOffset(Location loc, 829 ConversionPatternRewriter &rewriter, 830 LLVMTypeConverter &typeConverter, 831 Value originalOperand, 832 Value convertedOperand, 833 Value *allocatedPtr, Value *alignedPtr, 834 Value *offset = nullptr) { 835 Type operandType = originalOperand.getType(); 836 if (operandType.isa<MemRefType>()) { 837 MemRefDescriptor desc(convertedOperand); 838 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 839 *alignedPtr = desc.alignedPtr(rewriter, loc); 840 if (offset != nullptr) 841 *offset = desc.offset(rewriter, loc); 842 return; 843 } 844 845 unsigned memorySpace = 846 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 847 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 848 Type llvmElementType = typeConverter.convertType(elementType); 849 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 850 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 851 852 // Extract pointer to the underlying ranked memref descriptor and cast it to 853 // ElemType**. 854 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 855 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 856 857 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 858 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 859 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 860 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 861 if (offset != nullptr) { 862 *offset = UnrankedMemRefDescriptor::offset( 863 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 864 } 865 } 866 867 struct MemRefReinterpretCastOpLowering 868 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 869 using ConvertOpToLLVMPattern< 870 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 871 872 LogicalResult 873 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 874 ConversionPatternRewriter &rewriter) const override { 875 Type srcType = castOp.source().getType(); 876 877 Value descriptor; 878 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 879 adaptor, &descriptor))) 880 return failure(); 881 rewriter.replaceOp(castOp, {descriptor}); 882 return success(); 883 } 884 885 private: 886 LogicalResult convertSourceMemRefToDescriptor( 887 ConversionPatternRewriter &rewriter, Type srcType, 888 memref::ReinterpretCastOp castOp, 889 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 890 MemRefType targetMemRefType = 891 castOp.getResult().getType().cast<MemRefType>(); 892 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 893 .dyn_cast_or_null<LLVM::LLVMStructType>(); 894 if (!llvmTargetDescriptorTy) 895 return failure(); 896 897 // Create descriptor. 898 Location loc = castOp.getLoc(); 899 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 900 901 // Set allocated and aligned pointers. 902 Value allocatedPtr, alignedPtr; 903 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 904 castOp.source(), adaptor.source(), &allocatedPtr, 905 &alignedPtr); 906 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 907 desc.setAlignedPtr(rewriter, loc, alignedPtr); 908 909 // Set offset. 910 if (castOp.isDynamicOffset(0)) 911 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 912 else 913 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 914 915 // Set sizes and strides. 916 unsigned dynSizeId = 0; 917 unsigned dynStrideId = 0; 918 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 919 if (castOp.isDynamicSize(i)) 920 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 921 else 922 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 923 924 if (castOp.isDynamicStride(i)) 925 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 926 else 927 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 928 } 929 *descriptor = desc; 930 return success(); 931 } 932 }; 933 934 struct MemRefReshapeOpLowering 935 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 936 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 937 938 LogicalResult 939 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 940 ConversionPatternRewriter &rewriter) const override { 941 Type srcType = reshapeOp.source().getType(); 942 943 Value descriptor; 944 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 945 adaptor, &descriptor))) 946 return failure(); 947 rewriter.replaceOp(reshapeOp, {descriptor}); 948 return success(); 949 } 950 951 private: 952 LogicalResult 953 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 954 Type srcType, memref::ReshapeOp reshapeOp, 955 memref::ReshapeOp::Adaptor adaptor, 956 Value *descriptor) const { 957 // Conversion for statically-known shape args is performed via 958 // `memref_reinterpret_cast`. 959 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 960 if (shapeMemRefType.hasStaticShape()) 961 return failure(); 962 963 // The shape is a rank-1 tensor with unknown length. 964 Location loc = reshapeOp.getLoc(); 965 MemRefDescriptor shapeDesc(adaptor.shape()); 966 Value resultRank = shapeDesc.size(rewriter, loc, 0); 967 968 // Extract address space and element type. 969 auto targetType = 970 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 971 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 972 Type elementType = targetType.getElementType(); 973 974 // Create the unranked memref descriptor that holds the ranked one. The 975 // inner descriptor is allocated on stack. 976 auto targetDesc = UnrankedMemRefDescriptor::undef( 977 rewriter, loc, typeConverter->convertType(targetType)); 978 targetDesc.setRank(rewriter, loc, resultRank); 979 SmallVector<Value, 4> sizes; 980 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 981 targetDesc, sizes); 982 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 983 loc, getVoidPtrType(), sizes.front(), llvm::None); 984 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 985 986 // Extract pointers and offset from the source memref. 987 Value allocatedPtr, alignedPtr, offset; 988 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 989 reshapeOp.source(), adaptor.source(), 990 &allocatedPtr, &alignedPtr, &offset); 991 992 // Set pointers and offset. 993 Type llvmElementType = typeConverter->convertType(elementType); 994 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 995 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 996 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 997 elementPtrPtrType, allocatedPtr); 998 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 999 underlyingDescPtr, 1000 elementPtrPtrType, alignedPtr); 1001 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1002 underlyingDescPtr, elementPtrPtrType, 1003 offset); 1004 1005 // Use the offset pointer as base for further addressing. Copy over the new 1006 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1007 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1008 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1009 elementPtrPtrType); 1010 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1011 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1012 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1013 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1014 Value resultRankMinusOne = 1015 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1016 1017 Block *initBlock = rewriter.getInsertionBlock(); 1018 Type indexType = getTypeConverter()->getIndexType(); 1019 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1020 1021 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1022 {indexType, indexType}); 1023 1024 // Move the remaining initBlock ops to condBlock. 1025 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1026 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1027 1028 rewriter.setInsertionPointToEnd(initBlock); 1029 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1030 condBlock); 1031 rewriter.setInsertionPointToStart(condBlock); 1032 Value indexArg = condBlock->getArgument(0); 1033 Value strideArg = condBlock->getArgument(1); 1034 1035 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1036 Value pred = rewriter.create<LLVM::ICmpOp>( 1037 loc, IntegerType::get(rewriter.getContext(), 1), 1038 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1039 1040 Block *bodyBlock = 1041 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1042 rewriter.setInsertionPointToStart(bodyBlock); 1043 1044 // Copy size from shape to descriptor. 1045 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1046 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1047 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1048 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1049 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1050 targetSizesBase, indexArg, size); 1051 1052 // Write stride value and compute next one. 1053 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1054 targetStridesBase, indexArg, strideArg); 1055 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1056 1057 // Decrement loop counter and branch back. 1058 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1059 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1060 condBlock); 1061 1062 Block *remainder = 1063 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1064 1065 // Hook up the cond exit to the remainder. 1066 rewriter.setInsertionPointToEnd(condBlock); 1067 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1068 llvm::None); 1069 1070 // Reset position to beginning of new remainder block. 1071 rewriter.setInsertionPointToStart(remainder); 1072 1073 *descriptor = targetDesc; 1074 return success(); 1075 } 1076 }; 1077 1078 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1079 /// `Value`s. 1080 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1081 Type &llvmIndexType, 1082 ArrayRef<OpFoldResult> valueOrAttrVec) { 1083 return llvm::to_vector<4>( 1084 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1085 if (auto attr = value.dyn_cast<Attribute>()) 1086 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1087 return value.get<Value>(); 1088 })); 1089 } 1090 1091 /// Compute a map that for a given dimension of the expanded type gives the 1092 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1093 /// the `reassocation` maps. 1094 static DenseMap<int64_t, int64_t> 1095 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1096 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1097 for (auto &en : enumerate(reassociation)) { 1098 for (auto dim : en.value()) 1099 expandedDimToCollapsedDim[dim] = en.index(); 1100 } 1101 return expandedDimToCollapsedDim; 1102 } 1103 1104 static OpFoldResult 1105 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1106 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1107 MemRefDescriptor &inDesc, 1108 ArrayRef<int64_t> inStaticShape, 1109 ArrayRef<ReassociationIndices> reassocation, 1110 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1111 int64_t outDimSize = outStaticShape[outDimIndex]; 1112 if (!ShapedType::isDynamic(outDimSize)) 1113 return b.getIndexAttr(outDimSize); 1114 1115 // Calculate the multiplication of all the out dim sizes except the 1116 // current dim. 1117 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1118 int64_t otherDimSizesMul = 1; 1119 for (auto otherDimIndex : reassocation[inDimIndex]) { 1120 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1121 continue; 1122 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1123 assert(!ShapedType::isDynamic(otherDimSize) && 1124 "single dimension cannot be expanded into multiple dynamic " 1125 "dimensions"); 1126 otherDimSizesMul *= otherDimSize; 1127 } 1128 1129 // outDimSize = inDimSize / otherOutDimSizesMul 1130 int64_t inDimSize = inStaticShape[inDimIndex]; 1131 Value inDimSizeDynamic = 1132 ShapedType::isDynamic(inDimSize) 1133 ? inDesc.size(b, loc, inDimIndex) 1134 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1135 b.getIndexAttr(inDimSize)); 1136 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1137 loc, inDimSizeDynamic, 1138 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1139 b.getIndexAttr(otherDimSizesMul))); 1140 return outDimSizeDynamic; 1141 } 1142 1143 static OpFoldResult getCollapsedOutputDimSize( 1144 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1145 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1146 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1147 if (!ShapedType::isDynamic(outDimSize)) 1148 return b.getIndexAttr(outDimSize); 1149 1150 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1151 Value outDimSizeDynamic = c1; 1152 for (auto inDimIndex : reassocation[outDimIndex]) { 1153 int64_t inDimSize = inStaticShape[inDimIndex]; 1154 Value inDimSizeDynamic = 1155 ShapedType::isDynamic(inDimSize) 1156 ? inDesc.size(b, loc, inDimIndex) 1157 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1158 b.getIndexAttr(inDimSize)); 1159 outDimSizeDynamic = 1160 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1161 } 1162 return outDimSizeDynamic; 1163 } 1164 1165 static SmallVector<OpFoldResult, 4> 1166 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1167 ArrayRef<ReassociationIndices> reassocation, 1168 ArrayRef<int64_t> inStaticShape, 1169 MemRefDescriptor &inDesc, 1170 ArrayRef<int64_t> outStaticShape) { 1171 return llvm::to_vector<4>(llvm::map_range( 1172 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1173 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1174 outStaticShape[outDimIndex], 1175 inStaticShape, inDesc, reassocation); 1176 })); 1177 } 1178 1179 static SmallVector<OpFoldResult, 4> 1180 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1181 ArrayRef<ReassociationIndices> reassocation, 1182 ArrayRef<int64_t> inStaticShape, 1183 MemRefDescriptor &inDesc, 1184 ArrayRef<int64_t> outStaticShape) { 1185 DenseMap<int64_t, int64_t> outDimToInDimMap = 1186 getExpandedDimToCollapsedDimMap(reassocation); 1187 return llvm::to_vector<4>(llvm::map_range( 1188 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1189 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1190 outStaticShape, inDesc, inStaticShape, 1191 reassocation, outDimToInDimMap); 1192 })); 1193 } 1194 1195 static SmallVector<Value> 1196 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1197 ArrayRef<ReassociationIndices> reassocation, 1198 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1199 ArrayRef<int64_t> outStaticShape) { 1200 return outStaticShape.size() < inStaticShape.size() 1201 ? getAsValues(b, loc, llvmIndexType, 1202 getCollapsedOutputShape(b, loc, llvmIndexType, 1203 reassocation, inStaticShape, 1204 inDesc, outStaticShape)) 1205 : getAsValues(b, loc, llvmIndexType, 1206 getExpandedOutputShape(b, loc, llvmIndexType, 1207 reassocation, inStaticShape, 1208 inDesc, outStaticShape)); 1209 } 1210 1211 // ReshapeOp creates a new view descriptor of the proper rank. 1212 // For now, the only conversion supported is for target MemRef with static sizes 1213 // and strides. 1214 template <typename ReshapeOp> 1215 class ReassociatingReshapeOpConversion 1216 : public ConvertOpToLLVMPattern<ReshapeOp> { 1217 public: 1218 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1219 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1220 1221 LogicalResult 1222 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1223 ConversionPatternRewriter &rewriter) const override { 1224 MemRefType dstType = reshapeOp.getResultType(); 1225 MemRefType srcType = reshapeOp.getSrcType(); 1226 1227 // The condition on the layouts can be ignored when all shapes are static. 1228 if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { 1229 if (!srcType.getLayout().isIdentity() || 1230 !dstType.getLayout().isIdentity()) { 1231 return rewriter.notifyMatchFailure( 1232 reshapeOp, "only empty layout map is supported"); 1233 } 1234 } 1235 1236 int64_t offset; 1237 SmallVector<int64_t, 4> strides; 1238 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1239 return rewriter.notifyMatchFailure( 1240 reshapeOp, "failed to get stride and offset exprs"); 1241 } 1242 1243 MemRefDescriptor srcDesc(adaptor.src()); 1244 Location loc = reshapeOp->getLoc(); 1245 auto dstDesc = MemRefDescriptor::undef( 1246 rewriter, loc, this->typeConverter->convertType(dstType)); 1247 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1248 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1249 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1250 1251 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1252 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1253 Type llvmIndexType = 1254 this->typeConverter->convertType(rewriter.getIndexType()); 1255 SmallVector<Value> dstShape = getDynamicOutputShape( 1256 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1257 srcStaticShape, srcDesc, dstStaticShape); 1258 for (auto &en : llvm::enumerate(dstShape)) 1259 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1260 1261 auto isStaticStride = [](int64_t stride) { 1262 return !ShapedType::isDynamicStrideOrOffset(stride); 1263 }; 1264 if (llvm::all_of(strides, isStaticStride)) { 1265 for (auto &en : llvm::enumerate(strides)) 1266 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1267 } else { 1268 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1269 rewriter.getIndexAttr(1)); 1270 Value stride = c1; 1271 for (auto dimIndex : 1272 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1273 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1274 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1275 } 1276 } 1277 rewriter.replaceOp(reshapeOp, {dstDesc}); 1278 return success(); 1279 } 1280 }; 1281 1282 /// Conversion pattern that transforms a subview op into: 1283 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1284 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1285 /// and stride. 1286 /// The subview op is replaced by the descriptor. 1287 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1288 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1289 1290 LogicalResult 1291 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1292 ConversionPatternRewriter &rewriter) const override { 1293 auto loc = subViewOp.getLoc(); 1294 1295 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1296 auto sourceElementTy = 1297 typeConverter->convertType(sourceMemRefType.getElementType()); 1298 1299 auto viewMemRefType = subViewOp.getType(); 1300 auto inferredType = memref::SubViewOp::inferResultType( 1301 subViewOp.getSourceType(), 1302 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1303 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1304 extractFromI64ArrayAttr(subViewOp.static_strides())) 1305 .cast<MemRefType>(); 1306 auto targetElementTy = 1307 typeConverter->convertType(viewMemRefType.getElementType()); 1308 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1309 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1310 !LLVM::isCompatibleType(sourceElementTy) || 1311 !LLVM::isCompatibleType(targetElementTy) || 1312 !LLVM::isCompatibleType(targetDescTy)) 1313 return failure(); 1314 1315 // Extract the offset and strides from the type. 1316 int64_t offset; 1317 SmallVector<int64_t, 4> strides; 1318 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1319 if (failed(successStrides)) 1320 return failure(); 1321 1322 // Create the descriptor. 1323 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1324 return failure(); 1325 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1326 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1327 1328 // Copy the buffer pointer from the old descriptor to the new one. 1329 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1330 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1331 loc, 1332 LLVM::LLVMPointerType::get(targetElementTy, 1333 viewMemRefType.getMemorySpaceAsInt()), 1334 extracted); 1335 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1336 1337 // Copy the aligned pointer from the old descriptor to the new one. 1338 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1339 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1340 loc, 1341 LLVM::LLVMPointerType::get(targetElementTy, 1342 viewMemRefType.getMemorySpaceAsInt()), 1343 extracted); 1344 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1345 1346 size_t inferredShapeRank = inferredType.getRank(); 1347 size_t resultShapeRank = viewMemRefType.getRank(); 1348 1349 // Extract strides needed to compute offset. 1350 SmallVector<Value, 4> strideValues; 1351 strideValues.reserve(inferredShapeRank); 1352 for (unsigned i = 0; i < inferredShapeRank; ++i) 1353 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1354 1355 // Offset. 1356 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1357 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1358 targetMemRef.setConstantOffset(rewriter, loc, offset); 1359 } else { 1360 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1361 // `inferredShapeRank` may be larger than the number of offset operands 1362 // because of trailing semantics. In this case, the offset is guaranteed 1363 // to be interpreted as 0 and we can just skip the extra dimensions. 1364 for (unsigned i = 0, e = std::min(inferredShapeRank, 1365 subViewOp.getMixedOffsets().size()); 1366 i < e; ++i) { 1367 Value offset = 1368 // TODO: need OpFoldResult ODS adaptor to clean this up. 1369 subViewOp.isDynamicOffset(i) 1370 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1371 : rewriter.create<LLVM::ConstantOp>( 1372 loc, llvmIndexType, 1373 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1374 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1375 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1376 } 1377 targetMemRef.setOffset(rewriter, loc, baseOffset); 1378 } 1379 1380 // Update sizes and strides. 1381 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1382 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1383 assert(mixedSizes.size() == mixedStrides.size() && 1384 "expected sizes and strides of equal length"); 1385 llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims(); 1386 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1387 i >= 0 && j >= 0; --i) { 1388 if (unusedDims.contains(i)) 1389 continue; 1390 1391 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1392 // In this case, the size is guaranteed to be interpreted as Dim and the 1393 // stride as 1. 1394 Value size, stride; 1395 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1396 // If the static size is available, use it directly. This is similar to 1397 // the folding of dim(constant-op) but removes the need for dim to be 1398 // aware of LLVM constants and for this pass to be aware of std 1399 // constants. 1400 int64_t staticSize = 1401 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1402 if (staticSize != ShapedType::kDynamicSize) { 1403 size = rewriter.create<LLVM::ConstantOp>( 1404 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1405 } else { 1406 Value pos = rewriter.create<LLVM::ConstantOp>( 1407 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1408 Value dim = 1409 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1410 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1411 loc, llvmIndexType, dim); 1412 size = cast.getResult(0); 1413 } 1414 stride = rewriter.create<LLVM::ConstantOp>( 1415 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1416 } else { 1417 // TODO: need OpFoldResult ODS adaptor to clean this up. 1418 size = 1419 subViewOp.isDynamicSize(i) 1420 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1421 : rewriter.create<LLVM::ConstantOp>( 1422 loc, llvmIndexType, 1423 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1424 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1425 stride = rewriter.create<LLVM::ConstantOp>( 1426 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1427 } else { 1428 stride = 1429 subViewOp.isDynamicStride(i) 1430 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1431 : rewriter.create<LLVM::ConstantOp>( 1432 loc, llvmIndexType, 1433 rewriter.getI64IntegerAttr( 1434 subViewOp.getStaticStride(i))); 1435 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1436 } 1437 } 1438 targetMemRef.setSize(rewriter, loc, j, size); 1439 targetMemRef.setStride(rewriter, loc, j, stride); 1440 j--; 1441 } 1442 1443 rewriter.replaceOp(subViewOp, {targetMemRef}); 1444 return success(); 1445 } 1446 }; 1447 1448 /// Conversion pattern that transforms a transpose op into: 1449 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1450 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1451 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1452 /// and stride. Size and stride are permutations of the original values. 1453 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1454 /// The transpose op is replaced by the alloca'ed pointer. 1455 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1456 public: 1457 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1458 1459 LogicalResult 1460 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1461 ConversionPatternRewriter &rewriter) const override { 1462 auto loc = transposeOp.getLoc(); 1463 MemRefDescriptor viewMemRef(adaptor.in()); 1464 1465 // No permutation, early exit. 1466 if (transposeOp.permutation().isIdentity()) 1467 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1468 1469 auto targetMemRef = MemRefDescriptor::undef( 1470 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1471 1472 // Copy the base and aligned pointers from the old descriptor to the new 1473 // one. 1474 targetMemRef.setAllocatedPtr(rewriter, loc, 1475 viewMemRef.allocatedPtr(rewriter, loc)); 1476 targetMemRef.setAlignedPtr(rewriter, loc, 1477 viewMemRef.alignedPtr(rewriter, loc)); 1478 1479 // Copy the offset pointer from the old descriptor to the new one. 1480 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1481 1482 // Iterate over the dimensions and apply size/stride permutation. 1483 for (const auto &en : 1484 llvm::enumerate(transposeOp.permutation().getResults())) { 1485 int sourcePos = en.index(); 1486 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1487 targetMemRef.setSize(rewriter, loc, targetPos, 1488 viewMemRef.size(rewriter, loc, sourcePos)); 1489 targetMemRef.setStride(rewriter, loc, targetPos, 1490 viewMemRef.stride(rewriter, loc, sourcePos)); 1491 } 1492 1493 rewriter.replaceOp(transposeOp, {targetMemRef}); 1494 return success(); 1495 } 1496 }; 1497 1498 /// Conversion pattern that transforms an op into: 1499 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1500 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1501 /// and stride. 1502 /// The view op is replaced by the descriptor. 1503 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1504 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1505 1506 // Build and return the value for the idx^th shape dimension, either by 1507 // returning the constant shape dimension or counting the proper dynamic size. 1508 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1509 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1510 unsigned idx) const { 1511 assert(idx < shape.size()); 1512 if (!ShapedType::isDynamic(shape[idx])) 1513 return createIndexConstant(rewriter, loc, shape[idx]); 1514 // Count the number of dynamic dims in range [0, idx] 1515 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1516 return ShapedType::isDynamic(v); 1517 }); 1518 return dynamicSizes[nDynamic]; 1519 } 1520 1521 // Build and return the idx^th stride, either by returning the constant stride 1522 // or by computing the dynamic stride from the current `runningStride` and 1523 // `nextSize`. The caller should keep a running stride and update it with the 1524 // result returned by this function. 1525 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1526 ArrayRef<int64_t> strides, Value nextSize, 1527 Value runningStride, unsigned idx) const { 1528 assert(idx < strides.size()); 1529 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1530 return createIndexConstant(rewriter, loc, strides[idx]); 1531 if (nextSize) 1532 return runningStride 1533 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1534 : nextSize; 1535 assert(!runningStride); 1536 return createIndexConstant(rewriter, loc, 1); 1537 } 1538 1539 LogicalResult 1540 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1541 ConversionPatternRewriter &rewriter) const override { 1542 auto loc = viewOp.getLoc(); 1543 1544 auto viewMemRefType = viewOp.getType(); 1545 auto targetElementTy = 1546 typeConverter->convertType(viewMemRefType.getElementType()); 1547 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1548 if (!targetDescTy || !targetElementTy || 1549 !LLVM::isCompatibleType(targetElementTy) || 1550 !LLVM::isCompatibleType(targetDescTy)) 1551 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1552 failure(); 1553 1554 int64_t offset; 1555 SmallVector<int64_t, 4> strides; 1556 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1557 if (failed(successStrides)) 1558 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1559 assert(offset == 0 && "expected offset to be 0"); 1560 1561 // Create the descriptor. 1562 MemRefDescriptor sourceMemRef(adaptor.source()); 1563 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1564 1565 // Field 1: Copy the allocated pointer, used for malloc/free. 1566 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1567 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1568 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1569 loc, 1570 LLVM::LLVMPointerType::get(targetElementTy, 1571 srcMemRefType.getMemorySpaceAsInt()), 1572 allocatedPtr); 1573 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1574 1575 // Field 2: Copy the actual aligned pointer to payload. 1576 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1577 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1578 alignedPtr, adaptor.byte_shift()); 1579 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1580 loc, 1581 LLVM::LLVMPointerType::get(targetElementTy, 1582 srcMemRefType.getMemorySpaceAsInt()), 1583 alignedPtr); 1584 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1585 1586 // Field 3: The offset in the resulting type must be 0. This is because of 1587 // the type change: an offset on srcType* may not be expressible as an 1588 // offset on dstType*. 1589 targetMemRef.setOffset(rewriter, loc, 1590 createIndexConstant(rewriter, loc, offset)); 1591 1592 // Early exit for 0-D corner case. 1593 if (viewMemRefType.getRank() == 0) 1594 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1595 1596 // Fields 4 and 5: Update sizes and strides. 1597 if (strides.back() != 1) 1598 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1599 failure(); 1600 Value stride = nullptr, nextSize = nullptr; 1601 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1602 // Update size. 1603 Value size = 1604 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1605 targetMemRef.setSize(rewriter, loc, i, size); 1606 // Update stride. 1607 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1608 targetMemRef.setStride(rewriter, loc, i, stride); 1609 nextSize = size; 1610 } 1611 1612 rewriter.replaceOp(viewOp, {targetMemRef}); 1613 return success(); 1614 } 1615 }; 1616 1617 //===----------------------------------------------------------------------===// 1618 // AtomicRMWOpLowering 1619 //===----------------------------------------------------------------------===// 1620 1621 /// Try to match the kind of a std.atomic_rmw to determine whether to use a 1622 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1623 static Optional<LLVM::AtomicBinOp> 1624 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1625 switch (atomicOp.kind()) { 1626 case arith::AtomicRMWKind::addf: 1627 return LLVM::AtomicBinOp::fadd; 1628 case arith::AtomicRMWKind::addi: 1629 return LLVM::AtomicBinOp::add; 1630 case arith::AtomicRMWKind::assign: 1631 return LLVM::AtomicBinOp::xchg; 1632 case arith::AtomicRMWKind::maxs: 1633 return LLVM::AtomicBinOp::max; 1634 case arith::AtomicRMWKind::maxu: 1635 return LLVM::AtomicBinOp::umax; 1636 case arith::AtomicRMWKind::mins: 1637 return LLVM::AtomicBinOp::min; 1638 case arith::AtomicRMWKind::minu: 1639 return LLVM::AtomicBinOp::umin; 1640 case arith::AtomicRMWKind::ori: 1641 return LLVM::AtomicBinOp::_or; 1642 case arith::AtomicRMWKind::andi: 1643 return LLVM::AtomicBinOp::_and; 1644 default: 1645 return llvm::None; 1646 } 1647 llvm_unreachable("Invalid AtomicRMWKind"); 1648 } 1649 1650 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1651 using Base::Base; 1652 1653 LogicalResult 1654 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1655 ConversionPatternRewriter &rewriter) const override { 1656 if (failed(match(atomicOp))) 1657 return failure(); 1658 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1659 if (!maybeKind) 1660 return failure(); 1661 auto resultType = adaptor.value().getType(); 1662 auto memRefType = atomicOp.getMemRefType(); 1663 auto dataPtr = 1664 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1665 adaptor.indices(), rewriter); 1666 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1667 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1668 LLVM::AtomicOrdering::acq_rel); 1669 return success(); 1670 } 1671 }; 1672 1673 } // namespace 1674 1675 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1676 RewritePatternSet &patterns) { 1677 // clang-format off 1678 patterns.add< 1679 AllocaOpLowering, 1680 AllocaScopeOpLowering, 1681 AtomicRMWOpLowering, 1682 AssumeAlignmentOpLowering, 1683 DimOpLowering, 1684 GlobalMemrefOpLowering, 1685 GetGlobalMemrefOpLowering, 1686 LoadOpLowering, 1687 MemRefCastOpLowering, 1688 MemRefCopyOpLowering, 1689 MemRefReinterpretCastOpLowering, 1690 MemRefReshapeOpLowering, 1691 PrefetchOpLowering, 1692 RankOpLowering, 1693 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1694 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1695 StoreOpLowering, 1696 SubViewOpLowering, 1697 TransposeOpLowering, 1698 ViewOpLowering>(converter); 1699 // clang-format on 1700 auto allocLowering = converter.getOptions().allocLowering; 1701 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1702 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1703 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1704 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1705 } 1706 1707 namespace { 1708 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1709 MemRefToLLVMPass() = default; 1710 1711 void runOnOperation() override { 1712 Operation *op = getOperation(); 1713 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1714 LowerToLLVMOptions options(&getContext(), 1715 dataLayoutAnalysis.getAtOrAbove(op)); 1716 options.allocLowering = 1717 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1718 : LowerToLLVMOptions::AllocLowering::Malloc); 1719 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1720 options.overrideIndexBitwidth(indexBitwidth); 1721 1722 LLVMTypeConverter typeConverter(&getContext(), options, 1723 &dataLayoutAnalysis); 1724 RewritePatternSet patterns(&getContext()); 1725 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1726 LLVMConversionTarget target(getContext()); 1727 target.addLegalOp<FuncOp>(); 1728 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1729 signalPassFailure(); 1730 } 1731 }; 1732 } // namespace 1733 1734 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1735 return std::make_unique<MemRefToLLVMPass>(); 1736 } 1737