1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11 
12 using namespace mlir;
13 
createAligned(ConversionPatternRewriter & rewriter,Location loc,Value input,Value alignment)14 Value AllocLikeOpLLVMLowering::createAligned(
15     ConversionPatternRewriter &rewriter, Location loc, Value input,
16     Value alignment) {
17   Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
18   Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
19   Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
20   Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
21   return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
22 }
23 
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const24 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
25     Operation *op, ArrayRef<Value> operands,
26     ConversionPatternRewriter &rewriter) const {
27   MemRefType memRefType = getMemRefResultType(op);
28   if (!isConvertibleAndHasIdentityMaps(memRefType))
29     return rewriter.notifyMatchFailure(op, "incompatible memref type");
30   auto loc = op->getLoc();
31 
32   // Get actual sizes of the memref as values: static sizes are constant
33   // values and dynamic sizes are passed to 'alloc' as operands.  In case of
34   // zero-dimensional memref, assume a scalar (size 1).
35   SmallVector<Value, 4> sizes;
36   SmallVector<Value, 4> strides;
37   Value sizeBytes;
38   this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
39                                  strides, sizeBytes);
40 
41   // Allocate the underlying buffer.
42   Value allocatedPtr;
43   Value alignedPtr;
44   std::tie(allocatedPtr, alignedPtr) =
45       this->allocateBuffer(rewriter, loc, sizeBytes, op);
46 
47   // Create the MemRef descriptor.
48   auto memRefDescriptor = this->createMemRefDescriptor(
49       loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
50 
51   // Return the final value of the descriptor.
52   rewriter.replaceOp(op, {memRefDescriptor});
53   return success();
54 }
55