1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/IR/TypeUtilities.h" 17 18 using namespace mlir; 19 20 namespace { 21 using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>; 22 using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>; 23 using CopySignOpLowering = 24 VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>; 25 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 26 using CtPopFOpLowering = 27 VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>; 28 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 29 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 30 using FloorOpLowering = 31 VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>; 32 using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>; 33 using Log10OpLowering = 34 VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 35 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 36 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 37 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 38 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 39 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 40 41 // A `expm1` is converted into `exp - 1`. 42 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 43 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 44 45 LogicalResult 46 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 47 ConversionPatternRewriter &rewriter) const override { 48 auto operandType = adaptor.getOperand().getType(); 49 50 if (!operandType || !LLVM::isCompatibleType(operandType)) 51 return failure(); 52 53 auto loc = op.getLoc(); 54 auto resultType = op.getResult().getType(); 55 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 56 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 57 58 if (!operandType.isa<LLVM::LLVMArrayType>()) { 59 LLVM::ConstantOp one; 60 if (LLVM::isCompatibleVectorType(operandType)) { 61 one = rewriter.create<LLVM::ConstantOp>( 62 loc, operandType, 63 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 64 } else { 65 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 66 } 67 auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand()); 68 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 69 return success(); 70 } 71 72 auto vectorType = resultType.dyn_cast<VectorType>(); 73 if (!vectorType) 74 return rewriter.notifyMatchFailure(op, "expected vector result type"); 75 76 return LLVM::detail::handleMultidimensionalVectors( 77 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 78 [&](Type llvm1DVectorTy, ValueRange operands) { 79 auto splatAttr = SplatElementsAttr::get( 80 mlir::VectorType::get( 81 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 82 floatType), 83 floatOne); 84 auto one = 85 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 86 auto exp = 87 rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 88 return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 89 }, 90 rewriter); 91 } 92 }; 93 94 // A `log1p` is converted into `log(1 + ...)`. 95 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 96 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 97 98 LogicalResult 99 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 100 ConversionPatternRewriter &rewriter) const override { 101 auto operandType = adaptor.getOperand().getType(); 102 103 if (!operandType || !LLVM::isCompatibleType(operandType)) 104 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 105 106 auto loc = op.getLoc(); 107 auto resultType = op.getResult().getType(); 108 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 109 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 110 111 if (!operandType.isa<LLVM::LLVMArrayType>()) { 112 LLVM::ConstantOp one = 113 LLVM::isCompatibleVectorType(operandType) 114 ? rewriter.create<LLVM::ConstantOp>( 115 loc, operandType, 116 SplatElementsAttr::get(resultType.cast<ShapedType>(), 117 floatOne)) 118 : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 119 120 auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 121 adaptor.getOperand()); 122 rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 123 return success(); 124 } 125 126 auto vectorType = resultType.dyn_cast<VectorType>(); 127 if (!vectorType) 128 return rewriter.notifyMatchFailure(op, "expected vector result type"); 129 130 return LLVM::detail::handleMultidimensionalVectors( 131 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 132 [&](Type llvm1DVectorTy, ValueRange operands) { 133 auto splatAttr = SplatElementsAttr::get( 134 mlir::VectorType::get( 135 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 136 floatType), 137 floatOne); 138 auto one = 139 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 140 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 141 operands[0]); 142 return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 143 }, 144 rewriter); 145 } 146 }; 147 148 // A `rsqrt` is converted into `1 / sqrt`. 149 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 150 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 151 152 LogicalResult 153 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 154 ConversionPatternRewriter &rewriter) const override { 155 auto operandType = adaptor.getOperand().getType(); 156 157 if (!operandType || !LLVM::isCompatibleType(operandType)) 158 return failure(); 159 160 auto loc = op.getLoc(); 161 auto resultType = op.getResult().getType(); 162 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 163 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 164 165 if (!operandType.isa<LLVM::LLVMArrayType>()) { 166 LLVM::ConstantOp one; 167 if (LLVM::isCompatibleVectorType(operandType)) { 168 one = rewriter.create<LLVM::ConstantOp>( 169 loc, operandType, 170 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 171 } else { 172 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 173 } 174 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand()); 175 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 176 return success(); 177 } 178 179 auto vectorType = resultType.dyn_cast<VectorType>(); 180 if (!vectorType) 181 return failure(); 182 183 return LLVM::detail::handleMultidimensionalVectors( 184 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 185 [&](Type llvm1DVectorTy, ValueRange operands) { 186 auto splatAttr = SplatElementsAttr::get( 187 mlir::VectorType::get( 188 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 189 floatType), 190 floatOne); 191 auto one = 192 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 193 auto sqrt = 194 rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 195 return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 196 }, 197 rewriter); 198 } 199 }; 200 201 struct ConvertMathToLLVMPass 202 : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 203 ConvertMathToLLVMPass() = default; 204 205 void runOnFunction() override { 206 RewritePatternSet patterns(&getContext()); 207 LLVMTypeConverter converter(&getContext()); 208 populateMathToLLVMConversionPatterns(converter, patterns); 209 LLVMConversionTarget target(getContext()); 210 if (failed( 211 applyPartialConversion(getFunction(), target, std::move(patterns)))) 212 signalPassFailure(); 213 } 214 }; 215 } // namespace 216 217 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 218 RewritePatternSet &patterns) { 219 // clang-format off 220 patterns.add< 221 AbsOpLowering, 222 CeilOpLowering, 223 CopySignOpLowering, 224 CosOpLowering, 225 CtPopFOpLowering, 226 ExpOpLowering, 227 Exp2OpLowering, 228 ExpM1OpLowering, 229 FloorOpLowering, 230 FmaOpLowering, 231 Log10OpLowering, 232 Log1pOpLowering, 233 Log2OpLowering, 234 LogOpLowering, 235 PowFOpLowering, 236 RsqrtOpLowering, 237 SinOpLowering, 238 SqrtOpLowering 239 >(converter); 240 // clang-format on 241 } 242 243 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 244 return std::make_unique<ConvertMathToLLVMPass>(); 245 } 246