1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22 using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23 using CopySignOpLowering =
24     VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
25 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
26 using CtPopFOpLowering =
27     VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
28 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
29 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
30 using FloorOpLowering =
31     VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
32 using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
33 using Log10OpLowering =
34     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
35 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
36 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
37 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
38 using RoundOpLowering =
39     VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
40 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
41 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
42 
43 // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
44 template <typename MathOp, typename LLVMOp>
45 struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
46   using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
47   using Super = CountOpLowering<MathOp, LLVMOp>;
48 
49   LogicalResult
matchAndRewrite__anoncde9de6d0111::CountOpLowering50   matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
51                   ConversionPatternRewriter &rewriter) const override {
52     auto operandType = adaptor.getOperand().getType();
53 
54     if (!operandType || !LLVM::isCompatibleType(operandType))
55       return failure();
56 
57     auto loc = op.getLoc();
58     auto resultType = op.getResult().getType();
59     auto boolType = rewriter.getIntegerType(1);
60     auto boolZero = rewriter.getIntegerAttr(boolType, 0);
61 
62     if (!operandType.template isa<LLVM::LLVMArrayType>()) {
63       LLVM::ConstantOp zero =
64           rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
65       rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
66                                           zero);
67       return success();
68     }
69 
70     auto vectorType = resultType.template dyn_cast<VectorType>();
71     if (!vectorType)
72       return failure();
73 
74     return LLVM::detail::handleMultidimensionalVectors(
75         op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
76         [&](Type llvm1DVectorTy, ValueRange operands) {
77           LLVM::ConstantOp zero =
78               rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
79           return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
80                                          zero);
81         },
82         rewriter);
83   }
84 };
85 
86 using CountLeadingZerosOpLowering =
87     CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
88 using CountTrailingZerosOpLowering =
89     CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
90 
91 // A `expm1` is converted into `exp - 1`.
92 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
93   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
94 
95   LogicalResult
matchAndRewrite__anoncde9de6d0111::ExpM1OpLowering96   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
97                   ConversionPatternRewriter &rewriter) const override {
98     auto operandType = adaptor.getOperand().getType();
99 
100     if (!operandType || !LLVM::isCompatibleType(operandType))
101       return failure();
102 
103     auto loc = op.getLoc();
104     auto resultType = op.getResult().getType();
105     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
106     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
107 
108     if (!operandType.isa<LLVM::LLVMArrayType>()) {
109       LLVM::ConstantOp one;
110       if (LLVM::isCompatibleVectorType(operandType)) {
111         one = rewriter.create<LLVM::ConstantOp>(
112             loc, operandType,
113             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
114       } else {
115         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
116       }
117       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
118       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
119       return success();
120     }
121 
122     auto vectorType = resultType.dyn_cast<VectorType>();
123     if (!vectorType)
124       return rewriter.notifyMatchFailure(op, "expected vector result type");
125 
126     return LLVM::detail::handleMultidimensionalVectors(
127         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
128         [&](Type llvm1DVectorTy, ValueRange operands) {
129           auto splatAttr = SplatElementsAttr::get(
130               mlir::VectorType::get(
131                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
132                   floatType),
133               floatOne);
134           auto one =
135               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
136           auto exp =
137               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
138           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
139         },
140         rewriter);
141   }
142 };
143 
144 // A `log1p` is converted into `log(1 + ...)`.
145 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
146   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
147 
148   LogicalResult
matchAndRewrite__anoncde9de6d0111::Log1pOpLowering149   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
150                   ConversionPatternRewriter &rewriter) const override {
151     auto operandType = adaptor.getOperand().getType();
152 
153     if (!operandType || !LLVM::isCompatibleType(operandType))
154       return rewriter.notifyMatchFailure(op, "unsupported operand type");
155 
156     auto loc = op.getLoc();
157     auto resultType = op.getResult().getType();
158     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
159     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
160 
161     if (!operandType.isa<LLVM::LLVMArrayType>()) {
162       LLVM::ConstantOp one =
163           LLVM::isCompatibleVectorType(operandType)
164               ? rewriter.create<LLVM::ConstantOp>(
165                     loc, operandType,
166                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
167                                            floatOne))
168               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
169 
170       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
171                                                adaptor.getOperand());
172       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
173       return success();
174     }
175 
176     auto vectorType = resultType.dyn_cast<VectorType>();
177     if (!vectorType)
178       return rewriter.notifyMatchFailure(op, "expected vector result type");
179 
180     return LLVM::detail::handleMultidimensionalVectors(
181         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
182         [&](Type llvm1DVectorTy, ValueRange operands) {
183           auto splatAttr = SplatElementsAttr::get(
184               mlir::VectorType::get(
185                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
186                   floatType),
187               floatOne);
188           auto one =
189               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
190           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
191                                                    operands[0]);
192           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
193         },
194         rewriter);
195   }
196 };
197 
198 // A `rsqrt` is converted into `1 / sqrt`.
199 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
200   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
201 
202   LogicalResult
matchAndRewrite__anoncde9de6d0111::RsqrtOpLowering203   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
204                   ConversionPatternRewriter &rewriter) const override {
205     auto operandType = adaptor.getOperand().getType();
206 
207     if (!operandType || !LLVM::isCompatibleType(operandType))
208       return failure();
209 
210     auto loc = op.getLoc();
211     auto resultType = op.getResult().getType();
212     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
213     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
214 
215     if (!operandType.isa<LLVM::LLVMArrayType>()) {
216       LLVM::ConstantOp one;
217       if (LLVM::isCompatibleVectorType(operandType)) {
218         one = rewriter.create<LLVM::ConstantOp>(
219             loc, operandType,
220             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
221       } else {
222         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
223       }
224       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
225       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
226       return success();
227     }
228 
229     auto vectorType = resultType.dyn_cast<VectorType>();
230     if (!vectorType)
231       return failure();
232 
233     return LLVM::detail::handleMultidimensionalVectors(
234         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
235         [&](Type llvm1DVectorTy, ValueRange operands) {
236           auto splatAttr = SplatElementsAttr::get(
237               mlir::VectorType::get(
238                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
239                   floatType),
240               floatOne);
241           auto one =
242               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
243           auto sqrt =
244               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
245           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
246         },
247         rewriter);
248   }
249 };
250 
251 struct ConvertMathToLLVMPass
252     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
253   ConvertMathToLLVMPass() = default;
254 
runOnOperation__anoncde9de6d0111::ConvertMathToLLVMPass255   void runOnOperation() override {
256     RewritePatternSet patterns(&getContext());
257     LLVMTypeConverter converter(&getContext());
258     populateMathToLLVMConversionPatterns(converter, patterns);
259     LLVMConversionTarget target(getContext());
260     if (failed(applyPartialConversion(getOperation(), target,
261                                       std::move(patterns))))
262       signalPassFailure();
263   }
264 };
265 } // namespace
266 
populateMathToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)267 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
268                                                 RewritePatternSet &patterns) {
269   // clang-format off
270   patterns.add<
271     AbsOpLowering,
272     CeilOpLowering,
273     CopySignOpLowering,
274     CosOpLowering,
275     CountLeadingZerosOpLowering,
276     CountTrailingZerosOpLowering,
277     CtPopFOpLowering,
278     Exp2OpLowering,
279     ExpM1OpLowering,
280     ExpOpLowering,
281     FloorOpLowering,
282     FmaOpLowering,
283     Log10OpLowering,
284     Log1pOpLowering,
285     Log2OpLowering,
286     LogOpLowering,
287     PowFOpLowering,
288     RoundOpLowering,
289     RsqrtOpLowering,
290     SinOpLowering,
291     SqrtOpLowering
292   >(converter);
293   // clang-format on
294 }
295 
createConvertMathToLLVMPass()296 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
297   return std::make_unique<ConvertMathToLLVMPass>();
298 }
299