1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/IR/TypeUtilities.h"
17
18 using namespace mlir;
19
20 namespace {
21 using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22 using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23 using CopySignOpLowering =
24 VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
25 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
26 using CtPopFOpLowering =
27 VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
28 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
29 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
30 using FloorOpLowering =
31 VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
32 using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
33 using Log10OpLowering =
34 VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
35 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
36 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
37 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
38 using RoundOpLowering =
39 VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
40 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
41 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
42
43 // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
44 template <typename MathOp, typename LLVMOp>
45 struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
46 using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
47 using Super = CountOpLowering<MathOp, LLVMOp>;
48
49 LogicalResult
matchAndRewrite__anoncde9de6d0111::CountOpLowering50 matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
51 ConversionPatternRewriter &rewriter) const override {
52 auto operandType = adaptor.getOperand().getType();
53
54 if (!operandType || !LLVM::isCompatibleType(operandType))
55 return failure();
56
57 auto loc = op.getLoc();
58 auto resultType = op.getResult().getType();
59 auto boolType = rewriter.getIntegerType(1);
60 auto boolZero = rewriter.getIntegerAttr(boolType, 0);
61
62 if (!operandType.template isa<LLVM::LLVMArrayType>()) {
63 LLVM::ConstantOp zero =
64 rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
65 rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
66 zero);
67 return success();
68 }
69
70 auto vectorType = resultType.template dyn_cast<VectorType>();
71 if (!vectorType)
72 return failure();
73
74 return LLVM::detail::handleMultidimensionalVectors(
75 op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
76 [&](Type llvm1DVectorTy, ValueRange operands) {
77 LLVM::ConstantOp zero =
78 rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
79 return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
80 zero);
81 },
82 rewriter);
83 }
84 };
85
86 using CountLeadingZerosOpLowering =
87 CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
88 using CountTrailingZerosOpLowering =
89 CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
90
91 // A `expm1` is converted into `exp - 1`.
92 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
93 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
94
95 LogicalResult
matchAndRewrite__anoncde9de6d0111::ExpM1OpLowering96 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
97 ConversionPatternRewriter &rewriter) const override {
98 auto operandType = adaptor.getOperand().getType();
99
100 if (!operandType || !LLVM::isCompatibleType(operandType))
101 return failure();
102
103 auto loc = op.getLoc();
104 auto resultType = op.getResult().getType();
105 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
106 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
107
108 if (!operandType.isa<LLVM::LLVMArrayType>()) {
109 LLVM::ConstantOp one;
110 if (LLVM::isCompatibleVectorType(operandType)) {
111 one = rewriter.create<LLVM::ConstantOp>(
112 loc, operandType,
113 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
114 } else {
115 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
116 }
117 auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
118 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
119 return success();
120 }
121
122 auto vectorType = resultType.dyn_cast<VectorType>();
123 if (!vectorType)
124 return rewriter.notifyMatchFailure(op, "expected vector result type");
125
126 return LLVM::detail::handleMultidimensionalVectors(
127 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
128 [&](Type llvm1DVectorTy, ValueRange operands) {
129 auto splatAttr = SplatElementsAttr::get(
130 mlir::VectorType::get(
131 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
132 floatType),
133 floatOne);
134 auto one =
135 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
136 auto exp =
137 rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
138 return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
139 },
140 rewriter);
141 }
142 };
143
144 // A `log1p` is converted into `log(1 + ...)`.
145 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
146 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
147
148 LogicalResult
matchAndRewrite__anoncde9de6d0111::Log1pOpLowering149 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
150 ConversionPatternRewriter &rewriter) const override {
151 auto operandType = adaptor.getOperand().getType();
152
153 if (!operandType || !LLVM::isCompatibleType(operandType))
154 return rewriter.notifyMatchFailure(op, "unsupported operand type");
155
156 auto loc = op.getLoc();
157 auto resultType = op.getResult().getType();
158 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
159 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
160
161 if (!operandType.isa<LLVM::LLVMArrayType>()) {
162 LLVM::ConstantOp one =
163 LLVM::isCompatibleVectorType(operandType)
164 ? rewriter.create<LLVM::ConstantOp>(
165 loc, operandType,
166 SplatElementsAttr::get(resultType.cast<ShapedType>(),
167 floatOne))
168 : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
169
170 auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
171 adaptor.getOperand());
172 rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
173 return success();
174 }
175
176 auto vectorType = resultType.dyn_cast<VectorType>();
177 if (!vectorType)
178 return rewriter.notifyMatchFailure(op, "expected vector result type");
179
180 return LLVM::detail::handleMultidimensionalVectors(
181 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
182 [&](Type llvm1DVectorTy, ValueRange operands) {
183 auto splatAttr = SplatElementsAttr::get(
184 mlir::VectorType::get(
185 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
186 floatType),
187 floatOne);
188 auto one =
189 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
190 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
191 operands[0]);
192 return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
193 },
194 rewriter);
195 }
196 };
197
198 // A `rsqrt` is converted into `1 / sqrt`.
199 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
200 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
201
202 LogicalResult
matchAndRewrite__anoncde9de6d0111::RsqrtOpLowering203 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
204 ConversionPatternRewriter &rewriter) const override {
205 auto operandType = adaptor.getOperand().getType();
206
207 if (!operandType || !LLVM::isCompatibleType(operandType))
208 return failure();
209
210 auto loc = op.getLoc();
211 auto resultType = op.getResult().getType();
212 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
213 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
214
215 if (!operandType.isa<LLVM::LLVMArrayType>()) {
216 LLVM::ConstantOp one;
217 if (LLVM::isCompatibleVectorType(operandType)) {
218 one = rewriter.create<LLVM::ConstantOp>(
219 loc, operandType,
220 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
221 } else {
222 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
223 }
224 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
225 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
226 return success();
227 }
228
229 auto vectorType = resultType.dyn_cast<VectorType>();
230 if (!vectorType)
231 return failure();
232
233 return LLVM::detail::handleMultidimensionalVectors(
234 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
235 [&](Type llvm1DVectorTy, ValueRange operands) {
236 auto splatAttr = SplatElementsAttr::get(
237 mlir::VectorType::get(
238 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
239 floatType),
240 floatOne);
241 auto one =
242 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
243 auto sqrt =
244 rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
245 return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
246 },
247 rewriter);
248 }
249 };
250
251 struct ConvertMathToLLVMPass
252 : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
253 ConvertMathToLLVMPass() = default;
254
runOnOperation__anoncde9de6d0111::ConvertMathToLLVMPass255 void runOnOperation() override {
256 RewritePatternSet patterns(&getContext());
257 LLVMTypeConverter converter(&getContext());
258 populateMathToLLVMConversionPatterns(converter, patterns);
259 LLVMConversionTarget target(getContext());
260 if (failed(applyPartialConversion(getOperation(), target,
261 std::move(patterns))))
262 signalPassFailure();
263 }
264 };
265 } // namespace
266
populateMathToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)267 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
268 RewritePatternSet &patterns) {
269 // clang-format off
270 patterns.add<
271 AbsOpLowering,
272 CeilOpLowering,
273 CopySignOpLowering,
274 CosOpLowering,
275 CountLeadingZerosOpLowering,
276 CountTrailingZerosOpLowering,
277 CtPopFOpLowering,
278 Exp2OpLowering,
279 ExpM1OpLowering,
280 ExpOpLowering,
281 FloorOpLowering,
282 FmaOpLowering,
283 Log10OpLowering,
284 Log1pOpLowering,
285 Log2OpLowering,
286 LogOpLowering,
287 PowFOpLowering,
288 RoundOpLowering,
289 RsqrtOpLowering,
290 SinOpLowering,
291 SqrtOpLowering
292 >(converter);
293 // clang-format on
294 }
295
createConvertMathToLLVMPass()296 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
297 return std::make_unique<ConvertMathToLLVMPass>();
298 }
299