126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 226e59cc1SAlex Zinenko // 326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 626e59cc1SAlex Zinenko // 726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===// 826e59cc1SAlex Zinenko 926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 1026e59cc1SAlex Zinenko #include "../PassDetail.h" 1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h" 1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h" 1726e59cc1SAlex Zinenko 1826e59cc1SAlex Zinenko using namespace mlir; 1926e59cc1SAlex Zinenko 2026e59cc1SAlex Zinenko namespace { 21a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>; 22a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>; 23a54f4eaeSMogball using CopySignOpLowering = 24a54f4eaeSMogball VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>; 2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 26c5fef77bSRob Suderman using CtPopFOpLowering = 27c5fef77bSRob Suderman VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>; 2826e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 2926e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 30a54f4eaeSMogball using FloorOpLowering = 31a54f4eaeSMogball VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>; 32a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>; 3326e59cc1SAlex Zinenko using Log10OpLowering = 3426e59cc1SAlex Zinenko VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 3526e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 3626e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 3726e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 3826e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 3926e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 40*a0fc94abSlorenzo chelini using RoundOpLowering = 41*a0fc94abSlorenzo chelini VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>; 4226e59cc1SAlex Zinenko 4323149d52SRob Suderman // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`. 4423149d52SRob Suderman template <typename MathOp, typename LLVMOp> 4523149d52SRob Suderman struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> { 4623149d52SRob Suderman using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern; 4723149d52SRob Suderman using Super = CountOpLowering<MathOp, LLVMOp>; 4823149d52SRob Suderman 4923149d52SRob Suderman LogicalResult 5023149d52SRob Suderman matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, 5123149d52SRob Suderman ConversionPatternRewriter &rewriter) const override { 5223149d52SRob Suderman auto operandType = adaptor.getOperand().getType(); 5323149d52SRob Suderman 5423149d52SRob Suderman if (!operandType || !LLVM::isCompatibleType(operandType)) 5523149d52SRob Suderman return failure(); 5623149d52SRob Suderman 5723149d52SRob Suderman auto loc = op.getLoc(); 5823149d52SRob Suderman auto resultType = op.getResult().getType(); 5923149d52SRob Suderman auto boolType = rewriter.getIntegerType(1); 6023149d52SRob Suderman auto boolZero = rewriter.getIntegerAttr(boolType, 0); 6123149d52SRob Suderman 6223149d52SRob Suderman if (!operandType.template isa<LLVM::LLVMArrayType>()) { 6323149d52SRob Suderman LLVM::ConstantOp zero = 6423149d52SRob Suderman rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero); 6523149d52SRob Suderman rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(), 6623149d52SRob Suderman zero); 6723149d52SRob Suderman return success(); 6823149d52SRob Suderman } 6923149d52SRob Suderman 7023149d52SRob Suderman auto vectorType = resultType.template dyn_cast<VectorType>(); 7123149d52SRob Suderman if (!vectorType) 7223149d52SRob Suderman return failure(); 7323149d52SRob Suderman 7423149d52SRob Suderman return LLVM::detail::handleMultidimensionalVectors( 7523149d52SRob Suderman op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(), 7623149d52SRob Suderman [&](Type llvm1DVectorTy, ValueRange operands) { 7723149d52SRob Suderman LLVM::ConstantOp zero = 7823149d52SRob Suderman rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero); 79cb4a5eaeSRobert Suderman return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0], 80cb4a5eaeSRobert Suderman zero); 8123149d52SRob Suderman }, 8223149d52SRob Suderman rewriter); 8323149d52SRob Suderman } 8423149d52SRob Suderman }; 8523149d52SRob Suderman 8623149d52SRob Suderman using CountLeadingZerosOpLowering = 8723149d52SRob Suderman CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>; 8823149d52SRob Suderman using CountTrailingZerosOpLowering = 8923149d52SRob Suderman CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>; 9023149d52SRob Suderman 9126e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`. 9226e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 9326e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 9426e59cc1SAlex Zinenko 9526e59cc1SAlex Zinenko LogicalResult 96ef976337SRiver Riddle matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 9726e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 9862fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 9926e59cc1SAlex Zinenko 10026e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 10126e59cc1SAlex Zinenko return failure(); 10226e59cc1SAlex Zinenko 10326e59cc1SAlex Zinenko auto loc = op.getLoc(); 10426e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 10526e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 10626e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 10726e59cc1SAlex Zinenko 10826e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 10926e59cc1SAlex Zinenko LLVM::ConstantOp one; 11026e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 11126e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 11226e59cc1SAlex Zinenko loc, operandType, 11326e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 11426e59cc1SAlex Zinenko } else { 11526e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 11626e59cc1SAlex Zinenko } 11762fea88bSJacques Pienaar auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand()); 11826e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 11926e59cc1SAlex Zinenko return success(); 12026e59cc1SAlex Zinenko } 12126e59cc1SAlex Zinenko 12226e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 12326e59cc1SAlex Zinenko if (!vectorType) 12426e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 12526e59cc1SAlex Zinenko 12626e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 127ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 12826e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 12926e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 13026e59cc1SAlex Zinenko mlir::VectorType::get( 13126e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 13226e59cc1SAlex Zinenko floatType), 13326e59cc1SAlex Zinenko floatOne); 13426e59cc1SAlex Zinenko auto one = 13526e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 13626e59cc1SAlex Zinenko auto exp = 13726e59cc1SAlex Zinenko rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 13826e59cc1SAlex Zinenko return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 13926e59cc1SAlex Zinenko }, 14026e59cc1SAlex Zinenko rewriter); 14126e59cc1SAlex Zinenko } 14226e59cc1SAlex Zinenko }; 14326e59cc1SAlex Zinenko 14426e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`. 14526e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 14626e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 14726e59cc1SAlex Zinenko 14826e59cc1SAlex Zinenko LogicalResult 149ef976337SRiver Riddle matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 15026e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 15162fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 15226e59cc1SAlex Zinenko 15326e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 15426e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "unsupported operand type"); 15526e59cc1SAlex Zinenko 15626e59cc1SAlex Zinenko auto loc = op.getLoc(); 15726e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 15826e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 15926e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 16026e59cc1SAlex Zinenko 16126e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 16226e59cc1SAlex Zinenko LLVM::ConstantOp one = 16326e59cc1SAlex Zinenko LLVM::isCompatibleVectorType(operandType) 16426e59cc1SAlex Zinenko ? rewriter.create<LLVM::ConstantOp>( 16526e59cc1SAlex Zinenko loc, operandType, 16626e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), 16726e59cc1SAlex Zinenko floatOne)) 16826e59cc1SAlex Zinenko : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 16926e59cc1SAlex Zinenko 17026e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 17162fea88bSJacques Pienaar adaptor.getOperand()); 17226e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 17326e59cc1SAlex Zinenko return success(); 17426e59cc1SAlex Zinenko } 17526e59cc1SAlex Zinenko 17626e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 17726e59cc1SAlex Zinenko if (!vectorType) 17826e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 17926e59cc1SAlex Zinenko 18026e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 181ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 18226e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 18326e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 18426e59cc1SAlex Zinenko mlir::VectorType::get( 18526e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 18626e59cc1SAlex Zinenko floatType), 18726e59cc1SAlex Zinenko floatOne); 18826e59cc1SAlex Zinenko auto one = 18926e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 19026e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 19126e59cc1SAlex Zinenko operands[0]); 19226e59cc1SAlex Zinenko return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 19326e59cc1SAlex Zinenko }, 19426e59cc1SAlex Zinenko rewriter); 19526e59cc1SAlex Zinenko } 19626e59cc1SAlex Zinenko }; 19726e59cc1SAlex Zinenko 19826e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`. 19926e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 20026e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 20126e59cc1SAlex Zinenko 20226e59cc1SAlex Zinenko LogicalResult 203ef976337SRiver Riddle matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 20426e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 20562fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 20626e59cc1SAlex Zinenko 20726e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 20826e59cc1SAlex Zinenko return failure(); 20926e59cc1SAlex Zinenko 21026e59cc1SAlex Zinenko auto loc = op.getLoc(); 21126e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 21226e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 21326e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 21426e59cc1SAlex Zinenko 21526e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 21626e59cc1SAlex Zinenko LLVM::ConstantOp one; 21726e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 21826e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 21926e59cc1SAlex Zinenko loc, operandType, 22026e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 22126e59cc1SAlex Zinenko } else { 22226e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 22326e59cc1SAlex Zinenko } 22462fea88bSJacques Pienaar auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand()); 22526e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 22626e59cc1SAlex Zinenko return success(); 22726e59cc1SAlex Zinenko } 22826e59cc1SAlex Zinenko 22926e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 23026e59cc1SAlex Zinenko if (!vectorType) 23126e59cc1SAlex Zinenko return failure(); 23226e59cc1SAlex Zinenko 23326e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 234ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 23526e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 23626e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 23726e59cc1SAlex Zinenko mlir::VectorType::get( 23826e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 23926e59cc1SAlex Zinenko floatType), 24026e59cc1SAlex Zinenko floatOne); 24126e59cc1SAlex Zinenko auto one = 24226e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 24326e59cc1SAlex Zinenko auto sqrt = 24426e59cc1SAlex Zinenko rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 24526e59cc1SAlex Zinenko return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 24626e59cc1SAlex Zinenko }, 24726e59cc1SAlex Zinenko rewriter); 24826e59cc1SAlex Zinenko } 24926e59cc1SAlex Zinenko }; 25026e59cc1SAlex Zinenko 25126e59cc1SAlex Zinenko struct ConvertMathToLLVMPass 25226e59cc1SAlex Zinenko : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 25326e59cc1SAlex Zinenko ConvertMathToLLVMPass() = default; 25426e59cc1SAlex Zinenko 25541574554SRiver Riddle void runOnOperation() override { 25626e59cc1SAlex Zinenko RewritePatternSet patterns(&getContext()); 25726e59cc1SAlex Zinenko LLVMTypeConverter converter(&getContext()); 25826e59cc1SAlex Zinenko populateMathToLLVMConversionPatterns(converter, patterns); 25926e59cc1SAlex Zinenko LLVMConversionTarget target(getContext()); 26041574554SRiver Riddle if (failed(applyPartialConversion(getOperation(), target, 26141574554SRiver Riddle std::move(patterns)))) 26226e59cc1SAlex Zinenko signalPassFailure(); 26326e59cc1SAlex Zinenko } 26426e59cc1SAlex Zinenko }; 26526e59cc1SAlex Zinenko } // namespace 26626e59cc1SAlex Zinenko 26726e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 26826e59cc1SAlex Zinenko RewritePatternSet &patterns) { 26926e59cc1SAlex Zinenko // clang-format off 27026e59cc1SAlex Zinenko patterns.add< 271a54f4eaeSMogball AbsOpLowering, 272a54f4eaeSMogball CeilOpLowering, 273a54f4eaeSMogball CopySignOpLowering, 27426e59cc1SAlex Zinenko CosOpLowering, 27523149d52SRob Suderman CountLeadingZerosOpLowering, 27623149d52SRob Suderman CountTrailingZerosOpLowering, 277c5fef77bSRob Suderman CtPopFOpLowering, 27826e59cc1SAlex Zinenko ExpOpLowering, 27926e59cc1SAlex Zinenko Exp2OpLowering, 28026e59cc1SAlex Zinenko ExpM1OpLowering, 281a54f4eaeSMogball FloorOpLowering, 282a54f4eaeSMogball FmaOpLowering, 28326e59cc1SAlex Zinenko Log10OpLowering, 28426e59cc1SAlex Zinenko Log1pOpLowering, 28526e59cc1SAlex Zinenko Log2OpLowering, 28626e59cc1SAlex Zinenko LogOpLowering, 28726e59cc1SAlex Zinenko PowFOpLowering, 28826e59cc1SAlex Zinenko RsqrtOpLowering, 28926e59cc1SAlex Zinenko SinOpLowering, 290*a0fc94abSlorenzo chelini SqrtOpLowering, 291*a0fc94abSlorenzo chelini RoundOpLowering 29226e59cc1SAlex Zinenko >(converter); 29326e59cc1SAlex Zinenko // clang-format on 29426e59cc1SAlex Zinenko } 29526e59cc1SAlex Zinenko 29626e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 29726e59cc1SAlex Zinenko return std::make_unique<ConvertMathToLLVMPass>(); 29826e59cc1SAlex Zinenko } 299