1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/IR/TypeUtilities.h" 17 18 using namespace mlir; 19 20 namespace { 21 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 22 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 23 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 24 using Log10OpLowering = 25 VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 26 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 27 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 28 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 29 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 30 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 31 32 // A `expm1` is converted into `exp - 1`. 33 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 34 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 35 36 LogicalResult 37 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 38 ConversionPatternRewriter &rewriter) const override { 39 auto operandType = adaptor.operand().getType(); 40 41 if (!operandType || !LLVM::isCompatibleType(operandType)) 42 return failure(); 43 44 auto loc = op.getLoc(); 45 auto resultType = op.getResult().getType(); 46 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 47 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 48 49 if (!operandType.isa<LLVM::LLVMArrayType>()) { 50 LLVM::ConstantOp one; 51 if (LLVM::isCompatibleVectorType(operandType)) { 52 one = rewriter.create<LLVM::ConstantOp>( 53 loc, operandType, 54 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 55 } else { 56 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 57 } 58 auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.operand()); 59 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 60 return success(); 61 } 62 63 auto vectorType = resultType.dyn_cast<VectorType>(); 64 if (!vectorType) 65 return rewriter.notifyMatchFailure(op, "expected vector result type"); 66 67 return LLVM::detail::handleMultidimensionalVectors( 68 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 69 [&](Type llvm1DVectorTy, ValueRange operands) { 70 auto splatAttr = SplatElementsAttr::get( 71 mlir::VectorType::get( 72 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 73 floatType), 74 floatOne); 75 auto one = 76 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 77 auto exp = 78 rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 79 return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 80 }, 81 rewriter); 82 } 83 }; 84 85 // A `log1p` is converted into `log(1 + ...)`. 86 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 87 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 88 89 LogicalResult 90 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 91 ConversionPatternRewriter &rewriter) const override { 92 auto operandType = adaptor.operand().getType(); 93 94 if (!operandType || !LLVM::isCompatibleType(operandType)) 95 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 96 97 auto loc = op.getLoc(); 98 auto resultType = op.getResult().getType(); 99 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 100 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 101 102 if (!operandType.isa<LLVM::LLVMArrayType>()) { 103 LLVM::ConstantOp one = 104 LLVM::isCompatibleVectorType(operandType) 105 ? rewriter.create<LLVM::ConstantOp>( 106 loc, operandType, 107 SplatElementsAttr::get(resultType.cast<ShapedType>(), 108 floatOne)) 109 : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 110 111 auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 112 adaptor.operand()); 113 rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 114 return success(); 115 } 116 117 auto vectorType = resultType.dyn_cast<VectorType>(); 118 if (!vectorType) 119 return rewriter.notifyMatchFailure(op, "expected vector result type"); 120 121 return LLVM::detail::handleMultidimensionalVectors( 122 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 123 [&](Type llvm1DVectorTy, ValueRange operands) { 124 auto splatAttr = SplatElementsAttr::get( 125 mlir::VectorType::get( 126 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 127 floatType), 128 floatOne); 129 auto one = 130 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 131 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 132 operands[0]); 133 return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 134 }, 135 rewriter); 136 } 137 }; 138 139 // A `rsqrt` is converted into `1 / sqrt`. 140 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 141 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 142 143 LogicalResult 144 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 145 ConversionPatternRewriter &rewriter) const override { 146 auto operandType = adaptor.operand().getType(); 147 148 if (!operandType || !LLVM::isCompatibleType(operandType)) 149 return failure(); 150 151 auto loc = op.getLoc(); 152 auto resultType = op.getResult().getType(); 153 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 154 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 155 156 if (!operandType.isa<LLVM::LLVMArrayType>()) { 157 LLVM::ConstantOp one; 158 if (LLVM::isCompatibleVectorType(operandType)) { 159 one = rewriter.create<LLVM::ConstantOp>( 160 loc, operandType, 161 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 162 } else { 163 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 164 } 165 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.operand()); 166 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 167 return success(); 168 } 169 170 auto vectorType = resultType.dyn_cast<VectorType>(); 171 if (!vectorType) 172 return failure(); 173 174 return LLVM::detail::handleMultidimensionalVectors( 175 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 176 [&](Type llvm1DVectorTy, ValueRange operands) { 177 auto splatAttr = SplatElementsAttr::get( 178 mlir::VectorType::get( 179 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 180 floatType), 181 floatOne); 182 auto one = 183 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 184 auto sqrt = 185 rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 186 return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 187 }, 188 rewriter); 189 } 190 }; 191 192 struct ConvertMathToLLVMPass 193 : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 194 ConvertMathToLLVMPass() = default; 195 196 void runOnFunction() override { 197 RewritePatternSet patterns(&getContext()); 198 LLVMTypeConverter converter(&getContext()); 199 populateMathToLLVMConversionPatterns(converter, patterns); 200 LLVMConversionTarget target(getContext()); 201 if (failed( 202 applyPartialConversion(getFunction(), target, std::move(patterns)))) 203 signalPassFailure(); 204 } 205 }; 206 } // namespace 207 208 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 209 RewritePatternSet &patterns) { 210 // clang-format off 211 patterns.add< 212 CosOpLowering, 213 ExpOpLowering, 214 Exp2OpLowering, 215 ExpM1OpLowering, 216 Log10OpLowering, 217 Log1pOpLowering, 218 Log2OpLowering, 219 LogOpLowering, 220 PowFOpLowering, 221 RsqrtOpLowering, 222 SinOpLowering, 223 SqrtOpLowering 224 >(converter); 225 // clang-format on 226 } 227 228 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 229 return std::make_unique<ConvertMathToLLVMPass>(); 230 } 231