126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
226e59cc1SAlex Zinenko //
326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
626e59cc1SAlex Zinenko //
726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===//
826e59cc1SAlex Zinenko 
926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
1026e59cc1SAlex Zinenko #include "../PassDetail.h"
1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h"
1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h"
1726e59cc1SAlex Zinenko 
1826e59cc1SAlex Zinenko using namespace mlir;
1926e59cc1SAlex Zinenko 
2026e59cc1SAlex Zinenko namespace {
21a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23a54f4eaeSMogball using CopySignOpLowering =
24a54f4eaeSMogball     VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
26c5fef77bSRob Suderman using CtPopFOpLowering =
27c5fef77bSRob Suderman     VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
2826e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
2926e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
30a54f4eaeSMogball using FloorOpLowering =
31a54f4eaeSMogball     VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
32a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
3326e59cc1SAlex Zinenko using Log10OpLowering =
3426e59cc1SAlex Zinenko     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
3526e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
3626e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
3726e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
3826e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
3926e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
4026e59cc1SAlex Zinenko 
4123149d52SRob Suderman // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
4223149d52SRob Suderman template <typename MathOp, typename LLVMOp>
4323149d52SRob Suderman struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
4423149d52SRob Suderman   using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
4523149d52SRob Suderman   using Super = CountOpLowering<MathOp, LLVMOp>;
4623149d52SRob Suderman 
4723149d52SRob Suderman   LogicalResult
4823149d52SRob Suderman   matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
4923149d52SRob Suderman                   ConversionPatternRewriter &rewriter) const override {
5023149d52SRob Suderman     auto operandType = adaptor.getOperand().getType();
5123149d52SRob Suderman 
5223149d52SRob Suderman     if (!operandType || !LLVM::isCompatibleType(operandType))
5323149d52SRob Suderman       return failure();
5423149d52SRob Suderman 
5523149d52SRob Suderman     auto loc = op.getLoc();
5623149d52SRob Suderman     auto resultType = op.getResult().getType();
5723149d52SRob Suderman     auto boolType = rewriter.getIntegerType(1);
5823149d52SRob Suderman     auto boolZero = rewriter.getIntegerAttr(boolType, 0);
5923149d52SRob Suderman 
6023149d52SRob Suderman     if (!operandType.template isa<LLVM::LLVMArrayType>()) {
6123149d52SRob Suderman       LLVM::ConstantOp zero =
6223149d52SRob Suderman           rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
6323149d52SRob Suderman       rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
6423149d52SRob Suderman                                           zero);
6523149d52SRob Suderman       return success();
6623149d52SRob Suderman     }
6723149d52SRob Suderman 
6823149d52SRob Suderman     auto vectorType = resultType.template dyn_cast<VectorType>();
6923149d52SRob Suderman     if (!vectorType)
7023149d52SRob Suderman       return failure();
7123149d52SRob Suderman 
7223149d52SRob Suderman     return LLVM::detail::handleMultidimensionalVectors(
7323149d52SRob Suderman         op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
7423149d52SRob Suderman         [&](Type llvm1DVectorTy, ValueRange operands) {
7523149d52SRob Suderman           LLVM::ConstantOp zero =
7623149d52SRob Suderman               rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
77*cb4a5eaeSRobert Suderman           return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
78*cb4a5eaeSRobert Suderman                                          zero);
7923149d52SRob Suderman         },
8023149d52SRob Suderman         rewriter);
8123149d52SRob Suderman   }
8223149d52SRob Suderman };
8323149d52SRob Suderman 
8423149d52SRob Suderman using CountLeadingZerosOpLowering =
8523149d52SRob Suderman     CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
8623149d52SRob Suderman using CountTrailingZerosOpLowering =
8723149d52SRob Suderman     CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
8823149d52SRob Suderman 
8926e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`.
9026e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
9126e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
9226e59cc1SAlex Zinenko 
9326e59cc1SAlex Zinenko   LogicalResult
94ef976337SRiver Riddle   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
9526e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
9662fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
9726e59cc1SAlex Zinenko 
9826e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
9926e59cc1SAlex Zinenko       return failure();
10026e59cc1SAlex Zinenko 
10126e59cc1SAlex Zinenko     auto loc = op.getLoc();
10226e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
10326e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
10426e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
10526e59cc1SAlex Zinenko 
10626e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
10726e59cc1SAlex Zinenko       LLVM::ConstantOp one;
10826e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
10926e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
11026e59cc1SAlex Zinenko             loc, operandType,
11126e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
11226e59cc1SAlex Zinenko       } else {
11326e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
11426e59cc1SAlex Zinenko       }
11562fea88bSJacques Pienaar       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
11626e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
11726e59cc1SAlex Zinenko       return success();
11826e59cc1SAlex Zinenko     }
11926e59cc1SAlex Zinenko 
12026e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
12126e59cc1SAlex Zinenko     if (!vectorType)
12226e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
12326e59cc1SAlex Zinenko 
12426e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
125ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
12626e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
12726e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
12826e59cc1SAlex Zinenko               mlir::VectorType::get(
12926e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
13026e59cc1SAlex Zinenko                   floatType),
13126e59cc1SAlex Zinenko               floatOne);
13226e59cc1SAlex Zinenko           auto one =
13326e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
13426e59cc1SAlex Zinenko           auto exp =
13526e59cc1SAlex Zinenko               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
13626e59cc1SAlex Zinenko           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
13726e59cc1SAlex Zinenko         },
13826e59cc1SAlex Zinenko         rewriter);
13926e59cc1SAlex Zinenko   }
14026e59cc1SAlex Zinenko };
14126e59cc1SAlex Zinenko 
14226e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`.
14326e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
14426e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
14526e59cc1SAlex Zinenko 
14626e59cc1SAlex Zinenko   LogicalResult
147ef976337SRiver Riddle   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
14826e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
14962fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
15026e59cc1SAlex Zinenko 
15126e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
15226e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "unsupported operand type");
15326e59cc1SAlex Zinenko 
15426e59cc1SAlex Zinenko     auto loc = op.getLoc();
15526e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
15626e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
15726e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
15826e59cc1SAlex Zinenko 
15926e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
16026e59cc1SAlex Zinenko       LLVM::ConstantOp one =
16126e59cc1SAlex Zinenko           LLVM::isCompatibleVectorType(operandType)
16226e59cc1SAlex Zinenko               ? rewriter.create<LLVM::ConstantOp>(
16326e59cc1SAlex Zinenko                     loc, operandType,
16426e59cc1SAlex Zinenko                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
16526e59cc1SAlex Zinenko                                            floatOne))
16626e59cc1SAlex Zinenko               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
16726e59cc1SAlex Zinenko 
16826e59cc1SAlex Zinenko       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
16962fea88bSJacques Pienaar                                                adaptor.getOperand());
17026e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
17126e59cc1SAlex Zinenko       return success();
17226e59cc1SAlex Zinenko     }
17326e59cc1SAlex Zinenko 
17426e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
17526e59cc1SAlex Zinenko     if (!vectorType)
17626e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
17726e59cc1SAlex Zinenko 
17826e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
179ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
18026e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
18126e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
18226e59cc1SAlex Zinenko               mlir::VectorType::get(
18326e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
18426e59cc1SAlex Zinenko                   floatType),
18526e59cc1SAlex Zinenko               floatOne);
18626e59cc1SAlex Zinenko           auto one =
18726e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
18826e59cc1SAlex Zinenko           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
18926e59cc1SAlex Zinenko                                                    operands[0]);
19026e59cc1SAlex Zinenko           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
19126e59cc1SAlex Zinenko         },
19226e59cc1SAlex Zinenko         rewriter);
19326e59cc1SAlex Zinenko   }
19426e59cc1SAlex Zinenko };
19526e59cc1SAlex Zinenko 
19626e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`.
19726e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
19826e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
19926e59cc1SAlex Zinenko 
20026e59cc1SAlex Zinenko   LogicalResult
201ef976337SRiver Riddle   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
20226e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
20362fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
20426e59cc1SAlex Zinenko 
20526e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
20626e59cc1SAlex Zinenko       return failure();
20726e59cc1SAlex Zinenko 
20826e59cc1SAlex Zinenko     auto loc = op.getLoc();
20926e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
21026e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
21126e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
21226e59cc1SAlex Zinenko 
21326e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
21426e59cc1SAlex Zinenko       LLVM::ConstantOp one;
21526e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
21626e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
21726e59cc1SAlex Zinenko             loc, operandType,
21826e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
21926e59cc1SAlex Zinenko       } else {
22026e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
22126e59cc1SAlex Zinenko       }
22262fea88bSJacques Pienaar       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
22326e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
22426e59cc1SAlex Zinenko       return success();
22526e59cc1SAlex Zinenko     }
22626e59cc1SAlex Zinenko 
22726e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
22826e59cc1SAlex Zinenko     if (!vectorType)
22926e59cc1SAlex Zinenko       return failure();
23026e59cc1SAlex Zinenko 
23126e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
232ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
23326e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
23426e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
23526e59cc1SAlex Zinenko               mlir::VectorType::get(
23626e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
23726e59cc1SAlex Zinenko                   floatType),
23826e59cc1SAlex Zinenko               floatOne);
23926e59cc1SAlex Zinenko           auto one =
24026e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
24126e59cc1SAlex Zinenko           auto sqrt =
24226e59cc1SAlex Zinenko               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
24326e59cc1SAlex Zinenko           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
24426e59cc1SAlex Zinenko         },
24526e59cc1SAlex Zinenko         rewriter);
24626e59cc1SAlex Zinenko   }
24726e59cc1SAlex Zinenko };
24826e59cc1SAlex Zinenko 
24926e59cc1SAlex Zinenko struct ConvertMathToLLVMPass
25026e59cc1SAlex Zinenko     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
25126e59cc1SAlex Zinenko   ConvertMathToLLVMPass() = default;
25226e59cc1SAlex Zinenko 
25341574554SRiver Riddle   void runOnOperation() override {
25426e59cc1SAlex Zinenko     RewritePatternSet patterns(&getContext());
25526e59cc1SAlex Zinenko     LLVMTypeConverter converter(&getContext());
25626e59cc1SAlex Zinenko     populateMathToLLVMConversionPatterns(converter, patterns);
25726e59cc1SAlex Zinenko     LLVMConversionTarget target(getContext());
25841574554SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
25941574554SRiver Riddle                                       std::move(patterns))))
26026e59cc1SAlex Zinenko       signalPassFailure();
26126e59cc1SAlex Zinenko   }
26226e59cc1SAlex Zinenko };
26326e59cc1SAlex Zinenko } // namespace
26426e59cc1SAlex Zinenko 
26526e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
26626e59cc1SAlex Zinenko                                                 RewritePatternSet &patterns) {
26726e59cc1SAlex Zinenko   // clang-format off
26826e59cc1SAlex Zinenko   patterns.add<
269a54f4eaeSMogball     AbsOpLowering,
270a54f4eaeSMogball     CeilOpLowering,
271a54f4eaeSMogball     CopySignOpLowering,
27226e59cc1SAlex Zinenko     CosOpLowering,
27323149d52SRob Suderman     CountLeadingZerosOpLowering,
27423149d52SRob Suderman     CountTrailingZerosOpLowering,
275c5fef77bSRob Suderman     CtPopFOpLowering,
27626e59cc1SAlex Zinenko     ExpOpLowering,
27726e59cc1SAlex Zinenko     Exp2OpLowering,
27826e59cc1SAlex Zinenko     ExpM1OpLowering,
279a54f4eaeSMogball     FloorOpLowering,
280a54f4eaeSMogball     FmaOpLowering,
28126e59cc1SAlex Zinenko     Log10OpLowering,
28226e59cc1SAlex Zinenko     Log1pOpLowering,
28326e59cc1SAlex Zinenko     Log2OpLowering,
28426e59cc1SAlex Zinenko     LogOpLowering,
28526e59cc1SAlex Zinenko     PowFOpLowering,
28626e59cc1SAlex Zinenko     RsqrtOpLowering,
28726e59cc1SAlex Zinenko     SinOpLowering,
28826e59cc1SAlex Zinenko     SqrtOpLowering
28926e59cc1SAlex Zinenko   >(converter);
29026e59cc1SAlex Zinenko   // clang-format on
29126e59cc1SAlex Zinenko }
29226e59cc1SAlex Zinenko 
29326e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
29426e59cc1SAlex Zinenko   return std::make_unique<ConvertMathToLLVMPass>();
29526e59cc1SAlex Zinenko }
296