175e5f0aaSAlex Zinenko //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
275e5f0aaSAlex Zinenko //
375e5f0aaSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
475e5f0aaSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
575e5f0aaSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
675e5f0aaSAlex Zinenko //
775e5f0aaSAlex Zinenko //===----------------------------------------------------------------------===//
875e5f0aaSAlex Zinenko 
975e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
1075e5f0aaSAlex Zinenko #include "../PassDetail.h"
1175e5f0aaSAlex Zinenko #include "mlir/Analysis/DataLayoutAnalysis.h"
1275e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1375e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1475e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1575e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1736550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1875e5f0aaSAlex Zinenko #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1975e5f0aaSAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2075e5f0aaSAlex Zinenko #include "mlir/Dialect/MemRef/IR/MemRef.h"
2175e5f0aaSAlex Zinenko #include "mlir/IR/AffineMap.h"
2275e5f0aaSAlex Zinenko #include "mlir/IR/BlockAndValueMapping.h"
236635c12aSBenjamin Kramer #include "llvm/ADT/SmallBitVector.h"
2475e5f0aaSAlex Zinenko 
2575e5f0aaSAlex Zinenko using namespace mlir;
2675e5f0aaSAlex Zinenko 
2775e5f0aaSAlex Zinenko namespace {
2875e5f0aaSAlex Zinenko 
isStaticStrideOrOffset(int64_t strideOrOffset)295380e30eSAshay Rane bool isStaticStrideOrOffset(int64_t strideOrOffset) {
305380e30eSAshay Rane   return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
315380e30eSAshay Rane }
325380e30eSAshay Rane 
3375e5f0aaSAlex Zinenko struct AllocOpLowering : public AllocLikeOpLLVMLowering {
AllocOpLowering__anon7a9e10510111::AllocOpLowering3475e5f0aaSAlex Zinenko   AllocOpLowering(LLVMTypeConverter &converter)
3575e5f0aaSAlex Zinenko       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
3675e5f0aaSAlex Zinenko                                 converter) {}
3775e5f0aaSAlex Zinenko 
getAllocFn__anon7a9e10510111::AllocOpLowering38a8601f11SMichele Scuttari   LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
39a8601f11SMichele Scuttari     bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
40a8601f11SMichele Scuttari 
41a8601f11SMichele Scuttari     if (useGenericFn)
42a8601f11SMichele Scuttari       return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType());
43a8601f11SMichele Scuttari 
44a8601f11SMichele Scuttari     return LLVM::lookupOrCreateMallocFn(module, getIndexType());
45a8601f11SMichele Scuttari   }
46a8601f11SMichele Scuttari 
allocateBuffer__anon7a9e10510111::AllocOpLowering4775e5f0aaSAlex Zinenko   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
4875e5f0aaSAlex Zinenko                                           Location loc, Value sizeBytes,
4975e5f0aaSAlex Zinenko                                           Operation *op) const override {
5075e5f0aaSAlex Zinenko     // Heap allocations.
5175e5f0aaSAlex Zinenko     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
5275e5f0aaSAlex Zinenko     MemRefType memRefType = allocOp.getType();
5375e5f0aaSAlex Zinenko 
5475e5f0aaSAlex Zinenko     Value alignment;
55136d746eSJacques Pienaar     if (auto alignmentAttr = allocOp.getAlignment()) {
5675e5f0aaSAlex Zinenko       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
5775e5f0aaSAlex Zinenko     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
5875e5f0aaSAlex Zinenko       // In the case where no alignment is specified, we may want to override
5975e5f0aaSAlex Zinenko       // `malloc's` behavior. `malloc` typically aligns at the size of the
6075e5f0aaSAlex Zinenko       // biggest scalar on a target HW. For non-scalars, use the natural
6175e5f0aaSAlex Zinenko       // alignment of the LLVM type given by the LLVM DataLayout.
6275e5f0aaSAlex Zinenko       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
6375e5f0aaSAlex Zinenko     }
6475e5f0aaSAlex Zinenko 
6575e5f0aaSAlex Zinenko     if (alignment) {
6675e5f0aaSAlex Zinenko       // Adjust the allocation size to consider alignment.
6775e5f0aaSAlex Zinenko       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
6875e5f0aaSAlex Zinenko     }
6975e5f0aaSAlex Zinenko 
7075e5f0aaSAlex Zinenko     // Allocate the underlying buffer and store a pointer to it in the MemRef
7175e5f0aaSAlex Zinenko     // descriptor.
7275e5f0aaSAlex Zinenko     Type elementPtrType = this->getElementPtrType(memRefType);
73a8601f11SMichele Scuttari     auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
7475e5f0aaSAlex Zinenko     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
7575e5f0aaSAlex Zinenko                                   getVoidPtrType());
7675e5f0aaSAlex Zinenko     Value allocatedPtr =
7775e5f0aaSAlex Zinenko         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
7875e5f0aaSAlex Zinenko 
7975e5f0aaSAlex Zinenko     Value alignedPtr = allocatedPtr;
8075e5f0aaSAlex Zinenko     if (alignment) {
8175e5f0aaSAlex Zinenko       // Compute the aligned type pointer.
8275e5f0aaSAlex Zinenko       Value allocatedInt =
8375e5f0aaSAlex Zinenko           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
8475e5f0aaSAlex Zinenko       Value alignmentInt =
8575e5f0aaSAlex Zinenko           createAligned(rewriter, loc, allocatedInt, alignment);
8675e5f0aaSAlex Zinenko       alignedPtr =
8775e5f0aaSAlex Zinenko           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
8875e5f0aaSAlex Zinenko     }
8975e5f0aaSAlex Zinenko 
9075e5f0aaSAlex Zinenko     return std::make_tuple(allocatedPtr, alignedPtr);
9175e5f0aaSAlex Zinenko   }
9275e5f0aaSAlex Zinenko };
9375e5f0aaSAlex Zinenko 
9475e5f0aaSAlex Zinenko struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
AlignedAllocOpLowering__anon7a9e10510111::AlignedAllocOpLowering9575e5f0aaSAlex Zinenko   AlignedAllocOpLowering(LLVMTypeConverter &converter)
9675e5f0aaSAlex Zinenko       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
9775e5f0aaSAlex Zinenko                                 converter) {}
9875e5f0aaSAlex Zinenko 
9975e5f0aaSAlex Zinenko   /// Returns the memref's element size in bytes using the data layout active at
10075e5f0aaSAlex Zinenko   /// `op`.
10175e5f0aaSAlex Zinenko   // TODO: there are other places where this is used. Expose publicly?
getMemRefEltSizeInBytes__anon7a9e10510111::AlignedAllocOpLowering10275e5f0aaSAlex Zinenko   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
10375e5f0aaSAlex Zinenko     const DataLayout *layout = &defaultLayout;
10475e5f0aaSAlex Zinenko     if (const DataLayoutAnalysis *analysis =
10575e5f0aaSAlex Zinenko             getTypeConverter()->getDataLayoutAnalysis()) {
10675e5f0aaSAlex Zinenko       layout = &analysis->getAbove(op);
10775e5f0aaSAlex Zinenko     }
10875e5f0aaSAlex Zinenko     Type elementType = memRefType.getElementType();
10975e5f0aaSAlex Zinenko     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
11075e5f0aaSAlex Zinenko       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
11175e5f0aaSAlex Zinenko                                                          *layout);
11275e5f0aaSAlex Zinenko     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
11375e5f0aaSAlex Zinenko       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
11475e5f0aaSAlex Zinenko           memRefElementType, *layout);
11575e5f0aaSAlex Zinenko     return layout->getTypeSize(elementType);
11675e5f0aaSAlex Zinenko   }
11775e5f0aaSAlex Zinenko 
11875e5f0aaSAlex Zinenko   /// Returns true if the memref size in bytes is known to be a multiple of
11975e5f0aaSAlex Zinenko   /// factor assuming the data layout active at `op`.
isMemRefSizeMultipleOf__anon7a9e10510111::AlignedAllocOpLowering12075e5f0aaSAlex Zinenko   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
12175e5f0aaSAlex Zinenko                               Operation *op) const {
12275e5f0aaSAlex Zinenko     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
12375e5f0aaSAlex Zinenko     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
124676bfb2aSRiver Riddle       if (ShapedType::isDynamic(type.getDimSize(i)))
12575e5f0aaSAlex Zinenko         continue;
12675e5f0aaSAlex Zinenko       sizeDivisor = sizeDivisor * type.getDimSize(i);
12775e5f0aaSAlex Zinenko     }
12875e5f0aaSAlex Zinenko     return sizeDivisor % factor == 0;
12975e5f0aaSAlex Zinenko   }
13075e5f0aaSAlex Zinenko 
13175e5f0aaSAlex Zinenko   /// Returns the alignment to be used for the allocation call itself.
13275e5f0aaSAlex Zinenko   /// aligned_alloc requires the allocation size to be a power of two, and the
13375e5f0aaSAlex Zinenko   /// allocation size to be a multiple of alignment,
getAllocationAlignment__anon7a9e10510111::AlignedAllocOpLowering13475e5f0aaSAlex Zinenko   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
135136d746eSJacques Pienaar     if (Optional<uint64_t> alignment = allocOp.getAlignment())
13675e5f0aaSAlex Zinenko       return *alignment;
13775e5f0aaSAlex Zinenko 
13875e5f0aaSAlex Zinenko     // Whenever we don't have alignment set, we will use an alignment
13975e5f0aaSAlex Zinenko     // consistent with the element type; since the allocation size has to be a
14075e5f0aaSAlex Zinenko     // power of two, we will bump to the next power of two if it already isn't.
14175e5f0aaSAlex Zinenko     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
14275e5f0aaSAlex Zinenko     return std::max(kMinAlignedAllocAlignment,
14375e5f0aaSAlex Zinenko                     llvm::PowerOf2Ceil(eltSizeBytes));
14475e5f0aaSAlex Zinenko   }
14575e5f0aaSAlex Zinenko 
getAllocFn__anon7a9e10510111::AlignedAllocOpLowering146a8601f11SMichele Scuttari   LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
147a8601f11SMichele Scuttari     bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
148a8601f11SMichele Scuttari 
149a8601f11SMichele Scuttari     if (useGenericFn)
150a8601f11SMichele Scuttari       return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType());
151a8601f11SMichele Scuttari 
152a8601f11SMichele Scuttari     return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType());
153a8601f11SMichele Scuttari   }
154a8601f11SMichele Scuttari 
allocateBuffer__anon7a9e10510111::AlignedAllocOpLowering15575e5f0aaSAlex Zinenko   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
15675e5f0aaSAlex Zinenko                                           Location loc, Value sizeBytes,
15775e5f0aaSAlex Zinenko                                           Operation *op) const override {
15875e5f0aaSAlex Zinenko     // Heap allocations.
15975e5f0aaSAlex Zinenko     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
16075e5f0aaSAlex Zinenko     MemRefType memRefType = allocOp.getType();
16175e5f0aaSAlex Zinenko     int64_t alignment = getAllocationAlignment(allocOp);
16275e5f0aaSAlex Zinenko     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
16375e5f0aaSAlex Zinenko 
16475e5f0aaSAlex Zinenko     // aligned_alloc requires size to be a multiple of alignment; we will pad
16575e5f0aaSAlex Zinenko     // the size to the next multiple if necessary.
16675e5f0aaSAlex Zinenko     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
16775e5f0aaSAlex Zinenko       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
16875e5f0aaSAlex Zinenko 
16975e5f0aaSAlex Zinenko     Type elementPtrType = this->getElementPtrType(memRefType);
170a8601f11SMichele Scuttari     auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
17175e5f0aaSAlex Zinenko     auto results =
17275e5f0aaSAlex Zinenko         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
17375e5f0aaSAlex Zinenko                        getVoidPtrType());
17475e5f0aaSAlex Zinenko     Value allocatedPtr =
17575e5f0aaSAlex Zinenko         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
17675e5f0aaSAlex Zinenko 
17775e5f0aaSAlex Zinenko     return std::make_tuple(allocatedPtr, allocatedPtr);
17875e5f0aaSAlex Zinenko   }
17975e5f0aaSAlex Zinenko 
18075e5f0aaSAlex Zinenko   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
18175e5f0aaSAlex Zinenko   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
18275e5f0aaSAlex Zinenko 
18375e5f0aaSAlex Zinenko   /// Default layout to use in absence of the corresponding analysis.
18475e5f0aaSAlex Zinenko   DataLayout defaultLayout;
18575e5f0aaSAlex Zinenko };
18675e5f0aaSAlex Zinenko 
18775e5f0aaSAlex Zinenko // Out of line definition, required till C++17.
18875e5f0aaSAlex Zinenko constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
18975e5f0aaSAlex Zinenko 
19075e5f0aaSAlex Zinenko struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
AllocaOpLowering__anon7a9e10510111::AllocaOpLowering19175e5f0aaSAlex Zinenko   AllocaOpLowering(LLVMTypeConverter &converter)
19275e5f0aaSAlex Zinenko       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
19375e5f0aaSAlex Zinenko                                 converter) {}
19475e5f0aaSAlex Zinenko 
19575e5f0aaSAlex Zinenko   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
19675e5f0aaSAlex Zinenko   /// is set to null for stack allocations. `accessAlignment` is set if
19775e5f0aaSAlex Zinenko   /// alignment is needed post allocation (for eg. in conjunction with malloc).
allocateBuffer__anon7a9e10510111::AllocaOpLowering19875e5f0aaSAlex Zinenko   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
19975e5f0aaSAlex Zinenko                                           Location loc, Value sizeBytes,
20075e5f0aaSAlex Zinenko                                           Operation *op) const override {
20175e5f0aaSAlex Zinenko 
20275e5f0aaSAlex Zinenko     // With alloca, one gets a pointer to the element type right away.
20375e5f0aaSAlex Zinenko     // For stack allocations.
20475e5f0aaSAlex Zinenko     auto allocaOp = cast<memref::AllocaOp>(op);
20575e5f0aaSAlex Zinenko     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
20675e5f0aaSAlex Zinenko 
20775e5f0aaSAlex Zinenko     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
208*2789c4f5SKazu Hirata         loc, elementPtrType, sizeBytes, allocaOp.getAlignment().value_or(0));
20975e5f0aaSAlex Zinenko 
21075e5f0aaSAlex Zinenko     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
21175e5f0aaSAlex Zinenko   }
21275e5f0aaSAlex Zinenko };
21375e5f0aaSAlex Zinenko 
21475e5f0aaSAlex Zinenko struct AllocaScopeOpLowering
21575e5f0aaSAlex Zinenko     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
21675e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
21775e5f0aaSAlex Zinenko 
21875e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::AllocaScopeOpLowering219ef976337SRiver Riddle   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
22075e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
22175e5f0aaSAlex Zinenko     OpBuilder::InsertionGuard guard(rewriter);
22275e5f0aaSAlex Zinenko     Location loc = allocaScopeOp.getLoc();
22375e5f0aaSAlex Zinenko 
22475e5f0aaSAlex Zinenko     // Split the current block before the AllocaScopeOp to create the inlining
22575e5f0aaSAlex Zinenko     // point.
22675e5f0aaSAlex Zinenko     auto *currentBlock = rewriter.getInsertionBlock();
22775e5f0aaSAlex Zinenko     auto *remainingOpsBlock =
22875e5f0aaSAlex Zinenko         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
22975e5f0aaSAlex Zinenko     Block *continueBlock;
23075e5f0aaSAlex Zinenko     if (allocaScopeOp.getNumResults() == 0) {
23175e5f0aaSAlex Zinenko       continueBlock = remainingOpsBlock;
23275e5f0aaSAlex Zinenko     } else {
233e084679fSRiver Riddle       continueBlock = rewriter.createBlock(
234e084679fSRiver Riddle           remainingOpsBlock, allocaScopeOp.getResultTypes(),
235e084679fSRiver Riddle           SmallVector<Location>(allocaScopeOp->getNumResults(),
236e084679fSRiver Riddle                                 allocaScopeOp.getLoc()));
23775e5f0aaSAlex Zinenko       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
23875e5f0aaSAlex Zinenko     }
23975e5f0aaSAlex Zinenko 
24075e5f0aaSAlex Zinenko     // Inline body region.
241136d746eSJacques Pienaar     Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
242136d746eSJacques Pienaar     Block *afterBody = &allocaScopeOp.getBodyRegion().back();
243136d746eSJacques Pienaar     rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
24475e5f0aaSAlex Zinenko 
24575e5f0aaSAlex Zinenko     // Save stack and then branch into the body of the region.
24675e5f0aaSAlex Zinenko     rewriter.setInsertionPointToEnd(currentBlock);
24775e5f0aaSAlex Zinenko     auto stackSaveOp =
24875e5f0aaSAlex Zinenko         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
24975e5f0aaSAlex Zinenko     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
25075e5f0aaSAlex Zinenko 
25175e5f0aaSAlex Zinenko     // Replace the alloca_scope return with a branch that jumps out of the body.
25275e5f0aaSAlex Zinenko     // Stack restore before leaving the body region.
25375e5f0aaSAlex Zinenko     rewriter.setInsertionPointToEnd(afterBody);
25475e5f0aaSAlex Zinenko     auto returnOp =
25575e5f0aaSAlex Zinenko         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
25675e5f0aaSAlex Zinenko     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
257136d746eSJacques Pienaar         returnOp, returnOp.getResults(), continueBlock);
25875e5f0aaSAlex Zinenko 
25975e5f0aaSAlex Zinenko     // Insert stack restore before jumping out the body of the region.
26075e5f0aaSAlex Zinenko     rewriter.setInsertionPoint(branchOp);
26175e5f0aaSAlex Zinenko     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
26275e5f0aaSAlex Zinenko 
26375e5f0aaSAlex Zinenko     // Replace the op with values return from the body region.
26475e5f0aaSAlex Zinenko     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
26575e5f0aaSAlex Zinenko 
26675e5f0aaSAlex Zinenko     return success();
26775e5f0aaSAlex Zinenko   }
26875e5f0aaSAlex Zinenko };
26975e5f0aaSAlex Zinenko 
27075e5f0aaSAlex Zinenko struct AssumeAlignmentOpLowering
27175e5f0aaSAlex Zinenko     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
27275e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<
27375e5f0aaSAlex Zinenko       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
27475e5f0aaSAlex Zinenko 
27575e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::AssumeAlignmentOpLowering276ef976337SRiver Riddle   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
27775e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
278136d746eSJacques Pienaar     Value memref = adaptor.getMemref();
279136d746eSJacques Pienaar     unsigned alignment = op.getAlignment();
28075e5f0aaSAlex Zinenko     auto loc = op.getLoc();
28175e5f0aaSAlex Zinenko 
28275e5f0aaSAlex Zinenko     MemRefDescriptor memRefDescriptor(memref);
28375e5f0aaSAlex Zinenko     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
28475e5f0aaSAlex Zinenko 
28575e5f0aaSAlex Zinenko     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
28675e5f0aaSAlex Zinenko     // the asserted memref.alignedPtr isn't used anywhere else, as the real
28775e5f0aaSAlex Zinenko     // users like load/store/views always re-extract memref.alignedPtr as they
28875e5f0aaSAlex Zinenko     // get lowered.
28975e5f0aaSAlex Zinenko     //
29075e5f0aaSAlex Zinenko     // This relies on LLVM's CSE optimization (potentially after SROA), since
29175e5f0aaSAlex Zinenko     // after CSE all memref.alignedPtr instances get de-duplicated into the same
29275e5f0aaSAlex Zinenko     // pointer SSA value.
29375e5f0aaSAlex Zinenko     auto intPtrType =
29475e5f0aaSAlex Zinenko         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
29575e5f0aaSAlex Zinenko     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
29675e5f0aaSAlex Zinenko     Value mask =
29775e5f0aaSAlex Zinenko         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
29875e5f0aaSAlex Zinenko     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
29975e5f0aaSAlex Zinenko     rewriter.create<LLVM::AssumeOp>(
30075e5f0aaSAlex Zinenko         loc, rewriter.create<LLVM::ICmpOp>(
30175e5f0aaSAlex Zinenko                  loc, LLVM::ICmpPredicate::eq,
30275e5f0aaSAlex Zinenko                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
30375e5f0aaSAlex Zinenko 
30475e5f0aaSAlex Zinenko     rewriter.eraseOp(op);
30575e5f0aaSAlex Zinenko     return success();
30675e5f0aaSAlex Zinenko   }
30775e5f0aaSAlex Zinenko };
30875e5f0aaSAlex Zinenko 
30975e5f0aaSAlex Zinenko // A `dealloc` is converted into a call to `free` on the underlying data buffer.
31075e5f0aaSAlex Zinenko // The memref descriptor being an SSA value, there is no need to clean it up
31175e5f0aaSAlex Zinenko // in any way.
31275e5f0aaSAlex Zinenko struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
31375e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
31475e5f0aaSAlex Zinenko 
DeallocOpLowering__anon7a9e10510111::DeallocOpLowering31575e5f0aaSAlex Zinenko   explicit DeallocOpLowering(LLVMTypeConverter &converter)
31675e5f0aaSAlex Zinenko       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
31775e5f0aaSAlex Zinenko 
getFreeFn__anon7a9e10510111::DeallocOpLowering318a8601f11SMichele Scuttari   LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const {
319a8601f11SMichele Scuttari     bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
320a8601f11SMichele Scuttari 
321a8601f11SMichele Scuttari     if (useGenericFn)
322a8601f11SMichele Scuttari       return LLVM::lookupOrCreateGenericFreeFn(module);
323a8601f11SMichele Scuttari 
324a8601f11SMichele Scuttari     return LLVM::lookupOrCreateFreeFn(module);
325a8601f11SMichele Scuttari   }
326a8601f11SMichele Scuttari 
32775e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::DeallocOpLowering328ef976337SRiver Riddle   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
32975e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
33075e5f0aaSAlex Zinenko     // Insert the `free` declaration if it is not already present.
331a8601f11SMichele Scuttari     auto freeFunc = getFreeFn(op->getParentOfType<ModuleOp>());
332136d746eSJacques Pienaar     MemRefDescriptor memref(adaptor.getMemref());
33375e5f0aaSAlex Zinenko     Value casted = rewriter.create<LLVM::BitcastOp>(
33475e5f0aaSAlex Zinenko         op.getLoc(), getVoidPtrType(),
33575e5f0aaSAlex Zinenko         memref.allocatedPtr(rewriter, op.getLoc()));
33675e5f0aaSAlex Zinenko     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
337faf1c224SChris Lattner         op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
33875e5f0aaSAlex Zinenko     return success();
33975e5f0aaSAlex Zinenko   }
34075e5f0aaSAlex Zinenko };
34175e5f0aaSAlex Zinenko 
34275e5f0aaSAlex Zinenko // A `dim` is converted to a constant for static sizes and to an access to the
34375e5f0aaSAlex Zinenko // size stored in the memref descriptor for dynamic sizes.
34475e5f0aaSAlex Zinenko struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
34575e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
34675e5f0aaSAlex Zinenko 
34775e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::DimOpLowering348ef976337SRiver Riddle   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
34975e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
350136d746eSJacques Pienaar     Type operandType = dimOp.getSource().getType();
35175e5f0aaSAlex Zinenko     if (operandType.isa<UnrankedMemRefType>()) {
352ef976337SRiver Riddle       rewriter.replaceOp(
353ef976337SRiver Riddle           dimOp, {extractSizeOfUnrankedMemRef(
354ef976337SRiver Riddle                      operandType, dimOp, adaptor.getOperands(), rewriter)});
35575e5f0aaSAlex Zinenko 
35675e5f0aaSAlex Zinenko       return success();
35775e5f0aaSAlex Zinenko     }
35875e5f0aaSAlex Zinenko     if (operandType.isa<MemRefType>()) {
359ef976337SRiver Riddle       rewriter.replaceOp(
360ef976337SRiver Riddle           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
361ef976337SRiver Riddle                                             adaptor.getOperands(), rewriter)});
36275e5f0aaSAlex Zinenko       return success();
36375e5f0aaSAlex Zinenko     }
36475e5f0aaSAlex Zinenko     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
36575e5f0aaSAlex Zinenko   }
36675e5f0aaSAlex Zinenko 
36775e5f0aaSAlex Zinenko private:
extractSizeOfUnrankedMemRef__anon7a9e10510111::DimOpLowering36875e5f0aaSAlex Zinenko   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
369ef976337SRiver Riddle                                     OpAdaptor adaptor,
37075e5f0aaSAlex Zinenko                                     ConversionPatternRewriter &rewriter) const {
37175e5f0aaSAlex Zinenko     Location loc = dimOp.getLoc();
37275e5f0aaSAlex Zinenko 
37375e5f0aaSAlex Zinenko     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
37475e5f0aaSAlex Zinenko     auto scalarMemRefType =
37575e5f0aaSAlex Zinenko         MemRefType::get({}, unrankedMemRefType.getElementType());
37675e5f0aaSAlex Zinenko     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
37775e5f0aaSAlex Zinenko 
37875e5f0aaSAlex Zinenko     // Extract pointer to the underlying ranked descriptor and bitcast it to a
37975e5f0aaSAlex Zinenko     // memref<element_type> descriptor pointer to minimize the number of GEP
38075e5f0aaSAlex Zinenko     // operations.
381136d746eSJacques Pienaar     UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
38275e5f0aaSAlex Zinenko     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
38375e5f0aaSAlex Zinenko     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
38475e5f0aaSAlex Zinenko         loc,
38575e5f0aaSAlex Zinenko         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
38675e5f0aaSAlex Zinenko                                    addressSpace),
38775e5f0aaSAlex Zinenko         underlyingRankedDesc);
38875e5f0aaSAlex Zinenko 
38975e5f0aaSAlex Zinenko     // Get pointer to offset field of memref<element_type> descriptor.
39075e5f0aaSAlex Zinenko     Type indexPtrTy = LLVM::LLVMPointerType::get(
39175e5f0aaSAlex Zinenko         getTypeConverter()->getIndexType(), addressSpace);
39275e5f0aaSAlex Zinenko     Value two = rewriter.create<LLVM::ConstantOp>(
39375e5f0aaSAlex Zinenko         loc, typeConverter->convertType(rewriter.getI32Type()),
39475e5f0aaSAlex Zinenko         rewriter.getI32IntegerAttr(2));
39575e5f0aaSAlex Zinenko     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
39675e5f0aaSAlex Zinenko         loc, indexPtrTy, scalarMemRefDescPtr,
39775e5f0aaSAlex Zinenko         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
39875e5f0aaSAlex Zinenko 
39975e5f0aaSAlex Zinenko     // The size value that we have to extract can be obtained using GEPop with
40075e5f0aaSAlex Zinenko     // `dimOp.index() + 1` index argument.
40175e5f0aaSAlex Zinenko     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
402136d746eSJacques Pienaar         loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
40375e5f0aaSAlex Zinenko     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
40475e5f0aaSAlex Zinenko                                                  ValueRange({idxPlusOne}));
40575e5f0aaSAlex Zinenko     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
40675e5f0aaSAlex Zinenko   }
40775e5f0aaSAlex Zinenko 
getConstantDimIndex__anon7a9e10510111::DimOpLowering40875e5f0aaSAlex Zinenko   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
40975e5f0aaSAlex Zinenko     if (Optional<int64_t> idx = dimOp.getConstantIndex())
41075e5f0aaSAlex Zinenko       return idx;
41175e5f0aaSAlex Zinenko 
412136d746eSJacques Pienaar     if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
413cfb72fd3SJacques Pienaar       return constantOp.getValue()
414cfb72fd3SJacques Pienaar           .cast<IntegerAttr>()
415cfb72fd3SJacques Pienaar           .getValue()
416cfb72fd3SJacques Pienaar           .getSExtValue();
41775e5f0aaSAlex Zinenko 
41875e5f0aaSAlex Zinenko     return llvm::None;
41975e5f0aaSAlex Zinenko   }
42075e5f0aaSAlex Zinenko 
extractSizeOfRankedMemRef__anon7a9e10510111::DimOpLowering42175e5f0aaSAlex Zinenko   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
422ef976337SRiver Riddle                                   OpAdaptor adaptor,
42375e5f0aaSAlex Zinenko                                   ConversionPatternRewriter &rewriter) const {
42475e5f0aaSAlex Zinenko     Location loc = dimOp.getLoc();
425ef976337SRiver Riddle 
42675e5f0aaSAlex Zinenko     // Take advantage if index is constant.
42775e5f0aaSAlex Zinenko     MemRefType memRefType = operandType.cast<MemRefType>();
42875e5f0aaSAlex Zinenko     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
4296d5fc1e3SKazu Hirata       int64_t i = *index;
43075e5f0aaSAlex Zinenko       if (memRefType.isDynamicDim(i)) {
43175e5f0aaSAlex Zinenko         // extract dynamic size from the memref descriptor.
432136d746eSJacques Pienaar         MemRefDescriptor descriptor(adaptor.getSource());
43375e5f0aaSAlex Zinenko         return descriptor.size(rewriter, loc, i);
43475e5f0aaSAlex Zinenko       }
43575e5f0aaSAlex Zinenko       // Use constant for static size.
43675e5f0aaSAlex Zinenko       int64_t dimSize = memRefType.getDimSize(i);
43775e5f0aaSAlex Zinenko       return createIndexConstant(rewriter, loc, dimSize);
43875e5f0aaSAlex Zinenko     }
439136d746eSJacques Pienaar     Value index = adaptor.getIndex();
44075e5f0aaSAlex Zinenko     int64_t rank = memRefType.getRank();
441136d746eSJacques Pienaar     MemRefDescriptor memrefDescriptor(adaptor.getSource());
44275e5f0aaSAlex Zinenko     return memrefDescriptor.size(rewriter, loc, index, rank);
44375e5f0aaSAlex Zinenko   }
44475e5f0aaSAlex Zinenko };
44575e5f0aaSAlex Zinenko 
446632a4f88SRiver Riddle /// Common base for load and store operations on MemRefs. Restricts the match
447632a4f88SRiver Riddle /// to supported MemRef types. Provides functionality to emit code accessing a
448632a4f88SRiver Riddle /// specific element of the underlying data buffer.
449632a4f88SRiver Riddle template <typename Derived>
450632a4f88SRiver Riddle struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
451632a4f88SRiver Riddle   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
452632a4f88SRiver Riddle   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
453632a4f88SRiver Riddle   using Base = LoadStoreOpLowering<Derived>;
454632a4f88SRiver Riddle 
match__anon7a9e10510111::LoadStoreOpLowering455632a4f88SRiver Riddle   LogicalResult match(Derived op) const override {
456632a4f88SRiver Riddle     MemRefType type = op.getMemRefType();
457632a4f88SRiver Riddle     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
458632a4f88SRiver Riddle   }
459632a4f88SRiver Riddle };
460632a4f88SRiver Riddle 
461632a4f88SRiver Riddle /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
462632a4f88SRiver Riddle /// retried until it succeeds in atomically storing a new value into memory.
463632a4f88SRiver Riddle ///
464632a4f88SRiver Riddle ///      +---------------------------------+
465632a4f88SRiver Riddle ///      |   <code before the AtomicRMWOp> |
466632a4f88SRiver Riddle ///      |   <compute initial %loaded>     |
467ace01605SRiver Riddle ///      |   cf.br loop(%loaded)              |
468632a4f88SRiver Riddle ///      +---------------------------------+
469632a4f88SRiver Riddle ///             |
470632a4f88SRiver Riddle ///  -------|   |
471632a4f88SRiver Riddle ///  |      v   v
472632a4f88SRiver Riddle ///  |   +--------------------------------+
473632a4f88SRiver Riddle ///  |   | loop(%loaded):                 |
474632a4f88SRiver Riddle ///  |   |   <body contents>              |
475632a4f88SRiver Riddle ///  |   |   %pair = cmpxchg              |
476632a4f88SRiver Riddle ///  |   |   %ok = %pair[0]               |
477632a4f88SRiver Riddle ///  |   |   %new = %pair[1]              |
478ace01605SRiver Riddle ///  |   |   cf.cond_br %ok, end, loop(%new) |
479632a4f88SRiver Riddle ///  |   +--------------------------------+
480632a4f88SRiver Riddle ///  |          |        |
481632a4f88SRiver Riddle ///  |-----------        |
482632a4f88SRiver Riddle ///                      v
483632a4f88SRiver Riddle ///      +--------------------------------+
484632a4f88SRiver Riddle ///      | end:                           |
485632a4f88SRiver Riddle ///      |   <code after the AtomicRMWOp> |
486632a4f88SRiver Riddle ///      +--------------------------------+
487632a4f88SRiver Riddle ///
488632a4f88SRiver Riddle struct GenericAtomicRMWOpLowering
489632a4f88SRiver Riddle     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
490632a4f88SRiver Riddle   using Base::Base;
491632a4f88SRiver Riddle 
492632a4f88SRiver Riddle   LogicalResult
matchAndRewrite__anon7a9e10510111::GenericAtomicRMWOpLowering493632a4f88SRiver Riddle   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
494632a4f88SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
495632a4f88SRiver Riddle     auto loc = atomicOp.getLoc();
496632a4f88SRiver Riddle     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
497632a4f88SRiver Riddle 
498632a4f88SRiver Riddle     // Split the block into initial, loop, and ending parts.
499632a4f88SRiver Riddle     auto *initBlock = rewriter.getInsertionBlock();
500632a4f88SRiver Riddle     auto *loopBlock = rewriter.createBlock(
501632a4f88SRiver Riddle         initBlock->getParent(), std::next(Region::iterator(initBlock)),
502632a4f88SRiver Riddle         valueType, loc);
503632a4f88SRiver Riddle     auto *endBlock = rewriter.createBlock(
504632a4f88SRiver Riddle         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
505632a4f88SRiver Riddle 
506632a4f88SRiver Riddle     // Operations range to be moved to `endBlock`.
507632a4f88SRiver Riddle     auto opsToMoveStart = atomicOp->getIterator();
508632a4f88SRiver Riddle     auto opsToMoveEnd = initBlock->back().getIterator();
509632a4f88SRiver Riddle 
510632a4f88SRiver Riddle     // Compute the loaded value and branch to the loop block.
511632a4f88SRiver Riddle     rewriter.setInsertionPointToEnd(initBlock);
512136d746eSJacques Pienaar     auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
513136d746eSJacques Pienaar     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
514136d746eSJacques Pienaar                                         adaptor.getIndices(), rewriter);
515632a4f88SRiver Riddle     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
516632a4f88SRiver Riddle     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
517632a4f88SRiver Riddle 
518632a4f88SRiver Riddle     // Prepare the body of the loop block.
519632a4f88SRiver Riddle     rewriter.setInsertionPointToStart(loopBlock);
520632a4f88SRiver Riddle 
521632a4f88SRiver Riddle     // Clone the GenericAtomicRMWOp region and extract the result.
522632a4f88SRiver Riddle     auto loopArgument = loopBlock->getArgument(0);
523632a4f88SRiver Riddle     BlockAndValueMapping mapping;
524632a4f88SRiver Riddle     mapping.map(atomicOp.getCurrentValue(), loopArgument);
525632a4f88SRiver Riddle     Block &entryBlock = atomicOp.body().front();
526632a4f88SRiver Riddle     for (auto &nestedOp : entryBlock.without_terminator()) {
527632a4f88SRiver Riddle       Operation *clone = rewriter.clone(nestedOp, mapping);
528632a4f88SRiver Riddle       mapping.map(nestedOp.getResults(), clone->getResults());
529632a4f88SRiver Riddle     }
530632a4f88SRiver Riddle     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
531632a4f88SRiver Riddle 
532632a4f88SRiver Riddle     // Prepare the epilog of the loop block.
533632a4f88SRiver Riddle     // Append the cmpxchg op to the end of the loop block.
534632a4f88SRiver Riddle     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
535632a4f88SRiver Riddle     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
536632a4f88SRiver Riddle     auto boolType = IntegerType::get(rewriter.getContext(), 1);
537632a4f88SRiver Riddle     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
538632a4f88SRiver Riddle                                                      {valueType, boolType});
539632a4f88SRiver Riddle     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
540632a4f88SRiver Riddle         loc, pairType, dataPtr, loopArgument, result, successOrdering,
541632a4f88SRiver Riddle         failureOrdering);
542632a4f88SRiver Riddle     // Extract the %new_loaded and %ok values from the pair.
543632a4f88SRiver Riddle     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
544632a4f88SRiver Riddle         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
545632a4f88SRiver Riddle     Value ok = rewriter.create<LLVM::ExtractValueOp>(
546632a4f88SRiver Riddle         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
547632a4f88SRiver Riddle 
548632a4f88SRiver Riddle     // Conditionally branch to the end or back to the loop depending on %ok.
549632a4f88SRiver Riddle     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
550632a4f88SRiver Riddle                                     loopBlock, newLoaded);
551632a4f88SRiver Riddle 
552632a4f88SRiver Riddle     rewriter.setInsertionPointToEnd(endBlock);
553632a4f88SRiver Riddle     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
554632a4f88SRiver Riddle                  std::next(opsToMoveEnd), rewriter);
555632a4f88SRiver Riddle 
556632a4f88SRiver Riddle     // The 'result' of the atomic_rmw op is the newly loaded value.
557632a4f88SRiver Riddle     rewriter.replaceOp(atomicOp, {newLoaded});
558632a4f88SRiver Riddle 
559632a4f88SRiver Riddle     return success();
560632a4f88SRiver Riddle   }
561632a4f88SRiver Riddle 
562632a4f88SRiver Riddle private:
563632a4f88SRiver Riddle   // Clones a segment of ops [start, end) and erases the original.
moveOpsRange__anon7a9e10510111::GenericAtomicRMWOpLowering564632a4f88SRiver Riddle   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
565632a4f88SRiver Riddle                     Block::iterator start, Block::iterator end,
566632a4f88SRiver Riddle                     ConversionPatternRewriter &rewriter) const {
567632a4f88SRiver Riddle     BlockAndValueMapping mapping;
568632a4f88SRiver Riddle     mapping.map(oldResult, newResult);
569632a4f88SRiver Riddle     SmallVector<Operation *, 2> opsToErase;
570632a4f88SRiver Riddle     for (auto it = start; it != end; ++it) {
571632a4f88SRiver Riddle       rewriter.clone(*it, mapping);
572632a4f88SRiver Riddle       opsToErase.push_back(&*it);
573632a4f88SRiver Riddle     }
574632a4f88SRiver Riddle     for (auto *it : opsToErase)
575632a4f88SRiver Riddle       rewriter.eraseOp(it);
576632a4f88SRiver Riddle   }
577632a4f88SRiver Riddle };
578632a4f88SRiver Riddle 
57975e5f0aaSAlex Zinenko /// Returns the LLVM type of the global variable given the memref type `type`.
convertGlobalMemrefTypeToLLVM(MemRefType type,LLVMTypeConverter & typeConverter)58075e5f0aaSAlex Zinenko static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
58175e5f0aaSAlex Zinenko                                           LLVMTypeConverter &typeConverter) {
58275e5f0aaSAlex Zinenko   // LLVM type for a global memref will be a multi-dimension array. For
58375e5f0aaSAlex Zinenko   // declarations or uninitialized global memrefs, we can potentially flatten
58475e5f0aaSAlex Zinenko   // this to a 1D array. However, for memref.global's with an initial value,
58575e5f0aaSAlex Zinenko   // we do not intend to flatten the ElementsAttribute when going from std ->
58675e5f0aaSAlex Zinenko   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
58775e5f0aaSAlex Zinenko   Type elementType = typeConverter.convertType(type.getElementType());
58875e5f0aaSAlex Zinenko   Type arrayTy = elementType;
58975e5f0aaSAlex Zinenko   // Shape has the outermost dim at index 0, so need to walk it backwards
59075e5f0aaSAlex Zinenko   for (int64_t dim : llvm::reverse(type.getShape()))
59175e5f0aaSAlex Zinenko     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
59275e5f0aaSAlex Zinenko   return arrayTy;
59375e5f0aaSAlex Zinenko }
59475e5f0aaSAlex Zinenko 
59575e5f0aaSAlex Zinenko /// GlobalMemrefOp is lowered to a LLVM Global Variable.
59675e5f0aaSAlex Zinenko struct GlobalMemrefOpLowering
59775e5f0aaSAlex Zinenko     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
59875e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
59975e5f0aaSAlex Zinenko 
60075e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::GlobalMemrefOpLowering601ef976337SRiver Riddle   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
60275e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
603136d746eSJacques Pienaar     MemRefType type = global.getType();
60475e5f0aaSAlex Zinenko     if (!isConvertibleAndHasIdentityMaps(type))
60575e5f0aaSAlex Zinenko       return failure();
60675e5f0aaSAlex Zinenko 
60775e5f0aaSAlex Zinenko     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
60875e5f0aaSAlex Zinenko 
60975e5f0aaSAlex Zinenko     LLVM::Linkage linkage =
61075e5f0aaSAlex Zinenko         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
61175e5f0aaSAlex Zinenko 
61275e5f0aaSAlex Zinenko     Attribute initialValue = nullptr;
61375e5f0aaSAlex Zinenko     if (!global.isExternal() && !global.isUninitialized()) {
614136d746eSJacques Pienaar       auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
61575e5f0aaSAlex Zinenko       initialValue = elementsAttr;
61675e5f0aaSAlex Zinenko 
61775e5f0aaSAlex Zinenko       // For scalar memrefs, the global variable created is of the element type,
61875e5f0aaSAlex Zinenko       // so unpack the elements attribute to extract the value.
61975e5f0aaSAlex Zinenko       if (type.getRank() == 0)
620937e40a8SRiver Riddle         initialValue = elementsAttr.getSplatValue<Attribute>();
62175e5f0aaSAlex Zinenko     }
62275e5f0aaSAlex Zinenko 
623136d746eSJacques Pienaar     uint64_t alignment = global.getAlignment().value_or(0);
6248276ac13SEugene Zhulenev 
6258c2ff7b6SWilliam S. Moses     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
626136d746eSJacques Pienaar         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
6278276ac13SEugene Zhulenev         initialValue, alignment, type.getMemorySpaceAsInt());
6288c2ff7b6SWilliam S. Moses     if (!global.isExternal() && global.isUninitialized()) {
6298c2ff7b6SWilliam S. Moses       Block *blk = new Block();
6308c2ff7b6SWilliam S. Moses       newGlobal.getInitializerRegion().push_back(blk);
6318c2ff7b6SWilliam S. Moses       rewriter.setInsertionPointToStart(blk);
6328c2ff7b6SWilliam S. Moses       Value undef[] = {
6338c2ff7b6SWilliam S. Moses           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
6348c2ff7b6SWilliam S. Moses       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
6358c2ff7b6SWilliam S. Moses     }
63675e5f0aaSAlex Zinenko     return success();
63775e5f0aaSAlex Zinenko   }
63875e5f0aaSAlex Zinenko };
63975e5f0aaSAlex Zinenko 
64075e5f0aaSAlex Zinenko /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
64175e5f0aaSAlex Zinenko /// the first element stashed into the descriptor. This reuses
64275e5f0aaSAlex Zinenko /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
64375e5f0aaSAlex Zinenko struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
GetGlobalMemrefOpLowering__anon7a9e10510111::GetGlobalMemrefOpLowering64475e5f0aaSAlex Zinenko   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
64575e5f0aaSAlex Zinenko       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
64675e5f0aaSAlex Zinenko                                 converter) {}
64775e5f0aaSAlex Zinenko 
64875e5f0aaSAlex Zinenko   /// Buffer "allocation" for memref.get_global op is getting the address of
64975e5f0aaSAlex Zinenko   /// the global variable referenced.
allocateBuffer__anon7a9e10510111::GetGlobalMemrefOpLowering65075e5f0aaSAlex Zinenko   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
65175e5f0aaSAlex Zinenko                                           Location loc, Value sizeBytes,
65275e5f0aaSAlex Zinenko                                           Operation *op) const override {
65375e5f0aaSAlex Zinenko     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
654136d746eSJacques Pienaar     MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
65575e5f0aaSAlex Zinenko     unsigned memSpace = type.getMemorySpaceAsInt();
65675e5f0aaSAlex Zinenko 
65775e5f0aaSAlex Zinenko     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
65875e5f0aaSAlex Zinenko     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
659136d746eSJacques Pienaar         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace),
660136d746eSJacques Pienaar         getGlobalOp.getName());
66175e5f0aaSAlex Zinenko 
66275e5f0aaSAlex Zinenko     // Get the address of the first element in the array by creating a GEP with
66375e5f0aaSAlex Zinenko     // the address of the GV as the base, and (rank + 1) number of 0 indices.
66475e5f0aaSAlex Zinenko     Type elementType = typeConverter->convertType(type.getElementType());
66575e5f0aaSAlex Zinenko     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
66675e5f0aaSAlex Zinenko 
667cafaa350SAlex Zinenko     SmallVector<Value> operands;
66875e5f0aaSAlex Zinenko     operands.insert(operands.end(), type.getRank() + 1,
66975e5f0aaSAlex Zinenko                     createIndexConstant(rewriter, loc, 0));
670cafaa350SAlex Zinenko     auto gep =
671cafaa350SAlex Zinenko         rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
67275e5f0aaSAlex Zinenko 
67375e5f0aaSAlex Zinenko     // We do not expect the memref obtained using `memref.get_global` to be
67475e5f0aaSAlex Zinenko     // ever deallocated. Set the allocated pointer to be known bad value to
67575e5f0aaSAlex Zinenko     // help debug if that ever happens.
67675e5f0aaSAlex Zinenko     auto intPtrType = getIntPtrType(memSpace);
67775e5f0aaSAlex Zinenko     Value deadBeefConst =
67875e5f0aaSAlex Zinenko         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
67975e5f0aaSAlex Zinenko     auto deadBeefPtr =
68075e5f0aaSAlex Zinenko         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
68175e5f0aaSAlex Zinenko 
68275e5f0aaSAlex Zinenko     // Both allocated and aligned pointers are same. We could potentially stash
68375e5f0aaSAlex Zinenko     // a nullptr for the allocated pointer since we do not expect any dealloc.
68475e5f0aaSAlex Zinenko     return std::make_tuple(deadBeefPtr, gep);
68575e5f0aaSAlex Zinenko   }
68675e5f0aaSAlex Zinenko };
68775e5f0aaSAlex Zinenko 
68875e5f0aaSAlex Zinenko // Load operation is lowered to obtaining a pointer to the indexed element
68975e5f0aaSAlex Zinenko // and loading it.
69075e5f0aaSAlex Zinenko struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
69175e5f0aaSAlex Zinenko   using Base::Base;
69275e5f0aaSAlex Zinenko 
69375e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::LoadOpLowering694ef976337SRiver Riddle   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
69575e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
69675e5f0aaSAlex Zinenko     auto type = loadOp.getMemRefType();
69775e5f0aaSAlex Zinenko 
698136d746eSJacques Pienaar     Value dataPtr =
699136d746eSJacques Pienaar         getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
700136d746eSJacques Pienaar                              adaptor.getIndices(), rewriter);
70175e5f0aaSAlex Zinenko     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
70275e5f0aaSAlex Zinenko     return success();
70375e5f0aaSAlex Zinenko   }
70475e5f0aaSAlex Zinenko };
70575e5f0aaSAlex Zinenko 
70675e5f0aaSAlex Zinenko // Store operation is lowered to obtaining a pointer to the indexed element,
70775e5f0aaSAlex Zinenko // and storing the given value to it.
70875e5f0aaSAlex Zinenko struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
70975e5f0aaSAlex Zinenko   using Base::Base;
71075e5f0aaSAlex Zinenko 
71175e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::StoreOpLowering712ef976337SRiver Riddle   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
71375e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
71475e5f0aaSAlex Zinenko     auto type = op.getMemRefType();
71575e5f0aaSAlex Zinenko 
716136d746eSJacques Pienaar     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
717136d746eSJacques Pienaar                                          adaptor.getIndices(), rewriter);
718136d746eSJacques Pienaar     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
71975e5f0aaSAlex Zinenko     return success();
72075e5f0aaSAlex Zinenko   }
72175e5f0aaSAlex Zinenko };
72275e5f0aaSAlex Zinenko 
72375e5f0aaSAlex Zinenko // The prefetch operation is lowered in a way similar to the load operation
72475e5f0aaSAlex Zinenko // except that the llvm.prefetch operation is used for replacement.
72575e5f0aaSAlex Zinenko struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
72675e5f0aaSAlex Zinenko   using Base::Base;
72775e5f0aaSAlex Zinenko 
72875e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::PrefetchOpLowering729ef976337SRiver Riddle   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
73075e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
73175e5f0aaSAlex Zinenko     auto type = prefetchOp.getMemRefType();
73275e5f0aaSAlex Zinenko     auto loc = prefetchOp.getLoc();
73375e5f0aaSAlex Zinenko 
734136d746eSJacques Pienaar     Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
735136d746eSJacques Pienaar                                          adaptor.getIndices(), rewriter);
73675e5f0aaSAlex Zinenko 
73775e5f0aaSAlex Zinenko     // Replace with llvm.prefetch.
73875e5f0aaSAlex Zinenko     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
73975e5f0aaSAlex Zinenko     auto isWrite = rewriter.create<LLVM::ConstantOp>(
740136d746eSJacques Pienaar         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()));
74175e5f0aaSAlex Zinenko     auto localityHint = rewriter.create<LLVM::ConstantOp>(
74275e5f0aaSAlex Zinenko         loc, llvmI32Type,
743136d746eSJacques Pienaar         rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint()));
74475e5f0aaSAlex Zinenko     auto isData = rewriter.create<LLVM::ConstantOp>(
745136d746eSJacques Pienaar         loc, llvmI32Type,
746136d746eSJacques Pienaar         rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()));
74775e5f0aaSAlex Zinenko 
74875e5f0aaSAlex Zinenko     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
74975e5f0aaSAlex Zinenko                                                 localityHint, isData);
75075e5f0aaSAlex Zinenko     return success();
75175e5f0aaSAlex Zinenko   }
75275e5f0aaSAlex Zinenko };
75375e5f0aaSAlex Zinenko 
75415f8f3e2SAlexander Belyaev struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
75515f8f3e2SAlexander Belyaev   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
75615f8f3e2SAlexander Belyaev 
75715f8f3e2SAlexander Belyaev   LogicalResult
matchAndRewrite__anon7a9e10510111::RankOpLowering75815f8f3e2SAlexander Belyaev   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
75915f8f3e2SAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
76015f8f3e2SAlexander Belyaev     Location loc = op.getLoc();
761136d746eSJacques Pienaar     Type operandType = op.getMemref().getType();
76215f8f3e2SAlexander Belyaev     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
763136d746eSJacques Pienaar       UnrankedMemRefDescriptor desc(adaptor.getMemref());
76415f8f3e2SAlexander Belyaev       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
76515f8f3e2SAlexander Belyaev       return success();
76615f8f3e2SAlexander Belyaev     }
76715f8f3e2SAlexander Belyaev     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
76815f8f3e2SAlexander Belyaev       rewriter.replaceOp(
76915f8f3e2SAlexander Belyaev           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
77015f8f3e2SAlexander Belyaev       return success();
77115f8f3e2SAlexander Belyaev     }
77215f8f3e2SAlexander Belyaev     return failure();
77315f8f3e2SAlexander Belyaev   }
77415f8f3e2SAlexander Belyaev };
77515f8f3e2SAlexander Belyaev 
77675e5f0aaSAlex Zinenko struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
77775e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
77875e5f0aaSAlex Zinenko 
match__anon7a9e10510111::MemRefCastOpLowering77975e5f0aaSAlex Zinenko   LogicalResult match(memref::CastOp memRefCastOp) const override {
78075e5f0aaSAlex Zinenko     Type srcType = memRefCastOp.getOperand().getType();
78175e5f0aaSAlex Zinenko     Type dstType = memRefCastOp.getType();
78275e5f0aaSAlex Zinenko 
78375e5f0aaSAlex Zinenko     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
78475e5f0aaSAlex Zinenko     // used for type erasure. For now they must preserve underlying element type
78575e5f0aaSAlex Zinenko     // and require source and result type to have the same rank. Therefore,
78675e5f0aaSAlex Zinenko     // perform a sanity check that the underlying structs are the same. Once op
78775e5f0aaSAlex Zinenko     // semantics are relaxed we can revisit.
78875e5f0aaSAlex Zinenko     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
78975e5f0aaSAlex Zinenko       return success(typeConverter->convertType(srcType) ==
79075e5f0aaSAlex Zinenko                      typeConverter->convertType(dstType));
79175e5f0aaSAlex Zinenko 
79275e5f0aaSAlex Zinenko     // At least one of the operands is unranked type
79375e5f0aaSAlex Zinenko     assert(srcType.isa<UnrankedMemRefType>() ||
79475e5f0aaSAlex Zinenko            dstType.isa<UnrankedMemRefType>());
79575e5f0aaSAlex Zinenko 
79675e5f0aaSAlex Zinenko     // Unranked to unranked cast is disallowed
79775e5f0aaSAlex Zinenko     return !(srcType.isa<UnrankedMemRefType>() &&
79875e5f0aaSAlex Zinenko              dstType.isa<UnrankedMemRefType>())
79975e5f0aaSAlex Zinenko                ? success()
80075e5f0aaSAlex Zinenko                : failure();
80175e5f0aaSAlex Zinenko   }
80275e5f0aaSAlex Zinenko 
rewrite__anon7a9e10510111::MemRefCastOpLowering803ef976337SRiver Riddle   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
80475e5f0aaSAlex Zinenko                ConversionPatternRewriter &rewriter) const override {
80575e5f0aaSAlex Zinenko     auto srcType = memRefCastOp.getOperand().getType();
80675e5f0aaSAlex Zinenko     auto dstType = memRefCastOp.getType();
80775e5f0aaSAlex Zinenko     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
80875e5f0aaSAlex Zinenko     auto loc = memRefCastOp.getLoc();
80975e5f0aaSAlex Zinenko 
81075e5f0aaSAlex Zinenko     // For ranked/ranked case, just keep the original descriptor.
81175e5f0aaSAlex Zinenko     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
812136d746eSJacques Pienaar       return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
81375e5f0aaSAlex Zinenko 
81475e5f0aaSAlex Zinenko     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
81575e5f0aaSAlex Zinenko       // Casting ranked to unranked memref type
81675e5f0aaSAlex Zinenko       // Set the rank in the destination from the memref type
81775e5f0aaSAlex Zinenko       // Allocate space on the stack and copy the src memref descriptor
81875e5f0aaSAlex Zinenko       // Set the ptr in the destination to the stack space
81975e5f0aaSAlex Zinenko       auto srcMemRefType = srcType.cast<MemRefType>();
82075e5f0aaSAlex Zinenko       int64_t rank = srcMemRefType.getRank();
82175e5f0aaSAlex Zinenko       // ptr = AllocaOp sizeof(MemRefDescriptor)
82275e5f0aaSAlex Zinenko       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
823136d746eSJacques Pienaar           loc, adaptor.getSource(), rewriter);
82475e5f0aaSAlex Zinenko       // voidptr = BitCastOp srcType* to void*
82575e5f0aaSAlex Zinenko       auto voidPtr =
82675e5f0aaSAlex Zinenko           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
82775e5f0aaSAlex Zinenko               .getResult();
82875e5f0aaSAlex Zinenko       // rank = ConstantOp srcRank
82975e5f0aaSAlex Zinenko       auto rankVal = rewriter.create<LLVM::ConstantOp>(
8305b02a480SAdrian Kuegel           loc, getIndexType(), rewriter.getIndexAttr(rank));
83175e5f0aaSAlex Zinenko       // undef = UndefOp
83275e5f0aaSAlex Zinenko       UnrankedMemRefDescriptor memRefDesc =
83375e5f0aaSAlex Zinenko           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
83475e5f0aaSAlex Zinenko       // d1 = InsertValueOp undef, rank, 0
83575e5f0aaSAlex Zinenko       memRefDesc.setRank(rewriter, loc, rankVal);
83675e5f0aaSAlex Zinenko       // d2 = InsertValueOp d1, voidptr, 1
83775e5f0aaSAlex Zinenko       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
83875e5f0aaSAlex Zinenko       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
83975e5f0aaSAlex Zinenko 
84075e5f0aaSAlex Zinenko     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
84175e5f0aaSAlex Zinenko       // Casting from unranked type to ranked.
84275e5f0aaSAlex Zinenko       // The operation is assumed to be doing a correct cast. If the destination
84375e5f0aaSAlex Zinenko       // type mismatches the unranked the type, it is undefined behavior.
844136d746eSJacques Pienaar       UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
84575e5f0aaSAlex Zinenko       // ptr = ExtractValueOp src, 1
84675e5f0aaSAlex Zinenko       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
84775e5f0aaSAlex Zinenko       // castPtr = BitCastOp i8* to structTy*
84875e5f0aaSAlex Zinenko       auto castPtr =
84975e5f0aaSAlex Zinenko           rewriter
85075e5f0aaSAlex Zinenko               .create<LLVM::BitcastOp>(
85175e5f0aaSAlex Zinenko                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
85275e5f0aaSAlex Zinenko               .getResult();
85375e5f0aaSAlex Zinenko       // struct = LoadOp castPtr
85475e5f0aaSAlex Zinenko       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
85575e5f0aaSAlex Zinenko       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
85675e5f0aaSAlex Zinenko     } else {
85775e5f0aaSAlex Zinenko       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
85875e5f0aaSAlex Zinenko     }
85975e5f0aaSAlex Zinenko   }
86075e5f0aaSAlex Zinenko };
86175e5f0aaSAlex Zinenko 
862ab95ba70SStephan Herhut /// Pattern to lower a `memref.copy` to llvm.
863ab95ba70SStephan Herhut ///
864ab95ba70SStephan Herhut /// For memrefs with identity layouts, the copy is lowered to the llvm
865ab95ba70SStephan Herhut /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
866ab95ba70SStephan Herhut /// to the generic `MemrefCopyFn`.
86775e5f0aaSAlex Zinenko struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
86875e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
86975e5f0aaSAlex Zinenko 
87075e5f0aaSAlex Zinenko   LogicalResult
lowerToMemCopyIntrinsic__anon7a9e10510111::MemRefCopyOpLowering871ab95ba70SStephan Herhut   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
872ab95ba70SStephan Herhut                           ConversionPatternRewriter &rewriter) const {
873ab95ba70SStephan Herhut     auto loc = op.getLoc();
874136d746eSJacques Pienaar     auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
875ab95ba70SStephan Herhut 
876136d746eSJacques Pienaar     MemRefDescriptor srcDesc(adaptor.getSource());
877ab95ba70SStephan Herhut 
878ab95ba70SStephan Herhut     // Compute number of elements.
879aa3cabe3SStephan Herhut     Value numElements = rewriter.create<LLVM::ConstantOp>(
880aa3cabe3SStephan Herhut         loc, getIndexType(), rewriter.getIndexAttr(1));
881ab95ba70SStephan Herhut     for (int pos = 0; pos < srcType.getRank(); ++pos) {
882ab95ba70SStephan Herhut       auto size = srcDesc.size(rewriter, loc, pos);
883aa3cabe3SStephan Herhut       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
884ab95ba70SStephan Herhut     }
885aa3cabe3SStephan Herhut 
886ab95ba70SStephan Herhut     // Get element size.
887ab95ba70SStephan Herhut     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
888ab95ba70SStephan Herhut     // Compute total.
889ab95ba70SStephan Herhut     Value totalSize =
890ab95ba70SStephan Herhut         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
891ab95ba70SStephan Herhut 
892ab95ba70SStephan Herhut     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
89327cd2a62SBenjamin Kramer     Value srcOffset = srcDesc.offset(rewriter, loc);
89427cd2a62SBenjamin Kramer     Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
89527cd2a62SBenjamin Kramer                                                 srcBasePtr, srcOffset);
896136d746eSJacques Pienaar     MemRefDescriptor targetDesc(adaptor.getTarget());
897ab95ba70SStephan Herhut     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
89827cd2a62SBenjamin Kramer     Value targetOffset = targetDesc.offset(rewriter, loc);
89927cd2a62SBenjamin Kramer     Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
90027cd2a62SBenjamin Kramer                                                    targetBasePtr, targetOffset);
901ab95ba70SStephan Herhut     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
902ab95ba70SStephan Herhut         loc, typeConverter->convertType(rewriter.getI1Type()),
903ab95ba70SStephan Herhut         rewriter.getBoolAttr(false));
90427cd2a62SBenjamin Kramer     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
905ab95ba70SStephan Herhut                                     isVolatile);
906ab95ba70SStephan Herhut     rewriter.eraseOp(op);
907ab95ba70SStephan Herhut 
908ab95ba70SStephan Herhut     return success();
909ab95ba70SStephan Herhut   }
910ab95ba70SStephan Herhut 
911ab95ba70SStephan Herhut   LogicalResult
lowerToMemCopyFunctionCall__anon7a9e10510111::MemRefCopyOpLowering912ab95ba70SStephan Herhut   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
913ab95ba70SStephan Herhut                              ConversionPatternRewriter &rewriter) const {
91475e5f0aaSAlex Zinenko     auto loc = op.getLoc();
915136d746eSJacques Pienaar     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
916136d746eSJacques Pienaar     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
91775e5f0aaSAlex Zinenko 
91875e5f0aaSAlex Zinenko     // First make sure we have an unranked memref descriptor representation.
91975e5f0aaSAlex Zinenko     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
92075e5f0aaSAlex Zinenko       auto rank = rewriter.create<LLVM::ConstantOp>(
92175e5f0aaSAlex Zinenko           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
92275e5f0aaSAlex Zinenko       auto *typeConverter = getTypeConverter();
92375e5f0aaSAlex Zinenko       auto ptr =
92475e5f0aaSAlex Zinenko           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
92575e5f0aaSAlex Zinenko       auto voidPtr =
92675e5f0aaSAlex Zinenko           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
92775e5f0aaSAlex Zinenko               .getResult();
92875e5f0aaSAlex Zinenko       auto unrankedType =
92975e5f0aaSAlex Zinenko           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
93075e5f0aaSAlex Zinenko       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
93175e5f0aaSAlex Zinenko                                             unrankedType,
93275e5f0aaSAlex Zinenko                                             ValueRange{rank, voidPtr});
93375e5f0aaSAlex Zinenko     };
93475e5f0aaSAlex Zinenko 
93575e5f0aaSAlex Zinenko     Value unrankedSource = srcType.hasRank()
936136d746eSJacques Pienaar                                ? makeUnranked(adaptor.getSource(), srcType)
937136d746eSJacques Pienaar                                : adaptor.getSource();
93875e5f0aaSAlex Zinenko     Value unrankedTarget = targetType.hasRank()
939136d746eSJacques Pienaar                                ? makeUnranked(adaptor.getTarget(), targetType)
940136d746eSJacques Pienaar                                : adaptor.getTarget();
94175e5f0aaSAlex Zinenko 
94275e5f0aaSAlex Zinenko     // Now promote the unranked descriptors to the stack.
94375e5f0aaSAlex Zinenko     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
94475e5f0aaSAlex Zinenko                                                  rewriter.getIndexAttr(1));
94575e5f0aaSAlex Zinenko     auto promote = [&](Value desc) {
94675e5f0aaSAlex Zinenko       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
94775e5f0aaSAlex Zinenko       auto allocated =
94875e5f0aaSAlex Zinenko           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
94975e5f0aaSAlex Zinenko       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
95075e5f0aaSAlex Zinenko       return allocated;
95175e5f0aaSAlex Zinenko     };
95275e5f0aaSAlex Zinenko 
95375e5f0aaSAlex Zinenko     auto sourcePtr = promote(unrankedSource);
95475e5f0aaSAlex Zinenko     auto targetPtr = promote(unrankedTarget);
95575e5f0aaSAlex Zinenko 
9562219f9f5SAdrian Kuegel     unsigned typeSize =
9572219f9f5SAdrian Kuegel         mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
95875e5f0aaSAlex Zinenko     auto elemSize = rewriter.create<LLVM::ConstantOp>(
9592219f9f5SAdrian Kuegel         loc, getIndexType(), rewriter.getIndexAttr(typeSize));
96075e5f0aaSAlex Zinenko     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
96175e5f0aaSAlex Zinenko         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
96275e5f0aaSAlex Zinenko     rewriter.create<LLVM::CallOp>(loc, copyFn,
96375e5f0aaSAlex Zinenko                                   ValueRange{elemSize, sourcePtr, targetPtr});
96475e5f0aaSAlex Zinenko     rewriter.eraseOp(op);
96575e5f0aaSAlex Zinenko 
96675e5f0aaSAlex Zinenko     return success();
96775e5f0aaSAlex Zinenko   }
968ab95ba70SStephan Herhut 
969ab95ba70SStephan Herhut   LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefCopyOpLowering970ab95ba70SStephan Herhut   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
971ab95ba70SStephan Herhut                   ConversionPatternRewriter &rewriter) const override {
972136d746eSJacques Pienaar     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
973136d746eSJacques Pienaar     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
974ab95ba70SStephan Herhut 
97527cd2a62SBenjamin Kramer     auto isContiguousMemrefType = [](BaseMemRefType type) {
97627cd2a62SBenjamin Kramer       auto memrefType = type.dyn_cast<mlir::MemRefType>();
97727cd2a62SBenjamin Kramer       // We can use memcpy for memrefs if they have an identity layout or are
97827cd2a62SBenjamin Kramer       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
97927cd2a62SBenjamin Kramer       // special case handled by memrefCopy.
98027cd2a62SBenjamin Kramer       return memrefType &&
98127cd2a62SBenjamin Kramer              (memrefType.getLayout().isIdentity() ||
98227cd2a62SBenjamin Kramer               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
98327cd2a62SBenjamin Kramer                isStaticShapeAndContiguousRowMajor(memrefType)));
98427cd2a62SBenjamin Kramer     };
98527cd2a62SBenjamin Kramer 
98627cd2a62SBenjamin Kramer     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
987ab95ba70SStephan Herhut       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
988ab95ba70SStephan Herhut 
989ab95ba70SStephan Herhut     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
990ab95ba70SStephan Herhut   }
99175e5f0aaSAlex Zinenko };
99275e5f0aaSAlex Zinenko 
99375e5f0aaSAlex Zinenko /// Extracts allocated, aligned pointers and offset from a ranked or unranked
99475e5f0aaSAlex Zinenko /// memref type. In unranked case, the fields are extracted from the underlying
99575e5f0aaSAlex Zinenko /// ranked descriptor.
extractPointersAndOffset(Location loc,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Value originalOperand,Value convertedOperand,Value * allocatedPtr,Value * alignedPtr,Value * offset=nullptr)99675e5f0aaSAlex Zinenko static void extractPointersAndOffset(Location loc,
99775e5f0aaSAlex Zinenko                                      ConversionPatternRewriter &rewriter,
99875e5f0aaSAlex Zinenko                                      LLVMTypeConverter &typeConverter,
99975e5f0aaSAlex Zinenko                                      Value originalOperand,
100075e5f0aaSAlex Zinenko                                      Value convertedOperand,
100175e5f0aaSAlex Zinenko                                      Value *allocatedPtr, Value *alignedPtr,
100275e5f0aaSAlex Zinenko                                      Value *offset = nullptr) {
100375e5f0aaSAlex Zinenko   Type operandType = originalOperand.getType();
100475e5f0aaSAlex Zinenko   if (operandType.isa<MemRefType>()) {
100575e5f0aaSAlex Zinenko     MemRefDescriptor desc(convertedOperand);
100675e5f0aaSAlex Zinenko     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
100775e5f0aaSAlex Zinenko     *alignedPtr = desc.alignedPtr(rewriter, loc);
100875e5f0aaSAlex Zinenko     if (offset != nullptr)
100975e5f0aaSAlex Zinenko       *offset = desc.offset(rewriter, loc);
101075e5f0aaSAlex Zinenko     return;
101175e5f0aaSAlex Zinenko   }
101275e5f0aaSAlex Zinenko 
101375e5f0aaSAlex Zinenko   unsigned memorySpace =
101475e5f0aaSAlex Zinenko       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
101575e5f0aaSAlex Zinenko   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
101675e5f0aaSAlex Zinenko   Type llvmElementType = typeConverter.convertType(elementType);
101775e5f0aaSAlex Zinenko   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
101875e5f0aaSAlex Zinenko       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
101975e5f0aaSAlex Zinenko 
102075e5f0aaSAlex Zinenko   // Extract pointer to the underlying ranked memref descriptor and cast it to
102175e5f0aaSAlex Zinenko   // ElemType**.
102275e5f0aaSAlex Zinenko   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
102375e5f0aaSAlex Zinenko   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
102475e5f0aaSAlex Zinenko 
102575e5f0aaSAlex Zinenko   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
102675e5f0aaSAlex Zinenko       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
102775e5f0aaSAlex Zinenko   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
102875e5f0aaSAlex Zinenko       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
102975e5f0aaSAlex Zinenko   if (offset != nullptr) {
103075e5f0aaSAlex Zinenko     *offset = UnrankedMemRefDescriptor::offset(
103175e5f0aaSAlex Zinenko         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
103275e5f0aaSAlex Zinenko   }
103375e5f0aaSAlex Zinenko }
103475e5f0aaSAlex Zinenko 
103575e5f0aaSAlex Zinenko struct MemRefReinterpretCastOpLowering
103675e5f0aaSAlex Zinenko     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
103775e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<
103875e5f0aaSAlex Zinenko       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
103975e5f0aaSAlex Zinenko 
104075e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefReinterpretCastOpLowering1041ef976337SRiver Riddle   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
104275e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
1043136d746eSJacques Pienaar     Type srcType = castOp.getSource().getType();
104475e5f0aaSAlex Zinenko 
104575e5f0aaSAlex Zinenko     Value descriptor;
104675e5f0aaSAlex Zinenko     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
104775e5f0aaSAlex Zinenko                                                adaptor, &descriptor)))
104875e5f0aaSAlex Zinenko       return failure();
104975e5f0aaSAlex Zinenko     rewriter.replaceOp(castOp, {descriptor});
105075e5f0aaSAlex Zinenko     return success();
105175e5f0aaSAlex Zinenko   }
105275e5f0aaSAlex Zinenko 
105375e5f0aaSAlex Zinenko private:
convertSourceMemRefToDescriptor__anon7a9e10510111::MemRefReinterpretCastOpLowering105475e5f0aaSAlex Zinenko   LogicalResult convertSourceMemRefToDescriptor(
105575e5f0aaSAlex Zinenko       ConversionPatternRewriter &rewriter, Type srcType,
105675e5f0aaSAlex Zinenko       memref::ReinterpretCastOp castOp,
105775e5f0aaSAlex Zinenko       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
105875e5f0aaSAlex Zinenko     MemRefType targetMemRefType =
105975e5f0aaSAlex Zinenko         castOp.getResult().getType().cast<MemRefType>();
106075e5f0aaSAlex Zinenko     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
106175e5f0aaSAlex Zinenko                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
106275e5f0aaSAlex Zinenko     if (!llvmTargetDescriptorTy)
106375e5f0aaSAlex Zinenko       return failure();
106475e5f0aaSAlex Zinenko 
106575e5f0aaSAlex Zinenko     // Create descriptor.
106675e5f0aaSAlex Zinenko     Location loc = castOp.getLoc();
106775e5f0aaSAlex Zinenko     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
106875e5f0aaSAlex Zinenko 
106975e5f0aaSAlex Zinenko     // Set allocated and aligned pointers.
107075e5f0aaSAlex Zinenko     Value allocatedPtr, alignedPtr;
107175e5f0aaSAlex Zinenko     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1072136d746eSJacques Pienaar                              castOp.getSource(), adaptor.getSource(),
1073136d746eSJacques Pienaar                              &allocatedPtr, &alignedPtr);
107475e5f0aaSAlex Zinenko     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
107575e5f0aaSAlex Zinenko     desc.setAlignedPtr(rewriter, loc, alignedPtr);
107675e5f0aaSAlex Zinenko 
107775e5f0aaSAlex Zinenko     // Set offset.
107875e5f0aaSAlex Zinenko     if (castOp.isDynamicOffset(0))
1079136d746eSJacques Pienaar       desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
108075e5f0aaSAlex Zinenko     else
108175e5f0aaSAlex Zinenko       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
108275e5f0aaSAlex Zinenko 
108375e5f0aaSAlex Zinenko     // Set sizes and strides.
108475e5f0aaSAlex Zinenko     unsigned dynSizeId = 0;
108575e5f0aaSAlex Zinenko     unsigned dynStrideId = 0;
108675e5f0aaSAlex Zinenko     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
108775e5f0aaSAlex Zinenko       if (castOp.isDynamicSize(i))
1088136d746eSJacques Pienaar         desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
108975e5f0aaSAlex Zinenko       else
109075e5f0aaSAlex Zinenko         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
109175e5f0aaSAlex Zinenko 
109275e5f0aaSAlex Zinenko       if (castOp.isDynamicStride(i))
1093136d746eSJacques Pienaar         desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
109475e5f0aaSAlex Zinenko       else
109575e5f0aaSAlex Zinenko         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
109675e5f0aaSAlex Zinenko     }
109775e5f0aaSAlex Zinenko     *descriptor = desc;
109875e5f0aaSAlex Zinenko     return success();
109975e5f0aaSAlex Zinenko   }
110075e5f0aaSAlex Zinenko };
110175e5f0aaSAlex Zinenko 
110275e5f0aaSAlex Zinenko struct MemRefReshapeOpLowering
110375e5f0aaSAlex Zinenko     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
110475e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
110575e5f0aaSAlex Zinenko 
110675e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::MemRefReshapeOpLowering1107ef976337SRiver Riddle   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
110875e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
1109136d746eSJacques Pienaar     Type srcType = reshapeOp.getSource().getType();
111075e5f0aaSAlex Zinenko 
111175e5f0aaSAlex Zinenko     Value descriptor;
111275e5f0aaSAlex Zinenko     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
111375e5f0aaSAlex Zinenko                                                adaptor, &descriptor)))
111475e5f0aaSAlex Zinenko       return failure();
1115ef976337SRiver Riddle     rewriter.replaceOp(reshapeOp, {descriptor});
111675e5f0aaSAlex Zinenko     return success();
111775e5f0aaSAlex Zinenko   }
111875e5f0aaSAlex Zinenko 
111975e5f0aaSAlex Zinenko private:
112075e5f0aaSAlex Zinenko   LogicalResult
convertSourceMemRefToDescriptor__anon7a9e10510111::MemRefReshapeOpLowering112175e5f0aaSAlex Zinenko   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
112275e5f0aaSAlex Zinenko                                   Type srcType, memref::ReshapeOp reshapeOp,
112375e5f0aaSAlex Zinenko                                   memref::ReshapeOp::Adaptor adaptor,
112475e5f0aaSAlex Zinenko                                   Value *descriptor) const {
1125136d746eSJacques Pienaar     auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
11265380e30eSAshay Rane     if (shapeMemRefType.hasStaticShape()) {
11275380e30eSAshay Rane       MemRefType targetMemRefType =
11285380e30eSAshay Rane           reshapeOp.getResult().getType().cast<MemRefType>();
11295380e30eSAshay Rane       auto llvmTargetDescriptorTy =
11305380e30eSAshay Rane           typeConverter->convertType(targetMemRefType)
11315380e30eSAshay Rane               .dyn_cast_or_null<LLVM::LLVMStructType>();
11325380e30eSAshay Rane       if (!llvmTargetDescriptorTy)
113375e5f0aaSAlex Zinenko         return failure();
113475e5f0aaSAlex Zinenko 
11355380e30eSAshay Rane       // Create descriptor.
11365380e30eSAshay Rane       Location loc = reshapeOp.getLoc();
11375380e30eSAshay Rane       auto desc =
11385380e30eSAshay Rane           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11395380e30eSAshay Rane 
11405380e30eSAshay Rane       // Set allocated and aligned pointers.
11415380e30eSAshay Rane       Value allocatedPtr, alignedPtr;
11425380e30eSAshay Rane       extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1143136d746eSJacques Pienaar                                reshapeOp.getSource(), adaptor.getSource(),
11445380e30eSAshay Rane                                &allocatedPtr, &alignedPtr);
11455380e30eSAshay Rane       desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
11465380e30eSAshay Rane       desc.setAlignedPtr(rewriter, loc, alignedPtr);
11475380e30eSAshay Rane 
11485380e30eSAshay Rane       // Extract the offset and strides from the type.
11495380e30eSAshay Rane       int64_t offset;
11505380e30eSAshay Rane       SmallVector<int64_t> strides;
11515380e30eSAshay Rane       if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
11525380e30eSAshay Rane         return rewriter.notifyMatchFailure(
11535380e30eSAshay Rane             reshapeOp, "failed to get stride and offset exprs");
11545380e30eSAshay Rane 
11555380e30eSAshay Rane       if (!isStaticStrideOrOffset(offset))
11565380e30eSAshay Rane         return rewriter.notifyMatchFailure(reshapeOp,
11575380e30eSAshay Rane                                            "dynamic offset is unsupported");
11585380e30eSAshay Rane 
11595380e30eSAshay Rane       desc.setConstantOffset(rewriter, loc, offset);
11605fee1799SAshay Rane 
11615fee1799SAshay Rane       assert(targetMemRefType.getLayout().isIdentity() &&
11625fee1799SAshay Rane              "Identity layout map is a precondition of a valid reshape op");
11635fee1799SAshay Rane 
11645fee1799SAshay Rane       Value stride = nullptr;
11655fee1799SAshay Rane       int64_t targetRank = targetMemRefType.getRank();
11665fee1799SAshay Rane       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
11675fee1799SAshay Rane         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
11685fee1799SAshay Rane           // If the stride for this dimension is dynamic, then use the product
11695fee1799SAshay Rane           // of the sizes of the inner dimensions.
11705fee1799SAshay Rane           stride = createIndexConstant(rewriter, loc, strides[i]);
11715fee1799SAshay Rane         } else if (!stride) {
11725fee1799SAshay Rane           // `stride` is null only in the first iteration of the loop.  However,
11735fee1799SAshay Rane           // since the target memref has an identity layout, we can safely set
11745fee1799SAshay Rane           // the innermost stride to 1.
11755fee1799SAshay Rane           stride = createIndexConstant(rewriter, loc, 1);
11765fee1799SAshay Rane         }
11775fee1799SAshay Rane 
11785fee1799SAshay Rane         Value dimSize;
11795fee1799SAshay Rane         int64_t size = targetMemRefType.getDimSize(i);
11805fee1799SAshay Rane         // If the size of this dimension is dynamic, then load it at runtime
11815fee1799SAshay Rane         // from the shape operand.
11825fee1799SAshay Rane         if (!ShapedType::isDynamic(size)) {
11835fee1799SAshay Rane           dimSize = createIndexConstant(rewriter, loc, size);
11845fee1799SAshay Rane         } else {
1185136d746eSJacques Pienaar           Value shapeOp = reshapeOp.getShape();
11865fee1799SAshay Rane           Value index = createIndexConstant(rewriter, loc, i);
11875fee1799SAshay Rane           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1188d4217e6cSIvan Butygin           Type indexType = getIndexType();
1189d4217e6cSIvan Butygin           if (dimSize.getType() != indexType)
1190d4217e6cSIvan Butygin             dimSize = typeConverter->materializeTargetConversion(
1191d4217e6cSIvan Butygin                 rewriter, loc, indexType, dimSize);
1192d4217e6cSIvan Butygin           assert(dimSize && "Invalid memref element type");
11935fee1799SAshay Rane         }
11945fee1799SAshay Rane 
11955fee1799SAshay Rane         desc.setSize(rewriter, loc, i, dimSize);
11965fee1799SAshay Rane         desc.setStride(rewriter, loc, i, stride);
11975fee1799SAshay Rane 
11985fee1799SAshay Rane         // Prepare the stride value for the next dimension.
11995fee1799SAshay Rane         stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
12005380e30eSAshay Rane       }
12015380e30eSAshay Rane 
12025380e30eSAshay Rane       *descriptor = desc;
12035380e30eSAshay Rane       return success();
12045380e30eSAshay Rane     }
12055380e30eSAshay Rane 
120675e5f0aaSAlex Zinenko     // The shape is a rank-1 tensor with unknown length.
120775e5f0aaSAlex Zinenko     Location loc = reshapeOp.getLoc();
1208136d746eSJacques Pienaar     MemRefDescriptor shapeDesc(adaptor.getShape());
120975e5f0aaSAlex Zinenko     Value resultRank = shapeDesc.size(rewriter, loc, 0);
121075e5f0aaSAlex Zinenko 
121175e5f0aaSAlex Zinenko     // Extract address space and element type.
121275e5f0aaSAlex Zinenko     auto targetType =
121375e5f0aaSAlex Zinenko         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
121475e5f0aaSAlex Zinenko     unsigned addressSpace = targetType.getMemorySpaceAsInt();
121575e5f0aaSAlex Zinenko     Type elementType = targetType.getElementType();
121675e5f0aaSAlex Zinenko 
121775e5f0aaSAlex Zinenko     // Create the unranked memref descriptor that holds the ranked one. The
121875e5f0aaSAlex Zinenko     // inner descriptor is allocated on stack.
121975e5f0aaSAlex Zinenko     auto targetDesc = UnrankedMemRefDescriptor::undef(
122075e5f0aaSAlex Zinenko         rewriter, loc, typeConverter->convertType(targetType));
122175e5f0aaSAlex Zinenko     targetDesc.setRank(rewriter, loc, resultRank);
122275e5f0aaSAlex Zinenko     SmallVector<Value, 4> sizes;
122375e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
122475e5f0aaSAlex Zinenko                                            targetDesc, sizes);
122575e5f0aaSAlex Zinenko     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
122675e5f0aaSAlex Zinenko         loc, getVoidPtrType(), sizes.front(), llvm::None);
122775e5f0aaSAlex Zinenko     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
122875e5f0aaSAlex Zinenko 
122975e5f0aaSAlex Zinenko     // Extract pointers and offset from the source memref.
123075e5f0aaSAlex Zinenko     Value allocatedPtr, alignedPtr, offset;
123175e5f0aaSAlex Zinenko     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1232136d746eSJacques Pienaar                              reshapeOp.getSource(), adaptor.getSource(),
123375e5f0aaSAlex Zinenko                              &allocatedPtr, &alignedPtr, &offset);
123475e5f0aaSAlex Zinenko 
123575e5f0aaSAlex Zinenko     // Set pointers and offset.
123675e5f0aaSAlex Zinenko     Type llvmElementType = typeConverter->convertType(elementType);
123775e5f0aaSAlex Zinenko     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
123875e5f0aaSAlex Zinenko         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
123975e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
124075e5f0aaSAlex Zinenko                                               elementPtrPtrType, allocatedPtr);
124175e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
124275e5f0aaSAlex Zinenko                                             underlyingDescPtr,
124375e5f0aaSAlex Zinenko                                             elementPtrPtrType, alignedPtr);
124475e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
124575e5f0aaSAlex Zinenko                                         underlyingDescPtr, elementPtrPtrType,
124675e5f0aaSAlex Zinenko                                         offset);
124775e5f0aaSAlex Zinenko 
124875e5f0aaSAlex Zinenko     // Use the offset pointer as base for further addressing. Copy over the new
124975e5f0aaSAlex Zinenko     // shape and compute strides. For this, we create a loop from rank-1 to 0.
125075e5f0aaSAlex Zinenko     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
125175e5f0aaSAlex Zinenko         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
125275e5f0aaSAlex Zinenko         elementPtrPtrType);
125375e5f0aaSAlex Zinenko     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
125475e5f0aaSAlex Zinenko         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
125575e5f0aaSAlex Zinenko     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
125675e5f0aaSAlex Zinenko     Value oneIndex = createIndexConstant(rewriter, loc, 1);
125775e5f0aaSAlex Zinenko     Value resultRankMinusOne =
125875e5f0aaSAlex Zinenko         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
125975e5f0aaSAlex Zinenko 
126075e5f0aaSAlex Zinenko     Block *initBlock = rewriter.getInsertionBlock();
126175e5f0aaSAlex Zinenko     Type indexType = getTypeConverter()->getIndexType();
126275e5f0aaSAlex Zinenko     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
126375e5f0aaSAlex Zinenko 
126475e5f0aaSAlex Zinenko     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1265e084679fSRiver Riddle                                             {indexType, indexType}, {loc, loc});
126675e5f0aaSAlex Zinenko 
126775e5f0aaSAlex Zinenko     // Move the remaining initBlock ops to condBlock.
126875e5f0aaSAlex Zinenko     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
126975e5f0aaSAlex Zinenko     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
127075e5f0aaSAlex Zinenko 
127175e5f0aaSAlex Zinenko     rewriter.setInsertionPointToEnd(initBlock);
127275e5f0aaSAlex Zinenko     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
127375e5f0aaSAlex Zinenko                                 condBlock);
127475e5f0aaSAlex Zinenko     rewriter.setInsertionPointToStart(condBlock);
127575e5f0aaSAlex Zinenko     Value indexArg = condBlock->getArgument(0);
127675e5f0aaSAlex Zinenko     Value strideArg = condBlock->getArgument(1);
127775e5f0aaSAlex Zinenko 
127875e5f0aaSAlex Zinenko     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
127975e5f0aaSAlex Zinenko     Value pred = rewriter.create<LLVM::ICmpOp>(
128075e5f0aaSAlex Zinenko         loc, IntegerType::get(rewriter.getContext(), 1),
128175e5f0aaSAlex Zinenko         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
128275e5f0aaSAlex Zinenko 
128375e5f0aaSAlex Zinenko     Block *bodyBlock =
128475e5f0aaSAlex Zinenko         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
128575e5f0aaSAlex Zinenko     rewriter.setInsertionPointToStart(bodyBlock);
128675e5f0aaSAlex Zinenko 
128775e5f0aaSAlex Zinenko     // Copy size from shape to descriptor.
128875e5f0aaSAlex Zinenko     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
128975e5f0aaSAlex Zinenko     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
129075e5f0aaSAlex Zinenko         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
129175e5f0aaSAlex Zinenko     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
129275e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
129375e5f0aaSAlex Zinenko                                       targetSizesBase, indexArg, size);
129475e5f0aaSAlex Zinenko 
129575e5f0aaSAlex Zinenko     // Write stride value and compute next one.
129675e5f0aaSAlex Zinenko     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
129775e5f0aaSAlex Zinenko                                         targetStridesBase, indexArg, strideArg);
129875e5f0aaSAlex Zinenko     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
129975e5f0aaSAlex Zinenko 
130075e5f0aaSAlex Zinenko     // Decrement loop counter and branch back.
130175e5f0aaSAlex Zinenko     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
130275e5f0aaSAlex Zinenko     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
130375e5f0aaSAlex Zinenko                                 condBlock);
130475e5f0aaSAlex Zinenko 
130575e5f0aaSAlex Zinenko     Block *remainder =
130675e5f0aaSAlex Zinenko         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
130775e5f0aaSAlex Zinenko 
130875e5f0aaSAlex Zinenko     // Hook up the cond exit to the remainder.
130975e5f0aaSAlex Zinenko     rewriter.setInsertionPointToEnd(condBlock);
131075e5f0aaSAlex Zinenko     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
131175e5f0aaSAlex Zinenko                                     llvm::None);
131275e5f0aaSAlex Zinenko 
131375e5f0aaSAlex Zinenko     // Reset position to beginning of new remainder block.
131475e5f0aaSAlex Zinenko     rewriter.setInsertionPointToStart(remainder);
131575e5f0aaSAlex Zinenko 
131675e5f0aaSAlex Zinenko     *descriptor = targetDesc;
131775e5f0aaSAlex Zinenko     return success();
131875e5f0aaSAlex Zinenko   }
131975e5f0aaSAlex Zinenko };
132075e5f0aaSAlex Zinenko 
1321381c3b92SYi Zhang /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1322381c3b92SYi Zhang /// `Value`s.
getAsValues(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<OpFoldResult> valueOrAttrVec)1323381c3b92SYi Zhang static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1324381c3b92SYi Zhang                                       Type &llvmIndexType,
1325381c3b92SYi Zhang                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1326381c3b92SYi Zhang   return llvm::to_vector<4>(
1327381c3b92SYi Zhang       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1328381c3b92SYi Zhang         if (auto attr = value.dyn_cast<Attribute>())
1329381c3b92SYi Zhang           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1330381c3b92SYi Zhang         return value.get<Value>();
1331381c3b92SYi Zhang       }));
1332381c3b92SYi Zhang }
1333381c3b92SYi Zhang 
1334381c3b92SYi Zhang /// Compute a map that for a given dimension of the expanded type gives the
1335381c3b92SYi Zhang /// dimension in the collapsed type it maps to. Essentially its the inverse of
1336381c3b92SYi Zhang /// the `reassocation` maps.
1337381c3b92SYi Zhang static DenseMap<int64_t, int64_t>
getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation)1338381c3b92SYi Zhang getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1339381c3b92SYi Zhang   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1340381c3b92SYi Zhang   for (auto &en : enumerate(reassociation)) {
1341381c3b92SYi Zhang     for (auto dim : en.value())
1342381c3b92SYi Zhang       expandedDimToCollapsedDim[dim] = en.index();
1343381c3b92SYi Zhang   }
1344381c3b92SYi Zhang   return expandedDimToCollapsedDim;
1345381c3b92SYi Zhang }
1346381c3b92SYi Zhang 
1347381c3b92SYi Zhang static OpFoldResult
getExpandedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,ArrayRef<int64_t> outStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> inStaticShape,ArrayRef<ReassociationIndices> reassocation,DenseMap<int64_t,int64_t> & outDimToInDimMap)1348381c3b92SYi Zhang getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1349381c3b92SYi Zhang                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1350381c3b92SYi Zhang                          MemRefDescriptor &inDesc,
1351381c3b92SYi Zhang                          ArrayRef<int64_t> inStaticShape,
1352381c3b92SYi Zhang                          ArrayRef<ReassociationIndices> reassocation,
1353381c3b92SYi Zhang                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1354381c3b92SYi Zhang   int64_t outDimSize = outStaticShape[outDimIndex];
1355381c3b92SYi Zhang   if (!ShapedType::isDynamic(outDimSize))
1356381c3b92SYi Zhang     return b.getIndexAttr(outDimSize);
1357381c3b92SYi Zhang 
1358381c3b92SYi Zhang   // Calculate the multiplication of all the out dim sizes except the
1359381c3b92SYi Zhang   // current dim.
1360381c3b92SYi Zhang   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1361381c3b92SYi Zhang   int64_t otherDimSizesMul = 1;
1362381c3b92SYi Zhang   for (auto otherDimIndex : reassocation[inDimIndex]) {
1363381c3b92SYi Zhang     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1364381c3b92SYi Zhang       continue;
1365381c3b92SYi Zhang     int64_t otherDimSize = outStaticShape[otherDimIndex];
1366381c3b92SYi Zhang     assert(!ShapedType::isDynamic(otherDimSize) &&
1367381c3b92SYi Zhang            "single dimension cannot be expanded into multiple dynamic "
1368381c3b92SYi Zhang            "dimensions");
1369381c3b92SYi Zhang     otherDimSizesMul *= otherDimSize;
1370381c3b92SYi Zhang   }
1371381c3b92SYi Zhang 
1372381c3b92SYi Zhang   // outDimSize = inDimSize / otherOutDimSizesMul
1373381c3b92SYi Zhang   int64_t inDimSize = inStaticShape[inDimIndex];
1374381c3b92SYi Zhang   Value inDimSizeDynamic =
1375381c3b92SYi Zhang       ShapedType::isDynamic(inDimSize)
1376381c3b92SYi Zhang           ? inDesc.size(b, loc, inDimIndex)
1377381c3b92SYi Zhang           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1378381c3b92SYi Zhang                                        b.getIndexAttr(inDimSize));
1379381c3b92SYi Zhang   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1380381c3b92SYi Zhang       loc, inDimSizeDynamic,
1381381c3b92SYi Zhang       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1382381c3b92SYi Zhang                                  b.getIndexAttr(otherDimSizesMul)));
1383381c3b92SYi Zhang   return outDimSizeDynamic;
1384381c3b92SYi Zhang }
1385381c3b92SYi Zhang 
getCollapsedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,int64_t outDimSize,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<ReassociationIndices> reassocation)1386381c3b92SYi Zhang static OpFoldResult getCollapsedOutputDimSize(
1387381c3b92SYi Zhang     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1388381c3b92SYi Zhang     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1389381c3b92SYi Zhang     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1390381c3b92SYi Zhang   if (!ShapedType::isDynamic(outDimSize))
1391381c3b92SYi Zhang     return b.getIndexAttr(outDimSize);
1392381c3b92SYi Zhang 
1393381c3b92SYi Zhang   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1394381c3b92SYi Zhang   Value outDimSizeDynamic = c1;
1395381c3b92SYi Zhang   for (auto inDimIndex : reassocation[outDimIndex]) {
1396381c3b92SYi Zhang     int64_t inDimSize = inStaticShape[inDimIndex];
1397381c3b92SYi Zhang     Value inDimSizeDynamic =
1398381c3b92SYi Zhang         ShapedType::isDynamic(inDimSize)
1399381c3b92SYi Zhang             ? inDesc.size(b, loc, inDimIndex)
1400381c3b92SYi Zhang             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1401381c3b92SYi Zhang                                          b.getIndexAttr(inDimSize));
1402381c3b92SYi Zhang     outDimSizeDynamic =
1403381c3b92SYi Zhang         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1404381c3b92SYi Zhang   }
1405381c3b92SYi Zhang   return outDimSizeDynamic;
1406381c3b92SYi Zhang }
1407381c3b92SYi Zhang 
1408381c3b92SYi Zhang static SmallVector<OpFoldResult, 4>
getCollapsedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1409381c3b92SYi Zhang getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1410e1318078SYi Zhang                         ArrayRef<ReassociationIndices> reassociation,
1411381c3b92SYi Zhang                         ArrayRef<int64_t> inStaticShape,
1412381c3b92SYi Zhang                         MemRefDescriptor &inDesc,
1413381c3b92SYi Zhang                         ArrayRef<int64_t> outStaticShape) {
1414381c3b92SYi Zhang   return llvm::to_vector<4>(llvm::map_range(
1415381c3b92SYi Zhang       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1416381c3b92SYi Zhang         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1417381c3b92SYi Zhang                                          outStaticShape[outDimIndex],
1418e1318078SYi Zhang                                          inStaticShape, inDesc, reassociation);
1419381c3b92SYi Zhang       }));
1420381c3b92SYi Zhang }
1421381c3b92SYi Zhang 
1422381c3b92SYi Zhang static SmallVector<OpFoldResult, 4>
getExpandedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1423381c3b92SYi Zhang getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1424e1318078SYi Zhang                        ArrayRef<ReassociationIndices> reassociation,
1425381c3b92SYi Zhang                        ArrayRef<int64_t> inStaticShape,
1426381c3b92SYi Zhang                        MemRefDescriptor &inDesc,
1427381c3b92SYi Zhang                        ArrayRef<int64_t> outStaticShape) {
1428381c3b92SYi Zhang   DenseMap<int64_t, int64_t> outDimToInDimMap =
1429e1318078SYi Zhang       getExpandedDimToCollapsedDimMap(reassociation);
1430381c3b92SYi Zhang   return llvm::to_vector<4>(llvm::map_range(
1431381c3b92SYi Zhang       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1432381c3b92SYi Zhang         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1433381c3b92SYi Zhang                                         outStaticShape, inDesc, inStaticShape,
1434e1318078SYi Zhang                                         reassociation, outDimToInDimMap);
1435381c3b92SYi Zhang       }));
1436381c3b92SYi Zhang }
1437381c3b92SYi Zhang 
1438381c3b92SYi Zhang static SmallVector<Value>
getDynamicOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassociation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1439381c3b92SYi Zhang getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1440e1318078SYi Zhang                       ArrayRef<ReassociationIndices> reassociation,
1441381c3b92SYi Zhang                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1442381c3b92SYi Zhang                       ArrayRef<int64_t> outStaticShape) {
1443381c3b92SYi Zhang   return outStaticShape.size() < inStaticShape.size()
1444381c3b92SYi Zhang              ? getAsValues(b, loc, llvmIndexType,
1445381c3b92SYi Zhang                            getCollapsedOutputShape(b, loc, llvmIndexType,
1446e1318078SYi Zhang                                                    reassociation, inStaticShape,
1447381c3b92SYi Zhang                                                    inDesc, outStaticShape))
1448381c3b92SYi Zhang              : getAsValues(b, loc, llvmIndexType,
1449381c3b92SYi Zhang                            getExpandedOutputShape(b, loc, llvmIndexType,
1450e1318078SYi Zhang                                                   reassociation, inStaticShape,
1451381c3b92SYi Zhang                                                   inDesc, outStaticShape));
1452381c3b92SYi Zhang }
1453381c3b92SYi Zhang 
fillInStridesForExpandedMemDescriptor(OpBuilder & b,Location loc,MemRefType srcType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1454e1318078SYi Zhang static void fillInStridesForExpandedMemDescriptor(
1455e1318078SYi Zhang     OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1456e1318078SYi Zhang     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1457e1318078SYi Zhang   // See comments for computeExpandedLayoutMap for details on how the strides
1458e1318078SYi Zhang   // are calculated.
1459e1318078SYi Zhang   for (auto &en : llvm::enumerate(reassociation)) {
1460e1318078SYi Zhang     auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1461e1318078SYi Zhang     for (auto dstIndex : llvm::reverse(en.value())) {
1462e1318078SYi Zhang       dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1463e1318078SYi Zhang       Value size = dstDesc.size(b, loc, dstIndex);
1464e1318078SYi Zhang       currentStrideToExpand =
1465e1318078SYi Zhang           b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1466e1318078SYi Zhang     }
1467e1318078SYi Zhang   }
1468e1318078SYi Zhang }
1469e1318078SYi Zhang 
fillInStridesForCollapsedMemDescriptor(ConversionPatternRewriter & rewriter,Location loc,Operation * op,TypeConverter * typeConverter,MemRefType srcType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1470e1318078SYi Zhang static void fillInStridesForCollapsedMemDescriptor(
1471e1318078SYi Zhang     ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1472e1318078SYi Zhang     TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1473e1318078SYi Zhang     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1474e1318078SYi Zhang   // See comments for computeCollapsedLayoutMap for details on how the strides
1475e1318078SYi Zhang   // are calculated.
1476e1318078SYi Zhang   auto srcShape = srcType.getShape();
1477e1318078SYi Zhang   for (auto &en : llvm::enumerate(reassociation)) {
1478e1318078SYi Zhang     rewriter.setInsertionPoint(op);
1479e1318078SYi Zhang     auto dstIndex = en.index();
1480e1318078SYi Zhang     ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1481e1318078SYi Zhang     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1482e1318078SYi Zhang       ref = ref.drop_back();
1483e1318078SYi Zhang     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1484e1318078SYi Zhang       dstDesc.setStride(rewriter, loc, dstIndex,
1485e1318078SYi Zhang                         srcDesc.stride(rewriter, loc, ref.back()));
1486e1318078SYi Zhang     } else {
1487e1318078SYi Zhang       // Iterate over the source strides in reverse order. Skip over the
1488e1318078SYi Zhang       // dimensions whose size is 1.
1489e1318078SYi Zhang       // TODO: we should take the minimum stride in the reassociation group
1490e1318078SYi Zhang       // instead of just the first where the dimension is not 1.
1491e1318078SYi Zhang       //
1492e1318078SYi Zhang       // +------------------------------------------------------+
1493e1318078SYi Zhang       // | curEntry:                                            |
1494e1318078SYi Zhang       // |   %srcStride = strides[srcIndex]                     |
1495e1318078SYi Zhang       // |   %neOne = cmp sizes[srcIndex],1                     +--+
1496e1318078SYi Zhang       // |   cf.cond_br %neOne, continue(%srcStride), nextEntry |  |
1497e1318078SYi Zhang       // +-------------------------+----------------------------+  |
1498e1318078SYi Zhang       //                           |                               |
1499e1318078SYi Zhang       //                           v                               |
1500e1318078SYi Zhang       //            +-----------------------------+                |
1501e1318078SYi Zhang       //            | nextEntry:                  |                |
1502e1318078SYi Zhang       //            |   ...                       +---+            |
1503e1318078SYi Zhang       //            +--------------+--------------+   |            |
1504e1318078SYi Zhang       //                           |                  |            |
1505e1318078SYi Zhang       //                           v                  |            |
1506e1318078SYi Zhang       //            +-----------------------------+   |            |
1507e1318078SYi Zhang       //            | nextEntry:                  |   |            |
1508e1318078SYi Zhang       //            |   ...                       |   |            |
1509e1318078SYi Zhang       //            +--------------+--------------+   |   +--------+
1510e1318078SYi Zhang       //                           |                  |   |
1511e1318078SYi Zhang       //                           v                  v   v
1512e1318078SYi Zhang       //   +--------------------------------------------------+
1513e1318078SYi Zhang       //   | continue(%newStride):                            |
1514e1318078SYi Zhang       //   |   %newMemRefDes = setStride(%newStride,dstIndex) |
1515e1318078SYi Zhang       //   +--------------------------------------------------+
1516e1318078SYi Zhang       OpBuilder::InsertionGuard guard(rewriter);
1517e1318078SYi Zhang       Block *initBlock = rewriter.getInsertionBlock();
1518e1318078SYi Zhang       Block *continueBlock =
1519e1318078SYi Zhang           rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1520e1318078SYi Zhang       continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1521e1318078SYi Zhang       rewriter.setInsertionPointToStart(continueBlock);
1522e1318078SYi Zhang       dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1523e1318078SYi Zhang 
1524e1318078SYi Zhang       Block *curEntryBlock = initBlock;
1525e1318078SYi Zhang       Block *nextEntryBlock;
1526e1318078SYi Zhang       for (auto srcIndex : llvm::reverse(ref)) {
1527e1318078SYi Zhang         if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1528e1318078SYi Zhang           continue;
1529e1318078SYi Zhang         rewriter.setInsertionPointToEnd(curEntryBlock);
1530e1318078SYi Zhang         Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1531e1318078SYi Zhang         if (srcIndex == ref.front()) {
1532e1318078SYi Zhang           rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1533e1318078SYi Zhang           break;
1534e1318078SYi Zhang         }
1535e1318078SYi Zhang         Value one = rewriter.create<LLVM::ConstantOp>(
1536e1318078SYi Zhang             loc, typeConverter->convertType(rewriter.getI64Type()),
1537e1318078SYi Zhang             rewriter.getI32IntegerAttr(1));
1538e1318078SYi Zhang         Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1539e1318078SYi Zhang             loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1540e1318078SYi Zhang             one);
1541e1318078SYi Zhang         {
1542e1318078SYi Zhang           OpBuilder::InsertionGuard guard(rewriter);
1543e1318078SYi Zhang           nextEntryBlock = rewriter.createBlock(
1544e1318078SYi Zhang               initBlock->getParent(), Region::iterator(continueBlock), {});
1545e1318078SYi Zhang         }
1546e1318078SYi Zhang         rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1547e1318078SYi Zhang                                         srcStride, nextEntryBlock, llvm::None);
1548e1318078SYi Zhang         curEntryBlock = nextEntryBlock;
1549e1318078SYi Zhang       }
1550e1318078SYi Zhang     }
1551e1318078SYi Zhang   }
1552e1318078SYi Zhang }
1553e1318078SYi Zhang 
fillInDynamicStridesForMemDescriptor(ConversionPatternRewriter & b,Location loc,Operation * op,TypeConverter * typeConverter,MemRefType srcType,MemRefType dstType,MemRefDescriptor & srcDesc,MemRefDescriptor & dstDesc,ArrayRef<ReassociationIndices> reassociation)1554e1318078SYi Zhang static void fillInDynamicStridesForMemDescriptor(
1555e1318078SYi Zhang     ConversionPatternRewriter &b, Location loc, Operation *op,
1556e1318078SYi Zhang     TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1557e1318078SYi Zhang     MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1558e1318078SYi Zhang     ArrayRef<ReassociationIndices> reassociation) {
1559e1318078SYi Zhang   if (srcType.getRank() > dstType.getRank())
1560e1318078SYi Zhang     fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1561e1318078SYi Zhang                                            srcDesc, dstDesc, reassociation);
1562e1318078SYi Zhang   else
1563e1318078SYi Zhang     fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1564e1318078SYi Zhang                                           reassociation);
1565e1318078SYi Zhang }
1566e1318078SYi Zhang 
156746ef86b5SAlexander Belyaev // ReshapeOp creates a new view descriptor of the proper rank.
156846ef86b5SAlexander Belyaev // For now, the only conversion supported is for target MemRef with static sizes
156946ef86b5SAlexander Belyaev // and strides.
157046ef86b5SAlexander Belyaev template <typename ReshapeOp>
157146ef86b5SAlexander Belyaev class ReassociatingReshapeOpConversion
157246ef86b5SAlexander Belyaev     : public ConvertOpToLLVMPattern<ReshapeOp> {
157346ef86b5SAlexander Belyaev public:
157446ef86b5SAlexander Belyaev   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
157546ef86b5SAlexander Belyaev   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
157646ef86b5SAlexander Belyaev 
157746ef86b5SAlexander Belyaev   LogicalResult
matchAndRewrite(ReshapeOp reshapeOp,typename ReshapeOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1578ef976337SRiver Riddle   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
157946ef86b5SAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
158046ef86b5SAlexander Belyaev     MemRefType dstType = reshapeOp.getResultType();
1581381c3b92SYi Zhang     MemRefType srcType = reshapeOp.getSrcType();
1582499703e9SBenoit Jacob 
158346ef86b5SAlexander Belyaev     int64_t offset;
158446ef86b5SAlexander Belyaev     SmallVector<int64_t, 4> strides;
1585381c3b92SYi Zhang     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1586381c3b92SYi Zhang       return rewriter.notifyMatchFailure(
1587381c3b92SYi Zhang           reshapeOp, "failed to get stride and offset exprs");
1588381c3b92SYi Zhang     }
158946ef86b5SAlexander Belyaev 
1590136d746eSJacques Pienaar     MemRefDescriptor srcDesc(adaptor.getSrc());
159146ef86b5SAlexander Belyaev     Location loc = reshapeOp->getLoc();
1592381c3b92SYi Zhang     auto dstDesc = MemRefDescriptor::undef(
1593381c3b92SYi Zhang         rewriter, loc, this->typeConverter->convertType(dstType));
1594381c3b92SYi Zhang     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1595381c3b92SYi Zhang     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1596381c3b92SYi Zhang     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1597381c3b92SYi Zhang 
1598381c3b92SYi Zhang     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1599381c3b92SYi Zhang     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1600381c3b92SYi Zhang     Type llvmIndexType =
1601381c3b92SYi Zhang         this->typeConverter->convertType(rewriter.getIndexType());
1602381c3b92SYi Zhang     SmallVector<Value> dstShape = getDynamicOutputShape(
1603381c3b92SYi Zhang         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1604381c3b92SYi Zhang         srcStaticShape, srcDesc, dstStaticShape);
1605381c3b92SYi Zhang     for (auto &en : llvm::enumerate(dstShape))
1606381c3b92SYi Zhang       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1607381c3b92SYi Zhang 
16085380e30eSAshay Rane     if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1609381c3b92SYi Zhang       for (auto &en : llvm::enumerate(strides))
1610381c3b92SYi Zhang         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1611e1318078SYi Zhang     } else if (srcType.getLayout().isIdentity() &&
1612e1318078SYi Zhang                dstType.getLayout().isIdentity()) {
1613381c3b92SYi Zhang       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1614381c3b92SYi Zhang                                                    rewriter.getIndexAttr(1));
1615381c3b92SYi Zhang       Value stride = c1;
1616381c3b92SYi Zhang       for (auto dimIndex :
1617381c3b92SYi Zhang            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1618381c3b92SYi Zhang         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1619381c3b92SYi Zhang         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1620381c3b92SYi Zhang       }
1621e1318078SYi Zhang     } else {
1622e1318078SYi Zhang       // There could be mixed static/dynamic strides. For simplicity, we
1623e1318078SYi Zhang       // recompute all strides if there is at least one dynamic stride.
1624e1318078SYi Zhang       fillInDynamicStridesForMemDescriptor(
1625e1318078SYi Zhang           rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1626e1318078SYi Zhang           srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1627381c3b92SYi Zhang     }
1628381c3b92SYi Zhang     rewriter.replaceOp(reshapeOp, {dstDesc});
162946ef86b5SAlexander Belyaev     return success();
163046ef86b5SAlexander Belyaev   }
163146ef86b5SAlexander Belyaev };
1632381c3b92SYi Zhang 
163375e5f0aaSAlex Zinenko /// Conversion pattern that transforms a subview op into:
163475e5f0aaSAlex Zinenko ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
163575e5f0aaSAlex Zinenko ///   2. Updates to the descriptor to introduce the data ptr, offset, size
163675e5f0aaSAlex Zinenko ///      and stride.
163775e5f0aaSAlex Zinenko /// The subview op is replaced by the descriptor.
163875e5f0aaSAlex Zinenko struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
163975e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
164075e5f0aaSAlex Zinenko 
164175e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::SubViewOpLowering1642ef976337SRiver Riddle   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
164375e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
164475e5f0aaSAlex Zinenko     auto loc = subViewOp.getLoc();
164575e5f0aaSAlex Zinenko 
1646136d746eSJacques Pienaar     auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
164775e5f0aaSAlex Zinenko     auto sourceElementTy =
164875e5f0aaSAlex Zinenko         typeConverter->convertType(sourceMemRefType.getElementType());
164975e5f0aaSAlex Zinenko 
165075e5f0aaSAlex Zinenko     auto viewMemRefType = subViewOp.getType();
1651136d746eSJacques Pienaar     auto inferredType =
1652136d746eSJacques Pienaar         memref::SubViewOp::inferResultType(
165375e5f0aaSAlex Zinenko             subViewOp.getSourceType(),
1654136d746eSJacques Pienaar             extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
1655136d746eSJacques Pienaar             extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
1656136d746eSJacques Pienaar             extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
165775e5f0aaSAlex Zinenko             .cast<MemRefType>();
165875e5f0aaSAlex Zinenko     auto targetElementTy =
165975e5f0aaSAlex Zinenko         typeConverter->convertType(viewMemRefType.getElementType());
166075e5f0aaSAlex Zinenko     auto targetDescTy = typeConverter->convertType(viewMemRefType);
166175e5f0aaSAlex Zinenko     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
166275e5f0aaSAlex Zinenko         !LLVM::isCompatibleType(sourceElementTy) ||
166375e5f0aaSAlex Zinenko         !LLVM::isCompatibleType(targetElementTy) ||
166475e5f0aaSAlex Zinenko         !LLVM::isCompatibleType(targetDescTy))
166575e5f0aaSAlex Zinenko       return failure();
166675e5f0aaSAlex Zinenko 
166775e5f0aaSAlex Zinenko     // Extract the offset and strides from the type.
166875e5f0aaSAlex Zinenko     int64_t offset;
166975e5f0aaSAlex Zinenko     SmallVector<int64_t, 4> strides;
167075e5f0aaSAlex Zinenko     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
167175e5f0aaSAlex Zinenko     if (failed(successStrides))
167275e5f0aaSAlex Zinenko       return failure();
167375e5f0aaSAlex Zinenko 
167475e5f0aaSAlex Zinenko     // Create the descriptor.
1675ef976337SRiver Riddle     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
167675e5f0aaSAlex Zinenko       return failure();
1677ef976337SRiver Riddle     MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
167875e5f0aaSAlex Zinenko     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
167975e5f0aaSAlex Zinenko 
168075e5f0aaSAlex Zinenko     // Copy the buffer pointer from the old descriptor to the new one.
168175e5f0aaSAlex Zinenko     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
168275e5f0aaSAlex Zinenko     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
168375e5f0aaSAlex Zinenko         loc,
168475e5f0aaSAlex Zinenko         LLVM::LLVMPointerType::get(targetElementTy,
168575e5f0aaSAlex Zinenko                                    viewMemRefType.getMemorySpaceAsInt()),
168675e5f0aaSAlex Zinenko         extracted);
168775e5f0aaSAlex Zinenko     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
168875e5f0aaSAlex Zinenko 
168975e5f0aaSAlex Zinenko     // Copy the aligned pointer from the old descriptor to the new one.
169075e5f0aaSAlex Zinenko     extracted = sourceMemRef.alignedPtr(rewriter, loc);
169175e5f0aaSAlex Zinenko     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
169275e5f0aaSAlex Zinenko         loc,
169375e5f0aaSAlex Zinenko         LLVM::LLVMPointerType::get(targetElementTy,
169475e5f0aaSAlex Zinenko                                    viewMemRefType.getMemorySpaceAsInt()),
169575e5f0aaSAlex Zinenko         extracted);
169675e5f0aaSAlex Zinenko     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
169775e5f0aaSAlex Zinenko 
16984cf9bf6cSMaheshRavishankar     size_t inferredShapeRank = inferredType.getRank();
16994cf9bf6cSMaheshRavishankar     size_t resultShapeRank = viewMemRefType.getRank();
170075e5f0aaSAlex Zinenko 
170175e5f0aaSAlex Zinenko     // Extract strides needed to compute offset.
170275e5f0aaSAlex Zinenko     SmallVector<Value, 4> strideValues;
170375e5f0aaSAlex Zinenko     strideValues.reserve(inferredShapeRank);
170475e5f0aaSAlex Zinenko     for (unsigned i = 0; i < inferredShapeRank; ++i)
170575e5f0aaSAlex Zinenko       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
170675e5f0aaSAlex Zinenko 
170775e5f0aaSAlex Zinenko     // Offset.
170875e5f0aaSAlex Zinenko     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
170975e5f0aaSAlex Zinenko     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
171075e5f0aaSAlex Zinenko       targetMemRef.setConstantOffset(rewriter, loc, offset);
171175e5f0aaSAlex Zinenko     } else {
171275e5f0aaSAlex Zinenko       Value baseOffset = sourceMemRef.offset(rewriter, loc);
171375e5f0aaSAlex Zinenko       // `inferredShapeRank` may be larger than the number of offset operands
171475e5f0aaSAlex Zinenko       // because of trailing semantics. In this case, the offset is guaranteed
171575e5f0aaSAlex Zinenko       // to be interpreted as 0 and we can just skip the extra dimensions.
171675e5f0aaSAlex Zinenko       for (unsigned i = 0, e = std::min(inferredShapeRank,
171775e5f0aaSAlex Zinenko                                         subViewOp.getMixedOffsets().size());
171875e5f0aaSAlex Zinenko            i < e; ++i) {
171975e5f0aaSAlex Zinenko         Value offset =
172075e5f0aaSAlex Zinenko             // TODO: need OpFoldResult ODS adaptor to clean this up.
172175e5f0aaSAlex Zinenko             subViewOp.isDynamicOffset(i)
1722ef976337SRiver Riddle                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
172375e5f0aaSAlex Zinenko                 : rewriter.create<LLVM::ConstantOp>(
172475e5f0aaSAlex Zinenko                       loc, llvmIndexType,
172575e5f0aaSAlex Zinenko                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
172675e5f0aaSAlex Zinenko         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
172775e5f0aaSAlex Zinenko         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
172875e5f0aaSAlex Zinenko       }
172975e5f0aaSAlex Zinenko       targetMemRef.setOffset(rewriter, loc, baseOffset);
173075e5f0aaSAlex Zinenko     }
173175e5f0aaSAlex Zinenko 
173275e5f0aaSAlex Zinenko     // Update sizes and strides.
173375e5f0aaSAlex Zinenko     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
173475e5f0aaSAlex Zinenko     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
173575e5f0aaSAlex Zinenko     assert(mixedSizes.size() == mixedStrides.size() &&
173675e5f0aaSAlex Zinenko            "expected sizes and strides of equal length");
17376635c12aSBenjamin Kramer     llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
173875e5f0aaSAlex Zinenko     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
173975e5f0aaSAlex Zinenko          i >= 0 && j >= 0; --i) {
17406635c12aSBenjamin Kramer       if (unusedDims.test(i))
174175e5f0aaSAlex Zinenko         continue;
174275e5f0aaSAlex Zinenko 
174375e5f0aaSAlex Zinenko       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
174475e5f0aaSAlex Zinenko       // In this case, the size is guaranteed to be interpreted as Dim and the
174575e5f0aaSAlex Zinenko       // stride as 1.
174675e5f0aaSAlex Zinenko       Value size, stride;
174775e5f0aaSAlex Zinenko       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
174875e5f0aaSAlex Zinenko         // If the static size is available, use it directly. This is similar to
174975e5f0aaSAlex Zinenko         // the folding of dim(constant-op) but removes the need for dim to be
175075e5f0aaSAlex Zinenko         // aware of LLVM constants and for this pass to be aware of std
175175e5f0aaSAlex Zinenko         // constants.
175275e5f0aaSAlex Zinenko         int64_t staticSize =
1753136d746eSJacques Pienaar             subViewOp.getSource().getType().cast<MemRefType>().getShape()[i];
175475e5f0aaSAlex Zinenko         if (staticSize != ShapedType::kDynamicSize) {
175575e5f0aaSAlex Zinenko           size = rewriter.create<LLVM::ConstantOp>(
175675e5f0aaSAlex Zinenko               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
175775e5f0aaSAlex Zinenko         } else {
175875e5f0aaSAlex Zinenko           Value pos = rewriter.create<LLVM::ConstantOp>(
175975e5f0aaSAlex Zinenko               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1760881dc34fSAlex Zinenko           Value dim =
1761136d746eSJacques Pienaar               rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos);
1762881dc34fSAlex Zinenko           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1763881dc34fSAlex Zinenko               loc, llvmIndexType, dim);
1764881dc34fSAlex Zinenko           size = cast.getResult(0);
176575e5f0aaSAlex Zinenko         }
176675e5f0aaSAlex Zinenko         stride = rewriter.create<LLVM::ConstantOp>(
176775e5f0aaSAlex Zinenko             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
176875e5f0aaSAlex Zinenko       } else {
176975e5f0aaSAlex Zinenko         // TODO: need OpFoldResult ODS adaptor to clean this up.
177075e5f0aaSAlex Zinenko         size =
177175e5f0aaSAlex Zinenko             subViewOp.isDynamicSize(i)
1772ef976337SRiver Riddle                 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
177375e5f0aaSAlex Zinenko                 : rewriter.create<LLVM::ConstantOp>(
177475e5f0aaSAlex Zinenko                       loc, llvmIndexType,
177575e5f0aaSAlex Zinenko                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
177675e5f0aaSAlex Zinenko         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
177775e5f0aaSAlex Zinenko           stride = rewriter.create<LLVM::ConstantOp>(
177875e5f0aaSAlex Zinenko               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
177975e5f0aaSAlex Zinenko         } else {
1780ef976337SRiver Riddle           stride =
1781ef976337SRiver Riddle               subViewOp.isDynamicStride(i)
1782ef976337SRiver Riddle                   ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
178375e5f0aaSAlex Zinenko                   : rewriter.create<LLVM::ConstantOp>(
178475e5f0aaSAlex Zinenko                         loc, llvmIndexType,
178575e5f0aaSAlex Zinenko                         rewriter.getI64IntegerAttr(
178675e5f0aaSAlex Zinenko                             subViewOp.getStaticStride(i)));
178775e5f0aaSAlex Zinenko           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
178875e5f0aaSAlex Zinenko         }
178975e5f0aaSAlex Zinenko       }
179075e5f0aaSAlex Zinenko       targetMemRef.setSize(rewriter, loc, j, size);
179175e5f0aaSAlex Zinenko       targetMemRef.setStride(rewriter, loc, j, stride);
179275e5f0aaSAlex Zinenko       j--;
179375e5f0aaSAlex Zinenko     }
179475e5f0aaSAlex Zinenko 
179575e5f0aaSAlex Zinenko     rewriter.replaceOp(subViewOp, {targetMemRef});
179675e5f0aaSAlex Zinenko     return success();
179775e5f0aaSAlex Zinenko   }
179875e5f0aaSAlex Zinenko };
179975e5f0aaSAlex Zinenko 
180075e5f0aaSAlex Zinenko /// Conversion pattern that transforms a transpose op into:
180175e5f0aaSAlex Zinenko ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
180275e5f0aaSAlex Zinenko ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
180375e5f0aaSAlex Zinenko ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
180475e5f0aaSAlex Zinenko ///      and stride. Size and stride are permutations of the original values.
180575e5f0aaSAlex Zinenko ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
180675e5f0aaSAlex Zinenko /// The transpose op is replaced by the alloca'ed pointer.
180775e5f0aaSAlex Zinenko class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
180875e5f0aaSAlex Zinenko public:
180975e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
181075e5f0aaSAlex Zinenko 
181175e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite(memref::TransposeOp transposeOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1812ef976337SRiver Riddle   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
181375e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
181475e5f0aaSAlex Zinenko     auto loc = transposeOp.getLoc();
1815136d746eSJacques Pienaar     MemRefDescriptor viewMemRef(adaptor.getIn());
181675e5f0aaSAlex Zinenko 
181775e5f0aaSAlex Zinenko     // No permutation, early exit.
1818136d746eSJacques Pienaar     if (transposeOp.getPermutation().isIdentity())
181975e5f0aaSAlex Zinenko       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
182075e5f0aaSAlex Zinenko 
182175e5f0aaSAlex Zinenko     auto targetMemRef = MemRefDescriptor::undef(
182275e5f0aaSAlex Zinenko         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
182375e5f0aaSAlex Zinenko 
182475e5f0aaSAlex Zinenko     // Copy the base and aligned pointers from the old descriptor to the new
182575e5f0aaSAlex Zinenko     // one.
182675e5f0aaSAlex Zinenko     targetMemRef.setAllocatedPtr(rewriter, loc,
182775e5f0aaSAlex Zinenko                                  viewMemRef.allocatedPtr(rewriter, loc));
182875e5f0aaSAlex Zinenko     targetMemRef.setAlignedPtr(rewriter, loc,
182975e5f0aaSAlex Zinenko                                viewMemRef.alignedPtr(rewriter, loc));
183075e5f0aaSAlex Zinenko 
183175e5f0aaSAlex Zinenko     // Copy the offset pointer from the old descriptor to the new one.
183275e5f0aaSAlex Zinenko     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
183375e5f0aaSAlex Zinenko 
183475e5f0aaSAlex Zinenko     // Iterate over the dimensions and apply size/stride permutation.
1835e4853be2SMehdi Amini     for (const auto &en :
1836136d746eSJacques Pienaar          llvm::enumerate(transposeOp.getPermutation().getResults())) {
183775e5f0aaSAlex Zinenko       int sourcePos = en.index();
183875e5f0aaSAlex Zinenko       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
183975e5f0aaSAlex Zinenko       targetMemRef.setSize(rewriter, loc, targetPos,
184075e5f0aaSAlex Zinenko                            viewMemRef.size(rewriter, loc, sourcePos));
184175e5f0aaSAlex Zinenko       targetMemRef.setStride(rewriter, loc, targetPos,
184275e5f0aaSAlex Zinenko                              viewMemRef.stride(rewriter, loc, sourcePos));
184375e5f0aaSAlex Zinenko     }
184475e5f0aaSAlex Zinenko 
184575e5f0aaSAlex Zinenko     rewriter.replaceOp(transposeOp, {targetMemRef});
184675e5f0aaSAlex Zinenko     return success();
184775e5f0aaSAlex Zinenko   }
184875e5f0aaSAlex Zinenko };
184975e5f0aaSAlex Zinenko 
185075e5f0aaSAlex Zinenko /// Conversion pattern that transforms an op into:
185175e5f0aaSAlex Zinenko ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
185275e5f0aaSAlex Zinenko ///   2. Updates to the descriptor to introduce the data ptr, offset, size
185375e5f0aaSAlex Zinenko ///      and stride.
185475e5f0aaSAlex Zinenko /// The view op is replaced by the descriptor.
185575e5f0aaSAlex Zinenko struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
185675e5f0aaSAlex Zinenko   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
185775e5f0aaSAlex Zinenko 
185875e5f0aaSAlex Zinenko   // Build and return the value for the idx^th shape dimension, either by
185975e5f0aaSAlex Zinenko   // returning the constant shape dimension or counting the proper dynamic size.
getSize__anon7a9e10510111::ViewOpLowering186075e5f0aaSAlex Zinenko   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
186175e5f0aaSAlex Zinenko                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
186275e5f0aaSAlex Zinenko                 unsigned idx) const {
186375e5f0aaSAlex Zinenko     assert(idx < shape.size());
186475e5f0aaSAlex Zinenko     if (!ShapedType::isDynamic(shape[idx]))
186575e5f0aaSAlex Zinenko       return createIndexConstant(rewriter, loc, shape[idx]);
186675e5f0aaSAlex Zinenko     // Count the number of dynamic dims in range [0, idx]
1867380a1b20SKazu Hirata     unsigned nDynamic =
1868380a1b20SKazu Hirata         llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
186975e5f0aaSAlex Zinenko     return dynamicSizes[nDynamic];
187075e5f0aaSAlex Zinenko   }
187175e5f0aaSAlex Zinenko 
187275e5f0aaSAlex Zinenko   // Build and return the idx^th stride, either by returning the constant stride
187375e5f0aaSAlex Zinenko   // or by computing the dynamic stride from the current `runningStride` and
187475e5f0aaSAlex Zinenko   // `nextSize`. The caller should keep a running stride and update it with the
187575e5f0aaSAlex Zinenko   // result returned by this function.
getStride__anon7a9e10510111::ViewOpLowering187675e5f0aaSAlex Zinenko   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
187775e5f0aaSAlex Zinenko                   ArrayRef<int64_t> strides, Value nextSize,
187875e5f0aaSAlex Zinenko                   Value runningStride, unsigned idx) const {
187975e5f0aaSAlex Zinenko     assert(idx < strides.size());
1880676bfb2aSRiver Riddle     if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
188175e5f0aaSAlex Zinenko       return createIndexConstant(rewriter, loc, strides[idx]);
188275e5f0aaSAlex Zinenko     if (nextSize)
188375e5f0aaSAlex Zinenko       return runningStride
188475e5f0aaSAlex Zinenko                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
188575e5f0aaSAlex Zinenko                  : nextSize;
188675e5f0aaSAlex Zinenko     assert(!runningStride);
188775e5f0aaSAlex Zinenko     return createIndexConstant(rewriter, loc, 1);
188875e5f0aaSAlex Zinenko   }
188975e5f0aaSAlex Zinenko 
189075e5f0aaSAlex Zinenko   LogicalResult
matchAndRewrite__anon7a9e10510111::ViewOpLowering1891ef976337SRiver Riddle   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
189275e5f0aaSAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
189375e5f0aaSAlex Zinenko     auto loc = viewOp.getLoc();
189475e5f0aaSAlex Zinenko 
189575e5f0aaSAlex Zinenko     auto viewMemRefType = viewOp.getType();
189675e5f0aaSAlex Zinenko     auto targetElementTy =
189775e5f0aaSAlex Zinenko         typeConverter->convertType(viewMemRefType.getElementType());
189875e5f0aaSAlex Zinenko     auto targetDescTy = typeConverter->convertType(viewMemRefType);
189975e5f0aaSAlex Zinenko     if (!targetDescTy || !targetElementTy ||
190075e5f0aaSAlex Zinenko         !LLVM::isCompatibleType(targetElementTy) ||
190175e5f0aaSAlex Zinenko         !LLVM::isCompatibleType(targetDescTy))
190275e5f0aaSAlex Zinenko       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
190375e5f0aaSAlex Zinenko              failure();
190475e5f0aaSAlex Zinenko 
190575e5f0aaSAlex Zinenko     int64_t offset;
190675e5f0aaSAlex Zinenko     SmallVector<int64_t, 4> strides;
190775e5f0aaSAlex Zinenko     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
190875e5f0aaSAlex Zinenko     if (failed(successStrides))
190975e5f0aaSAlex Zinenko       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
191075e5f0aaSAlex Zinenko     assert(offset == 0 && "expected offset to be 0");
191175e5f0aaSAlex Zinenko 
1912705f048cSEugene Zhulenev     // Target memref must be contiguous in memory (innermost stride is 1), or
1913705f048cSEugene Zhulenev     // empty (special case when at least one of the memref dimensions is 0).
1914705f048cSEugene Zhulenev     if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1915705f048cSEugene Zhulenev       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1916705f048cSEugene Zhulenev              failure();
1917705f048cSEugene Zhulenev 
191875e5f0aaSAlex Zinenko     // Create the descriptor.
1919136d746eSJacques Pienaar     MemRefDescriptor sourceMemRef(adaptor.getSource());
192075e5f0aaSAlex Zinenko     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
192175e5f0aaSAlex Zinenko 
192275e5f0aaSAlex Zinenko     // Field 1: Copy the allocated pointer, used for malloc/free.
192375e5f0aaSAlex Zinenko     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1924136d746eSJacques Pienaar     auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
192575e5f0aaSAlex Zinenko     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
192675e5f0aaSAlex Zinenko         loc,
192775e5f0aaSAlex Zinenko         LLVM::LLVMPointerType::get(targetElementTy,
192875e5f0aaSAlex Zinenko                                    srcMemRefType.getMemorySpaceAsInt()),
192975e5f0aaSAlex Zinenko         allocatedPtr);
193075e5f0aaSAlex Zinenko     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
193175e5f0aaSAlex Zinenko 
193275e5f0aaSAlex Zinenko     // Field 2: Copy the actual aligned pointer to payload.
193375e5f0aaSAlex Zinenko     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1934136d746eSJacques Pienaar     alignedPtr = rewriter.create<LLVM::GEPOp>(
1935136d746eSJacques Pienaar         loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift());
193675e5f0aaSAlex Zinenko     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
193775e5f0aaSAlex Zinenko         loc,
193875e5f0aaSAlex Zinenko         LLVM::LLVMPointerType::get(targetElementTy,
193975e5f0aaSAlex Zinenko                                    srcMemRefType.getMemorySpaceAsInt()),
194075e5f0aaSAlex Zinenko         alignedPtr);
194175e5f0aaSAlex Zinenko     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
194275e5f0aaSAlex Zinenko 
194375e5f0aaSAlex Zinenko     // Field 3: The offset in the resulting type must be 0. This is because of
194475e5f0aaSAlex Zinenko     // the type change: an offset on srcType* may not be expressible as an
194575e5f0aaSAlex Zinenko     // offset on dstType*.
194675e5f0aaSAlex Zinenko     targetMemRef.setOffset(rewriter, loc,
194775e5f0aaSAlex Zinenko                            createIndexConstant(rewriter, loc, offset));
194875e5f0aaSAlex Zinenko 
194975e5f0aaSAlex Zinenko     // Early exit for 0-D corner case.
195075e5f0aaSAlex Zinenko     if (viewMemRefType.getRank() == 0)
195175e5f0aaSAlex Zinenko       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
195275e5f0aaSAlex Zinenko 
195375e5f0aaSAlex Zinenko     // Fields 4 and 5: Update sizes and strides.
195475e5f0aaSAlex Zinenko     Value stride = nullptr, nextSize = nullptr;
195575e5f0aaSAlex Zinenko     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
195675e5f0aaSAlex Zinenko       // Update size.
1957136d746eSJacques Pienaar       Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1958136d746eSJacques Pienaar                            adaptor.getSizes(), i);
195975e5f0aaSAlex Zinenko       targetMemRef.setSize(rewriter, loc, i, size);
196075e5f0aaSAlex Zinenko       // Update stride.
196175e5f0aaSAlex Zinenko       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
196275e5f0aaSAlex Zinenko       targetMemRef.setStride(rewriter, loc, i, stride);
196375e5f0aaSAlex Zinenko       nextSize = size;
196475e5f0aaSAlex Zinenko     }
196575e5f0aaSAlex Zinenko 
196675e5f0aaSAlex Zinenko     rewriter.replaceOp(viewOp, {targetMemRef});
196775e5f0aaSAlex Zinenko     return success();
196875e5f0aaSAlex Zinenko   }
196975e5f0aaSAlex Zinenko };
197075e5f0aaSAlex Zinenko 
1971a6a583daSWilliam S. Moses //===----------------------------------------------------------------------===//
1972a6a583daSWilliam S. Moses // AtomicRMWOpLowering
1973a6a583daSWilliam S. Moses //===----------------------------------------------------------------------===//
1974a6a583daSWilliam S. Moses 
197523aa5a74SRiver Riddle /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1976a6a583daSWilliam S. Moses /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1977a6a583daSWilliam S. Moses static Optional<LLVM::AtomicBinOp>
matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp)1978a6a583daSWilliam S. Moses matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1979136d746eSJacques Pienaar   switch (atomicOp.getKind()) {
1980a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::addf:
1981a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::fadd;
1982a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::addi:
1983a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::add;
1984a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::assign:
1985a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::xchg;
1986a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::maxs:
1987a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::max;
1988a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::maxu:
1989a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::umax;
1990a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::mins:
1991a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::min;
1992a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::minu:
1993a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::umin;
1994a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::ori:
1995a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::_or;
1996a6a583daSWilliam S. Moses   case arith::AtomicRMWKind::andi:
1997a6a583daSWilliam S. Moses     return LLVM::AtomicBinOp::_and;
1998a6a583daSWilliam S. Moses   default:
1999a6a583daSWilliam S. Moses     return llvm::None;
2000a6a583daSWilliam S. Moses   }
2001a6a583daSWilliam S. Moses   llvm_unreachable("Invalid AtomicRMWKind");
2002a6a583daSWilliam S. Moses }
2003a6a583daSWilliam S. Moses 
2004a6a583daSWilliam S. Moses struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
2005a6a583daSWilliam S. Moses   using Base::Base;
2006a6a583daSWilliam S. Moses 
2007a6a583daSWilliam S. Moses   LogicalResult
matchAndRewrite__anon7a9e10510111::AtomicRMWOpLowering2008a6a583daSWilliam S. Moses   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
2009a6a583daSWilliam S. Moses                   ConversionPatternRewriter &rewriter) const override {
2010a6a583daSWilliam S. Moses     if (failed(match(atomicOp)))
2011a6a583daSWilliam S. Moses       return failure();
2012a6a583daSWilliam S. Moses     auto maybeKind = matchSimpleAtomicOp(atomicOp);
2013a6a583daSWilliam S. Moses     if (!maybeKind)
2014a6a583daSWilliam S. Moses       return failure();
2015136d746eSJacques Pienaar     auto resultType = adaptor.getValue().getType();
2016a6a583daSWilliam S. Moses     auto memRefType = atomicOp.getMemRefType();
2017a6a583daSWilliam S. Moses     auto dataPtr =
2018136d746eSJacques Pienaar         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
2019136d746eSJacques Pienaar                              adaptor.getIndices(), rewriter);
2020a6a583daSWilliam S. Moses     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
2021136d746eSJacques Pienaar         atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
2022a6a583daSWilliam S. Moses         LLVM::AtomicOrdering::acq_rel);
2023a6a583daSWilliam S. Moses     return success();
2024a6a583daSWilliam S. Moses   }
2025a6a583daSWilliam S. Moses };
2026a6a583daSWilliam S. Moses 
202775e5f0aaSAlex Zinenko } // namespace
202875e5f0aaSAlex Zinenko 
populateMemRefToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)202975e5f0aaSAlex Zinenko void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
203075e5f0aaSAlex Zinenko                                                   RewritePatternSet &patterns) {
203175e5f0aaSAlex Zinenko   // clang-format off
203275e5f0aaSAlex Zinenko   patterns.add<
203375e5f0aaSAlex Zinenko       AllocaOpLowering,
203475e5f0aaSAlex Zinenko       AllocaScopeOpLowering,
2035a6a583daSWilliam S. Moses       AtomicRMWOpLowering,
203675e5f0aaSAlex Zinenko       AssumeAlignmentOpLowering,
203775e5f0aaSAlex Zinenko       DimOpLowering,
2038632a4f88SRiver Riddle       GenericAtomicRMWOpLowering,
203975e5f0aaSAlex Zinenko       GlobalMemrefOpLowering,
204075e5f0aaSAlex Zinenko       GetGlobalMemrefOpLowering,
204175e5f0aaSAlex Zinenko       LoadOpLowering,
204275e5f0aaSAlex Zinenko       MemRefCastOpLowering,
204375e5f0aaSAlex Zinenko       MemRefCopyOpLowering,
204475e5f0aaSAlex Zinenko       MemRefReinterpretCastOpLowering,
204575e5f0aaSAlex Zinenko       MemRefReshapeOpLowering,
204675e5f0aaSAlex Zinenko       PrefetchOpLowering,
204715f8f3e2SAlexander Belyaev       RankOpLowering,
204846ef86b5SAlexander Belyaev       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
204946ef86b5SAlexander Belyaev       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
205075e5f0aaSAlex Zinenko       StoreOpLowering,
205175e5f0aaSAlex Zinenko       SubViewOpLowering,
205275e5f0aaSAlex Zinenko       TransposeOpLowering,
205375e5f0aaSAlex Zinenko       ViewOpLowering>(converter);
205475e5f0aaSAlex Zinenko   // clang-format on
205575e5f0aaSAlex Zinenko   auto allocLowering = converter.getOptions().allocLowering;
205675e5f0aaSAlex Zinenko   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
2057d04c2b2fSMehdi Amini     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
205875e5f0aaSAlex Zinenko   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
205975e5f0aaSAlex Zinenko     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
206075e5f0aaSAlex Zinenko }
206175e5f0aaSAlex Zinenko 
206275e5f0aaSAlex Zinenko namespace {
206375e5f0aaSAlex Zinenko struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
206475e5f0aaSAlex Zinenko   MemRefToLLVMPass() = default;
206575e5f0aaSAlex Zinenko 
runOnOperation__anon7a9e10510811::MemRefToLLVMPass206675e5f0aaSAlex Zinenko   void runOnOperation() override {
206775e5f0aaSAlex Zinenko     Operation *op = getOperation();
206875e5f0aaSAlex Zinenko     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
206975e5f0aaSAlex Zinenko     LowerToLLVMOptions options(&getContext(),
207075e5f0aaSAlex Zinenko                                dataLayoutAnalysis.getAtOrAbove(op));
207175e5f0aaSAlex Zinenko     options.allocLowering =
207275e5f0aaSAlex Zinenko         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
207375e5f0aaSAlex Zinenko                          : LowerToLLVMOptions::AllocLowering::Malloc);
2074a8601f11SMichele Scuttari 
2075a8601f11SMichele Scuttari     options.useGenericFunctions = useGenericFunctions;
2076a8601f11SMichele Scuttari 
207775e5f0aaSAlex Zinenko     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
207875e5f0aaSAlex Zinenko       options.overrideIndexBitwidth(indexBitwidth);
207975e5f0aaSAlex Zinenko 
208075e5f0aaSAlex Zinenko     LLVMTypeConverter typeConverter(&getContext(), options,
208175e5f0aaSAlex Zinenko                                     &dataLayoutAnalysis);
208275e5f0aaSAlex Zinenko     RewritePatternSet patterns(&getContext());
208375e5f0aaSAlex Zinenko     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
208475e5f0aaSAlex Zinenko     LLVMConversionTarget target(getContext());
208558ceae95SRiver Riddle     target.addLegalOp<func::FuncOp>();
208675e5f0aaSAlex Zinenko     if (failed(applyPartialConversion(op, target, std::move(patterns))))
208775e5f0aaSAlex Zinenko       signalPassFailure();
208875e5f0aaSAlex Zinenko   }
208975e5f0aaSAlex Zinenko };
209075e5f0aaSAlex Zinenko } // namespace
209175e5f0aaSAlex Zinenko 
createMemRefToLLVMPass()209275e5f0aaSAlex Zinenko std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
209375e5f0aaSAlex Zinenko   return std::make_unique<MemRefToLLVMPass>();
209475e5f0aaSAlex Zinenko }
2095