1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22 using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23 using CopySignOpLowering =
24     VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
25 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
26 using CtPopFOpLowering =
27     VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
28 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
29 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
30 using FloorOpLowering =
31     VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
32 using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
33 using Log10OpLowering =
34     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
35 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
36 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
37 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
38 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
39 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
40 
41 // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
42 template <typename MathOp, typename LLVMOp>
43 struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
44   using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
45   using Super = CountOpLowering<MathOp, LLVMOp>;
46 
47   LogicalResult
48   matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
49                   ConversionPatternRewriter &rewriter) const override {
50     auto operandType = adaptor.getOperand().getType();
51 
52     if (!operandType || !LLVM::isCompatibleType(operandType))
53       return failure();
54 
55     auto loc = op.getLoc();
56     auto resultType = op.getResult().getType();
57     auto boolType = rewriter.getIntegerType(1);
58     auto boolZero = rewriter.getIntegerAttr(boolType, 0);
59 
60     if (!operandType.template isa<LLVM::LLVMArrayType>()) {
61       LLVM::ConstantOp zero =
62           rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
63       rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
64                                           zero);
65       return success();
66     }
67 
68     auto vectorType = resultType.template dyn_cast<VectorType>();
69     if (!vectorType)
70       return failure();
71 
72     return LLVM::detail::handleMultidimensionalVectors(
73         op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
74         [&](Type llvm1DVectorTy, ValueRange operands) {
75           LLVM::ConstantOp zero =
76               rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
77           return rewriter.replaceOpWithNewOp<LLVMOp>(op, llvm1DVectorTy,
78                                                      operands[0], zero);
79         },
80         rewriter);
81   }
82 };
83 
84 using CountLeadingZerosOpLowering =
85     CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
86 using CountTrailingZerosOpLowering =
87     CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
88 
89 // A `expm1` is converted into `exp - 1`.
90 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
91   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
92 
93   LogicalResult
94   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
95                   ConversionPatternRewriter &rewriter) const override {
96     auto operandType = adaptor.getOperand().getType();
97 
98     if (!operandType || !LLVM::isCompatibleType(operandType))
99       return failure();
100 
101     auto loc = op.getLoc();
102     auto resultType = op.getResult().getType();
103     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
104     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
105 
106     if (!operandType.isa<LLVM::LLVMArrayType>()) {
107       LLVM::ConstantOp one;
108       if (LLVM::isCompatibleVectorType(operandType)) {
109         one = rewriter.create<LLVM::ConstantOp>(
110             loc, operandType,
111             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
112       } else {
113         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
114       }
115       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
116       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
117       return success();
118     }
119 
120     auto vectorType = resultType.dyn_cast<VectorType>();
121     if (!vectorType)
122       return rewriter.notifyMatchFailure(op, "expected vector result type");
123 
124     return LLVM::detail::handleMultidimensionalVectors(
125         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
126         [&](Type llvm1DVectorTy, ValueRange operands) {
127           auto splatAttr = SplatElementsAttr::get(
128               mlir::VectorType::get(
129                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
130                   floatType),
131               floatOne);
132           auto one =
133               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
134           auto exp =
135               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
136           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
137         },
138         rewriter);
139   }
140 };
141 
142 // A `log1p` is converted into `log(1 + ...)`.
143 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
144   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
145 
146   LogicalResult
147   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
148                   ConversionPatternRewriter &rewriter) const override {
149     auto operandType = adaptor.getOperand().getType();
150 
151     if (!operandType || !LLVM::isCompatibleType(operandType))
152       return rewriter.notifyMatchFailure(op, "unsupported operand type");
153 
154     auto loc = op.getLoc();
155     auto resultType = op.getResult().getType();
156     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
157     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
158 
159     if (!operandType.isa<LLVM::LLVMArrayType>()) {
160       LLVM::ConstantOp one =
161           LLVM::isCompatibleVectorType(operandType)
162               ? rewriter.create<LLVM::ConstantOp>(
163                     loc, operandType,
164                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
165                                            floatOne))
166               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
167 
168       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
169                                                adaptor.getOperand());
170       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
171       return success();
172     }
173 
174     auto vectorType = resultType.dyn_cast<VectorType>();
175     if (!vectorType)
176       return rewriter.notifyMatchFailure(op, "expected vector result type");
177 
178     return LLVM::detail::handleMultidimensionalVectors(
179         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
180         [&](Type llvm1DVectorTy, ValueRange operands) {
181           auto splatAttr = SplatElementsAttr::get(
182               mlir::VectorType::get(
183                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
184                   floatType),
185               floatOne);
186           auto one =
187               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
188           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
189                                                    operands[0]);
190           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
191         },
192         rewriter);
193   }
194 };
195 
196 // A `rsqrt` is converted into `1 / sqrt`.
197 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
198   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
199 
200   LogicalResult
201   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
202                   ConversionPatternRewriter &rewriter) const override {
203     auto operandType = adaptor.getOperand().getType();
204 
205     if (!operandType || !LLVM::isCompatibleType(operandType))
206       return failure();
207 
208     auto loc = op.getLoc();
209     auto resultType = op.getResult().getType();
210     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
211     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
212 
213     if (!operandType.isa<LLVM::LLVMArrayType>()) {
214       LLVM::ConstantOp one;
215       if (LLVM::isCompatibleVectorType(operandType)) {
216         one = rewriter.create<LLVM::ConstantOp>(
217             loc, operandType,
218             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
219       } else {
220         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
221       }
222       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
223       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
224       return success();
225     }
226 
227     auto vectorType = resultType.dyn_cast<VectorType>();
228     if (!vectorType)
229       return failure();
230 
231     return LLVM::detail::handleMultidimensionalVectors(
232         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
233         [&](Type llvm1DVectorTy, ValueRange operands) {
234           auto splatAttr = SplatElementsAttr::get(
235               mlir::VectorType::get(
236                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
237                   floatType),
238               floatOne);
239           auto one =
240               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
241           auto sqrt =
242               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
243           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
244         },
245         rewriter);
246   }
247 };
248 
249 struct ConvertMathToLLVMPass
250     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
251   ConvertMathToLLVMPass() = default;
252 
253   void runOnFunction() override {
254     RewritePatternSet patterns(&getContext());
255     LLVMTypeConverter converter(&getContext());
256     populateMathToLLVMConversionPatterns(converter, patterns);
257     LLVMConversionTarget target(getContext());
258     if (failed(
259             applyPartialConversion(getFunction(), target, std::move(patterns))))
260       signalPassFailure();
261   }
262 };
263 } // namespace
264 
265 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
266                                                 RewritePatternSet &patterns) {
267   // clang-format off
268   patterns.add<
269     AbsOpLowering,
270     CeilOpLowering,
271     CopySignOpLowering,
272     CosOpLowering,
273     CountLeadingZerosOpLowering,
274     CountTrailingZerosOpLowering,
275     CtPopFOpLowering,
276     ExpOpLowering,
277     Exp2OpLowering,
278     ExpM1OpLowering,
279     FloorOpLowering,
280     FmaOpLowering,
281     Log10OpLowering,
282     Log1pOpLowering,
283     Log2OpLowering,
284     LogOpLowering,
285     PowFOpLowering,
286     RsqrtOpLowering,
287     SinOpLowering,
288     SqrtOpLowering
289   >(converter);
290   // clang-format on
291 }
292 
293 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
294   return std::make_unique<ConvertMathToLLVMPass>();
295 }
296