1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/IR/TypeUtilities.h" 17 18 using namespace mlir; 19 20 namespace { 21 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 22 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 23 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 24 using Log10OpLowering = 25 VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 26 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 27 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 28 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 29 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 30 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 31 32 // A `expm1` is converted into `exp - 1`. 33 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 34 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 35 36 LogicalResult 37 matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands, 38 ConversionPatternRewriter &rewriter) const override { 39 math::ExpM1Op::Adaptor transformed(operands); 40 auto operandType = transformed.operand().getType(); 41 42 if (!operandType || !LLVM::isCompatibleType(operandType)) 43 return failure(); 44 45 auto loc = op.getLoc(); 46 auto resultType = op.getResult().getType(); 47 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 48 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 49 50 if (!operandType.isa<LLVM::LLVMArrayType>()) { 51 LLVM::ConstantOp one; 52 if (LLVM::isCompatibleVectorType(operandType)) { 53 one = rewriter.create<LLVM::ConstantOp>( 54 loc, operandType, 55 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 56 } else { 57 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 58 } 59 auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand()); 60 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 61 return success(); 62 } 63 64 auto vectorType = resultType.dyn_cast<VectorType>(); 65 if (!vectorType) 66 return rewriter.notifyMatchFailure(op, "expected vector result type"); 67 68 return LLVM::detail::handleMultidimensionalVectors( 69 op.getOperation(), operands, *getTypeConverter(), 70 [&](Type llvm1DVectorTy, ValueRange operands) { 71 auto splatAttr = SplatElementsAttr::get( 72 mlir::VectorType::get( 73 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 74 floatType), 75 floatOne); 76 auto one = 77 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 78 auto exp = 79 rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 80 return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 81 }, 82 rewriter); 83 } 84 }; 85 86 // A `log1p` is converted into `log(1 + ...)`. 87 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 88 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 89 90 LogicalResult 91 matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands, 92 ConversionPatternRewriter &rewriter) const override { 93 math::Log1pOp::Adaptor transformed(operands); 94 auto operandType = transformed.operand().getType(); 95 96 if (!operandType || !LLVM::isCompatibleType(operandType)) 97 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 98 99 auto loc = op.getLoc(); 100 auto resultType = op.getResult().getType(); 101 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 102 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 103 104 if (!operandType.isa<LLVM::LLVMArrayType>()) { 105 LLVM::ConstantOp one = 106 LLVM::isCompatibleVectorType(operandType) 107 ? rewriter.create<LLVM::ConstantOp>( 108 loc, operandType, 109 SplatElementsAttr::get(resultType.cast<ShapedType>(), 110 floatOne)) 111 : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 112 113 auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 114 transformed.operand()); 115 rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 116 return success(); 117 } 118 119 auto vectorType = resultType.dyn_cast<VectorType>(); 120 if (!vectorType) 121 return rewriter.notifyMatchFailure(op, "expected vector result type"); 122 123 return LLVM::detail::handleMultidimensionalVectors( 124 op.getOperation(), operands, *getTypeConverter(), 125 [&](Type llvm1DVectorTy, ValueRange operands) { 126 auto splatAttr = SplatElementsAttr::get( 127 mlir::VectorType::get( 128 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 129 floatType), 130 floatOne); 131 auto one = 132 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 133 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 134 operands[0]); 135 return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 136 }, 137 rewriter); 138 } 139 }; 140 141 // A `rsqrt` is converted into `1 / sqrt`. 142 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 143 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 144 145 LogicalResult 146 matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands, 147 ConversionPatternRewriter &rewriter) const override { 148 math::RsqrtOp::Adaptor transformed(operands); 149 auto operandType = transformed.operand().getType(); 150 151 if (!operandType || !LLVM::isCompatibleType(operandType)) 152 return failure(); 153 154 auto loc = op.getLoc(); 155 auto resultType = op.getResult().getType(); 156 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 157 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 158 159 if (!operandType.isa<LLVM::LLVMArrayType>()) { 160 LLVM::ConstantOp one; 161 if (LLVM::isCompatibleVectorType(operandType)) { 162 one = rewriter.create<LLVM::ConstantOp>( 163 loc, operandType, 164 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 165 } else { 166 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 167 } 168 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand()); 169 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 170 return success(); 171 } 172 173 auto vectorType = resultType.dyn_cast<VectorType>(); 174 if (!vectorType) 175 return failure(); 176 177 return LLVM::detail::handleMultidimensionalVectors( 178 op.getOperation(), operands, *getTypeConverter(), 179 [&](Type llvm1DVectorTy, ValueRange operands) { 180 auto splatAttr = SplatElementsAttr::get( 181 mlir::VectorType::get( 182 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 183 floatType), 184 floatOne); 185 auto one = 186 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 187 auto sqrt = 188 rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 189 return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 190 }, 191 rewriter); 192 } 193 }; 194 195 struct ConvertMathToLLVMPass 196 : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 197 ConvertMathToLLVMPass() = default; 198 199 void runOnFunction() override { 200 RewritePatternSet patterns(&getContext()); 201 LLVMTypeConverter converter(&getContext()); 202 populateMathToLLVMConversionPatterns(converter, patterns); 203 LLVMConversionTarget target(getContext()); 204 if (failed( 205 applyPartialConversion(getFunction(), target, std::move(patterns)))) 206 signalPassFailure(); 207 } 208 }; 209 } // namespace 210 211 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 212 RewritePatternSet &patterns) { 213 // clang-format off 214 patterns.add< 215 CosOpLowering, 216 ExpOpLowering, 217 Exp2OpLowering, 218 ExpM1OpLowering, 219 Log10OpLowering, 220 Log1pOpLowering, 221 Log2OpLowering, 222 LogOpLowering, 223 PowFOpLowering, 224 RsqrtOpLowering, 225 SinOpLowering, 226 SqrtOpLowering 227 >(converter); 228 // clang-format on 229 } 230 231 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 232 return std::make_unique<ConvertMathToLLVMPass>(); 233 } 234