175e5f0aaSAlex Zinenko //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
275e5f0aaSAlex Zinenko //
375e5f0aaSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
475e5f0aaSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
575e5f0aaSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
675e5f0aaSAlex Zinenko //
775e5f0aaSAlex Zinenko //===----------------------------------------------------------------------===//
875e5f0aaSAlex Zinenko
975e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
1075e5f0aaSAlex Zinenko #include "../PassDetail.h"
1175e5f0aaSAlex Zinenko #include "mlir/Analysis/DataLayoutAnalysis.h"
1275e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1375e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1475e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1575e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1736550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1875e5f0aaSAlex Zinenko #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1975e5f0aaSAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2075e5f0aaSAlex Zinenko #include "mlir/Dialect/MemRef/IR/MemRef.h"
2175e5f0aaSAlex Zinenko #include "mlir/IR/AffineMap.h"
2275e5f0aaSAlex Zinenko #include "mlir/IR/BlockAndValueMapping.h"
236635c12aSBenjamin Kramer #include "llvm/ADT/SmallBitVector.h"
2475e5f0aaSAlex Zinenko
2575e5f0aaSAlex Zinenko using namespace mlir;
2675e5f0aaSAlex Zinenko
2775e5f0aaSAlex Zinenko namespace {
2875e5f0aaSAlex Zinenko
isStaticStrideOrOffset(int64_t strideOrOffset)295380e30eSAshay Rane bool isStaticStrideOrOffset(int64_t strideOrOffset) {
305380e30eSAshay Rane return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
315380e30eSAshay Rane }
325380e30eSAshay Rane
3375e5f0aaSAlex Zinenko struct AllocOpLowering : public AllocLikeOpLLVMLowering {
AllocOpLowering__anon7a9e10510111::AllocOpLowering3475e5f0aaSAlex Zinenko AllocOpLowering(LLVMTypeConverter &converter)
3575e5f0aaSAlex Zinenko : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
3675e5f0aaSAlex Zinenko converter) {}
3775e5f0aaSAlex Zinenko
getAllocFn__anon7a9e10510111::AllocOpLowering38a8601f11SMichele Scuttari LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
39a8601f11SMichele Scuttari bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
40a8601f11SMichele Scuttari
41a8601f11SMichele Scuttari if (useGenericFn)
42a8601f11SMichele Scuttari return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType());
43a8601f11SMichele Scuttari
44a8601f11SMichele Scuttari return LLVM::lookupOrCreateMallocFn(module, getIndexType());
45a8601f11SMichele Scuttari }
46a8601f11SMichele Scuttari
allocateBuffer__anon7a9e10510111::AllocOpLowering4775e5f0aaSAlex Zinenko std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
4875e5f0aaSAlex Zinenko Location loc, Value sizeBytes,
4975e5f0aaSAlex Zinenko Operation *op) const override {
5075e5f0aaSAlex Zinenko // Heap allocations.
5175e5f0aaSAlex Zinenko memref::AllocOp allocOp = cast<memref::AllocOp>(op);
5275e5f0aaSAlex Zinenko MemRefType memRefType = allocOp.getType();
5375e5f0aaSAlex Zinenko
5475e5f0aaSAlex Zinenko Value alignment;
55136d746eSJacques Pienaar if (auto alignmentAttr = allocOp.getAlignment()) {
5675e5f0aaSAlex Zinenko alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
5775e5f0aaSAlex Zinenko } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
5875e5f0aaSAlex Zinenko // In the case where no alignment is specified, we may want to override
5975e5f0aaSAlex Zinenko // `malloc's` behavior. `malloc` typically aligns at the size of the
6075e5f0aaSAlex Zinenko // biggest scalar on a target HW. For non-scalars, use the natural
6175e5f0aaSAlex Zinenko // alignment of the LLVM type given by the LLVM DataLayout.
6275e5f0aaSAlex Zinenko alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
6375e5f0aaSAlex Zinenko }
6475e5f0aaSAlex Zinenko
6575e5f0aaSAlex Zinenko if (alignment) {
6675e5f0aaSAlex Zinenko // Adjust the allocation size to consider alignment.
6775e5f0aaSAlex Zinenko sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
6875e5f0aaSAlex Zinenko }
6975e5f0aaSAlex Zinenko
7075e5f0aaSAlex Zinenko // Allocate the underlying buffer and store a pointer to it in the MemRef
7175e5f0aaSAlex Zinenko // descriptor.
7275e5f0aaSAlex Zinenko Type elementPtrType = this->getElementPtrType(memRefType);
73a8601f11SMichele Scuttari auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
7475e5f0aaSAlex Zinenko auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
7575e5f0aaSAlex Zinenko getVoidPtrType());
7675e5f0aaSAlex Zinenko Value allocatedPtr =
7775e5f0aaSAlex Zinenko rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
7875e5f0aaSAlex Zinenko
7975e5f0aaSAlex Zinenko Value alignedPtr = allocatedPtr;
8075e5f0aaSAlex Zinenko if (alignment) {
8175e5f0aaSAlex Zinenko // Compute the aligned type pointer.
8275e5f0aaSAlex Zinenko Value allocatedInt =
8375e5f0aaSAlex Zinenko rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
8475e5f0aaSAlex Zinenko Value alignmentInt =
8575e5f0aaSAlex Zinenko createAligned(rewriter, loc, allocatedInt, alignment);
8675e5f0aaSAlex Zinenko alignedPtr =
8775e5f0aaSAlex Zinenko rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
8875e5f0aaSAlex Zinenko }
8975e5f0aaSAlex Zinenko
9075e5f0aaSAlex Zinenko return std::make_tuple(allocatedPtr, alignedPtr);
9175e5f0aaSAlex Zinenko }
9275e5f0aaSAlex Zinenko };
9375e5f0aaSAlex Zinenko
9475e5f0aaSAlex Zinenko struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
AlignedAllocOpLowering__anon7a9e10510111::AlignedAllocOpLowering9575e5f0aaSAlex Zinenko AlignedAllocOpLowering(LLVMTypeConverter &converter)
9675e5f0aaSAlex Zinenko : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
9775e5f0aaSAlex Zinenko converter) {}
9875e5f0aaSAlex Zinenko
9975e5f0aaSAlex Zinenko /// Returns the memref's element size in bytes using the data layout active at
10075e5f0aaSAlex Zinenko /// `op`.
10175e5f0aaSAlex Zinenko // TODO: there are other places where this is used. Expose publicly?
getMemRefEltSizeInBytes__anon7a9e10510111::AlignedAllocOpLowering10275e5f0aaSAlex Zinenko unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
10375e5f0aaSAlex Zinenko const DataLayout *layout = &defaultLayout;
10475e5f0aaSAlex Zinenko if (const DataLayoutAnalysis *analysis =
10575e5f0aaSAlex Zinenko getTypeConverter()->getDataLayoutAnalysis()) {
10675e5f0aaSAlex Zinenko layout = &analysis->getAbove(op);
10775e5f0aaSAlex Zinenko }
10875e5f0aaSAlex Zinenko Type elementType = memRefType.getElementType();
10975e5f0aaSAlex Zinenko if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
11075e5f0aaSAlex Zinenko return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
11175e5f0aaSAlex Zinenko *layout);
11275e5f0aaSAlex Zinenko if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
11375e5f0aaSAlex Zinenko return getTypeConverter()->getUnrankedMemRefDescriptorSize(
11475e5f0aaSAlex Zinenko memRefElementType, *layout);
11575e5f0aaSAlex Zinenko return layout->getTypeSize(elementType);
11675e5f0aaSAlex Zinenko }
11775e5f0aaSAlex Zinenko
11875e5f0aaSAlex Zinenko /// Returns true if the memref size in bytes is known to be a multiple of
11975e5f0aaSAlex Zinenko /// factor assuming the data layout active at `op`.
isMemRefSizeMultipleOf__anon7a9e10510111::AlignedAllocOpLowering12075e5f0aaSAlex Zinenko bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
12175e5f0aaSAlex Zinenko Operation *op) const {
12275e5f0aaSAlex Zinenko uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
12375e5f0aaSAlex Zinenko for (unsigned i = 0, e = type.getRank(); i < e; i++) {
124676bfb2aSRiver Riddle if (ShapedType::isDynamic(type.getDimSize(i)))
12575e5f0aaSAlex Zinenko continue;
12675e5f0aaSAlex Zinenko sizeDivisor = sizeDivisor * type.getDimSize(i);
12775e5f0aaSAlex Zinenko }
12875e5f0aaSAlex Zinenko return sizeDivisor % factor == 0;
12975e5f0aaSAlex Zinenko }
13075e5f0aaSAlex Zinenko
13175e5f0aaSAlex Zinenko /// Returns the alignment to be used for the allocation call itself.
13275e5f0aaSAlex Zinenko /// aligned_alloc requires the allocation size to be a power of two, and the
13375e5f0aaSAlex Zinenko /// allocation size to be a multiple of alignment,
getAllocationAlignment__anon7a9e10510111::AlignedAllocOpLowering13475e5f0aaSAlex Zinenko int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
135136d746eSJacques Pienaar if (Optional<uint64_t> alignment = allocOp.getAlignment())
13675e5f0aaSAlex Zinenko return *alignment;
13775e5f0aaSAlex Zinenko
13875e5f0aaSAlex Zinenko // Whenever we don't have alignment set, we will use an alignment
13975e5f0aaSAlex Zinenko // consistent with the element type; since the allocation size has to be a
14075e5f0aaSAlex Zinenko // power of two, we will bump to the next power of two if it already isn't.
14175e5f0aaSAlex Zinenko auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
14275e5f0aaSAlex Zinenko return std::max(kMinAlignedAllocAlignment,
14375e5f0aaSAlex Zinenko llvm::PowerOf2Ceil(eltSizeBytes));
14475e5f0aaSAlex Zinenko }
14575e5f0aaSAlex Zinenko
getAllocFn__anon7a9e10510111::AlignedAllocOpLowering146a8601f11SMichele Scuttari LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
147a8601f11SMichele Scuttari bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
148a8601f11SMichele Scuttari
149a8601f11SMichele Scuttari if (useGenericFn)
150a8601f11SMichele Scuttari return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType());
151a8601f11SMichele Scuttari
152a8601f11SMichele Scuttari return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType());
153a8601f11SMichele Scuttari }
154a8601f11SMichele Scuttari
allocateBuffer__anon7a9e10510111::AlignedAllocOpLowering15575e5f0aaSAlex Zinenko std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
15675e5f0aaSAlex Zinenko Location loc, Value sizeBytes,
15775e5f0aaSAlex Zinenko Operation *op) const override {
15875e5f0aaSAlex Zinenko // Heap allocations.
15975e5f0aaSAlex Zinenko memref::AllocOp allocOp = cast<memref::AllocOp>(op);
16075e5f0aaSAlex Zinenko MemRefType memRefType = allocOp.getType();
16175e5f0aaSAlex Zinenko int64_t alignment = getAllocationAlignment(allocOp);
16275e5f0aaSAlex Zinenko Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
16375e5f0aaSAlex Zinenko
16475e5f0aaSAlex Zinenko // aligned_alloc requires size to be a multiple of alignment; we will pad
16575e5f0aaSAlex Zinenko // the size to the next multiple if necessary.
16675e5f0aaSAlex Zinenko if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
16775e5f0aaSAlex Zinenko sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
16875e5f0aaSAlex Zinenko
16975e5f0aaSAlex Zinenko Type elementPtrType = this->getElementPtrType(memRefType);
170a8601f11SMichele Scuttari auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
17175e5f0aaSAlex Zinenko auto results =
17275e5f0aaSAlex Zinenko createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
17375e5f0aaSAlex Zinenko getVoidPtrType());
17475e5f0aaSAlex Zinenko Value allocatedPtr =
17575e5f0aaSAlex Zinenko rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
17675e5f0aaSAlex Zinenko
17775e5f0aaSAlex Zinenko return std::make_tuple(allocatedPtr, allocatedPtr);
17875e5f0aaSAlex Zinenko }
17975e5f0aaSAlex Zinenko
18075e5f0aaSAlex Zinenko /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
18175e5f0aaSAlex Zinenko static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
18275e5f0aaSAlex Zinenko
18375e5f0aaSAlex Zinenko /// Default layout to use in absence of the corresponding analysis.
18475e5f0aaSAlex Zinenko DataLayout defaultLayout;
18575e5f0aaSAlex Zinenko };
18675e5f0aaSAlex Zinenko
18775e5f0aaSAlex Zinenko // Out of line definition, required till C++17.
18875e5f0aaSAlex Zinenko constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
18975e5f0aaSAlex Zinenko
19075e5f0aaSAlex Zinenko struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
AllocaOpLowering__anon7a9e10510111::AllocaOpLowering19175e5f0aaSAlex Zinenko AllocaOpLowering(LLVMTypeConverter &converter)
19275e5f0aaSAlex Zinenko : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
19375e5f0aaSAlex Zinenko converter) {}
19475e5f0aaSAlex Zinenko
19575e5f0aaSAlex Zinenko /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
19675e5f0aaSAlex Zinenko /// is set to null for stack allocations. `accessAlignment` is set if
19775e5f0aaSAlex Zinenko /// alignment is needed post allocation (for eg. in conjunction with malloc).
allocateBuffer__anon7a9e10510111::AllocaOpLowering19875e5f0aaSAlex Zinenko std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
19975e5f0aaSAlex Zinenko Location loc, Value sizeBytes,
20075e5f0aaSAlex Zinenko Operation *op) const override {
20175e5f0aaSAlex Zinenko
20275e5f0aaSAlex Zinenko // With alloca, one gets a pointer to the element type right away.
20375e5f0aaSAlex Zinenko // For stack allocations.
20475e5f0aaSAlex Zinenko auto allocaOp = cast<memref::AllocaOp>(op);
20575e5f0aaSAlex Zinenko auto elementPtrType = this->getElementPtrType(allocaOp.getType());
20675e5f0aaSAlex Zinenko
20775e5f0aaSAlex Zinenko auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
208*2789c4f5SKazu Hirata loc, elementPtrType, sizeBytes, allocaOp.getAlignment().value_or(0));
20975e5f0aaSAlex Zinenko
21075e5f0aaSAlex Zinenko return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
21175e5f0aaSAlex Zinenko }
21275e5f0aaSAlex Zinenko };
21375e5f0aaSAlex Zinenko
21475e5f0aaSAlex Zinenko struct AllocaScopeOpLowering
21575e5f0aaSAlex Zinenko : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
21675e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
21775e5f0aaSAlex Zinenko
21875e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::AllocaScopeOpLowering219ef976337SRiver Riddle matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
22075e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
22175e5f0aaSAlex Zinenko OpBuilder::InsertionGuard guard(rewriter);
22275e5f0aaSAlex Zinenko Location loc = allocaScopeOp.getLoc();
22375e5f0aaSAlex Zinenko
22475e5f0aaSAlex Zinenko // Split the current block before the AllocaScopeOp to create the inlining
22575e5f0aaSAlex Zinenko // point.
22675e5f0aaSAlex Zinenko auto *currentBlock = rewriter.getInsertionBlock();
22775e5f0aaSAlex Zinenko auto *remainingOpsBlock =
22875e5f0aaSAlex Zinenko rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
22975e5f0aaSAlex Zinenko Block *continueBlock;
23075e5f0aaSAlex Zinenko if (allocaScopeOp.getNumResults() == 0) {
23175e5f0aaSAlex Zinenko continueBlock = remainingOpsBlock;
23275e5f0aaSAlex Zinenko } else {
233e084679fSRiver Riddle continueBlock = rewriter.createBlock(
234e084679fSRiver Riddle remainingOpsBlock, allocaScopeOp.getResultTypes(),
235e084679fSRiver Riddle SmallVector<Location>(allocaScopeOp->getNumResults(),
236e084679fSRiver Riddle allocaScopeOp.getLoc()));
23775e5f0aaSAlex Zinenko rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
23875e5f0aaSAlex Zinenko }
23975e5f0aaSAlex Zinenko
24075e5f0aaSAlex Zinenko // Inline body region.
241136d746eSJacques Pienaar Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
242136d746eSJacques Pienaar Block *afterBody = &allocaScopeOp.getBodyRegion().back();
243136d746eSJacques Pienaar rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
24475e5f0aaSAlex Zinenko
24575e5f0aaSAlex Zinenko // Save stack and then branch into the body of the region.
24675e5f0aaSAlex Zinenko rewriter.setInsertionPointToEnd(currentBlock);
24775e5f0aaSAlex Zinenko auto stackSaveOp =
24875e5f0aaSAlex Zinenko rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
24975e5f0aaSAlex Zinenko rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
25075e5f0aaSAlex Zinenko
25175e5f0aaSAlex Zinenko // Replace the alloca_scope return with a branch that jumps out of the body.
25275e5f0aaSAlex Zinenko // Stack restore before leaving the body region.
25375e5f0aaSAlex Zinenko rewriter.setInsertionPointToEnd(afterBody);
25475e5f0aaSAlex Zinenko auto returnOp =
25575e5f0aaSAlex Zinenko cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
25675e5f0aaSAlex Zinenko auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
257136d746eSJacques Pienaar returnOp, returnOp.getResults(), continueBlock);
25875e5f0aaSAlex Zinenko
25975e5f0aaSAlex Zinenko // Insert stack restore before jumping out the body of the region.
26075e5f0aaSAlex Zinenko rewriter.setInsertionPoint(branchOp);
26175e5f0aaSAlex Zinenko rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
26275e5f0aaSAlex Zinenko
26375e5f0aaSAlex Zinenko // Replace the op with values return from the body region.
26475e5f0aaSAlex Zinenko rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
26575e5f0aaSAlex Zinenko
26675e5f0aaSAlex Zinenko return success();
26775e5f0aaSAlex Zinenko }
26875e5f0aaSAlex Zinenko };
26975e5f0aaSAlex Zinenko
27075e5f0aaSAlex Zinenko struct AssumeAlignmentOpLowering
27175e5f0aaSAlex Zinenko : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
27275e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<
27375e5f0aaSAlex Zinenko memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
27475e5f0aaSAlex Zinenko
27575e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::AssumeAlignmentOpLowering276ef976337SRiver Riddle matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
27775e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
278136d746eSJacques Pienaar Value memref = adaptor.getMemref();
279136d746eSJacques Pienaar unsigned alignment = op.getAlignment();
28075e5f0aaSAlex Zinenko auto loc = op.getLoc();
28175e5f0aaSAlex Zinenko
28275e5f0aaSAlex Zinenko MemRefDescriptor memRefDescriptor(memref);
28375e5f0aaSAlex Zinenko Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
28475e5f0aaSAlex Zinenko
28575e5f0aaSAlex Zinenko // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
28675e5f0aaSAlex Zinenko // the asserted memref.alignedPtr isn't used anywhere else, as the real
28775e5f0aaSAlex Zinenko // users like load/store/views always re-extract memref.alignedPtr as they
28875e5f0aaSAlex Zinenko // get lowered.
28975e5f0aaSAlex Zinenko //
29075e5f0aaSAlex Zinenko // This relies on LLVM's CSE optimization (potentially after SROA), since
29175e5f0aaSAlex Zinenko // after CSE all memref.alignedPtr instances get de-duplicated into the same
29275e5f0aaSAlex Zinenko // pointer SSA value.
29375e5f0aaSAlex Zinenko auto intPtrType =
29475e5f0aaSAlex Zinenko getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
29575e5f0aaSAlex Zinenko Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
29675e5f0aaSAlex Zinenko Value mask =
29775e5f0aaSAlex Zinenko createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
29875e5f0aaSAlex Zinenko Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
29975e5f0aaSAlex Zinenko rewriter.create<LLVM::AssumeOp>(
30075e5f0aaSAlex Zinenko loc, rewriter.create<LLVM::ICmpOp>(
30175e5f0aaSAlex Zinenko loc, LLVM::ICmpPredicate::eq,
30275e5f0aaSAlex Zinenko rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
30375e5f0aaSAlex Zinenko
30475e5f0aaSAlex Zinenko rewriter.eraseOp(op);
30575e5f0aaSAlex Zinenko return success();
30675e5f0aaSAlex Zinenko }
30775e5f0aaSAlex Zinenko };
30875e5f0aaSAlex Zinenko
30975e5f0aaSAlex Zinenko // A `dealloc` is converted into a call to `free` on the underlying data buffer.
31075e5f0aaSAlex Zinenko // The memref descriptor being an SSA value, there is no need to clean it up
31175e5f0aaSAlex Zinenko // in any way.
31275e5f0aaSAlex Zinenko struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
31375e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
31475e5f0aaSAlex Zinenko
DeallocOpLowering__anon7a9e10510111::DeallocOpLowering31575e5f0aaSAlex Zinenko explicit DeallocOpLowering(LLVMTypeConverter &converter)
31675e5f0aaSAlex Zinenko : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
31775e5f0aaSAlex Zinenko
getFreeFn__anon7a9e10510111::DeallocOpLowering318a8601f11SMichele Scuttari LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const {
319a8601f11SMichele Scuttari bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
320a8601f11SMichele Scuttari
321a8601f11SMichele Scuttari if (useGenericFn)
322a8601f11SMichele Scuttari return LLVM::lookupOrCreateGenericFreeFn(module);
323a8601f11SMichele Scuttari
324a8601f11SMichele Scuttari return LLVM::lookupOrCreateFreeFn(module);
325a8601f11SMichele Scuttari }
326a8601f11SMichele Scuttari
32775e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::DeallocOpLowering328ef976337SRiver Riddle matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
32975e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
33075e5f0aaSAlex Zinenko // Insert the `free` declaration if it is not already present.
331a8601f11SMichele Scuttari auto freeFunc = getFreeFn(op->getParentOfType<ModuleOp>());
332136d746eSJacques Pienaar MemRefDescriptor memref(adaptor.getMemref());
33375e5f0aaSAlex Zinenko Value casted = rewriter.create<LLVM::BitcastOp>(
33475e5f0aaSAlex Zinenko op.getLoc(), getVoidPtrType(),
33575e5f0aaSAlex Zinenko memref.allocatedPtr(rewriter, op.getLoc()));
33675e5f0aaSAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::CallOp>(
337faf1c224SChris Lattner op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
33875e5f0aaSAlex Zinenko return success();
33975e5f0aaSAlex Zinenko }
34075e5f0aaSAlex Zinenko };
34175e5f0aaSAlex Zinenko
34275e5f0aaSAlex Zinenko // A `dim` is converted to a constant for static sizes and to an access to the
34375e5f0aaSAlex Zinenko // size stored in the memref descriptor for dynamic sizes.
34475e5f0aaSAlex Zinenko struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
34575e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
34675e5f0aaSAlex Zinenko
34775e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::DimOpLowering348ef976337SRiver Riddle matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
34975e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
350136d746eSJacques Pienaar Type operandType = dimOp.getSource().getType();
35175e5f0aaSAlex Zinenko if (operandType.isa<UnrankedMemRefType>()) {
352ef976337SRiver Riddle rewriter.replaceOp(
353ef976337SRiver Riddle dimOp, {extractSizeOfUnrankedMemRef(
354ef976337SRiver Riddle operandType, dimOp, adaptor.getOperands(), rewriter)});
35575e5f0aaSAlex Zinenko
35675e5f0aaSAlex Zinenko return success();
35775e5f0aaSAlex Zinenko }
35875e5f0aaSAlex Zinenko if (operandType.isa<MemRefType>()) {
359ef976337SRiver Riddle rewriter.replaceOp(
360ef976337SRiver Riddle dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
361ef976337SRiver Riddle adaptor.getOperands(), rewriter)});
36275e5f0aaSAlex Zinenko return success();
36375e5f0aaSAlex Zinenko }
36475e5f0aaSAlex Zinenko llvm_unreachable("expected MemRefType or UnrankedMemRefType");
36575e5f0aaSAlex Zinenko }
36675e5f0aaSAlex Zinenko
36775e5f0aaSAlex Zinenko private:
extractSizeOfUnrankedMemRef__anon7a9e10510111::DimOpLowering36875e5f0aaSAlex Zinenko Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
369ef976337SRiver Riddle OpAdaptor adaptor,
37075e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const {
37175e5f0aaSAlex Zinenko Location loc = dimOp.getLoc();
37275e5f0aaSAlex Zinenko
37375e5f0aaSAlex Zinenko auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
37475e5f0aaSAlex Zinenko auto scalarMemRefType =
37575e5f0aaSAlex Zinenko MemRefType::get({}, unrankedMemRefType.getElementType());
37675e5f0aaSAlex Zinenko unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
37775e5f0aaSAlex Zinenko
37875e5f0aaSAlex Zinenko // Extract pointer to the underlying ranked descriptor and bitcast it to a
37975e5f0aaSAlex Zinenko // memref<element_type> descriptor pointer to minimize the number of GEP
38075e5f0aaSAlex Zinenko // operations.
381136d746eSJacques Pienaar UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
38275e5f0aaSAlex Zinenko Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
38375e5f0aaSAlex Zinenko Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
38475e5f0aaSAlex Zinenko loc,
38575e5f0aaSAlex Zinenko LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
38675e5f0aaSAlex Zinenko addressSpace),
38775e5f0aaSAlex Zinenko underlyingRankedDesc);
38875e5f0aaSAlex Zinenko
38975e5f0aaSAlex Zinenko // Get pointer to offset field of memref<element_type> descriptor.
39075e5f0aaSAlex Zinenko Type indexPtrTy = LLVM::LLVMPointerType::get(
39175e5f0aaSAlex Zinenko getTypeConverter()->getIndexType(), addressSpace);
39275e5f0aaSAlex Zinenko Value two = rewriter.create<LLVM::ConstantOp>(
39375e5f0aaSAlex Zinenko loc, typeConverter->convertType(rewriter.getI32Type()),
39475e5f0aaSAlex Zinenko rewriter.getI32IntegerAttr(2));
39575e5f0aaSAlex Zinenko Value offsetPtr = rewriter.create<LLVM::GEPOp>(
39675e5f0aaSAlex Zinenko loc, indexPtrTy, scalarMemRefDescPtr,
39775e5f0aaSAlex Zinenko ValueRange({createIndexConstant(rewriter, loc, 0), two}));
39875e5f0aaSAlex Zinenko
39975e5f0aaSAlex Zinenko // The size value that we have to extract can be obtained using GEPop with
40075e5f0aaSAlex Zinenko // `dimOp.index() + 1` index argument.
40175e5f0aaSAlex Zinenko Value idxPlusOne = rewriter.create<LLVM::AddOp>(
402136d746eSJacques Pienaar loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
40375e5f0aaSAlex Zinenko Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
40475e5f0aaSAlex Zinenko ValueRange({idxPlusOne}));
40575e5f0aaSAlex Zinenko return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
40675e5f0aaSAlex Zinenko }
40775e5f0aaSAlex Zinenko
getConstantDimIndex__anon7a9e10510111::DimOpLowering40875e5f0aaSAlex Zinenko Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
40975e5f0aaSAlex Zinenko if (Optional<int64_t> idx = dimOp.getConstantIndex())
41075e5f0aaSAlex Zinenko return idx;
41175e5f0aaSAlex Zinenko
412136d746eSJacques Pienaar if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
413cfb72fd3SJacques Pienaar return constantOp.getValue()
414cfb72fd3SJacques Pienaar .cast<IntegerAttr>()
415cfb72fd3SJacques Pienaar .getValue()
416cfb72fd3SJacques Pienaar .getSExtValue();
41775e5f0aaSAlex Zinenko
41875e5f0aaSAlex Zinenko return llvm::None;
41975e5f0aaSAlex Zinenko }
42075e5f0aaSAlex Zinenko
extractSizeOfRankedMemRef__anon7a9e10510111::DimOpLowering42175e5f0aaSAlex Zinenko Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
422ef976337SRiver Riddle OpAdaptor adaptor,
42375e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const {
42475e5f0aaSAlex Zinenko Location loc = dimOp.getLoc();
425ef976337SRiver Riddle
42675e5f0aaSAlex Zinenko // Take advantage if index is constant.
42775e5f0aaSAlex Zinenko MemRefType memRefType = operandType.cast<MemRefType>();
42875e5f0aaSAlex Zinenko if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
4296d5fc1e3SKazu Hirata int64_t i = *index;
43075e5f0aaSAlex Zinenko if (memRefType.isDynamicDim(i)) {
43175e5f0aaSAlex Zinenko // extract dynamic size from the memref descriptor.
432136d746eSJacques Pienaar MemRefDescriptor descriptor(adaptor.getSource());
43375e5f0aaSAlex Zinenko return descriptor.size(rewriter, loc, i);
43475e5f0aaSAlex Zinenko }
43575e5f0aaSAlex Zinenko // Use constant for static size.
43675e5f0aaSAlex Zinenko int64_t dimSize = memRefType.getDimSize(i);
43775e5f0aaSAlex Zinenko return createIndexConstant(rewriter, loc, dimSize);
43875e5f0aaSAlex Zinenko }
439136d746eSJacques Pienaar Value index = adaptor.getIndex();
44075e5f0aaSAlex Zinenko int64_t rank = memRefType.getRank();
441136d746eSJacques Pienaar MemRefDescriptor memrefDescriptor(adaptor.getSource());
44275e5f0aaSAlex Zinenko return memrefDescriptor.size(rewriter, loc, index, rank);
44375e5f0aaSAlex Zinenko }
44475e5f0aaSAlex Zinenko };
44575e5f0aaSAlex Zinenko
446632a4f88SRiver Riddle /// Common base for load and store operations on MemRefs. Restricts the match
447632a4f88SRiver Riddle /// to supported MemRef types. Provides functionality to emit code accessing a
448632a4f88SRiver Riddle /// specific element of the underlying data buffer.
449632a4f88SRiver Riddle template <typename Derived>
450632a4f88SRiver Riddle struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
451632a4f88SRiver Riddle using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
452632a4f88SRiver Riddle using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
453632a4f88SRiver Riddle using Base = LoadStoreOpLowering<Derived>;
454632a4f88SRiver Riddle
match__anon7a9e10510111::LoadStoreOpLowering455632a4f88SRiver Riddle LogicalResult match(Derived op) const override {
456632a4f88SRiver Riddle MemRefType type = op.getMemRefType();
457632a4f88SRiver Riddle return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
458632a4f88SRiver Riddle }
459632a4f88SRiver Riddle };
460632a4f88SRiver Riddle
461632a4f88SRiver Riddle /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
462632a4f88SRiver Riddle /// retried until it succeeds in atomically storing a new value into memory.
463632a4f88SRiver Riddle ///
464632a4f88SRiver Riddle /// +---------------------------------+
465632a4f88SRiver Riddle /// | <code before the AtomicRMWOp> |
466632a4f88SRiver Riddle /// | <compute initial %loaded> |
467ace01605SRiver Riddle /// | cf.br loop(%loaded) |
468632a4f88SRiver Riddle /// +---------------------------------+
469632a4f88SRiver Riddle /// |
470632a4f88SRiver Riddle /// -------| |
471632a4f88SRiver Riddle /// | v v
472632a4f88SRiver Riddle /// | +--------------------------------+
473632a4f88SRiver Riddle /// | | loop(%loaded): |
474632a4f88SRiver Riddle /// | | <body contents> |
475632a4f88SRiver Riddle /// | | %pair = cmpxchg |
476632a4f88SRiver Riddle /// | | %ok = %pair[0] |
477632a4f88SRiver Riddle /// | | %new = %pair[1] |
478ace01605SRiver Riddle /// | | cf.cond_br %ok, end, loop(%new) |
479632a4f88SRiver Riddle /// | +--------------------------------+
480632a4f88SRiver Riddle /// | | |
481632a4f88SRiver Riddle /// |----------- |
482632a4f88SRiver Riddle /// v
483632a4f88SRiver Riddle /// +--------------------------------+
484632a4f88SRiver Riddle /// | end: |
485632a4f88SRiver Riddle /// | <code after the AtomicRMWOp> |
486632a4f88SRiver Riddle /// +--------------------------------+
487632a4f88SRiver Riddle ///
488632a4f88SRiver Riddle struct GenericAtomicRMWOpLowering
489632a4f88SRiver Riddle : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
490632a4f88SRiver Riddle using Base::Base;
491632a4f88SRiver Riddle
492632a4f88SRiver Riddle LogicalResult
matchAndRewrite__anon7a9e10510111::GenericAtomicRMWOpLowering493632a4f88SRiver Riddle matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
494632a4f88SRiver Riddle ConversionPatternRewriter &rewriter) const override {
495632a4f88SRiver Riddle auto loc = atomicOp.getLoc();
496632a4f88SRiver Riddle Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
497632a4f88SRiver Riddle
498632a4f88SRiver Riddle // Split the block into initial, loop, and ending parts.
499632a4f88SRiver Riddle auto *initBlock = rewriter.getInsertionBlock();
500632a4f88SRiver Riddle auto *loopBlock = rewriter.createBlock(
501632a4f88SRiver Riddle initBlock->getParent(), std::next(Region::iterator(initBlock)),
502632a4f88SRiver Riddle valueType, loc);
503632a4f88SRiver Riddle auto *endBlock = rewriter.createBlock(
504632a4f88SRiver Riddle loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
505632a4f88SRiver Riddle
506632a4f88SRiver Riddle // Operations range to be moved to `endBlock`.
507632a4f88SRiver Riddle auto opsToMoveStart = atomicOp->getIterator();
508632a4f88SRiver Riddle auto opsToMoveEnd = initBlock->back().getIterator();
509632a4f88SRiver Riddle
510632a4f88SRiver Riddle // Compute the loaded value and branch to the loop block.
511632a4f88SRiver Riddle rewriter.setInsertionPointToEnd(initBlock);
512136d746eSJacques Pienaar auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
513136d746eSJacques Pienaar auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
514136d746eSJacques Pienaar adaptor.getIndices(), rewriter);
515632a4f88SRiver Riddle Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
516632a4f88SRiver Riddle rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
517632a4f88SRiver Riddle
518632a4f88SRiver Riddle // Prepare the body of the loop block.
519632a4f88SRiver Riddle rewriter.setInsertionPointToStart(loopBlock);
520632a4f88SRiver Riddle
521632a4f88SRiver Riddle // Clone the GenericAtomicRMWOp region and extract the result.
522632a4f88SRiver Riddle auto loopArgument = loopBlock->getArgument(0);
523632a4f88SRiver Riddle BlockAndValueMapping mapping;
524632a4f88SRiver Riddle mapping.map(atomicOp.getCurrentValue(), loopArgument);
525632a4f88SRiver Riddle Block &entryBlock = atomicOp.body().front();
526632a4f88SRiver Riddle for (auto &nestedOp : entryBlock.without_terminator()) {
527632a4f88SRiver Riddle Operation *clone = rewriter.clone(nestedOp, mapping);
528632a4f88SRiver Riddle mapping.map(nestedOp.getResults(), clone->getResults());
529632a4f88SRiver Riddle }
530632a4f88SRiver Riddle Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
531632a4f88SRiver Riddle
532632a4f88SRiver Riddle // Prepare the epilog of the loop block.
533632a4f88SRiver Riddle // Append the cmpxchg op to the end of the loop block.
534632a4f88SRiver Riddle auto successOrdering = LLVM::AtomicOrdering::acq_rel;
535632a4f88SRiver Riddle auto failureOrdering = LLVM::AtomicOrdering::monotonic;
536632a4f88SRiver Riddle auto boolType = IntegerType::get(rewriter.getContext(), 1);
537632a4f88SRiver Riddle auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
538632a4f88SRiver Riddle {valueType, boolType});
539632a4f88SRiver Riddle auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
540632a4f88SRiver Riddle loc, pairType, dataPtr, loopArgument, result, successOrdering,
541632a4f88SRiver Riddle failureOrdering);
542632a4f88SRiver Riddle // Extract the %new_loaded and %ok values from the pair.
543632a4f88SRiver Riddle Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
544632a4f88SRiver Riddle loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
545632a4f88SRiver Riddle Value ok = rewriter.create<LLVM::ExtractValueOp>(
546632a4f88SRiver Riddle loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
547632a4f88SRiver Riddle
548632a4f88SRiver Riddle // Conditionally branch to the end or back to the loop depending on %ok.
549632a4f88SRiver Riddle rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
550632a4f88SRiver Riddle loopBlock, newLoaded);
551632a4f88SRiver Riddle
552632a4f88SRiver Riddle rewriter.setInsertionPointToEnd(endBlock);
553632a4f88SRiver Riddle moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
554632a4f88SRiver Riddle std::next(opsToMoveEnd), rewriter);
555632a4f88SRiver Riddle
556632a4f88SRiver Riddle // The 'result' of the atomic_rmw op is the newly loaded value.
557632a4f88SRiver Riddle rewriter.replaceOp(atomicOp, {newLoaded});
558632a4f88SRiver Riddle
559632a4f88SRiver Riddle return success();
560632a4f88SRiver Riddle }
561632a4f88SRiver Riddle
562632a4f88SRiver Riddle private:
563632a4f88SRiver Riddle // Clones a segment of ops [start, end) and erases the original.
moveOpsRange__anon7a9e10510111::GenericAtomicRMWOpLowering564632a4f88SRiver Riddle void moveOpsRange(ValueRange oldResult, ValueRange newResult,
565632a4f88SRiver Riddle Block::iterator start, Block::iterator end,
566632a4f88SRiver Riddle ConversionPatternRewriter &rewriter) const {
567632a4f88SRiver Riddle BlockAndValueMapping mapping;
568632a4f88SRiver Riddle mapping.map(oldResult, newResult);
569632a4f88SRiver Riddle SmallVector<Operation *, 2> opsToErase;
570632a4f88SRiver Riddle for (auto it = start; it != end; ++it) {
571632a4f88SRiver Riddle rewriter.clone(*it, mapping);
572632a4f88SRiver Riddle opsToErase.push_back(&*it);
573632a4f88SRiver Riddle }
574632a4f88SRiver Riddle for (auto *it : opsToErase)
575632a4f88SRiver Riddle rewriter.eraseOp(it);
576632a4f88SRiver Riddle }
577632a4f88SRiver Riddle };
578632a4f88SRiver Riddle
57975e5f0aaSAlex Zinenko /// Returns the LLVM type of the global variable given the memref type `type`.
convertGlobalMemrefTypeToLLVM(MemRefType type,LLVMTypeConverter & typeConverter)58075e5f0aaSAlex Zinenko static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
58175e5f0aaSAlex Zinenko LLVMTypeConverter &typeConverter) {
58275e5f0aaSAlex Zinenko // LLVM type for a global memref will be a multi-dimension array. For
58375e5f0aaSAlex Zinenko // declarations or uninitialized global memrefs, we can potentially flatten
58475e5f0aaSAlex Zinenko // this to a 1D array. However, for memref.global's with an initial value,
58575e5f0aaSAlex Zinenko // we do not intend to flatten the ElementsAttribute when going from std ->
58675e5f0aaSAlex Zinenko // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
58775e5f0aaSAlex Zinenko Type elementType = typeConverter.convertType(type.getElementType());
58875e5f0aaSAlex Zinenko Type arrayTy = elementType;
58975e5f0aaSAlex Zinenko // Shape has the outermost dim at index 0, so need to walk it backwards
59075e5f0aaSAlex Zinenko for (int64_t dim : llvm::reverse(type.getShape()))
59175e5f0aaSAlex Zinenko arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
59275e5f0aaSAlex Zinenko return arrayTy;
59375e5f0aaSAlex Zinenko }
59475e5f0aaSAlex Zinenko
59575e5f0aaSAlex Zinenko /// GlobalMemrefOp is lowered to a LLVM Global Variable.
59675e5f0aaSAlex Zinenko struct GlobalMemrefOpLowering
59775e5f0aaSAlex Zinenko : public ConvertOpToLLVMPattern<memref::GlobalOp> {
59875e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
59975e5f0aaSAlex Zinenko
60075e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::GlobalMemrefOpLowering601ef976337SRiver Riddle matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
60275e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
603136d746eSJacques Pienaar MemRefType type = global.getType();
60475e5f0aaSAlex Zinenko if (!isConvertibleAndHasIdentityMaps(type))
60575e5f0aaSAlex Zinenko return failure();
60675e5f0aaSAlex Zinenko
60775e5f0aaSAlex Zinenko Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
60875e5f0aaSAlex Zinenko
60975e5f0aaSAlex Zinenko LLVM::Linkage linkage =
61075e5f0aaSAlex Zinenko global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
61175e5f0aaSAlex Zinenko
61275e5f0aaSAlex Zinenko Attribute initialValue = nullptr;
61375e5f0aaSAlex Zinenko if (!global.isExternal() && !global.isUninitialized()) {
614136d746eSJacques Pienaar auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
61575e5f0aaSAlex Zinenko initialValue = elementsAttr;
61675e5f0aaSAlex Zinenko
61775e5f0aaSAlex Zinenko // For scalar memrefs, the global variable created is of the element type,
61875e5f0aaSAlex Zinenko // so unpack the elements attribute to extract the value.
61975e5f0aaSAlex Zinenko if (type.getRank() == 0)
620937e40a8SRiver Riddle initialValue = elementsAttr.getSplatValue<Attribute>();
62175e5f0aaSAlex Zinenko }
62275e5f0aaSAlex Zinenko
623136d746eSJacques Pienaar uint64_t alignment = global.getAlignment().value_or(0);
6248276ac13SEugene Zhulenev
6258c2ff7b6SWilliam S. Moses auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
626136d746eSJacques Pienaar global, arrayTy, global.getConstant(), linkage, global.getSymName(),
6278276ac13SEugene Zhulenev initialValue, alignment, type.getMemorySpaceAsInt());
6288c2ff7b6SWilliam S. Moses if (!global.isExternal() && global.isUninitialized()) {
6298c2ff7b6SWilliam S. Moses Block *blk = new Block();
6308c2ff7b6SWilliam S. Moses newGlobal.getInitializerRegion().push_back(blk);
6318c2ff7b6SWilliam S. Moses rewriter.setInsertionPointToStart(blk);
6328c2ff7b6SWilliam S. Moses Value undef[] = {
6338c2ff7b6SWilliam S. Moses rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
6348c2ff7b6SWilliam S. Moses rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
6358c2ff7b6SWilliam S. Moses }
63675e5f0aaSAlex Zinenko return success();
63775e5f0aaSAlex Zinenko }
63875e5f0aaSAlex Zinenko };
63975e5f0aaSAlex Zinenko
64075e5f0aaSAlex Zinenko /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
64175e5f0aaSAlex Zinenko /// the first element stashed into the descriptor. This reuses
64275e5f0aaSAlex Zinenko /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
64375e5f0aaSAlex Zinenko struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
GetGlobalMemrefOpLowering__anon7a9e10510111::GetGlobalMemrefOpLowering64475e5f0aaSAlex Zinenko GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
64575e5f0aaSAlex Zinenko : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
64675e5f0aaSAlex Zinenko converter) {}
64775e5f0aaSAlex Zinenko
64875e5f0aaSAlex Zinenko /// Buffer "allocation" for memref.get_global op is getting the address of
64975e5f0aaSAlex Zinenko /// the global variable referenced.
allocateBuffer__anon7a9e10510111::GetGlobalMemrefOpLowering65075e5f0aaSAlex Zinenko std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
65175e5f0aaSAlex Zinenko Location loc, Value sizeBytes,
65275e5f0aaSAlex Zinenko Operation *op) const override {
65375e5f0aaSAlex Zinenko auto getGlobalOp = cast<memref::GetGlobalOp>(op);
654136d746eSJacques Pienaar MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
65575e5f0aaSAlex Zinenko unsigned memSpace = type.getMemorySpaceAsInt();
65675e5f0aaSAlex Zinenko
65775e5f0aaSAlex Zinenko Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
65875e5f0aaSAlex Zinenko auto addressOf = rewriter.create<LLVM::AddressOfOp>(
659136d746eSJacques Pienaar loc, LLVM::LLVMPointerType::get(arrayTy, memSpace),
660136d746eSJacques Pienaar getGlobalOp.getName());
66175e5f0aaSAlex Zinenko
66275e5f0aaSAlex Zinenko // Get the address of the first element in the array by creating a GEP with
66375e5f0aaSAlex Zinenko // the address of the GV as the base, and (rank + 1) number of 0 indices.
66475e5f0aaSAlex Zinenko Type elementType = typeConverter->convertType(type.getElementType());
66575e5f0aaSAlex Zinenko Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
66675e5f0aaSAlex Zinenko
667cafaa350SAlex Zinenko SmallVector<Value> operands;
66875e5f0aaSAlex Zinenko operands.insert(operands.end(), type.getRank() + 1,
66975e5f0aaSAlex Zinenko createIndexConstant(rewriter, loc, 0));
670cafaa350SAlex Zinenko auto gep =
671cafaa350SAlex Zinenko rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
67275e5f0aaSAlex Zinenko
67375e5f0aaSAlex Zinenko // We do not expect the memref obtained using `memref.get_global` to be
67475e5f0aaSAlex Zinenko // ever deallocated. Set the allocated pointer to be known bad value to
67575e5f0aaSAlex Zinenko // help debug if that ever happens.
67675e5f0aaSAlex Zinenko auto intPtrType = getIntPtrType(memSpace);
67775e5f0aaSAlex Zinenko Value deadBeefConst =
67875e5f0aaSAlex Zinenko createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
67975e5f0aaSAlex Zinenko auto deadBeefPtr =
68075e5f0aaSAlex Zinenko rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
68175e5f0aaSAlex Zinenko
68275e5f0aaSAlex Zinenko // Both allocated and aligned pointers are same. We could potentially stash
68375e5f0aaSAlex Zinenko // a nullptr for the allocated pointer since we do not expect any dealloc.
68475e5f0aaSAlex Zinenko return std::make_tuple(deadBeefPtr, gep);
68575e5f0aaSAlex Zinenko }
68675e5f0aaSAlex Zinenko };
68775e5f0aaSAlex Zinenko
68875e5f0aaSAlex Zinenko // Load operation is lowered to obtaining a pointer to the indexed element
68975e5f0aaSAlex Zinenko // and loading it.
69075e5f0aaSAlex Zinenko struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
69175e5f0aaSAlex Zinenko using Base::Base;
69275e5f0aaSAlex Zinenko
69375e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::LoadOpLowering694ef976337SRiver Riddle matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
69575e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
69675e5f0aaSAlex Zinenko auto type = loadOp.getMemRefType();
69775e5f0aaSAlex Zinenko
698136d746eSJacques Pienaar Value dataPtr =
699136d746eSJacques Pienaar getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
700136d746eSJacques Pienaar adaptor.getIndices(), rewriter);
70175e5f0aaSAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
70275e5f0aaSAlex Zinenko return success();
70375e5f0aaSAlex Zinenko }
70475e5f0aaSAlex Zinenko };
70575e5f0aaSAlex Zinenko
70675e5f0aaSAlex Zinenko // Store operation is lowered to obtaining a pointer to the indexed element,
70775e5f0aaSAlex Zinenko // and storing the given value to it.
70875e5f0aaSAlex Zinenko struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
70975e5f0aaSAlex Zinenko using Base::Base;
71075e5f0aaSAlex Zinenko
71175e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::StoreOpLowering712ef976337SRiver Riddle matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
71375e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
71475e5f0aaSAlex Zinenko auto type = op.getMemRefType();
71575e5f0aaSAlex Zinenko
716136d746eSJacques Pienaar Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
717136d746eSJacques Pienaar adaptor.getIndices(), rewriter);
718136d746eSJacques Pienaar rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
71975e5f0aaSAlex Zinenko return success();
72075e5f0aaSAlex Zinenko }
72175e5f0aaSAlex Zinenko };
72275e5f0aaSAlex Zinenko
72375e5f0aaSAlex Zinenko // The prefetch operation is lowered in a way similar to the load operation
72475e5f0aaSAlex Zinenko // except that the llvm.prefetch operation is used for replacement.
72575e5f0aaSAlex Zinenko struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
72675e5f0aaSAlex Zinenko using Base::Base;
72775e5f0aaSAlex Zinenko
72875e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::PrefetchOpLowering729ef976337SRiver Riddle matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
73075e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
73175e5f0aaSAlex Zinenko auto type = prefetchOp.getMemRefType();
73275e5f0aaSAlex Zinenko auto loc = prefetchOp.getLoc();
73375e5f0aaSAlex Zinenko
734136d746eSJacques Pienaar Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
735136d746eSJacques Pienaar adaptor.getIndices(), rewriter);
73675e5f0aaSAlex Zinenko
73775e5f0aaSAlex Zinenko // Replace with llvm.prefetch.
73875e5f0aaSAlex Zinenko auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
73975e5f0aaSAlex Zinenko auto isWrite = rewriter.create<LLVM::ConstantOp>(
740136d746eSJacques Pienaar loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()));
74175e5f0aaSAlex Zinenko auto localityHint = rewriter.create<LLVM::ConstantOp>(
74275e5f0aaSAlex Zinenko loc, llvmI32Type,
743136d746eSJacques Pienaar rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint()));
74475e5f0aaSAlex Zinenko auto isData = rewriter.create<LLVM::ConstantOp>(
745136d746eSJacques Pienaar loc, llvmI32Type,
746136d746eSJacques Pienaar rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()));
74775e5f0aaSAlex Zinenko
74875e5f0aaSAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
74975e5f0aaSAlex Zinenko localityHint, isData);
75075e5f0aaSAlex Zinenko return success();
75175e5f0aaSAlex Zinenko }
75275e5f0aaSAlex Zinenko };
75375e5f0aaSAlex Zinenko
75415f8f3e2SAlexander Belyaev struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
75515f8f3e2SAlexander Belyaev using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
75615f8f3e2SAlexander Belyaev
75715f8f3e2SAlexander Belyaev LogicalResult
matchAndRewrite__anon7a9e10510111::RankOpLowering75815f8f3e2SAlexander Belyaev matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
75915f8f3e2SAlexander Belyaev ConversionPatternRewriter &rewriter) const override {
76015f8f3e2SAlexander Belyaev Location loc = op.getLoc();
761136d746eSJacques Pienaar Type operandType = op.getMemref().getType();
76215f8f3e2SAlexander Belyaev if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
763136d746eSJacques Pienaar UnrankedMemRefDescriptor desc(adaptor.getMemref());
76415f8f3e2SAlexander Belyaev rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
76515f8f3e2SAlexander Belyaev return success();
76615f8f3e2SAlexander Belyaev }
76715f8f3e2SAlexander Belyaev if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
76815f8f3e2SAlexander Belyaev rewriter.replaceOp(
76915f8f3e2SAlexander Belyaev op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
77015f8f3e2SAlexander Belyaev return success();
77115f8f3e2SAlexander Belyaev }
77215f8f3e2SAlexander Belyaev return failure();
77315f8f3e2SAlexander Belyaev }
77415f8f3e2SAlexander Belyaev };
77515f8f3e2SAlexander Belyaev
77675e5f0aaSAlex Zinenko struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
77775e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
77875e5f0aaSAlex Zinenko
match__anon7a9e10510111::MemRefCastOpLowering77975e5f0aaSAlex Zinenko LogicalResult match(memref::CastOp memRefCastOp) const override {
78075e5f0aaSAlex Zinenko Type srcType = memRefCastOp.getOperand().getType();
78175e5f0aaSAlex Zinenko Type dstType = memRefCastOp.getType();
78275e5f0aaSAlex Zinenko
78375e5f0aaSAlex Zinenko // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
78475e5f0aaSAlex Zinenko // used for type erasure. For now they must preserve underlying element type
78575e5f0aaSAlex Zinenko // and require source and result type to have the same rank. Therefore,
78675e5f0aaSAlex Zinenko // perform a sanity check that the underlying structs are the same. Once op
78775e5f0aaSAlex Zinenko // semantics are relaxed we can revisit.
78875e5f0aaSAlex Zinenko if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
78975e5f0aaSAlex Zinenko return success(typeConverter->convertType(srcType) ==
79075e5f0aaSAlex Zinenko typeConverter->convertType(dstType));
79175e5f0aaSAlex Zinenko
79275e5f0aaSAlex Zinenko // At least one of the operands is unranked type
79375e5f0aaSAlex Zinenko assert(srcType.isa<UnrankedMemRefType>() ||
79475e5f0aaSAlex Zinenko dstType.isa<UnrankedMemRefType>());
79575e5f0aaSAlex Zinenko
79675e5f0aaSAlex Zinenko // Unranked to unranked cast is disallowed
79775e5f0aaSAlex Zinenko return !(srcType.isa<UnrankedMemRefType>() &&
79875e5f0aaSAlex Zinenko dstType.isa<UnrankedMemRefType>())
79975e5f0aaSAlex Zinenko ? success()
80075e5f0aaSAlex Zinenko : failure();
80175e5f0aaSAlex Zinenko }
80275e5f0aaSAlex Zinenko
rewrite__anon7a9e10510111::MemRefCastOpLowering803ef976337SRiver Riddle void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
80475e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
80575e5f0aaSAlex Zinenko auto srcType = memRefCastOp.getOperand().getType();
80675e5f0aaSAlex Zinenko auto dstType = memRefCastOp.getType();
80775e5f0aaSAlex Zinenko auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
80875e5f0aaSAlex Zinenko auto loc = memRefCastOp.getLoc();
80975e5f0aaSAlex Zinenko
81075e5f0aaSAlex Zinenko // For ranked/ranked case, just keep the original descriptor.
81175e5f0aaSAlex Zinenko if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
812136d746eSJacques Pienaar return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
81375e5f0aaSAlex Zinenko
81475e5f0aaSAlex Zinenko if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
81575e5f0aaSAlex Zinenko // Casting ranked to unranked memref type
81675e5f0aaSAlex Zinenko // Set the rank in the destination from the memref type
81775e5f0aaSAlex Zinenko // Allocate space on the stack and copy the src memref descriptor
81875e5f0aaSAlex Zinenko // Set the ptr in the destination to the stack space
81975e5f0aaSAlex Zinenko auto srcMemRefType = srcType.cast<MemRefType>();
82075e5f0aaSAlex Zinenko int64_t rank = srcMemRefType.getRank();
82175e5f0aaSAlex Zinenko // ptr = AllocaOp sizeof(MemRefDescriptor)
82275e5f0aaSAlex Zinenko auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
823136d746eSJacques Pienaar loc, adaptor.getSource(), rewriter);
82475e5f0aaSAlex Zinenko // voidptr = BitCastOp srcType* to void*
82575e5f0aaSAlex Zinenko auto voidPtr =
82675e5f0aaSAlex Zinenko rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
82775e5f0aaSAlex Zinenko .getResult();
82875e5f0aaSAlex Zinenko // rank = ConstantOp srcRank
82975e5f0aaSAlex Zinenko auto rankVal = rewriter.create<LLVM::ConstantOp>(
8305b02a480SAdrian Kuegel loc, getIndexType(), rewriter.getIndexAttr(rank));
83175e5f0aaSAlex Zinenko // undef = UndefOp
83275e5f0aaSAlex Zinenko UnrankedMemRefDescriptor memRefDesc =
83375e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
83475e5f0aaSAlex Zinenko // d1 = InsertValueOp undef, rank, 0
83575e5f0aaSAlex Zinenko memRefDesc.setRank(rewriter, loc, rankVal);
83675e5f0aaSAlex Zinenko // d2 = InsertValueOp d1, voidptr, 1
83775e5f0aaSAlex Zinenko memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
83875e5f0aaSAlex Zinenko rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
83975e5f0aaSAlex Zinenko
84075e5f0aaSAlex Zinenko } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
84175e5f0aaSAlex Zinenko // Casting from unranked type to ranked.
84275e5f0aaSAlex Zinenko // The operation is assumed to be doing a correct cast. If the destination
84375e5f0aaSAlex Zinenko // type mismatches the unranked the type, it is undefined behavior.
844136d746eSJacques Pienaar UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
84575e5f0aaSAlex Zinenko // ptr = ExtractValueOp src, 1
84675e5f0aaSAlex Zinenko auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
84775e5f0aaSAlex Zinenko // castPtr = BitCastOp i8* to structTy*
84875e5f0aaSAlex Zinenko auto castPtr =
84975e5f0aaSAlex Zinenko rewriter
85075e5f0aaSAlex Zinenko .create<LLVM::BitcastOp>(
85175e5f0aaSAlex Zinenko loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
85275e5f0aaSAlex Zinenko .getResult();
85375e5f0aaSAlex Zinenko // struct = LoadOp castPtr
85475e5f0aaSAlex Zinenko auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
85575e5f0aaSAlex Zinenko rewriter.replaceOp(memRefCastOp, loadOp.getResult());
85675e5f0aaSAlex Zinenko } else {
85775e5f0aaSAlex Zinenko llvm_unreachable("Unsupported unranked memref to unranked memref cast");
85875e5f0aaSAlex Zinenko }
85975e5f0aaSAlex Zinenko }
86075e5f0aaSAlex Zinenko };
86175e5f0aaSAlex Zinenko
862ab95ba70SStephan Herhut /// Pattern to lower a `memref.copy` to llvm.
863ab95ba70SStephan Herhut ///
864ab95ba70SStephan Herhut /// For memrefs with identity layouts, the copy is lowered to the llvm
865ab95ba70SStephan Herhut /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
866ab95ba70SStephan Herhut /// to the generic `MemrefCopyFn`.
86775e5f0aaSAlex Zinenko struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
86875e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
86975e5f0aaSAlex Zinenko
87075e5f0aaSAlex Zinenko LogicalResult
lowerToMemCopyIntrinsic__anon7a9e10510111::MemRefCopyOpLowering871ab95ba70SStephan Herhut lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
872ab95ba70SStephan Herhut ConversionPatternRewriter &rewriter) const {
873ab95ba70SStephan Herhut auto loc = op.getLoc();
874136d746eSJacques Pienaar auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
875ab95ba70SStephan Herhut
876136d746eSJacques Pienaar MemRefDescriptor srcDesc(adaptor.getSource());
877ab95ba70SStephan Herhut
878ab95ba70SStephan Herhut // Compute number of elements.
879aa3cabe3SStephan Herhut Value numElements = rewriter.create<LLVM::ConstantOp>(
880aa3cabe3SStephan Herhut loc, getIndexType(), rewriter.getIndexAttr(1));
881ab95ba70SStephan Herhut for (int pos = 0; pos < srcType.getRank(); ++pos) {
882ab95ba70SStephan Herhut auto size = srcDesc.size(rewriter, loc, pos);
883aa3cabe3SStephan Herhut numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
884ab95ba70SStephan Herhut }
885aa3cabe3SStephan Herhut
886ab95ba70SStephan Herhut // Get element size.
887ab95ba70SStephan Herhut auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
888ab95ba70SStephan Herhut // Compute total.
889ab95ba70SStephan Herhut Value totalSize =
890ab95ba70SStephan Herhut rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
891ab95ba70SStephan Herhut
892ab95ba70SStephan Herhut Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
89327cd2a62SBenjamin Kramer Value srcOffset = srcDesc.offset(rewriter, loc);
89427cd2a62SBenjamin Kramer Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
89527cd2a62SBenjamin Kramer srcBasePtr, srcOffset);
896136d746eSJacques Pienaar MemRefDescriptor targetDesc(adaptor.getTarget());
897ab95ba70SStephan Herhut Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
89827cd2a62SBenjamin Kramer Value targetOffset = targetDesc.offset(rewriter, loc);
89927cd2a62SBenjamin Kramer Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
90027cd2a62SBenjamin Kramer targetBasePtr, targetOffset);
901ab95ba70SStephan Herhut Value isVolatile = rewriter.create<LLVM::ConstantOp>(
902ab95ba70SStephan Herhut loc, typeConverter->convertType(rewriter.getI1Type()),
903ab95ba70SStephan Herhut rewriter.getBoolAttr(false));
90427cd2a62SBenjamin Kramer rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
905ab95ba70SStephan Herhut isVolatile);
906ab95ba70SStephan Herhut rewriter.eraseOp(op);
907ab95ba70SStephan Herhut
908ab95ba70SStephan Herhut return success();
909ab95ba70SStephan Herhut }
910ab95ba70SStephan Herhut
911ab95ba70SStephan Herhut LogicalResult
lowerToMemCopyFunctionCall__anon7a9e10510111::MemRefCopyOpLowering912ab95ba70SStephan Herhut lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
913ab95ba70SStephan Herhut ConversionPatternRewriter &rewriter) const {
91475e5f0aaSAlex Zinenko auto loc = op.getLoc();
915136d746eSJacques Pienaar auto srcType = op.getSource().getType().cast<BaseMemRefType>();
916136d746eSJacques Pienaar auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
91775e5f0aaSAlex Zinenko
91875e5f0aaSAlex Zinenko // First make sure we have an unranked memref descriptor representation.
91975e5f0aaSAlex Zinenko auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
92075e5f0aaSAlex Zinenko auto rank = rewriter.create<LLVM::ConstantOp>(
92175e5f0aaSAlex Zinenko loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
92275e5f0aaSAlex Zinenko auto *typeConverter = getTypeConverter();
92375e5f0aaSAlex Zinenko auto ptr =
92475e5f0aaSAlex Zinenko typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
92575e5f0aaSAlex Zinenko auto voidPtr =
92675e5f0aaSAlex Zinenko rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
92775e5f0aaSAlex Zinenko .getResult();
92875e5f0aaSAlex Zinenko auto unrankedType =
92975e5f0aaSAlex Zinenko UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
93075e5f0aaSAlex Zinenko return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
93175e5f0aaSAlex Zinenko unrankedType,
93275e5f0aaSAlex Zinenko ValueRange{rank, voidPtr});
93375e5f0aaSAlex Zinenko };
93475e5f0aaSAlex Zinenko
93575e5f0aaSAlex Zinenko Value unrankedSource = srcType.hasRank()
936136d746eSJacques Pienaar ? makeUnranked(adaptor.getSource(), srcType)
937136d746eSJacques Pienaar : adaptor.getSource();
93875e5f0aaSAlex Zinenko Value unrankedTarget = targetType.hasRank()
939136d746eSJacques Pienaar ? makeUnranked(adaptor.getTarget(), targetType)
940136d746eSJacques Pienaar : adaptor.getTarget();
94175e5f0aaSAlex Zinenko
94275e5f0aaSAlex Zinenko // Now promote the unranked descriptors to the stack.
94375e5f0aaSAlex Zinenko auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
94475e5f0aaSAlex Zinenko rewriter.getIndexAttr(1));
94575e5f0aaSAlex Zinenko auto promote = [&](Value desc) {
94675e5f0aaSAlex Zinenko auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
94775e5f0aaSAlex Zinenko auto allocated =
94875e5f0aaSAlex Zinenko rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
94975e5f0aaSAlex Zinenko rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
95075e5f0aaSAlex Zinenko return allocated;
95175e5f0aaSAlex Zinenko };
95275e5f0aaSAlex Zinenko
95375e5f0aaSAlex Zinenko auto sourcePtr = promote(unrankedSource);
95475e5f0aaSAlex Zinenko auto targetPtr = promote(unrankedTarget);
95575e5f0aaSAlex Zinenko
9562219f9f5SAdrian Kuegel unsigned typeSize =
9572219f9f5SAdrian Kuegel mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
95875e5f0aaSAlex Zinenko auto elemSize = rewriter.create<LLVM::ConstantOp>(
9592219f9f5SAdrian Kuegel loc, getIndexType(), rewriter.getIndexAttr(typeSize));
96075e5f0aaSAlex Zinenko auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
96175e5f0aaSAlex Zinenko op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
96275e5f0aaSAlex Zinenko rewriter.create<LLVM::CallOp>(loc, copyFn,
96375e5f0aaSAlex Zinenko ValueRange{elemSize, sourcePtr, targetPtr});
96475e5f0aaSAlex Zinenko rewriter.eraseOp(op);
96575e5f0aaSAlex Zinenko
96675e5f0aaSAlex Zinenko return success();
96775e5f0aaSAlex Zinenko }
968ab95ba70SStephan Herhut
969ab95ba70SStephan Herhut LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefCopyOpLowering970ab95ba70SStephan Herhut matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
971ab95ba70SStephan Herhut ConversionPatternRewriter &rewriter) const override {
972136d746eSJacques Pienaar auto srcType = op.getSource().getType().cast<BaseMemRefType>();
973136d746eSJacques Pienaar auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
974ab95ba70SStephan Herhut
97527cd2a62SBenjamin Kramer auto isContiguousMemrefType = [](BaseMemRefType type) {
97627cd2a62SBenjamin Kramer auto memrefType = type.dyn_cast<mlir::MemRefType>();
97727cd2a62SBenjamin Kramer // We can use memcpy for memrefs if they have an identity layout or are
97827cd2a62SBenjamin Kramer // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
97927cd2a62SBenjamin Kramer // special case handled by memrefCopy.
98027cd2a62SBenjamin Kramer return memrefType &&
98127cd2a62SBenjamin Kramer (memrefType.getLayout().isIdentity() ||
98227cd2a62SBenjamin Kramer (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
98327cd2a62SBenjamin Kramer isStaticShapeAndContiguousRowMajor(memrefType)));
98427cd2a62SBenjamin Kramer };
98527cd2a62SBenjamin Kramer
98627cd2a62SBenjamin Kramer if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
987ab95ba70SStephan Herhut return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
988ab95ba70SStephan Herhut
989ab95ba70SStephan Herhut return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
990ab95ba70SStephan Herhut }
99175e5f0aaSAlex Zinenko };
99275e5f0aaSAlex Zinenko
99375e5f0aaSAlex Zinenko /// Extracts allocated, aligned pointers and offset from a ranked or unranked
99475e5f0aaSAlex Zinenko /// memref type. In unranked case, the fields are extracted from the underlying
99575e5f0aaSAlex Zinenko /// ranked descriptor.
extractPointersAndOffset(Location loc,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Value originalOperand,Value convertedOperand,Value * allocatedPtr,Value * alignedPtr,Value * offset=nullptr)99675e5f0aaSAlex Zinenko static void extractPointersAndOffset(Location loc,
99775e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter,
99875e5f0aaSAlex Zinenko LLVMTypeConverter &typeConverter,
99975e5f0aaSAlex Zinenko Value originalOperand,
100075e5f0aaSAlex Zinenko Value convertedOperand,
100175e5f0aaSAlex Zinenko Value *allocatedPtr, Value *alignedPtr,
100275e5f0aaSAlex Zinenko Value *offset = nullptr) {
100375e5f0aaSAlex Zinenko Type operandType = originalOperand.getType();
100475e5f0aaSAlex Zinenko if (operandType.isa<MemRefType>()) {
100575e5f0aaSAlex Zinenko MemRefDescriptor desc(convertedOperand);
100675e5f0aaSAlex Zinenko *allocatedPtr = desc.allocatedPtr(rewriter, loc);
100775e5f0aaSAlex Zinenko *alignedPtr = desc.alignedPtr(rewriter, loc);
100875e5f0aaSAlex Zinenko if (offset != nullptr)
100975e5f0aaSAlex Zinenko *offset = desc.offset(rewriter, loc);
101075e5f0aaSAlex Zinenko return;
101175e5f0aaSAlex Zinenko }
101275e5f0aaSAlex Zinenko
101375e5f0aaSAlex Zinenko unsigned memorySpace =
101475e5f0aaSAlex Zinenko operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
101575e5f0aaSAlex Zinenko Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
101675e5f0aaSAlex Zinenko Type llvmElementType = typeConverter.convertType(elementType);
101775e5f0aaSAlex Zinenko Type elementPtrPtrType = LLVM::LLVMPointerType::get(
101875e5f0aaSAlex Zinenko LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
101975e5f0aaSAlex Zinenko
102075e5f0aaSAlex Zinenko // Extract pointer to the underlying ranked memref descriptor and cast it to
102175e5f0aaSAlex Zinenko // ElemType**.
102275e5f0aaSAlex Zinenko UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
102375e5f0aaSAlex Zinenko Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
102475e5f0aaSAlex Zinenko
102575e5f0aaSAlex Zinenko *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
102675e5f0aaSAlex Zinenko rewriter, loc, underlyingDescPtr, elementPtrPtrType);
102775e5f0aaSAlex Zinenko *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
102875e5f0aaSAlex Zinenko rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
102975e5f0aaSAlex Zinenko if (offset != nullptr) {
103075e5f0aaSAlex Zinenko *offset = UnrankedMemRefDescriptor::offset(
103175e5f0aaSAlex Zinenko rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
103275e5f0aaSAlex Zinenko }
103375e5f0aaSAlex Zinenko }
103475e5f0aaSAlex Zinenko
103575e5f0aaSAlex Zinenko struct MemRefReinterpretCastOpLowering
103675e5f0aaSAlex Zinenko : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
103775e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<
103875e5f0aaSAlex Zinenko memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
103975e5f0aaSAlex Zinenko
104075e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefReinterpretCastOpLowering1041ef976337SRiver Riddle matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
104275e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
1043136d746eSJacques Pienaar Type srcType = castOp.getSource().getType();
104475e5f0aaSAlex Zinenko
104575e5f0aaSAlex Zinenko Value descriptor;
104675e5f0aaSAlex Zinenko if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
104775e5f0aaSAlex Zinenko adaptor, &descriptor)))
104875e5f0aaSAlex Zinenko return failure();
104975e5f0aaSAlex Zinenko rewriter.replaceOp(castOp, {descriptor});
105075e5f0aaSAlex Zinenko return success();
105175e5f0aaSAlex Zinenko }
105275e5f0aaSAlex Zinenko
105375e5f0aaSAlex Zinenko private:
convertSourceMemRefToDescriptor__anon7a9e10510111::MemRefReinterpretCastOpLowering105475e5f0aaSAlex Zinenko LogicalResult convertSourceMemRefToDescriptor(
105575e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter, Type srcType,
105675e5f0aaSAlex Zinenko memref::ReinterpretCastOp castOp,
105775e5f0aaSAlex Zinenko memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
105875e5f0aaSAlex Zinenko MemRefType targetMemRefType =
105975e5f0aaSAlex Zinenko castOp.getResult().getType().cast<MemRefType>();
106075e5f0aaSAlex Zinenko auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
106175e5f0aaSAlex Zinenko .dyn_cast_or_null<LLVM::LLVMStructType>();
106275e5f0aaSAlex Zinenko if (!llvmTargetDescriptorTy)
106375e5f0aaSAlex Zinenko return failure();
106475e5f0aaSAlex Zinenko
106575e5f0aaSAlex Zinenko // Create descriptor.
106675e5f0aaSAlex Zinenko Location loc = castOp.getLoc();
106775e5f0aaSAlex Zinenko auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
106875e5f0aaSAlex Zinenko
106975e5f0aaSAlex Zinenko // Set allocated and aligned pointers.
107075e5f0aaSAlex Zinenko Value allocatedPtr, alignedPtr;
107175e5f0aaSAlex Zinenko extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1072136d746eSJacques Pienaar castOp.getSource(), adaptor.getSource(),
1073136d746eSJacques Pienaar &allocatedPtr, &alignedPtr);
107475e5f0aaSAlex Zinenko desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
107575e5f0aaSAlex Zinenko desc.setAlignedPtr(rewriter, loc, alignedPtr);
107675e5f0aaSAlex Zinenko
107775e5f0aaSAlex Zinenko // Set offset.
107875e5f0aaSAlex Zinenko if (castOp.isDynamicOffset(0))
1079136d746eSJacques Pienaar desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
108075e5f0aaSAlex Zinenko else
108175e5f0aaSAlex Zinenko desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
108275e5f0aaSAlex Zinenko
108375e5f0aaSAlex Zinenko // Set sizes and strides.
108475e5f0aaSAlex Zinenko unsigned dynSizeId = 0;
108575e5f0aaSAlex Zinenko unsigned dynStrideId = 0;
108675e5f0aaSAlex Zinenko for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
108775e5f0aaSAlex Zinenko if (castOp.isDynamicSize(i))
1088136d746eSJacques Pienaar desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
108975e5f0aaSAlex Zinenko else
109075e5f0aaSAlex Zinenko desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
109175e5f0aaSAlex Zinenko
109275e5f0aaSAlex Zinenko if (castOp.isDynamicStride(i))
1093136d746eSJacques Pienaar desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
109475e5f0aaSAlex Zinenko else
109575e5f0aaSAlex Zinenko desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
109675e5f0aaSAlex Zinenko }
109775e5f0aaSAlex Zinenko *descriptor = desc;
109875e5f0aaSAlex Zinenko return success();
109975e5f0aaSAlex Zinenko }
110075e5f0aaSAlex Zinenko };
110175e5f0aaSAlex Zinenko
110275e5f0aaSAlex Zinenko struct MemRefReshapeOpLowering
110375e5f0aaSAlex Zinenko : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
110475e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
110575e5f0aaSAlex Zinenko
110675e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefReshapeOpLowering1107ef976337SRiver Riddle matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
110875e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
1109136d746eSJacques Pienaar Type srcType = reshapeOp.getSource().getType();
111075e5f0aaSAlex Zinenko
111175e5f0aaSAlex Zinenko Value descriptor;
111275e5f0aaSAlex Zinenko if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
111375e5f0aaSAlex Zinenko adaptor, &descriptor)))
111475e5f0aaSAlex Zinenko return failure();
1115ef976337SRiver Riddle rewriter.replaceOp(reshapeOp, {descriptor});
111675e5f0aaSAlex Zinenko return success();
111775e5f0aaSAlex Zinenko }
111875e5f0aaSAlex Zinenko
111975e5f0aaSAlex Zinenko private:
112075e5f0aaSAlex Zinenko LogicalResult
convertSourceMemRefToDescriptor__anon7a9e10510111::MemRefReshapeOpLowering112175e5f0aaSAlex Zinenko convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
112275e5f0aaSAlex Zinenko Type srcType, memref::ReshapeOp reshapeOp,
112375e5f0aaSAlex Zinenko memref::ReshapeOp::Adaptor adaptor,
112475e5f0aaSAlex Zinenko Value *descriptor) const {
1125136d746eSJacques Pienaar auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
11265380e30eSAshay Rane if (shapeMemRefType.hasStaticShape()) {
11275380e30eSAshay Rane MemRefType targetMemRefType =
11285380e30eSAshay Rane reshapeOp.getResult().getType().cast<MemRefType>();
11295380e30eSAshay Rane auto llvmTargetDescriptorTy =
11305380e30eSAshay Rane typeConverter->convertType(targetMemRefType)
11315380e30eSAshay Rane .dyn_cast_or_null<LLVM::LLVMStructType>();
11325380e30eSAshay Rane if (!llvmTargetDescriptorTy)
113375e5f0aaSAlex Zinenko return failure();
113475e5f0aaSAlex Zinenko
11355380e30eSAshay Rane // Create descriptor.
11365380e30eSAshay Rane Location loc = reshapeOp.getLoc();
11375380e30eSAshay Rane auto desc =
11385380e30eSAshay Rane MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11395380e30eSAshay Rane
11405380e30eSAshay Rane // Set allocated and aligned pointers.
11415380e30eSAshay Rane Value allocatedPtr, alignedPtr;
11425380e30eSAshay Rane extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1143136d746eSJacques Pienaar reshapeOp.getSource(), adaptor.getSource(),
11445380e30eSAshay Rane &allocatedPtr, &alignedPtr);
11455380e30eSAshay Rane desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
11465380e30eSAshay Rane desc.setAlignedPtr(rewriter, loc, alignedPtr);
11475380e30eSAshay Rane
11485380e30eSAshay Rane // Extract the offset and strides from the type.
11495380e30eSAshay Rane int64_t offset;
11505380e30eSAshay Rane SmallVector<int64_t> strides;
11515380e30eSAshay Rane if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
11525380e30eSAshay Rane return rewriter.notifyMatchFailure(
11535380e30eSAshay Rane reshapeOp, "failed to get stride and offset exprs");
11545380e30eSAshay Rane
11555380e30eSAshay Rane if (!isStaticStrideOrOffset(offset))
11565380e30eSAshay Rane return rewriter.notifyMatchFailure(reshapeOp,
11575380e30eSAshay Rane "dynamic offset is unsupported");
11585380e30eSAshay Rane
11595380e30eSAshay Rane desc.setConstantOffset(rewriter, loc, offset);
11605fee1799SAshay Rane
11615fee1799SAshay Rane assert(targetMemRefType.getLayout().isIdentity() &&
11625fee1799SAshay Rane "Identity layout map is a precondition of a valid reshape op");
11635fee1799SAshay Rane
11645fee1799SAshay Rane Value stride = nullptr;
11655fee1799SAshay Rane int64_t targetRank = targetMemRefType.getRank();
11665fee1799SAshay Rane for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
11675fee1799SAshay Rane if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
11685fee1799SAshay Rane // If the stride for this dimension is dynamic, then use the product
11695fee1799SAshay Rane // of the sizes of the inner dimensions.
11705fee1799SAshay Rane stride = createIndexConstant(rewriter, loc, strides[i]);
11715fee1799SAshay Rane } else if (!stride) {
11725fee1799SAshay Rane // `stride` is null only in the first iteration of the loop. However,
11735fee1799SAshay Rane // since the target memref has an identity layout, we can safely set
11745fee1799SAshay Rane // the innermost stride to 1.
11755fee1799SAshay Rane stride = createIndexConstant(rewriter, loc, 1);
11765fee1799SAshay Rane }
11775fee1799SAshay Rane
11785fee1799SAshay Rane Value dimSize;
11795fee1799SAshay Rane int64_t size = targetMemRefType.getDimSize(i);
11805fee1799SAshay Rane // If the size of this dimension is dynamic, then load it at runtime
11815fee1799SAshay Rane // from the shape operand.
11825fee1799SAshay Rane if (!ShapedType::isDynamic(size)) {
11835fee1799SAshay Rane dimSize = createIndexConstant(rewriter, loc, size);
11845fee1799SAshay Rane } else {
1185136d746eSJacques Pienaar Value shapeOp = reshapeOp.getShape();
11865fee1799SAshay Rane Value index = createIndexConstant(rewriter, loc, i);
11875fee1799SAshay Rane dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1188d4217e6cSIvan Butygin Type indexType = getIndexType();
1189d4217e6cSIvan Butygin if (dimSize.getType() != indexType)
1190d4217e6cSIvan Butygin dimSize = typeConverter->materializeTargetConversion(
1191d4217e6cSIvan Butygin rewriter, loc, indexType, dimSize);
1192d4217e6cSIvan Butygin assert(dimSize && "Invalid memref element type");
11935fee1799SAshay Rane }
11945fee1799SAshay Rane
11955fee1799SAshay Rane desc.setSize(rewriter, loc, i, dimSize);
11965fee1799SAshay Rane desc.setStride(rewriter, loc, i, stride);
11975fee1799SAshay Rane
11985fee1799SAshay Rane // Prepare the stride value for the next dimension.
11995fee1799SAshay Rane stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
12005380e30eSAshay Rane }
12015380e30eSAshay Rane
12025380e30eSAshay Rane *descriptor = desc;
12035380e30eSAshay Rane return success();
12045380e30eSAshay Rane }
12055380e30eSAshay Rane
120675e5f0aaSAlex Zinenko // The shape is a rank-1 tensor with unknown length.
120775e5f0aaSAlex Zinenko Location loc = reshapeOp.getLoc();
1208136d746eSJacques Pienaar MemRefDescriptor shapeDesc(adaptor.getShape());
120975e5f0aaSAlex Zinenko Value resultRank = shapeDesc.size(rewriter, loc, 0);
121075e5f0aaSAlex Zinenko
121175e5f0aaSAlex Zinenko // Extract address space and element type.
121275e5f0aaSAlex Zinenko auto targetType =
121375e5f0aaSAlex Zinenko reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
121475e5f0aaSAlex Zinenko unsigned addressSpace = targetType.getMemorySpaceAsInt();
121575e5f0aaSAlex Zinenko Type elementType = targetType.getElementType();
121675e5f0aaSAlex Zinenko
121775e5f0aaSAlex Zinenko // Create the unranked memref descriptor that holds the ranked one. The
121875e5f0aaSAlex Zinenko // inner descriptor is allocated on stack.
121975e5f0aaSAlex Zinenko auto targetDesc = UnrankedMemRefDescriptor::undef(
122075e5f0aaSAlex Zinenko rewriter, loc, typeConverter->convertType(targetType));
122175e5f0aaSAlex Zinenko targetDesc.setRank(rewriter, loc, resultRank);
122275e5f0aaSAlex Zinenko SmallVector<Value, 4> sizes;
122375e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
122475e5f0aaSAlex Zinenko targetDesc, sizes);
122575e5f0aaSAlex Zinenko Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
122675e5f0aaSAlex Zinenko loc, getVoidPtrType(), sizes.front(), llvm::None);
122775e5f0aaSAlex Zinenko targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
122875e5f0aaSAlex Zinenko
122975e5f0aaSAlex Zinenko // Extract pointers and offset from the source memref.
123075e5f0aaSAlex Zinenko Value allocatedPtr, alignedPtr, offset;
123175e5f0aaSAlex Zinenko extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1232136d746eSJacques Pienaar reshapeOp.getSource(), adaptor.getSource(),
123375e5f0aaSAlex Zinenko &allocatedPtr, &alignedPtr, &offset);
123475e5f0aaSAlex Zinenko
123575e5f0aaSAlex Zinenko // Set pointers and offset.
123675e5f0aaSAlex Zinenko Type llvmElementType = typeConverter->convertType(elementType);
123775e5f0aaSAlex Zinenko auto elementPtrPtrType = LLVM::LLVMPointerType::get(
123875e5f0aaSAlex Zinenko LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
123975e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
124075e5f0aaSAlex Zinenko elementPtrPtrType, allocatedPtr);
124175e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
124275e5f0aaSAlex Zinenko underlyingDescPtr,
124375e5f0aaSAlex Zinenko elementPtrPtrType, alignedPtr);
124475e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
124575e5f0aaSAlex Zinenko underlyingDescPtr, elementPtrPtrType,
124675e5f0aaSAlex Zinenko offset);
124775e5f0aaSAlex Zinenko
124875e5f0aaSAlex Zinenko // Use the offset pointer as base for further addressing. Copy over the new
124975e5f0aaSAlex Zinenko // shape and compute strides. For this, we create a loop from rank-1 to 0.
125075e5f0aaSAlex Zinenko Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
125175e5f0aaSAlex Zinenko rewriter, loc, *getTypeConverter(), underlyingDescPtr,
125275e5f0aaSAlex Zinenko elementPtrPtrType);
125375e5f0aaSAlex Zinenko Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
125475e5f0aaSAlex Zinenko rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
125575e5f0aaSAlex Zinenko Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
125675e5f0aaSAlex Zinenko Value oneIndex = createIndexConstant(rewriter, loc, 1);
125775e5f0aaSAlex Zinenko Value resultRankMinusOne =
125875e5f0aaSAlex Zinenko rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
125975e5f0aaSAlex Zinenko
126075e5f0aaSAlex Zinenko Block *initBlock = rewriter.getInsertionBlock();
126175e5f0aaSAlex Zinenko Type indexType = getTypeConverter()->getIndexType();
126275e5f0aaSAlex Zinenko Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
126375e5f0aaSAlex Zinenko
126475e5f0aaSAlex Zinenko Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1265e084679fSRiver Riddle {indexType, indexType}, {loc, loc});
126675e5f0aaSAlex Zinenko
126775e5f0aaSAlex Zinenko // Move the remaining initBlock ops to condBlock.
126875e5f0aaSAlex Zinenko Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
126975e5f0aaSAlex Zinenko rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
127075e5f0aaSAlex Zinenko
127175e5f0aaSAlex Zinenko rewriter.setInsertionPointToEnd(initBlock);
127275e5f0aaSAlex Zinenko rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
127375e5f0aaSAlex Zinenko condBlock);
127475e5f0aaSAlex Zinenko rewriter.setInsertionPointToStart(condBlock);
127575e5f0aaSAlex Zinenko Value indexArg = condBlock->getArgument(0);
127675e5f0aaSAlex Zinenko Value strideArg = condBlock->getArgument(1);
127775e5f0aaSAlex Zinenko
127875e5f0aaSAlex Zinenko Value zeroIndex = createIndexConstant(rewriter, loc, 0);
127975e5f0aaSAlex Zinenko Value pred = rewriter.create<LLVM::ICmpOp>(
128075e5f0aaSAlex Zinenko loc, IntegerType::get(rewriter.getContext(), 1),
128175e5f0aaSAlex Zinenko LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
128275e5f0aaSAlex Zinenko
128375e5f0aaSAlex Zinenko Block *bodyBlock =
128475e5f0aaSAlex Zinenko rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
128575e5f0aaSAlex Zinenko rewriter.setInsertionPointToStart(bodyBlock);
128675e5f0aaSAlex Zinenko
128775e5f0aaSAlex Zinenko // Copy size from shape to descriptor.
128875e5f0aaSAlex Zinenko Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
128975e5f0aaSAlex Zinenko Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
129075e5f0aaSAlex Zinenko loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
129175e5f0aaSAlex Zinenko Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
129275e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
129375e5f0aaSAlex Zinenko targetSizesBase, indexArg, size);
129475e5f0aaSAlex Zinenko
129575e5f0aaSAlex Zinenko // Write stride value and compute next one.
129675e5f0aaSAlex Zinenko UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
129775e5f0aaSAlex Zinenko targetStridesBase, indexArg, strideArg);
129875e5f0aaSAlex Zinenko Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
129975e5f0aaSAlex Zinenko
130075e5f0aaSAlex Zinenko // Decrement loop counter and branch back.
130175e5f0aaSAlex Zinenko Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
130275e5f0aaSAlex Zinenko rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
130375e5f0aaSAlex Zinenko condBlock);
130475e5f0aaSAlex Zinenko
130575e5f0aaSAlex Zinenko Block *remainder =
130675e5f0aaSAlex Zinenko rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
130775e5f0aaSAlex Zinenko
130875e5f0aaSAlex Zinenko // Hook up the cond exit to the remainder.
130975e5f0aaSAlex Zinenko rewriter.setInsertionPointToEnd(condBlock);
131075e5f0aaSAlex Zinenko rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
131175e5f0aaSAlex Zinenko llvm::None);
131275e5f0aaSAlex Zinenko
131375e5f0aaSAlex Zinenko // Reset position to beginning of new remainder block.
131475e5f0aaSAlex Zinenko rewriter.setInsertionPointToStart(remainder);
131575e5f0aaSAlex Zinenko
131675e5f0aaSAlex Zinenko *descriptor = targetDesc;
131775e5f0aaSAlex Zinenko return success();
131875e5f0aaSAlex Zinenko }
131975e5f0aaSAlex Zinenko };
132075e5f0aaSAlex Zinenko
1321381c3b92SYi Zhang /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1322381c3b92SYi Zhang /// `Value`s.
getAsValues(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<OpFoldResult> valueOrAttrVec)1323381c3b92SYi Zhang static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1324381c3b92SYi Zhang Type &llvmIndexType,
1325381c3b92SYi Zhang ArrayRef<OpFoldResult> valueOrAttrVec) {
1326381c3b92SYi Zhang return llvm::to_vector<4>(
1327381c3b92SYi Zhang llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1328381c3b92SYi Zhang if (auto attr = value.dyn_cast<Attribute>())
1329381c3b92SYi Zhang return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1330381c3b92SYi Zhang return value.get<Value>();
1331381c3b92SYi Zhang }));
1332381c3b92SYi Zhang }
1333381c3b92SYi Zhang
1334381c3b92SYi Zhang /// Compute a map that for a given dimension of the expanded type gives the
1335381c3b92SYi Zhang /// dimension in the collapsed type it maps to. Essentially its the inverse of
1336381c3b92SYi Zhang /// the `reassocation` maps.
1337381c3b92SYi Zhang static DenseMap<int64_t, int64_t>
getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation)1338381c3b92SYi Zhang getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1339381c3b92SYi Zhang llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1340381c3b92SYi Zhang for (auto &en : enumerate(reassociation)) {
1341381c3b92SYi Zhang for (auto dim : en.value())
1342381c3b92SYi Zhang expandedDimToCollapsedDim[dim] = en.index();
1343381c3b92SYi Zhang }
1344381c3b92SYi Zhang return expandedDimToCollapsedDim;
1345381c3b92SYi Zhang }
1346381c3b92SYi Zhang
1347381c3b92SYi Zhang static OpFoldResult
getExpandedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,ArrayRef<int64_t> outStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> inStaticShape,ArrayRef<ReassociationIndices> reassocation,DenseMap<int64_t,int64_t> & outDimToInDimMap)1348381c3b92SYi Zhang getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1349381c3b92SYi Zhang int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1350381c3b92SYi Zhang MemRefDescriptor &inDesc,
1351381c3b92SYi Zhang ArrayRef<int64_t> inStaticShape,
1352381c3b92SYi Zhang ArrayRef<ReassociationIndices> reassocation,
1353381c3b92SYi Zhang DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1354381c3b92SYi Zhang int64_t outDimSize = outStaticShape[outDimIndex];
1355381c3b92SYi Zhang if (!ShapedType::isDynamic(outDimSize))
1356381c3b92SYi Zhang return b.getIndexAttr(outDimSize);
1357381c3b92SYi Zhang
1358381c3b92SYi Zhang // Calculate the multiplication of all the out dim sizes except the
1359381c3b92SYi Zhang // current dim.
1360381c3b92SYi Zhang int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1361381c3b92SYi Zhang int64_t otherDimSizesMul = 1;
1362381c3b92SYi Zhang for (auto otherDimIndex : reassocation[inDimIndex]) {
1363381c3b92SYi Zhang if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1364381c3b92SYi Zhang continue;
1365381c3b92SYi Zhang int64_t otherDimSize = outStaticShape[otherDimIndex];
1366381c3b92SYi Zhang assert(!ShapedType::isDynamic(otherDimSize) &&
1367381c3b92SYi Zhang "single dimension cannot be expanded into multiple dynamic "
1368381c3b92SYi Zhang "dimensions");
1369381c3b92SYi Zhang otherDimSizesMul *= otherDimSize;
1370381c3b92SYi Zhang }
1371381c3b92SYi Zhang
1372381c3b92SYi Zhang // outDimSize = inDimSize / otherOutDimSizesMul
1373381c3b92SYi Zhang int64_t inDimSize = inStaticShape[inDimIndex];
1374381c3b92SYi Zhang Value inDimSizeDynamic =
1375381c3b92SYi Zhang ShapedType::isDynamic(inDimSize)
1376381c3b92SYi Zhang ? inDesc.size(b, loc, inDimIndex)
1377381c3b92SYi Zhang : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1378381c3b92SYi Zhang b.getIndexAttr(inDimSize));
1379381c3b92SYi Zhang Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1380381c3b92SYi Zhang loc, inDimSizeDynamic,
1381381c3b92SYi Zhang b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1382381c3b92SYi Zhang b.getIndexAttr(otherDimSizesMul)));
1383381c3b92SYi Zhang return outDimSizeDynamic;
1384381c3b92SYi Zhang }
1385381c3b92SYi Zhang
getCollapsedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,int64_t outDimSize,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<ReassociationIndices> reassocation)1386381c3b92SYi Zhang static OpFoldResult getCollapsedOutputDimSize(
1387381c3b92SYi Zhang OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1388381c3b92SYi Zhang int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1389381c3b92SYi Zhang MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1390381c3b92SYi Zhang if (!ShapedType::isDynamic(outDimSize))
1391381c3b92SYi Zhang return b.getIndexAttr(outDimSize);
1392381c3b92SYi Zhang
1393381c3b92SYi Zhang Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1394381c3b92SYi Zhang Value outDimSizeDynamic = c1;
1395381c3b92SYi Zhang for (auto inDimIndex : reassocation[outDimIndex]) {
1396381c3b92SYi Zhang int64_t inDimSize = inStaticShape[inDimIndex];
1397381c3b92SYi Zhang Value inDimSizeDynamic =
1398381c3b92SYi Zhang ShapedType::isDynamic(inDimSize)
1399381c3b92SYi Zhang ? inDesc.size(b, loc, inDimIndex)
1400381c3b92SYi Zhang : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1401381c3b92SYi Zhang b.getIndexAttr(inDimSize));
1402381c3b92SYi Zhang outDimSizeDynamic =
1403381c3b92SYi Zhang b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1404381c3b92SYi Zhang }
1405381c3b92SYi Zhang return outDimSizeDynamic;
1406381c3b92SYi Zhang }
1407381c3b92SYi Zhang
1408381c3b92SYi Zhang static SmallVector<OpFoldResult, 4>
getCollapsedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1409381c3b92SYi Zhang getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1410e1318078SYi Zhang ArrayRef<ReassociationIndices> reassociation,
1411381c3b92SYi Zhang ArrayRef<int64_t> inStaticShape,
1412381c3b92SYi Zhang MemRefDescriptor &inDesc,
1413381c3b92SYi Zhang ArrayRef<int64_t> outStaticShape) {
1414381c3b92SYi Zhang return llvm::to_vector<4>(llvm::map_range(
1415381c3b92SYi Zhang llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1416381c3b92SYi Zhang return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1417381c3b92SYi Zhang outStaticShape[outDimIndex],
1418e1318078SYi Zhang inStaticShape, inDesc, reassociation);
1419381c3b92SYi Zhang }));
1420381c3b92SYi Zhang }
1421381c3b92SYi Zhang
1422381c3b92SYi Zhang static SmallVector<OpFoldResult, 4>
getExpandedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1423381c3b92SYi Zhang getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1424e1318078SYi Zhang ArrayRef<ReassociationIndices> reassociation,
1425381c3b92SYi Zhang ArrayRef<int64_t> inStaticShape,
1426381c3b92SYi Zhang MemRefDescriptor &inDesc,
1427381c3b92SYi Zhang ArrayRef<int64_t> outStaticShape) {
1428381c3b92SYi Zhang DenseMap<int64_t, int64_t> outDimToInDimMap =
1429e1318078SYi Zhang getExpandedDimToCollapsedDimMap(reassociation);
1430381c3b92SYi Zhang return llvm::to_vector<4>(llvm::map_range(
1431381c3b92SYi Zhang llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1432381c3b92SYi Zhang return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1433381c3b92SYi Zhang outStaticShape, inDesc, inStaticShape,
1434e1318078SYi Zhang reassociation, outDimToInDimMap);
1435381c3b92SYi Zhang }));
1436381c3b92SYi Zhang }
1437381c3b92SYi Zhang
1438381c3b92SYi Zhang static SmallVector<Value>
getDynamicOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1439381c3b92SYi Zhang getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1440e1318078SYi Zhang ArrayRef<ReassociationIndices> reassociation,
1441381c3b92SYi Zhang ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1442381c3b92SYi Zhang ArrayRef<int64_t> outStaticShape) {
1443381c3b92SYi Zhang return outStaticShape.size() < inStaticShape.size()
1444381c3b92SYi Zhang ? getAsValues(b, loc, llvmIndexType,
1445381c3b92SYi Zhang getCollapsedOutputShape(b, loc, llvmIndexType,
1446e1318078SYi Zhang reassociation, inStaticShape,
1447381c3b92SYi Zhang inDesc, outStaticShape))
1448381c3b92SYi Zhang : getAsValues(b, loc, llvmIndexType,
1449381c3b92SYi Zhang getExpandedOutputShape(b, loc, llvmIndexType,
1450e1318078SYi Zhang reassociation, inStaticShape,
1451381c3b92SYi Zhang inDesc, outStaticShape));
1452381c3b92SYi Zhang }
1453381c3b92SYi Zhang
fillInStridesForExpandedMemDescriptor(OpBuilder & b,Location loc,MemRefType srcType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1454e1318078SYi Zhang static void fillInStridesForExpandedMemDescriptor(
1455e1318078SYi Zhang OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1456e1318078SYi Zhang MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1457e1318078SYi Zhang // See comments for computeExpandedLayoutMap for details on how the strides
1458e1318078SYi Zhang // are calculated.
1459e1318078SYi Zhang for (auto &en : llvm::enumerate(reassociation)) {
1460e1318078SYi Zhang auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1461e1318078SYi Zhang for (auto dstIndex : llvm::reverse(en.value())) {
1462e1318078SYi Zhang dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1463e1318078SYi Zhang Value size = dstDesc.size(b, loc, dstIndex);
1464e1318078SYi Zhang currentStrideToExpand =
1465e1318078SYi Zhang b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1466e1318078SYi Zhang }
1467e1318078SYi Zhang }
1468e1318078SYi Zhang }
1469e1318078SYi Zhang
fillInStridesForCollapsedMemDescriptor(ConversionPatternRewriter & rewriter,Location loc,Operation * op,TypeConverter * typeConverter,MemRefType srcType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1470e1318078SYi Zhang static void fillInStridesForCollapsedMemDescriptor(
1471e1318078SYi Zhang ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1472e1318078SYi Zhang TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1473e1318078SYi Zhang MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1474e1318078SYi Zhang // See comments for computeCollapsedLayoutMap for details on how the strides
1475e1318078SYi Zhang // are calculated.
1476e1318078SYi Zhang auto srcShape = srcType.getShape();
1477e1318078SYi Zhang for (auto &en : llvm::enumerate(reassociation)) {
1478e1318078SYi Zhang rewriter.setInsertionPoint(op);
1479e1318078SYi Zhang auto dstIndex = en.index();
1480e1318078SYi Zhang ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1481e1318078SYi Zhang while (srcShape[ref.back()] == 1 && ref.size() > 1)
1482e1318078SYi Zhang ref = ref.drop_back();
1483e1318078SYi Zhang if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1484e1318078SYi Zhang dstDesc.setStride(rewriter, loc, dstIndex,
1485e1318078SYi Zhang srcDesc.stride(rewriter, loc, ref.back()));
1486e1318078SYi Zhang } else {
1487e1318078SYi Zhang // Iterate over the source strides in reverse order. Skip over the
1488e1318078SYi Zhang // dimensions whose size is 1.
1489e1318078SYi Zhang // TODO: we should take the minimum stride in the reassociation group
1490e1318078SYi Zhang // instead of just the first where the dimension is not 1.
1491e1318078SYi Zhang //
1492e1318078SYi Zhang // +------------------------------------------------------+
1493e1318078SYi Zhang // | curEntry: |
1494e1318078SYi Zhang // | %srcStride = strides[srcIndex] |
1495e1318078SYi Zhang // | %neOne = cmp sizes[srcIndex],1 +--+
1496e1318078SYi Zhang // | cf.cond_br %neOne, continue(%srcStride), nextEntry | |
1497e1318078SYi Zhang // +-------------------------+----------------------------+ |
1498e1318078SYi Zhang // | |
1499e1318078SYi Zhang // v |
1500e1318078SYi Zhang // +-----------------------------+ |
1501e1318078SYi Zhang // | nextEntry: | |
1502e1318078SYi Zhang // | ... +---+ |
1503e1318078SYi Zhang // +--------------+--------------+ | |
1504e1318078SYi Zhang // | | |
1505e1318078SYi Zhang // v | |
1506e1318078SYi Zhang // +-----------------------------+ | |
1507e1318078SYi Zhang // | nextEntry: | | |
1508e1318078SYi Zhang // | ... | | |
1509e1318078SYi Zhang // +--------------+--------------+ | +--------+
1510e1318078SYi Zhang // | | |
1511e1318078SYi Zhang // v v v
1512e1318078SYi Zhang // +--------------------------------------------------+
1513e1318078SYi Zhang // | continue(%newStride): |
1514e1318078SYi Zhang // | %newMemRefDes = setStride(%newStride,dstIndex) |
1515e1318078SYi Zhang // +--------------------------------------------------+
1516e1318078SYi Zhang OpBuilder::InsertionGuard guard(rewriter);
1517e1318078SYi Zhang Block *initBlock = rewriter.getInsertionBlock();
1518e1318078SYi Zhang Block *continueBlock =
1519e1318078SYi Zhang rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1520e1318078SYi Zhang continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1521e1318078SYi Zhang rewriter.setInsertionPointToStart(continueBlock);
1522e1318078SYi Zhang dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1523e1318078SYi Zhang
1524e1318078SYi Zhang Block *curEntryBlock = initBlock;
1525e1318078SYi Zhang Block *nextEntryBlock;
1526e1318078SYi Zhang for (auto srcIndex : llvm::reverse(ref)) {
1527e1318078SYi Zhang if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1528e1318078SYi Zhang continue;
1529e1318078SYi Zhang rewriter.setInsertionPointToEnd(curEntryBlock);
1530e1318078SYi Zhang Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1531e1318078SYi Zhang if (srcIndex == ref.front()) {
1532e1318078SYi Zhang rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1533e1318078SYi Zhang break;
1534e1318078SYi Zhang }
1535e1318078SYi Zhang Value one = rewriter.create<LLVM::ConstantOp>(
1536e1318078SYi Zhang loc, typeConverter->convertType(rewriter.getI64Type()),
1537e1318078SYi Zhang rewriter.getI32IntegerAttr(1));
1538e1318078SYi Zhang Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1539e1318078SYi Zhang loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1540e1318078SYi Zhang one);
1541e1318078SYi Zhang {
1542e1318078SYi Zhang OpBuilder::InsertionGuard guard(rewriter);
1543e1318078SYi Zhang nextEntryBlock = rewriter.createBlock(
1544e1318078SYi Zhang initBlock->getParent(), Region::iterator(continueBlock), {});
1545e1318078SYi Zhang }
1546e1318078SYi Zhang rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1547e1318078SYi Zhang srcStride, nextEntryBlock, llvm::None);
1548e1318078SYi Zhang curEntryBlock = nextEntryBlock;
1549e1318078SYi Zhang }
1550e1318078SYi Zhang }
1551e1318078SYi Zhang }
1552e1318078SYi Zhang }
1553e1318078SYi Zhang
fillInDynamicStridesForMemDescriptor(ConversionPatternRewriter & b,Location loc,Operation * op,TypeConverter * typeConverter,MemRefType srcType,MemRefType dstType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1554e1318078SYi Zhang static void fillInDynamicStridesForMemDescriptor(
1555e1318078SYi Zhang ConversionPatternRewriter &b, Location loc, Operation *op,
1556e1318078SYi Zhang TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1557e1318078SYi Zhang MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1558e1318078SYi Zhang ArrayRef<ReassociationIndices> reassociation) {
1559e1318078SYi Zhang if (srcType.getRank() > dstType.getRank())
1560e1318078SYi Zhang fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1561e1318078SYi Zhang srcDesc, dstDesc, reassociation);
1562e1318078SYi Zhang else
1563e1318078SYi Zhang fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1564e1318078SYi Zhang reassociation);
1565e1318078SYi Zhang }
1566e1318078SYi Zhang
156746ef86b5SAlexander Belyaev // ReshapeOp creates a new view descriptor of the proper rank.
156846ef86b5SAlexander Belyaev // For now, the only conversion supported is for target MemRef with static sizes
156946ef86b5SAlexander Belyaev // and strides.
157046ef86b5SAlexander Belyaev template <typename ReshapeOp>
157146ef86b5SAlexander Belyaev class ReassociatingReshapeOpConversion
157246ef86b5SAlexander Belyaev : public ConvertOpToLLVMPattern<ReshapeOp> {
157346ef86b5SAlexander Belyaev public:
157446ef86b5SAlexander Belyaev using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
157546ef86b5SAlexander Belyaev using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
157646ef86b5SAlexander Belyaev
157746ef86b5SAlexander Belyaev LogicalResult
matchAndRewrite(ReshapeOp reshapeOp,typename ReshapeOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1578ef976337SRiver Riddle matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
157946ef86b5SAlexander Belyaev ConversionPatternRewriter &rewriter) const override {
158046ef86b5SAlexander Belyaev MemRefType dstType = reshapeOp.getResultType();
1581381c3b92SYi Zhang MemRefType srcType = reshapeOp.getSrcType();
1582499703e9SBenoit Jacob
158346ef86b5SAlexander Belyaev int64_t offset;
158446ef86b5SAlexander Belyaev SmallVector<int64_t, 4> strides;
1585381c3b92SYi Zhang if (failed(getStridesAndOffset(dstType, strides, offset))) {
1586381c3b92SYi Zhang return rewriter.notifyMatchFailure(
1587381c3b92SYi Zhang reshapeOp, "failed to get stride and offset exprs");
1588381c3b92SYi Zhang }
158946ef86b5SAlexander Belyaev
1590136d746eSJacques Pienaar MemRefDescriptor srcDesc(adaptor.getSrc());
159146ef86b5SAlexander Belyaev Location loc = reshapeOp->getLoc();
1592381c3b92SYi Zhang auto dstDesc = MemRefDescriptor::undef(
1593381c3b92SYi Zhang rewriter, loc, this->typeConverter->convertType(dstType));
1594381c3b92SYi Zhang dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1595381c3b92SYi Zhang dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1596381c3b92SYi Zhang dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1597381c3b92SYi Zhang
1598381c3b92SYi Zhang ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1599381c3b92SYi Zhang ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1600381c3b92SYi Zhang Type llvmIndexType =
1601381c3b92SYi Zhang this->typeConverter->convertType(rewriter.getIndexType());
1602381c3b92SYi Zhang SmallVector<Value> dstShape = getDynamicOutputShape(
1603381c3b92SYi Zhang rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1604381c3b92SYi Zhang srcStaticShape, srcDesc, dstStaticShape);
1605381c3b92SYi Zhang for (auto &en : llvm::enumerate(dstShape))
1606381c3b92SYi Zhang dstDesc.setSize(rewriter, loc, en.index(), en.value());
1607381c3b92SYi Zhang
16085380e30eSAshay Rane if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1609381c3b92SYi Zhang for (auto &en : llvm::enumerate(strides))
1610381c3b92SYi Zhang dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1611e1318078SYi Zhang } else if (srcType.getLayout().isIdentity() &&
1612e1318078SYi Zhang dstType.getLayout().isIdentity()) {
1613381c3b92SYi Zhang Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1614381c3b92SYi Zhang rewriter.getIndexAttr(1));
1615381c3b92SYi Zhang Value stride = c1;
1616381c3b92SYi Zhang for (auto dimIndex :
1617381c3b92SYi Zhang llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1618381c3b92SYi Zhang dstDesc.setStride(rewriter, loc, dimIndex, stride);
1619381c3b92SYi Zhang stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1620381c3b92SYi Zhang }
1621e1318078SYi Zhang } else {
1622e1318078SYi Zhang // There could be mixed static/dynamic strides. For simplicity, we
1623e1318078SYi Zhang // recompute all strides if there is at least one dynamic stride.
1624e1318078SYi Zhang fillInDynamicStridesForMemDescriptor(
1625e1318078SYi Zhang rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1626e1318078SYi Zhang srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1627381c3b92SYi Zhang }
1628381c3b92SYi Zhang rewriter.replaceOp(reshapeOp, {dstDesc});
162946ef86b5SAlexander Belyaev return success();
163046ef86b5SAlexander Belyaev }
163146ef86b5SAlexander Belyaev };
1632381c3b92SYi Zhang
163375e5f0aaSAlex Zinenko /// Conversion pattern that transforms a subview op into:
163475e5f0aaSAlex Zinenko /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
163575e5f0aaSAlex Zinenko /// 2. Updates to the descriptor to introduce the data ptr, offset, size
163675e5f0aaSAlex Zinenko /// and stride.
163775e5f0aaSAlex Zinenko /// The subview op is replaced by the descriptor.
163875e5f0aaSAlex Zinenko struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
163975e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
164075e5f0aaSAlex Zinenko
164175e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::SubViewOpLowering1642ef976337SRiver Riddle matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
164375e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
164475e5f0aaSAlex Zinenko auto loc = subViewOp.getLoc();
164575e5f0aaSAlex Zinenko
1646136d746eSJacques Pienaar auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
164775e5f0aaSAlex Zinenko auto sourceElementTy =
164875e5f0aaSAlex Zinenko typeConverter->convertType(sourceMemRefType.getElementType());
164975e5f0aaSAlex Zinenko
165075e5f0aaSAlex Zinenko auto viewMemRefType = subViewOp.getType();
1651136d746eSJacques Pienaar auto inferredType =
1652136d746eSJacques Pienaar memref::SubViewOp::inferResultType(
165375e5f0aaSAlex Zinenko subViewOp.getSourceType(),
1654136d746eSJacques Pienaar extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
1655136d746eSJacques Pienaar extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
1656136d746eSJacques Pienaar extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
165775e5f0aaSAlex Zinenko .cast<MemRefType>();
165875e5f0aaSAlex Zinenko auto targetElementTy =
165975e5f0aaSAlex Zinenko typeConverter->convertType(viewMemRefType.getElementType());
166075e5f0aaSAlex Zinenko auto targetDescTy = typeConverter->convertType(viewMemRefType);
166175e5f0aaSAlex Zinenko if (!sourceElementTy || !targetDescTy || !targetElementTy ||
166275e5f0aaSAlex Zinenko !LLVM::isCompatibleType(sourceElementTy) ||
166375e5f0aaSAlex Zinenko !LLVM::isCompatibleType(targetElementTy) ||
166475e5f0aaSAlex Zinenko !LLVM::isCompatibleType(targetDescTy))
166575e5f0aaSAlex Zinenko return failure();
166675e5f0aaSAlex Zinenko
166775e5f0aaSAlex Zinenko // Extract the offset and strides from the type.
166875e5f0aaSAlex Zinenko int64_t offset;
166975e5f0aaSAlex Zinenko SmallVector<int64_t, 4> strides;
167075e5f0aaSAlex Zinenko auto successStrides = getStridesAndOffset(inferredType, strides, offset);
167175e5f0aaSAlex Zinenko if (failed(successStrides))
167275e5f0aaSAlex Zinenko return failure();
167375e5f0aaSAlex Zinenko
167475e5f0aaSAlex Zinenko // Create the descriptor.
1675ef976337SRiver Riddle if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
167675e5f0aaSAlex Zinenko return failure();
1677ef976337SRiver Riddle MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
167875e5f0aaSAlex Zinenko auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
167975e5f0aaSAlex Zinenko
168075e5f0aaSAlex Zinenko // Copy the buffer pointer from the old descriptor to the new one.
168175e5f0aaSAlex Zinenko Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
168275e5f0aaSAlex Zinenko Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
168375e5f0aaSAlex Zinenko loc,
168475e5f0aaSAlex Zinenko LLVM::LLVMPointerType::get(targetElementTy,
168575e5f0aaSAlex Zinenko viewMemRefType.getMemorySpaceAsInt()),
168675e5f0aaSAlex Zinenko extracted);
168775e5f0aaSAlex Zinenko targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
168875e5f0aaSAlex Zinenko
168975e5f0aaSAlex Zinenko // Copy the aligned pointer from the old descriptor to the new one.
169075e5f0aaSAlex Zinenko extracted = sourceMemRef.alignedPtr(rewriter, loc);
169175e5f0aaSAlex Zinenko bitcastPtr = rewriter.create<LLVM::BitcastOp>(
169275e5f0aaSAlex Zinenko loc,
169375e5f0aaSAlex Zinenko LLVM::LLVMPointerType::get(targetElementTy,
169475e5f0aaSAlex Zinenko viewMemRefType.getMemorySpaceAsInt()),
169575e5f0aaSAlex Zinenko extracted);
169675e5f0aaSAlex Zinenko targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
169775e5f0aaSAlex Zinenko
16984cf9bf6cSMaheshRavishankar size_t inferredShapeRank = inferredType.getRank();
16994cf9bf6cSMaheshRavishankar size_t resultShapeRank = viewMemRefType.getRank();
170075e5f0aaSAlex Zinenko
170175e5f0aaSAlex Zinenko // Extract strides needed to compute offset.
170275e5f0aaSAlex Zinenko SmallVector<Value, 4> strideValues;
170375e5f0aaSAlex Zinenko strideValues.reserve(inferredShapeRank);
170475e5f0aaSAlex Zinenko for (unsigned i = 0; i < inferredShapeRank; ++i)
170575e5f0aaSAlex Zinenko strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
170675e5f0aaSAlex Zinenko
170775e5f0aaSAlex Zinenko // Offset.
170875e5f0aaSAlex Zinenko auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
170975e5f0aaSAlex Zinenko if (!ShapedType::isDynamicStrideOrOffset(offset)) {
171075e5f0aaSAlex Zinenko targetMemRef.setConstantOffset(rewriter, loc, offset);
171175e5f0aaSAlex Zinenko } else {
171275e5f0aaSAlex Zinenko Value baseOffset = sourceMemRef.offset(rewriter, loc);
171375e5f0aaSAlex Zinenko // `inferredShapeRank` may be larger than the number of offset operands
171475e5f0aaSAlex Zinenko // because of trailing semantics. In this case, the offset is guaranteed
171575e5f0aaSAlex Zinenko // to be interpreted as 0 and we can just skip the extra dimensions.
171675e5f0aaSAlex Zinenko for (unsigned i = 0, e = std::min(inferredShapeRank,
171775e5f0aaSAlex Zinenko subViewOp.getMixedOffsets().size());
171875e5f0aaSAlex Zinenko i < e; ++i) {
171975e5f0aaSAlex Zinenko Value offset =
172075e5f0aaSAlex Zinenko // TODO: need OpFoldResult ODS adaptor to clean this up.
172175e5f0aaSAlex Zinenko subViewOp.isDynamicOffset(i)
1722ef976337SRiver Riddle ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
172375e5f0aaSAlex Zinenko : rewriter.create<LLVM::ConstantOp>(
172475e5f0aaSAlex Zinenko loc, llvmIndexType,
172575e5f0aaSAlex Zinenko rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
172675e5f0aaSAlex Zinenko Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
172775e5f0aaSAlex Zinenko baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
172875e5f0aaSAlex Zinenko }
172975e5f0aaSAlex Zinenko targetMemRef.setOffset(rewriter, loc, baseOffset);
173075e5f0aaSAlex Zinenko }
173175e5f0aaSAlex Zinenko
173275e5f0aaSAlex Zinenko // Update sizes and strides.
173375e5f0aaSAlex Zinenko SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
173475e5f0aaSAlex Zinenko SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
173575e5f0aaSAlex Zinenko assert(mixedSizes.size() == mixedStrides.size() &&
173675e5f0aaSAlex Zinenko "expected sizes and strides of equal length");
17376635c12aSBenjamin Kramer llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
173875e5f0aaSAlex Zinenko for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
173975e5f0aaSAlex Zinenko i >= 0 && j >= 0; --i) {
17406635c12aSBenjamin Kramer if (unusedDims.test(i))
174175e5f0aaSAlex Zinenko continue;
174275e5f0aaSAlex Zinenko
174375e5f0aaSAlex Zinenko // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
174475e5f0aaSAlex Zinenko // In this case, the size is guaranteed to be interpreted as Dim and the
174575e5f0aaSAlex Zinenko // stride as 1.
174675e5f0aaSAlex Zinenko Value size, stride;
174775e5f0aaSAlex Zinenko if (static_cast<unsigned>(i) >= mixedSizes.size()) {
174875e5f0aaSAlex Zinenko // If the static size is available, use it directly. This is similar to
174975e5f0aaSAlex Zinenko // the folding of dim(constant-op) but removes the need for dim to be
175075e5f0aaSAlex Zinenko // aware of LLVM constants and for this pass to be aware of std
175175e5f0aaSAlex Zinenko // constants.
175275e5f0aaSAlex Zinenko int64_t staticSize =
1753136d746eSJacques Pienaar subViewOp.getSource().getType().cast<MemRefType>().getShape()[i];
175475e5f0aaSAlex Zinenko if (staticSize != ShapedType::kDynamicSize) {
175575e5f0aaSAlex Zinenko size = rewriter.create<LLVM::ConstantOp>(
175675e5f0aaSAlex Zinenko loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
175775e5f0aaSAlex Zinenko } else {
175875e5f0aaSAlex Zinenko Value pos = rewriter.create<LLVM::ConstantOp>(
175975e5f0aaSAlex Zinenko loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1760881dc34fSAlex Zinenko Value dim =
1761136d746eSJacques Pienaar rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos);
1762881dc34fSAlex Zinenko auto cast = rewriter.create<UnrealizedConversionCastOp>(
1763881dc34fSAlex Zinenko loc, llvmIndexType, dim);
1764881dc34fSAlex Zinenko size = cast.getResult(0);
176575e5f0aaSAlex Zinenko }
176675e5f0aaSAlex Zinenko stride = rewriter.create<LLVM::ConstantOp>(
176775e5f0aaSAlex Zinenko loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
176875e5f0aaSAlex Zinenko } else {
176975e5f0aaSAlex Zinenko // TODO: need OpFoldResult ODS adaptor to clean this up.
177075e5f0aaSAlex Zinenko size =
177175e5f0aaSAlex Zinenko subViewOp.isDynamicSize(i)
1772ef976337SRiver Riddle ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
177375e5f0aaSAlex Zinenko : rewriter.create<LLVM::ConstantOp>(
177475e5f0aaSAlex Zinenko loc, llvmIndexType,
177575e5f0aaSAlex Zinenko rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
177675e5f0aaSAlex Zinenko if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
177775e5f0aaSAlex Zinenko stride = rewriter.create<LLVM::ConstantOp>(
177875e5f0aaSAlex Zinenko loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
177975e5f0aaSAlex Zinenko } else {
1780ef976337SRiver Riddle stride =
1781ef976337SRiver Riddle subViewOp.isDynamicStride(i)
1782ef976337SRiver Riddle ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
178375e5f0aaSAlex Zinenko : rewriter.create<LLVM::ConstantOp>(
178475e5f0aaSAlex Zinenko loc, llvmIndexType,
178575e5f0aaSAlex Zinenko rewriter.getI64IntegerAttr(
178675e5f0aaSAlex Zinenko subViewOp.getStaticStride(i)));
178775e5f0aaSAlex Zinenko stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
178875e5f0aaSAlex Zinenko }
178975e5f0aaSAlex Zinenko }
179075e5f0aaSAlex Zinenko targetMemRef.setSize(rewriter, loc, j, size);
179175e5f0aaSAlex Zinenko targetMemRef.setStride(rewriter, loc, j, stride);
179275e5f0aaSAlex Zinenko j--;
179375e5f0aaSAlex Zinenko }
179475e5f0aaSAlex Zinenko
179575e5f0aaSAlex Zinenko rewriter.replaceOp(subViewOp, {targetMemRef});
179675e5f0aaSAlex Zinenko return success();
179775e5f0aaSAlex Zinenko }
179875e5f0aaSAlex Zinenko };
179975e5f0aaSAlex Zinenko
180075e5f0aaSAlex Zinenko /// Conversion pattern that transforms a transpose op into:
180175e5f0aaSAlex Zinenko /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
180275e5f0aaSAlex Zinenko /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
180375e5f0aaSAlex Zinenko /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
180475e5f0aaSAlex Zinenko /// and stride. Size and stride are permutations of the original values.
180575e5f0aaSAlex Zinenko /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
180675e5f0aaSAlex Zinenko /// The transpose op is replaced by the alloca'ed pointer.
180775e5f0aaSAlex Zinenko class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
180875e5f0aaSAlex Zinenko public:
180975e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
181075e5f0aaSAlex Zinenko
181175e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite(memref::TransposeOp transposeOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1812ef976337SRiver Riddle matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
181375e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
181475e5f0aaSAlex Zinenko auto loc = transposeOp.getLoc();
1815136d746eSJacques Pienaar MemRefDescriptor viewMemRef(adaptor.getIn());
181675e5f0aaSAlex Zinenko
181775e5f0aaSAlex Zinenko // No permutation, early exit.
1818136d746eSJacques Pienaar if (transposeOp.getPermutation().isIdentity())
181975e5f0aaSAlex Zinenko return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
182075e5f0aaSAlex Zinenko
182175e5f0aaSAlex Zinenko auto targetMemRef = MemRefDescriptor::undef(
182275e5f0aaSAlex Zinenko rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
182375e5f0aaSAlex Zinenko
182475e5f0aaSAlex Zinenko // Copy the base and aligned pointers from the old descriptor to the new
182575e5f0aaSAlex Zinenko // one.
182675e5f0aaSAlex Zinenko targetMemRef.setAllocatedPtr(rewriter, loc,
182775e5f0aaSAlex Zinenko viewMemRef.allocatedPtr(rewriter, loc));
182875e5f0aaSAlex Zinenko targetMemRef.setAlignedPtr(rewriter, loc,
182975e5f0aaSAlex Zinenko viewMemRef.alignedPtr(rewriter, loc));
183075e5f0aaSAlex Zinenko
183175e5f0aaSAlex Zinenko // Copy the offset pointer from the old descriptor to the new one.
183275e5f0aaSAlex Zinenko targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
183375e5f0aaSAlex Zinenko
183475e5f0aaSAlex Zinenko // Iterate over the dimensions and apply size/stride permutation.
1835e4853be2SMehdi Amini for (const auto &en :
1836136d746eSJacques Pienaar llvm::enumerate(transposeOp.getPermutation().getResults())) {
183775e5f0aaSAlex Zinenko int sourcePos = en.index();
183875e5f0aaSAlex Zinenko int targetPos = en.value().cast<AffineDimExpr>().getPosition();
183975e5f0aaSAlex Zinenko targetMemRef.setSize(rewriter, loc, targetPos,
184075e5f0aaSAlex Zinenko viewMemRef.size(rewriter, loc, sourcePos));
184175e5f0aaSAlex Zinenko targetMemRef.setStride(rewriter, loc, targetPos,
184275e5f0aaSAlex Zinenko viewMemRef.stride(rewriter, loc, sourcePos));
184375e5f0aaSAlex Zinenko }
184475e5f0aaSAlex Zinenko
184575e5f0aaSAlex Zinenko rewriter.replaceOp(transposeOp, {targetMemRef});
184675e5f0aaSAlex Zinenko return success();
184775e5f0aaSAlex Zinenko }
184875e5f0aaSAlex Zinenko };
184975e5f0aaSAlex Zinenko
185075e5f0aaSAlex Zinenko /// Conversion pattern that transforms an op into:
185175e5f0aaSAlex Zinenko /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
185275e5f0aaSAlex Zinenko /// 2. Updates to the descriptor to introduce the data ptr, offset, size
185375e5f0aaSAlex Zinenko /// and stride.
185475e5f0aaSAlex Zinenko /// The view op is replaced by the descriptor.
185575e5f0aaSAlex Zinenko struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
185675e5f0aaSAlex Zinenko using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
185775e5f0aaSAlex Zinenko
185875e5f0aaSAlex Zinenko // Build and return the value for the idx^th shape dimension, either by
185975e5f0aaSAlex Zinenko // returning the constant shape dimension or counting the proper dynamic size.
getSize__anon7a9e10510111::ViewOpLowering186075e5f0aaSAlex Zinenko Value getSize(ConversionPatternRewriter &rewriter, Location loc,
186175e5f0aaSAlex Zinenko ArrayRef<int64_t> shape, ValueRange dynamicSizes,
186275e5f0aaSAlex Zinenko unsigned idx) const {
186375e5f0aaSAlex Zinenko assert(idx < shape.size());
186475e5f0aaSAlex Zinenko if (!ShapedType::isDynamic(shape[idx]))
186575e5f0aaSAlex Zinenko return createIndexConstant(rewriter, loc, shape[idx]);
186675e5f0aaSAlex Zinenko // Count the number of dynamic dims in range [0, idx]
1867380a1b20SKazu Hirata unsigned nDynamic =
1868380a1b20SKazu Hirata llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
186975e5f0aaSAlex Zinenko return dynamicSizes[nDynamic];
187075e5f0aaSAlex Zinenko }
187175e5f0aaSAlex Zinenko
187275e5f0aaSAlex Zinenko // Build and return the idx^th stride, either by returning the constant stride
187375e5f0aaSAlex Zinenko // or by computing the dynamic stride from the current `runningStride` and
187475e5f0aaSAlex Zinenko // `nextSize`. The caller should keep a running stride and update it with the
187575e5f0aaSAlex Zinenko // result returned by this function.
getStride__anon7a9e10510111::ViewOpLowering187675e5f0aaSAlex Zinenko Value getStride(ConversionPatternRewriter &rewriter, Location loc,
187775e5f0aaSAlex Zinenko ArrayRef<int64_t> strides, Value nextSize,
187875e5f0aaSAlex Zinenko Value runningStride, unsigned idx) const {
187975e5f0aaSAlex Zinenko assert(idx < strides.size());
1880676bfb2aSRiver Riddle if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
188175e5f0aaSAlex Zinenko return createIndexConstant(rewriter, loc, strides[idx]);
188275e5f0aaSAlex Zinenko if (nextSize)
188375e5f0aaSAlex Zinenko return runningStride
188475e5f0aaSAlex Zinenko ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
188575e5f0aaSAlex Zinenko : nextSize;
188675e5f0aaSAlex Zinenko assert(!runningStride);
188775e5f0aaSAlex Zinenko return createIndexConstant(rewriter, loc, 1);
188875e5f0aaSAlex Zinenko }
188975e5f0aaSAlex Zinenko
189075e5f0aaSAlex Zinenko LogicalResult
matchAndRewrite__anon7a9e10510111::ViewOpLowering1891ef976337SRiver Riddle matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
189275e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const override {
189375e5f0aaSAlex Zinenko auto loc = viewOp.getLoc();
189475e5f0aaSAlex Zinenko
189575e5f0aaSAlex Zinenko auto viewMemRefType = viewOp.getType();
189675e5f0aaSAlex Zinenko auto targetElementTy =
189775e5f0aaSAlex Zinenko typeConverter->convertType(viewMemRefType.getElementType());
189875e5f0aaSAlex Zinenko auto targetDescTy = typeConverter->convertType(viewMemRefType);
189975e5f0aaSAlex Zinenko if (!targetDescTy || !targetElementTy ||
190075e5f0aaSAlex Zinenko !LLVM::isCompatibleType(targetElementTy) ||
190175e5f0aaSAlex Zinenko !LLVM::isCompatibleType(targetDescTy))
190275e5f0aaSAlex Zinenko return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
190375e5f0aaSAlex Zinenko failure();
190475e5f0aaSAlex Zinenko
190575e5f0aaSAlex Zinenko int64_t offset;
190675e5f0aaSAlex Zinenko SmallVector<int64_t, 4> strides;
190775e5f0aaSAlex Zinenko auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
190875e5f0aaSAlex Zinenko if (failed(successStrides))
190975e5f0aaSAlex Zinenko return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
191075e5f0aaSAlex Zinenko assert(offset == 0 && "expected offset to be 0");
191175e5f0aaSAlex Zinenko
1912705f048cSEugene Zhulenev // Target memref must be contiguous in memory (innermost stride is 1), or
1913705f048cSEugene Zhulenev // empty (special case when at least one of the memref dimensions is 0).
1914705f048cSEugene Zhulenev if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1915705f048cSEugene Zhulenev return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1916705f048cSEugene Zhulenev failure();
1917705f048cSEugene Zhulenev
191875e5f0aaSAlex Zinenko // Create the descriptor.
1919136d746eSJacques Pienaar MemRefDescriptor sourceMemRef(adaptor.getSource());
192075e5f0aaSAlex Zinenko auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
192175e5f0aaSAlex Zinenko
192275e5f0aaSAlex Zinenko // Field 1: Copy the allocated pointer, used for malloc/free.
192375e5f0aaSAlex Zinenko Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1924136d746eSJacques Pienaar auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
192575e5f0aaSAlex Zinenko Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
192675e5f0aaSAlex Zinenko loc,
192775e5f0aaSAlex Zinenko LLVM::LLVMPointerType::get(targetElementTy,
192875e5f0aaSAlex Zinenko srcMemRefType.getMemorySpaceAsInt()),
192975e5f0aaSAlex Zinenko allocatedPtr);
193075e5f0aaSAlex Zinenko targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
193175e5f0aaSAlex Zinenko
193275e5f0aaSAlex Zinenko // Field 2: Copy the actual aligned pointer to payload.
193375e5f0aaSAlex Zinenko Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1934136d746eSJacques Pienaar alignedPtr = rewriter.create<LLVM::GEPOp>(
1935136d746eSJacques Pienaar loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift());
193675e5f0aaSAlex Zinenko bitcastPtr = rewriter.create<LLVM::BitcastOp>(
193775e5f0aaSAlex Zinenko loc,
193875e5f0aaSAlex Zinenko LLVM::LLVMPointerType::get(targetElementTy,
193975e5f0aaSAlex Zinenko srcMemRefType.getMemorySpaceAsInt()),
194075e5f0aaSAlex Zinenko alignedPtr);
194175e5f0aaSAlex Zinenko targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
194275e5f0aaSAlex Zinenko
194375e5f0aaSAlex Zinenko // Field 3: The offset in the resulting type must be 0. This is because of
194475e5f0aaSAlex Zinenko // the type change: an offset on srcType* may not be expressible as an
194575e5f0aaSAlex Zinenko // offset on dstType*.
194675e5f0aaSAlex Zinenko targetMemRef.setOffset(rewriter, loc,
194775e5f0aaSAlex Zinenko createIndexConstant(rewriter, loc, offset));
194875e5f0aaSAlex Zinenko
194975e5f0aaSAlex Zinenko // Early exit for 0-D corner case.
195075e5f0aaSAlex Zinenko if (viewMemRefType.getRank() == 0)
195175e5f0aaSAlex Zinenko return rewriter.replaceOp(viewOp, {targetMemRef}), success();
195275e5f0aaSAlex Zinenko
195375e5f0aaSAlex Zinenko // Fields 4 and 5: Update sizes and strides.
195475e5f0aaSAlex Zinenko Value stride = nullptr, nextSize = nullptr;
195575e5f0aaSAlex Zinenko for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
195675e5f0aaSAlex Zinenko // Update size.
1957136d746eSJacques Pienaar Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1958136d746eSJacques Pienaar adaptor.getSizes(), i);
195975e5f0aaSAlex Zinenko targetMemRef.setSize(rewriter, loc, i, size);
196075e5f0aaSAlex Zinenko // Update stride.
196175e5f0aaSAlex Zinenko stride = getStride(rewriter, loc, strides, nextSize, stride, i);
196275e5f0aaSAlex Zinenko targetMemRef.setStride(rewriter, loc, i, stride);
196375e5f0aaSAlex Zinenko nextSize = size;
196475e5f0aaSAlex Zinenko }
196575e5f0aaSAlex Zinenko
196675e5f0aaSAlex Zinenko rewriter.replaceOp(viewOp, {targetMemRef});
196775e5f0aaSAlex Zinenko return success();
196875e5f0aaSAlex Zinenko }
196975e5f0aaSAlex Zinenko };
197075e5f0aaSAlex Zinenko
1971a6a583daSWilliam S. Moses //===----------------------------------------------------------------------===//
1972a6a583daSWilliam S. Moses // AtomicRMWOpLowering
1973a6a583daSWilliam S. Moses //===----------------------------------------------------------------------===//
1974a6a583daSWilliam S. Moses
197523aa5a74SRiver Riddle /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1976a6a583daSWilliam S. Moses /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1977a6a583daSWilliam S. Moses static Optional<LLVM::AtomicBinOp>
matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp)1978a6a583daSWilliam S. Moses matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1979136d746eSJacques Pienaar switch (atomicOp.getKind()) {
1980a6a583daSWilliam S. Moses case arith::AtomicRMWKind::addf:
1981a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::fadd;
1982a6a583daSWilliam S. Moses case arith::AtomicRMWKind::addi:
1983a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::add;
1984a6a583daSWilliam S. Moses case arith::AtomicRMWKind::assign:
1985a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::xchg;
1986a6a583daSWilliam S. Moses case arith::AtomicRMWKind::maxs:
1987a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::max;
1988a6a583daSWilliam S. Moses case arith::AtomicRMWKind::maxu:
1989a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::umax;
1990a6a583daSWilliam S. Moses case arith::AtomicRMWKind::mins:
1991a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::min;
1992a6a583daSWilliam S. Moses case arith::AtomicRMWKind::minu:
1993a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::umin;
1994a6a583daSWilliam S. Moses case arith::AtomicRMWKind::ori:
1995a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::_or;
1996a6a583daSWilliam S. Moses case arith::AtomicRMWKind::andi:
1997a6a583daSWilliam S. Moses return LLVM::AtomicBinOp::_and;
1998a6a583daSWilliam S. Moses default:
1999a6a583daSWilliam S. Moses return llvm::None;
2000a6a583daSWilliam S. Moses }
2001a6a583daSWilliam S. Moses llvm_unreachable("Invalid AtomicRMWKind");
2002a6a583daSWilliam S. Moses }
2003a6a583daSWilliam S. Moses
2004a6a583daSWilliam S. Moses struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
2005a6a583daSWilliam S. Moses using Base::Base;
2006a6a583daSWilliam S. Moses
2007a6a583daSWilliam S. Moses LogicalResult
matchAndRewrite__anon7a9e10510111::AtomicRMWOpLowering2008a6a583daSWilliam S. Moses matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
2009a6a583daSWilliam S. Moses ConversionPatternRewriter &rewriter) const override {
2010a6a583daSWilliam S. Moses if (failed(match(atomicOp)))
2011a6a583daSWilliam S. Moses return failure();
2012a6a583daSWilliam S. Moses auto maybeKind = matchSimpleAtomicOp(atomicOp);
2013a6a583daSWilliam S. Moses if (!maybeKind)
2014a6a583daSWilliam S. Moses return failure();
2015136d746eSJacques Pienaar auto resultType = adaptor.getValue().getType();
2016a6a583daSWilliam S. Moses auto memRefType = atomicOp.getMemRefType();
2017a6a583daSWilliam S. Moses auto dataPtr =
2018136d746eSJacques Pienaar getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
2019136d746eSJacques Pienaar adaptor.getIndices(), rewriter);
2020a6a583daSWilliam S. Moses rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
2021136d746eSJacques Pienaar atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
2022a6a583daSWilliam S. Moses LLVM::AtomicOrdering::acq_rel);
2023a6a583daSWilliam S. Moses return success();
2024a6a583daSWilliam S. Moses }
2025a6a583daSWilliam S. Moses };
2026a6a583daSWilliam S. Moses
202775e5f0aaSAlex Zinenko } // namespace
202875e5f0aaSAlex Zinenko
populateMemRefToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)202975e5f0aaSAlex Zinenko void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
203075e5f0aaSAlex Zinenko RewritePatternSet &patterns) {
203175e5f0aaSAlex Zinenko // clang-format off
203275e5f0aaSAlex Zinenko patterns.add<
203375e5f0aaSAlex Zinenko AllocaOpLowering,
203475e5f0aaSAlex Zinenko AllocaScopeOpLowering,
2035a6a583daSWilliam S. Moses AtomicRMWOpLowering,
203675e5f0aaSAlex Zinenko AssumeAlignmentOpLowering,
203775e5f0aaSAlex Zinenko DimOpLowering,
2038632a4f88SRiver Riddle GenericAtomicRMWOpLowering,
203975e5f0aaSAlex Zinenko GlobalMemrefOpLowering,
204075e5f0aaSAlex Zinenko GetGlobalMemrefOpLowering,
204175e5f0aaSAlex Zinenko LoadOpLowering,
204275e5f0aaSAlex Zinenko MemRefCastOpLowering,
204375e5f0aaSAlex Zinenko MemRefCopyOpLowering,
204475e5f0aaSAlex Zinenko MemRefReinterpretCastOpLowering,
204575e5f0aaSAlex Zinenko MemRefReshapeOpLowering,
204675e5f0aaSAlex Zinenko PrefetchOpLowering,
204715f8f3e2SAlexander Belyaev RankOpLowering,
204846ef86b5SAlexander Belyaev ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
204946ef86b5SAlexander Belyaev ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
205075e5f0aaSAlex Zinenko StoreOpLowering,
205175e5f0aaSAlex Zinenko SubViewOpLowering,
205275e5f0aaSAlex Zinenko TransposeOpLowering,
205375e5f0aaSAlex Zinenko ViewOpLowering>(converter);
205475e5f0aaSAlex Zinenko // clang-format on
205575e5f0aaSAlex Zinenko auto allocLowering = converter.getOptions().allocLowering;
205675e5f0aaSAlex Zinenko if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
2057d04c2b2fSMehdi Amini patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
205875e5f0aaSAlex Zinenko else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
205975e5f0aaSAlex Zinenko patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
206075e5f0aaSAlex Zinenko }
206175e5f0aaSAlex Zinenko
206275e5f0aaSAlex Zinenko namespace {
206375e5f0aaSAlex Zinenko struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
206475e5f0aaSAlex Zinenko MemRefToLLVMPass() = default;
206575e5f0aaSAlex Zinenko
runOnOperation__anon7a9e10510811::MemRefToLLVMPass206675e5f0aaSAlex Zinenko void runOnOperation() override {
206775e5f0aaSAlex Zinenko Operation *op = getOperation();
206875e5f0aaSAlex Zinenko const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
206975e5f0aaSAlex Zinenko LowerToLLVMOptions options(&getContext(),
207075e5f0aaSAlex Zinenko dataLayoutAnalysis.getAtOrAbove(op));
207175e5f0aaSAlex Zinenko options.allocLowering =
207275e5f0aaSAlex Zinenko (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
207375e5f0aaSAlex Zinenko : LowerToLLVMOptions::AllocLowering::Malloc);
2074a8601f11SMichele Scuttari
2075a8601f11SMichele Scuttari options.useGenericFunctions = useGenericFunctions;
2076a8601f11SMichele Scuttari
207775e5f0aaSAlex Zinenko if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
207875e5f0aaSAlex Zinenko options.overrideIndexBitwidth(indexBitwidth);
207975e5f0aaSAlex Zinenko
208075e5f0aaSAlex Zinenko LLVMTypeConverter typeConverter(&getContext(), options,
208175e5f0aaSAlex Zinenko &dataLayoutAnalysis);
208275e5f0aaSAlex Zinenko RewritePatternSet patterns(&getContext());
208375e5f0aaSAlex Zinenko populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
208475e5f0aaSAlex Zinenko LLVMConversionTarget target(getContext());
208558ceae95SRiver Riddle target.addLegalOp<func::FuncOp>();
208675e5f0aaSAlex Zinenko if (failed(applyPartialConversion(op, target, std::move(patterns))))
208775e5f0aaSAlex Zinenko signalPassFailure();
208875e5f0aaSAlex Zinenko }
208975e5f0aaSAlex Zinenko };
209075e5f0aaSAlex Zinenko } // namespace
209175e5f0aaSAlex Zinenko
createMemRefToLLVMPass()209275e5f0aaSAlex Zinenko std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
209375e5f0aaSAlex Zinenko return std::make_unique<MemRefToLLVMPass>();
209475e5f0aaSAlex Zinenko }
2095