1a54f4eaeSMogball //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2a54f4eaeSMogball //
3a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information.
5a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a54f4eaeSMogball //
7a54f4eaeSMogball //===----------------------------------------------------------------------===//
8a54f4eaeSMogball
9a54f4eaeSMogball #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10a54f4eaeSMogball #include "../PassDetail.h"
11a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
13a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14a54f4eaeSMogball #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15a54f4eaeSMogball #include "mlir/IR/TypeUtilities.h"
16a54f4eaeSMogball
17a54f4eaeSMogball using namespace mlir;
18a54f4eaeSMogball
19a54f4eaeSMogball namespace {
20a54f4eaeSMogball
21a54f4eaeSMogball //===----------------------------------------------------------------------===//
22a54f4eaeSMogball // Straightforward Op Lowerings
23a54f4eaeSMogball //===----------------------------------------------------------------------===//
24a54f4eaeSMogball
25a54f4eaeSMogball using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
26a54f4eaeSMogball using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
27a54f4eaeSMogball using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
28a54f4eaeSMogball using DivUIOpLowering =
29a54f4eaeSMogball VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
30a54f4eaeSMogball using DivSIOpLowering =
31a54f4eaeSMogball VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
32a54f4eaeSMogball using RemUIOpLowering =
33a54f4eaeSMogball VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
34a54f4eaeSMogball using RemSIOpLowering =
35a54f4eaeSMogball VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
36a54f4eaeSMogball using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
37a54f4eaeSMogball using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
38a54f4eaeSMogball using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
39a54f4eaeSMogball using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
40a54f4eaeSMogball using ShRUIOpLowering =
41a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
42a54f4eaeSMogball using ShRSIOpLowering =
43a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
44a54f4eaeSMogball using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
45a54f4eaeSMogball using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
46a54f4eaeSMogball using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
47a54f4eaeSMogball using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
48a54f4eaeSMogball using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
49a54f4eaeSMogball using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
50a54f4eaeSMogball using ExtUIOpLowering =
51a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
52a54f4eaeSMogball using ExtSIOpLowering =
53a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
54a54f4eaeSMogball using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
55a54f4eaeSMogball using TruncIOpLowering =
56a54f4eaeSMogball VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
57a54f4eaeSMogball using TruncFOpLowering =
58a54f4eaeSMogball VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
59a54f4eaeSMogball using UIToFPOpLowering =
60a54f4eaeSMogball VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
61a54f4eaeSMogball using SIToFPOpLowering =
62a54f4eaeSMogball VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
63a54f4eaeSMogball using FPToUIOpLowering =
64a54f4eaeSMogball VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
65a54f4eaeSMogball using FPToSIOpLowering =
66a54f4eaeSMogball VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
67a54f4eaeSMogball using BitcastOpLowering =
68a54f4eaeSMogball VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
69*dec8af70SRiver Riddle using SelectOpLowering =
70*dec8af70SRiver Riddle VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
71a54f4eaeSMogball
72a54f4eaeSMogball //===----------------------------------------------------------------------===//
73a54f4eaeSMogball // Op Lowering Patterns
74a54f4eaeSMogball //===----------------------------------------------------------------------===//
75a54f4eaeSMogball
76a54f4eaeSMogball /// Directly lower to LLVM op.
77a54f4eaeSMogball struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
78a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern;
79a54f4eaeSMogball
80a54f4eaeSMogball LogicalResult
81a54f4eaeSMogball matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
82a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
83a54f4eaeSMogball };
84a54f4eaeSMogball
85a54f4eaeSMogball /// The lowering of index_cast becomes an integer conversion since index
86a54f4eaeSMogball /// becomes an integer. If the bit width of the source and target integer
87a54f4eaeSMogball /// types is the same, just erase the cast. If the target type is wider,
88a54f4eaeSMogball /// sign-extend the value, otherwise truncate it.
89a54f4eaeSMogball struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
90a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern;
91a54f4eaeSMogball
92a54f4eaeSMogball LogicalResult
93a54f4eaeSMogball matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
94a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
95a54f4eaeSMogball };
96a54f4eaeSMogball
97a54f4eaeSMogball struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
98a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern;
99a54f4eaeSMogball
100a54f4eaeSMogball LogicalResult
101a54f4eaeSMogball matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
102a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
103a54f4eaeSMogball };
104a54f4eaeSMogball
105a54f4eaeSMogball struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
106a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern;
107a54f4eaeSMogball
108a54f4eaeSMogball LogicalResult
109a54f4eaeSMogball matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
110a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
111a54f4eaeSMogball };
112a54f4eaeSMogball
113be0a7e9fSMehdi Amini } // namespace
114a54f4eaeSMogball
115a54f4eaeSMogball //===----------------------------------------------------------------------===//
116a54f4eaeSMogball // ConstantOpLowering
117a54f4eaeSMogball //===----------------------------------------------------------------------===//
118a54f4eaeSMogball
119a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::ConstantOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const120a54f4eaeSMogball ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
121a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
122a54f4eaeSMogball return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
123a54f4eaeSMogball adaptor.getOperands(),
124a54f4eaeSMogball *getTypeConverter(), rewriter);
125a54f4eaeSMogball }
126a54f4eaeSMogball
127a54f4eaeSMogball //===----------------------------------------------------------------------===//
128a54f4eaeSMogball // IndexCastOpLowering
129a54f4eaeSMogball //===----------------------------------------------------------------------===//
130a54f4eaeSMogball
matchAndRewrite(arith::IndexCastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const131a54f4eaeSMogball LogicalResult IndexCastOpLowering::matchAndRewrite(
132a54f4eaeSMogball arith::IndexCastOp op, OpAdaptor adaptor,
133a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
134a54f4eaeSMogball auto targetType = typeConverter->convertType(op.getResult().getType());
135a54f4eaeSMogball auto targetElementType =
136a54f4eaeSMogball typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
137a54f4eaeSMogball .cast<IntegerType>();
138a54f4eaeSMogball auto sourceElementType =
139cfb72fd3SJacques Pienaar getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
140a54f4eaeSMogball unsigned targetBits = targetElementType.getWidth();
141a54f4eaeSMogball unsigned sourceBits = sourceElementType.getWidth();
142a54f4eaeSMogball
143a54f4eaeSMogball if (targetBits == sourceBits)
144cfb72fd3SJacques Pienaar rewriter.replaceOp(op, adaptor.getIn());
145a54f4eaeSMogball else if (targetBits < sourceBits)
146cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
147a54f4eaeSMogball else
148cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
149a54f4eaeSMogball return success();
150a54f4eaeSMogball }
151a54f4eaeSMogball
152a54f4eaeSMogball //===----------------------------------------------------------------------===//
153a54f4eaeSMogball // CmpIOpLowering
154a54f4eaeSMogball //===----------------------------------------------------------------------===//
155a54f4eaeSMogball
156a54f4eaeSMogball // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
157a54f4eaeSMogball // share numerical values so just cast.
158a54f4eaeSMogball template <typename LLVMPredType, typename PredType>
convertCmpPredicate(PredType pred)159a54f4eaeSMogball static LLVMPredType convertCmpPredicate(PredType pred) {
160a54f4eaeSMogball return static_cast<LLVMPredType>(pred);
161a54f4eaeSMogball }
162a54f4eaeSMogball
163a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::CmpIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const164a54f4eaeSMogball CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
165a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
166cfb72fd3SJacques Pienaar auto operandType = adaptor.getLhs().getType();
167a54f4eaeSMogball auto resultType = op.getResult().getType();
168a54f4eaeSMogball
169a54f4eaeSMogball // Handle the scalar and 1D vector cases.
170a54f4eaeSMogball if (!operandType.isa<LLVM::LLVMArrayType>()) {
171a54f4eaeSMogball rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
172a54f4eaeSMogball op, typeConverter->convertType(resultType),
173a54f4eaeSMogball convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
174cfb72fd3SJacques Pienaar adaptor.getLhs(), adaptor.getRhs());
175a54f4eaeSMogball return success();
176a54f4eaeSMogball }
177a54f4eaeSMogball
178a54f4eaeSMogball auto vectorType = resultType.dyn_cast<VectorType>();
179a54f4eaeSMogball if (!vectorType)
180a54f4eaeSMogball return rewriter.notifyMatchFailure(op, "expected vector result type");
181a54f4eaeSMogball
182a54f4eaeSMogball return LLVM::detail::handleMultidimensionalVectors(
183a54f4eaeSMogball op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
184a54f4eaeSMogball [&](Type llvm1DVectorTy, ValueRange operands) {
185a54f4eaeSMogball OpAdaptor adaptor(operands);
186a54f4eaeSMogball return rewriter.create<LLVM::ICmpOp>(
187a54f4eaeSMogball op.getLoc(), llvm1DVectorTy,
188a54f4eaeSMogball convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
189cfb72fd3SJacques Pienaar adaptor.getLhs(), adaptor.getRhs());
190a54f4eaeSMogball },
191a54f4eaeSMogball rewriter);
192a54f4eaeSMogball }
193a54f4eaeSMogball
194a54f4eaeSMogball //===----------------------------------------------------------------------===//
195a54f4eaeSMogball // CmpFOpLowering
196a54f4eaeSMogball //===----------------------------------------------------------------------===//
197a54f4eaeSMogball
198a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const199a54f4eaeSMogball CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
200a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
201cfb72fd3SJacques Pienaar auto operandType = adaptor.getLhs().getType();
202a54f4eaeSMogball auto resultType = op.getResult().getType();
203a54f4eaeSMogball
204a54f4eaeSMogball // Handle the scalar and 1D vector cases.
205a54f4eaeSMogball if (!operandType.isa<LLVM::LLVMArrayType>()) {
206a54f4eaeSMogball rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
207a54f4eaeSMogball op, typeConverter->convertType(resultType),
208a54f4eaeSMogball convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
209cfb72fd3SJacques Pienaar adaptor.getLhs(), adaptor.getRhs());
210a54f4eaeSMogball return success();
211a54f4eaeSMogball }
212a54f4eaeSMogball
213a54f4eaeSMogball auto vectorType = resultType.dyn_cast<VectorType>();
214a54f4eaeSMogball if (!vectorType)
215a54f4eaeSMogball return rewriter.notifyMatchFailure(op, "expected vector result type");
216a54f4eaeSMogball
217a54f4eaeSMogball return LLVM::detail::handleMultidimensionalVectors(
218a54f4eaeSMogball op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
219a54f4eaeSMogball [&](Type llvm1DVectorTy, ValueRange operands) {
220a54f4eaeSMogball OpAdaptor adaptor(operands);
221a54f4eaeSMogball return rewriter.create<LLVM::FCmpOp>(
222a54f4eaeSMogball op.getLoc(), llvm1DVectorTy,
223a54f4eaeSMogball convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
224cfb72fd3SJacques Pienaar adaptor.getLhs(), adaptor.getRhs());
225a54f4eaeSMogball },
226a54f4eaeSMogball rewriter);
227a54f4eaeSMogball }
228a54f4eaeSMogball
229a54f4eaeSMogball //===----------------------------------------------------------------------===//
230a54f4eaeSMogball // Pass Definition
231a54f4eaeSMogball //===----------------------------------------------------------------------===//
232a54f4eaeSMogball
233a54f4eaeSMogball namespace {
234a54f4eaeSMogball struct ConvertArithmeticToLLVMPass
235a54f4eaeSMogball : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> {
236a54f4eaeSMogball ConvertArithmeticToLLVMPass() = default;
237a54f4eaeSMogball
runOnOperation__anonb4369bed0411::ConvertArithmeticToLLVMPass23841574554SRiver Riddle void runOnOperation() override {
239a54f4eaeSMogball LLVMConversionTarget target(getContext());
240a54f4eaeSMogball RewritePatternSet patterns(&getContext());
241a54f4eaeSMogball
242a54f4eaeSMogball LowerToLLVMOptions options(&getContext());
243a54f4eaeSMogball if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
244a54f4eaeSMogball options.overrideIndexBitwidth(indexBitwidth);
245a54f4eaeSMogball
246a54f4eaeSMogball LLVMTypeConverter converter(&getContext(), options);
247a54f4eaeSMogball mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
248a54f4eaeSMogball patterns);
249a54f4eaeSMogball
25041574554SRiver Riddle if (failed(applyPartialConversion(getOperation(), target,
25141574554SRiver Riddle std::move(patterns))))
252a54f4eaeSMogball signalPassFailure();
253a54f4eaeSMogball }
254a54f4eaeSMogball };
255be0a7e9fSMehdi Amini } // namespace
256a54f4eaeSMogball
257a54f4eaeSMogball //===----------------------------------------------------------------------===//
258a54f4eaeSMogball // Pattern Population
259a54f4eaeSMogball //===----------------------------------------------------------------------===//
260a54f4eaeSMogball
populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)261a54f4eaeSMogball void mlir::arith::populateArithmeticToLLVMConversionPatterns(
262a54f4eaeSMogball LLVMTypeConverter &converter, RewritePatternSet &patterns) {
263a54f4eaeSMogball // clang-format off
264a54f4eaeSMogball patterns.add<
265a54f4eaeSMogball ConstantOpLowering,
266a54f4eaeSMogball AddIOpLowering,
267a54f4eaeSMogball SubIOpLowering,
268a54f4eaeSMogball MulIOpLowering,
269a54f4eaeSMogball DivUIOpLowering,
270a54f4eaeSMogball DivSIOpLowering,
271a54f4eaeSMogball RemUIOpLowering,
272a54f4eaeSMogball RemSIOpLowering,
273a54f4eaeSMogball AndIOpLowering,
274a54f4eaeSMogball OrIOpLowering,
275a54f4eaeSMogball XOrIOpLowering,
276a54f4eaeSMogball ShLIOpLowering,
277a54f4eaeSMogball ShRUIOpLowering,
278a54f4eaeSMogball ShRSIOpLowering,
279a54f4eaeSMogball NegFOpLowering,
280a54f4eaeSMogball AddFOpLowering,
281a54f4eaeSMogball SubFOpLowering,
282a54f4eaeSMogball MulFOpLowering,
283a54f4eaeSMogball DivFOpLowering,
284a54f4eaeSMogball RemFOpLowering,
285a54f4eaeSMogball ExtUIOpLowering,
286a54f4eaeSMogball ExtSIOpLowering,
287a54f4eaeSMogball ExtFOpLowering,
288a54f4eaeSMogball TruncIOpLowering,
289a54f4eaeSMogball TruncFOpLowering,
290a54f4eaeSMogball UIToFPOpLowering,
291a54f4eaeSMogball SIToFPOpLowering,
292a54f4eaeSMogball FPToUIOpLowering,
293a54f4eaeSMogball FPToSIOpLowering,
294a54f4eaeSMogball IndexCastOpLowering,
295a54f4eaeSMogball BitcastOpLowering,
296a54f4eaeSMogball CmpIOpLowering,
297*dec8af70SRiver Riddle CmpFOpLowering,
298*dec8af70SRiver Riddle SelectOpLowering
299a54f4eaeSMogball >(converter);
300a54f4eaeSMogball // clang-format on
301a54f4eaeSMogball }
302a54f4eaeSMogball
createConvertArithmeticToLLVMPass()303a54f4eaeSMogball std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() {
304a54f4eaeSMogball return std::make_unique<ConvertArithmeticToLLVMPass>();
305a54f4eaeSMogball }
306