126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 226e59cc1SAlex Zinenko // 326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 626e59cc1SAlex Zinenko // 726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===// 826e59cc1SAlex Zinenko 926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 1026e59cc1SAlex Zinenko #include "../PassDetail.h" 1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h" 1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h" 1726e59cc1SAlex Zinenko 1826e59cc1SAlex Zinenko using namespace mlir; 1926e59cc1SAlex Zinenko 2026e59cc1SAlex Zinenko namespace { 21a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>; 22a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>; 23a54f4eaeSMogball using CopySignOpLowering = 24a54f4eaeSMogball VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>; 2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 2626e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 2726e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 28a54f4eaeSMogball using FloorOpLowering = 29a54f4eaeSMogball VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>; 30a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>; 3126e59cc1SAlex Zinenko using Log10OpLowering = 3226e59cc1SAlex Zinenko VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 3326e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 3426e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 3526e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 3626e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 3726e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 3826e59cc1SAlex Zinenko 3926e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`. 4026e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 4126e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 4226e59cc1SAlex Zinenko 4326e59cc1SAlex Zinenko LogicalResult 44ef976337SRiver Riddle matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 4526e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 46*62fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 4726e59cc1SAlex Zinenko 4826e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 4926e59cc1SAlex Zinenko return failure(); 5026e59cc1SAlex Zinenko 5126e59cc1SAlex Zinenko auto loc = op.getLoc(); 5226e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 5326e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 5426e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 5526e59cc1SAlex Zinenko 5626e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 5726e59cc1SAlex Zinenko LLVM::ConstantOp one; 5826e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 5926e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 6026e59cc1SAlex Zinenko loc, operandType, 6126e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 6226e59cc1SAlex Zinenko } else { 6326e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 6426e59cc1SAlex Zinenko } 65*62fea88bSJacques Pienaar auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand()); 6626e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 6726e59cc1SAlex Zinenko return success(); 6826e59cc1SAlex Zinenko } 6926e59cc1SAlex Zinenko 7026e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 7126e59cc1SAlex Zinenko if (!vectorType) 7226e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 7326e59cc1SAlex Zinenko 7426e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 75ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 7626e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 7726e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 7826e59cc1SAlex Zinenko mlir::VectorType::get( 7926e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 8026e59cc1SAlex Zinenko floatType), 8126e59cc1SAlex Zinenko floatOne); 8226e59cc1SAlex Zinenko auto one = 8326e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 8426e59cc1SAlex Zinenko auto exp = 8526e59cc1SAlex Zinenko rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 8626e59cc1SAlex Zinenko return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 8726e59cc1SAlex Zinenko }, 8826e59cc1SAlex Zinenko rewriter); 8926e59cc1SAlex Zinenko } 9026e59cc1SAlex Zinenko }; 9126e59cc1SAlex Zinenko 9226e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`. 9326e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 9426e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 9526e59cc1SAlex Zinenko 9626e59cc1SAlex Zinenko LogicalResult 97ef976337SRiver Riddle matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 9826e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 99*62fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 10026e59cc1SAlex Zinenko 10126e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 10226e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "unsupported operand type"); 10326e59cc1SAlex Zinenko 10426e59cc1SAlex Zinenko auto loc = op.getLoc(); 10526e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 10626e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 10726e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 10826e59cc1SAlex Zinenko 10926e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 11026e59cc1SAlex Zinenko LLVM::ConstantOp one = 11126e59cc1SAlex Zinenko LLVM::isCompatibleVectorType(operandType) 11226e59cc1SAlex Zinenko ? rewriter.create<LLVM::ConstantOp>( 11326e59cc1SAlex Zinenko loc, operandType, 11426e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), 11526e59cc1SAlex Zinenko floatOne)) 11626e59cc1SAlex Zinenko : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 11726e59cc1SAlex Zinenko 11826e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 119*62fea88bSJacques Pienaar adaptor.getOperand()); 12026e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 12126e59cc1SAlex Zinenko return success(); 12226e59cc1SAlex Zinenko } 12326e59cc1SAlex Zinenko 12426e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 12526e59cc1SAlex Zinenko if (!vectorType) 12626e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 12726e59cc1SAlex Zinenko 12826e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 129ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 13026e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 13126e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 13226e59cc1SAlex Zinenko mlir::VectorType::get( 13326e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 13426e59cc1SAlex Zinenko floatType), 13526e59cc1SAlex Zinenko floatOne); 13626e59cc1SAlex Zinenko auto one = 13726e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 13826e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 13926e59cc1SAlex Zinenko operands[0]); 14026e59cc1SAlex Zinenko return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 14126e59cc1SAlex Zinenko }, 14226e59cc1SAlex Zinenko rewriter); 14326e59cc1SAlex Zinenko } 14426e59cc1SAlex Zinenko }; 14526e59cc1SAlex Zinenko 14626e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`. 14726e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 14826e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 14926e59cc1SAlex Zinenko 15026e59cc1SAlex Zinenko LogicalResult 151ef976337SRiver Riddle matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 15226e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 153*62fea88bSJacques Pienaar auto operandType = adaptor.getOperand().getType(); 15426e59cc1SAlex Zinenko 15526e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 15626e59cc1SAlex Zinenko return failure(); 15726e59cc1SAlex Zinenko 15826e59cc1SAlex Zinenko auto loc = op.getLoc(); 15926e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 16026e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 16126e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 16226e59cc1SAlex Zinenko 16326e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 16426e59cc1SAlex Zinenko LLVM::ConstantOp one; 16526e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 16626e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 16726e59cc1SAlex Zinenko loc, operandType, 16826e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 16926e59cc1SAlex Zinenko } else { 17026e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 17126e59cc1SAlex Zinenko } 172*62fea88bSJacques Pienaar auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand()); 17326e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 17426e59cc1SAlex Zinenko return success(); 17526e59cc1SAlex Zinenko } 17626e59cc1SAlex Zinenko 17726e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 17826e59cc1SAlex Zinenko if (!vectorType) 17926e59cc1SAlex Zinenko return failure(); 18026e59cc1SAlex Zinenko 18126e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 182ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 18326e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 18426e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 18526e59cc1SAlex Zinenko mlir::VectorType::get( 18626e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 18726e59cc1SAlex Zinenko floatType), 18826e59cc1SAlex Zinenko floatOne); 18926e59cc1SAlex Zinenko auto one = 19026e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 19126e59cc1SAlex Zinenko auto sqrt = 19226e59cc1SAlex Zinenko rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 19326e59cc1SAlex Zinenko return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 19426e59cc1SAlex Zinenko }, 19526e59cc1SAlex Zinenko rewriter); 19626e59cc1SAlex Zinenko } 19726e59cc1SAlex Zinenko }; 19826e59cc1SAlex Zinenko 19926e59cc1SAlex Zinenko struct ConvertMathToLLVMPass 20026e59cc1SAlex Zinenko : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 20126e59cc1SAlex Zinenko ConvertMathToLLVMPass() = default; 20226e59cc1SAlex Zinenko 20326e59cc1SAlex Zinenko void runOnFunction() override { 20426e59cc1SAlex Zinenko RewritePatternSet patterns(&getContext()); 20526e59cc1SAlex Zinenko LLVMTypeConverter converter(&getContext()); 20626e59cc1SAlex Zinenko populateMathToLLVMConversionPatterns(converter, patterns); 20726e59cc1SAlex Zinenko LLVMConversionTarget target(getContext()); 20826e59cc1SAlex Zinenko if (failed( 20926e59cc1SAlex Zinenko applyPartialConversion(getFunction(), target, std::move(patterns)))) 21026e59cc1SAlex Zinenko signalPassFailure(); 21126e59cc1SAlex Zinenko } 21226e59cc1SAlex Zinenko }; 21326e59cc1SAlex Zinenko } // namespace 21426e59cc1SAlex Zinenko 21526e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 21626e59cc1SAlex Zinenko RewritePatternSet &patterns) { 21726e59cc1SAlex Zinenko // clang-format off 21826e59cc1SAlex Zinenko patterns.add< 219a54f4eaeSMogball AbsOpLowering, 220a54f4eaeSMogball CeilOpLowering, 221a54f4eaeSMogball CopySignOpLowering, 22226e59cc1SAlex Zinenko CosOpLowering, 22326e59cc1SAlex Zinenko ExpOpLowering, 22426e59cc1SAlex Zinenko Exp2OpLowering, 22526e59cc1SAlex Zinenko ExpM1OpLowering, 226a54f4eaeSMogball FloorOpLowering, 227a54f4eaeSMogball FmaOpLowering, 22826e59cc1SAlex Zinenko Log10OpLowering, 22926e59cc1SAlex Zinenko Log1pOpLowering, 23026e59cc1SAlex Zinenko Log2OpLowering, 23126e59cc1SAlex Zinenko LogOpLowering, 23226e59cc1SAlex Zinenko PowFOpLowering, 23326e59cc1SAlex Zinenko RsqrtOpLowering, 23426e59cc1SAlex Zinenko SinOpLowering, 23526e59cc1SAlex Zinenko SqrtOpLowering 23626e59cc1SAlex Zinenko >(converter); 23726e59cc1SAlex Zinenko // clang-format on 23826e59cc1SAlex Zinenko } 23926e59cc1SAlex Zinenko 24026e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 24126e59cc1SAlex Zinenko return std::make_unique<ConvertMathToLLVMPass>(); 24226e59cc1SAlex Zinenko } 243