126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
226e59cc1SAlex Zinenko //
326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
626e59cc1SAlex Zinenko //
726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===//
826e59cc1SAlex Zinenko 
926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
1026e59cc1SAlex Zinenko #include "../PassDetail.h"
1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h"
1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h"
1726e59cc1SAlex Zinenko 
1826e59cc1SAlex Zinenko using namespace mlir;
1926e59cc1SAlex Zinenko 
2026e59cc1SAlex Zinenko namespace {
21*a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22*a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23*a54f4eaeSMogball using CopySignOpLowering =
24*a54f4eaeSMogball     VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
2626e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
2726e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
28*a54f4eaeSMogball using FloorOpLowering =
29*a54f4eaeSMogball     VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
30*a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
3126e59cc1SAlex Zinenko using Log10OpLowering =
3226e59cc1SAlex Zinenko     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
3326e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
3426e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
3526e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
3626e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
3726e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
3826e59cc1SAlex Zinenko 
3926e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`.
4026e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
4126e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
4226e59cc1SAlex Zinenko 
4326e59cc1SAlex Zinenko   LogicalResult
44ef976337SRiver Riddle   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
4526e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
46ef976337SRiver Riddle     auto operandType = adaptor.operand().getType();
4726e59cc1SAlex Zinenko 
4826e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
4926e59cc1SAlex Zinenko       return failure();
5026e59cc1SAlex Zinenko 
5126e59cc1SAlex Zinenko     auto loc = op.getLoc();
5226e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
5326e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
5426e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
5526e59cc1SAlex Zinenko 
5626e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
5726e59cc1SAlex Zinenko       LLVM::ConstantOp one;
5826e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
5926e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
6026e59cc1SAlex Zinenko             loc, operandType,
6126e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
6226e59cc1SAlex Zinenko       } else {
6326e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
6426e59cc1SAlex Zinenko       }
65ef976337SRiver Riddle       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.operand());
6626e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
6726e59cc1SAlex Zinenko       return success();
6826e59cc1SAlex Zinenko     }
6926e59cc1SAlex Zinenko 
7026e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
7126e59cc1SAlex Zinenko     if (!vectorType)
7226e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
7326e59cc1SAlex Zinenko 
7426e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
75ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
7626e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
7726e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
7826e59cc1SAlex Zinenko               mlir::VectorType::get(
7926e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
8026e59cc1SAlex Zinenko                   floatType),
8126e59cc1SAlex Zinenko               floatOne);
8226e59cc1SAlex Zinenko           auto one =
8326e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
8426e59cc1SAlex Zinenko           auto exp =
8526e59cc1SAlex Zinenko               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
8626e59cc1SAlex Zinenko           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
8726e59cc1SAlex Zinenko         },
8826e59cc1SAlex Zinenko         rewriter);
8926e59cc1SAlex Zinenko   }
9026e59cc1SAlex Zinenko };
9126e59cc1SAlex Zinenko 
9226e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`.
9326e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
9426e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
9526e59cc1SAlex Zinenko 
9626e59cc1SAlex Zinenko   LogicalResult
97ef976337SRiver Riddle   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
9826e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
99ef976337SRiver Riddle     auto operandType = adaptor.operand().getType();
10026e59cc1SAlex Zinenko 
10126e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
10226e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "unsupported operand type");
10326e59cc1SAlex Zinenko 
10426e59cc1SAlex Zinenko     auto loc = op.getLoc();
10526e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
10626e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
10726e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
10826e59cc1SAlex Zinenko 
10926e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
11026e59cc1SAlex Zinenko       LLVM::ConstantOp one =
11126e59cc1SAlex Zinenko           LLVM::isCompatibleVectorType(operandType)
11226e59cc1SAlex Zinenko               ? rewriter.create<LLVM::ConstantOp>(
11326e59cc1SAlex Zinenko                     loc, operandType,
11426e59cc1SAlex Zinenko                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
11526e59cc1SAlex Zinenko                                            floatOne))
11626e59cc1SAlex Zinenko               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
11726e59cc1SAlex Zinenko 
11826e59cc1SAlex Zinenko       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
119ef976337SRiver Riddle                                                adaptor.operand());
12026e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
12126e59cc1SAlex Zinenko       return success();
12226e59cc1SAlex Zinenko     }
12326e59cc1SAlex Zinenko 
12426e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
12526e59cc1SAlex Zinenko     if (!vectorType)
12626e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
12726e59cc1SAlex Zinenko 
12826e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
129ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
13026e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
13126e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
13226e59cc1SAlex Zinenko               mlir::VectorType::get(
13326e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
13426e59cc1SAlex Zinenko                   floatType),
13526e59cc1SAlex Zinenko               floatOne);
13626e59cc1SAlex Zinenko           auto one =
13726e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
13826e59cc1SAlex Zinenko           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
13926e59cc1SAlex Zinenko                                                    operands[0]);
14026e59cc1SAlex Zinenko           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
14126e59cc1SAlex Zinenko         },
14226e59cc1SAlex Zinenko         rewriter);
14326e59cc1SAlex Zinenko   }
14426e59cc1SAlex Zinenko };
14526e59cc1SAlex Zinenko 
14626e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`.
14726e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
14826e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
14926e59cc1SAlex Zinenko 
15026e59cc1SAlex Zinenko   LogicalResult
151ef976337SRiver Riddle   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
15226e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
153ef976337SRiver Riddle     auto operandType = adaptor.operand().getType();
15426e59cc1SAlex Zinenko 
15526e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
15626e59cc1SAlex Zinenko       return failure();
15726e59cc1SAlex Zinenko 
15826e59cc1SAlex Zinenko     auto loc = op.getLoc();
15926e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
16026e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
16126e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
16226e59cc1SAlex Zinenko 
16326e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
16426e59cc1SAlex Zinenko       LLVM::ConstantOp one;
16526e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
16626e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
16726e59cc1SAlex Zinenko             loc, operandType,
16826e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
16926e59cc1SAlex Zinenko       } else {
17026e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
17126e59cc1SAlex Zinenko       }
172ef976337SRiver Riddle       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.operand());
17326e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
17426e59cc1SAlex Zinenko       return success();
17526e59cc1SAlex Zinenko     }
17626e59cc1SAlex Zinenko 
17726e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
17826e59cc1SAlex Zinenko     if (!vectorType)
17926e59cc1SAlex Zinenko       return failure();
18026e59cc1SAlex Zinenko 
18126e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
182ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
18326e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
18426e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
18526e59cc1SAlex Zinenko               mlir::VectorType::get(
18626e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
18726e59cc1SAlex Zinenko                   floatType),
18826e59cc1SAlex Zinenko               floatOne);
18926e59cc1SAlex Zinenko           auto one =
19026e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
19126e59cc1SAlex Zinenko           auto sqrt =
19226e59cc1SAlex Zinenko               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
19326e59cc1SAlex Zinenko           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
19426e59cc1SAlex Zinenko         },
19526e59cc1SAlex Zinenko         rewriter);
19626e59cc1SAlex Zinenko   }
19726e59cc1SAlex Zinenko };
19826e59cc1SAlex Zinenko 
19926e59cc1SAlex Zinenko struct ConvertMathToLLVMPass
20026e59cc1SAlex Zinenko     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
20126e59cc1SAlex Zinenko   ConvertMathToLLVMPass() = default;
20226e59cc1SAlex Zinenko 
20326e59cc1SAlex Zinenko   void runOnFunction() override {
20426e59cc1SAlex Zinenko     RewritePatternSet patterns(&getContext());
20526e59cc1SAlex Zinenko     LLVMTypeConverter converter(&getContext());
20626e59cc1SAlex Zinenko     populateMathToLLVMConversionPatterns(converter, patterns);
20726e59cc1SAlex Zinenko     LLVMConversionTarget target(getContext());
20826e59cc1SAlex Zinenko     if (failed(
20926e59cc1SAlex Zinenko             applyPartialConversion(getFunction(), target, std::move(patterns))))
21026e59cc1SAlex Zinenko       signalPassFailure();
21126e59cc1SAlex Zinenko   }
21226e59cc1SAlex Zinenko };
21326e59cc1SAlex Zinenko } // namespace
21426e59cc1SAlex Zinenko 
21526e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
21626e59cc1SAlex Zinenko                                                 RewritePatternSet &patterns) {
21726e59cc1SAlex Zinenko   // clang-format off
21826e59cc1SAlex Zinenko   patterns.add<
219*a54f4eaeSMogball     AbsOpLowering,
220*a54f4eaeSMogball     CeilOpLowering,
221*a54f4eaeSMogball     CopySignOpLowering,
22226e59cc1SAlex Zinenko     CosOpLowering,
22326e59cc1SAlex Zinenko     ExpOpLowering,
22426e59cc1SAlex Zinenko     Exp2OpLowering,
22526e59cc1SAlex Zinenko     ExpM1OpLowering,
226*a54f4eaeSMogball     FloorOpLowering,
227*a54f4eaeSMogball     FmaOpLowering,
22826e59cc1SAlex Zinenko     Log10OpLowering,
22926e59cc1SAlex Zinenko     Log1pOpLowering,
23026e59cc1SAlex Zinenko     Log2OpLowering,
23126e59cc1SAlex Zinenko     LogOpLowering,
23226e59cc1SAlex Zinenko     PowFOpLowering,
23326e59cc1SAlex Zinenko     RsqrtOpLowering,
23426e59cc1SAlex Zinenko     SinOpLowering,
23526e59cc1SAlex Zinenko     SqrtOpLowering
23626e59cc1SAlex Zinenko   >(converter);
23726e59cc1SAlex Zinenko   // clang-format on
23826e59cc1SAlex Zinenko }
23926e59cc1SAlex Zinenko 
24026e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
24126e59cc1SAlex Zinenko   return std::make_unique<ConvertMathToLLVMPass>();
24226e59cc1SAlex Zinenko }
243