126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 226e59cc1SAlex Zinenko // 326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 626e59cc1SAlex Zinenko // 726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===// 826e59cc1SAlex Zinenko 926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 1026e59cc1SAlex Zinenko #include "../PassDetail.h" 1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h" 1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h" 1726e59cc1SAlex Zinenko 1826e59cc1SAlex Zinenko using namespace mlir; 1926e59cc1SAlex Zinenko 2026e59cc1SAlex Zinenko namespace { 21a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>; 22a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>; 23a54f4eaeSMogball using CopySignOpLowering = 24a54f4eaeSMogball VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>; 2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 26*c5fef77bSRob Suderman using CtPopFOpLowering = 27*c5fef77bSRob Suderman VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>; 2826e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 2926e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 30a54f4eaeSMogball using FloorOpLowering = 31a54f4eaeSMogball VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>; 32a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>; 3326e59cc1SAlex Zinenko using Log10OpLowering = 3426e59cc1SAlex Zinenko VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 3526e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 3626e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 3726e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 3826e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 3926e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 4026e59cc1SAlex Zinenko 4126e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`. 4226e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 4326e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 4426e59cc1SAlex Zinenko 4526e59cc1SAlex Zinenko LogicalResult 46ef976337SRiver Riddle matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 4726e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 4862fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 4926e59cc1SAlex Zinenko 5026e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 5126e59cc1SAlex Zinenko return failure(); 5226e59cc1SAlex Zinenko 5326e59cc1SAlex Zinenko auto loc = op.getLoc(); 5426e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 5526e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 5626e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 5726e59cc1SAlex Zinenko 5826e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 5926e59cc1SAlex Zinenko LLVM::ConstantOp one; 6026e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 6126e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 6226e59cc1SAlex Zinenko loc, operandType, 6326e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 6426e59cc1SAlex Zinenko } else { 6526e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 6626e59cc1SAlex Zinenko } 6762fea88bSJacques Pienaar auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand()); 6826e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 6926e59cc1SAlex Zinenko return success(); 7026e59cc1SAlex Zinenko } 7126e59cc1SAlex Zinenko 7226e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 7326e59cc1SAlex Zinenko if (!vectorType) 7426e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 7526e59cc1SAlex Zinenko 7626e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 77ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 7826e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 7926e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 8026e59cc1SAlex Zinenko mlir::VectorType::get( 8126e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 8226e59cc1SAlex Zinenko floatType), 8326e59cc1SAlex Zinenko floatOne); 8426e59cc1SAlex Zinenko auto one = 8526e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 8626e59cc1SAlex Zinenko auto exp = 8726e59cc1SAlex Zinenko rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 8826e59cc1SAlex Zinenko return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 8926e59cc1SAlex Zinenko }, 9026e59cc1SAlex Zinenko rewriter); 9126e59cc1SAlex Zinenko } 9226e59cc1SAlex Zinenko }; 9326e59cc1SAlex Zinenko 9426e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`. 9526e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 9626e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 9726e59cc1SAlex Zinenko 9826e59cc1SAlex Zinenko LogicalResult 99ef976337SRiver Riddle matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 10026e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 10162fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 10226e59cc1SAlex Zinenko 10326e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 10426e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "unsupported operand type"); 10526e59cc1SAlex Zinenko 10626e59cc1SAlex Zinenko auto loc = op.getLoc(); 10726e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 10826e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 10926e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 11026e59cc1SAlex Zinenko 11126e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 11226e59cc1SAlex Zinenko LLVM::ConstantOp one = 11326e59cc1SAlex Zinenko LLVM::isCompatibleVectorType(operandType) 11426e59cc1SAlex Zinenko ? rewriter.create<LLVM::ConstantOp>( 11526e59cc1SAlex Zinenko loc, operandType, 11626e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), 11726e59cc1SAlex Zinenko floatOne)) 11826e59cc1SAlex Zinenko : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 11926e59cc1SAlex Zinenko 12026e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 12162fea88bSJacques Pienaar adaptor.getOperand()); 12226e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 12326e59cc1SAlex Zinenko return success(); 12426e59cc1SAlex Zinenko } 12526e59cc1SAlex Zinenko 12626e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 12726e59cc1SAlex Zinenko if (!vectorType) 12826e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 12926e59cc1SAlex Zinenko 13026e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 131ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 13226e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 13326e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 13426e59cc1SAlex Zinenko mlir::VectorType::get( 13526e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 13626e59cc1SAlex Zinenko floatType), 13726e59cc1SAlex Zinenko floatOne); 13826e59cc1SAlex Zinenko auto one = 13926e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 14026e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 14126e59cc1SAlex Zinenko operands[0]); 14226e59cc1SAlex Zinenko return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 14326e59cc1SAlex Zinenko }, 14426e59cc1SAlex Zinenko rewriter); 14526e59cc1SAlex Zinenko } 14626e59cc1SAlex Zinenko }; 14726e59cc1SAlex Zinenko 14826e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`. 14926e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 15026e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 15126e59cc1SAlex Zinenko 15226e59cc1SAlex Zinenko LogicalResult 153ef976337SRiver Riddle matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 15426e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 15562fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 15626e59cc1SAlex Zinenko 15726e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 15826e59cc1SAlex Zinenko return failure(); 15926e59cc1SAlex Zinenko 16026e59cc1SAlex Zinenko auto loc = op.getLoc(); 16126e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 16226e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 16326e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 16426e59cc1SAlex Zinenko 16526e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 16626e59cc1SAlex Zinenko LLVM::ConstantOp one; 16726e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 16826e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 16926e59cc1SAlex Zinenko loc, operandType, 17026e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 17126e59cc1SAlex Zinenko } else { 17226e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 17326e59cc1SAlex Zinenko } 17462fea88bSJacques Pienaar auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand()); 17526e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 17626e59cc1SAlex Zinenko return success(); 17726e59cc1SAlex Zinenko } 17826e59cc1SAlex Zinenko 17926e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 18026e59cc1SAlex Zinenko if (!vectorType) 18126e59cc1SAlex Zinenko return failure(); 18226e59cc1SAlex Zinenko 18326e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 184ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 18526e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 18626e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 18726e59cc1SAlex Zinenko mlir::VectorType::get( 18826e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 18926e59cc1SAlex Zinenko floatType), 19026e59cc1SAlex Zinenko floatOne); 19126e59cc1SAlex Zinenko auto one = 19226e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 19326e59cc1SAlex Zinenko auto sqrt = 19426e59cc1SAlex Zinenko rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 19526e59cc1SAlex Zinenko return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 19626e59cc1SAlex Zinenko }, 19726e59cc1SAlex Zinenko rewriter); 19826e59cc1SAlex Zinenko } 19926e59cc1SAlex Zinenko }; 20026e59cc1SAlex Zinenko 20126e59cc1SAlex Zinenko struct ConvertMathToLLVMPass 20226e59cc1SAlex Zinenko : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 20326e59cc1SAlex Zinenko ConvertMathToLLVMPass() = default; 20426e59cc1SAlex Zinenko 20526e59cc1SAlex Zinenko void runOnFunction() override { 20626e59cc1SAlex Zinenko RewritePatternSet patterns(&getContext()); 20726e59cc1SAlex Zinenko LLVMTypeConverter converter(&getContext()); 20826e59cc1SAlex Zinenko populateMathToLLVMConversionPatterns(converter, patterns); 20926e59cc1SAlex Zinenko LLVMConversionTarget target(getContext()); 21026e59cc1SAlex Zinenko if (failed( 21126e59cc1SAlex Zinenko applyPartialConversion(getFunction(), target, std::move(patterns)))) 21226e59cc1SAlex Zinenko signalPassFailure(); 21326e59cc1SAlex Zinenko } 21426e59cc1SAlex Zinenko }; 21526e59cc1SAlex Zinenko } // namespace 21626e59cc1SAlex Zinenko 21726e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 21826e59cc1SAlex Zinenko RewritePatternSet &patterns) { 21926e59cc1SAlex Zinenko // clang-format off 22026e59cc1SAlex Zinenko patterns.add< 221a54f4eaeSMogball AbsOpLowering, 222a54f4eaeSMogball CeilOpLowering, 223a54f4eaeSMogball CopySignOpLowering, 22426e59cc1SAlex Zinenko CosOpLowering, 225*c5fef77bSRob Suderman CtPopFOpLowering, 22626e59cc1SAlex Zinenko ExpOpLowering, 22726e59cc1SAlex Zinenko Exp2OpLowering, 22826e59cc1SAlex Zinenko ExpM1OpLowering, 229a54f4eaeSMogball FloorOpLowering, 230a54f4eaeSMogball FmaOpLowering, 23126e59cc1SAlex Zinenko Log10OpLowering, 23226e59cc1SAlex Zinenko Log1pOpLowering, 23326e59cc1SAlex Zinenko Log2OpLowering, 23426e59cc1SAlex Zinenko LogOpLowering, 23526e59cc1SAlex Zinenko PowFOpLowering, 23626e59cc1SAlex Zinenko RsqrtOpLowering, 23726e59cc1SAlex Zinenko SinOpLowering, 23826e59cc1SAlex Zinenko SqrtOpLowering 23926e59cc1SAlex Zinenko >(converter); 24026e59cc1SAlex Zinenko // clang-format on 24126e59cc1SAlex Zinenko } 24226e59cc1SAlex Zinenko 24326e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 24426e59cc1SAlex Zinenko return std::make_unique<ConvertMathToLLVMPass>(); 24526e59cc1SAlex Zinenko } 246