1*75e5f0aaSAlex Zinenko //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2*75e5f0aaSAlex Zinenko //
3*75e5f0aaSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*75e5f0aaSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5*75e5f0aaSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*75e5f0aaSAlex Zinenko //
7*75e5f0aaSAlex Zinenko //===----------------------------------------------------------------------===//
8*75e5f0aaSAlex Zinenko
9*75e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
10*75e5f0aaSAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11*75e5f0aaSAlex Zinenko
12*75e5f0aaSAlex Zinenko using namespace mlir;
13*75e5f0aaSAlex Zinenko
createAligned(ConversionPatternRewriter & rewriter,Location loc,Value input,Value alignment)14*75e5f0aaSAlex Zinenko Value AllocLikeOpLLVMLowering::createAligned(
15*75e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter, Location loc, Value input,
16*75e5f0aaSAlex Zinenko Value alignment) {
17*75e5f0aaSAlex Zinenko Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
18*75e5f0aaSAlex Zinenko Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
19*75e5f0aaSAlex Zinenko Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
20*75e5f0aaSAlex Zinenko Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
21*75e5f0aaSAlex Zinenko return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
22*75e5f0aaSAlex Zinenko }
23*75e5f0aaSAlex Zinenko
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const24*75e5f0aaSAlex Zinenko LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
25*75e5f0aaSAlex Zinenko Operation *op, ArrayRef<Value> operands,
26*75e5f0aaSAlex Zinenko ConversionPatternRewriter &rewriter) const {
27*75e5f0aaSAlex Zinenko MemRefType memRefType = getMemRefResultType(op);
28*75e5f0aaSAlex Zinenko if (!isConvertibleAndHasIdentityMaps(memRefType))
29*75e5f0aaSAlex Zinenko return rewriter.notifyMatchFailure(op, "incompatible memref type");
30*75e5f0aaSAlex Zinenko auto loc = op->getLoc();
31*75e5f0aaSAlex Zinenko
32*75e5f0aaSAlex Zinenko // Get actual sizes of the memref as values: static sizes are constant
33*75e5f0aaSAlex Zinenko // values and dynamic sizes are passed to 'alloc' as operands. In case of
34*75e5f0aaSAlex Zinenko // zero-dimensional memref, assume a scalar (size 1).
35*75e5f0aaSAlex Zinenko SmallVector<Value, 4> sizes;
36*75e5f0aaSAlex Zinenko SmallVector<Value, 4> strides;
37*75e5f0aaSAlex Zinenko Value sizeBytes;
38*75e5f0aaSAlex Zinenko this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
39*75e5f0aaSAlex Zinenko strides, sizeBytes);
40*75e5f0aaSAlex Zinenko
41*75e5f0aaSAlex Zinenko // Allocate the underlying buffer.
42*75e5f0aaSAlex Zinenko Value allocatedPtr;
43*75e5f0aaSAlex Zinenko Value alignedPtr;
44*75e5f0aaSAlex Zinenko std::tie(allocatedPtr, alignedPtr) =
45*75e5f0aaSAlex Zinenko this->allocateBuffer(rewriter, loc, sizeBytes, op);
46*75e5f0aaSAlex Zinenko
47*75e5f0aaSAlex Zinenko // Create the MemRef descriptor.
48*75e5f0aaSAlex Zinenko auto memRefDescriptor = this->createMemRefDescriptor(
49*75e5f0aaSAlex Zinenko loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
50*75e5f0aaSAlex Zinenko
51*75e5f0aaSAlex Zinenko // Return the final value of the descriptor.
52*75e5f0aaSAlex Zinenko rewriter.replaceOp(op, {memRefDescriptor});
53*75e5f0aaSAlex Zinenko return success();
54*75e5f0aaSAlex Zinenko }
55