1 //===- MathOps.cpp - MLIR operations for math implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Math/IR/Math.h" 11 #include "mlir/IR/Builders.h" 12 13 using namespace mlir; 14 using namespace mlir::math; 15 16 //===----------------------------------------------------------------------===// 17 // TableGen'd op method definitions 18 //===----------------------------------------------------------------------===// 19 20 #define GET_OP_CLASSES 21 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc" 22 23 //===----------------------------------------------------------------------===// 24 // CeilOp folder 25 //===----------------------------------------------------------------------===// 26 27 OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) { 28 auto constOperand = operands.front(); 29 if (!constOperand) 30 return {}; 31 32 auto attr = constOperand.dyn_cast<FloatAttr>(); 33 if (!attr) 34 return {}; 35 36 APFloat sourceVal = attr.getValue(); 37 sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive); 38 39 return FloatAttr::get(getType(), sourceVal); 40 } 41 42 //===----------------------------------------------------------------------===// 43 // Log2Op folder 44 //===----------------------------------------------------------------------===// 45 46 OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) { 47 auto constOperand = operands.front(); 48 if (!constOperand) 49 return {}; 50 51 auto attr = constOperand.dyn_cast<FloatAttr>(); 52 if (!attr) 53 return {}; 54 55 auto ft = getType().cast<FloatType>(); 56 57 APFloat apf = attr.getValue(); 58 59 if (apf.isNegative()) 60 return {}; 61 62 if (ft.getWidth() == 64) 63 return FloatAttr::get(getType(), log2(apf.convertToDouble())); 64 65 if (ft.getWidth() == 32) 66 return FloatAttr::get(getType(), log2f(apf.convertToFloat())); 67 68 return {}; 69 } 70 71 //===----------------------------------------------------------------------===// 72 // PowFOp folder 73 //===----------------------------------------------------------------------===// 74 75 OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) { 76 auto ft = getType().dyn_cast<FloatType>(); 77 if (!ft) 78 return {}; 79 80 APFloat vals[2]{APFloat(ft.getFloatSemantics()), 81 APFloat(ft.getFloatSemantics())}; 82 for (int i = 0; i < 2; ++i) { 83 if (!operands[i]) 84 return {}; 85 86 auto attr = operands[i].dyn_cast<FloatAttr>(); 87 if (!attr) 88 return {}; 89 90 vals[i] = attr.getValue(); 91 } 92 93 if (ft.getWidth() == 64) 94 return FloatAttr::get( 95 getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble())); 96 97 if (ft.getWidth() == 32) 98 return FloatAttr::get( 99 getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat())); 100 101 return {}; 102 } 103 104 OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) { 105 auto constOperand = operands.front(); 106 if (!constOperand) 107 return {}; 108 109 auto attr = constOperand.dyn_cast<FloatAttr>(); 110 if (!attr) 111 return {}; 112 113 auto ft = getType().cast<FloatType>(); 114 115 APFloat apf = attr.getValue(); 116 117 if (apf.isNegative()) 118 return {}; 119 120 if (ft.getWidth() == 64) 121 return FloatAttr::get(getType(), sqrt(apf.convertToDouble())); 122 123 if (ft.getWidth() == 32) 124 return FloatAttr::get(getType(), sqrtf(apf.convertToFloat())); 125 126 return {}; 127 } 128 129 /// Materialize an integer or floating point constant. 130 Operation *math::MathDialect::materializeConstant(OpBuilder &builder, 131 Attribute value, Type type, 132 Location loc) { 133 return builder.create<arith::ConstantOp>(loc, value, type); 134 } 135