1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22 using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23 using CopySignOpLowering =
24     VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
25 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
26 using CtPopFOpLowering =
27     VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
28 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
29 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
30 using FloorOpLowering =
31     VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
32 using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
33 using Log10OpLowering =
34     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
35 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
36 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
37 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
38 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
39 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
40 
41 // A `expm1` is converted into `exp - 1`.
42 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
43   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
44 
45   LogicalResult
46   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
47                   ConversionPatternRewriter &rewriter) const override {
48     auto operandType = adaptor.getOperand().getType();
49 
50     if (!operandType || !LLVM::isCompatibleType(operandType))
51       return failure();
52 
53     auto loc = op.getLoc();
54     auto resultType = op.getResult().getType();
55     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
56     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
57 
58     if (!operandType.isa<LLVM::LLVMArrayType>()) {
59       LLVM::ConstantOp one;
60       if (LLVM::isCompatibleVectorType(operandType)) {
61         one = rewriter.create<LLVM::ConstantOp>(
62             loc, operandType,
63             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
64       } else {
65         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
66       }
67       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
68       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
69       return success();
70     }
71 
72     auto vectorType = resultType.dyn_cast<VectorType>();
73     if (!vectorType)
74       return rewriter.notifyMatchFailure(op, "expected vector result type");
75 
76     return LLVM::detail::handleMultidimensionalVectors(
77         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
78         [&](Type llvm1DVectorTy, ValueRange operands) {
79           auto splatAttr = SplatElementsAttr::get(
80               mlir::VectorType::get(
81                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
82                   floatType),
83               floatOne);
84           auto one =
85               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
86           auto exp =
87               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
88           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
89         },
90         rewriter);
91   }
92 };
93 
94 // A `log1p` is converted into `log(1 + ...)`.
95 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
96   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
97 
98   LogicalResult
99   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
100                   ConversionPatternRewriter &rewriter) const override {
101     auto operandType = adaptor.getOperand().getType();
102 
103     if (!operandType || !LLVM::isCompatibleType(operandType))
104       return rewriter.notifyMatchFailure(op, "unsupported operand type");
105 
106     auto loc = op.getLoc();
107     auto resultType = op.getResult().getType();
108     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
109     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
110 
111     if (!operandType.isa<LLVM::LLVMArrayType>()) {
112       LLVM::ConstantOp one =
113           LLVM::isCompatibleVectorType(operandType)
114               ? rewriter.create<LLVM::ConstantOp>(
115                     loc, operandType,
116                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
117                                            floatOne))
118               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
119 
120       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
121                                                adaptor.getOperand());
122       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
123       return success();
124     }
125 
126     auto vectorType = resultType.dyn_cast<VectorType>();
127     if (!vectorType)
128       return rewriter.notifyMatchFailure(op, "expected vector result type");
129 
130     return LLVM::detail::handleMultidimensionalVectors(
131         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
132         [&](Type llvm1DVectorTy, ValueRange operands) {
133           auto splatAttr = SplatElementsAttr::get(
134               mlir::VectorType::get(
135                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
136                   floatType),
137               floatOne);
138           auto one =
139               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
140           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
141                                                    operands[0]);
142           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
143         },
144         rewriter);
145   }
146 };
147 
148 // A `rsqrt` is converted into `1 / sqrt`.
149 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
150   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
151 
152   LogicalResult
153   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
154                   ConversionPatternRewriter &rewriter) const override {
155     auto operandType = adaptor.getOperand().getType();
156 
157     if (!operandType || !LLVM::isCompatibleType(operandType))
158       return failure();
159 
160     auto loc = op.getLoc();
161     auto resultType = op.getResult().getType();
162     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
163     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
164 
165     if (!operandType.isa<LLVM::LLVMArrayType>()) {
166       LLVM::ConstantOp one;
167       if (LLVM::isCompatibleVectorType(operandType)) {
168         one = rewriter.create<LLVM::ConstantOp>(
169             loc, operandType,
170             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
171       } else {
172         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
173       }
174       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
175       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
176       return success();
177     }
178 
179     auto vectorType = resultType.dyn_cast<VectorType>();
180     if (!vectorType)
181       return failure();
182 
183     return LLVM::detail::handleMultidimensionalVectors(
184         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
185         [&](Type llvm1DVectorTy, ValueRange operands) {
186           auto splatAttr = SplatElementsAttr::get(
187               mlir::VectorType::get(
188                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
189                   floatType),
190               floatOne);
191           auto one =
192               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
193           auto sqrt =
194               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
195           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
196         },
197         rewriter);
198   }
199 };
200 
201 struct ConvertMathToLLVMPass
202     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
203   ConvertMathToLLVMPass() = default;
204 
205   void runOnFunction() override {
206     RewritePatternSet patterns(&getContext());
207     LLVMTypeConverter converter(&getContext());
208     populateMathToLLVMConversionPatterns(converter, patterns);
209     LLVMConversionTarget target(getContext());
210     if (failed(
211             applyPartialConversion(getFunction(), target, std::move(patterns))))
212       signalPassFailure();
213   }
214 };
215 } // namespace
216 
217 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
218                                                 RewritePatternSet &patterns) {
219   // clang-format off
220   patterns.add<
221     AbsOpLowering,
222     CeilOpLowering,
223     CopySignOpLowering,
224     CosOpLowering,
225     CtPopFOpLowering,
226     ExpOpLowering,
227     Exp2OpLowering,
228     ExpM1OpLowering,
229     FloorOpLowering,
230     FmaOpLowering,
231     Log10OpLowering,
232     Log1pOpLowering,
233     Log2OpLowering,
234     LogOpLowering,
235     PowFOpLowering,
236     RsqrtOpLowering,
237     SinOpLowering,
238     SqrtOpLowering
239   >(converter);
240   // clang-format on
241 }
242 
243 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
244   return std::make_unique<ConvertMathToLLVMPass>();
245 }
246