126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
226e59cc1SAlex Zinenko //
326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
626e59cc1SAlex Zinenko //
726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===//
826e59cc1SAlex Zinenko 
926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
1026e59cc1SAlex Zinenko #include "../PassDetail.h"
1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h"
1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h"
1726e59cc1SAlex Zinenko 
1826e59cc1SAlex Zinenko using namespace mlir;
1926e59cc1SAlex Zinenko 
2026e59cc1SAlex Zinenko namespace {
21a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23a54f4eaeSMogball using CopySignOpLowering =
24a54f4eaeSMogball     VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
26c5fef77bSRob Suderman using CtPopFOpLowering =
27c5fef77bSRob Suderman     VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
2826e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
2926e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
30a54f4eaeSMogball using FloorOpLowering =
31a54f4eaeSMogball     VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
32a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
3326e59cc1SAlex Zinenko using Log10OpLowering =
3426e59cc1SAlex Zinenko     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
3526e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
3626e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
3726e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
3826e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
3926e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
40*a0fc94abSlorenzo chelini using RoundOpLowering =
41*a0fc94abSlorenzo chelini     VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
4226e59cc1SAlex Zinenko 
4323149d52SRob Suderman // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
4423149d52SRob Suderman template <typename MathOp, typename LLVMOp>
4523149d52SRob Suderman struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
4623149d52SRob Suderman   using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
4723149d52SRob Suderman   using Super = CountOpLowering<MathOp, LLVMOp>;
4823149d52SRob Suderman 
4923149d52SRob Suderman   LogicalResult
5023149d52SRob Suderman   matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
5123149d52SRob Suderman                   ConversionPatternRewriter &rewriter) const override {
5223149d52SRob Suderman     auto operandType = adaptor.getOperand().getType();
5323149d52SRob Suderman 
5423149d52SRob Suderman     if (!operandType || !LLVM::isCompatibleType(operandType))
5523149d52SRob Suderman       return failure();
5623149d52SRob Suderman 
5723149d52SRob Suderman     auto loc = op.getLoc();
5823149d52SRob Suderman     auto resultType = op.getResult().getType();
5923149d52SRob Suderman     auto boolType = rewriter.getIntegerType(1);
6023149d52SRob Suderman     auto boolZero = rewriter.getIntegerAttr(boolType, 0);
6123149d52SRob Suderman 
6223149d52SRob Suderman     if (!operandType.template isa<LLVM::LLVMArrayType>()) {
6323149d52SRob Suderman       LLVM::ConstantOp zero =
6423149d52SRob Suderman           rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
6523149d52SRob Suderman       rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
6623149d52SRob Suderman                                           zero);
6723149d52SRob Suderman       return success();
6823149d52SRob Suderman     }
6923149d52SRob Suderman 
7023149d52SRob Suderman     auto vectorType = resultType.template dyn_cast<VectorType>();
7123149d52SRob Suderman     if (!vectorType)
7223149d52SRob Suderman       return failure();
7323149d52SRob Suderman 
7423149d52SRob Suderman     return LLVM::detail::handleMultidimensionalVectors(
7523149d52SRob Suderman         op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
7623149d52SRob Suderman         [&](Type llvm1DVectorTy, ValueRange operands) {
7723149d52SRob Suderman           LLVM::ConstantOp zero =
7823149d52SRob Suderman               rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
79cb4a5eaeSRobert Suderman           return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
80cb4a5eaeSRobert Suderman                                          zero);
8123149d52SRob Suderman         },
8223149d52SRob Suderman         rewriter);
8323149d52SRob Suderman   }
8423149d52SRob Suderman };
8523149d52SRob Suderman 
8623149d52SRob Suderman using CountLeadingZerosOpLowering =
8723149d52SRob Suderman     CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
8823149d52SRob Suderman using CountTrailingZerosOpLowering =
8923149d52SRob Suderman     CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
9023149d52SRob Suderman 
9126e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`.
9226e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
9326e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
9426e59cc1SAlex Zinenko 
9526e59cc1SAlex Zinenko   LogicalResult
96ef976337SRiver Riddle   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
9726e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
9862fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
9926e59cc1SAlex Zinenko 
10026e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
10126e59cc1SAlex Zinenko       return failure();
10226e59cc1SAlex Zinenko 
10326e59cc1SAlex Zinenko     auto loc = op.getLoc();
10426e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
10526e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
10626e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
10726e59cc1SAlex Zinenko 
10826e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
10926e59cc1SAlex Zinenko       LLVM::ConstantOp one;
11026e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
11126e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
11226e59cc1SAlex Zinenko             loc, operandType,
11326e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
11426e59cc1SAlex Zinenko       } else {
11526e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
11626e59cc1SAlex Zinenko       }
11762fea88bSJacques Pienaar       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
11826e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
11926e59cc1SAlex Zinenko       return success();
12026e59cc1SAlex Zinenko     }
12126e59cc1SAlex Zinenko 
12226e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
12326e59cc1SAlex Zinenko     if (!vectorType)
12426e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
12526e59cc1SAlex Zinenko 
12626e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
127ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
12826e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
12926e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
13026e59cc1SAlex Zinenko               mlir::VectorType::get(
13126e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
13226e59cc1SAlex Zinenko                   floatType),
13326e59cc1SAlex Zinenko               floatOne);
13426e59cc1SAlex Zinenko           auto one =
13526e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
13626e59cc1SAlex Zinenko           auto exp =
13726e59cc1SAlex Zinenko               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
13826e59cc1SAlex Zinenko           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
13926e59cc1SAlex Zinenko         },
14026e59cc1SAlex Zinenko         rewriter);
14126e59cc1SAlex Zinenko   }
14226e59cc1SAlex Zinenko };
14326e59cc1SAlex Zinenko 
14426e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`.
14526e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
14626e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
14726e59cc1SAlex Zinenko 
14826e59cc1SAlex Zinenko   LogicalResult
149ef976337SRiver Riddle   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
15026e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
15162fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
15226e59cc1SAlex Zinenko 
15326e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
15426e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "unsupported operand type");
15526e59cc1SAlex Zinenko 
15626e59cc1SAlex Zinenko     auto loc = op.getLoc();
15726e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
15826e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
15926e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
16026e59cc1SAlex Zinenko 
16126e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
16226e59cc1SAlex Zinenko       LLVM::ConstantOp one =
16326e59cc1SAlex Zinenko           LLVM::isCompatibleVectorType(operandType)
16426e59cc1SAlex Zinenko               ? rewriter.create<LLVM::ConstantOp>(
16526e59cc1SAlex Zinenko                     loc, operandType,
16626e59cc1SAlex Zinenko                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
16726e59cc1SAlex Zinenko                                            floatOne))
16826e59cc1SAlex Zinenko               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
16926e59cc1SAlex Zinenko 
17026e59cc1SAlex Zinenko       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
17162fea88bSJacques Pienaar                                                adaptor.getOperand());
17226e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
17326e59cc1SAlex Zinenko       return success();
17426e59cc1SAlex Zinenko     }
17526e59cc1SAlex Zinenko 
17626e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
17726e59cc1SAlex Zinenko     if (!vectorType)
17826e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
17926e59cc1SAlex Zinenko 
18026e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
181ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
18226e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
18326e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
18426e59cc1SAlex Zinenko               mlir::VectorType::get(
18526e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
18626e59cc1SAlex Zinenko                   floatType),
18726e59cc1SAlex Zinenko               floatOne);
18826e59cc1SAlex Zinenko           auto one =
18926e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
19026e59cc1SAlex Zinenko           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
19126e59cc1SAlex Zinenko                                                    operands[0]);
19226e59cc1SAlex Zinenko           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
19326e59cc1SAlex Zinenko         },
19426e59cc1SAlex Zinenko         rewriter);
19526e59cc1SAlex Zinenko   }
19626e59cc1SAlex Zinenko };
19726e59cc1SAlex Zinenko 
19826e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`.
19926e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
20026e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
20126e59cc1SAlex Zinenko 
20226e59cc1SAlex Zinenko   LogicalResult
203ef976337SRiver Riddle   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
20426e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
20562fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
20626e59cc1SAlex Zinenko 
20726e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
20826e59cc1SAlex Zinenko       return failure();
20926e59cc1SAlex Zinenko 
21026e59cc1SAlex Zinenko     auto loc = op.getLoc();
21126e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
21226e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
21326e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
21426e59cc1SAlex Zinenko 
21526e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
21626e59cc1SAlex Zinenko       LLVM::ConstantOp one;
21726e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
21826e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
21926e59cc1SAlex Zinenko             loc, operandType,
22026e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
22126e59cc1SAlex Zinenko       } else {
22226e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
22326e59cc1SAlex Zinenko       }
22462fea88bSJacques Pienaar       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
22526e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
22626e59cc1SAlex Zinenko       return success();
22726e59cc1SAlex Zinenko     }
22826e59cc1SAlex Zinenko 
22926e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
23026e59cc1SAlex Zinenko     if (!vectorType)
23126e59cc1SAlex Zinenko       return failure();
23226e59cc1SAlex Zinenko 
23326e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
234ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
23526e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
23626e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
23726e59cc1SAlex Zinenko               mlir::VectorType::get(
23826e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
23926e59cc1SAlex Zinenko                   floatType),
24026e59cc1SAlex Zinenko               floatOne);
24126e59cc1SAlex Zinenko           auto one =
24226e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
24326e59cc1SAlex Zinenko           auto sqrt =
24426e59cc1SAlex Zinenko               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
24526e59cc1SAlex Zinenko           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
24626e59cc1SAlex Zinenko         },
24726e59cc1SAlex Zinenko         rewriter);
24826e59cc1SAlex Zinenko   }
24926e59cc1SAlex Zinenko };
25026e59cc1SAlex Zinenko 
25126e59cc1SAlex Zinenko struct ConvertMathToLLVMPass
25226e59cc1SAlex Zinenko     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
25326e59cc1SAlex Zinenko   ConvertMathToLLVMPass() = default;
25426e59cc1SAlex Zinenko 
25541574554SRiver Riddle   void runOnOperation() override {
25626e59cc1SAlex Zinenko     RewritePatternSet patterns(&getContext());
25726e59cc1SAlex Zinenko     LLVMTypeConverter converter(&getContext());
25826e59cc1SAlex Zinenko     populateMathToLLVMConversionPatterns(converter, patterns);
25926e59cc1SAlex Zinenko     LLVMConversionTarget target(getContext());
26041574554SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
26141574554SRiver Riddle                                       std::move(patterns))))
26226e59cc1SAlex Zinenko       signalPassFailure();
26326e59cc1SAlex Zinenko   }
26426e59cc1SAlex Zinenko };
26526e59cc1SAlex Zinenko } // namespace
26626e59cc1SAlex Zinenko 
26726e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
26826e59cc1SAlex Zinenko                                                 RewritePatternSet &patterns) {
26926e59cc1SAlex Zinenko   // clang-format off
27026e59cc1SAlex Zinenko   patterns.add<
271a54f4eaeSMogball     AbsOpLowering,
272a54f4eaeSMogball     CeilOpLowering,
273a54f4eaeSMogball     CopySignOpLowering,
27426e59cc1SAlex Zinenko     CosOpLowering,
27523149d52SRob Suderman     CountLeadingZerosOpLowering,
27623149d52SRob Suderman     CountTrailingZerosOpLowering,
277c5fef77bSRob Suderman     CtPopFOpLowering,
27826e59cc1SAlex Zinenko     ExpOpLowering,
27926e59cc1SAlex Zinenko     Exp2OpLowering,
28026e59cc1SAlex Zinenko     ExpM1OpLowering,
281a54f4eaeSMogball     FloorOpLowering,
282a54f4eaeSMogball     FmaOpLowering,
28326e59cc1SAlex Zinenko     Log10OpLowering,
28426e59cc1SAlex Zinenko     Log1pOpLowering,
28526e59cc1SAlex Zinenko     Log2OpLowering,
28626e59cc1SAlex Zinenko     LogOpLowering,
28726e59cc1SAlex Zinenko     PowFOpLowering,
28826e59cc1SAlex Zinenko     RsqrtOpLowering,
28926e59cc1SAlex Zinenko     SinOpLowering,
290*a0fc94abSlorenzo chelini     SqrtOpLowering,
291*a0fc94abSlorenzo chelini     RoundOpLowering
29226e59cc1SAlex Zinenko   >(converter);
29326e59cc1SAlex Zinenko   // clang-format on
29426e59cc1SAlex Zinenko }
29526e59cc1SAlex Zinenko 
29626e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
29726e59cc1SAlex Zinenko   return std::make_unique<ConvertMathToLLVMPass>();
29826e59cc1SAlex Zinenko }
299