1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/IR/TypeUtilities.h" 17 18 using namespace mlir; 19 20 namespace { 21 using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>; 22 using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>; 23 using CopySignOpLowering = 24 VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>; 25 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 26 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 27 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 28 using FloorOpLowering = 29 VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>; 30 using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>; 31 using Log10OpLowering = 32 VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 33 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 34 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 35 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 36 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 37 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 38 39 // A `expm1` is converted into `exp - 1`. 40 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 41 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 42 43 LogicalResult 44 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 45 ConversionPatternRewriter &rewriter) const override { 46 auto operandType = adaptor.operand().getType(); 47 48 if (!operandType || !LLVM::isCompatibleType(operandType)) 49 return failure(); 50 51 auto loc = op.getLoc(); 52 auto resultType = op.getResult().getType(); 53 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 54 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 55 56 if (!operandType.isa<LLVM::LLVMArrayType>()) { 57 LLVM::ConstantOp one; 58 if (LLVM::isCompatibleVectorType(operandType)) { 59 one = rewriter.create<LLVM::ConstantOp>( 60 loc, operandType, 61 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 62 } else { 63 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 64 } 65 auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.operand()); 66 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 67 return success(); 68 } 69 70 auto vectorType = resultType.dyn_cast<VectorType>(); 71 if (!vectorType) 72 return rewriter.notifyMatchFailure(op, "expected vector result type"); 73 74 return LLVM::detail::handleMultidimensionalVectors( 75 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 76 [&](Type llvm1DVectorTy, ValueRange operands) { 77 auto splatAttr = SplatElementsAttr::get( 78 mlir::VectorType::get( 79 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 80 floatType), 81 floatOne); 82 auto one = 83 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 84 auto exp = 85 rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 86 return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 87 }, 88 rewriter); 89 } 90 }; 91 92 // A `log1p` is converted into `log(1 + ...)`. 93 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 94 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 95 96 LogicalResult 97 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 98 ConversionPatternRewriter &rewriter) const override { 99 auto operandType = adaptor.operand().getType(); 100 101 if (!operandType || !LLVM::isCompatibleType(operandType)) 102 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 103 104 auto loc = op.getLoc(); 105 auto resultType = op.getResult().getType(); 106 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 107 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 108 109 if (!operandType.isa<LLVM::LLVMArrayType>()) { 110 LLVM::ConstantOp one = 111 LLVM::isCompatibleVectorType(operandType) 112 ? rewriter.create<LLVM::ConstantOp>( 113 loc, operandType, 114 SplatElementsAttr::get(resultType.cast<ShapedType>(), 115 floatOne)) 116 : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 117 118 auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 119 adaptor.operand()); 120 rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 121 return success(); 122 } 123 124 auto vectorType = resultType.dyn_cast<VectorType>(); 125 if (!vectorType) 126 return rewriter.notifyMatchFailure(op, "expected vector result type"); 127 128 return LLVM::detail::handleMultidimensionalVectors( 129 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 130 [&](Type llvm1DVectorTy, ValueRange operands) { 131 auto splatAttr = SplatElementsAttr::get( 132 mlir::VectorType::get( 133 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 134 floatType), 135 floatOne); 136 auto one = 137 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 138 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 139 operands[0]); 140 return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 141 }, 142 rewriter); 143 } 144 }; 145 146 // A `rsqrt` is converted into `1 / sqrt`. 147 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 148 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 149 150 LogicalResult 151 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 152 ConversionPatternRewriter &rewriter) const override { 153 auto operandType = adaptor.operand().getType(); 154 155 if (!operandType || !LLVM::isCompatibleType(operandType)) 156 return failure(); 157 158 auto loc = op.getLoc(); 159 auto resultType = op.getResult().getType(); 160 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 161 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 162 163 if (!operandType.isa<LLVM::LLVMArrayType>()) { 164 LLVM::ConstantOp one; 165 if (LLVM::isCompatibleVectorType(operandType)) { 166 one = rewriter.create<LLVM::ConstantOp>( 167 loc, operandType, 168 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 169 } else { 170 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 171 } 172 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.operand()); 173 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 174 return success(); 175 } 176 177 auto vectorType = resultType.dyn_cast<VectorType>(); 178 if (!vectorType) 179 return failure(); 180 181 return LLVM::detail::handleMultidimensionalVectors( 182 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 183 [&](Type llvm1DVectorTy, ValueRange operands) { 184 auto splatAttr = SplatElementsAttr::get( 185 mlir::VectorType::get( 186 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 187 floatType), 188 floatOne); 189 auto one = 190 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 191 auto sqrt = 192 rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 193 return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 194 }, 195 rewriter); 196 } 197 }; 198 199 struct ConvertMathToLLVMPass 200 : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 201 ConvertMathToLLVMPass() = default; 202 203 void runOnFunction() override { 204 RewritePatternSet patterns(&getContext()); 205 LLVMTypeConverter converter(&getContext()); 206 populateMathToLLVMConversionPatterns(converter, patterns); 207 LLVMConversionTarget target(getContext()); 208 if (failed( 209 applyPartialConversion(getFunction(), target, std::move(patterns)))) 210 signalPassFailure(); 211 } 212 }; 213 } // namespace 214 215 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 216 RewritePatternSet &patterns) { 217 // clang-format off 218 patterns.add< 219 AbsOpLowering, 220 CeilOpLowering, 221 CopySignOpLowering, 222 CosOpLowering, 223 ExpOpLowering, 224 Exp2OpLowering, 225 ExpM1OpLowering, 226 FloorOpLowering, 227 FmaOpLowering, 228 Log10OpLowering, 229 Log1pOpLowering, 230 Log2OpLowering, 231 LogOpLowering, 232 PowFOpLowering, 233 RsqrtOpLowering, 234 SinOpLowering, 235 SqrtOpLowering 236 >(converter); 237 // clang-format on 238 } 239 240 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 241 return std::make_unique<ConvertMathToLLVMPass>(); 242 } 243