126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
226e59cc1SAlex Zinenko //
326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
626e59cc1SAlex Zinenko //
726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===//
826e59cc1SAlex Zinenko
926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
1026e59cc1SAlex Zinenko #include "../PassDetail.h"
1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h"
1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h"
1726e59cc1SAlex Zinenko
1826e59cc1SAlex Zinenko using namespace mlir;
1926e59cc1SAlex Zinenko
2026e59cc1SAlex Zinenko namespace {
21a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23a54f4eaeSMogball using CopySignOpLowering =
24a54f4eaeSMogball VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
26c5fef77bSRob Suderman using CtPopFOpLowering =
27c5fef77bSRob Suderman VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
2826e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
29*2a3c07f8Slorenzo chelini using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
30a54f4eaeSMogball using FloorOpLowering =
31a54f4eaeSMogball VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
32a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
3326e59cc1SAlex Zinenko using Log10OpLowering =
3426e59cc1SAlex Zinenko VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
3526e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
3626e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
3726e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
38a0fc94abSlorenzo chelini using RoundOpLowering =
39a0fc94abSlorenzo chelini VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
40*2a3c07f8Slorenzo chelini using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
41*2a3c07f8Slorenzo chelini using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
4226e59cc1SAlex Zinenko
4323149d52SRob Suderman // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
4423149d52SRob Suderman template <typename MathOp, typename LLVMOp>
4523149d52SRob Suderman struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
4623149d52SRob Suderman using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
4723149d52SRob Suderman using Super = CountOpLowering<MathOp, LLVMOp>;
4823149d52SRob Suderman
4923149d52SRob Suderman LogicalResult
matchAndRewrite__anoncde9de6d0111::CountOpLowering5023149d52SRob Suderman matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
5123149d52SRob Suderman ConversionPatternRewriter &rewriter) const override {
5223149d52SRob Suderman auto operandType = adaptor.getOperand().getType();
5323149d52SRob Suderman
5423149d52SRob Suderman if (!operandType || !LLVM::isCompatibleType(operandType))
5523149d52SRob Suderman return failure();
5623149d52SRob Suderman
5723149d52SRob Suderman auto loc = op.getLoc();
5823149d52SRob Suderman auto resultType = op.getResult().getType();
5923149d52SRob Suderman auto boolType = rewriter.getIntegerType(1);
6023149d52SRob Suderman auto boolZero = rewriter.getIntegerAttr(boolType, 0);
6123149d52SRob Suderman
6223149d52SRob Suderman if (!operandType.template isa<LLVM::LLVMArrayType>()) {
6323149d52SRob Suderman LLVM::ConstantOp zero =
6423149d52SRob Suderman rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
6523149d52SRob Suderman rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
6623149d52SRob Suderman zero);
6723149d52SRob Suderman return success();
6823149d52SRob Suderman }
6923149d52SRob Suderman
7023149d52SRob Suderman auto vectorType = resultType.template dyn_cast<VectorType>();
7123149d52SRob Suderman if (!vectorType)
7223149d52SRob Suderman return failure();
7323149d52SRob Suderman
7423149d52SRob Suderman return LLVM::detail::handleMultidimensionalVectors(
7523149d52SRob Suderman op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
7623149d52SRob Suderman [&](Type llvm1DVectorTy, ValueRange operands) {
7723149d52SRob Suderman LLVM::ConstantOp zero =
7823149d52SRob Suderman rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
79cb4a5eaeSRobert Suderman return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
80cb4a5eaeSRobert Suderman zero);
8123149d52SRob Suderman },
8223149d52SRob Suderman rewriter);
8323149d52SRob Suderman }
8423149d52SRob Suderman };
8523149d52SRob Suderman
8623149d52SRob Suderman using CountLeadingZerosOpLowering =
8723149d52SRob Suderman CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
8823149d52SRob Suderman using CountTrailingZerosOpLowering =
8923149d52SRob Suderman CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
9023149d52SRob Suderman
9126e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`.
9226e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
9326e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
9426e59cc1SAlex Zinenko
9526e59cc1SAlex Zinenko LogicalResult
matchAndRewrite__anoncde9de6d0111::ExpM1OpLowering96ef976337SRiver Riddle matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
9726e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override {
9862fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType();
9926e59cc1SAlex Zinenko
10026e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType))
10126e59cc1SAlex Zinenko return failure();
10226e59cc1SAlex Zinenko
10326e59cc1SAlex Zinenko auto loc = op.getLoc();
10426e59cc1SAlex Zinenko auto resultType = op.getResult().getType();
10526e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
10626e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
10726e59cc1SAlex Zinenko
10826e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) {
10926e59cc1SAlex Zinenko LLVM::ConstantOp one;
11026e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) {
11126e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(
11226e59cc1SAlex Zinenko loc, operandType,
11326e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
11426e59cc1SAlex Zinenko } else {
11526e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
11626e59cc1SAlex Zinenko }
11762fea88bSJacques Pienaar auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
11826e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
11926e59cc1SAlex Zinenko return success();
12026e59cc1SAlex Zinenko }
12126e59cc1SAlex Zinenko
12226e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>();
12326e59cc1SAlex Zinenko if (!vectorType)
12426e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type");
12526e59cc1SAlex Zinenko
12626e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors(
127ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
12826e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) {
12926e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get(
13026e59cc1SAlex Zinenko mlir::VectorType::get(
13126e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
13226e59cc1SAlex Zinenko floatType),
13326e59cc1SAlex Zinenko floatOne);
13426e59cc1SAlex Zinenko auto one =
13526e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
13626e59cc1SAlex Zinenko auto exp =
13726e59cc1SAlex Zinenko rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
13826e59cc1SAlex Zinenko return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
13926e59cc1SAlex Zinenko },
14026e59cc1SAlex Zinenko rewriter);
14126e59cc1SAlex Zinenko }
14226e59cc1SAlex Zinenko };
14326e59cc1SAlex Zinenko
14426e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`.
14526e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
14626e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
14726e59cc1SAlex Zinenko
14826e59cc1SAlex Zinenko LogicalResult
matchAndRewrite__anoncde9de6d0111::Log1pOpLowering149ef976337SRiver Riddle matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
15026e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override {
15162fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType();
15226e59cc1SAlex Zinenko
15326e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType))
15426e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "unsupported operand type");
15526e59cc1SAlex Zinenko
15626e59cc1SAlex Zinenko auto loc = op.getLoc();
15726e59cc1SAlex Zinenko auto resultType = op.getResult().getType();
15826e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
15926e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
16026e59cc1SAlex Zinenko
16126e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) {
16226e59cc1SAlex Zinenko LLVM::ConstantOp one =
16326e59cc1SAlex Zinenko LLVM::isCompatibleVectorType(operandType)
16426e59cc1SAlex Zinenko ? rewriter.create<LLVM::ConstantOp>(
16526e59cc1SAlex Zinenko loc, operandType,
16626e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(),
16726e59cc1SAlex Zinenko floatOne))
16826e59cc1SAlex Zinenko : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
16926e59cc1SAlex Zinenko
17026e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
17162fea88bSJacques Pienaar adaptor.getOperand());
17226e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
17326e59cc1SAlex Zinenko return success();
17426e59cc1SAlex Zinenko }
17526e59cc1SAlex Zinenko
17626e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>();
17726e59cc1SAlex Zinenko if (!vectorType)
17826e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type");
17926e59cc1SAlex Zinenko
18026e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors(
181ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
18226e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) {
18326e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get(
18426e59cc1SAlex Zinenko mlir::VectorType::get(
18526e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
18626e59cc1SAlex Zinenko floatType),
18726e59cc1SAlex Zinenko floatOne);
18826e59cc1SAlex Zinenko auto one =
18926e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
19026e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
19126e59cc1SAlex Zinenko operands[0]);
19226e59cc1SAlex Zinenko return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
19326e59cc1SAlex Zinenko },
19426e59cc1SAlex Zinenko rewriter);
19526e59cc1SAlex Zinenko }
19626e59cc1SAlex Zinenko };
19726e59cc1SAlex Zinenko
19826e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`.
19926e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
20026e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
20126e59cc1SAlex Zinenko
20226e59cc1SAlex Zinenko LogicalResult
matchAndRewrite__anoncde9de6d0111::RsqrtOpLowering203ef976337SRiver Riddle matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
20426e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override {
20562fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType();
20626e59cc1SAlex Zinenko
20726e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType))
20826e59cc1SAlex Zinenko return failure();
20926e59cc1SAlex Zinenko
21026e59cc1SAlex Zinenko auto loc = op.getLoc();
21126e59cc1SAlex Zinenko auto resultType = op.getResult().getType();
21226e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
21326e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
21426e59cc1SAlex Zinenko
21526e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) {
21626e59cc1SAlex Zinenko LLVM::ConstantOp one;
21726e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) {
21826e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(
21926e59cc1SAlex Zinenko loc, operandType,
22026e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
22126e59cc1SAlex Zinenko } else {
22226e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
22326e59cc1SAlex Zinenko }
22462fea88bSJacques Pienaar auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
22526e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
22626e59cc1SAlex Zinenko return success();
22726e59cc1SAlex Zinenko }
22826e59cc1SAlex Zinenko
22926e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>();
23026e59cc1SAlex Zinenko if (!vectorType)
23126e59cc1SAlex Zinenko return failure();
23226e59cc1SAlex Zinenko
23326e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors(
234ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
23526e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) {
23626e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get(
23726e59cc1SAlex Zinenko mlir::VectorType::get(
23826e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
23926e59cc1SAlex Zinenko floatType),
24026e59cc1SAlex Zinenko floatOne);
24126e59cc1SAlex Zinenko auto one =
24226e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
24326e59cc1SAlex Zinenko auto sqrt =
24426e59cc1SAlex Zinenko rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
24526e59cc1SAlex Zinenko return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
24626e59cc1SAlex Zinenko },
24726e59cc1SAlex Zinenko rewriter);
24826e59cc1SAlex Zinenko }
24926e59cc1SAlex Zinenko };
25026e59cc1SAlex Zinenko
25126e59cc1SAlex Zinenko struct ConvertMathToLLVMPass
25226e59cc1SAlex Zinenko : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
25326e59cc1SAlex Zinenko ConvertMathToLLVMPass() = default;
25426e59cc1SAlex Zinenko
runOnOperation__anoncde9de6d0111::ConvertMathToLLVMPass25541574554SRiver Riddle void runOnOperation() override {
25626e59cc1SAlex Zinenko RewritePatternSet patterns(&getContext());
25726e59cc1SAlex Zinenko LLVMTypeConverter converter(&getContext());
25826e59cc1SAlex Zinenko populateMathToLLVMConversionPatterns(converter, patterns);
25926e59cc1SAlex Zinenko LLVMConversionTarget target(getContext());
26041574554SRiver Riddle if (failed(applyPartialConversion(getOperation(), target,
26141574554SRiver Riddle std::move(patterns))))
26226e59cc1SAlex Zinenko signalPassFailure();
26326e59cc1SAlex Zinenko }
26426e59cc1SAlex Zinenko };
26526e59cc1SAlex Zinenko } // namespace
26626e59cc1SAlex Zinenko
populateMathToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)26726e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
26826e59cc1SAlex Zinenko RewritePatternSet &patterns) {
26926e59cc1SAlex Zinenko // clang-format off
27026e59cc1SAlex Zinenko patterns.add<
271a54f4eaeSMogball AbsOpLowering,
272a54f4eaeSMogball CeilOpLowering,
273a54f4eaeSMogball CopySignOpLowering,
27426e59cc1SAlex Zinenko CosOpLowering,
27523149d52SRob Suderman CountLeadingZerosOpLowering,
27623149d52SRob Suderman CountTrailingZerosOpLowering,
277c5fef77bSRob Suderman CtPopFOpLowering,
27826e59cc1SAlex Zinenko Exp2OpLowering,
27926e59cc1SAlex Zinenko ExpM1OpLowering,
280*2a3c07f8Slorenzo chelini ExpOpLowering,
281a54f4eaeSMogball FloorOpLowering,
282a54f4eaeSMogball FmaOpLowering,
28326e59cc1SAlex Zinenko Log10OpLowering,
28426e59cc1SAlex Zinenko Log1pOpLowering,
28526e59cc1SAlex Zinenko Log2OpLowering,
28626e59cc1SAlex Zinenko LogOpLowering,
28726e59cc1SAlex Zinenko PowFOpLowering,
288*2a3c07f8Slorenzo chelini RoundOpLowering,
28926e59cc1SAlex Zinenko RsqrtOpLowering,
29026e59cc1SAlex Zinenko SinOpLowering,
291*2a3c07f8Slorenzo chelini SqrtOpLowering
29226e59cc1SAlex Zinenko >(converter);
29326e59cc1SAlex Zinenko // clang-format on
29426e59cc1SAlex Zinenko }
29526e59cc1SAlex Zinenko
createConvertMathToLLVMPass()29626e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
29726e59cc1SAlex Zinenko return std::make_unique<ConvertMathToLLVMPass>();
29826e59cc1SAlex Zinenko }
299