1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 #include "llvm/ADT/APSInt.h" 20 21 using namespace mlir; 22 using namespace mlir::arith; 23 24 //===----------------------------------------------------------------------===// 25 // Pattern helpers 26 //===----------------------------------------------------------------------===// 27 28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 29 Attribute lhs, Attribute rhs) { 30 return builder.getIntegerAttr(res.getType(), 31 lhs.cast<IntegerAttr>().getInt() + 32 rhs.cast<IntegerAttr>().getInt()); 33 } 34 35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 36 Attribute lhs, Attribute rhs) { 37 return builder.getIntegerAttr(res.getType(), 38 lhs.cast<IntegerAttr>().getInt() - 39 rhs.cast<IntegerAttr>().getInt()); 40 } 41 42 /// Invert an integer comparison predicate. 43 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { 44 switch (pred) { 45 case arith::CmpIPredicate::eq: 46 return arith::CmpIPredicate::ne; 47 case arith::CmpIPredicate::ne: 48 return arith::CmpIPredicate::eq; 49 case arith::CmpIPredicate::slt: 50 return arith::CmpIPredicate::sge; 51 case arith::CmpIPredicate::sle: 52 return arith::CmpIPredicate::sgt; 53 case arith::CmpIPredicate::sgt: 54 return arith::CmpIPredicate::sle; 55 case arith::CmpIPredicate::sge: 56 return arith::CmpIPredicate::slt; 57 case arith::CmpIPredicate::ult: 58 return arith::CmpIPredicate::uge; 59 case arith::CmpIPredicate::ule: 60 return arith::CmpIPredicate::ugt; 61 case arith::CmpIPredicate::ugt: 62 return arith::CmpIPredicate::ule; 63 case arith::CmpIPredicate::uge: 64 return arith::CmpIPredicate::ult; 65 } 66 llvm_unreachable("unknown cmpi predicate kind"); 67 } 68 69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 70 return arith::CmpIPredicateAttr::get(pred.getContext(), 71 invertPredicate(pred.getValue())); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // TableGen'd canonicalization patterns 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 #include "ArithmeticCanonicalization.inc" 80 } // namespace 81 82 //===----------------------------------------------------------------------===// 83 // ConstantOp 84 //===----------------------------------------------------------------------===// 85 86 void arith::ConstantOp::getAsmResultNames( 87 function_ref<void(Value, StringRef)> setNameFn) { 88 auto type = getType(); 89 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 90 auto intType = type.dyn_cast<IntegerType>(); 91 92 // Sugar i1 constants with 'true' and 'false'. 93 if (intType && intType.getWidth() == 1) 94 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 95 96 // Otherwise, build a compex name with the value and type. 97 SmallString<32> specialNameBuffer; 98 llvm::raw_svector_ostream specialName(specialNameBuffer); 99 specialName << 'c' << intCst.getInt(); 100 if (intType) 101 specialName << '_' << type; 102 setNameFn(getResult(), specialName.str()); 103 } else { 104 setNameFn(getResult(), "cst"); 105 } 106 } 107 108 /// TODO: disallow arith.constant to return anything other than signless integer 109 /// or float like. 110 static LogicalResult verify(arith::ConstantOp op) { 111 auto type = op.getType(); 112 // The value's type must match the return type. 113 if (op.getValue().getType() != type) { 114 return op.emitOpError() << "value type " << op.getValue().getType() 115 << " must match return type: " << type; 116 } 117 // Integer values must be signless. 118 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 119 return op.emitOpError("integer return type must be signless"); 120 // Any float or elements attribute are acceptable. 121 if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 122 return op.emitOpError( 123 "value must be an integer, float, or elements attribute"); 124 } 125 return success(); 126 } 127 128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 129 // The value's type must be the same as the provided type. 130 if (value.getType() != type) 131 return false; 132 // Integer values must be signless. 133 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 134 return false; 135 // Integer, float, and element attributes are buildable. 136 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 137 } 138 139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 140 return getValue(); 141 } 142 143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 144 int64_t value, unsigned width) { 145 auto type = builder.getIntegerType(width); 146 arith::ConstantOp::build(builder, result, type, 147 builder.getIntegerAttr(type, value)); 148 } 149 150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 151 int64_t value, Type type) { 152 assert(type.isSignlessInteger() && 153 "ConstantIntOp can only have signless integer type values"); 154 arith::ConstantOp::build(builder, result, type, 155 builder.getIntegerAttr(type, value)); 156 } 157 158 bool arith::ConstantIntOp::classof(Operation *op) { 159 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 160 return constOp.getType().isSignlessInteger(); 161 return false; 162 } 163 164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 165 const APFloat &value, FloatType type) { 166 arith::ConstantOp::build(builder, result, type, 167 builder.getFloatAttr(type, value)); 168 } 169 170 bool arith::ConstantFloatOp::classof(Operation *op) { 171 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 172 return constOp.getType().isa<FloatType>(); 173 return false; 174 } 175 176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 177 int64_t value) { 178 arith::ConstantOp::build(builder, result, builder.getIndexType(), 179 builder.getIndexAttr(value)); 180 } 181 182 bool arith::ConstantIndexOp::classof(Operation *op) { 183 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 184 return constOp.getType().isIndex(); 185 return false; 186 } 187 188 //===----------------------------------------------------------------------===// 189 // AddIOp 190 //===----------------------------------------------------------------------===// 191 192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 193 // addi(x, 0) -> x 194 if (matchPattern(getRhs(), m_Zero())) 195 return getLhs(); 196 197 return constFoldBinaryOp<IntegerAttr>( 198 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 199 } 200 201 void arith::AddIOp::getCanonicalizationPatterns( 202 OwningRewritePatternList &patterns, MLIRContext *context) { 203 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 204 context); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // SubIOp 209 //===----------------------------------------------------------------------===// 210 211 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 212 // subi(x,x) -> 0 213 if (getOperand(0) == getOperand(1)) 214 return Builder(getContext()).getZeroAttr(getType()); 215 // subi(x,0) -> x 216 if (matchPattern(getRhs(), m_Zero())) 217 return getLhs(); 218 219 return constFoldBinaryOp<IntegerAttr>( 220 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 221 } 222 223 void arith::SubIOp::getCanonicalizationPatterns( 224 OwningRewritePatternList &patterns, MLIRContext *context) { 225 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 226 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 227 SubILHSSubConstantLHS>(context); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // MulIOp 232 //===----------------------------------------------------------------------===// 233 234 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 235 // muli(x, 0) -> 0 236 if (matchPattern(getRhs(), m_Zero())) 237 return getRhs(); 238 // muli(x, 1) -> x 239 if (matchPattern(getRhs(), m_One())) 240 return getOperand(0); 241 // TODO: Handle the overflow case. 242 243 // default folder 244 return constFoldBinaryOp<IntegerAttr>( 245 operands, [](const APInt &a, const APInt &b) { return a * b; }); 246 } 247 248 //===----------------------------------------------------------------------===// 249 // DivUIOp 250 //===----------------------------------------------------------------------===// 251 252 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 253 // Don't fold if it would require a division by zero. 254 bool div0 = false; 255 auto result = 256 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 257 if (div0 || !b) { 258 div0 = true; 259 return a; 260 } 261 return a.udiv(b); 262 }); 263 264 // Fold out division by one. Assumes all tensors of all ones are splats. 265 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 266 if (rhs.getValue() == 1) 267 return getLhs(); 268 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 269 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 270 return getLhs(); 271 } 272 273 return div0 ? Attribute() : result; 274 } 275 276 //===----------------------------------------------------------------------===// 277 // DivSIOp 278 //===----------------------------------------------------------------------===// 279 280 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 281 // Don't fold if it would overflow or if it requires a division by zero. 282 bool overflowOrDiv0 = false; 283 auto result = 284 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 285 if (overflowOrDiv0 || !b) { 286 overflowOrDiv0 = true; 287 return a; 288 } 289 return a.sdiv_ov(b, overflowOrDiv0); 290 }); 291 292 // Fold out division by one. Assumes all tensors of all ones are splats. 293 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 294 if (rhs.getValue() == 1) 295 return getLhs(); 296 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 297 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 298 return getLhs(); 299 } 300 301 return overflowOrDiv0 ? Attribute() : result; 302 } 303 304 //===----------------------------------------------------------------------===// 305 // Ceil and floor division folding helpers 306 //===----------------------------------------------------------------------===// 307 308 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 309 bool &overflow) { 310 // Returns (a-1)/b + 1 311 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 312 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 313 return val.sadd_ov(one, overflow); 314 } 315 316 //===----------------------------------------------------------------------===// 317 // CeilDivUIOp 318 //===----------------------------------------------------------------------===// 319 320 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 321 bool overflowOrDiv0 = false; 322 auto result = 323 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 324 if (overflowOrDiv0 || !b) { 325 overflowOrDiv0 = true; 326 return a; 327 } 328 APInt quotient = a.udiv(b); 329 if (!a.urem(b)) 330 return quotient; 331 APInt one(a.getBitWidth(), 1, true); 332 return quotient.uadd_ov(one, overflowOrDiv0); 333 }); 334 // Fold out ceil division by one. Assumes all tensors of all ones are 335 // splats. 336 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 337 if (rhs.getValue() == 1) 338 return getLhs(); 339 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 340 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 341 return getLhs(); 342 } 343 344 return overflowOrDiv0 ? Attribute() : result; 345 } 346 347 //===----------------------------------------------------------------------===// 348 // CeilDivSIOp 349 //===----------------------------------------------------------------------===// 350 351 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 352 // Don't fold if it would overflow or if it requires a division by zero. 353 bool overflowOrDiv0 = false; 354 auto result = 355 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 356 if (overflowOrDiv0 || !b) { 357 overflowOrDiv0 = true; 358 return a; 359 } 360 if (!a) 361 return a; 362 // After this point we know that neither a or b are zero. 363 unsigned bits = a.getBitWidth(); 364 APInt zero = APInt::getZero(bits); 365 bool aGtZero = a.sgt(zero); 366 bool bGtZero = b.sgt(zero); 367 if (aGtZero && bGtZero) { 368 // Both positive, return ceil(a, b). 369 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 370 } 371 if (!aGtZero && !bGtZero) { 372 // Both negative, return ceil(-a, -b). 373 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 374 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 375 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 376 } 377 if (!aGtZero && bGtZero) { 378 // A is negative, b is positive, return - ( -a / b). 379 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 380 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 381 return zero.ssub_ov(div, overflowOrDiv0); 382 } 383 // A is positive, b is negative, return - (a / -b). 384 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 385 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 386 return zero.ssub_ov(div, overflowOrDiv0); 387 }); 388 389 // Fold out ceil division by one. Assumes all tensors of all ones are 390 // splats. 391 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 392 if (rhs.getValue() == 1) 393 return getLhs(); 394 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 395 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 396 return getLhs(); 397 } 398 399 return overflowOrDiv0 ? Attribute() : result; 400 } 401 402 //===----------------------------------------------------------------------===// 403 // FloorDivSIOp 404 //===----------------------------------------------------------------------===// 405 406 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 407 // Don't fold if it would overflow or if it requires a division by zero. 408 bool overflowOrDiv0 = false; 409 auto result = 410 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 411 if (overflowOrDiv0 || !b) { 412 overflowOrDiv0 = true; 413 return a; 414 } 415 if (!a) 416 return a; 417 // After this point we know that neither a or b are zero. 418 unsigned bits = a.getBitWidth(); 419 APInt zero = APInt::getZero(bits); 420 bool aGtZero = a.sgt(zero); 421 bool bGtZero = b.sgt(zero); 422 if (aGtZero && bGtZero) { 423 // Both positive, return a / b. 424 return a.sdiv_ov(b, overflowOrDiv0); 425 } 426 if (!aGtZero && !bGtZero) { 427 // Both negative, return -a / -b. 428 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 429 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 430 return posA.sdiv_ov(posB, overflowOrDiv0); 431 } 432 if (!aGtZero && bGtZero) { 433 // A is negative, b is positive, return - ceil(-a, b). 434 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 435 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 436 return zero.ssub_ov(ceil, overflowOrDiv0); 437 } 438 // A is positive, b is negative, return - ceil(a, -b). 439 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 440 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 441 return zero.ssub_ov(ceil, overflowOrDiv0); 442 }); 443 444 // Fold out floor division by one. Assumes all tensors of all ones are 445 // splats. 446 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 447 if (rhs.getValue() == 1) 448 return getLhs(); 449 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 450 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 451 return getLhs(); 452 } 453 454 return overflowOrDiv0 ? Attribute() : result; 455 } 456 457 //===----------------------------------------------------------------------===// 458 // RemUIOp 459 //===----------------------------------------------------------------------===// 460 461 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 462 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 463 if (!rhs) 464 return {}; 465 auto rhsValue = rhs.getValue(); 466 467 // x % 1 = 0 468 if (rhsValue.isOneValue()) 469 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 470 471 // Don't fold if it requires division by zero. 472 if (rhsValue.isNullValue()) 473 return {}; 474 475 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 476 if (!lhs) 477 return {}; 478 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 479 } 480 481 //===----------------------------------------------------------------------===// 482 // RemSIOp 483 //===----------------------------------------------------------------------===// 484 485 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 486 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 487 if (!rhs) 488 return {}; 489 auto rhsValue = rhs.getValue(); 490 491 // x % 1 = 0 492 if (rhsValue.isOneValue()) 493 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 494 495 // Don't fold if it requires division by zero. 496 if (rhsValue.isNullValue()) 497 return {}; 498 499 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 500 if (!lhs) 501 return {}; 502 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 503 } 504 505 //===----------------------------------------------------------------------===// 506 // AndIOp 507 //===----------------------------------------------------------------------===// 508 509 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 510 /// and(x, 0) -> 0 511 if (matchPattern(getRhs(), m_Zero())) 512 return getRhs(); 513 /// and(x, allOnes) -> x 514 APInt intValue; 515 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 516 return getLhs(); 517 518 return constFoldBinaryOp<IntegerAttr>( 519 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 520 } 521 522 //===----------------------------------------------------------------------===// 523 // OrIOp 524 //===----------------------------------------------------------------------===// 525 526 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 527 /// or(x, 0) -> x 528 if (matchPattern(getRhs(), m_Zero())) 529 return getLhs(); 530 /// or(x, <all ones>) -> <all ones> 531 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 532 if (rhsAttr.getValue().isAllOnes()) 533 return rhsAttr; 534 535 return constFoldBinaryOp<IntegerAttr>( 536 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 537 } 538 539 //===----------------------------------------------------------------------===// 540 // XOrIOp 541 //===----------------------------------------------------------------------===// 542 543 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 544 /// xor(x, 0) -> x 545 if (matchPattern(getRhs(), m_Zero())) 546 return getLhs(); 547 /// xor(x, x) -> 0 548 if (getLhs() == getRhs()) 549 return Builder(getContext()).getZeroAttr(getType()); 550 551 return constFoldBinaryOp<IntegerAttr>( 552 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 553 } 554 555 void arith::XOrIOp::getCanonicalizationPatterns( 556 OwningRewritePatternList &patterns, MLIRContext *context) { 557 patterns.insert<XOrINotCmpI>(context); 558 } 559 560 //===----------------------------------------------------------------------===// 561 // AddFOp 562 //===----------------------------------------------------------------------===// 563 564 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 565 return constFoldBinaryOp<FloatAttr>( 566 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // SubFOp 571 //===----------------------------------------------------------------------===// 572 573 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 574 return constFoldBinaryOp<FloatAttr>( 575 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 576 } 577 578 //===----------------------------------------------------------------------===// 579 // MaxSIOp 580 //===----------------------------------------------------------------------===// 581 582 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 583 assert(operands.size() == 2 && "binary operation takes two operands"); 584 585 // maxsi(x,x) -> x 586 if (getLhs() == getRhs()) 587 return getRhs(); 588 589 APInt intValue; 590 // maxsi(x,MAX_INT) -> MAX_INT 591 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 592 intValue.isMaxSignedValue()) 593 return getRhs(); 594 595 // maxsi(x, MIN_INT) -> x 596 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 597 intValue.isMinSignedValue()) 598 return getLhs(); 599 600 return constFoldBinaryOp<IntegerAttr>(operands, 601 [](const APInt &a, const APInt &b) { 602 return llvm::APIntOps::smax(a, b); 603 }); 604 } 605 606 //===----------------------------------------------------------------------===// 607 // MaxUIOp 608 //===----------------------------------------------------------------------===// 609 610 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 611 assert(operands.size() == 2 && "binary operation takes two operands"); 612 613 // maxui(x,x) -> x 614 if (getLhs() == getRhs()) 615 return getRhs(); 616 617 APInt intValue; 618 // maxui(x,MAX_INT) -> MAX_INT 619 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 620 return getRhs(); 621 622 // maxui(x, MIN_INT) -> x 623 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 624 return getLhs(); 625 626 return constFoldBinaryOp<IntegerAttr>(operands, 627 [](const APInt &a, const APInt &b) { 628 return llvm::APIntOps::umax(a, b); 629 }); 630 } 631 632 //===----------------------------------------------------------------------===// 633 // MinSIOp 634 //===----------------------------------------------------------------------===// 635 636 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 637 assert(operands.size() == 2 && "binary operation takes two operands"); 638 639 // minsi(x,x) -> x 640 if (getLhs() == getRhs()) 641 return getRhs(); 642 643 APInt intValue; 644 // minsi(x,MIN_INT) -> MIN_INT 645 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 646 intValue.isMinSignedValue()) 647 return getRhs(); 648 649 // minsi(x, MAX_INT) -> x 650 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 651 intValue.isMaxSignedValue()) 652 return getLhs(); 653 654 return constFoldBinaryOp<IntegerAttr>(operands, 655 [](const APInt &a, const APInt &b) { 656 return llvm::APIntOps::smin(a, b); 657 }); 658 } 659 660 //===----------------------------------------------------------------------===// 661 // MinUIOp 662 //===----------------------------------------------------------------------===// 663 664 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 665 assert(operands.size() == 2 && "binary operation takes two operands"); 666 667 // minui(x,x) -> x 668 if (getLhs() == getRhs()) 669 return getRhs(); 670 671 APInt intValue; 672 // minui(x,MIN_INT) -> MIN_INT 673 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 674 return getRhs(); 675 676 // minui(x, MAX_INT) -> x 677 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 678 return getLhs(); 679 680 return constFoldBinaryOp<IntegerAttr>(operands, 681 [](const APInt &a, const APInt &b) { 682 return llvm::APIntOps::umin(a, b); 683 }); 684 } 685 686 //===----------------------------------------------------------------------===// 687 // MulFOp 688 //===----------------------------------------------------------------------===// 689 690 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 691 return constFoldBinaryOp<FloatAttr>( 692 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 693 } 694 695 //===----------------------------------------------------------------------===// 696 // DivFOp 697 //===----------------------------------------------------------------------===// 698 699 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 700 return constFoldBinaryOp<FloatAttr>( 701 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 702 } 703 704 //===----------------------------------------------------------------------===// 705 // Utility functions for verifying cast ops 706 //===----------------------------------------------------------------------===// 707 708 template <typename... Types> 709 using type_list = std::tuple<Types...> *; 710 711 /// Returns a non-null type only if the provided type is one of the allowed 712 /// types or one of the allowed shaped types of the allowed types. Returns the 713 /// element type if a valid shaped type is provided. 714 template <typename... ShapedTypes, typename... ElementTypes> 715 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 716 type_list<ElementTypes...>) { 717 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 718 return {}; 719 720 auto underlyingType = getElementTypeOrSelf(type); 721 if (!underlyingType.isa<ElementTypes...>()) 722 return {}; 723 724 return underlyingType; 725 } 726 727 /// Get allowed underlying types for vectors and tensors. 728 template <typename... ElementTypes> 729 static Type getTypeIfLike(Type type) { 730 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 731 type_list<ElementTypes...>()); 732 } 733 734 /// Get allowed underlying types for vectors, tensors, and memrefs. 735 template <typename... ElementTypes> 736 static Type getTypeIfLikeOrMemRef(Type type) { 737 return getUnderlyingType(type, 738 type_list<VectorType, TensorType, MemRefType>(), 739 type_list<ElementTypes...>()); 740 } 741 742 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 743 return inputs.size() == 1 && outputs.size() == 1 && 744 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 745 } 746 747 //===----------------------------------------------------------------------===// 748 // Verifiers for integer and floating point extension/truncation ops 749 //===----------------------------------------------------------------------===// 750 751 // Extend ops can only extend to a wider type. 752 template <typename ValType, typename Op> 753 static LogicalResult verifyExtOp(Op op) { 754 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 755 Type dstType = getElementTypeOrSelf(op.getType()); 756 757 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 758 return op.emitError("result type ") 759 << dstType << " must be wider than operand type " << srcType; 760 761 return success(); 762 } 763 764 // Truncate ops can only truncate to a shorter type. 765 template <typename ValType, typename Op> 766 static LogicalResult verifyTruncateOp(Op op) { 767 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 768 Type dstType = getElementTypeOrSelf(op.getType()); 769 770 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 771 return op.emitError("result type ") 772 << dstType << " must be shorter than operand type " << srcType; 773 774 return success(); 775 } 776 777 /// Validate a cast that changes the width of a type. 778 template <template <typename> class WidthComparator, typename... ElementTypes> 779 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 780 if (!areValidCastInputsAndOutputs(inputs, outputs)) 781 return false; 782 783 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 784 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 785 if (!srcType || !dstType) 786 return false; 787 788 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 789 srcType.getIntOrFloatBitWidth()); 790 } 791 792 //===----------------------------------------------------------------------===// 793 // ExtUIOp 794 //===----------------------------------------------------------------------===// 795 796 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 797 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 798 return IntegerAttr::get( 799 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 800 801 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 802 getInMutable().assign(lhs.getIn()); 803 return getResult(); 804 } 805 806 return {}; 807 } 808 809 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 810 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 811 } 812 813 //===----------------------------------------------------------------------===// 814 // ExtSIOp 815 //===----------------------------------------------------------------------===// 816 817 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 818 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 819 return IntegerAttr::get( 820 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 821 822 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 823 getInMutable().assign(lhs.getIn()); 824 return getResult(); 825 } 826 827 return {}; 828 } 829 830 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 831 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 832 } 833 834 void arith::ExtSIOp::getCanonicalizationPatterns( 835 OwningRewritePatternList &patterns, MLIRContext *context) { 836 patterns.insert<ExtSIOfExtUI>(context); 837 } 838 839 //===----------------------------------------------------------------------===// 840 // ExtFOp 841 //===----------------------------------------------------------------------===// 842 843 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 844 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 845 } 846 847 //===----------------------------------------------------------------------===// 848 // TruncIOp 849 //===----------------------------------------------------------------------===// 850 851 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 852 // trunci(zexti(a)) -> a 853 // trunci(sexti(a)) -> a 854 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 855 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 856 return getOperand().getDefiningOp()->getOperand(0); 857 858 assert(operands.size() == 1 && "unary operation takes one operand"); 859 860 if (!operands[0]) 861 return {}; 862 863 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 864 return IntegerAttr::get( 865 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 866 } 867 868 return {}; 869 } 870 871 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 872 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 873 } 874 875 //===----------------------------------------------------------------------===// 876 // TruncFOp 877 //===----------------------------------------------------------------------===// 878 879 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 880 /// can be represented without precision loss or rounding. 881 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 882 assert(operands.size() == 1 && "unary operation takes one operand"); 883 884 auto constOperand = operands.front(); 885 if (!constOperand || !constOperand.isa<FloatAttr>()) 886 return {}; 887 888 // Convert to target type via 'double'. 889 double sourceValue = 890 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 891 auto targetAttr = FloatAttr::get(getType(), sourceValue); 892 893 // Propagate if constant's value does not change after truncation. 894 if (sourceValue == targetAttr.getValue().convertToDouble()) 895 return targetAttr; 896 897 return {}; 898 } 899 900 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 901 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 902 } 903 904 //===----------------------------------------------------------------------===// 905 // Verifiers for casts between integers and floats. 906 //===----------------------------------------------------------------------===// 907 908 template <typename From, typename To> 909 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 910 if (!areValidCastInputsAndOutputs(inputs, outputs)) 911 return false; 912 913 auto srcType = getTypeIfLike<From>(inputs.front()); 914 auto dstType = getTypeIfLike<To>(outputs.back()); 915 916 return srcType && dstType; 917 } 918 919 //===----------------------------------------------------------------------===// 920 // UIToFPOp 921 //===----------------------------------------------------------------------===// 922 923 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 924 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 925 } 926 927 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 928 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 929 const APInt &api = lhs.getValue(); 930 FloatType floatTy = getType().cast<FloatType>(); 931 APFloat apf(floatTy.getFloatSemantics(), 932 APInt::getZero(floatTy.getWidth())); 933 apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven); 934 return FloatAttr::get(floatTy, apf); 935 } 936 return {}; 937 } 938 939 //===----------------------------------------------------------------------===// 940 // SIToFPOp 941 //===----------------------------------------------------------------------===// 942 943 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 944 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 945 } 946 947 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 948 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 949 const APInt &api = lhs.getValue(); 950 FloatType floatTy = getType().cast<FloatType>(); 951 APFloat apf(floatTy.getFloatSemantics(), 952 APInt::getZero(floatTy.getWidth())); 953 apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven); 954 return FloatAttr::get(floatTy, apf); 955 } 956 return {}; 957 } 958 //===----------------------------------------------------------------------===// 959 // FPToUIOp 960 //===----------------------------------------------------------------------===// 961 962 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 963 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 964 } 965 966 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 967 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 968 const APFloat &apf = lhs.getValue(); 969 IntegerType intTy = getType().cast<IntegerType>(); 970 bool ignored; 971 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 972 if (APFloat::opInvalidOp == 973 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 974 // Undefined behavior invoked - the destination type can't represent 975 // the input constant. 976 return {}; 977 } 978 return IntegerAttr::get(getType(), api); 979 } 980 981 return {}; 982 } 983 984 //===----------------------------------------------------------------------===// 985 // FPToSIOp 986 //===----------------------------------------------------------------------===// 987 988 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 989 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 990 } 991 992 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 993 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 994 const APFloat &apf = lhs.getValue(); 995 IntegerType intTy = getType().cast<IntegerType>(); 996 bool ignored; 997 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 998 if (APFloat::opInvalidOp == 999 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1000 // Undefined behavior invoked - the destination type can't represent 1001 // the input constant. 1002 return {}; 1003 } 1004 return IntegerAttr::get(getType(), api); 1005 } 1006 1007 return {}; 1008 } 1009 1010 //===----------------------------------------------------------------------===// 1011 // IndexCastOp 1012 //===----------------------------------------------------------------------===// 1013 1014 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1015 TypeRange outputs) { 1016 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1017 return false; 1018 1019 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1020 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1021 if (!srcType || !dstType) 1022 return false; 1023 1024 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1025 (srcType.isSignlessInteger() && dstType.isIndex()); 1026 } 1027 1028 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1029 // index_cast(constant) -> constant 1030 // A little hack because we go through int. Otherwise, the size of the 1031 // constant might need to change. 1032 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1033 return IntegerAttr::get(getType(), value.getInt()); 1034 1035 return {}; 1036 } 1037 1038 void arith::IndexCastOp::getCanonicalizationPatterns( 1039 OwningRewritePatternList &patterns, MLIRContext *context) { 1040 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1041 } 1042 1043 //===----------------------------------------------------------------------===// 1044 // BitcastOp 1045 //===----------------------------------------------------------------------===// 1046 1047 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1048 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1049 return false; 1050 1051 auto srcType = 1052 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1053 auto dstType = 1054 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1055 if (!srcType || !dstType) 1056 return false; 1057 1058 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1059 } 1060 1061 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1062 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1063 1064 auto resType = getType(); 1065 auto operand = operands[0]; 1066 if (!operand) 1067 return {}; 1068 1069 /// Bitcast dense elements. 1070 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1071 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1072 /// Other shaped types unhandled. 1073 if (resType.isa<ShapedType>()) 1074 return {}; 1075 1076 /// Bitcast integer or float to integer or float. 1077 APInt bits = operand.isa<FloatAttr>() 1078 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1079 : operand.cast<IntegerAttr>().getValue(); 1080 1081 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1082 return FloatAttr::get(resType, 1083 APFloat(resFloatType.getFloatSemantics(), bits)); 1084 return IntegerAttr::get(resType, bits); 1085 } 1086 1087 void arith::BitcastOp::getCanonicalizationPatterns( 1088 OwningRewritePatternList &patterns, MLIRContext *context) { 1089 patterns.insert<BitcastOfBitcast>(context); 1090 } 1091 1092 //===----------------------------------------------------------------------===// 1093 // Helpers for compare ops 1094 //===----------------------------------------------------------------------===// 1095 1096 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1097 static Type getI1SameShape(Type type) { 1098 auto i1Type = IntegerType::get(type.getContext(), 1); 1099 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1100 return RankedTensorType::get(tensorType.getShape(), i1Type); 1101 if (type.isa<UnrankedTensorType>()) 1102 return UnrankedTensorType::get(i1Type); 1103 if (auto vectorType = type.dyn_cast<VectorType>()) 1104 return VectorType::get(vectorType.getShape(), i1Type, 1105 vectorType.getNumScalableDims()); 1106 return i1Type; 1107 } 1108 1109 //===----------------------------------------------------------------------===// 1110 // CmpIOp 1111 //===----------------------------------------------------------------------===// 1112 1113 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1114 /// comparison predicates. 1115 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1116 const APInt &lhs, const APInt &rhs) { 1117 switch (predicate) { 1118 case arith::CmpIPredicate::eq: 1119 return lhs.eq(rhs); 1120 case arith::CmpIPredicate::ne: 1121 return lhs.ne(rhs); 1122 case arith::CmpIPredicate::slt: 1123 return lhs.slt(rhs); 1124 case arith::CmpIPredicate::sle: 1125 return lhs.sle(rhs); 1126 case arith::CmpIPredicate::sgt: 1127 return lhs.sgt(rhs); 1128 case arith::CmpIPredicate::sge: 1129 return lhs.sge(rhs); 1130 case arith::CmpIPredicate::ult: 1131 return lhs.ult(rhs); 1132 case arith::CmpIPredicate::ule: 1133 return lhs.ule(rhs); 1134 case arith::CmpIPredicate::ugt: 1135 return lhs.ugt(rhs); 1136 case arith::CmpIPredicate::uge: 1137 return lhs.uge(rhs); 1138 } 1139 llvm_unreachable("unknown cmpi predicate kind"); 1140 } 1141 1142 /// Returns true if the predicate is true for two equal operands. 1143 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1144 switch (predicate) { 1145 case arith::CmpIPredicate::eq: 1146 case arith::CmpIPredicate::sle: 1147 case arith::CmpIPredicate::sge: 1148 case arith::CmpIPredicate::ule: 1149 case arith::CmpIPredicate::uge: 1150 return true; 1151 case arith::CmpIPredicate::ne: 1152 case arith::CmpIPredicate::slt: 1153 case arith::CmpIPredicate::sgt: 1154 case arith::CmpIPredicate::ult: 1155 case arith::CmpIPredicate::ugt: 1156 return false; 1157 } 1158 llvm_unreachable("unknown cmpi predicate kind"); 1159 } 1160 1161 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1162 auto boolAttr = BoolAttr::get(ctx, value); 1163 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1164 if (!shapedType) 1165 return boolAttr; 1166 return DenseElementsAttr::get(shapedType, boolAttr); 1167 } 1168 1169 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1170 assert(operands.size() == 2 && "cmpi takes two operands"); 1171 1172 // cmpi(pred, x, x) 1173 if (getLhs() == getRhs()) { 1174 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1175 return getBoolAttribute(getType(), getContext(), val); 1176 } 1177 1178 if (matchPattern(getRhs(), m_Zero())) { 1179 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1180 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1181 // extsi(%x : i1 -> iN) != 0 -> %x 1182 if (getPredicate() == arith::CmpIPredicate::ne) { 1183 return extOp.getOperand(); 1184 } 1185 } 1186 } 1187 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1188 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1189 // extui(%x : i1 -> iN) != 0 -> %x 1190 if (getPredicate() == arith::CmpIPredicate::ne) { 1191 return extOp.getOperand(); 1192 } 1193 } 1194 } 1195 } 1196 1197 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1198 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1199 if (!lhs || !rhs) 1200 return {}; 1201 1202 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1203 return BoolAttr::get(getContext(), val); 1204 } 1205 1206 //===----------------------------------------------------------------------===// 1207 // CmpFOp 1208 //===----------------------------------------------------------------------===// 1209 1210 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1211 /// comparison predicates. 1212 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1213 const APFloat &lhs, const APFloat &rhs) { 1214 auto cmpResult = lhs.compare(rhs); 1215 switch (predicate) { 1216 case arith::CmpFPredicate::AlwaysFalse: 1217 return false; 1218 case arith::CmpFPredicate::OEQ: 1219 return cmpResult == APFloat::cmpEqual; 1220 case arith::CmpFPredicate::OGT: 1221 return cmpResult == APFloat::cmpGreaterThan; 1222 case arith::CmpFPredicate::OGE: 1223 return cmpResult == APFloat::cmpGreaterThan || 1224 cmpResult == APFloat::cmpEqual; 1225 case arith::CmpFPredicate::OLT: 1226 return cmpResult == APFloat::cmpLessThan; 1227 case arith::CmpFPredicate::OLE: 1228 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1229 case arith::CmpFPredicate::ONE: 1230 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1231 case arith::CmpFPredicate::ORD: 1232 return cmpResult != APFloat::cmpUnordered; 1233 case arith::CmpFPredicate::UEQ: 1234 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1235 case arith::CmpFPredicate::UGT: 1236 return cmpResult == APFloat::cmpUnordered || 1237 cmpResult == APFloat::cmpGreaterThan; 1238 case arith::CmpFPredicate::UGE: 1239 return cmpResult == APFloat::cmpUnordered || 1240 cmpResult == APFloat::cmpGreaterThan || 1241 cmpResult == APFloat::cmpEqual; 1242 case arith::CmpFPredicate::ULT: 1243 return cmpResult == APFloat::cmpUnordered || 1244 cmpResult == APFloat::cmpLessThan; 1245 case arith::CmpFPredicate::ULE: 1246 return cmpResult == APFloat::cmpUnordered || 1247 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1248 case arith::CmpFPredicate::UNE: 1249 return cmpResult != APFloat::cmpEqual; 1250 case arith::CmpFPredicate::UNO: 1251 return cmpResult == APFloat::cmpUnordered; 1252 case arith::CmpFPredicate::AlwaysTrue: 1253 return true; 1254 } 1255 llvm_unreachable("unknown cmpf predicate kind"); 1256 } 1257 1258 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1259 assert(operands.size() == 2 && "cmpf takes two operands"); 1260 1261 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1262 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1263 1264 if (!lhs || !rhs) 1265 return {}; 1266 1267 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1268 return BoolAttr::get(getContext(), val); 1269 } 1270 1271 //===----------------------------------------------------------------------===// 1272 // Atomic Enum 1273 //===----------------------------------------------------------------------===// 1274 1275 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1276 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1277 OpBuilder &builder, Location loc) { 1278 switch (kind) { 1279 case AtomicRMWKind::maxf: 1280 return builder.getFloatAttr( 1281 resultType, 1282 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1283 /*Negative=*/true)); 1284 case AtomicRMWKind::addf: 1285 case AtomicRMWKind::addi: 1286 case AtomicRMWKind::maxu: 1287 case AtomicRMWKind::ori: 1288 return builder.getZeroAttr(resultType); 1289 case AtomicRMWKind::andi: 1290 return builder.getIntegerAttr( 1291 resultType, 1292 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1293 case AtomicRMWKind::maxs: 1294 return builder.getIntegerAttr( 1295 resultType, 1296 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1297 case AtomicRMWKind::minf: 1298 return builder.getFloatAttr( 1299 resultType, 1300 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1301 /*Negative=*/false)); 1302 case AtomicRMWKind::mins: 1303 return builder.getIntegerAttr( 1304 resultType, 1305 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1306 case AtomicRMWKind::minu: 1307 return builder.getIntegerAttr( 1308 resultType, 1309 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1310 case AtomicRMWKind::muli: 1311 return builder.getIntegerAttr(resultType, 1); 1312 case AtomicRMWKind::mulf: 1313 return builder.getFloatAttr(resultType, 1); 1314 // TODO: Add remaining reduction operations. 1315 default: 1316 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1317 break; 1318 } 1319 return nullptr; 1320 } 1321 1322 /// Returns the identity value associated with an AtomicRMWKind op. 1323 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 1324 OpBuilder &builder, Location loc) { 1325 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 1326 return builder.create<arith::ConstantOp>(loc, attr); 1327 } 1328 1329 /// Return the value obtained by applying the reduction operation kind 1330 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 1331 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 1332 Location loc, Value lhs, Value rhs) { 1333 switch (op) { 1334 case AtomicRMWKind::addf: 1335 return builder.create<arith::AddFOp>(loc, lhs, rhs); 1336 case AtomicRMWKind::addi: 1337 return builder.create<arith::AddIOp>(loc, lhs, rhs); 1338 case AtomicRMWKind::mulf: 1339 return builder.create<arith::MulFOp>(loc, lhs, rhs); 1340 case AtomicRMWKind::muli: 1341 return builder.create<arith::MulIOp>(loc, lhs, rhs); 1342 case AtomicRMWKind::maxf: 1343 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 1344 case AtomicRMWKind::minf: 1345 return builder.create<arith::MinFOp>(loc, lhs, rhs); 1346 case AtomicRMWKind::maxs: 1347 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 1348 case AtomicRMWKind::mins: 1349 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 1350 case AtomicRMWKind::maxu: 1351 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 1352 case AtomicRMWKind::minu: 1353 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 1354 case AtomicRMWKind::ori: 1355 return builder.create<arith::OrIOp>(loc, lhs, rhs); 1356 case AtomicRMWKind::andi: 1357 return builder.create<arith::AndIOp>(loc, lhs, rhs); 1358 // TODO: Add remaining reduction operations. 1359 default: 1360 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1361 break; 1362 } 1363 return nullptr; 1364 } 1365 1366 //===----------------------------------------------------------------------===// 1367 // TableGen'd op method definitions 1368 //===----------------------------------------------------------------------===// 1369 1370 #define GET_OP_CLASSES 1371 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1372 1373 //===----------------------------------------------------------------------===// 1374 // TableGen'd enum attribute definitions 1375 //===----------------------------------------------------------------------===// 1376 1377 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1378