1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11
12 using namespace mlir;
13
createAligned(ConversionPatternRewriter & rewriter,Location loc,Value input,Value alignment)14 Value AllocLikeOpLLVMLowering::createAligned(
15 ConversionPatternRewriter &rewriter, Location loc, Value input,
16 Value alignment) {
17 Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
18 Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
19 Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
20 Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
21 return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
22 }
23
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const24 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
25 Operation *op, ArrayRef<Value> operands,
26 ConversionPatternRewriter &rewriter) const {
27 MemRefType memRefType = getMemRefResultType(op);
28 if (!isConvertibleAndHasIdentityMaps(memRefType))
29 return rewriter.notifyMatchFailure(op, "incompatible memref type");
30 auto loc = op->getLoc();
31
32 // Get actual sizes of the memref as values: static sizes are constant
33 // values and dynamic sizes are passed to 'alloc' as operands. In case of
34 // zero-dimensional memref, assume a scalar (size 1).
35 SmallVector<Value, 4> sizes;
36 SmallVector<Value, 4> strides;
37 Value sizeBytes;
38 this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
39 strides, sizeBytes);
40
41 // Allocate the underlying buffer.
42 Value allocatedPtr;
43 Value alignedPtr;
44 std::tie(allocatedPtr, alignedPtr) =
45 this->allocateBuffer(rewriter, loc, sizeBytes, op);
46
47 // Create the MemRef descriptor.
48 auto memRefDescriptor = this->createMemRefDescriptor(
49 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
50
51 // Return the final value of the descriptor.
52 rewriter.replaceOp(op, {memRefDescriptor});
53 return success();
54 }
55