126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
226e59cc1SAlex Zinenko //
326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
626e59cc1SAlex Zinenko //
726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===//
826e59cc1SAlex Zinenko 
926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
1026e59cc1SAlex Zinenko #include "../PassDetail.h"
1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h"
1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h"
1726e59cc1SAlex Zinenko 
1826e59cc1SAlex Zinenko using namespace mlir;
1926e59cc1SAlex Zinenko 
2026e59cc1SAlex Zinenko namespace {
2126e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
2226e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
2326e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
2426e59cc1SAlex Zinenko using Log10OpLowering =
2526e59cc1SAlex Zinenko     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
2626e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
2726e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
2826e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
2926e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
3026e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
3126e59cc1SAlex Zinenko 
3226e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`.
3326e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
3426e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
3526e59cc1SAlex Zinenko 
3626e59cc1SAlex Zinenko   LogicalResult
37*ef976337SRiver Riddle   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
3826e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
39*ef976337SRiver Riddle     auto operandType = adaptor.operand().getType();
4026e59cc1SAlex Zinenko 
4126e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
4226e59cc1SAlex Zinenko       return failure();
4326e59cc1SAlex Zinenko 
4426e59cc1SAlex Zinenko     auto loc = op.getLoc();
4526e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
4626e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
4726e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
4826e59cc1SAlex Zinenko 
4926e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
5026e59cc1SAlex Zinenko       LLVM::ConstantOp one;
5126e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
5226e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
5326e59cc1SAlex Zinenko             loc, operandType,
5426e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
5526e59cc1SAlex Zinenko       } else {
5626e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
5726e59cc1SAlex Zinenko       }
58*ef976337SRiver Riddle       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.operand());
5926e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
6026e59cc1SAlex Zinenko       return success();
6126e59cc1SAlex Zinenko     }
6226e59cc1SAlex Zinenko 
6326e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
6426e59cc1SAlex Zinenko     if (!vectorType)
6526e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
6626e59cc1SAlex Zinenko 
6726e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
68*ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
6926e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
7026e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
7126e59cc1SAlex Zinenko               mlir::VectorType::get(
7226e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
7326e59cc1SAlex Zinenko                   floatType),
7426e59cc1SAlex Zinenko               floatOne);
7526e59cc1SAlex Zinenko           auto one =
7626e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
7726e59cc1SAlex Zinenko           auto exp =
7826e59cc1SAlex Zinenko               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
7926e59cc1SAlex Zinenko           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
8026e59cc1SAlex Zinenko         },
8126e59cc1SAlex Zinenko         rewriter);
8226e59cc1SAlex Zinenko   }
8326e59cc1SAlex Zinenko };
8426e59cc1SAlex Zinenko 
8526e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`.
8626e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
8726e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
8826e59cc1SAlex Zinenko 
8926e59cc1SAlex Zinenko   LogicalResult
90*ef976337SRiver Riddle   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
9126e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
92*ef976337SRiver Riddle     auto operandType = adaptor.operand().getType();
9326e59cc1SAlex Zinenko 
9426e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
9526e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "unsupported operand type");
9626e59cc1SAlex Zinenko 
9726e59cc1SAlex Zinenko     auto loc = op.getLoc();
9826e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
9926e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
10026e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
10126e59cc1SAlex Zinenko 
10226e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
10326e59cc1SAlex Zinenko       LLVM::ConstantOp one =
10426e59cc1SAlex Zinenko           LLVM::isCompatibleVectorType(operandType)
10526e59cc1SAlex Zinenko               ? rewriter.create<LLVM::ConstantOp>(
10626e59cc1SAlex Zinenko                     loc, operandType,
10726e59cc1SAlex Zinenko                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
10826e59cc1SAlex Zinenko                                            floatOne))
10926e59cc1SAlex Zinenko               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
11026e59cc1SAlex Zinenko 
11126e59cc1SAlex Zinenko       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
112*ef976337SRiver Riddle                                                adaptor.operand());
11326e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
11426e59cc1SAlex Zinenko       return success();
11526e59cc1SAlex Zinenko     }
11626e59cc1SAlex Zinenko 
11726e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
11826e59cc1SAlex Zinenko     if (!vectorType)
11926e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
12026e59cc1SAlex Zinenko 
12126e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
122*ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
12326e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
12426e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
12526e59cc1SAlex Zinenko               mlir::VectorType::get(
12626e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
12726e59cc1SAlex Zinenko                   floatType),
12826e59cc1SAlex Zinenko               floatOne);
12926e59cc1SAlex Zinenko           auto one =
13026e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
13126e59cc1SAlex Zinenko           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
13226e59cc1SAlex Zinenko                                                    operands[0]);
13326e59cc1SAlex Zinenko           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
13426e59cc1SAlex Zinenko         },
13526e59cc1SAlex Zinenko         rewriter);
13626e59cc1SAlex Zinenko   }
13726e59cc1SAlex Zinenko };
13826e59cc1SAlex Zinenko 
13926e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`.
14026e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
14126e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
14226e59cc1SAlex Zinenko 
14326e59cc1SAlex Zinenko   LogicalResult
144*ef976337SRiver Riddle   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
14526e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
146*ef976337SRiver Riddle     auto operandType = adaptor.operand().getType();
14726e59cc1SAlex Zinenko 
14826e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
14926e59cc1SAlex Zinenko       return failure();
15026e59cc1SAlex Zinenko 
15126e59cc1SAlex Zinenko     auto loc = op.getLoc();
15226e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
15326e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
15426e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
15526e59cc1SAlex Zinenko 
15626e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
15726e59cc1SAlex Zinenko       LLVM::ConstantOp one;
15826e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
15926e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
16026e59cc1SAlex Zinenko             loc, operandType,
16126e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
16226e59cc1SAlex Zinenko       } else {
16326e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
16426e59cc1SAlex Zinenko       }
165*ef976337SRiver Riddle       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.operand());
16626e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
16726e59cc1SAlex Zinenko       return success();
16826e59cc1SAlex Zinenko     }
16926e59cc1SAlex Zinenko 
17026e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
17126e59cc1SAlex Zinenko     if (!vectorType)
17226e59cc1SAlex Zinenko       return failure();
17326e59cc1SAlex Zinenko 
17426e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
175*ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
17626e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
17726e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
17826e59cc1SAlex Zinenko               mlir::VectorType::get(
17926e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
18026e59cc1SAlex Zinenko                   floatType),
18126e59cc1SAlex Zinenko               floatOne);
18226e59cc1SAlex Zinenko           auto one =
18326e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
18426e59cc1SAlex Zinenko           auto sqrt =
18526e59cc1SAlex Zinenko               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
18626e59cc1SAlex Zinenko           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
18726e59cc1SAlex Zinenko         },
18826e59cc1SAlex Zinenko         rewriter);
18926e59cc1SAlex Zinenko   }
19026e59cc1SAlex Zinenko };
19126e59cc1SAlex Zinenko 
19226e59cc1SAlex Zinenko struct ConvertMathToLLVMPass
19326e59cc1SAlex Zinenko     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
19426e59cc1SAlex Zinenko   ConvertMathToLLVMPass() = default;
19526e59cc1SAlex Zinenko 
19626e59cc1SAlex Zinenko   void runOnFunction() override {
19726e59cc1SAlex Zinenko     RewritePatternSet patterns(&getContext());
19826e59cc1SAlex Zinenko     LLVMTypeConverter converter(&getContext());
19926e59cc1SAlex Zinenko     populateMathToLLVMConversionPatterns(converter, patterns);
20026e59cc1SAlex Zinenko     LLVMConversionTarget target(getContext());
20126e59cc1SAlex Zinenko     if (failed(
20226e59cc1SAlex Zinenko             applyPartialConversion(getFunction(), target, std::move(patterns))))
20326e59cc1SAlex Zinenko       signalPassFailure();
20426e59cc1SAlex Zinenko   }
20526e59cc1SAlex Zinenko };
20626e59cc1SAlex Zinenko } // namespace
20726e59cc1SAlex Zinenko 
20826e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
20926e59cc1SAlex Zinenko                                                 RewritePatternSet &patterns) {
21026e59cc1SAlex Zinenko   // clang-format off
21126e59cc1SAlex Zinenko   patterns.add<
21226e59cc1SAlex Zinenko     CosOpLowering,
21326e59cc1SAlex Zinenko     ExpOpLowering,
21426e59cc1SAlex Zinenko     Exp2OpLowering,
21526e59cc1SAlex Zinenko     ExpM1OpLowering,
21626e59cc1SAlex Zinenko     Log10OpLowering,
21726e59cc1SAlex Zinenko     Log1pOpLowering,
21826e59cc1SAlex Zinenko     Log2OpLowering,
21926e59cc1SAlex Zinenko     LogOpLowering,
22026e59cc1SAlex Zinenko     PowFOpLowering,
22126e59cc1SAlex Zinenko     RsqrtOpLowering,
22226e59cc1SAlex Zinenko     SinOpLowering,
22326e59cc1SAlex Zinenko     SqrtOpLowering
22426e59cc1SAlex Zinenko   >(converter);
22526e59cc1SAlex Zinenko   // clang-format on
22626e59cc1SAlex Zinenko }
22726e59cc1SAlex Zinenko 
22826e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
22926e59cc1SAlex Zinenko   return std::make_unique<ConvertMathToLLVMPass>();
23026e59cc1SAlex Zinenko }
231