126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
226e59cc1SAlex Zinenko //
326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
626e59cc1SAlex Zinenko //
726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===//
826e59cc1SAlex Zinenko 
926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
1026e59cc1SAlex Zinenko #include "../PassDetail.h"
1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h"
1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h"
1726e59cc1SAlex Zinenko 
1826e59cc1SAlex Zinenko using namespace mlir;
1926e59cc1SAlex Zinenko 
2026e59cc1SAlex Zinenko namespace {
21a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23a54f4eaeSMogball using CopySignOpLowering =
24a54f4eaeSMogball     VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
26*c5fef77bSRob Suderman using CtPopFOpLowering =
27*c5fef77bSRob Suderman     VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
2826e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
2926e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
30a54f4eaeSMogball using FloorOpLowering =
31a54f4eaeSMogball     VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
32a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
3326e59cc1SAlex Zinenko using Log10OpLowering =
3426e59cc1SAlex Zinenko     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
3526e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
3626e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
3726e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
3826e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
3926e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
4026e59cc1SAlex Zinenko 
4126e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`.
4226e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
4326e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
4426e59cc1SAlex Zinenko 
4526e59cc1SAlex Zinenko   LogicalResult
46ef976337SRiver Riddle   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
4726e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
4862fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
4926e59cc1SAlex Zinenko 
5026e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
5126e59cc1SAlex Zinenko       return failure();
5226e59cc1SAlex Zinenko 
5326e59cc1SAlex Zinenko     auto loc = op.getLoc();
5426e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
5526e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
5626e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
5726e59cc1SAlex Zinenko 
5826e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
5926e59cc1SAlex Zinenko       LLVM::ConstantOp one;
6026e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
6126e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
6226e59cc1SAlex Zinenko             loc, operandType,
6326e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
6426e59cc1SAlex Zinenko       } else {
6526e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
6626e59cc1SAlex Zinenko       }
6762fea88bSJacques Pienaar       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
6826e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
6926e59cc1SAlex Zinenko       return success();
7026e59cc1SAlex Zinenko     }
7126e59cc1SAlex Zinenko 
7226e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
7326e59cc1SAlex Zinenko     if (!vectorType)
7426e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
7526e59cc1SAlex Zinenko 
7626e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
77ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
7826e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
7926e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
8026e59cc1SAlex Zinenko               mlir::VectorType::get(
8126e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
8226e59cc1SAlex Zinenko                   floatType),
8326e59cc1SAlex Zinenko               floatOne);
8426e59cc1SAlex Zinenko           auto one =
8526e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
8626e59cc1SAlex Zinenko           auto exp =
8726e59cc1SAlex Zinenko               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
8826e59cc1SAlex Zinenko           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
8926e59cc1SAlex Zinenko         },
9026e59cc1SAlex Zinenko         rewriter);
9126e59cc1SAlex Zinenko   }
9226e59cc1SAlex Zinenko };
9326e59cc1SAlex Zinenko 
9426e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`.
9526e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
9626e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
9726e59cc1SAlex Zinenko 
9826e59cc1SAlex Zinenko   LogicalResult
99ef976337SRiver Riddle   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
10026e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
10162fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
10226e59cc1SAlex Zinenko 
10326e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
10426e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "unsupported operand type");
10526e59cc1SAlex Zinenko 
10626e59cc1SAlex Zinenko     auto loc = op.getLoc();
10726e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
10826e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
10926e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
11026e59cc1SAlex Zinenko 
11126e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
11226e59cc1SAlex Zinenko       LLVM::ConstantOp one =
11326e59cc1SAlex Zinenko           LLVM::isCompatibleVectorType(operandType)
11426e59cc1SAlex Zinenko               ? rewriter.create<LLVM::ConstantOp>(
11526e59cc1SAlex Zinenko                     loc, operandType,
11626e59cc1SAlex Zinenko                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
11726e59cc1SAlex Zinenko                                            floatOne))
11826e59cc1SAlex Zinenko               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
11926e59cc1SAlex Zinenko 
12026e59cc1SAlex Zinenko       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
12162fea88bSJacques Pienaar                                                adaptor.getOperand());
12226e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
12326e59cc1SAlex Zinenko       return success();
12426e59cc1SAlex Zinenko     }
12526e59cc1SAlex Zinenko 
12626e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
12726e59cc1SAlex Zinenko     if (!vectorType)
12826e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
12926e59cc1SAlex Zinenko 
13026e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
131ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
13226e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
13326e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
13426e59cc1SAlex Zinenko               mlir::VectorType::get(
13526e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
13626e59cc1SAlex Zinenko                   floatType),
13726e59cc1SAlex Zinenko               floatOne);
13826e59cc1SAlex Zinenko           auto one =
13926e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
14026e59cc1SAlex Zinenko           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
14126e59cc1SAlex Zinenko                                                    operands[0]);
14226e59cc1SAlex Zinenko           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
14326e59cc1SAlex Zinenko         },
14426e59cc1SAlex Zinenko         rewriter);
14526e59cc1SAlex Zinenko   }
14626e59cc1SAlex Zinenko };
14726e59cc1SAlex Zinenko 
14826e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`.
14926e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
15026e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
15126e59cc1SAlex Zinenko 
15226e59cc1SAlex Zinenko   LogicalResult
153ef976337SRiver Riddle   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
15426e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
15562fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
15626e59cc1SAlex Zinenko 
15726e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
15826e59cc1SAlex Zinenko       return failure();
15926e59cc1SAlex Zinenko 
16026e59cc1SAlex Zinenko     auto loc = op.getLoc();
16126e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
16226e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
16326e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
16426e59cc1SAlex Zinenko 
16526e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
16626e59cc1SAlex Zinenko       LLVM::ConstantOp one;
16726e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
16826e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
16926e59cc1SAlex Zinenko             loc, operandType,
17026e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
17126e59cc1SAlex Zinenko       } else {
17226e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
17326e59cc1SAlex Zinenko       }
17462fea88bSJacques Pienaar       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
17526e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
17626e59cc1SAlex Zinenko       return success();
17726e59cc1SAlex Zinenko     }
17826e59cc1SAlex Zinenko 
17926e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
18026e59cc1SAlex Zinenko     if (!vectorType)
18126e59cc1SAlex Zinenko       return failure();
18226e59cc1SAlex Zinenko 
18326e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
184ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
18526e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
18626e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
18726e59cc1SAlex Zinenko               mlir::VectorType::get(
18826e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
18926e59cc1SAlex Zinenko                   floatType),
19026e59cc1SAlex Zinenko               floatOne);
19126e59cc1SAlex Zinenko           auto one =
19226e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
19326e59cc1SAlex Zinenko           auto sqrt =
19426e59cc1SAlex Zinenko               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
19526e59cc1SAlex Zinenko           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
19626e59cc1SAlex Zinenko         },
19726e59cc1SAlex Zinenko         rewriter);
19826e59cc1SAlex Zinenko   }
19926e59cc1SAlex Zinenko };
20026e59cc1SAlex Zinenko 
20126e59cc1SAlex Zinenko struct ConvertMathToLLVMPass
20226e59cc1SAlex Zinenko     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
20326e59cc1SAlex Zinenko   ConvertMathToLLVMPass() = default;
20426e59cc1SAlex Zinenko 
20526e59cc1SAlex Zinenko   void runOnFunction() override {
20626e59cc1SAlex Zinenko     RewritePatternSet patterns(&getContext());
20726e59cc1SAlex Zinenko     LLVMTypeConverter converter(&getContext());
20826e59cc1SAlex Zinenko     populateMathToLLVMConversionPatterns(converter, patterns);
20926e59cc1SAlex Zinenko     LLVMConversionTarget target(getContext());
21026e59cc1SAlex Zinenko     if (failed(
21126e59cc1SAlex Zinenko             applyPartialConversion(getFunction(), target, std::move(patterns))))
21226e59cc1SAlex Zinenko       signalPassFailure();
21326e59cc1SAlex Zinenko   }
21426e59cc1SAlex Zinenko };
21526e59cc1SAlex Zinenko } // namespace
21626e59cc1SAlex Zinenko 
21726e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
21826e59cc1SAlex Zinenko                                                 RewritePatternSet &patterns) {
21926e59cc1SAlex Zinenko   // clang-format off
22026e59cc1SAlex Zinenko   patterns.add<
221a54f4eaeSMogball     AbsOpLowering,
222a54f4eaeSMogball     CeilOpLowering,
223a54f4eaeSMogball     CopySignOpLowering,
22426e59cc1SAlex Zinenko     CosOpLowering,
225*c5fef77bSRob Suderman     CtPopFOpLowering,
22626e59cc1SAlex Zinenko     ExpOpLowering,
22726e59cc1SAlex Zinenko     Exp2OpLowering,
22826e59cc1SAlex Zinenko     ExpM1OpLowering,
229a54f4eaeSMogball     FloorOpLowering,
230a54f4eaeSMogball     FmaOpLowering,
23126e59cc1SAlex Zinenko     Log10OpLowering,
23226e59cc1SAlex Zinenko     Log1pOpLowering,
23326e59cc1SAlex Zinenko     Log2OpLowering,
23426e59cc1SAlex Zinenko     LogOpLowering,
23526e59cc1SAlex Zinenko     PowFOpLowering,
23626e59cc1SAlex Zinenko     RsqrtOpLowering,
23726e59cc1SAlex Zinenko     SinOpLowering,
23826e59cc1SAlex Zinenko     SqrtOpLowering
23926e59cc1SAlex Zinenko   >(converter);
24026e59cc1SAlex Zinenko   // clang-format on
24126e59cc1SAlex Zinenko }
24226e59cc1SAlex Zinenko 
24326e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
24426e59cc1SAlex Zinenko   return std::make_unique<ConvertMathToLLVMPass>();
24526e59cc1SAlex Zinenko }
246