1a54f4eaeSMogball //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2a54f4eaeSMogball //
3a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information.
5a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a54f4eaeSMogball //
7a54f4eaeSMogball //===----------------------------------------------------------------------===//
8a54f4eaeSMogball 
9a54f4eaeSMogball #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10a54f4eaeSMogball #include "../PassDetail.h"
11a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
13a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14a54f4eaeSMogball #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15a54f4eaeSMogball #include "mlir/IR/TypeUtilities.h"
16a54f4eaeSMogball 
17a54f4eaeSMogball using namespace mlir;
18a54f4eaeSMogball 
19a54f4eaeSMogball namespace {
20a54f4eaeSMogball 
21a54f4eaeSMogball //===----------------------------------------------------------------------===//
22a54f4eaeSMogball // Straightforward Op Lowerings
23a54f4eaeSMogball //===----------------------------------------------------------------------===//
24a54f4eaeSMogball 
25a54f4eaeSMogball using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
26a54f4eaeSMogball using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
27a54f4eaeSMogball using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
28a54f4eaeSMogball using DivUIOpLowering =
29a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
30a54f4eaeSMogball using DivSIOpLowering =
31a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
32a54f4eaeSMogball using RemUIOpLowering =
33a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
34a54f4eaeSMogball using RemSIOpLowering =
35a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
36a54f4eaeSMogball using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
37a54f4eaeSMogball using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
38a54f4eaeSMogball using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
39a54f4eaeSMogball using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
40a54f4eaeSMogball using ShRUIOpLowering =
41a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
42a54f4eaeSMogball using ShRSIOpLowering =
43a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
44a54f4eaeSMogball using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
45a54f4eaeSMogball using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
46a54f4eaeSMogball using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
47a54f4eaeSMogball using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
48a54f4eaeSMogball using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
49a54f4eaeSMogball using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
50a54f4eaeSMogball using ExtUIOpLowering =
51a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
52a54f4eaeSMogball using ExtSIOpLowering =
53a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
54a54f4eaeSMogball using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
55a54f4eaeSMogball using TruncIOpLowering =
56a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
57a54f4eaeSMogball using TruncFOpLowering =
58a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
59a54f4eaeSMogball using UIToFPOpLowering =
60a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
61a54f4eaeSMogball using SIToFPOpLowering =
62a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
63a54f4eaeSMogball using FPToUIOpLowering =
64a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
65a54f4eaeSMogball using FPToSIOpLowering =
66a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
67a54f4eaeSMogball using BitcastOpLowering =
68a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
69*dec8af70SRiver Riddle using SelectOpLowering =
70*dec8af70SRiver Riddle     VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
71a54f4eaeSMogball 
72a54f4eaeSMogball //===----------------------------------------------------------------------===//
73a54f4eaeSMogball // Op Lowering Patterns
74a54f4eaeSMogball //===----------------------------------------------------------------------===//
75a54f4eaeSMogball 
76a54f4eaeSMogball /// Directly lower to LLVM op.
77a54f4eaeSMogball struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
78a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern;
79a54f4eaeSMogball 
80a54f4eaeSMogball   LogicalResult
81a54f4eaeSMogball   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
82a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
83a54f4eaeSMogball };
84a54f4eaeSMogball 
85a54f4eaeSMogball /// The lowering of index_cast becomes an integer conversion since index
86a54f4eaeSMogball /// becomes an integer.  If the bit width of the source and target integer
87a54f4eaeSMogball /// types is the same, just erase the cast.  If the target type is wider,
88a54f4eaeSMogball /// sign-extend the value, otherwise truncate it.
89a54f4eaeSMogball struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
90a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern;
91a54f4eaeSMogball 
92a54f4eaeSMogball   LogicalResult
93a54f4eaeSMogball   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
94a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
95a54f4eaeSMogball };
96a54f4eaeSMogball 
97a54f4eaeSMogball struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
98a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern;
99a54f4eaeSMogball 
100a54f4eaeSMogball   LogicalResult
101a54f4eaeSMogball   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
102a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
103a54f4eaeSMogball };
104a54f4eaeSMogball 
105a54f4eaeSMogball struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
106a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern;
107a54f4eaeSMogball 
108a54f4eaeSMogball   LogicalResult
109a54f4eaeSMogball   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
110a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
111a54f4eaeSMogball };
112a54f4eaeSMogball 
113be0a7e9fSMehdi Amini } // namespace
114a54f4eaeSMogball 
115a54f4eaeSMogball //===----------------------------------------------------------------------===//
116a54f4eaeSMogball // ConstantOpLowering
117a54f4eaeSMogball //===----------------------------------------------------------------------===//
118a54f4eaeSMogball 
119a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::ConstantOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const120a54f4eaeSMogball ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
121a54f4eaeSMogball                                     ConversionPatternRewriter &rewriter) const {
122a54f4eaeSMogball   return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
123a54f4eaeSMogball                                        adaptor.getOperands(),
124a54f4eaeSMogball                                        *getTypeConverter(), rewriter);
125a54f4eaeSMogball }
126a54f4eaeSMogball 
127a54f4eaeSMogball //===----------------------------------------------------------------------===//
128a54f4eaeSMogball // IndexCastOpLowering
129a54f4eaeSMogball //===----------------------------------------------------------------------===//
130a54f4eaeSMogball 
matchAndRewrite(arith::IndexCastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const131a54f4eaeSMogball LogicalResult IndexCastOpLowering::matchAndRewrite(
132a54f4eaeSMogball     arith::IndexCastOp op, OpAdaptor adaptor,
133a54f4eaeSMogball     ConversionPatternRewriter &rewriter) const {
134a54f4eaeSMogball   auto targetType = typeConverter->convertType(op.getResult().getType());
135a54f4eaeSMogball   auto targetElementType =
136a54f4eaeSMogball       typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
137a54f4eaeSMogball           .cast<IntegerType>();
138a54f4eaeSMogball   auto sourceElementType =
139cfb72fd3SJacques Pienaar       getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
140a54f4eaeSMogball   unsigned targetBits = targetElementType.getWidth();
141a54f4eaeSMogball   unsigned sourceBits = sourceElementType.getWidth();
142a54f4eaeSMogball 
143a54f4eaeSMogball   if (targetBits == sourceBits)
144cfb72fd3SJacques Pienaar     rewriter.replaceOp(op, adaptor.getIn());
145a54f4eaeSMogball   else if (targetBits < sourceBits)
146cfb72fd3SJacques Pienaar     rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
147a54f4eaeSMogball   else
148cfb72fd3SJacques Pienaar     rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
149a54f4eaeSMogball   return success();
150a54f4eaeSMogball }
151a54f4eaeSMogball 
152a54f4eaeSMogball //===----------------------------------------------------------------------===//
153a54f4eaeSMogball // CmpIOpLowering
154a54f4eaeSMogball //===----------------------------------------------------------------------===//
155a54f4eaeSMogball 
156a54f4eaeSMogball // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
157a54f4eaeSMogball // share numerical values so just cast.
158a54f4eaeSMogball template <typename LLVMPredType, typename PredType>
convertCmpPredicate(PredType pred)159a54f4eaeSMogball static LLVMPredType convertCmpPredicate(PredType pred) {
160a54f4eaeSMogball   return static_cast<LLVMPredType>(pred);
161a54f4eaeSMogball }
162a54f4eaeSMogball 
163a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::CmpIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const164a54f4eaeSMogball CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
165a54f4eaeSMogball                                 ConversionPatternRewriter &rewriter) const {
166cfb72fd3SJacques Pienaar   auto operandType = adaptor.getLhs().getType();
167a54f4eaeSMogball   auto resultType = op.getResult().getType();
168a54f4eaeSMogball 
169a54f4eaeSMogball   // Handle the scalar and 1D vector cases.
170a54f4eaeSMogball   if (!operandType.isa<LLVM::LLVMArrayType>()) {
171a54f4eaeSMogball     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
172a54f4eaeSMogball         op, typeConverter->convertType(resultType),
173a54f4eaeSMogball         convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
174cfb72fd3SJacques Pienaar         adaptor.getLhs(), adaptor.getRhs());
175a54f4eaeSMogball     return success();
176a54f4eaeSMogball   }
177a54f4eaeSMogball 
178a54f4eaeSMogball   auto vectorType = resultType.dyn_cast<VectorType>();
179a54f4eaeSMogball   if (!vectorType)
180a54f4eaeSMogball     return rewriter.notifyMatchFailure(op, "expected vector result type");
181a54f4eaeSMogball 
182a54f4eaeSMogball   return LLVM::detail::handleMultidimensionalVectors(
183a54f4eaeSMogball       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
184a54f4eaeSMogball       [&](Type llvm1DVectorTy, ValueRange operands) {
185a54f4eaeSMogball         OpAdaptor adaptor(operands);
186a54f4eaeSMogball         return rewriter.create<LLVM::ICmpOp>(
187a54f4eaeSMogball             op.getLoc(), llvm1DVectorTy,
188a54f4eaeSMogball             convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
189cfb72fd3SJacques Pienaar             adaptor.getLhs(), adaptor.getRhs());
190a54f4eaeSMogball       },
191a54f4eaeSMogball       rewriter);
192a54f4eaeSMogball }
193a54f4eaeSMogball 
194a54f4eaeSMogball //===----------------------------------------------------------------------===//
195a54f4eaeSMogball // CmpFOpLowering
196a54f4eaeSMogball //===----------------------------------------------------------------------===//
197a54f4eaeSMogball 
198a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const199a54f4eaeSMogball CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
200a54f4eaeSMogball                                 ConversionPatternRewriter &rewriter) const {
201cfb72fd3SJacques Pienaar   auto operandType = adaptor.getLhs().getType();
202a54f4eaeSMogball   auto resultType = op.getResult().getType();
203a54f4eaeSMogball 
204a54f4eaeSMogball   // Handle the scalar and 1D vector cases.
205a54f4eaeSMogball   if (!operandType.isa<LLVM::LLVMArrayType>()) {
206a54f4eaeSMogball     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
207a54f4eaeSMogball         op, typeConverter->convertType(resultType),
208a54f4eaeSMogball         convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
209cfb72fd3SJacques Pienaar         adaptor.getLhs(), adaptor.getRhs());
210a54f4eaeSMogball     return success();
211a54f4eaeSMogball   }
212a54f4eaeSMogball 
213a54f4eaeSMogball   auto vectorType = resultType.dyn_cast<VectorType>();
214a54f4eaeSMogball   if (!vectorType)
215a54f4eaeSMogball     return rewriter.notifyMatchFailure(op, "expected vector result type");
216a54f4eaeSMogball 
217a54f4eaeSMogball   return LLVM::detail::handleMultidimensionalVectors(
218a54f4eaeSMogball       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
219a54f4eaeSMogball       [&](Type llvm1DVectorTy, ValueRange operands) {
220a54f4eaeSMogball         OpAdaptor adaptor(operands);
221a54f4eaeSMogball         return rewriter.create<LLVM::FCmpOp>(
222a54f4eaeSMogball             op.getLoc(), llvm1DVectorTy,
223a54f4eaeSMogball             convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
224cfb72fd3SJacques Pienaar             adaptor.getLhs(), adaptor.getRhs());
225a54f4eaeSMogball       },
226a54f4eaeSMogball       rewriter);
227a54f4eaeSMogball }
228a54f4eaeSMogball 
229a54f4eaeSMogball //===----------------------------------------------------------------------===//
230a54f4eaeSMogball // Pass Definition
231a54f4eaeSMogball //===----------------------------------------------------------------------===//
232a54f4eaeSMogball 
233a54f4eaeSMogball namespace {
234a54f4eaeSMogball struct ConvertArithmeticToLLVMPass
235a54f4eaeSMogball     : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> {
236a54f4eaeSMogball   ConvertArithmeticToLLVMPass() = default;
237a54f4eaeSMogball 
runOnOperation__anonb4369bed0411::ConvertArithmeticToLLVMPass23841574554SRiver Riddle   void runOnOperation() override {
239a54f4eaeSMogball     LLVMConversionTarget target(getContext());
240a54f4eaeSMogball     RewritePatternSet patterns(&getContext());
241a54f4eaeSMogball 
242a54f4eaeSMogball     LowerToLLVMOptions options(&getContext());
243a54f4eaeSMogball     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
244a54f4eaeSMogball       options.overrideIndexBitwidth(indexBitwidth);
245a54f4eaeSMogball 
246a54f4eaeSMogball     LLVMTypeConverter converter(&getContext(), options);
247a54f4eaeSMogball     mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
248a54f4eaeSMogball                                                             patterns);
249a54f4eaeSMogball 
25041574554SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
25141574554SRiver Riddle                                       std::move(patterns))))
252a54f4eaeSMogball       signalPassFailure();
253a54f4eaeSMogball   }
254a54f4eaeSMogball };
255be0a7e9fSMehdi Amini } // namespace
256a54f4eaeSMogball 
257a54f4eaeSMogball //===----------------------------------------------------------------------===//
258a54f4eaeSMogball // Pattern Population
259a54f4eaeSMogball //===----------------------------------------------------------------------===//
260a54f4eaeSMogball 
populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)261a54f4eaeSMogball void mlir::arith::populateArithmeticToLLVMConversionPatterns(
262a54f4eaeSMogball     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
263a54f4eaeSMogball   // clang-format off
264a54f4eaeSMogball   patterns.add<
265a54f4eaeSMogball     ConstantOpLowering,
266a54f4eaeSMogball     AddIOpLowering,
267a54f4eaeSMogball     SubIOpLowering,
268a54f4eaeSMogball     MulIOpLowering,
269a54f4eaeSMogball     DivUIOpLowering,
270a54f4eaeSMogball     DivSIOpLowering,
271a54f4eaeSMogball     RemUIOpLowering,
272a54f4eaeSMogball     RemSIOpLowering,
273a54f4eaeSMogball     AndIOpLowering,
274a54f4eaeSMogball     OrIOpLowering,
275a54f4eaeSMogball     XOrIOpLowering,
276a54f4eaeSMogball     ShLIOpLowering,
277a54f4eaeSMogball     ShRUIOpLowering,
278a54f4eaeSMogball     ShRSIOpLowering,
279a54f4eaeSMogball     NegFOpLowering,
280a54f4eaeSMogball     AddFOpLowering,
281a54f4eaeSMogball     SubFOpLowering,
282a54f4eaeSMogball     MulFOpLowering,
283a54f4eaeSMogball     DivFOpLowering,
284a54f4eaeSMogball     RemFOpLowering,
285a54f4eaeSMogball     ExtUIOpLowering,
286a54f4eaeSMogball     ExtSIOpLowering,
287a54f4eaeSMogball     ExtFOpLowering,
288a54f4eaeSMogball     TruncIOpLowering,
289a54f4eaeSMogball     TruncFOpLowering,
290a54f4eaeSMogball     UIToFPOpLowering,
291a54f4eaeSMogball     SIToFPOpLowering,
292a54f4eaeSMogball     FPToUIOpLowering,
293a54f4eaeSMogball     FPToSIOpLowering,
294a54f4eaeSMogball     IndexCastOpLowering,
295a54f4eaeSMogball     BitcastOpLowering,
296a54f4eaeSMogball     CmpIOpLowering,
297*dec8af70SRiver Riddle     CmpFOpLowering,
298*dec8af70SRiver Riddle     SelectOpLowering
299a54f4eaeSMogball   >(converter);
300a54f4eaeSMogball   // clang-format on
301a54f4eaeSMogball }
302a54f4eaeSMogball 
createConvertArithmeticToLLVMPass()303a54f4eaeSMogball std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() {
304a54f4eaeSMogball   return std::make_unique<ConvertArithmeticToLLVMPass>();
305a54f4eaeSMogball }
306