126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
226e59cc1SAlex Zinenko //
326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
626e59cc1SAlex Zinenko //
726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===//
826e59cc1SAlex Zinenko 
926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
1026e59cc1SAlex Zinenko #include "../PassDetail.h"
1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h"
1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h"
1726e59cc1SAlex Zinenko 
1826e59cc1SAlex Zinenko using namespace mlir;
1926e59cc1SAlex Zinenko 
2026e59cc1SAlex Zinenko namespace {
21a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>;
22a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
23a54f4eaeSMogball using CopySignOpLowering =
24a54f4eaeSMogball     VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
2626e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
2726e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
28a54f4eaeSMogball using FloorOpLowering =
29a54f4eaeSMogball     VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
30a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
3126e59cc1SAlex Zinenko using Log10OpLowering =
3226e59cc1SAlex Zinenko     VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
3326e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
3426e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
3526e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
3626e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
3726e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
3826e59cc1SAlex Zinenko 
3926e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`.
4026e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
4126e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
4226e59cc1SAlex Zinenko 
4326e59cc1SAlex Zinenko   LogicalResult
44ef976337SRiver Riddle   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
4526e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
46*62fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
4726e59cc1SAlex Zinenko 
4826e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
4926e59cc1SAlex Zinenko       return failure();
5026e59cc1SAlex Zinenko 
5126e59cc1SAlex Zinenko     auto loc = op.getLoc();
5226e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
5326e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
5426e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
5526e59cc1SAlex Zinenko 
5626e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
5726e59cc1SAlex Zinenko       LLVM::ConstantOp one;
5826e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
5926e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
6026e59cc1SAlex Zinenko             loc, operandType,
6126e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
6226e59cc1SAlex Zinenko       } else {
6326e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
6426e59cc1SAlex Zinenko       }
65*62fea88bSJacques Pienaar       auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
6626e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
6726e59cc1SAlex Zinenko       return success();
6826e59cc1SAlex Zinenko     }
6926e59cc1SAlex Zinenko 
7026e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
7126e59cc1SAlex Zinenko     if (!vectorType)
7226e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
7326e59cc1SAlex Zinenko 
7426e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
75ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
7626e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
7726e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
7826e59cc1SAlex Zinenko               mlir::VectorType::get(
7926e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
8026e59cc1SAlex Zinenko                   floatType),
8126e59cc1SAlex Zinenko               floatOne);
8226e59cc1SAlex Zinenko           auto one =
8326e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
8426e59cc1SAlex Zinenko           auto exp =
8526e59cc1SAlex Zinenko               rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
8626e59cc1SAlex Zinenko           return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
8726e59cc1SAlex Zinenko         },
8826e59cc1SAlex Zinenko         rewriter);
8926e59cc1SAlex Zinenko   }
9026e59cc1SAlex Zinenko };
9126e59cc1SAlex Zinenko 
9226e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`.
9326e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
9426e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
9526e59cc1SAlex Zinenko 
9626e59cc1SAlex Zinenko   LogicalResult
97ef976337SRiver Riddle   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
9826e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
99*62fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
10026e59cc1SAlex Zinenko 
10126e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
10226e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "unsupported operand type");
10326e59cc1SAlex Zinenko 
10426e59cc1SAlex Zinenko     auto loc = op.getLoc();
10526e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
10626e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
10726e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
10826e59cc1SAlex Zinenko 
10926e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
11026e59cc1SAlex Zinenko       LLVM::ConstantOp one =
11126e59cc1SAlex Zinenko           LLVM::isCompatibleVectorType(operandType)
11226e59cc1SAlex Zinenko               ? rewriter.create<LLVM::ConstantOp>(
11326e59cc1SAlex Zinenko                     loc, operandType,
11426e59cc1SAlex Zinenko                     SplatElementsAttr::get(resultType.cast<ShapedType>(),
11526e59cc1SAlex Zinenko                                            floatOne))
11626e59cc1SAlex Zinenko               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
11726e59cc1SAlex Zinenko 
11826e59cc1SAlex Zinenko       auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
119*62fea88bSJacques Pienaar                                                adaptor.getOperand());
12026e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
12126e59cc1SAlex Zinenko       return success();
12226e59cc1SAlex Zinenko     }
12326e59cc1SAlex Zinenko 
12426e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
12526e59cc1SAlex Zinenko     if (!vectorType)
12626e59cc1SAlex Zinenko       return rewriter.notifyMatchFailure(op, "expected vector result type");
12726e59cc1SAlex Zinenko 
12826e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
129ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
13026e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
13126e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
13226e59cc1SAlex Zinenko               mlir::VectorType::get(
13326e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
13426e59cc1SAlex Zinenko                   floatType),
13526e59cc1SAlex Zinenko               floatOne);
13626e59cc1SAlex Zinenko           auto one =
13726e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
13826e59cc1SAlex Zinenko           auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
13926e59cc1SAlex Zinenko                                                    operands[0]);
14026e59cc1SAlex Zinenko           return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
14126e59cc1SAlex Zinenko         },
14226e59cc1SAlex Zinenko         rewriter);
14326e59cc1SAlex Zinenko   }
14426e59cc1SAlex Zinenko };
14526e59cc1SAlex Zinenko 
14626e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`.
14726e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
14826e59cc1SAlex Zinenko   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
14926e59cc1SAlex Zinenko 
15026e59cc1SAlex Zinenko   LogicalResult
151ef976337SRiver Riddle   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
15226e59cc1SAlex Zinenko                   ConversionPatternRewriter &rewriter) const override {
153*62fea88bSJacques Pienaar     auto operandType = adaptor.getOperand().getType();
15426e59cc1SAlex Zinenko 
15526e59cc1SAlex Zinenko     if (!operandType || !LLVM::isCompatibleType(operandType))
15626e59cc1SAlex Zinenko       return failure();
15726e59cc1SAlex Zinenko 
15826e59cc1SAlex Zinenko     auto loc = op.getLoc();
15926e59cc1SAlex Zinenko     auto resultType = op.getResult().getType();
16026e59cc1SAlex Zinenko     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
16126e59cc1SAlex Zinenko     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
16226e59cc1SAlex Zinenko 
16326e59cc1SAlex Zinenko     if (!operandType.isa<LLVM::LLVMArrayType>()) {
16426e59cc1SAlex Zinenko       LLVM::ConstantOp one;
16526e59cc1SAlex Zinenko       if (LLVM::isCompatibleVectorType(operandType)) {
16626e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(
16726e59cc1SAlex Zinenko             loc, operandType,
16826e59cc1SAlex Zinenko             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
16926e59cc1SAlex Zinenko       } else {
17026e59cc1SAlex Zinenko         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
17126e59cc1SAlex Zinenko       }
172*62fea88bSJacques Pienaar       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
17326e59cc1SAlex Zinenko       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
17426e59cc1SAlex Zinenko       return success();
17526e59cc1SAlex Zinenko     }
17626e59cc1SAlex Zinenko 
17726e59cc1SAlex Zinenko     auto vectorType = resultType.dyn_cast<VectorType>();
17826e59cc1SAlex Zinenko     if (!vectorType)
17926e59cc1SAlex Zinenko       return failure();
18026e59cc1SAlex Zinenko 
18126e59cc1SAlex Zinenko     return LLVM::detail::handleMultidimensionalVectors(
182ef976337SRiver Riddle         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
18326e59cc1SAlex Zinenko         [&](Type llvm1DVectorTy, ValueRange operands) {
18426e59cc1SAlex Zinenko           auto splatAttr = SplatElementsAttr::get(
18526e59cc1SAlex Zinenko               mlir::VectorType::get(
18626e59cc1SAlex Zinenko                   {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
18726e59cc1SAlex Zinenko                   floatType),
18826e59cc1SAlex Zinenko               floatOne);
18926e59cc1SAlex Zinenko           auto one =
19026e59cc1SAlex Zinenko               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
19126e59cc1SAlex Zinenko           auto sqrt =
19226e59cc1SAlex Zinenko               rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
19326e59cc1SAlex Zinenko           return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
19426e59cc1SAlex Zinenko         },
19526e59cc1SAlex Zinenko         rewriter);
19626e59cc1SAlex Zinenko   }
19726e59cc1SAlex Zinenko };
19826e59cc1SAlex Zinenko 
19926e59cc1SAlex Zinenko struct ConvertMathToLLVMPass
20026e59cc1SAlex Zinenko     : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
20126e59cc1SAlex Zinenko   ConvertMathToLLVMPass() = default;
20226e59cc1SAlex Zinenko 
20326e59cc1SAlex Zinenko   void runOnFunction() override {
20426e59cc1SAlex Zinenko     RewritePatternSet patterns(&getContext());
20526e59cc1SAlex Zinenko     LLVMTypeConverter converter(&getContext());
20626e59cc1SAlex Zinenko     populateMathToLLVMConversionPatterns(converter, patterns);
20726e59cc1SAlex Zinenko     LLVMConversionTarget target(getContext());
20826e59cc1SAlex Zinenko     if (failed(
20926e59cc1SAlex Zinenko             applyPartialConversion(getFunction(), target, std::move(patterns))))
21026e59cc1SAlex Zinenko       signalPassFailure();
21126e59cc1SAlex Zinenko   }
21226e59cc1SAlex Zinenko };
21326e59cc1SAlex Zinenko } // namespace
21426e59cc1SAlex Zinenko 
21526e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
21626e59cc1SAlex Zinenko                                                 RewritePatternSet &patterns) {
21726e59cc1SAlex Zinenko   // clang-format off
21826e59cc1SAlex Zinenko   patterns.add<
219a54f4eaeSMogball     AbsOpLowering,
220a54f4eaeSMogball     CeilOpLowering,
221a54f4eaeSMogball     CopySignOpLowering,
22226e59cc1SAlex Zinenko     CosOpLowering,
22326e59cc1SAlex Zinenko     ExpOpLowering,
22426e59cc1SAlex Zinenko     Exp2OpLowering,
22526e59cc1SAlex Zinenko     ExpM1OpLowering,
226a54f4eaeSMogball     FloorOpLowering,
227a54f4eaeSMogball     FmaOpLowering,
22826e59cc1SAlex Zinenko     Log10OpLowering,
22926e59cc1SAlex Zinenko     Log1pOpLowering,
23026e59cc1SAlex Zinenko     Log2OpLowering,
23126e59cc1SAlex Zinenko     LogOpLowering,
23226e59cc1SAlex Zinenko     PowFOpLowering,
23326e59cc1SAlex Zinenko     RsqrtOpLowering,
23426e59cc1SAlex Zinenko     SinOpLowering,
23526e59cc1SAlex Zinenko     SqrtOpLowering
23626e59cc1SAlex Zinenko   >(converter);
23726e59cc1SAlex Zinenko   // clang-format on
23826e59cc1SAlex Zinenko }
23926e59cc1SAlex Zinenko 
24026e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
24126e59cc1SAlex Zinenko   return std::make_unique<ConvertMathToLLVMPass>();
24226e59cc1SAlex Zinenko }
243