126e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 226e59cc1SAlex Zinenko // 326e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 426e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 526e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 626e59cc1SAlex Zinenko // 726e59cc1SAlex Zinenko //===----------------------------------------------------------------------===// 826e59cc1SAlex Zinenko 926e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 1026e59cc1SAlex Zinenko #include "../PassDetail.h" 1126e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1226e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 1326e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 1426e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1526e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h" 1626e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h" 1726e59cc1SAlex Zinenko 1826e59cc1SAlex Zinenko using namespace mlir; 1926e59cc1SAlex Zinenko 2026e59cc1SAlex Zinenko namespace { 21*a54f4eaeSMogball using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>; 22*a54f4eaeSMogball using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>; 23*a54f4eaeSMogball using CopySignOpLowering = 24*a54f4eaeSMogball VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>; 2526e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 2626e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 2726e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 28*a54f4eaeSMogball using FloorOpLowering = 29*a54f4eaeSMogball VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>; 30*a54f4eaeSMogball using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>; 3126e59cc1SAlex Zinenko using Log10OpLowering = 3226e59cc1SAlex Zinenko VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 3326e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 3426e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 3526e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 3626e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 3726e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 3826e59cc1SAlex Zinenko 3926e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`. 4026e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 4126e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 4226e59cc1SAlex Zinenko 4326e59cc1SAlex Zinenko LogicalResult 44ef976337SRiver Riddle matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 4526e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 46ef976337SRiver Riddle auto operandType = adaptor.operand().getType(); 4726e59cc1SAlex Zinenko 4826e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 4926e59cc1SAlex Zinenko return failure(); 5026e59cc1SAlex Zinenko 5126e59cc1SAlex Zinenko auto loc = op.getLoc(); 5226e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 5326e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 5426e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 5526e59cc1SAlex Zinenko 5626e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 5726e59cc1SAlex Zinenko LLVM::ConstantOp one; 5826e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 5926e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 6026e59cc1SAlex Zinenko loc, operandType, 6126e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 6226e59cc1SAlex Zinenko } else { 6326e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 6426e59cc1SAlex Zinenko } 65ef976337SRiver Riddle auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.operand()); 6626e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 6726e59cc1SAlex Zinenko return success(); 6826e59cc1SAlex Zinenko } 6926e59cc1SAlex Zinenko 7026e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 7126e59cc1SAlex Zinenko if (!vectorType) 7226e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 7326e59cc1SAlex Zinenko 7426e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 75ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 7626e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 7726e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 7826e59cc1SAlex Zinenko mlir::VectorType::get( 7926e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 8026e59cc1SAlex Zinenko floatType), 8126e59cc1SAlex Zinenko floatOne); 8226e59cc1SAlex Zinenko auto one = 8326e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 8426e59cc1SAlex Zinenko auto exp = 8526e59cc1SAlex Zinenko rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 8626e59cc1SAlex Zinenko return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 8726e59cc1SAlex Zinenko }, 8826e59cc1SAlex Zinenko rewriter); 8926e59cc1SAlex Zinenko } 9026e59cc1SAlex Zinenko }; 9126e59cc1SAlex Zinenko 9226e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`. 9326e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 9426e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 9526e59cc1SAlex Zinenko 9626e59cc1SAlex Zinenko LogicalResult 97ef976337SRiver Riddle matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 9826e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 99ef976337SRiver Riddle auto operandType = adaptor.operand().getType(); 10026e59cc1SAlex Zinenko 10126e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 10226e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "unsupported operand type"); 10326e59cc1SAlex Zinenko 10426e59cc1SAlex Zinenko auto loc = op.getLoc(); 10526e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 10626e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 10726e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 10826e59cc1SAlex Zinenko 10926e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 11026e59cc1SAlex Zinenko LLVM::ConstantOp one = 11126e59cc1SAlex Zinenko LLVM::isCompatibleVectorType(operandType) 11226e59cc1SAlex Zinenko ? rewriter.create<LLVM::ConstantOp>( 11326e59cc1SAlex Zinenko loc, operandType, 11426e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), 11526e59cc1SAlex Zinenko floatOne)) 11626e59cc1SAlex Zinenko : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 11726e59cc1SAlex Zinenko 11826e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 119ef976337SRiver Riddle adaptor.operand()); 12026e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 12126e59cc1SAlex Zinenko return success(); 12226e59cc1SAlex Zinenko } 12326e59cc1SAlex Zinenko 12426e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 12526e59cc1SAlex Zinenko if (!vectorType) 12626e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 12726e59cc1SAlex Zinenko 12826e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 129ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 13026e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 13126e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 13226e59cc1SAlex Zinenko mlir::VectorType::get( 13326e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 13426e59cc1SAlex Zinenko floatType), 13526e59cc1SAlex Zinenko floatOne); 13626e59cc1SAlex Zinenko auto one = 13726e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 13826e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 13926e59cc1SAlex Zinenko operands[0]); 14026e59cc1SAlex Zinenko return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 14126e59cc1SAlex Zinenko }, 14226e59cc1SAlex Zinenko rewriter); 14326e59cc1SAlex Zinenko } 14426e59cc1SAlex Zinenko }; 14526e59cc1SAlex Zinenko 14626e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`. 14726e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 14826e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 14926e59cc1SAlex Zinenko 15026e59cc1SAlex Zinenko LogicalResult 151ef976337SRiver Riddle matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 15226e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 153ef976337SRiver Riddle auto operandType = adaptor.operand().getType(); 15426e59cc1SAlex Zinenko 15526e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 15626e59cc1SAlex Zinenko return failure(); 15726e59cc1SAlex Zinenko 15826e59cc1SAlex Zinenko auto loc = op.getLoc(); 15926e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 16026e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 16126e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 16226e59cc1SAlex Zinenko 16326e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 16426e59cc1SAlex Zinenko LLVM::ConstantOp one; 16526e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 16626e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 16726e59cc1SAlex Zinenko loc, operandType, 16826e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 16926e59cc1SAlex Zinenko } else { 17026e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 17126e59cc1SAlex Zinenko } 172ef976337SRiver Riddle auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.operand()); 17326e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 17426e59cc1SAlex Zinenko return success(); 17526e59cc1SAlex Zinenko } 17626e59cc1SAlex Zinenko 17726e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 17826e59cc1SAlex Zinenko if (!vectorType) 17926e59cc1SAlex Zinenko return failure(); 18026e59cc1SAlex Zinenko 18126e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 182ef976337SRiver Riddle op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 18326e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 18426e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 18526e59cc1SAlex Zinenko mlir::VectorType::get( 18626e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 18726e59cc1SAlex Zinenko floatType), 18826e59cc1SAlex Zinenko floatOne); 18926e59cc1SAlex Zinenko auto one = 19026e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 19126e59cc1SAlex Zinenko auto sqrt = 19226e59cc1SAlex Zinenko rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 19326e59cc1SAlex Zinenko return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 19426e59cc1SAlex Zinenko }, 19526e59cc1SAlex Zinenko rewriter); 19626e59cc1SAlex Zinenko } 19726e59cc1SAlex Zinenko }; 19826e59cc1SAlex Zinenko 19926e59cc1SAlex Zinenko struct ConvertMathToLLVMPass 20026e59cc1SAlex Zinenko : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 20126e59cc1SAlex Zinenko ConvertMathToLLVMPass() = default; 20226e59cc1SAlex Zinenko 20326e59cc1SAlex Zinenko void runOnFunction() override { 20426e59cc1SAlex Zinenko RewritePatternSet patterns(&getContext()); 20526e59cc1SAlex Zinenko LLVMTypeConverter converter(&getContext()); 20626e59cc1SAlex Zinenko populateMathToLLVMConversionPatterns(converter, patterns); 20726e59cc1SAlex Zinenko LLVMConversionTarget target(getContext()); 20826e59cc1SAlex Zinenko if (failed( 20926e59cc1SAlex Zinenko applyPartialConversion(getFunction(), target, std::move(patterns)))) 21026e59cc1SAlex Zinenko signalPassFailure(); 21126e59cc1SAlex Zinenko } 21226e59cc1SAlex Zinenko }; 21326e59cc1SAlex Zinenko } // namespace 21426e59cc1SAlex Zinenko 21526e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 21626e59cc1SAlex Zinenko RewritePatternSet &patterns) { 21726e59cc1SAlex Zinenko // clang-format off 21826e59cc1SAlex Zinenko patterns.add< 219*a54f4eaeSMogball AbsOpLowering, 220*a54f4eaeSMogball CeilOpLowering, 221*a54f4eaeSMogball CopySignOpLowering, 22226e59cc1SAlex Zinenko CosOpLowering, 22326e59cc1SAlex Zinenko ExpOpLowering, 22426e59cc1SAlex Zinenko Exp2OpLowering, 22526e59cc1SAlex Zinenko ExpM1OpLowering, 226*a54f4eaeSMogball FloorOpLowering, 227*a54f4eaeSMogball FmaOpLowering, 22826e59cc1SAlex Zinenko Log10OpLowering, 22926e59cc1SAlex Zinenko Log1pOpLowering, 23026e59cc1SAlex Zinenko Log2OpLowering, 23126e59cc1SAlex Zinenko LogOpLowering, 23226e59cc1SAlex Zinenko PowFOpLowering, 23326e59cc1SAlex Zinenko RsqrtOpLowering, 23426e59cc1SAlex Zinenko SinOpLowering, 23526e59cc1SAlex Zinenko SqrtOpLowering 23626e59cc1SAlex Zinenko >(converter); 23726e59cc1SAlex Zinenko // clang-format on 23826e59cc1SAlex Zinenko } 23926e59cc1SAlex Zinenko 24026e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 24126e59cc1SAlex Zinenko return std::make_unique<ConvertMathToLLVMPass>(); 24226e59cc1SAlex Zinenko } 243