1*75e5f0aaSAlex Zinenko //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2*75e5f0aaSAlex Zinenko //
3*75e5f0aaSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*75e5f0aaSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5*75e5f0aaSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*75e5f0aaSAlex Zinenko //
7*75e5f0aaSAlex Zinenko //===----------------------------------------------------------------------===//
8*75e5f0aaSAlex Zinenko 
9*75e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
10*75e5f0aaSAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11*75e5f0aaSAlex Zinenko 
12*75e5f0aaSAlex Zinenko using namespace mlir;
13*75e5f0aaSAlex Zinenko 
createAligned(ConversionPatternRewriter & rewriter,Location loc,Value input,Value alignment)14*75e5f0aaSAlex Zinenko Value AllocLikeOpLLVMLowering::createAligned(
15*75e5f0aaSAlex Zinenko     ConversionPatternRewriter &rewriter, Location loc, Value input,
16*75e5f0aaSAlex Zinenko     Value alignment) {
17*75e5f0aaSAlex Zinenko   Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
18*75e5f0aaSAlex Zinenko   Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
19*75e5f0aaSAlex Zinenko   Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
20*75e5f0aaSAlex Zinenko   Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
21*75e5f0aaSAlex Zinenko   return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
22*75e5f0aaSAlex Zinenko }
23*75e5f0aaSAlex Zinenko 
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const24*75e5f0aaSAlex Zinenko LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
25*75e5f0aaSAlex Zinenko     Operation *op, ArrayRef<Value> operands,
26*75e5f0aaSAlex Zinenko     ConversionPatternRewriter &rewriter) const {
27*75e5f0aaSAlex Zinenko   MemRefType memRefType = getMemRefResultType(op);
28*75e5f0aaSAlex Zinenko   if (!isConvertibleAndHasIdentityMaps(memRefType))
29*75e5f0aaSAlex Zinenko     return rewriter.notifyMatchFailure(op, "incompatible memref type");
30*75e5f0aaSAlex Zinenko   auto loc = op->getLoc();
31*75e5f0aaSAlex Zinenko 
32*75e5f0aaSAlex Zinenko   // Get actual sizes of the memref as values: static sizes are constant
33*75e5f0aaSAlex Zinenko   // values and dynamic sizes are passed to 'alloc' as operands.  In case of
34*75e5f0aaSAlex Zinenko   // zero-dimensional memref, assume a scalar (size 1).
35*75e5f0aaSAlex Zinenko   SmallVector<Value, 4> sizes;
36*75e5f0aaSAlex Zinenko   SmallVector<Value, 4> strides;
37*75e5f0aaSAlex Zinenko   Value sizeBytes;
38*75e5f0aaSAlex Zinenko   this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
39*75e5f0aaSAlex Zinenko                                  strides, sizeBytes);
40*75e5f0aaSAlex Zinenko 
41*75e5f0aaSAlex Zinenko   // Allocate the underlying buffer.
42*75e5f0aaSAlex Zinenko   Value allocatedPtr;
43*75e5f0aaSAlex Zinenko   Value alignedPtr;
44*75e5f0aaSAlex Zinenko   std::tie(allocatedPtr, alignedPtr) =
45*75e5f0aaSAlex Zinenko       this->allocateBuffer(rewriter, loc, sizeBytes, op);
46*75e5f0aaSAlex Zinenko 
47*75e5f0aaSAlex Zinenko   // Create the MemRef descriptor.
48*75e5f0aaSAlex Zinenko   auto memRefDescriptor = this->createMemRefDescriptor(
49*75e5f0aaSAlex Zinenko       loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
50*75e5f0aaSAlex Zinenko 
51*75e5f0aaSAlex Zinenko   // Return the final value of the descriptor.
52*75e5f0aaSAlex Zinenko   rewriter.replaceOp(op, {memRefDescriptor});
53*75e5f0aaSAlex Zinenko   return success();
54*75e5f0aaSAlex Zinenko }
55