1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/IR/TypeUtilities.h" 17 18 using namespace mlir; 19 20 namespace { 21 using AbsOpLowering = VectorConvertToLLVMPattern<math::AbsOp, LLVM::FAbsOp>; 22 using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>; 23 using CopySignOpLowering = 24 VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>; 25 using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>; 26 using CtPopFOpLowering = 27 VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>; 28 using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>; 29 using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; 30 using FloorOpLowering = 31 VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>; 32 using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>; 33 using Log10OpLowering = 34 VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>; 35 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>; 36 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>; 37 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>; 38 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>; 39 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; 40 41 // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`. 42 template <typename MathOp, typename LLVMOp> 43 struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> { 44 using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern; 45 using Super = CountOpLowering<MathOp, LLVMOp>; 46 47 LogicalResult 48 matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, 49 ConversionPatternRewriter &rewriter) const override { 50 auto operandType = adaptor.getOperand().getType(); 51 52 if (!operandType || !LLVM::isCompatibleType(operandType)) 53 return failure(); 54 55 auto loc = op.getLoc(); 56 auto resultType = op.getResult().getType(); 57 auto boolType = rewriter.getIntegerType(1); 58 auto boolZero = rewriter.getIntegerAttr(boolType, 0); 59 60 if (!operandType.template isa<LLVM::LLVMArrayType>()) { 61 LLVM::ConstantOp zero = 62 rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero); 63 rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(), 64 zero); 65 return success(); 66 } 67 68 auto vectorType = resultType.template dyn_cast<VectorType>(); 69 if (!vectorType) 70 return failure(); 71 72 return LLVM::detail::handleMultidimensionalVectors( 73 op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(), 74 [&](Type llvm1DVectorTy, ValueRange operands) { 75 LLVM::ConstantOp zero = 76 rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero); 77 return rewriter.replaceOpWithNewOp<LLVMOp>(op, llvm1DVectorTy, 78 operands[0], zero); 79 }, 80 rewriter); 81 } 82 }; 83 84 using CountLeadingZerosOpLowering = 85 CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>; 86 using CountTrailingZerosOpLowering = 87 CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>; 88 89 // A `expm1` is converted into `exp - 1`. 90 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { 91 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; 92 93 LogicalResult 94 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, 95 ConversionPatternRewriter &rewriter) const override { 96 auto operandType = adaptor.getOperand().getType(); 97 98 if (!operandType || !LLVM::isCompatibleType(operandType)) 99 return failure(); 100 101 auto loc = op.getLoc(); 102 auto resultType = op.getResult().getType(); 103 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 104 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 105 106 if (!operandType.isa<LLVM::LLVMArrayType>()) { 107 LLVM::ConstantOp one; 108 if (LLVM::isCompatibleVectorType(operandType)) { 109 one = rewriter.create<LLVM::ConstantOp>( 110 loc, operandType, 111 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 112 } else { 113 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 114 } 115 auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand()); 116 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one); 117 return success(); 118 } 119 120 auto vectorType = resultType.dyn_cast<VectorType>(); 121 if (!vectorType) 122 return rewriter.notifyMatchFailure(op, "expected vector result type"); 123 124 return LLVM::detail::handleMultidimensionalVectors( 125 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 126 [&](Type llvm1DVectorTy, ValueRange operands) { 127 auto splatAttr = SplatElementsAttr::get( 128 mlir::VectorType::get( 129 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 130 floatType), 131 floatOne); 132 auto one = 133 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 134 auto exp = 135 rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]); 136 return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one); 137 }, 138 rewriter); 139 } 140 }; 141 142 // A `log1p` is converted into `log(1 + ...)`. 143 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { 144 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; 145 146 LogicalResult 147 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, 148 ConversionPatternRewriter &rewriter) const override { 149 auto operandType = adaptor.getOperand().getType(); 150 151 if (!operandType || !LLVM::isCompatibleType(operandType)) 152 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 153 154 auto loc = op.getLoc(); 155 auto resultType = op.getResult().getType(); 156 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 157 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 158 159 if (!operandType.isa<LLVM::LLVMArrayType>()) { 160 LLVM::ConstantOp one = 161 LLVM::isCompatibleVectorType(operandType) 162 ? rewriter.create<LLVM::ConstantOp>( 163 loc, operandType, 164 SplatElementsAttr::get(resultType.cast<ShapedType>(), 165 floatOne)) 166 : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 167 168 auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one, 169 adaptor.getOperand()); 170 rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add); 171 return success(); 172 } 173 174 auto vectorType = resultType.dyn_cast<VectorType>(); 175 if (!vectorType) 176 return rewriter.notifyMatchFailure(op, "expected vector result type"); 177 178 return LLVM::detail::handleMultidimensionalVectors( 179 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 180 [&](Type llvm1DVectorTy, ValueRange operands) { 181 auto splatAttr = SplatElementsAttr::get( 182 mlir::VectorType::get( 183 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 184 floatType), 185 floatOne); 186 auto one = 187 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 188 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one, 189 operands[0]); 190 return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add); 191 }, 192 rewriter); 193 } 194 }; 195 196 // A `rsqrt` is converted into `1 / sqrt`. 197 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { 198 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; 199 200 LogicalResult 201 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, 202 ConversionPatternRewriter &rewriter) const override { 203 auto operandType = adaptor.getOperand().getType(); 204 205 if (!operandType || !LLVM::isCompatibleType(operandType)) 206 return failure(); 207 208 auto loc = op.getLoc(); 209 auto resultType = op.getResult().getType(); 210 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>(); 211 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 212 213 if (!operandType.isa<LLVM::LLVMArrayType>()) { 214 LLVM::ConstantOp one; 215 if (LLVM::isCompatibleVectorType(operandType)) { 216 one = rewriter.create<LLVM::ConstantOp>( 217 loc, operandType, 218 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne)); 219 } else { 220 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); 221 } 222 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand()); 223 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt); 224 return success(); 225 } 226 227 auto vectorType = resultType.dyn_cast<VectorType>(); 228 if (!vectorType) 229 return failure(); 230 231 return LLVM::detail::handleMultidimensionalVectors( 232 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 233 [&](Type llvm1DVectorTy, ValueRange operands) { 234 auto splatAttr = SplatElementsAttr::get( 235 mlir::VectorType::get( 236 {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, 237 floatType), 238 floatOne); 239 auto one = 240 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); 241 auto sqrt = 242 rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]); 243 return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt); 244 }, 245 rewriter); 246 } 247 }; 248 249 struct ConvertMathToLLVMPass 250 : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> { 251 ConvertMathToLLVMPass() = default; 252 253 void runOnOperation() override { 254 RewritePatternSet patterns(&getContext()); 255 LLVMTypeConverter converter(&getContext()); 256 populateMathToLLVMConversionPatterns(converter, patterns); 257 LLVMConversionTarget target(getContext()); 258 if (failed(applyPartialConversion(getOperation(), target, 259 std::move(patterns)))) 260 signalPassFailure(); 261 } 262 }; 263 } // namespace 264 265 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, 266 RewritePatternSet &patterns) { 267 // clang-format off 268 patterns.add< 269 AbsOpLowering, 270 CeilOpLowering, 271 CopySignOpLowering, 272 CosOpLowering, 273 CountLeadingZerosOpLowering, 274 CountTrailingZerosOpLowering, 275 CtPopFOpLowering, 276 ExpOpLowering, 277 Exp2OpLowering, 278 ExpM1OpLowering, 279 FloorOpLowering, 280 FmaOpLowering, 281 Log10OpLowering, 282 Log1pOpLowering, 283 Log2OpLowering, 284 LogOpLowering, 285 PowFOpLowering, 286 RsqrtOpLowering, 287 SinOpLowering, 288 SqrtOpLowering 289 >(converter); 290 // clang-format on 291 } 292 293 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() { 294 return std::make_unique<ConvertMathToLLVMPass>(); 295 } 296