1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22 using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23 using CopySignOpLowering =
24     VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
25 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
26 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
27 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
28 using FloorOpLowering =
29     VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
30 using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
31 using Log10OpLowering =
32     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
33 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
34 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
35 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
36 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
37 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
38 
39 // A `expm1` is converted into `exp - 1`.
40 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
41   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
42 
43   LogicalResult
44   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
45                   ConversionPatternRewriter &rewriter) const override {
46     auto operandType = adaptor.operand().getType();
47 
48     if (!operandType || !LLVM::isCompatibleType(operandType))
49       return failure();
50 
51     auto loc = op.getLoc();
52     auto resultType = op.getResult().getType();
53     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
54     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
55 
56     if (!operandType.isa<LLVM::LLVMArrayType>()) {
57       LLVM::ConstantOp one;
58       if (LLVM::isCompatibleVectorType(operandType)) {
59         one = rewriter.create<LLVM::ConstantOp>(
60             loc, operandType,
61             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
62       } else {
63         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
64       }
65       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.operand());
66       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
67       return success();
68     }
69 
70     auto vectorType = resultType.dyn_cast<VectorType>();
71     if (!vectorType)
72       return rewriter.notifyMatchFailure(op, "expected vector result type");
73 
74     return LLVM::detail::handleMultidimensionalVectors(
75         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
76         [&](Type llvm1DVectorTy, ValueRange operands) {
77           auto splatAttr = SplatElementsAttr::get(
78               mlir::VectorType::get(
79                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
80                   floatType),
81               floatOne);
82           auto one =
83               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
84           auto exp =
85               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
86           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
87         },
88         rewriter);
89   }
90 };
91 
92 // A `log1p` is converted into `log(1 + ...)`.
93 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
94   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
95 
96   LogicalResult
97   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
98                   ConversionPatternRewriter &rewriter) const override {
99     auto operandType = adaptor.operand().getType();
100 
101     if (!operandType || !LLVM::isCompatibleType(operandType))
102       return rewriter.notifyMatchFailure(op, "unsupported operand type");
103 
104     auto loc = op.getLoc();
105     auto resultType = op.getResult().getType();
106     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
107     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
108 
109     if (!operandType.isa<LLVM::LLVMArrayType>()) {
110       LLVM::ConstantOp one =
111           LLVM::isCompatibleVectorType(operandType)
112               ? rewriter.create<LLVM::ConstantOp>(
113                     loc, operandType,
114                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
115                                            floatOne))
116               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
117 
118       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
119                                                adaptor.operand());
120       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
121       return success();
122     }
123 
124     auto vectorType = resultType.dyn_cast<VectorType>();
125     if (!vectorType)
126       return rewriter.notifyMatchFailure(op, "expected vector result type");
127 
128     return LLVM::detail::handleMultidimensionalVectors(
129         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
130         [&](Type llvm1DVectorTy, ValueRange operands) {
131           auto splatAttr = SplatElementsAttr::get(
132               mlir::VectorType::get(
133                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
134                   floatType),
135               floatOne);
136           auto one =
137               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
138           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
139                                                    operands[0]);
140           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
141         },
142         rewriter);
143   }
144 };
145 
146 // A `rsqrt` is converted into `1 / sqrt`.
147 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
148   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
149 
150   LogicalResult
151   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
152                   ConversionPatternRewriter &rewriter) const override {
153     auto operandType = adaptor.operand().getType();
154 
155     if (!operandType || !LLVM::isCompatibleType(operandType))
156       return failure();
157 
158     auto loc = op.getLoc();
159     auto resultType = op.getResult().getType();
160     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
161     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
162 
163     if (!operandType.isa<LLVM::LLVMArrayType>()) {
164       LLVM::ConstantOp one;
165       if (LLVM::isCompatibleVectorType(operandType)) {
166         one = rewriter.create<LLVM::ConstantOp>(
167             loc, operandType,
168             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
169       } else {
170         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
171       }
172       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.operand());
173       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
174       return success();
175     }
176 
177     auto vectorType = resultType.dyn_cast<VectorType>();
178     if (!vectorType)
179       return failure();
180 
181     return LLVM::detail::handleMultidimensionalVectors(
182         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
183         [&](Type llvm1DVectorTy, ValueRange operands) {
184           auto splatAttr = SplatElementsAttr::get(
185               mlir::VectorType::get(
186                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
187                   floatType),
188               floatOne);
189           auto one =
190               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
191           auto sqrt =
192               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
193           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
194         },
195         rewriter);
196   }
197 };
198 
199 struct ConvertMathToLLVMPass
200     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
201   ConvertMathToLLVMPass() = default;
202 
203   void runOnFunction() override {
204     RewritePatternSet patterns(&getContext());
205     LLVMTypeConverter converter(&getContext());
206     populateMathToLLVMConversionPatterns(converter, patterns);
207     LLVMConversionTarget target(getContext());
208     if (failed(
209             applyPartialConversion(getFunction(), target, std::move(patterns))))
210       signalPassFailure();
211   }
212 };
213 } // namespace
214 
215 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
216                                                 RewritePatternSet &patterns) {
217   // clang-format off
218   patterns.add<
219     AbsOpLowering,
220     CeilOpLowering,
221     CopySignOpLowering,
222     CosOpLowering,
223     ExpOpLowering,
224     Exp2OpLowering,
225     ExpM1OpLowering,
226     FloorOpLowering,
227     FmaOpLowering,
228     Log10OpLowering,
229     Log1pOpLowering,
230     Log2OpLowering,
231     LogOpLowering,
232     PowFOpLowering,
233     RsqrtOpLowering,
234     SinOpLowering,
235     SqrtOpLowering
236   >(converter);
237   // clang-format on
238 }
239 
240 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
241   return std::make_unique<ConvertMathToLLVMPass>();
242 }
243