1*a54f4eaeSMogball //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2*a54f4eaeSMogball //
3*a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information.
5*a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*a54f4eaeSMogball //
7*a54f4eaeSMogball //===----------------------------------------------------------------------===//
8*a54f4eaeSMogball 
9*a54f4eaeSMogball #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10*a54f4eaeSMogball #include "../PassDetail.h"
11*a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12*a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
13*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14*a54f4eaeSMogball #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15*a54f4eaeSMogball #include "mlir/IR/TypeUtilities.h"
16*a54f4eaeSMogball 
17*a54f4eaeSMogball using namespace mlir;
18*a54f4eaeSMogball 
19*a54f4eaeSMogball namespace {
20*a54f4eaeSMogball 
21*a54f4eaeSMogball //===----------------------------------------------------------------------===//
22*a54f4eaeSMogball // Straightforward Op Lowerings
23*a54f4eaeSMogball //===----------------------------------------------------------------------===//
24*a54f4eaeSMogball 
25*a54f4eaeSMogball using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
26*a54f4eaeSMogball using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
27*a54f4eaeSMogball using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
28*a54f4eaeSMogball using DivUIOpLowering =
29*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
30*a54f4eaeSMogball using DivSIOpLowering =
31*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
32*a54f4eaeSMogball using RemUIOpLowering =
33*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
34*a54f4eaeSMogball using RemSIOpLowering =
35*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
36*a54f4eaeSMogball using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
37*a54f4eaeSMogball using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
38*a54f4eaeSMogball using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
39*a54f4eaeSMogball using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
40*a54f4eaeSMogball using ShRUIOpLowering =
41*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
42*a54f4eaeSMogball using ShRSIOpLowering =
43*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
44*a54f4eaeSMogball using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
45*a54f4eaeSMogball using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
46*a54f4eaeSMogball using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
47*a54f4eaeSMogball using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
48*a54f4eaeSMogball using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
49*a54f4eaeSMogball using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
50*a54f4eaeSMogball using ExtUIOpLowering =
51*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
52*a54f4eaeSMogball using ExtSIOpLowering =
53*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
54*a54f4eaeSMogball using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
55*a54f4eaeSMogball using TruncIOpLowering =
56*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
57*a54f4eaeSMogball using TruncFOpLowering =
58*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
59*a54f4eaeSMogball using UIToFPOpLowering =
60*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
61*a54f4eaeSMogball using SIToFPOpLowering =
62*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
63*a54f4eaeSMogball using FPToUIOpLowering =
64*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
65*a54f4eaeSMogball using FPToSIOpLowering =
66*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
67*a54f4eaeSMogball using BitcastOpLowering =
68*a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
69*a54f4eaeSMogball 
70*a54f4eaeSMogball //===----------------------------------------------------------------------===//
71*a54f4eaeSMogball // Op Lowering Patterns
72*a54f4eaeSMogball //===----------------------------------------------------------------------===//
73*a54f4eaeSMogball 
74*a54f4eaeSMogball /// Directly lower to LLVM op.
75*a54f4eaeSMogball struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
76*a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern;
77*a54f4eaeSMogball 
78*a54f4eaeSMogball   LogicalResult
79*a54f4eaeSMogball   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
80*a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
81*a54f4eaeSMogball };
82*a54f4eaeSMogball 
83*a54f4eaeSMogball /// The lowering of index_cast becomes an integer conversion since index
84*a54f4eaeSMogball /// becomes an integer.  If the bit width of the source and target integer
85*a54f4eaeSMogball /// types is the same, just erase the cast.  If the target type is wider,
86*a54f4eaeSMogball /// sign-extend the value, otherwise truncate it.
87*a54f4eaeSMogball struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
88*a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern;
89*a54f4eaeSMogball 
90*a54f4eaeSMogball   LogicalResult
91*a54f4eaeSMogball   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
92*a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
93*a54f4eaeSMogball };
94*a54f4eaeSMogball 
95*a54f4eaeSMogball struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
96*a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern;
97*a54f4eaeSMogball 
98*a54f4eaeSMogball   LogicalResult
99*a54f4eaeSMogball   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
100*a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
101*a54f4eaeSMogball };
102*a54f4eaeSMogball 
103*a54f4eaeSMogball struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
104*a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern;
105*a54f4eaeSMogball 
106*a54f4eaeSMogball   LogicalResult
107*a54f4eaeSMogball   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
108*a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
109*a54f4eaeSMogball };
110*a54f4eaeSMogball 
111*a54f4eaeSMogball } // end anonymous namespace
112*a54f4eaeSMogball 
113*a54f4eaeSMogball //===----------------------------------------------------------------------===//
114*a54f4eaeSMogball // ConstantOpLowering
115*a54f4eaeSMogball //===----------------------------------------------------------------------===//
116*a54f4eaeSMogball 
117*a54f4eaeSMogball LogicalResult
118*a54f4eaeSMogball ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
119*a54f4eaeSMogball                                     ConversionPatternRewriter &rewriter) const {
120*a54f4eaeSMogball   return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
121*a54f4eaeSMogball                                        adaptor.getOperands(),
122*a54f4eaeSMogball                                        *getTypeConverter(), rewriter);
123*a54f4eaeSMogball }
124*a54f4eaeSMogball 
125*a54f4eaeSMogball //===----------------------------------------------------------------------===//
126*a54f4eaeSMogball // IndexCastOpLowering
127*a54f4eaeSMogball //===----------------------------------------------------------------------===//
128*a54f4eaeSMogball 
129*a54f4eaeSMogball LogicalResult IndexCastOpLowering::matchAndRewrite(
130*a54f4eaeSMogball     arith::IndexCastOp op, OpAdaptor adaptor,
131*a54f4eaeSMogball     ConversionPatternRewriter &rewriter) const {
132*a54f4eaeSMogball   auto targetType = typeConverter->convertType(op.getResult().getType());
133*a54f4eaeSMogball   auto targetElementType =
134*a54f4eaeSMogball       typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
135*a54f4eaeSMogball           .cast<IntegerType>();
136*a54f4eaeSMogball   auto sourceElementType =
137*a54f4eaeSMogball       getElementTypeOrSelf(adaptor.in()).cast<IntegerType>();
138*a54f4eaeSMogball   unsigned targetBits = targetElementType.getWidth();
139*a54f4eaeSMogball   unsigned sourceBits = sourceElementType.getWidth();
140*a54f4eaeSMogball 
141*a54f4eaeSMogball   if (targetBits == sourceBits)
142*a54f4eaeSMogball     rewriter.replaceOp(op, adaptor.in());
143*a54f4eaeSMogball   else if (targetBits < sourceBits)
144*a54f4eaeSMogball     rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.in());
145*a54f4eaeSMogball   else
146*a54f4eaeSMogball     rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.in());
147*a54f4eaeSMogball   return success();
148*a54f4eaeSMogball }
149*a54f4eaeSMogball 
150*a54f4eaeSMogball //===----------------------------------------------------------------------===//
151*a54f4eaeSMogball // CmpIOpLowering
152*a54f4eaeSMogball //===----------------------------------------------------------------------===//
153*a54f4eaeSMogball 
154*a54f4eaeSMogball // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
155*a54f4eaeSMogball // share numerical values so just cast.
156*a54f4eaeSMogball template <typename LLVMPredType, typename PredType>
157*a54f4eaeSMogball static LLVMPredType convertCmpPredicate(PredType pred) {
158*a54f4eaeSMogball   return static_cast<LLVMPredType>(pred);
159*a54f4eaeSMogball }
160*a54f4eaeSMogball 
161*a54f4eaeSMogball LogicalResult
162*a54f4eaeSMogball CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
163*a54f4eaeSMogball                                 ConversionPatternRewriter &rewriter) const {
164*a54f4eaeSMogball   auto operandType = adaptor.lhs().getType();
165*a54f4eaeSMogball   auto resultType = op.getResult().getType();
166*a54f4eaeSMogball 
167*a54f4eaeSMogball   // Handle the scalar and 1D vector cases.
168*a54f4eaeSMogball   if (!operandType.isa<LLVM::LLVMArrayType>()) {
169*a54f4eaeSMogball     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
170*a54f4eaeSMogball         op, typeConverter->convertType(resultType),
171*a54f4eaeSMogball         convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
172*a54f4eaeSMogball         adaptor.lhs(), adaptor.rhs());
173*a54f4eaeSMogball     return success();
174*a54f4eaeSMogball   }
175*a54f4eaeSMogball 
176*a54f4eaeSMogball   auto vectorType = resultType.dyn_cast<VectorType>();
177*a54f4eaeSMogball   if (!vectorType)
178*a54f4eaeSMogball     return rewriter.notifyMatchFailure(op, "expected vector result type");
179*a54f4eaeSMogball 
180*a54f4eaeSMogball   return LLVM::detail::handleMultidimensionalVectors(
181*a54f4eaeSMogball       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
182*a54f4eaeSMogball       [&](Type llvm1DVectorTy, ValueRange operands) {
183*a54f4eaeSMogball         OpAdaptor adaptor(operands);
184*a54f4eaeSMogball         return rewriter.create<LLVM::ICmpOp>(
185*a54f4eaeSMogball             op.getLoc(), llvm1DVectorTy,
186*a54f4eaeSMogball             convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
187*a54f4eaeSMogball             adaptor.lhs(), adaptor.rhs());
188*a54f4eaeSMogball       },
189*a54f4eaeSMogball       rewriter);
190*a54f4eaeSMogball 
191*a54f4eaeSMogball   return success();
192*a54f4eaeSMogball }
193*a54f4eaeSMogball 
194*a54f4eaeSMogball //===----------------------------------------------------------------------===//
195*a54f4eaeSMogball // CmpFOpLowering
196*a54f4eaeSMogball //===----------------------------------------------------------------------===//
197*a54f4eaeSMogball 
198*a54f4eaeSMogball LogicalResult
199*a54f4eaeSMogball CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
200*a54f4eaeSMogball                                 ConversionPatternRewriter &rewriter) const {
201*a54f4eaeSMogball   auto operandType = adaptor.lhs().getType();
202*a54f4eaeSMogball   auto resultType = op.getResult().getType();
203*a54f4eaeSMogball 
204*a54f4eaeSMogball   // Handle the scalar and 1D vector cases.
205*a54f4eaeSMogball   if (!operandType.isa<LLVM::LLVMArrayType>()) {
206*a54f4eaeSMogball     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
207*a54f4eaeSMogball         op, typeConverter->convertType(resultType),
208*a54f4eaeSMogball         convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
209*a54f4eaeSMogball         adaptor.lhs(), adaptor.rhs());
210*a54f4eaeSMogball     return success();
211*a54f4eaeSMogball   }
212*a54f4eaeSMogball 
213*a54f4eaeSMogball   auto vectorType = resultType.dyn_cast<VectorType>();
214*a54f4eaeSMogball   if (!vectorType)
215*a54f4eaeSMogball     return rewriter.notifyMatchFailure(op, "expected vector result type");
216*a54f4eaeSMogball 
217*a54f4eaeSMogball   return LLVM::detail::handleMultidimensionalVectors(
218*a54f4eaeSMogball       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
219*a54f4eaeSMogball       [&](Type llvm1DVectorTy, ValueRange operands) {
220*a54f4eaeSMogball         OpAdaptor adaptor(operands);
221*a54f4eaeSMogball         return rewriter.create<LLVM::FCmpOp>(
222*a54f4eaeSMogball             op.getLoc(), llvm1DVectorTy,
223*a54f4eaeSMogball             convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
224*a54f4eaeSMogball             adaptor.lhs(), adaptor.rhs());
225*a54f4eaeSMogball       },
226*a54f4eaeSMogball       rewriter);
227*a54f4eaeSMogball }
228*a54f4eaeSMogball 
229*a54f4eaeSMogball //===----------------------------------------------------------------------===//
230*a54f4eaeSMogball // Pass Definition
231*a54f4eaeSMogball //===----------------------------------------------------------------------===//
232*a54f4eaeSMogball 
233*a54f4eaeSMogball namespace {
234*a54f4eaeSMogball struct ConvertArithmeticToLLVMPass
235*a54f4eaeSMogball     : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> {
236*a54f4eaeSMogball   ConvertArithmeticToLLVMPass() = default;
237*a54f4eaeSMogball 
238*a54f4eaeSMogball   void runOnFunction() override {
239*a54f4eaeSMogball     LLVMConversionTarget target(getContext());
240*a54f4eaeSMogball     RewritePatternSet patterns(&getContext());
241*a54f4eaeSMogball 
242*a54f4eaeSMogball     LowerToLLVMOptions options(&getContext());
243*a54f4eaeSMogball     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
244*a54f4eaeSMogball       options.overrideIndexBitwidth(indexBitwidth);
245*a54f4eaeSMogball 
246*a54f4eaeSMogball     LLVMTypeConverter converter(&getContext(), options);
247*a54f4eaeSMogball     mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
248*a54f4eaeSMogball                                                             patterns);
249*a54f4eaeSMogball 
250*a54f4eaeSMogball     if (failed(
251*a54f4eaeSMogball             applyPartialConversion(getFunction(), target, std::move(patterns))))
252*a54f4eaeSMogball       signalPassFailure();
253*a54f4eaeSMogball   }
254*a54f4eaeSMogball };
255*a54f4eaeSMogball } // end anonymous namespace
256*a54f4eaeSMogball 
257*a54f4eaeSMogball //===----------------------------------------------------------------------===//
258*a54f4eaeSMogball // Pattern Population
259*a54f4eaeSMogball //===----------------------------------------------------------------------===//
260*a54f4eaeSMogball 
261*a54f4eaeSMogball void mlir::arith::populateArithmeticToLLVMConversionPatterns(
262*a54f4eaeSMogball     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
263*a54f4eaeSMogball   // clang-format off
264*a54f4eaeSMogball   patterns.add<
265*a54f4eaeSMogball     ConstantOpLowering,
266*a54f4eaeSMogball     AddIOpLowering,
267*a54f4eaeSMogball     SubIOpLowering,
268*a54f4eaeSMogball     MulIOpLowering,
269*a54f4eaeSMogball     DivUIOpLowering,
270*a54f4eaeSMogball     DivSIOpLowering,
271*a54f4eaeSMogball     RemUIOpLowering,
272*a54f4eaeSMogball     RemSIOpLowering,
273*a54f4eaeSMogball     AndIOpLowering,
274*a54f4eaeSMogball     OrIOpLowering,
275*a54f4eaeSMogball     XOrIOpLowering,
276*a54f4eaeSMogball     ShLIOpLowering,
277*a54f4eaeSMogball     ShRUIOpLowering,
278*a54f4eaeSMogball     ShRSIOpLowering,
279*a54f4eaeSMogball     NegFOpLowering,
280*a54f4eaeSMogball     AddFOpLowering,
281*a54f4eaeSMogball     SubFOpLowering,
282*a54f4eaeSMogball     MulFOpLowering,
283*a54f4eaeSMogball     DivFOpLowering,
284*a54f4eaeSMogball     RemFOpLowering,
285*a54f4eaeSMogball     ExtUIOpLowering,
286*a54f4eaeSMogball     ExtSIOpLowering,
287*a54f4eaeSMogball     ExtFOpLowering,
288*a54f4eaeSMogball     TruncIOpLowering,
289*a54f4eaeSMogball     TruncFOpLowering,
290*a54f4eaeSMogball     UIToFPOpLowering,
291*a54f4eaeSMogball     SIToFPOpLowering,
292*a54f4eaeSMogball     FPToUIOpLowering,
293*a54f4eaeSMogball     FPToSIOpLowering,
294*a54f4eaeSMogball     IndexCastOpLowering,
295*a54f4eaeSMogball     BitcastOpLowering,
296*a54f4eaeSMogball     CmpIOpLowering,
297*a54f4eaeSMogball     CmpFOpLowering
298*a54f4eaeSMogball   >(converter);
299*a54f4eaeSMogball   // clang-format on
300*a54f4eaeSMogball }
301*a54f4eaeSMogball 
302*a54f4eaeSMogball std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() {
303*a54f4eaeSMogball   return std::make_unique<ConvertArithmeticToLLVMPass>();
304*a54f4eaeSMogball }
305