126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 226e59cc1SAlex Zinenko // 326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 626e59cc1SAlex Zinenko // 726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===// 826e59cc1SAlex Zinenko 926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 1026e59cc1SAlex Zinenko #include "../PassDetail.h" 1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h" 1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h" 1726e59cc1SAlex Zinenko 1826e59cc1SAlex Zinenko using namespace mlir; 1926e59cc1SAlex Zinenko 2026e59cc1SAlex Zinenko namespace { 2126e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 2226e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 2326e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 2426e59cc1SAlex Zinenko using Log10OpLowering = 2526e59cc1SAlex Zinenko VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 2626e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 2726e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 2826e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 2926e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 3026e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 3126e59cc1SAlex Zinenko 3226e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`. 3326e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 3426e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 3526e59cc1SAlex Zinenko 3626e59cc1SAlex Zinenko LogicalResult 37*ef976337SRiver Riddle matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 3826e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 39*ef976337SRiver Riddle auto operandType = adaptor.operand().getType(); 4026e59cc1SAlex Zinenko 4126e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 4226e59cc1SAlex Zinenko return failure(); 4326e59cc1SAlex Zinenko 4426e59cc1SAlex Zinenko auto loc = op.getLoc(); 4526e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 4626e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 4726e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 4826e59cc1SAlex Zinenko 4926e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 5026e59cc1SAlex Zinenko LLVM::ConstantOp one; 5126e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 5226e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 5326e59cc1SAlex Zinenko loc, operandType, 5426e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 5526e59cc1SAlex Zinenko } else { 5626e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 5726e59cc1SAlex Zinenko } 58*ef976337SRiver Riddle auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.operand()); 5926e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 6026e59cc1SAlex Zinenko return success(); 6126e59cc1SAlex Zinenko } 6226e59cc1SAlex Zinenko 6326e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 6426e59cc1SAlex Zinenko if (!vectorType) 6526e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 6626e59cc1SAlex Zinenko 6726e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 68*ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 6926e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 7026e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 7126e59cc1SAlex Zinenko mlir::VectorType::get( 7226e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 7326e59cc1SAlex Zinenko floatType), 7426e59cc1SAlex Zinenko floatOne); 7526e59cc1SAlex Zinenko auto one = 7626e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 7726e59cc1SAlex Zinenko auto exp = 7826e59cc1SAlex Zinenko rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 7926e59cc1SAlex Zinenko return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 8026e59cc1SAlex Zinenko }, 8126e59cc1SAlex Zinenko rewriter); 8226e59cc1SAlex Zinenko } 8326e59cc1SAlex Zinenko }; 8426e59cc1SAlex Zinenko 8526e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`. 8626e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 8726e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 8826e59cc1SAlex Zinenko 8926e59cc1SAlex Zinenko LogicalResult 90*ef976337SRiver Riddle matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 9126e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 92*ef976337SRiver Riddle auto operandType = adaptor.operand().getType(); 9326e59cc1SAlex Zinenko 9426e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 9526e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "unsupported operand type"); 9626e59cc1SAlex Zinenko 9726e59cc1SAlex Zinenko auto loc = op.getLoc(); 9826e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 9926e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 10026e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 10126e59cc1SAlex Zinenko 10226e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 10326e59cc1SAlex Zinenko LLVM::ConstantOp one = 10426e59cc1SAlex Zinenko LLVM::isCompatibleVectorType(operandType) 10526e59cc1SAlex Zinenko ? rewriter.create<LLVM::ConstantOp>( 10626e59cc1SAlex Zinenko loc, operandType, 10726e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), 10826e59cc1SAlex Zinenko floatOne)) 10926e59cc1SAlex Zinenko : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 11026e59cc1SAlex Zinenko 11126e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 112*ef976337SRiver Riddle adaptor.operand()); 11326e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 11426e59cc1SAlex Zinenko return success(); 11526e59cc1SAlex Zinenko } 11626e59cc1SAlex Zinenko 11726e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 11826e59cc1SAlex Zinenko if (!vectorType) 11926e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 12026e59cc1SAlex Zinenko 12126e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 122*ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 12326e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 12426e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 12526e59cc1SAlex Zinenko mlir::VectorType::get( 12626e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 12726e59cc1SAlex Zinenko floatType), 12826e59cc1SAlex Zinenko floatOne); 12926e59cc1SAlex Zinenko auto one = 13026e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 13126e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 13226e59cc1SAlex Zinenko operands[0]); 13326e59cc1SAlex Zinenko return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 13426e59cc1SAlex Zinenko }, 13526e59cc1SAlex Zinenko rewriter); 13626e59cc1SAlex Zinenko } 13726e59cc1SAlex Zinenko }; 13826e59cc1SAlex Zinenko 13926e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`. 14026e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 14126e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 14226e59cc1SAlex Zinenko 14326e59cc1SAlex Zinenko LogicalResult 144*ef976337SRiver Riddle matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 14526e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 146*ef976337SRiver Riddle auto operandType = adaptor.operand().getType(); 14726e59cc1SAlex Zinenko 14826e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 14926e59cc1SAlex Zinenko return failure(); 15026e59cc1SAlex Zinenko 15126e59cc1SAlex Zinenko auto loc = op.getLoc(); 15226e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 15326e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 15426e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 15526e59cc1SAlex Zinenko 15626e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 15726e59cc1SAlex Zinenko LLVM::ConstantOp one; 15826e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 15926e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 16026e59cc1SAlex Zinenko loc, operandType, 16126e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 16226e59cc1SAlex Zinenko } else { 16326e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 16426e59cc1SAlex Zinenko } 165*ef976337SRiver Riddle auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.operand()); 16626e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 16726e59cc1SAlex Zinenko return success(); 16826e59cc1SAlex Zinenko } 16926e59cc1SAlex Zinenko 17026e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 17126e59cc1SAlex Zinenko if (!vectorType) 17226e59cc1SAlex Zinenko return failure(); 17326e59cc1SAlex Zinenko 17426e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 175*ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 17626e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 17726e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 17826e59cc1SAlex Zinenko mlir::VectorType::get( 17926e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 18026e59cc1SAlex Zinenko floatType), 18126e59cc1SAlex Zinenko floatOne); 18226e59cc1SAlex Zinenko auto one = 18326e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 18426e59cc1SAlex Zinenko auto sqrt = 18526e59cc1SAlex Zinenko rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 18626e59cc1SAlex Zinenko return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 18726e59cc1SAlex Zinenko }, 18826e59cc1SAlex Zinenko rewriter); 18926e59cc1SAlex Zinenko } 19026e59cc1SAlex Zinenko }; 19126e59cc1SAlex Zinenko 19226e59cc1SAlex Zinenko struct ConvertMathToLLVMPass 19326e59cc1SAlex Zinenko : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 19426e59cc1SAlex Zinenko ConvertMathToLLVMPass() = default; 19526e59cc1SAlex Zinenko 19626e59cc1SAlex Zinenko void runOnFunction() override { 19726e59cc1SAlex Zinenko RewritePatternSet patterns(&getContext()); 19826e59cc1SAlex Zinenko LLVMTypeConverter converter(&getContext()); 19926e59cc1SAlex Zinenko populateMathToLLVMConversionPatterns(converter, patterns); 20026e59cc1SAlex Zinenko LLVMConversionTarget target(getContext()); 20126e59cc1SAlex Zinenko if (failed( 20226e59cc1SAlex Zinenko applyPartialConversion(getFunction(), target, std::move(patterns)))) 20326e59cc1SAlex Zinenko signalPassFailure(); 20426e59cc1SAlex Zinenko } 20526e59cc1SAlex Zinenko }; 20626e59cc1SAlex Zinenko } // namespace 20726e59cc1SAlex Zinenko 20826e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 20926e59cc1SAlex Zinenko RewritePatternSet &patterns) { 21026e59cc1SAlex Zinenko // clang-format off 21126e59cc1SAlex Zinenko patterns.add< 21226e59cc1SAlex Zinenko CosOpLowering, 21326e59cc1SAlex Zinenko ExpOpLowering, 21426e59cc1SAlex Zinenko Exp2OpLowering, 21526e59cc1SAlex Zinenko ExpM1OpLowering, 21626e59cc1SAlex Zinenko Log10OpLowering, 21726e59cc1SAlex Zinenko Log1pOpLowering, 21826e59cc1SAlex Zinenko Log2OpLowering, 21926e59cc1SAlex Zinenko LogOpLowering, 22026e59cc1SAlex Zinenko PowFOpLowering, 22126e59cc1SAlex Zinenko RsqrtOpLowering, 22226e59cc1SAlex Zinenko SinOpLowering, 22326e59cc1SAlex Zinenko SqrtOpLowering 22426e59cc1SAlex Zinenko >(converter); 22526e59cc1SAlex Zinenko // clang-format on 22626e59cc1SAlex Zinenko } 22726e59cc1SAlex Zinenko 22826e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 22926e59cc1SAlex Zinenko return std::make_unique<ConvertMathToLLVMPass>(); 23026e59cc1SAlex Zinenko } 231