126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 226e59cc1SAlex Zinenko // 326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 626e59cc1SAlex Zinenko // 726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===// 826e59cc1SAlex Zinenko 926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 1026e59cc1SAlex Zinenko #include "../PassDetail.h" 1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h" 1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h" 1726e59cc1SAlex Zinenko 1826e59cc1SAlex Zinenko using namespace mlir; 1926e59cc1SAlex Zinenko 2026e59cc1SAlex Zinenko namespace { 21a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>; 22a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>; 23a54f4eaeSMogball using CopySignOpLowering = 24a54f4eaeSMogball VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>; 2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 26c5fef77bSRob Suderman using CtPopFOpLowering = 27c5fef77bSRob Suderman VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>; 2826e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 2926e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 30a54f4eaeSMogball using FloorOpLowering = 31a54f4eaeSMogball VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>; 32a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>; 3326e59cc1SAlex Zinenko using Log10OpLowering = 3426e59cc1SAlex Zinenko VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 3526e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 3626e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 3726e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 3826e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 3926e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 4026e59cc1SAlex Zinenko 4123149d52SRob Suderman // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`. 4223149d52SRob Suderman template <typename MathOp, typename LLVMOp> 4323149d52SRob Suderman struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> { 4423149d52SRob Suderman using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern; 4523149d52SRob Suderman using Super = CountOpLowering<MathOp, LLVMOp>; 4623149d52SRob Suderman 4723149d52SRob Suderman LogicalResult 4823149d52SRob Suderman matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, 4923149d52SRob Suderman ConversionPatternRewriter &rewriter) const override { 5023149d52SRob Suderman auto operandType = adaptor.getOperand().getType(); 5123149d52SRob Suderman 5223149d52SRob Suderman if (!operandType || !LLVM::isCompatibleType(operandType)) 5323149d52SRob Suderman return failure(); 5423149d52SRob Suderman 5523149d52SRob Suderman auto loc = op.getLoc(); 5623149d52SRob Suderman auto resultType = op.getResult().getType(); 5723149d52SRob Suderman auto boolType = rewriter.getIntegerType(1); 5823149d52SRob Suderman auto boolZero = rewriter.getIntegerAttr(boolType, 0); 5923149d52SRob Suderman 6023149d52SRob Suderman if (!operandType.template isa<LLVM::LLVMArrayType>()) { 6123149d52SRob Suderman LLVM::ConstantOp zero = 6223149d52SRob Suderman rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero); 6323149d52SRob Suderman rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(), 6423149d52SRob Suderman zero); 6523149d52SRob Suderman return success(); 6623149d52SRob Suderman } 6723149d52SRob Suderman 6823149d52SRob Suderman auto vectorType = resultType.template dyn_cast<VectorType>(); 6923149d52SRob Suderman if (!vectorType) 7023149d52SRob Suderman return failure(); 7123149d52SRob Suderman 7223149d52SRob Suderman return LLVM::detail::handleMultidimensionalVectors( 7323149d52SRob Suderman op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(), 7423149d52SRob Suderman [&](Type llvm1DVectorTy, ValueRange operands) { 7523149d52SRob Suderman LLVM::ConstantOp zero = 7623149d52SRob Suderman rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero); 77*cb4a5eaeSRobert Suderman return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0], 78*cb4a5eaeSRobert Suderman zero); 7923149d52SRob Suderman }, 8023149d52SRob Suderman rewriter); 8123149d52SRob Suderman } 8223149d52SRob Suderman }; 8323149d52SRob Suderman 8423149d52SRob Suderman using CountLeadingZerosOpLowering = 8523149d52SRob Suderman CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>; 8623149d52SRob Suderman using CountTrailingZerosOpLowering = 8723149d52SRob Suderman CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>; 8823149d52SRob Suderman 8926e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`. 9026e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 9126e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 9226e59cc1SAlex Zinenko 9326e59cc1SAlex Zinenko LogicalResult 94ef976337SRiver Riddle matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 9526e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 9662fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 9726e59cc1SAlex Zinenko 9826e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 9926e59cc1SAlex Zinenko return failure(); 10026e59cc1SAlex Zinenko 10126e59cc1SAlex Zinenko auto loc = op.getLoc(); 10226e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 10326e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 10426e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 10526e59cc1SAlex Zinenko 10626e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 10726e59cc1SAlex Zinenko LLVM::ConstantOp one; 10826e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 10926e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 11026e59cc1SAlex Zinenko loc, operandType, 11126e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 11226e59cc1SAlex Zinenko } else { 11326e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 11426e59cc1SAlex Zinenko } 11562fea88bSJacques Pienaar auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand()); 11626e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 11726e59cc1SAlex Zinenko return success(); 11826e59cc1SAlex Zinenko } 11926e59cc1SAlex Zinenko 12026e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 12126e59cc1SAlex Zinenko if (!vectorType) 12226e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 12326e59cc1SAlex Zinenko 12426e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 125ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 12626e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 12726e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 12826e59cc1SAlex Zinenko mlir::VectorType::get( 12926e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 13026e59cc1SAlex Zinenko floatType), 13126e59cc1SAlex Zinenko floatOne); 13226e59cc1SAlex Zinenko auto one = 13326e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 13426e59cc1SAlex Zinenko auto exp = 13526e59cc1SAlex Zinenko rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 13626e59cc1SAlex Zinenko return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 13726e59cc1SAlex Zinenko }, 13826e59cc1SAlex Zinenko rewriter); 13926e59cc1SAlex Zinenko } 14026e59cc1SAlex Zinenko }; 14126e59cc1SAlex Zinenko 14226e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`. 14326e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 14426e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 14526e59cc1SAlex Zinenko 14626e59cc1SAlex Zinenko LogicalResult 147ef976337SRiver Riddle matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 14826e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 14962fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 15026e59cc1SAlex Zinenko 15126e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 15226e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "unsupported operand type"); 15326e59cc1SAlex Zinenko 15426e59cc1SAlex Zinenko auto loc = op.getLoc(); 15526e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 15626e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 15726e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 15826e59cc1SAlex Zinenko 15926e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 16026e59cc1SAlex Zinenko LLVM::ConstantOp one = 16126e59cc1SAlex Zinenko LLVM::isCompatibleVectorType(operandType) 16226e59cc1SAlex Zinenko ? rewriter.create<LLVM::ConstantOp>( 16326e59cc1SAlex Zinenko loc, operandType, 16426e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), 16526e59cc1SAlex Zinenko floatOne)) 16626e59cc1SAlex Zinenko : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 16726e59cc1SAlex Zinenko 16826e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 16962fea88bSJacques Pienaar adaptor.getOperand()); 17026e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 17126e59cc1SAlex Zinenko return success(); 17226e59cc1SAlex Zinenko } 17326e59cc1SAlex Zinenko 17426e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 17526e59cc1SAlex Zinenko if (!vectorType) 17626e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 17726e59cc1SAlex Zinenko 17826e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 179ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 18026e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 18126e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 18226e59cc1SAlex Zinenko mlir::VectorType::get( 18326e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 18426e59cc1SAlex Zinenko floatType), 18526e59cc1SAlex Zinenko floatOne); 18626e59cc1SAlex Zinenko auto one = 18726e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 18826e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 18926e59cc1SAlex Zinenko operands[0]); 19026e59cc1SAlex Zinenko return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 19126e59cc1SAlex Zinenko }, 19226e59cc1SAlex Zinenko rewriter); 19326e59cc1SAlex Zinenko } 19426e59cc1SAlex Zinenko }; 19526e59cc1SAlex Zinenko 19626e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`. 19726e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 19826e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 19926e59cc1SAlex Zinenko 20026e59cc1SAlex Zinenko LogicalResult 201ef976337SRiver Riddle matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 20226e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 20362fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 20426e59cc1SAlex Zinenko 20526e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 20626e59cc1SAlex Zinenko return failure(); 20726e59cc1SAlex Zinenko 20826e59cc1SAlex Zinenko auto loc = op.getLoc(); 20926e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 21026e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 21126e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 21226e59cc1SAlex Zinenko 21326e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 21426e59cc1SAlex Zinenko LLVM::ConstantOp one; 21526e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 21626e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 21726e59cc1SAlex Zinenko loc, operandType, 21826e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 21926e59cc1SAlex Zinenko } else { 22026e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 22126e59cc1SAlex Zinenko } 22262fea88bSJacques Pienaar auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand()); 22326e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 22426e59cc1SAlex Zinenko return success(); 22526e59cc1SAlex Zinenko } 22626e59cc1SAlex Zinenko 22726e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 22826e59cc1SAlex Zinenko if (!vectorType) 22926e59cc1SAlex Zinenko return failure(); 23026e59cc1SAlex Zinenko 23126e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 232ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 23326e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 23426e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 23526e59cc1SAlex Zinenko mlir::VectorType::get( 23626e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 23726e59cc1SAlex Zinenko floatType), 23826e59cc1SAlex Zinenko floatOne); 23926e59cc1SAlex Zinenko auto one = 24026e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 24126e59cc1SAlex Zinenko auto sqrt = 24226e59cc1SAlex Zinenko rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 24326e59cc1SAlex Zinenko return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 24426e59cc1SAlex Zinenko }, 24526e59cc1SAlex Zinenko rewriter); 24626e59cc1SAlex Zinenko } 24726e59cc1SAlex Zinenko }; 24826e59cc1SAlex Zinenko 24926e59cc1SAlex Zinenko struct ConvertMathToLLVMPass 25026e59cc1SAlex Zinenko : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 25126e59cc1SAlex Zinenko ConvertMathToLLVMPass() = default; 25226e59cc1SAlex Zinenko 25341574554SRiver Riddle void runOnOperation() override { 25426e59cc1SAlex Zinenko RewritePatternSet patterns(&getContext()); 25526e59cc1SAlex Zinenko LLVMTypeConverter converter(&getContext()); 25626e59cc1SAlex Zinenko populateMathToLLVMConversionPatterns(converter, patterns); 25726e59cc1SAlex Zinenko LLVMConversionTarget target(getContext()); 25841574554SRiver Riddle if (failed(applyPartialConversion(getOperation(), target, 25941574554SRiver Riddle std::move(patterns)))) 26026e59cc1SAlex Zinenko signalPassFailure(); 26126e59cc1SAlex Zinenko } 26226e59cc1SAlex Zinenko }; 26326e59cc1SAlex Zinenko } // namespace 26426e59cc1SAlex Zinenko 26526e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 26626e59cc1SAlex Zinenko RewritePatternSet &patterns) { 26726e59cc1SAlex Zinenko // clang-format off 26826e59cc1SAlex Zinenko patterns.add< 269a54f4eaeSMogball AbsOpLowering, 270a54f4eaeSMogball CeilOpLowering, 271a54f4eaeSMogball CopySignOpLowering, 27226e59cc1SAlex Zinenko CosOpLowering, 27323149d52SRob Suderman CountLeadingZerosOpLowering, 27423149d52SRob Suderman CountTrailingZerosOpLowering, 275c5fef77bSRob Suderman CtPopFOpLowering, 27626e59cc1SAlex Zinenko ExpOpLowering, 27726e59cc1SAlex Zinenko Exp2OpLowering, 27826e59cc1SAlex Zinenko ExpM1OpLowering, 279a54f4eaeSMogball FloorOpLowering, 280a54f4eaeSMogball FmaOpLowering, 28126e59cc1SAlex Zinenko Log10OpLowering, 28226e59cc1SAlex Zinenko Log1pOpLowering, 28326e59cc1SAlex Zinenko Log2OpLowering, 28426e59cc1SAlex Zinenko LogOpLowering, 28526e59cc1SAlex Zinenko PowFOpLowering, 28626e59cc1SAlex Zinenko RsqrtOpLowering, 28726e59cc1SAlex Zinenko SinOpLowering, 28826e59cc1SAlex Zinenko SqrtOpLowering 28926e59cc1SAlex Zinenko >(converter); 29026e59cc1SAlex Zinenko // clang-format on 29126e59cc1SAlex Zinenko } 29226e59cc1SAlex Zinenko 29326e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 29426e59cc1SAlex Zinenko return std::make_unique<ConvertMathToLLVMPass>(); 29526e59cc1SAlex Zinenko } 296