1 //===- MathOps.cpp - MLIR operations for math implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Math/IR/Math.h" 11 #include "mlir/IR/Builders.h" 12 13 using namespace mlir; 14 using namespace mlir::math; 15 16 //===----------------------------------------------------------------------===// 17 // TableGen'd op method definitions 18 //===----------------------------------------------------------------------===// 19 20 #define GET_OP_CLASSES 21 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc" 22 23 //===----------------------------------------------------------------------===// 24 // AbsOp folder 25 //===----------------------------------------------------------------------===// 26 27 OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) { 28 auto constOperand = operands.front(); 29 if (!constOperand) 30 return {}; 31 32 auto attr = constOperand.dyn_cast<FloatAttr>(); 33 if (!attr) 34 return {}; 35 36 auto ft = getType().cast<FloatType>(); 37 38 APFloat apf = attr.getValue(); 39 40 if (ft.getWidth() == 64) 41 return FloatAttr::get(getType(), fabs(apf.convertToDouble())); 42 43 if (ft.getWidth() == 32) 44 return FloatAttr::get(getType(), fabsf(apf.convertToFloat())); 45 46 return {}; 47 } 48 49 //===----------------------------------------------------------------------===// 50 // CeilOp folder 51 //===----------------------------------------------------------------------===// 52 53 OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) { 54 auto constOperand = operands.front(); 55 if (!constOperand) 56 return {}; 57 58 auto attr = constOperand.dyn_cast<FloatAttr>(); 59 if (!attr) 60 return {}; 61 62 APFloat sourceVal = attr.getValue(); 63 sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive); 64 65 return FloatAttr::get(getType(), sourceVal); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // CopySignOp folder 70 //===----------------------------------------------------------------------===// 71 72 OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) { 73 auto ft = getType().dyn_cast<FloatType>(); 74 if (!ft) 75 return {}; 76 77 APFloat vals[2]{APFloat(ft.getFloatSemantics()), 78 APFloat(ft.getFloatSemantics())}; 79 for (int i = 0; i < 2; ++i) { 80 if (!operands[i]) 81 return {}; 82 83 auto attr = operands[i].dyn_cast<FloatAttr>(); 84 if (!attr) 85 return {}; 86 87 vals[i] = attr.getValue(); 88 } 89 90 vals[0].copySign(vals[1]); 91 92 return FloatAttr::get(getType(), vals[0]); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // CountLeadingZerosOp folder 97 //===----------------------------------------------------------------------===// 98 99 OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) { 100 auto constOperand = operands.front(); 101 if (!constOperand) 102 return {}; 103 104 auto attr = constOperand.dyn_cast<IntegerAttr>(); 105 if (!attr) 106 return {}; 107 108 return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros()); 109 } 110 111 //===----------------------------------------------------------------------===// 112 // CountTrailingZerosOp folder 113 //===----------------------------------------------------------------------===// 114 115 OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) { 116 auto constOperand = operands.front(); 117 if (!constOperand) 118 return {}; 119 120 auto attr = constOperand.dyn_cast<IntegerAttr>(); 121 if (!attr) 122 return {}; 123 124 return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros()); 125 } 126 127 //===----------------------------------------------------------------------===// 128 // CtPopOp folder 129 //===----------------------------------------------------------------------===// 130 131 OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) { 132 auto constOperand = operands.front(); 133 if (!constOperand) 134 return {}; 135 136 auto attr = constOperand.dyn_cast<IntegerAttr>(); 137 if (!attr) 138 return {}; 139 140 return IntegerAttr::get(getType(), attr.getValue().countPopulation()); 141 } 142 143 //===----------------------------------------------------------------------===// 144 // Log2Op folder 145 //===----------------------------------------------------------------------===// 146 147 OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) { 148 auto constOperand = operands.front(); 149 if (!constOperand) 150 return {}; 151 152 auto attr = constOperand.dyn_cast<FloatAttr>(); 153 if (!attr) 154 return {}; 155 156 auto ft = getType().cast<FloatType>(); 157 158 APFloat apf = attr.getValue(); 159 160 if (apf.isNegative()) 161 return {}; 162 163 if (ft.getWidth() == 64) 164 return FloatAttr::get(getType(), log2(apf.convertToDouble())); 165 166 if (ft.getWidth() == 32) 167 return FloatAttr::get(getType(), log2f(apf.convertToFloat())); 168 169 return {}; 170 } 171 172 //===----------------------------------------------------------------------===// 173 // PowFOp folder 174 //===----------------------------------------------------------------------===// 175 176 OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) { 177 auto ft = getType().dyn_cast<FloatType>(); 178 if (!ft) 179 return {}; 180 181 APFloat vals[2]{APFloat(ft.getFloatSemantics()), 182 APFloat(ft.getFloatSemantics())}; 183 for (int i = 0; i < 2; ++i) { 184 if (!operands[i]) 185 return {}; 186 187 auto attr = operands[i].dyn_cast<FloatAttr>(); 188 if (!attr) 189 return {}; 190 191 vals[i] = attr.getValue(); 192 } 193 194 if (ft.getWidth() == 64) 195 return FloatAttr::get( 196 getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble())); 197 198 if (ft.getWidth() == 32) 199 return FloatAttr::get( 200 getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat())); 201 202 return {}; 203 } 204 205 OpFoldResult math::SqrtOp::fold(ArrayRef<Attribute> operands) { 206 auto constOperand = operands.front(); 207 if (!constOperand) 208 return {}; 209 210 auto attr = constOperand.dyn_cast<FloatAttr>(); 211 if (!attr) 212 return {}; 213 214 auto ft = getType().cast<FloatType>(); 215 216 APFloat apf = attr.getValue(); 217 218 if (apf.isNegative()) 219 return {}; 220 221 if (ft.getWidth() == 64) 222 return FloatAttr::get(getType(), sqrt(apf.convertToDouble())); 223 224 if (ft.getWidth() == 32) 225 return FloatAttr::get(getType(), sqrtf(apf.convertToFloat())); 226 227 return {}; 228 } 229 230 /// Materialize an integer or floating point constant. 231 Operation *math::MathDialect::materializeConstant(OpBuilder &builder, 232 Attribute value, Type type, 233 Location loc) { 234 return builder.create<arith::ConstantOp>(loc, value, type); 235 } 236