1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "llvm/ADT/SmallString.h" 19 20 #include "llvm/ADT/APSInt.h" 21 22 using namespace mlir; 23 using namespace mlir::arith; 24 25 //===----------------------------------------------------------------------===// 26 // Pattern helpers 27 //===----------------------------------------------------------------------===// 28 29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 30 Attribute lhs, Attribute rhs) { 31 return builder.getIntegerAttr(res.getType(), 32 lhs.cast<IntegerAttr>().getInt() + 33 rhs.cast<IntegerAttr>().getInt()); 34 } 35 36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 37 Attribute lhs, Attribute rhs) { 38 return builder.getIntegerAttr(res.getType(), 39 lhs.cast<IntegerAttr>().getInt() - 40 rhs.cast<IntegerAttr>().getInt()); 41 } 42 43 /// Invert an integer comparison predicate. 44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { 45 switch (pred) { 46 case arith::CmpIPredicate::eq: 47 return arith::CmpIPredicate::ne; 48 case arith::CmpIPredicate::ne: 49 return arith::CmpIPredicate::eq; 50 case arith::CmpIPredicate::slt: 51 return arith::CmpIPredicate::sge; 52 case arith::CmpIPredicate::sle: 53 return arith::CmpIPredicate::sgt; 54 case arith::CmpIPredicate::sgt: 55 return arith::CmpIPredicate::sle; 56 case arith::CmpIPredicate::sge: 57 return arith::CmpIPredicate::slt; 58 case arith::CmpIPredicate::ult: 59 return arith::CmpIPredicate::uge; 60 case arith::CmpIPredicate::ule: 61 return arith::CmpIPredicate::ugt; 62 case arith::CmpIPredicate::ugt: 63 return arith::CmpIPredicate::ule; 64 case arith::CmpIPredicate::uge: 65 return arith::CmpIPredicate::ult; 66 } 67 llvm_unreachable("unknown cmpi predicate kind"); 68 } 69 70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 71 return arith::CmpIPredicateAttr::get(pred.getContext(), 72 invertPredicate(pred.getValue())); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // TableGen'd canonicalization patterns 77 //===----------------------------------------------------------------------===// 78 79 namespace { 80 #include "ArithmeticCanonicalization.inc" 81 } // namespace 82 83 //===----------------------------------------------------------------------===// 84 // ConstantOp 85 //===----------------------------------------------------------------------===// 86 87 void arith::ConstantOp::getAsmResultNames( 88 function_ref<void(Value, StringRef)> setNameFn) { 89 auto type = getType(); 90 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 91 auto intType = type.dyn_cast<IntegerType>(); 92 93 // Sugar i1 constants with 'true' and 'false'. 94 if (intType && intType.getWidth() == 1) 95 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 96 97 // Otherwise, build a compex name with the value and type. 98 SmallString<32> specialNameBuffer; 99 llvm::raw_svector_ostream specialName(specialNameBuffer); 100 specialName << 'c' << intCst.getInt(); 101 if (intType) 102 specialName << '_' << type; 103 setNameFn(getResult(), specialName.str()); 104 } else { 105 setNameFn(getResult(), "cst"); 106 } 107 } 108 109 /// TODO: disallow arith.constant to return anything other than signless integer 110 /// or float like. 111 LogicalResult arith::ConstantOp::verify() { 112 auto type = getType(); 113 // The value's type must match the return type. 114 if (getValue().getType() != type) { 115 return emitOpError() << "value type " << getValue().getType() 116 << " must match return type: " << type; 117 } 118 // Integer values must be signless. 119 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 120 return emitOpError("integer return type must be signless"); 121 // Any float or elements attribute are acceptable. 122 if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 123 return emitOpError( 124 "value must be an integer, float, or elements attribute"); 125 } 126 return success(); 127 } 128 129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 130 // The value's type must be the same as the provided type. 131 if (value.getType() != type) 132 return false; 133 // Integer values must be signless. 134 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 135 return false; 136 // Integer, float, and element attributes are buildable. 137 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 138 } 139 140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 141 return getValue(); 142 } 143 144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 145 int64_t value, unsigned width) { 146 auto type = builder.getIntegerType(width); 147 arith::ConstantOp::build(builder, result, type, 148 builder.getIntegerAttr(type, value)); 149 } 150 151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 152 int64_t value, Type type) { 153 assert(type.isSignlessInteger() && 154 "ConstantIntOp can only have signless integer type values"); 155 arith::ConstantOp::build(builder, result, type, 156 builder.getIntegerAttr(type, value)); 157 } 158 159 bool arith::ConstantIntOp::classof(Operation *op) { 160 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 161 return constOp.getType().isSignlessInteger(); 162 return false; 163 } 164 165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 166 const APFloat &value, FloatType type) { 167 arith::ConstantOp::build(builder, result, type, 168 builder.getFloatAttr(type, value)); 169 } 170 171 bool arith::ConstantFloatOp::classof(Operation *op) { 172 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 173 return constOp.getType().isa<FloatType>(); 174 return false; 175 } 176 177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 178 int64_t value) { 179 arith::ConstantOp::build(builder, result, builder.getIndexType(), 180 builder.getIndexAttr(value)); 181 } 182 183 bool arith::ConstantIndexOp::classof(Operation *op) { 184 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 185 return constOp.getType().isIndex(); 186 return false; 187 } 188 189 //===----------------------------------------------------------------------===// 190 // AddIOp 191 //===----------------------------------------------------------------------===// 192 193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 194 // addi(x, 0) -> x 195 if (matchPattern(getRhs(), m_Zero())) 196 return getLhs(); 197 198 // addi(subi(a, b), b) -> a 199 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 200 if (getRhs() == sub.getRhs()) 201 return sub.getLhs(); 202 203 // addi(b, subi(a, b)) -> a 204 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 205 if (getLhs() == sub.getRhs()) 206 return sub.getLhs(); 207 208 return constFoldBinaryOp<IntegerAttr>( 209 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 210 } 211 212 void arith::AddIOp::getCanonicalizationPatterns( 213 RewritePatternSet &patterns, MLIRContext *context) { 214 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 215 context); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // SubIOp 220 //===----------------------------------------------------------------------===// 221 222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 223 // subi(x,x) -> 0 224 if (getOperand(0) == getOperand(1)) 225 return Builder(getContext()).getZeroAttr(getType()); 226 // subi(x,0) -> x 227 if (matchPattern(getRhs(), m_Zero())) 228 return getLhs(); 229 230 return constFoldBinaryOp<IntegerAttr>( 231 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 232 } 233 234 void arith::SubIOp::getCanonicalizationPatterns( 235 RewritePatternSet &patterns, MLIRContext *context) { 236 patterns 237 .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 238 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>( 239 context); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // MulIOp 244 //===----------------------------------------------------------------------===// 245 246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 247 // muli(x, 0) -> 0 248 if (matchPattern(getRhs(), m_Zero())) 249 return getRhs(); 250 // muli(x, 1) -> x 251 if (matchPattern(getRhs(), m_One())) 252 return getOperand(0); 253 // TODO: Handle the overflow case. 254 255 // default folder 256 return constFoldBinaryOp<IntegerAttr>( 257 operands, [](const APInt &a, const APInt &b) { return a * b; }); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // DivUIOp 262 //===----------------------------------------------------------------------===// 263 264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 265 // divui (x, 1) -> x. 266 if (matchPattern(getRhs(), m_One())) 267 return getLhs(); 268 269 // Don't fold if it would require a division by zero. 270 bool div0 = false; 271 auto result = 272 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 273 if (div0 || !b) { 274 div0 = true; 275 return a; 276 } 277 return a.udiv(b); 278 }); 279 280 return div0 ? Attribute() : result; 281 } 282 283 //===----------------------------------------------------------------------===// 284 // DivSIOp 285 //===----------------------------------------------------------------------===// 286 287 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 288 // divsi (x, 1) -> x. 289 if (matchPattern(getRhs(), m_One())) 290 return getLhs(); 291 292 // Don't fold if it would overflow or if it requires a division by zero. 293 bool overflowOrDiv0 = false; 294 auto result = 295 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 296 if (overflowOrDiv0 || !b) { 297 overflowOrDiv0 = true; 298 return a; 299 } 300 return a.sdiv_ov(b, overflowOrDiv0); 301 }); 302 303 return overflowOrDiv0 ? Attribute() : result; 304 } 305 306 //===----------------------------------------------------------------------===// 307 // Ceil and floor division folding helpers 308 //===----------------------------------------------------------------------===// 309 310 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 311 bool &overflow) { 312 // Returns (a-1)/b + 1 313 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 314 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 315 return val.sadd_ov(one, overflow); 316 } 317 318 //===----------------------------------------------------------------------===// 319 // CeilDivUIOp 320 //===----------------------------------------------------------------------===// 321 322 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 323 // ceildivui (x, 1) -> x. 324 if (matchPattern(getRhs(), m_One())) 325 return getLhs(); 326 327 bool overflowOrDiv0 = false; 328 auto result = 329 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 330 if (overflowOrDiv0 || !b) { 331 overflowOrDiv0 = true; 332 return a; 333 } 334 APInt quotient = a.udiv(b); 335 if (!a.urem(b)) 336 return quotient; 337 APInt one(a.getBitWidth(), 1, true); 338 return quotient.uadd_ov(one, overflowOrDiv0); 339 }); 340 341 return overflowOrDiv0 ? Attribute() : result; 342 } 343 344 //===----------------------------------------------------------------------===// 345 // CeilDivSIOp 346 //===----------------------------------------------------------------------===// 347 348 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 349 // ceildivsi (x, 1) -> x. 350 if (matchPattern(getRhs(), m_One())) 351 return getLhs(); 352 353 // Don't fold if it would overflow or if it requires a division by zero. 354 bool overflowOrDiv0 = false; 355 auto result = 356 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 357 if (overflowOrDiv0 || !b) { 358 overflowOrDiv0 = true; 359 return a; 360 } 361 if (!a) 362 return a; 363 // After this point we know that neither a or b are zero. 364 unsigned bits = a.getBitWidth(); 365 APInt zero = APInt::getZero(bits); 366 bool aGtZero = a.sgt(zero); 367 bool bGtZero = b.sgt(zero); 368 if (aGtZero && bGtZero) { 369 // Both positive, return ceil(a, b). 370 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 371 } 372 if (!aGtZero && !bGtZero) { 373 // Both negative, return ceil(-a, -b). 374 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 375 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 376 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 377 } 378 if (!aGtZero && bGtZero) { 379 // A is negative, b is positive, return - ( -a / b). 380 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 381 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 382 return zero.ssub_ov(div, overflowOrDiv0); 383 } 384 // A is positive, b is negative, return - (a / -b). 385 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 386 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 387 return zero.ssub_ov(div, overflowOrDiv0); 388 }); 389 390 return overflowOrDiv0 ? Attribute() : result; 391 } 392 393 //===----------------------------------------------------------------------===// 394 // FloorDivSIOp 395 //===----------------------------------------------------------------------===// 396 397 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 398 // floordivsi (x, 1) -> x. 399 if (matchPattern(getRhs(), m_One())) 400 return getLhs(); 401 402 // Don't fold if it would overflow or if it requires a division by zero. 403 bool overflowOrDiv0 = false; 404 auto result = 405 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 406 if (overflowOrDiv0 || !b) { 407 overflowOrDiv0 = true; 408 return a; 409 } 410 if (!a) 411 return a; 412 // After this point we know that neither a or b are zero. 413 unsigned bits = a.getBitWidth(); 414 APInt zero = APInt::getZero(bits); 415 bool aGtZero = a.sgt(zero); 416 bool bGtZero = b.sgt(zero); 417 if (aGtZero && bGtZero) { 418 // Both positive, return a / b. 419 return a.sdiv_ov(b, overflowOrDiv0); 420 } 421 if (!aGtZero && !bGtZero) { 422 // Both negative, return -a / -b. 423 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 424 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 425 return posA.sdiv_ov(posB, overflowOrDiv0); 426 } 427 if (!aGtZero && bGtZero) { 428 // A is negative, b is positive, return - ceil(-a, b). 429 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 430 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 431 return zero.ssub_ov(ceil, overflowOrDiv0); 432 } 433 // A is positive, b is negative, return - ceil(a, -b). 434 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 435 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 436 return zero.ssub_ov(ceil, overflowOrDiv0); 437 }); 438 439 return overflowOrDiv0 ? Attribute() : result; 440 } 441 442 //===----------------------------------------------------------------------===// 443 // RemUIOp 444 //===----------------------------------------------------------------------===// 445 446 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 447 // remui (x, 1) -> 0. 448 if (matchPattern(getRhs(), m_One())) 449 return Builder(getContext()).getZeroAttr(getType()); 450 451 // Don't fold if it would require a division by zero. 452 bool div0 = false; 453 auto result = 454 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 455 if (div0 || b.isNullValue()) { 456 div0 = true; 457 return a; 458 } 459 return a.urem(b); 460 }); 461 462 return div0 ? Attribute() : result; 463 } 464 465 //===----------------------------------------------------------------------===// 466 // RemSIOp 467 //===----------------------------------------------------------------------===// 468 469 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 470 // remsi (x, 1) -> 0. 471 if (matchPattern(getRhs(), m_One())) 472 return Builder(getContext()).getZeroAttr(getType()); 473 474 // Don't fold if it would require a division by zero. 475 bool div0 = false; 476 auto result = 477 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 478 if (div0 || b.isNullValue()) { 479 div0 = true; 480 return a; 481 } 482 return a.srem(b); 483 }); 484 485 return div0 ? Attribute() : result; 486 } 487 488 //===----------------------------------------------------------------------===// 489 // AndIOp 490 //===----------------------------------------------------------------------===// 491 492 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 493 /// and(x, 0) -> 0 494 if (matchPattern(getRhs(), m_Zero())) 495 return getRhs(); 496 /// and(x, allOnes) -> x 497 APInt intValue; 498 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 499 return getLhs(); 500 501 return constFoldBinaryOp<IntegerAttr>( 502 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 503 } 504 505 //===----------------------------------------------------------------------===// 506 // OrIOp 507 //===----------------------------------------------------------------------===// 508 509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 510 /// or(x, 0) -> x 511 if (matchPattern(getRhs(), m_Zero())) 512 return getLhs(); 513 /// or(x, <all ones>) -> <all ones> 514 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 515 if (rhsAttr.getValue().isAllOnes()) 516 return rhsAttr; 517 518 return constFoldBinaryOp<IntegerAttr>( 519 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 520 } 521 522 //===----------------------------------------------------------------------===// 523 // XOrIOp 524 //===----------------------------------------------------------------------===// 525 526 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 527 /// xor(x, 0) -> x 528 if (matchPattern(getRhs(), m_Zero())) 529 return getLhs(); 530 /// xor(x, x) -> 0 531 if (getLhs() == getRhs()) 532 return Builder(getContext()).getZeroAttr(getType()); 533 /// xor(xor(x, a), a) -> x 534 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) 535 if (prev.getRhs() == getRhs()) 536 return prev.getLhs(); 537 538 return constFoldBinaryOp<IntegerAttr>( 539 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 540 } 541 542 void arith::XOrIOp::getCanonicalizationPatterns( 543 RewritePatternSet &patterns, MLIRContext *context) { 544 patterns.add<XOrINotCmpI>(context); 545 } 546 547 //===----------------------------------------------------------------------===// 548 // NegFOp 549 //===----------------------------------------------------------------------===// 550 551 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) { 552 return constFoldUnaryOp<FloatAttr>(operands, 553 [](const APFloat &a) { return -a; }); 554 } 555 556 //===----------------------------------------------------------------------===// 557 // AddFOp 558 //===----------------------------------------------------------------------===// 559 560 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 561 // addf(x, -0) -> x 562 if (matchPattern(getRhs(), m_NegZeroFloat())) 563 return getLhs(); 564 565 return constFoldBinaryOp<FloatAttr>( 566 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // SubFOp 571 //===----------------------------------------------------------------------===// 572 573 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 574 // subf(x, +0) -> x 575 if (matchPattern(getRhs(), m_PosZeroFloat())) 576 return getLhs(); 577 578 return constFoldBinaryOp<FloatAttr>( 579 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 580 } 581 582 //===----------------------------------------------------------------------===// 583 // MaxFOp 584 //===----------------------------------------------------------------------===// 585 586 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) { 587 assert(operands.size() == 2 && "maxf takes two operands"); 588 589 // maxf(x,x) -> x 590 if (getLhs() == getRhs()) 591 return getRhs(); 592 593 // maxf(x, -inf) -> x 594 if (matchPattern(getRhs(), m_NegInfFloat())) 595 return getLhs(); 596 597 return constFoldBinaryOp<FloatAttr>( 598 operands, 599 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); 600 } 601 602 //===----------------------------------------------------------------------===// 603 // MaxSIOp 604 //===----------------------------------------------------------------------===// 605 606 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 607 assert(operands.size() == 2 && "binary operation takes two operands"); 608 609 // maxsi(x,x) -> x 610 if (getLhs() == getRhs()) 611 return getRhs(); 612 613 APInt intValue; 614 // maxsi(x,MAX_INT) -> MAX_INT 615 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 616 intValue.isMaxSignedValue()) 617 return getRhs(); 618 619 // maxsi(x, MIN_INT) -> x 620 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 621 intValue.isMinSignedValue()) 622 return getLhs(); 623 624 return constFoldBinaryOp<IntegerAttr>(operands, 625 [](const APInt &a, const APInt &b) { 626 return llvm::APIntOps::smax(a, b); 627 }); 628 } 629 630 //===----------------------------------------------------------------------===// 631 // MaxUIOp 632 //===----------------------------------------------------------------------===// 633 634 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 635 assert(operands.size() == 2 && "binary operation takes two operands"); 636 637 // maxui(x,x) -> x 638 if (getLhs() == getRhs()) 639 return getRhs(); 640 641 APInt intValue; 642 // maxui(x,MAX_INT) -> MAX_INT 643 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 644 return getRhs(); 645 646 // maxui(x, MIN_INT) -> x 647 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 648 return getLhs(); 649 650 return constFoldBinaryOp<IntegerAttr>(operands, 651 [](const APInt &a, const APInt &b) { 652 return llvm::APIntOps::umax(a, b); 653 }); 654 } 655 656 //===----------------------------------------------------------------------===// 657 // MinFOp 658 //===----------------------------------------------------------------------===// 659 660 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) { 661 assert(operands.size() == 2 && "minf takes two operands"); 662 663 // minf(x,x) -> x 664 if (getLhs() == getRhs()) 665 return getRhs(); 666 667 // minf(x, +inf) -> x 668 if (matchPattern(getRhs(), m_PosInfFloat())) 669 return getLhs(); 670 671 return constFoldBinaryOp<FloatAttr>( 672 operands, 673 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // MinSIOp 678 //===----------------------------------------------------------------------===// 679 680 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 681 assert(operands.size() == 2 && "binary operation takes two operands"); 682 683 // minsi(x,x) -> x 684 if (getLhs() == getRhs()) 685 return getRhs(); 686 687 APInt intValue; 688 // minsi(x,MIN_INT) -> MIN_INT 689 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 690 intValue.isMinSignedValue()) 691 return getRhs(); 692 693 // minsi(x, MAX_INT) -> x 694 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 695 intValue.isMaxSignedValue()) 696 return getLhs(); 697 698 return constFoldBinaryOp<IntegerAttr>(operands, 699 [](const APInt &a, const APInt &b) { 700 return llvm::APIntOps::smin(a, b); 701 }); 702 } 703 704 //===----------------------------------------------------------------------===// 705 // MinUIOp 706 //===----------------------------------------------------------------------===// 707 708 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 709 assert(operands.size() == 2 && "binary operation takes two operands"); 710 711 // minui(x,x) -> x 712 if (getLhs() == getRhs()) 713 return getRhs(); 714 715 APInt intValue; 716 // minui(x,MIN_INT) -> MIN_INT 717 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 718 return getRhs(); 719 720 // minui(x, MAX_INT) -> x 721 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 722 return getLhs(); 723 724 return constFoldBinaryOp<IntegerAttr>(operands, 725 [](const APInt &a, const APInt &b) { 726 return llvm::APIntOps::umin(a, b); 727 }); 728 } 729 730 //===----------------------------------------------------------------------===// 731 // MulFOp 732 //===----------------------------------------------------------------------===// 733 734 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 735 // mulf(x, 1) -> x 736 if (matchPattern(getRhs(), m_OneFloat())) 737 return getLhs(); 738 739 return constFoldBinaryOp<FloatAttr>( 740 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 741 } 742 743 //===----------------------------------------------------------------------===// 744 // DivFOp 745 //===----------------------------------------------------------------------===// 746 747 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 748 // divf(x, 1) -> x 749 if (matchPattern(getRhs(), m_OneFloat())) 750 return getLhs(); 751 752 return constFoldBinaryOp<FloatAttr>( 753 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 754 } 755 756 //===----------------------------------------------------------------------===// 757 // Utility functions for verifying cast ops 758 //===----------------------------------------------------------------------===// 759 760 template <typename... Types> 761 using type_list = std::tuple<Types...> *; 762 763 /// Returns a non-null type only if the provided type is one of the allowed 764 /// types or one of the allowed shaped types of the allowed types. Returns the 765 /// element type if a valid shaped type is provided. 766 template <typename... ShapedTypes, typename... ElementTypes> 767 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 768 type_list<ElementTypes...>) { 769 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 770 return {}; 771 772 auto underlyingType = getElementTypeOrSelf(type); 773 if (!underlyingType.isa<ElementTypes...>()) 774 return {}; 775 776 return underlyingType; 777 } 778 779 /// Get allowed underlying types for vectors and tensors. 780 template <typename... ElementTypes> 781 static Type getTypeIfLike(Type type) { 782 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 783 type_list<ElementTypes...>()); 784 } 785 786 /// Get allowed underlying types for vectors, tensors, and memrefs. 787 template <typename... ElementTypes> 788 static Type getTypeIfLikeOrMemRef(Type type) { 789 return getUnderlyingType(type, 790 type_list<VectorType, TensorType, MemRefType>(), 791 type_list<ElementTypes...>()); 792 } 793 794 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 795 return inputs.size() == 1 && outputs.size() == 1 && 796 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 797 } 798 799 //===----------------------------------------------------------------------===// 800 // Verifiers for integer and floating point extension/truncation ops 801 //===----------------------------------------------------------------------===// 802 803 // Extend ops can only extend to a wider type. 804 template <typename ValType, typename Op> 805 static LogicalResult verifyExtOp(Op op) { 806 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 807 Type dstType = getElementTypeOrSelf(op.getType()); 808 809 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 810 return op.emitError("result type ") 811 << dstType << " must be wider than operand type " << srcType; 812 813 return success(); 814 } 815 816 // Truncate ops can only truncate to a shorter type. 817 template <typename ValType, typename Op> 818 static LogicalResult verifyTruncateOp(Op op) { 819 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 820 Type dstType = getElementTypeOrSelf(op.getType()); 821 822 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 823 return op.emitError("result type ") 824 << dstType << " must be shorter than operand type " << srcType; 825 826 return success(); 827 } 828 829 /// Validate a cast that changes the width of a type. 830 template <template <typename> class WidthComparator, typename... ElementTypes> 831 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 832 if (!areValidCastInputsAndOutputs(inputs, outputs)) 833 return false; 834 835 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 836 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 837 if (!srcType || !dstType) 838 return false; 839 840 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 841 srcType.getIntOrFloatBitWidth()); 842 } 843 844 //===----------------------------------------------------------------------===// 845 // ExtUIOp 846 //===----------------------------------------------------------------------===// 847 848 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 849 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 850 getInMutable().assign(lhs.getIn()); 851 return getResult(); 852 } 853 Type resType = getType(); 854 unsigned bitWidth; 855 if (auto shapedType = resType.dyn_cast<ShapedType>()) 856 bitWidth = shapedType.getElementTypeBitWidth(); 857 else 858 bitWidth = resType.getIntOrFloatBitWidth(); 859 return constFoldCastOp<IntegerAttr, IntegerAttr>( 860 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 861 return a.zext(bitWidth); 862 }); 863 } 864 865 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 866 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 867 } 868 869 LogicalResult arith::ExtUIOp::verify() { 870 return verifyExtOp<IntegerType>(*this); 871 } 872 873 //===----------------------------------------------------------------------===// 874 // ExtSIOp 875 //===----------------------------------------------------------------------===// 876 877 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 878 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 879 getInMutable().assign(lhs.getIn()); 880 return getResult(); 881 } 882 Type resType = getType(); 883 unsigned bitWidth; 884 if (auto shapedType = resType.dyn_cast<ShapedType>()) 885 bitWidth = shapedType.getElementTypeBitWidth(); 886 else 887 bitWidth = resType.getIntOrFloatBitWidth(); 888 return constFoldCastOp<IntegerAttr, IntegerAttr>( 889 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 890 return a.sext(bitWidth); 891 }); 892 } 893 894 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 895 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 896 } 897 898 void arith::ExtSIOp::getCanonicalizationPatterns( 899 RewritePatternSet &patterns, MLIRContext *context) { 900 patterns.add<ExtSIOfExtUI>(context); 901 } 902 903 LogicalResult arith::ExtSIOp::verify() { 904 return verifyExtOp<IntegerType>(*this); 905 } 906 907 //===----------------------------------------------------------------------===// 908 // ExtFOp 909 //===----------------------------------------------------------------------===// 910 911 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 912 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 913 } 914 915 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } 916 917 //===----------------------------------------------------------------------===// 918 // TruncIOp 919 //===----------------------------------------------------------------------===// 920 921 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 922 assert(operands.size() == 1 && "unary operation takes one operand"); 923 924 // trunci(zexti(a)) -> a 925 // trunci(sexti(a)) -> a 926 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 927 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 928 return getOperand().getDefiningOp()->getOperand(0); 929 930 // trunci(trunci(a)) -> trunci(a)) 931 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { 932 setOperand(getOperand().getDefiningOp()->getOperand(0)); 933 return getResult(); 934 } 935 936 Type resType = getType(); 937 unsigned bitWidth; 938 if (auto shapedType = resType.dyn_cast<ShapedType>()) 939 bitWidth = shapedType.getElementTypeBitWidth(); 940 else 941 bitWidth = resType.getIntOrFloatBitWidth(); 942 943 return constFoldCastOp<IntegerAttr, IntegerAttr>( 944 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 945 return a.trunc(bitWidth); 946 }); 947 } 948 949 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 950 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 951 } 952 953 LogicalResult arith::TruncIOp::verify() { 954 return verifyTruncateOp<IntegerType>(*this); 955 } 956 957 //===----------------------------------------------------------------------===// 958 // TruncFOp 959 //===----------------------------------------------------------------------===// 960 961 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 962 /// can be represented without precision loss or rounding. 963 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 964 assert(operands.size() == 1 && "unary operation takes one operand"); 965 966 auto constOperand = operands.front(); 967 if (!constOperand || !constOperand.isa<FloatAttr>()) 968 return {}; 969 970 // Convert to target type via 'double'. 971 double sourceValue = 972 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 973 auto targetAttr = FloatAttr::get(getType(), sourceValue); 974 975 // Propagate if constant's value does not change after truncation. 976 if (sourceValue == targetAttr.getValue().convertToDouble()) 977 return targetAttr; 978 979 return {}; 980 } 981 982 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 983 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 984 } 985 986 LogicalResult arith::TruncFOp::verify() { 987 return verifyTruncateOp<FloatType>(*this); 988 } 989 990 //===----------------------------------------------------------------------===// 991 // AndIOp 992 //===----------------------------------------------------------------------===// 993 994 void arith::AndIOp::getCanonicalizationPatterns( 995 RewritePatternSet &patterns, MLIRContext *context) { 996 patterns.add<AndOfExtUI, AndOfExtSI>(context); 997 } 998 999 //===----------------------------------------------------------------------===// 1000 // OrIOp 1001 //===----------------------------------------------------------------------===// 1002 1003 void arith::OrIOp::getCanonicalizationPatterns( 1004 RewritePatternSet &patterns, MLIRContext *context) { 1005 patterns.add<OrOfExtUI, OrOfExtSI>(context); 1006 } 1007 1008 //===----------------------------------------------------------------------===// 1009 // Verifiers for casts between integers and floats. 1010 //===----------------------------------------------------------------------===// 1011 1012 template <typename From, typename To> 1013 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 1014 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1015 return false; 1016 1017 auto srcType = getTypeIfLike<From>(inputs.front()); 1018 auto dstType = getTypeIfLike<To>(outputs.back()); 1019 1020 return srcType && dstType; 1021 } 1022 1023 //===----------------------------------------------------------------------===// 1024 // UIToFPOp 1025 //===----------------------------------------------------------------------===// 1026 1027 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1028 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1029 } 1030 1031 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 1032 Type resType = getType(); 1033 Type resEleType; 1034 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1035 resEleType = shapedType.getElementType(); 1036 else 1037 resEleType = resType; 1038 return constFoldCastOp<IntegerAttr, FloatAttr>( 1039 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { 1040 FloatType floatTy = resEleType.cast<FloatType>(); 1041 APFloat apf(floatTy.getFloatSemantics(), 1042 APInt::getZero(floatTy.getWidth())); 1043 apf.convertFromAPInt(a, /*IsSigned=*/false, 1044 APFloat::rmNearestTiesToEven); 1045 return apf; 1046 }); 1047 } 1048 1049 //===----------------------------------------------------------------------===// 1050 // SIToFPOp 1051 //===----------------------------------------------------------------------===// 1052 1053 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1054 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1055 } 1056 1057 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 1058 Type resType = getType(); 1059 Type resEleType; 1060 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1061 resEleType = shapedType.getElementType(); 1062 else 1063 resEleType = resType; 1064 return constFoldCastOp<IntegerAttr, FloatAttr>( 1065 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { 1066 FloatType floatTy = resEleType.cast<FloatType>(); 1067 APFloat apf(floatTy.getFloatSemantics(), 1068 APInt::getZero(floatTy.getWidth())); 1069 apf.convertFromAPInt(a, /*IsSigned=*/true, 1070 APFloat::rmNearestTiesToEven); 1071 return apf; 1072 }); 1073 } 1074 //===----------------------------------------------------------------------===// 1075 // FPToUIOp 1076 //===----------------------------------------------------------------------===// 1077 1078 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1079 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1080 } 1081 1082 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 1083 Type resType = getType(); 1084 Type resEleType; 1085 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1086 resEleType = shapedType.getElementType(); 1087 else 1088 resEleType = resType; 1089 return constFoldCastOp<FloatAttr, IntegerAttr>( 1090 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { 1091 IntegerType intTy = resEleType.cast<IntegerType>(); 1092 bool ignored; 1093 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 1094 castStatus = APFloat::opInvalidOp != 1095 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1096 return api; 1097 }); 1098 } 1099 1100 //===----------------------------------------------------------------------===// 1101 // FPToSIOp 1102 //===----------------------------------------------------------------------===// 1103 1104 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1105 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1106 } 1107 1108 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 1109 Type resType = getType(); 1110 Type resEleType; 1111 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1112 resEleType = shapedType.getElementType(); 1113 else 1114 resEleType = resType; 1115 return constFoldCastOp<FloatAttr, IntegerAttr>( 1116 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { 1117 IntegerType intTy = resEleType.cast<IntegerType>(); 1118 bool ignored; 1119 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 1120 castStatus = APFloat::opInvalidOp != 1121 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1122 return api; 1123 }); 1124 } 1125 1126 //===----------------------------------------------------------------------===// 1127 // IndexCastOp 1128 //===----------------------------------------------------------------------===// 1129 1130 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1131 TypeRange outputs) { 1132 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1133 return false; 1134 1135 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1136 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1137 if (!srcType || !dstType) 1138 return false; 1139 1140 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1141 (srcType.isSignlessInteger() && dstType.isIndex()); 1142 } 1143 1144 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1145 // index_cast(constant) -> constant 1146 // A little hack because we go through int. Otherwise, the size of the 1147 // constant might need to change. 1148 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1149 return IntegerAttr::get(getType(), value.getInt()); 1150 1151 return {}; 1152 } 1153 1154 void arith::IndexCastOp::getCanonicalizationPatterns( 1155 RewritePatternSet &patterns, MLIRContext *context) { 1156 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1157 } 1158 1159 //===----------------------------------------------------------------------===// 1160 // BitcastOp 1161 //===----------------------------------------------------------------------===// 1162 1163 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1164 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1165 return false; 1166 1167 auto srcType = 1168 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1169 auto dstType = 1170 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1171 if (!srcType || !dstType) 1172 return false; 1173 1174 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1175 } 1176 1177 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1178 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1179 1180 auto resType = getType(); 1181 auto operand = operands[0]; 1182 if (!operand) 1183 return {}; 1184 1185 /// Bitcast dense elements. 1186 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1187 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1188 /// Other shaped types unhandled. 1189 if (resType.isa<ShapedType>()) 1190 return {}; 1191 1192 /// Bitcast integer or float to integer or float. 1193 APInt bits = operand.isa<FloatAttr>() 1194 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1195 : operand.cast<IntegerAttr>().getValue(); 1196 1197 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1198 return FloatAttr::get(resType, 1199 APFloat(resFloatType.getFloatSemantics(), bits)); 1200 return IntegerAttr::get(resType, bits); 1201 } 1202 1203 void arith::BitcastOp::getCanonicalizationPatterns( 1204 RewritePatternSet &patterns, MLIRContext *context) { 1205 patterns.add<BitcastOfBitcast>(context); 1206 } 1207 1208 //===----------------------------------------------------------------------===// 1209 // Helpers for compare ops 1210 //===----------------------------------------------------------------------===// 1211 1212 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1213 static Type getI1SameShape(Type type) { 1214 auto i1Type = IntegerType::get(type.getContext(), 1); 1215 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1216 return RankedTensorType::get(tensorType.getShape(), i1Type); 1217 if (type.isa<UnrankedTensorType>()) 1218 return UnrankedTensorType::get(i1Type); 1219 if (auto vectorType = type.dyn_cast<VectorType>()) 1220 return VectorType::get(vectorType.getShape(), i1Type, 1221 vectorType.getNumScalableDims()); 1222 return i1Type; 1223 } 1224 1225 //===----------------------------------------------------------------------===// 1226 // CmpIOp 1227 //===----------------------------------------------------------------------===// 1228 1229 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1230 /// comparison predicates. 1231 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1232 const APInt &lhs, const APInt &rhs) { 1233 switch (predicate) { 1234 case arith::CmpIPredicate::eq: 1235 return lhs.eq(rhs); 1236 case arith::CmpIPredicate::ne: 1237 return lhs.ne(rhs); 1238 case arith::CmpIPredicate::slt: 1239 return lhs.slt(rhs); 1240 case arith::CmpIPredicate::sle: 1241 return lhs.sle(rhs); 1242 case arith::CmpIPredicate::sgt: 1243 return lhs.sgt(rhs); 1244 case arith::CmpIPredicate::sge: 1245 return lhs.sge(rhs); 1246 case arith::CmpIPredicate::ult: 1247 return lhs.ult(rhs); 1248 case arith::CmpIPredicate::ule: 1249 return lhs.ule(rhs); 1250 case arith::CmpIPredicate::ugt: 1251 return lhs.ugt(rhs); 1252 case arith::CmpIPredicate::uge: 1253 return lhs.uge(rhs); 1254 } 1255 llvm_unreachable("unknown cmpi predicate kind"); 1256 } 1257 1258 /// Returns true if the predicate is true for two equal operands. 1259 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1260 switch (predicate) { 1261 case arith::CmpIPredicate::eq: 1262 case arith::CmpIPredicate::sle: 1263 case arith::CmpIPredicate::sge: 1264 case arith::CmpIPredicate::ule: 1265 case arith::CmpIPredicate::uge: 1266 return true; 1267 case arith::CmpIPredicate::ne: 1268 case arith::CmpIPredicate::slt: 1269 case arith::CmpIPredicate::sgt: 1270 case arith::CmpIPredicate::ult: 1271 case arith::CmpIPredicate::ugt: 1272 return false; 1273 } 1274 llvm_unreachable("unknown cmpi predicate kind"); 1275 } 1276 1277 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1278 auto boolAttr = BoolAttr::get(ctx, value); 1279 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1280 if (!shapedType) 1281 return boolAttr; 1282 return DenseElementsAttr::get(shapedType, boolAttr); 1283 } 1284 1285 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1286 assert(operands.size() == 2 && "cmpi takes two operands"); 1287 1288 // cmpi(pred, x, x) 1289 if (getLhs() == getRhs()) { 1290 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1291 return getBoolAttribute(getType(), getContext(), val); 1292 } 1293 1294 if (matchPattern(getRhs(), m_Zero())) { 1295 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1296 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1297 // extsi(%x : i1 -> iN) != 0 -> %x 1298 if (getPredicate() == arith::CmpIPredicate::ne) { 1299 return extOp.getOperand(); 1300 } 1301 } 1302 } 1303 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1304 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1305 // extui(%x : i1 -> iN) != 0 -> %x 1306 if (getPredicate() == arith::CmpIPredicate::ne) { 1307 return extOp.getOperand(); 1308 } 1309 } 1310 } 1311 } 1312 1313 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1314 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1315 if (!lhs || !rhs) 1316 return {}; 1317 1318 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1319 return BoolAttr::get(getContext(), val); 1320 } 1321 1322 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1323 MLIRContext *context) { 1324 patterns.insert<CmpIExtSI, CmpIExtUI>(context); 1325 } 1326 1327 //===----------------------------------------------------------------------===// 1328 // CmpFOp 1329 //===----------------------------------------------------------------------===// 1330 1331 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1332 /// comparison predicates. 1333 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1334 const APFloat &lhs, const APFloat &rhs) { 1335 auto cmpResult = lhs.compare(rhs); 1336 switch (predicate) { 1337 case arith::CmpFPredicate::AlwaysFalse: 1338 return false; 1339 case arith::CmpFPredicate::OEQ: 1340 return cmpResult == APFloat::cmpEqual; 1341 case arith::CmpFPredicate::OGT: 1342 return cmpResult == APFloat::cmpGreaterThan; 1343 case arith::CmpFPredicate::OGE: 1344 return cmpResult == APFloat::cmpGreaterThan || 1345 cmpResult == APFloat::cmpEqual; 1346 case arith::CmpFPredicate::OLT: 1347 return cmpResult == APFloat::cmpLessThan; 1348 case arith::CmpFPredicate::OLE: 1349 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1350 case arith::CmpFPredicate::ONE: 1351 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1352 case arith::CmpFPredicate::ORD: 1353 return cmpResult != APFloat::cmpUnordered; 1354 case arith::CmpFPredicate::UEQ: 1355 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1356 case arith::CmpFPredicate::UGT: 1357 return cmpResult == APFloat::cmpUnordered || 1358 cmpResult == APFloat::cmpGreaterThan; 1359 case arith::CmpFPredicate::UGE: 1360 return cmpResult == APFloat::cmpUnordered || 1361 cmpResult == APFloat::cmpGreaterThan || 1362 cmpResult == APFloat::cmpEqual; 1363 case arith::CmpFPredicate::ULT: 1364 return cmpResult == APFloat::cmpUnordered || 1365 cmpResult == APFloat::cmpLessThan; 1366 case arith::CmpFPredicate::ULE: 1367 return cmpResult == APFloat::cmpUnordered || 1368 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1369 case arith::CmpFPredicate::UNE: 1370 return cmpResult != APFloat::cmpEqual; 1371 case arith::CmpFPredicate::UNO: 1372 return cmpResult == APFloat::cmpUnordered; 1373 case arith::CmpFPredicate::AlwaysTrue: 1374 return true; 1375 } 1376 llvm_unreachable("unknown cmpf predicate kind"); 1377 } 1378 1379 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1380 assert(operands.size() == 2 && "cmpf takes two operands"); 1381 1382 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1383 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1384 1385 // If one operand is NaN, making them both NaN does not change the result. 1386 if (lhs && lhs.getValue().isNaN()) 1387 rhs = lhs; 1388 if (rhs && rhs.getValue().isNaN()) 1389 lhs = rhs; 1390 1391 if (!lhs || !rhs) 1392 return {}; 1393 1394 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1395 return BoolAttr::get(getContext(), val); 1396 } 1397 1398 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> { 1399 public: 1400 using OpRewritePattern<CmpFOp>::OpRewritePattern; 1401 1402 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, 1403 bool isUnsigned) { 1404 using namespace arith; 1405 switch (pred) { 1406 case CmpFPredicate::UEQ: 1407 case CmpFPredicate::OEQ: 1408 return CmpIPredicate::eq; 1409 case CmpFPredicate::UGT: 1410 case CmpFPredicate::OGT: 1411 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; 1412 case CmpFPredicate::UGE: 1413 case CmpFPredicate::OGE: 1414 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; 1415 case CmpFPredicate::ULT: 1416 case CmpFPredicate::OLT: 1417 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; 1418 case CmpFPredicate::ULE: 1419 case CmpFPredicate::OLE: 1420 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; 1421 case CmpFPredicate::UNE: 1422 case CmpFPredicate::ONE: 1423 return CmpIPredicate::ne; 1424 default: 1425 llvm_unreachable("Unexpected predicate!"); 1426 } 1427 } 1428 1429 LogicalResult matchAndRewrite(CmpFOp op, 1430 PatternRewriter &rewriter) const override { 1431 FloatAttr flt; 1432 if (!matchPattern(op.getRhs(), m_Constant(&flt))) 1433 return failure(); 1434 1435 const APFloat &rhs = flt.getValue(); 1436 1437 // Don't attempt to fold a nan. 1438 if (rhs.isNaN()) 1439 return failure(); 1440 1441 // Get the width of the mantissa. We don't want to hack on conversions that 1442 // might lose information from the integer, e.g. "i64 -> float" 1443 FloatType floatTy = op.getRhs().getType().cast<FloatType>(); 1444 int mantissaWidth = floatTy.getFPMantissaWidth(); 1445 if (mantissaWidth <= 0) 1446 return failure(); 1447 1448 bool isUnsigned; 1449 Value intVal; 1450 1451 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) { 1452 isUnsigned = false; 1453 intVal = si.getIn(); 1454 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) { 1455 isUnsigned = true; 1456 intVal = ui.getIn(); 1457 } else { 1458 return failure(); 1459 } 1460 1461 // Check to see that the input is converted from an integer type that is 1462 // small enough that preserves all bits. 1463 auto intTy = intVal.getType().cast<IntegerType>(); 1464 auto intWidth = intTy.getWidth(); 1465 1466 // Number of bits representing values, as opposed to the sign 1467 auto valueBits = isUnsigned ? intWidth : (intWidth - 1); 1468 1469 // Following test does NOT adjust intWidth downwards for signed inputs, 1470 // because the most negative value still requires all the mantissa bits 1471 // to distinguish it from one less than that value. 1472 if ((int)intWidth > mantissaWidth) { 1473 // Conversion would lose accuracy. Check if loss can impact comparison. 1474 int exponent = ilogb(rhs); 1475 if (exponent == APFloat::IEK_Inf) { 1476 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics())); 1477 if (maxExponent < (int)valueBits) { 1478 // Conversion could create infinity. 1479 return failure(); 1480 } 1481 } else { 1482 // Note that if rhs is zero or NaN, then Exp is negative 1483 // and first condition is trivially false. 1484 if (mantissaWidth <= exponent && exponent <= (int)valueBits) { 1485 // Conversion could affect comparison. 1486 return failure(); 1487 } 1488 } 1489 } 1490 1491 // Convert to equivalent cmpi predicate 1492 CmpIPredicate pred; 1493 switch (op.getPredicate()) { 1494 case CmpFPredicate::ORD: 1495 // Int to fp conversion doesn't create a nan (ord checks neither is a nan) 1496 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1497 /*width=*/1); 1498 return success(); 1499 case CmpFPredicate::UNO: 1500 // Int to fp conversion doesn't create a nan (uno checks either is a nan) 1501 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1502 /*width=*/1); 1503 return success(); 1504 default: 1505 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned); 1506 break; 1507 } 1508 1509 if (!isUnsigned) { 1510 // If the rhs value is > SignedMax, fold the comparison. This handles 1511 // +INF and large values. 1512 APFloat signedMax(rhs.getSemantics()); 1513 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true, 1514 APFloat::rmNearestTiesToEven); 1515 if (signedMax < rhs) { // smax < 13123.0 1516 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt || 1517 pred == CmpIPredicate::sle) 1518 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1519 /*width=*/1); 1520 else 1521 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1522 /*width=*/1); 1523 return success(); 1524 } 1525 } else { 1526 // If the rhs value is > UnsignedMax, fold the comparison. This handles 1527 // +INF and large values. 1528 APFloat unsignedMax(rhs.getSemantics()); 1529 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false, 1530 APFloat::rmNearestTiesToEven); 1531 if (unsignedMax < rhs) { // umax < 13123.0 1532 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult || 1533 pred == CmpIPredicate::ule) 1534 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1535 /*width=*/1); 1536 else 1537 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1538 /*width=*/1); 1539 return success(); 1540 } 1541 } 1542 1543 if (!isUnsigned) { 1544 // See if the rhs value is < SignedMin. 1545 APFloat signedMin(rhs.getSemantics()); 1546 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true, 1547 APFloat::rmNearestTiesToEven); 1548 if (signedMin > rhs) { // smin > 12312.0 1549 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt || 1550 pred == CmpIPredicate::sge) 1551 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1552 /*width=*/1); 1553 else 1554 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1555 /*width=*/1); 1556 return success(); 1557 } 1558 } else { 1559 // See if the rhs value is < UnsignedMin. 1560 APFloat unsignedMin(rhs.getSemantics()); 1561 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false, 1562 APFloat::rmNearestTiesToEven); 1563 if (unsignedMin > rhs) { // umin > 12312.0 1564 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt || 1565 pred == CmpIPredicate::uge) 1566 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1567 /*width=*/1); 1568 else 1569 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1570 /*width=*/1); 1571 return success(); 1572 } 1573 } 1574 1575 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or 1576 // [0, UMAX], but it may still be fractional. See if it is fractional by 1577 // casting the FP value to the integer value and back, checking for 1578 // equality. Don't do this for zero, because -0.0 is not fractional. 1579 bool ignored; 1580 APSInt rhsInt(intWidth, isUnsigned); 1581 if (APFloat::opInvalidOp == 1582 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) { 1583 // Undefined behavior invoked - the destination type can't represent 1584 // the input constant. 1585 return failure(); 1586 } 1587 1588 if (!rhs.isZero()) { 1589 APFloat apf(floatTy.getFloatSemantics(), 1590 APInt::getZero(floatTy.getWidth())); 1591 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven); 1592 1593 bool equal = apf == rhs; 1594 if (!equal) { 1595 // If we had a comparison against a fractional value, we have to adjust 1596 // the compare predicate and sometimes the value. rhsInt is rounded 1597 // towards zero at this point. 1598 switch (pred) { 1599 case CmpIPredicate::ne: // (float)int != 4.4 --> true 1600 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1601 /*width=*/1); 1602 return success(); 1603 case CmpIPredicate::eq: // (float)int == 4.4 --> false 1604 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1605 /*width=*/1); 1606 return success(); 1607 case CmpIPredicate::ule: 1608 // (float)int <= 4.4 --> int <= 4 1609 // (float)int <= -4.4 --> false 1610 if (rhs.isNegative()) { 1611 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1612 /*width=*/1); 1613 return success(); 1614 } 1615 break; 1616 case CmpIPredicate::sle: 1617 // (float)int <= 4.4 --> int <= 4 1618 // (float)int <= -4.4 --> int < -4 1619 if (rhs.isNegative()) 1620 pred = CmpIPredicate::slt; 1621 break; 1622 case CmpIPredicate::ult: 1623 // (float)int < -4.4 --> false 1624 // (float)int < 4.4 --> int <= 4 1625 if (rhs.isNegative()) { 1626 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1627 /*width=*/1); 1628 return success(); 1629 } 1630 pred = CmpIPredicate::ule; 1631 break; 1632 case CmpIPredicate::slt: 1633 // (float)int < -4.4 --> int < -4 1634 // (float)int < 4.4 --> int <= 4 1635 if (!rhs.isNegative()) 1636 pred = CmpIPredicate::sle; 1637 break; 1638 case CmpIPredicate::ugt: 1639 // (float)int > 4.4 --> int > 4 1640 // (float)int > -4.4 --> true 1641 if (rhs.isNegative()) { 1642 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1643 /*width=*/1); 1644 return success(); 1645 } 1646 break; 1647 case CmpIPredicate::sgt: 1648 // (float)int > 4.4 --> int > 4 1649 // (float)int > -4.4 --> int >= -4 1650 if (rhs.isNegative()) 1651 pred = CmpIPredicate::sge; 1652 break; 1653 case CmpIPredicate::uge: 1654 // (float)int >= -4.4 --> true 1655 // (float)int >= 4.4 --> int > 4 1656 if (rhs.isNegative()) { 1657 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1658 /*width=*/1); 1659 return success(); 1660 } 1661 pred = CmpIPredicate::ugt; 1662 break; 1663 case CmpIPredicate::sge: 1664 // (float)int >= -4.4 --> int >= -4 1665 // (float)int >= 4.4 --> int > 4 1666 if (!rhs.isNegative()) 1667 pred = CmpIPredicate::sgt; 1668 break; 1669 } 1670 } 1671 } 1672 1673 // Lower this FP comparison into an appropriate integer version of the 1674 // comparison. 1675 rewriter.replaceOpWithNewOp<CmpIOp>( 1676 op, pred, intVal, 1677 rewriter.create<ConstantOp>( 1678 op.getLoc(), intVal.getType(), 1679 rewriter.getIntegerAttr(intVal.getType(), rhsInt))); 1680 return success(); 1681 } 1682 }; 1683 1684 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1685 MLIRContext *context) { 1686 patterns.insert<CmpFIntToFPConst>(context); 1687 } 1688 1689 //===----------------------------------------------------------------------===// 1690 // SelectOp 1691 //===----------------------------------------------------------------------===// 1692 1693 // Transforms a select of a boolean to arithmetic operations 1694 // 1695 // arith.select %arg, %x, %y : i1 1696 // 1697 // becomes 1698 // 1699 // and(%arg, %x) or and(!%arg, %y) 1700 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> { 1701 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1702 1703 LogicalResult matchAndRewrite(arith::SelectOp op, 1704 PatternRewriter &rewriter) const override { 1705 if (!op.getType().isInteger(1)) 1706 return failure(); 1707 1708 Value falseConstant = 1709 rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1); 1710 Value notCondition = rewriter.create<arith::XOrIOp>( 1711 op.getLoc(), op.getCondition(), falseConstant); 1712 1713 Value trueVal = rewriter.create<arith::AndIOp>( 1714 op.getLoc(), op.getCondition(), op.getTrueValue()); 1715 Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition, 1716 op.getFalseValue()); 1717 rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal); 1718 return success(); 1719 } 1720 }; 1721 1722 // select %arg, %c1, %c0 => extui %arg 1723 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { 1724 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1725 1726 LogicalResult matchAndRewrite(arith::SelectOp op, 1727 PatternRewriter &rewriter) const override { 1728 // Cannot extui i1 to i1, or i1 to f32 1729 if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1)) 1730 return failure(); 1731 1732 // select %x, c1, %c0 => extui %arg 1733 if (matchPattern(op.getTrueValue(), m_One())) 1734 if (matchPattern(op.getFalseValue(), m_Zero())) { 1735 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(), 1736 op.getCondition()); 1737 return success(); 1738 } 1739 1740 // select %x, c0, %c1 => extui (xor %arg, true) 1741 if (matchPattern(op.getTrueValue(), m_Zero())) 1742 if (matchPattern(op.getFalseValue(), m_One())) { 1743 rewriter.replaceOpWithNewOp<arith::ExtUIOp>( 1744 op, op.getType(), 1745 rewriter.create<arith::XOrIOp>( 1746 op.getLoc(), op.getCondition(), 1747 rewriter.create<arith::ConstantIntOp>( 1748 op.getLoc(), 1, op.getCondition().getType()))); 1749 return success(); 1750 } 1751 1752 return failure(); 1753 } 1754 }; 1755 1756 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, 1757 MLIRContext *context) { 1758 results.add<SelectI1Simplify, SelectToExtUI>(context); 1759 } 1760 1761 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) { 1762 Value trueVal = getTrueValue(); 1763 Value falseVal = getFalseValue(); 1764 if (trueVal == falseVal) 1765 return trueVal; 1766 1767 Value condition = getCondition(); 1768 1769 // select true, %0, %1 => %0 1770 if (matchPattern(condition, m_One())) 1771 return trueVal; 1772 1773 // select false, %0, %1 => %1 1774 if (matchPattern(condition, m_Zero())) 1775 return falseVal; 1776 1777 // select %x, true, false => %x 1778 if (getType().isInteger(1)) 1779 if (matchPattern(getTrueValue(), m_One())) 1780 if (matchPattern(getFalseValue(), m_Zero())) 1781 return condition; 1782 1783 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { 1784 auto pred = cmp.getPredicate(); 1785 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { 1786 auto cmpLhs = cmp.getLhs(); 1787 auto cmpRhs = cmp.getRhs(); 1788 1789 // %0 = arith.cmpi eq, %arg0, %arg1 1790 // %1 = arith.select %0, %arg0, %arg1 => %arg1 1791 1792 // %0 = arith.cmpi ne, %arg0, %arg1 1793 // %1 = arith.select %0, %arg0, %arg1 => %arg0 1794 1795 if ((cmpLhs == trueVal && cmpRhs == falseVal) || 1796 (cmpRhs == trueVal && cmpLhs == falseVal)) 1797 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; 1798 } 1799 } 1800 return nullptr; 1801 } 1802 1803 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { 1804 Type conditionType, resultType; 1805 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; 1806 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || 1807 parser.parseOptionalAttrDict(result.attributes) || 1808 parser.parseColonType(resultType)) 1809 return failure(); 1810 1811 // Check for the explicit condition type if this is a masked tensor or vector. 1812 if (succeeded(parser.parseOptionalComma())) { 1813 conditionType = resultType; 1814 if (parser.parseType(resultType)) 1815 return failure(); 1816 } else { 1817 conditionType = parser.getBuilder().getI1Type(); 1818 } 1819 1820 result.addTypes(resultType); 1821 return parser.resolveOperands(operands, 1822 {conditionType, resultType, resultType}, 1823 parser.getNameLoc(), result.operands); 1824 } 1825 1826 void arith::SelectOp::print(OpAsmPrinter &p) { 1827 p << " " << getOperands(); 1828 p.printOptionalAttrDict((*this)->getAttrs()); 1829 p << " : "; 1830 if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>()) 1831 p << condType << ", "; 1832 p << getType(); 1833 } 1834 1835 LogicalResult arith::SelectOp::verify() { 1836 Type conditionType = getCondition().getType(); 1837 if (conditionType.isSignlessInteger(1)) 1838 return success(); 1839 1840 // If the result type is a vector or tensor, the type can be a mask with the 1841 // same elements. 1842 Type resultType = getType(); 1843 if (!resultType.isa<TensorType, VectorType>()) 1844 return emitOpError() << "expected condition to be a signless i1, but got " 1845 << conditionType; 1846 Type shapedConditionType = getI1SameShape(resultType); 1847 if (conditionType != shapedConditionType) { 1848 return emitOpError() << "expected condition type to have the same shape " 1849 "as the result type, expected " 1850 << shapedConditionType << ", but got " 1851 << conditionType; 1852 } 1853 return success(); 1854 } 1855 //===----------------------------------------------------------------------===// 1856 // ShLIOp 1857 //===----------------------------------------------------------------------===// 1858 1859 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) { 1860 // Don't fold if shifting more than the bit width. 1861 bool bounded = false; 1862 auto result = constFoldBinaryOp<IntegerAttr>( 1863 operands, [&](const APInt &a, const APInt &b) { 1864 bounded = b.ule(b.getBitWidth()); 1865 return a.shl(b); 1866 }); 1867 return bounded ? result : Attribute(); 1868 } 1869 1870 //===----------------------------------------------------------------------===// 1871 // ShRUIOp 1872 //===----------------------------------------------------------------------===// 1873 1874 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) { 1875 // Don't fold if shifting more than the bit width. 1876 bool bounded = false; 1877 auto result = constFoldBinaryOp<IntegerAttr>( 1878 operands, [&](const APInt &a, const APInt &b) { 1879 bounded = b.ule(b.getBitWidth()); 1880 return a.lshr(b); 1881 }); 1882 return bounded ? result : Attribute(); 1883 } 1884 1885 //===----------------------------------------------------------------------===// 1886 // ShRSIOp 1887 //===----------------------------------------------------------------------===// 1888 1889 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) { 1890 // Don't fold if shifting more than the bit width. 1891 bool bounded = false; 1892 auto result = constFoldBinaryOp<IntegerAttr>( 1893 operands, [&](const APInt &a, const APInt &b) { 1894 bounded = b.ule(b.getBitWidth()); 1895 return a.ashr(b); 1896 }); 1897 return bounded ? result : Attribute(); 1898 } 1899 1900 //===----------------------------------------------------------------------===// 1901 // Atomic Enum 1902 //===----------------------------------------------------------------------===// 1903 1904 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1905 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1906 OpBuilder &builder, Location loc) { 1907 switch (kind) { 1908 case AtomicRMWKind::maxf: 1909 return builder.getFloatAttr( 1910 resultType, 1911 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1912 /*Negative=*/true)); 1913 case AtomicRMWKind::addf: 1914 case AtomicRMWKind::addi: 1915 case AtomicRMWKind::maxu: 1916 case AtomicRMWKind::ori: 1917 return builder.getZeroAttr(resultType); 1918 case AtomicRMWKind::andi: 1919 return builder.getIntegerAttr( 1920 resultType, 1921 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1922 case AtomicRMWKind::maxs: 1923 return builder.getIntegerAttr( 1924 resultType, 1925 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1926 case AtomicRMWKind::minf: 1927 return builder.getFloatAttr( 1928 resultType, 1929 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1930 /*Negative=*/false)); 1931 case AtomicRMWKind::mins: 1932 return builder.getIntegerAttr( 1933 resultType, 1934 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1935 case AtomicRMWKind::minu: 1936 return builder.getIntegerAttr( 1937 resultType, 1938 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1939 case AtomicRMWKind::muli: 1940 return builder.getIntegerAttr(resultType, 1); 1941 case AtomicRMWKind::mulf: 1942 return builder.getFloatAttr(resultType, 1); 1943 // TODO: Add remaining reduction operations. 1944 default: 1945 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1946 break; 1947 } 1948 return nullptr; 1949 } 1950 1951 /// Returns the identity value associated with an AtomicRMWKind op. 1952 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 1953 OpBuilder &builder, Location loc) { 1954 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 1955 return builder.create<arith::ConstantOp>(loc, attr); 1956 } 1957 1958 /// Return the value obtained by applying the reduction operation kind 1959 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 1960 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 1961 Location loc, Value lhs, Value rhs) { 1962 switch (op) { 1963 case AtomicRMWKind::addf: 1964 return builder.create<arith::AddFOp>(loc, lhs, rhs); 1965 case AtomicRMWKind::addi: 1966 return builder.create<arith::AddIOp>(loc, lhs, rhs); 1967 case AtomicRMWKind::mulf: 1968 return builder.create<arith::MulFOp>(loc, lhs, rhs); 1969 case AtomicRMWKind::muli: 1970 return builder.create<arith::MulIOp>(loc, lhs, rhs); 1971 case AtomicRMWKind::maxf: 1972 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 1973 case AtomicRMWKind::minf: 1974 return builder.create<arith::MinFOp>(loc, lhs, rhs); 1975 case AtomicRMWKind::maxs: 1976 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 1977 case AtomicRMWKind::mins: 1978 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 1979 case AtomicRMWKind::maxu: 1980 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 1981 case AtomicRMWKind::minu: 1982 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 1983 case AtomicRMWKind::ori: 1984 return builder.create<arith::OrIOp>(loc, lhs, rhs); 1985 case AtomicRMWKind::andi: 1986 return builder.create<arith::AndIOp>(loc, lhs, rhs); 1987 // TODO: Add remaining reduction operations. 1988 default: 1989 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1990 break; 1991 } 1992 return nullptr; 1993 } 1994 1995 //===----------------------------------------------------------------------===// 1996 // TableGen'd op method definitions 1997 //===----------------------------------------------------------------------===// 1998 1999 #define GET_OP_CLASSES 2000 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 2001 2002 //===----------------------------------------------------------------------===// 2003 // TableGen'd enum attribute definitions 2004 //===----------------------------------------------------------------------===// 2005 2006 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 2007