1*26e59cc1SAlex Zinenko //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 2*26e59cc1SAlex Zinenko // 3*26e59cc1SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*26e59cc1SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5*26e59cc1SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*26e59cc1SAlex Zinenko // 7*26e59cc1SAlex Zinenko //===----------------------------------------------------------------------===// 8*26e59cc1SAlex Zinenko 9*26e59cc1SAlex Zinenko #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 10*26e59cc1SAlex Zinenko #include "../PassDetail.h" 11*26e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12*26e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 13*26e59cc1SAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 14*26e59cc1SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15*26e59cc1SAlex Zinenko #include "mlir/Dialect/Math/IR/Math.h" 16*26e59cc1SAlex Zinenko #include "mlir/IR/TypeUtilities.h" 17*26e59cc1SAlex Zinenko 18*26e59cc1SAlex Zinenko using namespace mlir; 19*26e59cc1SAlex Zinenko 20*26e59cc1SAlex Zinenko namespace { 21*26e59cc1SAlex Zinenko using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 22*26e59cc1SAlex Zinenko using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 23*26e59cc1SAlex Zinenko using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 24*26e59cc1SAlex Zinenko using Log10OpLowering = 25*26e59cc1SAlex Zinenko VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 26*26e59cc1SAlex Zinenko using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 27*26e59cc1SAlex Zinenko using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 28*26e59cc1SAlex Zinenko using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 29*26e59cc1SAlex Zinenko using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 30*26e59cc1SAlex Zinenko using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 31*26e59cc1SAlex Zinenko 32*26e59cc1SAlex Zinenko // A `expm1` is converted into `exp - 1`. 33*26e59cc1SAlex Zinenko struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 34*26e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 35*26e59cc1SAlex Zinenko 36*26e59cc1SAlex Zinenko LogicalResult 37*26e59cc1SAlex Zinenko matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands, 38*26e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 39*26e59cc1SAlex Zinenko math::ExpM1Op::Adaptor transformed(operands); 40*26e59cc1SAlex Zinenko auto operandType = transformed.operand().getType(); 41*26e59cc1SAlex Zinenko 42*26e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 43*26e59cc1SAlex Zinenko return failure(); 44*26e59cc1SAlex Zinenko 45*26e59cc1SAlex Zinenko auto loc = op.getLoc(); 46*26e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 47*26e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 48*26e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 49*26e59cc1SAlex Zinenko 50*26e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 51*26e59cc1SAlex Zinenko LLVM::ConstantOp one; 52*26e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 53*26e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 54*26e59cc1SAlex Zinenko loc, operandType, 55*26e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 56*26e59cc1SAlex Zinenko } else { 57*26e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 58*26e59cc1SAlex Zinenko } 59*26e59cc1SAlex Zinenko auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand()); 60*26e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 61*26e59cc1SAlex Zinenko return success(); 62*26e59cc1SAlex Zinenko } 63*26e59cc1SAlex Zinenko 64*26e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 65*26e59cc1SAlex Zinenko if (!vectorType) 66*26e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 67*26e59cc1SAlex Zinenko 68*26e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 69*26e59cc1SAlex Zinenko op.getOperation(), operands, *getTypeConverter(), 70*26e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 71*26e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 72*26e59cc1SAlex Zinenko mlir::VectorType::get( 73*26e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 74*26e59cc1SAlex Zinenko floatType), 75*26e59cc1SAlex Zinenko floatOne); 76*26e59cc1SAlex Zinenko auto one = 77*26e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 78*26e59cc1SAlex Zinenko auto exp = 79*26e59cc1SAlex Zinenko rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 80*26e59cc1SAlex Zinenko return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 81*26e59cc1SAlex Zinenko }, 82*26e59cc1SAlex Zinenko rewriter); 83*26e59cc1SAlex Zinenko } 84*26e59cc1SAlex Zinenko }; 85*26e59cc1SAlex Zinenko 86*26e59cc1SAlex Zinenko // A `log1p` is converted into `log(1 + ...)`. 87*26e59cc1SAlex Zinenko struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 88*26e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 89*26e59cc1SAlex Zinenko 90*26e59cc1SAlex Zinenko LogicalResult 91*26e59cc1SAlex Zinenko matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands, 92*26e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 93*26e59cc1SAlex Zinenko math::Log1pOp::Adaptor transformed(operands); 94*26e59cc1SAlex Zinenko auto operandType = transformed.operand().getType(); 95*26e59cc1SAlex Zinenko 96*26e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 97*26e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "unsupported operand type"); 98*26e59cc1SAlex Zinenko 99*26e59cc1SAlex Zinenko auto loc = op.getLoc(); 100*26e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 101*26e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 102*26e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 103*26e59cc1SAlex Zinenko 104*26e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 105*26e59cc1SAlex Zinenko LLVM::ConstantOp one = 106*26e59cc1SAlex Zinenko LLVM::isCompatibleVectorType(operandType) 107*26e59cc1SAlex Zinenko ? rewriter.create<LLVM::ConstantOp>( 108*26e59cc1SAlex Zinenko loc, operandType, 109*26e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), 110*26e59cc1SAlex Zinenko floatOne)) 111*26e59cc1SAlex Zinenko : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 112*26e59cc1SAlex Zinenko 113*26e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 114*26e59cc1SAlex Zinenko transformed.operand()); 115*26e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 116*26e59cc1SAlex Zinenko return success(); 117*26e59cc1SAlex Zinenko } 118*26e59cc1SAlex Zinenko 119*26e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 120*26e59cc1SAlex Zinenko if (!vectorType) 121*26e59cc1SAlex Zinenko return rewriter.notifyMatchFailure(op, "expected vector result type"); 122*26e59cc1SAlex Zinenko 123*26e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 124*26e59cc1SAlex Zinenko op.getOperation(), operands, *getTypeConverter(), 125*26e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 126*26e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 127*26e59cc1SAlex Zinenko mlir::VectorType::get( 128*26e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 129*26e59cc1SAlex Zinenko floatType), 130*26e59cc1SAlex Zinenko floatOne); 131*26e59cc1SAlex Zinenko auto one = 132*26e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 133*26e59cc1SAlex Zinenko auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 134*26e59cc1SAlex Zinenko operands[0]); 135*26e59cc1SAlex Zinenko return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 136*26e59cc1SAlex Zinenko }, 137*26e59cc1SAlex Zinenko rewriter); 138*26e59cc1SAlex Zinenko } 139*26e59cc1SAlex Zinenko }; 140*26e59cc1SAlex Zinenko 141*26e59cc1SAlex Zinenko // A `rsqrt` is converted into `1 / sqrt`. 142*26e59cc1SAlex Zinenko struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 143*26e59cc1SAlex Zinenko using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 144*26e59cc1SAlex Zinenko 145*26e59cc1SAlex Zinenko LogicalResult 146*26e59cc1SAlex Zinenko matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands, 147*26e59cc1SAlex Zinenko ConversionPatternRewriter &rewriter) const override { 148*26e59cc1SAlex Zinenko math::RsqrtOp::Adaptor transformed(operands); 149*26e59cc1SAlex Zinenko auto operandType = transformed.operand().getType(); 150*26e59cc1SAlex Zinenko 151*26e59cc1SAlex Zinenko if (!operandType || !LLVM::isCompatibleType(operandType)) 152*26e59cc1SAlex Zinenko return failure(); 153*26e59cc1SAlex Zinenko 154*26e59cc1SAlex Zinenko auto loc = op.getLoc(); 155*26e59cc1SAlex Zinenko auto resultType = op.getResult().getType(); 156*26e59cc1SAlex Zinenko auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 157*26e59cc1SAlex Zinenko auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 158*26e59cc1SAlex Zinenko 159*26e59cc1SAlex Zinenko if (!operandType.isa<LLVM::LLVMArrayType>()) { 160*26e59cc1SAlex Zinenko LLVM::ConstantOp one; 161*26e59cc1SAlex Zinenko if (LLVM::isCompatibleVectorType(operandType)) { 162*26e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>( 163*26e59cc1SAlex Zinenko loc, operandType, 164*26e59cc1SAlex Zinenko SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 165*26e59cc1SAlex Zinenko } else { 166*26e59cc1SAlex Zinenko one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 167*26e59cc1SAlex Zinenko } 168*26e59cc1SAlex Zinenko auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand()); 169*26e59cc1SAlex Zinenko rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 170*26e59cc1SAlex Zinenko return success(); 171*26e59cc1SAlex Zinenko } 172*26e59cc1SAlex Zinenko 173*26e59cc1SAlex Zinenko auto vectorType = resultType.dyn_cast<VectorType>(); 174*26e59cc1SAlex Zinenko if (!vectorType) 175*26e59cc1SAlex Zinenko return failure(); 176*26e59cc1SAlex Zinenko 177*26e59cc1SAlex Zinenko return LLVM::detail::handleMultidimensionalVectors( 178*26e59cc1SAlex Zinenko op.getOperation(), operands, *getTypeConverter(), 179*26e59cc1SAlex Zinenko [&](Type llvm1DVectorTy, ValueRange operands) { 180*26e59cc1SAlex Zinenko auto splatAttr = SplatElementsAttr::get( 181*26e59cc1SAlex Zinenko mlir::VectorType::get( 182*26e59cc1SAlex Zinenko {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 183*26e59cc1SAlex Zinenko floatType), 184*26e59cc1SAlex Zinenko floatOne); 185*26e59cc1SAlex Zinenko auto one = 186*26e59cc1SAlex Zinenko rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 187*26e59cc1SAlex Zinenko auto sqrt = 188*26e59cc1SAlex Zinenko rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 189*26e59cc1SAlex Zinenko return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 190*26e59cc1SAlex Zinenko }, 191*26e59cc1SAlex Zinenko rewriter); 192*26e59cc1SAlex Zinenko } 193*26e59cc1SAlex Zinenko }; 194*26e59cc1SAlex Zinenko 195*26e59cc1SAlex Zinenko struct ConvertMathToLLVMPass 196*26e59cc1SAlex Zinenko : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 197*26e59cc1SAlex Zinenko ConvertMathToLLVMPass() = default; 198*26e59cc1SAlex Zinenko 199*26e59cc1SAlex Zinenko void runOnFunction() override { 200*26e59cc1SAlex Zinenko RewritePatternSet patterns(&getContext()); 201*26e59cc1SAlex Zinenko LLVMTypeConverter converter(&getContext()); 202*26e59cc1SAlex Zinenko populateMathToLLVMConversionPatterns(converter, patterns); 203*26e59cc1SAlex Zinenko LLVMConversionTarget target(getContext()); 204*26e59cc1SAlex Zinenko target.addLegalOp<LLVM::DialectCastOp>(); 205*26e59cc1SAlex Zinenko if (failed( 206*26e59cc1SAlex Zinenko applyPartialConversion(getFunction(), target, std::move(patterns)))) 207*26e59cc1SAlex Zinenko signalPassFailure(); 208*26e59cc1SAlex Zinenko } 209*26e59cc1SAlex Zinenko }; 210*26e59cc1SAlex Zinenko } // namespace 211*26e59cc1SAlex Zinenko 212*26e59cc1SAlex Zinenko void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 213*26e59cc1SAlex Zinenko RewritePatternSet &patterns) { 214*26e59cc1SAlex Zinenko // clang-format off 215*26e59cc1SAlex Zinenko patterns.add< 216*26e59cc1SAlex Zinenko CosOpLowering, 217*26e59cc1SAlex Zinenko ExpOpLowering, 218*26e59cc1SAlex Zinenko Exp2OpLowering, 219*26e59cc1SAlex Zinenko ExpM1OpLowering, 220*26e59cc1SAlex Zinenko Log10OpLowering, 221*26e59cc1SAlex Zinenko Log1pOpLowering, 222*26e59cc1SAlex Zinenko Log2OpLowering, 223*26e59cc1SAlex Zinenko LogOpLowering, 224*26e59cc1SAlex Zinenko PowFOpLowering, 225*26e59cc1SAlex Zinenko RsqrtOpLowering, 226*26e59cc1SAlex Zinenko SinOpLowering, 227*26e59cc1SAlex Zinenko SqrtOpLowering 228*26e59cc1SAlex Zinenko >(converter); 229*26e59cc1SAlex Zinenko // clang-format on 230*26e59cc1SAlex Zinenko } 231*26e59cc1SAlex Zinenko 232*26e59cc1SAlex Zinenko std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 233*26e59cc1SAlex Zinenko return std::make_unique<ConvertMathToLLVMPass>(); 234*26e59cc1SAlex Zinenko } 235