1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 27 AllocOpLowering(LLVMTypeConverter &converter) 28 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 29 converter) {} 30 31 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 32 Location loc, Value sizeBytes, 33 Operation *op) const override { 34 // Heap allocations. 35 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 36 MemRefType memRefType = allocOp.getType(); 37 38 Value alignment; 39 if (auto alignmentAttr = allocOp.alignment()) { 40 alignment = createIndexConstant(rewriter, loc, *alignmentAttr); 41 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { 42 // In the case where no alignment is specified, we may want to override 43 // `malloc's` behavior. `malloc` typically aligns at the size of the 44 // biggest scalar on a target HW. For non-scalars, use the natural 45 // alignment of the LLVM type given by the LLVM DataLayout. 46 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); 47 } 48 49 if (alignment) { 50 // Adjust the allocation size to consider alignment. 51 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); 52 } 53 54 // Allocate the underlying buffer and store a pointer to it in the MemRef 55 // descriptor. 56 Type elementPtrType = this->getElementPtrType(memRefType); 57 auto allocFuncOp = LLVM::lookupOrCreateMallocFn( 58 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 59 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, 60 getVoidPtrType()); 61 Value allocatedPtr = 62 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 63 64 Value alignedPtr = allocatedPtr; 65 if (alignment) { 66 // Compute the aligned type pointer. 67 Value allocatedInt = 68 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); 69 Value alignmentInt = 70 createAligned(rewriter, loc, allocatedInt, alignment); 71 alignedPtr = 72 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); 73 } 74 75 return std::make_tuple(allocatedPtr, alignedPtr); 76 } 77 }; 78 79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 80 AlignedAllocOpLowering(LLVMTypeConverter &converter) 81 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 82 converter) {} 83 84 /// Returns the memref's element size in bytes using the data layout active at 85 /// `op`. 86 // TODO: there are other places where this is used. Expose publicly? 87 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { 88 const DataLayout *layout = &defaultLayout; 89 if (const DataLayoutAnalysis *analysis = 90 getTypeConverter()->getDataLayoutAnalysis()) { 91 layout = &analysis->getAbove(op); 92 } 93 Type elementType = memRefType.getElementType(); 94 if (auto memRefElementType = elementType.dyn_cast<MemRefType>()) 95 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, 96 *layout); 97 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>()) 98 return getTypeConverter()->getUnrankedMemRefDescriptorSize( 99 memRefElementType, *layout); 100 return layout->getTypeSize(elementType); 101 } 102 103 /// Returns true if the memref size in bytes is known to be a multiple of 104 /// factor assuming the data layout active at `op`. 105 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, 106 Operation *op) const { 107 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); 108 for (unsigned i = 0, e = type.getRank(); i < e; i++) { 109 if (ShapedType::isDynamic(type.getDimSize(i))) 110 continue; 111 sizeDivisor = sizeDivisor * type.getDimSize(i); 112 } 113 return sizeDivisor % factor == 0; 114 } 115 116 /// Returns the alignment to be used for the allocation call itself. 117 /// aligned_alloc requires the allocation size to be a power of two, and the 118 /// allocation size to be a multiple of alignment, 119 int64_t getAllocationAlignment(memref::AllocOp allocOp) const { 120 if (Optional<uint64_t> alignment = allocOp.alignment()) 121 return *alignment; 122 123 // Whenever we don't have alignment set, we will use an alignment 124 // consistent with the element type; since the allocation size has to be a 125 // power of two, we will bump to the next power of two if it already isn't. 126 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); 127 return std::max(kMinAlignedAllocAlignment, 128 llvm::PowerOf2Ceil(eltSizeBytes)); 129 } 130 131 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 132 Location loc, Value sizeBytes, 133 Operation *op) const override { 134 // Heap allocations. 135 memref::AllocOp allocOp = cast<memref::AllocOp>(op); 136 MemRefType memRefType = allocOp.getType(); 137 int64_t alignment = getAllocationAlignment(allocOp); 138 Value allocAlignment = createIndexConstant(rewriter, loc, alignment); 139 140 // aligned_alloc requires size to be a multiple of alignment; we will pad 141 // the size to the next multiple if necessary. 142 if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) 143 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); 144 145 Type elementPtrType = this->getElementPtrType(memRefType); 146 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 147 allocOp->getParentOfType<ModuleOp>(), getIndexType()); 148 auto results = 149 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, 150 getVoidPtrType()); 151 Value allocatedPtr = 152 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); 153 154 return std::make_tuple(allocatedPtr, allocatedPtr); 155 } 156 157 /// The minimum alignment to use with aligned_alloc (has to be a power of 2). 158 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; 159 160 /// Default layout to use in absence of the corresponding analysis. 161 DataLayout defaultLayout; 162 }; 163 164 // Out of line definition, required till C++17. 165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; 166 167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 168 AllocaOpLowering(LLVMTypeConverter &converter) 169 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 170 converter) {} 171 172 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 173 /// is set to null for stack allocations. `accessAlignment` is set if 174 /// alignment is needed post allocation (for eg. in conjunction with malloc). 175 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 176 Location loc, Value sizeBytes, 177 Operation *op) const override { 178 179 // With alloca, one gets a pointer to the element type right away. 180 // For stack allocations. 181 auto allocaOp = cast<memref::AllocaOp>(op); 182 auto elementPtrType = this->getElementPtrType(allocaOp.getType()); 183 184 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( 185 loc, elementPtrType, sizeBytes, 186 allocaOp.alignment() ? *allocaOp.alignment() : 0); 187 188 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 189 } 190 }; 191 192 struct AllocaScopeOpLowering 193 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 194 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 195 196 LogicalResult 197 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 198 ConversionPatternRewriter &rewriter) const override { 199 OpBuilder::InsertionGuard guard(rewriter); 200 Location loc = allocaScopeOp.getLoc(); 201 202 // Split the current block before the AllocaScopeOp to create the inlining 203 // point. 204 auto *currentBlock = rewriter.getInsertionBlock(); 205 auto *remainingOpsBlock = 206 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 207 Block *continueBlock; 208 if (allocaScopeOp.getNumResults() == 0) { 209 continueBlock = remainingOpsBlock; 210 } else { 211 continueBlock = rewriter.createBlock( 212 remainingOpsBlock, allocaScopeOp.getResultTypes(), 213 SmallVector<Location>(allocaScopeOp->getNumResults(), 214 allocaScopeOp.getLoc())); 215 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 216 } 217 218 // Inline body region. 219 Block *beforeBody = &allocaScopeOp.bodyRegion().front(); 220 Block *afterBody = &allocaScopeOp.bodyRegion().back(); 221 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock); 222 223 // Save stack and then branch into the body of the region. 224 rewriter.setInsertionPointToEnd(currentBlock); 225 auto stackSaveOp = 226 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 227 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 228 229 // Replace the alloca_scope return with a branch that jumps out of the body. 230 // Stack restore before leaving the body region. 231 rewriter.setInsertionPointToEnd(afterBody); 232 auto returnOp = 233 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 234 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 235 returnOp, returnOp.results(), continueBlock); 236 237 // Insert stack restore before jumping out the body of the region. 238 rewriter.setInsertionPoint(branchOp); 239 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 240 241 // Replace the op with values return from the body region. 242 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 243 244 return success(); 245 } 246 }; 247 248 struct AssumeAlignmentOpLowering 249 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 250 using ConvertOpToLLVMPattern< 251 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 252 253 LogicalResult 254 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 255 ConversionPatternRewriter &rewriter) const override { 256 Value memref = adaptor.memref(); 257 unsigned alignment = op.alignment(); 258 auto loc = op.getLoc(); 259 260 MemRefDescriptor memRefDescriptor(memref); 261 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); 262 263 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that 264 // the asserted memref.alignedPtr isn't used anywhere else, as the real 265 // users like load/store/views always re-extract memref.alignedPtr as they 266 // get lowered. 267 // 268 // This relies on LLVM's CSE optimization (potentially after SROA), since 269 // after CSE all memref.alignedPtr instances get de-duplicated into the same 270 // pointer SSA value. 271 auto intPtrType = 272 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); 273 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); 274 Value mask = 275 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); 276 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); 277 rewriter.create<LLVM::AssumeOp>( 278 loc, rewriter.create<LLVM::ICmpOp>( 279 loc, LLVM::ICmpPredicate::eq, 280 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); 281 282 rewriter.eraseOp(op); 283 return success(); 284 } 285 }; 286 287 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 288 // The memref descriptor being an SSA value, there is no need to clean it up 289 // in any way. 290 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 291 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 292 293 explicit DeallocOpLowering(LLVMTypeConverter &converter) 294 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 295 296 LogicalResult 297 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 298 ConversionPatternRewriter &rewriter) const override { 299 // Insert the `free` declaration if it is not already present. 300 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 301 MemRefDescriptor memref(adaptor.memref()); 302 Value casted = rewriter.create<LLVM::BitcastOp>( 303 op.getLoc(), getVoidPtrType(), 304 memref.allocatedPtr(rewriter, op.getLoc())); 305 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 306 op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); 307 return success(); 308 } 309 }; 310 311 // A `dim` is converted to a constant for static sizes and to an access to the 312 // size stored in the memref descriptor for dynamic sizes. 313 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 314 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 315 316 LogicalResult 317 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 318 ConversionPatternRewriter &rewriter) const override { 319 Type operandType = dimOp.source().getType(); 320 if (operandType.isa<UnrankedMemRefType>()) { 321 rewriter.replaceOp( 322 dimOp, {extractSizeOfUnrankedMemRef( 323 operandType, dimOp, adaptor.getOperands(), rewriter)}); 324 325 return success(); 326 } 327 if (operandType.isa<MemRefType>()) { 328 rewriter.replaceOp( 329 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 330 adaptor.getOperands(), rewriter)}); 331 return success(); 332 } 333 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 334 } 335 336 private: 337 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 338 OpAdaptor adaptor, 339 ConversionPatternRewriter &rewriter) const { 340 Location loc = dimOp.getLoc(); 341 342 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>(); 343 auto scalarMemRefType = 344 MemRefType::get({}, unrankedMemRefType.getElementType()); 345 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); 346 347 // Extract pointer to the underlying ranked descriptor and bitcast it to a 348 // memref<element_type> descriptor pointer to minimize the number of GEP 349 // operations. 350 UnrankedMemRefDescriptor unrankedDesc(adaptor.source()); 351 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 352 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( 353 loc, 354 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), 355 addressSpace), 356 underlyingRankedDesc); 357 358 // Get pointer to offset field of memref<element_type> descriptor. 359 Type indexPtrTy = LLVM::LLVMPointerType::get( 360 getTypeConverter()->getIndexType(), addressSpace); 361 Value two = rewriter.create<LLVM::ConstantOp>( 362 loc, typeConverter->convertType(rewriter.getI32Type()), 363 rewriter.getI32IntegerAttr(2)); 364 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 365 loc, indexPtrTy, scalarMemRefDescPtr, 366 ValueRange({createIndexConstant(rewriter, loc, 0), two})); 367 368 // The size value that we have to extract can be obtained using GEPop with 369 // `dimOp.index() + 1` index argument. 370 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 371 loc, createIndexConstant(rewriter, loc, 1), adaptor.index()); 372 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, 373 ValueRange({idxPlusOne})); 374 return rewriter.create<LLVM::LoadOp>(loc, sizePtr); 375 } 376 377 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 378 if (Optional<int64_t> idx = dimOp.getConstantIndex()) 379 return idx; 380 381 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>()) 382 return constantOp.getValue() 383 .cast<IntegerAttr>() 384 .getValue() 385 .getSExtValue(); 386 387 return llvm::None; 388 } 389 390 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 391 OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const { 393 Location loc = dimOp.getLoc(); 394 395 // Take advantage if index is constant. 396 MemRefType memRefType = operandType.cast<MemRefType>(); 397 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) { 398 int64_t i = index.getValue(); 399 if (memRefType.isDynamicDim(i)) { 400 // extract dynamic size from the memref descriptor. 401 MemRefDescriptor descriptor(adaptor.source()); 402 return descriptor.size(rewriter, loc, i); 403 } 404 // Use constant for static size. 405 int64_t dimSize = memRefType.getDimSize(i); 406 return createIndexConstant(rewriter, loc, dimSize); 407 } 408 Value index = adaptor.index(); 409 int64_t rank = memRefType.getRank(); 410 MemRefDescriptor memrefDescriptor(adaptor.source()); 411 return memrefDescriptor.size(rewriter, loc, index, rank); 412 } 413 }; 414 415 /// Returns the LLVM type of the global variable given the memref type `type`. 416 static Type convertGlobalMemrefTypeToLLVM(MemRefType type, 417 LLVMTypeConverter &typeConverter) { 418 // LLVM type for a global memref will be a multi-dimension array. For 419 // declarations or uninitialized global memrefs, we can potentially flatten 420 // this to a 1D array. However, for memref.global's with an initial value, 421 // we do not intend to flatten the ElementsAttribute when going from std -> 422 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 423 Type elementType = typeConverter.convertType(type.getElementType()); 424 Type arrayTy = elementType; 425 // Shape has the outermost dim at index 0, so need to walk it backwards 426 for (int64_t dim : llvm::reverse(type.getShape())) 427 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 428 return arrayTy; 429 } 430 431 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 432 struct GlobalMemrefOpLowering 433 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 434 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 435 436 LogicalResult 437 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 438 ConversionPatternRewriter &rewriter) const override { 439 MemRefType type = global.type(); 440 if (!isConvertibleAndHasIdentityMaps(type)) 441 return failure(); 442 443 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 444 445 LLVM::Linkage linkage = 446 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 447 448 Attribute initialValue = nullptr; 449 if (!global.isExternal() && !global.isUninitialized()) { 450 auto elementsAttr = global.initial_value()->cast<ElementsAttr>(); 451 initialValue = elementsAttr; 452 453 // For scalar memrefs, the global variable created is of the element type, 454 // so unpack the elements attribute to extract the value. 455 if (type.getRank() == 0) 456 initialValue = elementsAttr.getSplatValue<Attribute>(); 457 } 458 459 uint64_t alignment = global.alignment().getValueOr(0); 460 461 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 462 global, arrayTy, global.constant(), linkage, global.sym_name(), 463 initialValue, alignment, type.getMemorySpaceAsInt()); 464 if (!global.isExternal() && global.isUninitialized()) { 465 Block *blk = new Block(); 466 newGlobal.getInitializerRegion().push_back(blk); 467 rewriter.setInsertionPointToStart(blk); 468 Value undef[] = { 469 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 470 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 471 } 472 return success(); 473 } 474 }; 475 476 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 477 /// the first element stashed into the descriptor. This reuses 478 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 479 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 480 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) 481 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 482 converter) {} 483 484 /// Buffer "allocation" for memref.get_global op is getting the address of 485 /// the global variable referenced. 486 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 487 Location loc, Value sizeBytes, 488 Operation *op) const override { 489 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 490 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); 491 unsigned memSpace = type.getMemorySpaceAsInt(); 492 493 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 494 auto addressOf = rewriter.create<LLVM::AddressOfOp>( 495 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); 496 497 // Get the address of the first element in the array by creating a GEP with 498 // the address of the GV as the base, and (rank + 1) number of 0 indices. 499 Type elementType = typeConverter->convertType(type.getElementType()); 500 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); 501 502 SmallVector<Value> operands; 503 operands.insert(operands.end(), type.getRank() + 1, 504 createIndexConstant(rewriter, loc, 0)); 505 auto gep = 506 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands); 507 508 // We do not expect the memref obtained using `memref.get_global` to be 509 // ever deallocated. Set the allocated pointer to be known bad value to 510 // help debug if that ever happens. 511 auto intPtrType = getIntPtrType(memSpace); 512 Value deadBeefConst = 513 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 514 auto deadBeefPtr = 515 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst); 516 517 // Both allocated and aligned pointers are same. We could potentially stash 518 // a nullptr for the allocated pointer since we do not expect any dealloc. 519 return std::make_tuple(deadBeefPtr, gep); 520 } 521 }; 522 523 // Common base for load and store operations on MemRefs. Restricts the match 524 // to supported MemRef types. Provides functionality to emit code accessing a 525 // specific element of the underlying data buffer. 526 template <typename Derived> 527 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 528 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 529 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 530 using Base = LoadStoreOpLowering<Derived>; 531 532 LogicalResult match(Derived op) const override { 533 MemRefType type = op.getMemRefType(); 534 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 535 } 536 }; 537 538 // Load operation is lowered to obtaining a pointer to the indexed element 539 // and loading it. 540 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 541 using Base::Base; 542 543 LogicalResult 544 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 545 ConversionPatternRewriter &rewriter) const override { 546 auto type = loadOp.getMemRefType(); 547 548 Value dataPtr = getStridedElementPtr( 549 loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter); 550 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr); 551 return success(); 552 } 553 }; 554 555 // Store operation is lowered to obtaining a pointer to the indexed element, 556 // and storing the given value to it. 557 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 558 using Base::Base; 559 560 LogicalResult 561 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 562 ConversionPatternRewriter &rewriter) const override { 563 auto type = op.getMemRefType(); 564 565 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(), 566 adaptor.indices(), rewriter); 567 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr); 568 return success(); 569 } 570 }; 571 572 // The prefetch operation is lowered in a way similar to the load operation 573 // except that the llvm.prefetch operation is used for replacement. 574 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 575 using Base::Base; 576 577 LogicalResult 578 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 579 ConversionPatternRewriter &rewriter) const override { 580 auto type = prefetchOp.getMemRefType(); 581 auto loc = prefetchOp.getLoc(); 582 583 Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(), 584 adaptor.indices(), rewriter); 585 586 // Replace with llvm.prefetch. 587 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); 588 auto isWrite = rewriter.create<LLVM::ConstantOp>( 589 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); 590 auto localityHint = rewriter.create<LLVM::ConstantOp>( 591 loc, llvmI32Type, 592 rewriter.getI32IntegerAttr(prefetchOp.localityHint())); 593 auto isData = rewriter.create<LLVM::ConstantOp>( 594 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); 595 596 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 597 localityHint, isData); 598 return success(); 599 } 600 }; 601 602 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 603 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 604 605 LogicalResult 606 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 607 ConversionPatternRewriter &rewriter) const override { 608 Location loc = op.getLoc(); 609 Type operandType = op.memref().getType(); 610 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { 611 UnrankedMemRefDescriptor desc(adaptor.memref()); 612 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 613 return success(); 614 } 615 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { 616 rewriter.replaceOp( 617 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); 618 return success(); 619 } 620 return failure(); 621 } 622 }; 623 624 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 625 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 626 627 LogicalResult match(memref::CastOp memRefCastOp) const override { 628 Type srcType = memRefCastOp.getOperand().getType(); 629 Type dstType = memRefCastOp.getType(); 630 631 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 632 // used for type erasure. For now they must preserve underlying element type 633 // and require source and result type to have the same rank. Therefore, 634 // perform a sanity check that the underlying structs are the same. Once op 635 // semantics are relaxed we can revisit. 636 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 637 return success(typeConverter->convertType(srcType) == 638 typeConverter->convertType(dstType)); 639 640 // At least one of the operands is unranked type 641 assert(srcType.isa<UnrankedMemRefType>() || 642 dstType.isa<UnrankedMemRefType>()); 643 644 // Unranked to unranked cast is disallowed 645 return !(srcType.isa<UnrankedMemRefType>() && 646 dstType.isa<UnrankedMemRefType>()) 647 ? success() 648 : failure(); 649 } 650 651 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 652 ConversionPatternRewriter &rewriter) const override { 653 auto srcType = memRefCastOp.getOperand().getType(); 654 auto dstType = memRefCastOp.getType(); 655 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 656 auto loc = memRefCastOp.getLoc(); 657 658 // For ranked/ranked case, just keep the original descriptor. 659 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) 660 return rewriter.replaceOp(memRefCastOp, {adaptor.source()}); 661 662 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { 663 // Casting ranked to unranked memref type 664 // Set the rank in the destination from the memref type 665 // Allocate space on the stack and copy the src memref descriptor 666 // Set the ptr in the destination to the stack space 667 auto srcMemRefType = srcType.cast<MemRefType>(); 668 int64_t rank = srcMemRefType.getRank(); 669 // ptr = AllocaOp sizeof(MemRefDescriptor) 670 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 671 loc, adaptor.source(), rewriter); 672 // voidptr = BitCastOp srcType* to void* 673 auto voidPtr = 674 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 675 .getResult(); 676 // rank = ConstantOp srcRank 677 auto rankVal = rewriter.create<LLVM::ConstantOp>( 678 loc, typeConverter->convertType(rewriter.getIntegerType(64)), 679 rewriter.getI64IntegerAttr(rank)); 680 // undef = UndefOp 681 UnrankedMemRefDescriptor memRefDesc = 682 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 683 // d1 = InsertValueOp undef, rank, 0 684 memRefDesc.setRank(rewriter, loc, rankVal); 685 // d2 = InsertValueOp d1, voidptr, 1 686 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); 687 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 688 689 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { 690 // Casting from unranked type to ranked. 691 // The operation is assumed to be doing a correct cast. If the destination 692 // type mismatches the unranked the type, it is undefined behavior. 693 UnrankedMemRefDescriptor memRefDesc(adaptor.source()); 694 // ptr = ExtractValueOp src, 1 695 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 696 // castPtr = BitCastOp i8* to structTy* 697 auto castPtr = 698 rewriter 699 .create<LLVM::BitcastOp>( 700 loc, LLVM::LLVMPointerType::get(targetStructType), ptr) 701 .getResult(); 702 // struct = LoadOp castPtr 703 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); 704 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 705 } else { 706 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 707 } 708 } 709 }; 710 711 /// Pattern to lower a `memref.copy` to llvm. 712 /// 713 /// For memrefs with identity layouts, the copy is lowered to the llvm 714 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 715 /// to the generic `MemrefCopyFn`. 716 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 717 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 718 719 LogicalResult 720 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 721 ConversionPatternRewriter &rewriter) const { 722 auto loc = op.getLoc(); 723 auto srcType = op.source().getType().dyn_cast<MemRefType>(); 724 725 MemRefDescriptor srcDesc(adaptor.source()); 726 727 // Compute number of elements. 728 Value numElements = rewriter.create<LLVM::ConstantOp>( 729 loc, getIndexType(), rewriter.getIndexAttr(1)); 730 for (int pos = 0; pos < srcType.getRank(); ++pos) { 731 auto size = srcDesc.size(rewriter, loc, pos); 732 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 733 } 734 735 // Get element size. 736 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 737 // Compute total. 738 Value totalSize = 739 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 740 741 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 742 MemRefDescriptor targetDesc(adaptor.target()); 743 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 744 Value isVolatile = rewriter.create<LLVM::ConstantOp>( 745 loc, typeConverter->convertType(rewriter.getI1Type()), 746 rewriter.getBoolAttr(false)); 747 rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize, 748 isVolatile); 749 rewriter.eraseOp(op); 750 751 return success(); 752 } 753 754 LogicalResult 755 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 756 ConversionPatternRewriter &rewriter) const { 757 auto loc = op.getLoc(); 758 auto srcType = op.source().getType().cast<BaseMemRefType>(); 759 auto targetType = op.target().getType().cast<BaseMemRefType>(); 760 761 // First make sure we have an unranked memref descriptor representation. 762 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { 763 auto rank = rewriter.create<LLVM::ConstantOp>( 764 loc, getIndexType(), rewriter.getIndexAttr(type.getRank())); 765 auto *typeConverter = getTypeConverter(); 766 auto ptr = 767 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 768 auto voidPtr = 769 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) 770 .getResult(); 771 auto unrankedType = 772 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 773 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter, 774 unrankedType, 775 ValueRange{rank, voidPtr}); 776 }; 777 778 Value unrankedSource = srcType.hasRank() 779 ? makeUnranked(adaptor.source(), srcType) 780 : adaptor.source(); 781 Value unrankedTarget = targetType.hasRank() 782 ? makeUnranked(adaptor.target(), targetType) 783 : adaptor.target(); 784 785 // Now promote the unranked descriptors to the stack. 786 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 787 rewriter.getIndexAttr(1)); 788 auto promote = [&](Value desc) { 789 auto ptrType = LLVM::LLVMPointerType::get(desc.getType()); 790 auto allocated = 791 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one}); 792 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 793 return allocated; 794 }; 795 796 auto sourcePtr = promote(unrankedSource); 797 auto targetPtr = promote(unrankedTarget); 798 799 auto elemSize = rewriter.create<LLVM::ConstantOp>( 800 loc, getIndexType(), 801 rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8)); 802 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 803 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 804 rewriter.create<LLVM::CallOp>(loc, copyFn, 805 ValueRange{elemSize, sourcePtr, targetPtr}); 806 rewriter.eraseOp(op); 807 808 return success(); 809 } 810 811 LogicalResult 812 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 813 ConversionPatternRewriter &rewriter) const override { 814 auto srcType = op.source().getType().cast<BaseMemRefType>(); 815 auto targetType = op.target().getType().cast<BaseMemRefType>(); 816 817 if (srcType.hasRank() && 818 srcType.cast<MemRefType>().getLayout().isIdentity() && 819 targetType.hasRank() && 820 targetType.cast<MemRefType>().getLayout().isIdentity()) 821 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 822 823 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 824 } 825 }; 826 827 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 828 /// memref type. In unranked case, the fields are extracted from the underlying 829 /// ranked descriptor. 830 static void extractPointersAndOffset(Location loc, 831 ConversionPatternRewriter &rewriter, 832 LLVMTypeConverter &typeConverter, 833 Value originalOperand, 834 Value convertedOperand, 835 Value *allocatedPtr, Value *alignedPtr, 836 Value *offset = nullptr) { 837 Type operandType = originalOperand.getType(); 838 if (operandType.isa<MemRefType>()) { 839 MemRefDescriptor desc(convertedOperand); 840 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 841 *alignedPtr = desc.alignedPtr(rewriter, loc); 842 if (offset != nullptr) 843 *offset = desc.offset(rewriter, loc); 844 return; 845 } 846 847 unsigned memorySpace = 848 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 849 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); 850 Type llvmElementType = typeConverter.convertType(elementType); 851 Type elementPtrPtrType = LLVM::LLVMPointerType::get( 852 LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); 853 854 // Extract pointer to the underlying ranked memref descriptor and cast it to 855 // ElemType**. 856 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 857 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 858 859 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 860 rewriter, loc, underlyingDescPtr, elementPtrPtrType); 861 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 862 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 863 if (offset != nullptr) { 864 *offset = UnrankedMemRefDescriptor::offset( 865 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); 866 } 867 } 868 869 struct MemRefReinterpretCastOpLowering 870 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 871 using ConvertOpToLLVMPattern< 872 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 873 874 LogicalResult 875 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 876 ConversionPatternRewriter &rewriter) const override { 877 Type srcType = castOp.source().getType(); 878 879 Value descriptor; 880 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 881 adaptor, &descriptor))) 882 return failure(); 883 rewriter.replaceOp(castOp, {descriptor}); 884 return success(); 885 } 886 887 private: 888 LogicalResult convertSourceMemRefToDescriptor( 889 ConversionPatternRewriter &rewriter, Type srcType, 890 memref::ReinterpretCastOp castOp, 891 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 892 MemRefType targetMemRefType = 893 castOp.getResult().getType().cast<MemRefType>(); 894 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 895 .dyn_cast_or_null<LLVM::LLVMStructType>(); 896 if (!llvmTargetDescriptorTy) 897 return failure(); 898 899 // Create descriptor. 900 Location loc = castOp.getLoc(); 901 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 902 903 // Set allocated and aligned pointers. 904 Value allocatedPtr, alignedPtr; 905 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 906 castOp.source(), adaptor.source(), &allocatedPtr, 907 &alignedPtr); 908 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 909 desc.setAlignedPtr(rewriter, loc, alignedPtr); 910 911 // Set offset. 912 if (castOp.isDynamicOffset(0)) 913 desc.setOffset(rewriter, loc, adaptor.offsets()[0]); 914 else 915 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 916 917 // Set sizes and strides. 918 unsigned dynSizeId = 0; 919 unsigned dynStrideId = 0; 920 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 921 if (castOp.isDynamicSize(i)) 922 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); 923 else 924 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 925 926 if (castOp.isDynamicStride(i)) 927 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); 928 else 929 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 930 } 931 *descriptor = desc; 932 return success(); 933 } 934 }; 935 936 struct MemRefReshapeOpLowering 937 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 938 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 939 940 LogicalResult 941 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 942 ConversionPatternRewriter &rewriter) const override { 943 Type srcType = reshapeOp.source().getType(); 944 945 Value descriptor; 946 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 947 adaptor, &descriptor))) 948 return failure(); 949 rewriter.replaceOp(reshapeOp, {descriptor}); 950 return success(); 951 } 952 953 private: 954 LogicalResult 955 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 956 Type srcType, memref::ReshapeOp reshapeOp, 957 memref::ReshapeOp::Adaptor adaptor, 958 Value *descriptor) const { 959 // Conversion for statically-known shape args is performed via 960 // `memref_reinterpret_cast`. 961 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); 962 if (shapeMemRefType.hasStaticShape()) 963 return failure(); 964 965 // The shape is a rank-1 tensor with unknown length. 966 Location loc = reshapeOp.getLoc(); 967 MemRefDescriptor shapeDesc(adaptor.shape()); 968 Value resultRank = shapeDesc.size(rewriter, loc, 0); 969 970 // Extract address space and element type. 971 auto targetType = 972 reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); 973 unsigned addressSpace = targetType.getMemorySpaceAsInt(); 974 Type elementType = targetType.getElementType(); 975 976 // Create the unranked memref descriptor that holds the ranked one. The 977 // inner descriptor is allocated on stack. 978 auto targetDesc = UnrankedMemRefDescriptor::undef( 979 rewriter, loc, typeConverter->convertType(targetType)); 980 targetDesc.setRank(rewriter, loc, resultRank); 981 SmallVector<Value, 4> sizes; 982 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 983 targetDesc, sizes); 984 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 985 loc, getVoidPtrType(), sizes.front(), llvm::None); 986 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 987 988 // Extract pointers and offset from the source memref. 989 Value allocatedPtr, alignedPtr, offset; 990 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 991 reshapeOp.source(), adaptor.source(), 992 &allocatedPtr, &alignedPtr, &offset); 993 994 // Set pointers and offset. 995 Type llvmElementType = typeConverter->convertType(elementType); 996 auto elementPtrPtrType = LLVM::LLVMPointerType::get( 997 LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); 998 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 999 elementPtrPtrType, allocatedPtr); 1000 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1001 underlyingDescPtr, 1002 elementPtrPtrType, alignedPtr); 1003 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1004 underlyingDescPtr, elementPtrPtrType, 1005 offset); 1006 1007 // Use the offset pointer as base for further addressing. Copy over the new 1008 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1009 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1010 rewriter, loc, *getTypeConverter(), underlyingDescPtr, 1011 elementPtrPtrType); 1012 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1013 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1014 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1015 Value oneIndex = createIndexConstant(rewriter, loc, 1); 1016 Value resultRankMinusOne = 1017 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1018 1019 Block *initBlock = rewriter.getInsertionBlock(); 1020 Type indexType = getTypeConverter()->getIndexType(); 1021 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1022 1023 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1024 {indexType, indexType}, {loc, loc}); 1025 1026 // Move the remaining initBlock ops to condBlock. 1027 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1028 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1029 1030 rewriter.setInsertionPointToEnd(initBlock); 1031 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1032 condBlock); 1033 rewriter.setInsertionPointToStart(condBlock); 1034 Value indexArg = condBlock->getArgument(0); 1035 Value strideArg = condBlock->getArgument(1); 1036 1037 Value zeroIndex = createIndexConstant(rewriter, loc, 0); 1038 Value pred = rewriter.create<LLVM::ICmpOp>( 1039 loc, IntegerType::get(rewriter.getContext(), 1), 1040 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1041 1042 Block *bodyBlock = 1043 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1044 rewriter.setInsertionPointToStart(bodyBlock); 1045 1046 // Copy size from shape to descriptor. 1047 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); 1048 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1049 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); 1050 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); 1051 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1052 targetSizesBase, indexArg, size); 1053 1054 // Write stride value and compute next one. 1055 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1056 targetStridesBase, indexArg, strideArg); 1057 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1058 1059 // Decrement loop counter and branch back. 1060 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1061 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1062 condBlock); 1063 1064 Block *remainder = 1065 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1066 1067 // Hook up the cond exit to the remainder. 1068 rewriter.setInsertionPointToEnd(condBlock); 1069 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, 1070 llvm::None); 1071 1072 // Reset position to beginning of new remainder block. 1073 rewriter.setInsertionPointToStart(remainder); 1074 1075 *descriptor = targetDesc; 1076 return success(); 1077 } 1078 }; 1079 1080 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 1081 /// `Value`s. 1082 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 1083 Type &llvmIndexType, 1084 ArrayRef<OpFoldResult> valueOrAttrVec) { 1085 return llvm::to_vector<4>( 1086 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 1087 if (auto attr = value.dyn_cast<Attribute>()) 1088 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr); 1089 return value.get<Value>(); 1090 })); 1091 } 1092 1093 /// Compute a map that for a given dimension of the expanded type gives the 1094 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1095 /// the `reassocation` maps. 1096 static DenseMap<int64_t, int64_t> 1097 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) { 1098 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1099 for (auto &en : enumerate(reassociation)) { 1100 for (auto dim : en.value()) 1101 expandedDimToCollapsedDim[dim] = en.index(); 1102 } 1103 return expandedDimToCollapsedDim; 1104 } 1105 1106 static OpFoldResult 1107 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType, 1108 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape, 1109 MemRefDescriptor &inDesc, 1110 ArrayRef<int64_t> inStaticShape, 1111 ArrayRef<ReassociationIndices> reassocation, 1112 DenseMap<int64_t, int64_t> &outDimToInDimMap) { 1113 int64_t outDimSize = outStaticShape[outDimIndex]; 1114 if (!ShapedType::isDynamic(outDimSize)) 1115 return b.getIndexAttr(outDimSize); 1116 1117 // Calculate the multiplication of all the out dim sizes except the 1118 // current dim. 1119 int64_t inDimIndex = outDimToInDimMap[outDimIndex]; 1120 int64_t otherDimSizesMul = 1; 1121 for (auto otherDimIndex : reassocation[inDimIndex]) { 1122 if (otherDimIndex == static_cast<unsigned>(outDimIndex)) 1123 continue; 1124 int64_t otherDimSize = outStaticShape[otherDimIndex]; 1125 assert(!ShapedType::isDynamic(otherDimSize) && 1126 "single dimension cannot be expanded into multiple dynamic " 1127 "dimensions"); 1128 otherDimSizesMul *= otherDimSize; 1129 } 1130 1131 // outDimSize = inDimSize / otherOutDimSizesMul 1132 int64_t inDimSize = inStaticShape[inDimIndex]; 1133 Value inDimSizeDynamic = 1134 ShapedType::isDynamic(inDimSize) 1135 ? inDesc.size(b, loc, inDimIndex) 1136 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1137 b.getIndexAttr(inDimSize)); 1138 Value outDimSizeDynamic = b.create<LLVM::SDivOp>( 1139 loc, inDimSizeDynamic, 1140 b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1141 b.getIndexAttr(otherDimSizesMul))); 1142 return outDimSizeDynamic; 1143 } 1144 1145 static OpFoldResult getCollapsedOutputDimSize( 1146 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex, 1147 int64_t outDimSize, ArrayRef<int64_t> inStaticShape, 1148 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) { 1149 if (!ShapedType::isDynamic(outDimSize)) 1150 return b.getIndexAttr(outDimSize); 1151 1152 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1)); 1153 Value outDimSizeDynamic = c1; 1154 for (auto inDimIndex : reassocation[outDimIndex]) { 1155 int64_t inDimSize = inStaticShape[inDimIndex]; 1156 Value inDimSizeDynamic = 1157 ShapedType::isDynamic(inDimSize) 1158 ? inDesc.size(b, loc, inDimIndex) 1159 : b.create<LLVM::ConstantOp>(loc, llvmIndexType, 1160 b.getIndexAttr(inDimSize)); 1161 outDimSizeDynamic = 1162 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic); 1163 } 1164 return outDimSizeDynamic; 1165 } 1166 1167 static SmallVector<OpFoldResult, 4> 1168 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1169 ArrayRef<ReassociationIndices> reassocation, 1170 ArrayRef<int64_t> inStaticShape, 1171 MemRefDescriptor &inDesc, 1172 ArrayRef<int64_t> outStaticShape) { 1173 return llvm::to_vector<4>(llvm::map_range( 1174 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1175 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1176 outStaticShape[outDimIndex], 1177 inStaticShape, inDesc, reassocation); 1178 })); 1179 } 1180 1181 static SmallVector<OpFoldResult, 4> 1182 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1183 ArrayRef<ReassociationIndices> reassocation, 1184 ArrayRef<int64_t> inStaticShape, 1185 MemRefDescriptor &inDesc, 1186 ArrayRef<int64_t> outStaticShape) { 1187 DenseMap<int64_t, int64_t> outDimToInDimMap = 1188 getExpandedDimToCollapsedDimMap(reassocation); 1189 return llvm::to_vector<4>(llvm::map_range( 1190 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) { 1191 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex, 1192 outStaticShape, inDesc, inStaticShape, 1193 reassocation, outDimToInDimMap); 1194 })); 1195 } 1196 1197 static SmallVector<Value> 1198 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType, 1199 ArrayRef<ReassociationIndices> reassocation, 1200 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc, 1201 ArrayRef<int64_t> outStaticShape) { 1202 return outStaticShape.size() < inStaticShape.size() 1203 ? getAsValues(b, loc, llvmIndexType, 1204 getCollapsedOutputShape(b, loc, llvmIndexType, 1205 reassocation, inStaticShape, 1206 inDesc, outStaticShape)) 1207 : getAsValues(b, loc, llvmIndexType, 1208 getExpandedOutputShape(b, loc, llvmIndexType, 1209 reassocation, inStaticShape, 1210 inDesc, outStaticShape)); 1211 } 1212 1213 // ReshapeOp creates a new view descriptor of the proper rank. 1214 // For now, the only conversion supported is for target MemRef with static sizes 1215 // and strides. 1216 template <typename ReshapeOp> 1217 class ReassociatingReshapeOpConversion 1218 : public ConvertOpToLLVMPattern<ReshapeOp> { 1219 public: 1220 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1221 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1222 1223 LogicalResult 1224 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1225 ConversionPatternRewriter &rewriter) const override { 1226 MemRefType dstType = reshapeOp.getResultType(); 1227 MemRefType srcType = reshapeOp.getSrcType(); 1228 1229 // The condition on the layouts can be ignored when all shapes are static. 1230 if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { 1231 if (!srcType.getLayout().isIdentity() || 1232 !dstType.getLayout().isIdentity()) { 1233 return rewriter.notifyMatchFailure( 1234 reshapeOp, "only empty layout map is supported"); 1235 } 1236 } 1237 1238 int64_t offset; 1239 SmallVector<int64_t, 4> strides; 1240 if (failed(getStridesAndOffset(dstType, strides, offset))) { 1241 return rewriter.notifyMatchFailure( 1242 reshapeOp, "failed to get stride and offset exprs"); 1243 } 1244 1245 MemRefDescriptor srcDesc(adaptor.src()); 1246 Location loc = reshapeOp->getLoc(); 1247 auto dstDesc = MemRefDescriptor::undef( 1248 rewriter, loc, this->typeConverter->convertType(dstType)); 1249 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc)); 1250 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc)); 1251 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc)); 1252 1253 ArrayRef<int64_t> srcStaticShape = srcType.getShape(); 1254 ArrayRef<int64_t> dstStaticShape = dstType.getShape(); 1255 Type llvmIndexType = 1256 this->typeConverter->convertType(rewriter.getIndexType()); 1257 SmallVector<Value> dstShape = getDynamicOutputShape( 1258 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(), 1259 srcStaticShape, srcDesc, dstStaticShape); 1260 for (auto &en : llvm::enumerate(dstShape)) 1261 dstDesc.setSize(rewriter, loc, en.index(), en.value()); 1262 1263 auto isStaticStride = [](int64_t stride) { 1264 return !ShapedType::isDynamicStrideOrOffset(stride); 1265 }; 1266 if (llvm::all_of(strides, isStaticStride)) { 1267 for (auto &en : llvm::enumerate(strides)) 1268 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); 1269 } else { 1270 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType, 1271 rewriter.getIndexAttr(1)); 1272 Value stride = c1; 1273 for (auto dimIndex : 1274 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) { 1275 dstDesc.setStride(rewriter, loc, dimIndex, stride); 1276 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride); 1277 } 1278 } 1279 rewriter.replaceOp(reshapeOp, {dstDesc}); 1280 return success(); 1281 } 1282 }; 1283 1284 /// Conversion pattern that transforms a subview op into: 1285 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1286 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1287 /// and stride. 1288 /// The subview op is replaced by the descriptor. 1289 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1290 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1291 1292 LogicalResult 1293 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1294 ConversionPatternRewriter &rewriter) const override { 1295 auto loc = subViewOp.getLoc(); 1296 1297 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); 1298 auto sourceElementTy = 1299 typeConverter->convertType(sourceMemRefType.getElementType()); 1300 1301 auto viewMemRefType = subViewOp.getType(); 1302 auto inferredType = memref::SubViewOp::inferResultType( 1303 subViewOp.getSourceType(), 1304 extractFromI64ArrayAttr(subViewOp.static_offsets()), 1305 extractFromI64ArrayAttr(subViewOp.static_sizes()), 1306 extractFromI64ArrayAttr(subViewOp.static_strides())) 1307 .cast<MemRefType>(); 1308 auto targetElementTy = 1309 typeConverter->convertType(viewMemRefType.getElementType()); 1310 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1311 if (!sourceElementTy || !targetDescTy || !targetElementTy || 1312 !LLVM::isCompatibleType(sourceElementTy) || 1313 !LLVM::isCompatibleType(targetElementTy) || 1314 !LLVM::isCompatibleType(targetDescTy)) 1315 return failure(); 1316 1317 // Extract the offset and strides from the type. 1318 int64_t offset; 1319 SmallVector<int64_t, 4> strides; 1320 auto successStrides = getStridesAndOffset(inferredType, strides, offset); 1321 if (failed(successStrides)) 1322 return failure(); 1323 1324 // Create the descriptor. 1325 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1326 return failure(); 1327 MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); 1328 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1329 1330 // Copy the buffer pointer from the old descriptor to the new one. 1331 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); 1332 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1333 loc, 1334 LLVM::LLVMPointerType::get(targetElementTy, 1335 viewMemRefType.getMemorySpaceAsInt()), 1336 extracted); 1337 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1338 1339 // Copy the aligned pointer from the old descriptor to the new one. 1340 extracted = sourceMemRef.alignedPtr(rewriter, loc); 1341 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1342 loc, 1343 LLVM::LLVMPointerType::get(targetElementTy, 1344 viewMemRefType.getMemorySpaceAsInt()), 1345 extracted); 1346 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1347 1348 size_t inferredShapeRank = inferredType.getRank(); 1349 size_t resultShapeRank = viewMemRefType.getRank(); 1350 1351 // Extract strides needed to compute offset. 1352 SmallVector<Value, 4> strideValues; 1353 strideValues.reserve(inferredShapeRank); 1354 for (unsigned i = 0; i < inferredShapeRank; ++i) 1355 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); 1356 1357 // Offset. 1358 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); 1359 if (!ShapedType::isDynamicStrideOrOffset(offset)) { 1360 targetMemRef.setConstantOffset(rewriter, loc, offset); 1361 } else { 1362 Value baseOffset = sourceMemRef.offset(rewriter, loc); 1363 // `inferredShapeRank` may be larger than the number of offset operands 1364 // because of trailing semantics. In this case, the offset is guaranteed 1365 // to be interpreted as 0 and we can just skip the extra dimensions. 1366 for (unsigned i = 0, e = std::min(inferredShapeRank, 1367 subViewOp.getMixedOffsets().size()); 1368 i < e; ++i) { 1369 Value offset = 1370 // TODO: need OpFoldResult ODS adaptor to clean this up. 1371 subViewOp.isDynamicOffset(i) 1372 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)] 1373 : rewriter.create<LLVM::ConstantOp>( 1374 loc, llvmIndexType, 1375 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); 1376 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]); 1377 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul); 1378 } 1379 targetMemRef.setOffset(rewriter, loc, baseOffset); 1380 } 1381 1382 // Update sizes and strides. 1383 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 1384 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 1385 assert(mixedSizes.size() == mixedStrides.size() && 1386 "expected sizes and strides of equal length"); 1387 llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims(); 1388 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; 1389 i >= 0 && j >= 0; --i) { 1390 if (unusedDims.contains(i)) 1391 continue; 1392 1393 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. 1394 // In this case, the size is guaranteed to be interpreted as Dim and the 1395 // stride as 1. 1396 Value size, stride; 1397 if (static_cast<unsigned>(i) >= mixedSizes.size()) { 1398 // If the static size is available, use it directly. This is similar to 1399 // the folding of dim(constant-op) but removes the need for dim to be 1400 // aware of LLVM constants and for this pass to be aware of std 1401 // constants. 1402 int64_t staticSize = 1403 subViewOp.source().getType().cast<MemRefType>().getShape()[i]; 1404 if (staticSize != ShapedType::kDynamicSize) { 1405 size = rewriter.create<LLVM::ConstantOp>( 1406 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize)); 1407 } else { 1408 Value pos = rewriter.create<LLVM::ConstantOp>( 1409 loc, llvmIndexType, rewriter.getI64IntegerAttr(i)); 1410 Value dim = 1411 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos); 1412 auto cast = rewriter.create<UnrealizedConversionCastOp>( 1413 loc, llvmIndexType, dim); 1414 size = cast.getResult(0); 1415 } 1416 stride = rewriter.create<LLVM::ConstantOp>( 1417 loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); 1418 } else { 1419 // TODO: need OpFoldResult ODS adaptor to clean this up. 1420 size = 1421 subViewOp.isDynamicSize(i) 1422 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)] 1423 : rewriter.create<LLVM::ConstantOp>( 1424 loc, llvmIndexType, 1425 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); 1426 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { 1427 stride = rewriter.create<LLVM::ConstantOp>( 1428 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); 1429 } else { 1430 stride = 1431 subViewOp.isDynamicStride(i) 1432 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)] 1433 : rewriter.create<LLVM::ConstantOp>( 1434 loc, llvmIndexType, 1435 rewriter.getI64IntegerAttr( 1436 subViewOp.getStaticStride(i))); 1437 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]); 1438 } 1439 } 1440 targetMemRef.setSize(rewriter, loc, j, size); 1441 targetMemRef.setStride(rewriter, loc, j, stride); 1442 j--; 1443 } 1444 1445 rewriter.replaceOp(subViewOp, {targetMemRef}); 1446 return success(); 1447 } 1448 }; 1449 1450 /// Conversion pattern that transforms a transpose op into: 1451 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1452 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1453 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1454 /// and stride. Size and stride are permutations of the original values. 1455 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1456 /// The transpose op is replaced by the alloca'ed pointer. 1457 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1458 public: 1459 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1460 1461 LogicalResult 1462 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1463 ConversionPatternRewriter &rewriter) const override { 1464 auto loc = transposeOp.getLoc(); 1465 MemRefDescriptor viewMemRef(adaptor.in()); 1466 1467 // No permutation, early exit. 1468 if (transposeOp.permutation().isIdentity()) 1469 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1470 1471 auto targetMemRef = MemRefDescriptor::undef( 1472 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); 1473 1474 // Copy the base and aligned pointers from the old descriptor to the new 1475 // one. 1476 targetMemRef.setAllocatedPtr(rewriter, loc, 1477 viewMemRef.allocatedPtr(rewriter, loc)); 1478 targetMemRef.setAlignedPtr(rewriter, loc, 1479 viewMemRef.alignedPtr(rewriter, loc)); 1480 1481 // Copy the offset pointer from the old descriptor to the new one. 1482 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1483 1484 // Iterate over the dimensions and apply size/stride permutation. 1485 for (const auto &en : 1486 llvm::enumerate(transposeOp.permutation().getResults())) { 1487 int sourcePos = en.index(); 1488 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 1489 targetMemRef.setSize(rewriter, loc, targetPos, 1490 viewMemRef.size(rewriter, loc, sourcePos)); 1491 targetMemRef.setStride(rewriter, loc, targetPos, 1492 viewMemRef.stride(rewriter, loc, sourcePos)); 1493 } 1494 1495 rewriter.replaceOp(transposeOp, {targetMemRef}); 1496 return success(); 1497 } 1498 }; 1499 1500 /// Conversion pattern that transforms an op into: 1501 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1502 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1503 /// and stride. 1504 /// The view op is replaced by the descriptor. 1505 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1506 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1507 1508 // Build and return the value for the idx^th shape dimension, either by 1509 // returning the constant shape dimension or counting the proper dynamic size. 1510 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1511 ArrayRef<int64_t> shape, ValueRange dynamicSizes, 1512 unsigned idx) const { 1513 assert(idx < shape.size()); 1514 if (!ShapedType::isDynamic(shape[idx])) 1515 return createIndexConstant(rewriter, loc, shape[idx]); 1516 // Count the number of dynamic dims in range [0, idx] 1517 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { 1518 return ShapedType::isDynamic(v); 1519 }); 1520 return dynamicSizes[nDynamic]; 1521 } 1522 1523 // Build and return the idx^th stride, either by returning the constant stride 1524 // or by computing the dynamic stride from the current `runningStride` and 1525 // `nextSize`. The caller should keep a running stride and update it with the 1526 // result returned by this function. 1527 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1528 ArrayRef<int64_t> strides, Value nextSize, 1529 Value runningStride, unsigned idx) const { 1530 assert(idx < strides.size()); 1531 if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) 1532 return createIndexConstant(rewriter, loc, strides[idx]); 1533 if (nextSize) 1534 return runningStride 1535 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1536 : nextSize; 1537 assert(!runningStride); 1538 return createIndexConstant(rewriter, loc, 1); 1539 } 1540 1541 LogicalResult 1542 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1543 ConversionPatternRewriter &rewriter) const override { 1544 auto loc = viewOp.getLoc(); 1545 1546 auto viewMemRefType = viewOp.getType(); 1547 auto targetElementTy = 1548 typeConverter->convertType(viewMemRefType.getElementType()); 1549 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1550 if (!targetDescTy || !targetElementTy || 1551 !LLVM::isCompatibleType(targetElementTy) || 1552 !LLVM::isCompatibleType(targetDescTy)) 1553 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1554 failure(); 1555 1556 int64_t offset; 1557 SmallVector<int64_t, 4> strides; 1558 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); 1559 if (failed(successStrides)) 1560 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1561 assert(offset == 0 && "expected offset to be 0"); 1562 1563 // Create the descriptor. 1564 MemRefDescriptor sourceMemRef(adaptor.source()); 1565 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1566 1567 // Field 1: Copy the allocated pointer, used for malloc/free. 1568 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1569 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>(); 1570 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1571 loc, 1572 LLVM::LLVMPointerType::get(targetElementTy, 1573 srcMemRefType.getMemorySpaceAsInt()), 1574 allocatedPtr); 1575 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); 1576 1577 // Field 2: Copy the actual aligned pointer to payload. 1578 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1579 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(), 1580 alignedPtr, adaptor.byte_shift()); 1581 bitcastPtr = rewriter.create<LLVM::BitcastOp>( 1582 loc, 1583 LLVM::LLVMPointerType::get(targetElementTy, 1584 srcMemRefType.getMemorySpaceAsInt()), 1585 alignedPtr); 1586 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); 1587 1588 // Field 3: The offset in the resulting type must be 0. This is because of 1589 // the type change: an offset on srcType* may not be expressible as an 1590 // offset on dstType*. 1591 targetMemRef.setOffset(rewriter, loc, 1592 createIndexConstant(rewriter, loc, offset)); 1593 1594 // Early exit for 0-D corner case. 1595 if (viewMemRefType.getRank() == 0) 1596 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1597 1598 // Fields 4 and 5: Update sizes and strides. 1599 if (strides.back() != 1) 1600 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1601 failure(); 1602 Value stride = nullptr, nextSize = nullptr; 1603 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1604 // Update size. 1605 Value size = 1606 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); 1607 targetMemRef.setSize(rewriter, loc, i, size); 1608 // Update stride. 1609 stride = getStride(rewriter, loc, strides, nextSize, stride, i); 1610 targetMemRef.setStride(rewriter, loc, i, stride); 1611 nextSize = size; 1612 } 1613 1614 rewriter.replaceOp(viewOp, {targetMemRef}); 1615 return success(); 1616 } 1617 }; 1618 1619 //===----------------------------------------------------------------------===// 1620 // AtomicRMWOpLowering 1621 //===----------------------------------------------------------------------===// 1622 1623 /// Try to match the kind of a std.atomic_rmw to determine whether to use a 1624 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1625 static Optional<LLVM::AtomicBinOp> 1626 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1627 switch (atomicOp.kind()) { 1628 case arith::AtomicRMWKind::addf: 1629 return LLVM::AtomicBinOp::fadd; 1630 case arith::AtomicRMWKind::addi: 1631 return LLVM::AtomicBinOp::add; 1632 case arith::AtomicRMWKind::assign: 1633 return LLVM::AtomicBinOp::xchg; 1634 case arith::AtomicRMWKind::maxs: 1635 return LLVM::AtomicBinOp::max; 1636 case arith::AtomicRMWKind::maxu: 1637 return LLVM::AtomicBinOp::umax; 1638 case arith::AtomicRMWKind::mins: 1639 return LLVM::AtomicBinOp::min; 1640 case arith::AtomicRMWKind::minu: 1641 return LLVM::AtomicBinOp::umin; 1642 case arith::AtomicRMWKind::ori: 1643 return LLVM::AtomicBinOp::_or; 1644 case arith::AtomicRMWKind::andi: 1645 return LLVM::AtomicBinOp::_and; 1646 default: 1647 return llvm::None; 1648 } 1649 llvm_unreachable("Invalid AtomicRMWKind"); 1650 } 1651 1652 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1653 using Base::Base; 1654 1655 LogicalResult 1656 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1657 ConversionPatternRewriter &rewriter) const override { 1658 if (failed(match(atomicOp))) 1659 return failure(); 1660 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1661 if (!maybeKind) 1662 return failure(); 1663 auto resultType = adaptor.value().getType(); 1664 auto memRefType = atomicOp.getMemRefType(); 1665 auto dataPtr = 1666 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), 1667 adaptor.indices(), rewriter); 1668 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1669 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), 1670 LLVM::AtomicOrdering::acq_rel); 1671 return success(); 1672 } 1673 }; 1674 1675 } // namespace 1676 1677 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, 1678 RewritePatternSet &patterns) { 1679 // clang-format off 1680 patterns.add< 1681 AllocaOpLowering, 1682 AllocaScopeOpLowering, 1683 AtomicRMWOpLowering, 1684 AssumeAlignmentOpLowering, 1685 DimOpLowering, 1686 GlobalMemrefOpLowering, 1687 GetGlobalMemrefOpLowering, 1688 LoadOpLowering, 1689 MemRefCastOpLowering, 1690 MemRefCopyOpLowering, 1691 MemRefReinterpretCastOpLowering, 1692 MemRefReshapeOpLowering, 1693 PrefetchOpLowering, 1694 RankOpLowering, 1695 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1696 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1697 StoreOpLowering, 1698 SubViewOpLowering, 1699 TransposeOpLowering, 1700 ViewOpLowering>(converter); 1701 // clang-format on 1702 auto allocLowering = converter.getOptions().allocLowering; 1703 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1704 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1705 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1706 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1707 } 1708 1709 namespace { 1710 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> { 1711 MemRefToLLVMPass() = default; 1712 1713 void runOnOperation() override { 1714 Operation *op = getOperation(); 1715 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1716 LowerToLLVMOptions options(&getContext(), 1717 dataLayoutAnalysis.getAtOrAbove(op)); 1718 options.allocLowering = 1719 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1720 : LowerToLLVMOptions::AllocLowering::Malloc); 1721 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1722 options.overrideIndexBitwidth(indexBitwidth); 1723 1724 LLVMTypeConverter typeConverter(&getContext(), options, 1725 &dataLayoutAnalysis); 1726 RewritePatternSet patterns(&getContext()); 1727 populateMemRefToLLVMConversionPatterns(typeConverter, patterns); 1728 LLVMConversionTarget target(getContext()); 1729 target.addLegalOp<FuncOp>(); 1730 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1731 signalPassFailure(); 1732 } 1733 }; 1734 } // namespace 1735 1736 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() { 1737 return std::make_unique<MemRefToLLVMPass>(); 1738 } 1739