1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
22 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
23 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
24 using Log10OpLowering =
25     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
26 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
27 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
28 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
29 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
30 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
31 
32 // A `expm1` is converted into `exp - 1`.
33 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
34   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
35 
36   LogicalResult
37   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
38                   ConversionPatternRewriter &rewriter) const override {
39     auto operandType = adaptor.operand().getType();
40 
41     if (!operandType || !LLVM::isCompatibleType(operandType))
42       return failure();
43 
44     auto loc = op.getLoc();
45     auto resultType = op.getResult().getType();
46     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
47     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
48 
49     if (!operandType.isa<LLVM::LLVMArrayType>()) {
50       LLVM::ConstantOp one;
51       if (LLVM::isCompatibleVectorType(operandType)) {
52         one = rewriter.create<LLVM::ConstantOp>(
53             loc, operandType,
54             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
55       } else {
56         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
57       }
58       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.operand());
59       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
60       return success();
61     }
62 
63     auto vectorType = resultType.dyn_cast<VectorType>();
64     if (!vectorType)
65       return rewriter.notifyMatchFailure(op, "expected vector result type");
66 
67     return LLVM::detail::handleMultidimensionalVectors(
68         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
69         [&](Type llvm1DVectorTy, ValueRange operands) {
70           auto splatAttr = SplatElementsAttr::get(
71               mlir::VectorType::get(
72                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
73                   floatType),
74               floatOne);
75           auto one =
76               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
77           auto exp =
78               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
79           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
80         },
81         rewriter);
82   }
83 };
84 
85 // A `log1p` is converted into `log(1 + ...)`.
86 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
87   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
88 
89   LogicalResult
90   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
91                   ConversionPatternRewriter &rewriter) const override {
92     auto operandType = adaptor.operand().getType();
93 
94     if (!operandType || !LLVM::isCompatibleType(operandType))
95       return rewriter.notifyMatchFailure(op, "unsupported operand type");
96 
97     auto loc = op.getLoc();
98     auto resultType = op.getResult().getType();
99     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
100     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
101 
102     if (!operandType.isa<LLVM::LLVMArrayType>()) {
103       LLVM::ConstantOp one =
104           LLVM::isCompatibleVectorType(operandType)
105               ? rewriter.create<LLVM::ConstantOp>(
106                     loc, operandType,
107                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
108                                            floatOne))
109               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
110 
111       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
112                                                adaptor.operand());
113       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
114       return success();
115     }
116 
117     auto vectorType = resultType.dyn_cast<VectorType>();
118     if (!vectorType)
119       return rewriter.notifyMatchFailure(op, "expected vector result type");
120 
121     return LLVM::detail::handleMultidimensionalVectors(
122         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
123         [&](Type llvm1DVectorTy, ValueRange operands) {
124           auto splatAttr = SplatElementsAttr::get(
125               mlir::VectorType::get(
126                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
127                   floatType),
128               floatOne);
129           auto one =
130               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
131           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
132                                                    operands[0]);
133           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
134         },
135         rewriter);
136   }
137 };
138 
139 // A `rsqrt` is converted into `1 / sqrt`.
140 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
141   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
142 
143   LogicalResult
144   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
145                   ConversionPatternRewriter &rewriter) const override {
146     auto operandType = adaptor.operand().getType();
147 
148     if (!operandType || !LLVM::isCompatibleType(operandType))
149       return failure();
150 
151     auto loc = op.getLoc();
152     auto resultType = op.getResult().getType();
153     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
154     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
155 
156     if (!operandType.isa<LLVM::LLVMArrayType>()) {
157       LLVM::ConstantOp one;
158       if (LLVM::isCompatibleVectorType(operandType)) {
159         one = rewriter.create<LLVM::ConstantOp>(
160             loc, operandType,
161             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
162       } else {
163         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
164       }
165       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.operand());
166       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
167       return success();
168     }
169 
170     auto vectorType = resultType.dyn_cast<VectorType>();
171     if (!vectorType)
172       return failure();
173 
174     return LLVM::detail::handleMultidimensionalVectors(
175         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
176         [&](Type llvm1DVectorTy, ValueRange operands) {
177           auto splatAttr = SplatElementsAttr::get(
178               mlir::VectorType::get(
179                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
180                   floatType),
181               floatOne);
182           auto one =
183               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
184           auto sqrt =
185               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
186           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
187         },
188         rewriter);
189   }
190 };
191 
192 struct ConvertMathToLLVMPass
193     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
194   ConvertMathToLLVMPass() = default;
195 
196   void runOnFunction() override {
197     RewritePatternSet patterns(&getContext());
198     LLVMTypeConverter converter(&getContext());
199     populateMathToLLVMConversionPatterns(converter, patterns);
200     LLVMConversionTarget target(getContext());
201     if (failed(
202             applyPartialConversion(getFunction(), target, std::move(patterns))))
203       signalPassFailure();
204   }
205 };
206 } // namespace
207 
208 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
209                                                 RewritePatternSet &patterns) {
210   // clang-format off
211   patterns.add<
212     CosOpLowering,
213     ExpOpLowering,
214     Exp2OpLowering,
215     ExpM1OpLowering,
216     Log10OpLowering,
217     Log1pOpLowering,
218     Log2OpLowering,
219     LogOpLowering,
220     PowFOpLowering,
221     RsqrtOpLowering,
222     SinOpLowering,
223     SqrtOpLowering
224   >(converter);
225   // clang-format on
226 }
227 
228 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
229   return std::make_unique<ConvertMathToLLVMPass>();
230 }
231