1*26e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2*26e59cc1SAlex Zinenko //
3*26e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*26e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5*26e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*26e59cc1SAlex Zinenko //
7*26e59cc1SAlex Zinenko //===----------------------------------------------------------------------===//
8*26e59cc1SAlex Zinenko 
9*26e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10*26e59cc1SAlex Zinenko #include "../PassDetail.h"
11*26e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12*26e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
13*26e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14*26e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15*26e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h"
16*26e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h"
17*26e59cc1SAlex Zinenko 
18*26e59cc1SAlex Zinenko using namespace mlir;
19*26e59cc1SAlex Zinenko 
20*26e59cc1SAlex Zinenko namespace {
21*26e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
22*26e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
23*26e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
24*26e59cc1SAlex Zinenko using Log10OpLowering =
25*26e59cc1SAlex Zinenko     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
26*26e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
27*26e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
28*26e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
29*26e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
30*26e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
31*26e59cc1SAlex Zinenko 
32*26e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`.
33*26e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
34*26e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
35*26e59cc1SAlex Zinenko 
36*26e59cc1SAlex Zinenko   LogicalResult
37*26e59cc1SAlex Zinenko   matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
38*26e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
39*26e59cc1SAlex Zinenko     math::ExpM1Op::Adaptor transformed(operands);
40*26e59cc1SAlex Zinenko     auto operandType = transformed.operand().getType();
41*26e59cc1SAlex Zinenko 
42*26e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
43*26e59cc1SAlex Zinenko       return failure();
44*26e59cc1SAlex Zinenko 
45*26e59cc1SAlex Zinenko     auto loc = op.getLoc();
46*26e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
47*26e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
48*26e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
49*26e59cc1SAlex Zinenko 
50*26e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
51*26e59cc1SAlex Zinenko       LLVM::ConstantOp one;
52*26e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
53*26e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
54*26e59cc1SAlex Zinenko             loc, operandType,
55*26e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
56*26e59cc1SAlex Zinenko       } else {
57*26e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
58*26e59cc1SAlex Zinenko       }
59*26e59cc1SAlex Zinenko       auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
60*26e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
61*26e59cc1SAlex Zinenko       return success();
62*26e59cc1SAlex Zinenko     }
63*26e59cc1SAlex Zinenko 
64*26e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
65*26e59cc1SAlex Zinenko     if (!vectorType)
66*26e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
67*26e59cc1SAlex Zinenko 
68*26e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
69*26e59cc1SAlex Zinenko         op.getOperation(), operands, *getTypeConverter(),
70*26e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
71*26e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
72*26e59cc1SAlex Zinenko               mlir::VectorType::get(
73*26e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
74*26e59cc1SAlex Zinenko                   floatType),
75*26e59cc1SAlex Zinenko               floatOne);
76*26e59cc1SAlex Zinenko           auto one =
77*26e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
78*26e59cc1SAlex Zinenko           auto exp =
79*26e59cc1SAlex Zinenko               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
80*26e59cc1SAlex Zinenko           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
81*26e59cc1SAlex Zinenko         },
82*26e59cc1SAlex Zinenko         rewriter);
83*26e59cc1SAlex Zinenko   }
84*26e59cc1SAlex Zinenko };
85*26e59cc1SAlex Zinenko 
86*26e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`.
87*26e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
88*26e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
89*26e59cc1SAlex Zinenko 
90*26e59cc1SAlex Zinenko   LogicalResult
91*26e59cc1SAlex Zinenko   matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
92*26e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
93*26e59cc1SAlex Zinenko     math::Log1pOp::Adaptor transformed(operands);
94*26e59cc1SAlex Zinenko     auto operandType = transformed.operand().getType();
95*26e59cc1SAlex Zinenko 
96*26e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
97*26e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "unsupported operand type");
98*26e59cc1SAlex Zinenko 
99*26e59cc1SAlex Zinenko     auto loc = op.getLoc();
100*26e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
101*26e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
102*26e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
103*26e59cc1SAlex Zinenko 
104*26e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
105*26e59cc1SAlex Zinenko       LLVM::ConstantOp one =
106*26e59cc1SAlex Zinenko           LLVM::isCompatibleVectorType(operandType)
107*26e59cc1SAlex Zinenko               ? rewriter.create<LLVM::ConstantOp>(
108*26e59cc1SAlex Zinenko                     loc, operandType,
109*26e59cc1SAlex Zinenko                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
110*26e59cc1SAlex Zinenko                                            floatOne))
111*26e59cc1SAlex Zinenko               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
112*26e59cc1SAlex Zinenko 
113*26e59cc1SAlex Zinenko       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
114*26e59cc1SAlex Zinenko                                                transformed.operand());
115*26e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
116*26e59cc1SAlex Zinenko       return success();
117*26e59cc1SAlex Zinenko     }
118*26e59cc1SAlex Zinenko 
119*26e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
120*26e59cc1SAlex Zinenko     if (!vectorType)
121*26e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
122*26e59cc1SAlex Zinenko 
123*26e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
124*26e59cc1SAlex Zinenko         op.getOperation(), operands, *getTypeConverter(),
125*26e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
126*26e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
127*26e59cc1SAlex Zinenko               mlir::VectorType::get(
128*26e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
129*26e59cc1SAlex Zinenko                   floatType),
130*26e59cc1SAlex Zinenko               floatOne);
131*26e59cc1SAlex Zinenko           auto one =
132*26e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
133*26e59cc1SAlex Zinenko           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
134*26e59cc1SAlex Zinenko                                                    operands[0]);
135*26e59cc1SAlex Zinenko           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
136*26e59cc1SAlex Zinenko         },
137*26e59cc1SAlex Zinenko         rewriter);
138*26e59cc1SAlex Zinenko   }
139*26e59cc1SAlex Zinenko };
140*26e59cc1SAlex Zinenko 
141*26e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`.
142*26e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
143*26e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
144*26e59cc1SAlex Zinenko 
145*26e59cc1SAlex Zinenko   LogicalResult
146*26e59cc1SAlex Zinenko   matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands,
147*26e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
148*26e59cc1SAlex Zinenko     math::RsqrtOp::Adaptor transformed(operands);
149*26e59cc1SAlex Zinenko     auto operandType = transformed.operand().getType();
150*26e59cc1SAlex Zinenko 
151*26e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
152*26e59cc1SAlex Zinenko       return failure();
153*26e59cc1SAlex Zinenko 
154*26e59cc1SAlex Zinenko     auto loc = op.getLoc();
155*26e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
156*26e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
157*26e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
158*26e59cc1SAlex Zinenko 
159*26e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
160*26e59cc1SAlex Zinenko       LLVM::ConstantOp one;
161*26e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
162*26e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
163*26e59cc1SAlex Zinenko             loc, operandType,
164*26e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
165*26e59cc1SAlex Zinenko       } else {
166*26e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
167*26e59cc1SAlex Zinenko       }
168*26e59cc1SAlex Zinenko       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
169*26e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
170*26e59cc1SAlex Zinenko       return success();
171*26e59cc1SAlex Zinenko     }
172*26e59cc1SAlex Zinenko 
173*26e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
174*26e59cc1SAlex Zinenko     if (!vectorType)
175*26e59cc1SAlex Zinenko       return failure();
176*26e59cc1SAlex Zinenko 
177*26e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
178*26e59cc1SAlex Zinenko         op.getOperation(), operands, *getTypeConverter(),
179*26e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
180*26e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
181*26e59cc1SAlex Zinenko               mlir::VectorType::get(
182*26e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
183*26e59cc1SAlex Zinenko                   floatType),
184*26e59cc1SAlex Zinenko               floatOne);
185*26e59cc1SAlex Zinenko           auto one =
186*26e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
187*26e59cc1SAlex Zinenko           auto sqrt =
188*26e59cc1SAlex Zinenko               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
189*26e59cc1SAlex Zinenko           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
190*26e59cc1SAlex Zinenko         },
191*26e59cc1SAlex Zinenko         rewriter);
192*26e59cc1SAlex Zinenko   }
193*26e59cc1SAlex Zinenko };
194*26e59cc1SAlex Zinenko 
195*26e59cc1SAlex Zinenko struct ConvertMathToLLVMPass
196*26e59cc1SAlex Zinenko     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
197*26e59cc1SAlex Zinenko   ConvertMathToLLVMPass() = default;
198*26e59cc1SAlex Zinenko 
199*26e59cc1SAlex Zinenko   void runOnFunction() override {
200*26e59cc1SAlex Zinenko     RewritePatternSet patterns(&getContext());
201*26e59cc1SAlex Zinenko     LLVMTypeConverter converter(&getContext());
202*26e59cc1SAlex Zinenko     populateMathToLLVMConversionPatterns(converter, patterns);
203*26e59cc1SAlex Zinenko     LLVMConversionTarget target(getContext());
204*26e59cc1SAlex Zinenko     target.addLegalOp<LLVM::DialectCastOp>();
205*26e59cc1SAlex Zinenko     if (failed(
206*26e59cc1SAlex Zinenko             applyPartialConversion(getFunction(), target, std::move(patterns))))
207*26e59cc1SAlex Zinenko       signalPassFailure();
208*26e59cc1SAlex Zinenko   }
209*26e59cc1SAlex Zinenko };
210*26e59cc1SAlex Zinenko } // namespace
211*26e59cc1SAlex Zinenko 
212*26e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
213*26e59cc1SAlex Zinenko                                                 RewritePatternSet &patterns) {
214*26e59cc1SAlex Zinenko   // clang-format off
215*26e59cc1SAlex Zinenko   patterns.add<
216*26e59cc1SAlex Zinenko     CosOpLowering,
217*26e59cc1SAlex Zinenko     ExpOpLowering,
218*26e59cc1SAlex Zinenko     Exp2OpLowering,
219*26e59cc1SAlex Zinenko     ExpM1OpLowering,
220*26e59cc1SAlex Zinenko     Log10OpLowering,
221*26e59cc1SAlex Zinenko     Log1pOpLowering,
222*26e59cc1SAlex Zinenko     Log2OpLowering,
223*26e59cc1SAlex Zinenko     LogOpLowering,
224*26e59cc1SAlex Zinenko     PowFOpLowering,
225*26e59cc1SAlex Zinenko     RsqrtOpLowering,
226*26e59cc1SAlex Zinenko     SinOpLowering,
227*26e59cc1SAlex Zinenko     SqrtOpLowering
228*26e59cc1SAlex Zinenko   >(converter);
229*26e59cc1SAlex Zinenko   // clang-format on
230*26e59cc1SAlex Zinenko }
231*26e59cc1SAlex Zinenko 
232*26e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
233*26e59cc1SAlex Zinenko   return std::make_unique<ConvertMathToLLVMPass>();
234*26e59cc1SAlex Zinenko }
235