1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "llvm/ADT/SmallString.h" 19 20 #include "llvm/ADT/APSInt.h" 21 22 using namespace mlir; 23 using namespace mlir::arith; 24 25 //===----------------------------------------------------------------------===// 26 // Pattern helpers 27 //===----------------------------------------------------------------------===// 28 29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 30 Attribute lhs, Attribute rhs) { 31 return builder.getIntegerAttr(res.getType(), 32 lhs.cast<IntegerAttr>().getInt() + 33 rhs.cast<IntegerAttr>().getInt()); 34 } 35 36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 37 Attribute lhs, Attribute rhs) { 38 return builder.getIntegerAttr(res.getType(), 39 lhs.cast<IntegerAttr>().getInt() - 40 rhs.cast<IntegerAttr>().getInt()); 41 } 42 43 /// Invert an integer comparison predicate. 44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { 45 switch (pred) { 46 case arith::CmpIPredicate::eq: 47 return arith::CmpIPredicate::ne; 48 case arith::CmpIPredicate::ne: 49 return arith::CmpIPredicate::eq; 50 case arith::CmpIPredicate::slt: 51 return arith::CmpIPredicate::sge; 52 case arith::CmpIPredicate::sle: 53 return arith::CmpIPredicate::sgt; 54 case arith::CmpIPredicate::sgt: 55 return arith::CmpIPredicate::sle; 56 case arith::CmpIPredicate::sge: 57 return arith::CmpIPredicate::slt; 58 case arith::CmpIPredicate::ult: 59 return arith::CmpIPredicate::uge; 60 case arith::CmpIPredicate::ule: 61 return arith::CmpIPredicate::ugt; 62 case arith::CmpIPredicate::ugt: 63 return arith::CmpIPredicate::ule; 64 case arith::CmpIPredicate::uge: 65 return arith::CmpIPredicate::ult; 66 } 67 llvm_unreachable("unknown cmpi predicate kind"); 68 } 69 70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 71 return arith::CmpIPredicateAttr::get(pred.getContext(), 72 invertPredicate(pred.getValue())); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // TableGen'd canonicalization patterns 77 //===----------------------------------------------------------------------===// 78 79 namespace { 80 #include "ArithmeticCanonicalization.inc" 81 } // namespace 82 83 //===----------------------------------------------------------------------===// 84 // ConstantOp 85 //===----------------------------------------------------------------------===// 86 87 void arith::ConstantOp::getAsmResultNames( 88 function_ref<void(Value, StringRef)> setNameFn) { 89 auto type = getType(); 90 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 91 auto intType = type.dyn_cast<IntegerType>(); 92 93 // Sugar i1 constants with 'true' and 'false'. 94 if (intType && intType.getWidth() == 1) 95 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 96 97 // Otherwise, build a complex name with the value and type. 98 SmallString<32> specialNameBuffer; 99 llvm::raw_svector_ostream specialName(specialNameBuffer); 100 specialName << 'c' << intCst.getValue(); 101 if (intType) 102 specialName << '_' << type; 103 setNameFn(getResult(), specialName.str()); 104 } else { 105 setNameFn(getResult(), "cst"); 106 } 107 } 108 109 /// TODO: disallow arith.constant to return anything other than signless integer 110 /// or float like. 111 LogicalResult arith::ConstantOp::verify() { 112 auto type = getType(); 113 // The value's type must match the return type. 114 if (getValue().getType() != type) { 115 return emitOpError() << "value type " << getValue().getType() 116 << " must match return type: " << type; 117 } 118 // Integer values must be signless. 119 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 120 return emitOpError("integer return type must be signless"); 121 // Any float or elements attribute are acceptable. 122 if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 123 return emitOpError( 124 "value must be an integer, float, or elements attribute"); 125 } 126 return success(); 127 } 128 129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 130 // The value's type must be the same as the provided type. 131 if (value.getType() != type) 132 return false; 133 // Integer values must be signless. 134 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 135 return false; 136 // Integer, float, and element attributes are buildable. 137 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 138 } 139 140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 141 return getValue(); 142 } 143 144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 145 int64_t value, unsigned width) { 146 auto type = builder.getIntegerType(width); 147 arith::ConstantOp::build(builder, result, type, 148 builder.getIntegerAttr(type, value)); 149 } 150 151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 152 int64_t value, Type type) { 153 assert(type.isSignlessInteger() && 154 "ConstantIntOp can only have signless integer type values"); 155 arith::ConstantOp::build(builder, result, type, 156 builder.getIntegerAttr(type, value)); 157 } 158 159 bool arith::ConstantIntOp::classof(Operation *op) { 160 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 161 return constOp.getType().isSignlessInteger(); 162 return false; 163 } 164 165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 166 const APFloat &value, FloatType type) { 167 arith::ConstantOp::build(builder, result, type, 168 builder.getFloatAttr(type, value)); 169 } 170 171 bool arith::ConstantFloatOp::classof(Operation *op) { 172 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 173 return constOp.getType().isa<FloatType>(); 174 return false; 175 } 176 177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 178 int64_t value) { 179 arith::ConstantOp::build(builder, result, builder.getIndexType(), 180 builder.getIndexAttr(value)); 181 } 182 183 bool arith::ConstantIndexOp::classof(Operation *op) { 184 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 185 return constOp.getType().isIndex(); 186 return false; 187 } 188 189 //===----------------------------------------------------------------------===// 190 // AddIOp 191 //===----------------------------------------------------------------------===// 192 193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 194 // addi(x, 0) -> x 195 if (matchPattern(getRhs(), m_Zero())) 196 return getLhs(); 197 198 // addi(subi(a, b), b) -> a 199 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 200 if (getRhs() == sub.getRhs()) 201 return sub.getLhs(); 202 203 // addi(b, subi(a, b)) -> a 204 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 205 if (getLhs() == sub.getRhs()) 206 return sub.getLhs(); 207 208 return constFoldBinaryOp<IntegerAttr>( 209 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 210 } 211 212 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 213 MLIRContext *context) { 214 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 215 context); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // SubIOp 220 //===----------------------------------------------------------------------===// 221 222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 223 // subi(x,x) -> 0 224 if (getOperand(0) == getOperand(1)) 225 return Builder(getContext()).getZeroAttr(getType()); 226 // subi(x,0) -> x 227 if (matchPattern(getRhs(), m_Zero())) 228 return getLhs(); 229 230 return constFoldBinaryOp<IntegerAttr>( 231 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 232 } 233 234 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 235 MLIRContext *context) { 236 patterns 237 .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 238 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>( 239 context); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // MulIOp 244 //===----------------------------------------------------------------------===// 245 246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 247 // muli(x, 0) -> 0 248 if (matchPattern(getRhs(), m_Zero())) 249 return getRhs(); 250 // muli(x, 1) -> x 251 if (matchPattern(getRhs(), m_One())) 252 return getOperand(0); 253 // TODO: Handle the overflow case. 254 255 // default folder 256 return constFoldBinaryOp<IntegerAttr>( 257 operands, [](const APInt &a, const APInt &b) { return a * b; }); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // DivUIOp 262 //===----------------------------------------------------------------------===// 263 264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 265 // divui (x, 1) -> x. 266 if (matchPattern(getRhs(), m_One())) 267 return getLhs(); 268 269 // Don't fold if it would require a division by zero. 270 bool div0 = false; 271 auto result = 272 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 273 if (div0 || !b) { 274 div0 = true; 275 return a; 276 } 277 return a.udiv(b); 278 }); 279 280 return div0 ? Attribute() : result; 281 } 282 283 //===----------------------------------------------------------------------===// 284 // DivSIOp 285 //===----------------------------------------------------------------------===// 286 287 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 288 // divsi (x, 1) -> x. 289 if (matchPattern(getRhs(), m_One())) 290 return getLhs(); 291 292 // Don't fold if it would overflow or if it requires a division by zero. 293 bool overflowOrDiv0 = false; 294 auto result = 295 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 296 if (overflowOrDiv0 || !b) { 297 overflowOrDiv0 = true; 298 return a; 299 } 300 return a.sdiv_ov(b, overflowOrDiv0); 301 }); 302 303 return overflowOrDiv0 ? Attribute() : result; 304 } 305 306 //===----------------------------------------------------------------------===// 307 // Ceil and floor division folding helpers 308 //===----------------------------------------------------------------------===// 309 310 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 311 bool &overflow) { 312 // Returns (a-1)/b + 1 313 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 314 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 315 return val.sadd_ov(one, overflow); 316 } 317 318 //===----------------------------------------------------------------------===// 319 // CeilDivUIOp 320 //===----------------------------------------------------------------------===// 321 322 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 323 // ceildivui (x, 1) -> x. 324 if (matchPattern(getRhs(), m_One())) 325 return getLhs(); 326 327 bool overflowOrDiv0 = false; 328 auto result = 329 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 330 if (overflowOrDiv0 || !b) { 331 overflowOrDiv0 = true; 332 return a; 333 } 334 APInt quotient = a.udiv(b); 335 if (!a.urem(b)) 336 return quotient; 337 APInt one(a.getBitWidth(), 1, true); 338 return quotient.uadd_ov(one, overflowOrDiv0); 339 }); 340 341 return overflowOrDiv0 ? Attribute() : result; 342 } 343 344 //===----------------------------------------------------------------------===// 345 // CeilDivSIOp 346 //===----------------------------------------------------------------------===// 347 348 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 349 // ceildivsi (x, 1) -> x. 350 if (matchPattern(getRhs(), m_One())) 351 return getLhs(); 352 353 // Don't fold if it would overflow or if it requires a division by zero. 354 bool overflowOrDiv0 = false; 355 auto result = 356 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 357 if (overflowOrDiv0 || !b) { 358 overflowOrDiv0 = true; 359 return a; 360 } 361 if (!a) 362 return a; 363 // After this point we know that neither a or b are zero. 364 unsigned bits = a.getBitWidth(); 365 APInt zero = APInt::getZero(bits); 366 bool aGtZero = a.sgt(zero); 367 bool bGtZero = b.sgt(zero); 368 if (aGtZero && bGtZero) { 369 // Both positive, return ceil(a, b). 370 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 371 } 372 if (!aGtZero && !bGtZero) { 373 // Both negative, return ceil(-a, -b). 374 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 375 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 376 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 377 } 378 if (!aGtZero && bGtZero) { 379 // A is negative, b is positive, return - ( -a / b). 380 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 381 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 382 return zero.ssub_ov(div, overflowOrDiv0); 383 } 384 // A is positive, b is negative, return - (a / -b). 385 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 386 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 387 return zero.ssub_ov(div, overflowOrDiv0); 388 }); 389 390 return overflowOrDiv0 ? Attribute() : result; 391 } 392 393 //===----------------------------------------------------------------------===// 394 // FloorDivSIOp 395 //===----------------------------------------------------------------------===// 396 397 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 398 // floordivsi (x, 1) -> x. 399 if (matchPattern(getRhs(), m_One())) 400 return getLhs(); 401 402 // Don't fold if it would overflow or if it requires a division by zero. 403 bool overflowOrDiv0 = false; 404 auto result = 405 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 406 if (overflowOrDiv0 || !b) { 407 overflowOrDiv0 = true; 408 return a; 409 } 410 if (!a) 411 return a; 412 // After this point we know that neither a or b are zero. 413 unsigned bits = a.getBitWidth(); 414 APInt zero = APInt::getZero(bits); 415 bool aGtZero = a.sgt(zero); 416 bool bGtZero = b.sgt(zero); 417 if (aGtZero && bGtZero) { 418 // Both positive, return a / b. 419 return a.sdiv_ov(b, overflowOrDiv0); 420 } 421 if (!aGtZero && !bGtZero) { 422 // Both negative, return -a / -b. 423 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 424 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 425 return posA.sdiv_ov(posB, overflowOrDiv0); 426 } 427 if (!aGtZero && bGtZero) { 428 // A is negative, b is positive, return - ceil(-a, b). 429 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 430 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 431 return zero.ssub_ov(ceil, overflowOrDiv0); 432 } 433 // A is positive, b is negative, return - ceil(a, -b). 434 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 435 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 436 return zero.ssub_ov(ceil, overflowOrDiv0); 437 }); 438 439 return overflowOrDiv0 ? Attribute() : result; 440 } 441 442 //===----------------------------------------------------------------------===// 443 // RemUIOp 444 //===----------------------------------------------------------------------===// 445 446 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 447 // remui (x, 1) -> 0. 448 if (matchPattern(getRhs(), m_One())) 449 return Builder(getContext()).getZeroAttr(getType()); 450 451 // Don't fold if it would require a division by zero. 452 bool div0 = false; 453 auto result = 454 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 455 if (div0 || b.isNullValue()) { 456 div0 = true; 457 return a; 458 } 459 return a.urem(b); 460 }); 461 462 return div0 ? Attribute() : result; 463 } 464 465 //===----------------------------------------------------------------------===// 466 // RemSIOp 467 //===----------------------------------------------------------------------===// 468 469 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 470 // remsi (x, 1) -> 0. 471 if (matchPattern(getRhs(), m_One())) 472 return Builder(getContext()).getZeroAttr(getType()); 473 474 // Don't fold if it would require a division by zero. 475 bool div0 = false; 476 auto result = 477 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 478 if (div0 || b.isNullValue()) { 479 div0 = true; 480 return a; 481 } 482 return a.srem(b); 483 }); 484 485 return div0 ? Attribute() : result; 486 } 487 488 //===----------------------------------------------------------------------===// 489 // AndIOp 490 //===----------------------------------------------------------------------===// 491 492 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 493 /// and(x, 0) -> 0 494 if (matchPattern(getRhs(), m_Zero())) 495 return getRhs(); 496 /// and(x, allOnes) -> x 497 APInt intValue; 498 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 499 return getLhs(); 500 501 return constFoldBinaryOp<IntegerAttr>( 502 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 503 } 504 505 //===----------------------------------------------------------------------===// 506 // OrIOp 507 //===----------------------------------------------------------------------===// 508 509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 510 /// or(x, 0) -> x 511 if (matchPattern(getRhs(), m_Zero())) 512 return getLhs(); 513 /// or(x, <all ones>) -> <all ones> 514 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 515 if (rhsAttr.getValue().isAllOnes()) 516 return rhsAttr; 517 518 return constFoldBinaryOp<IntegerAttr>( 519 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 520 } 521 522 //===----------------------------------------------------------------------===// 523 // XOrIOp 524 //===----------------------------------------------------------------------===// 525 526 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 527 /// xor(x, 0) -> x 528 if (matchPattern(getRhs(), m_Zero())) 529 return getLhs(); 530 /// xor(x, x) -> 0 531 if (getLhs() == getRhs()) 532 return Builder(getContext()).getZeroAttr(getType()); 533 /// xor(xor(x, a), a) -> x 534 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) 535 if (prev.getRhs() == getRhs()) 536 return prev.getLhs(); 537 538 return constFoldBinaryOp<IntegerAttr>( 539 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 540 } 541 542 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 543 MLIRContext *context) { 544 patterns.add<XOrINotCmpI>(context); 545 } 546 547 //===----------------------------------------------------------------------===// 548 // NegFOp 549 //===----------------------------------------------------------------------===// 550 551 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) { 552 /// negf(negf(x)) -> x 553 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>()) 554 return op.getOperand(); 555 return constFoldUnaryOp<FloatAttr>(operands, 556 [](const APFloat &a) { return -a; }); 557 } 558 559 //===----------------------------------------------------------------------===// 560 // AddFOp 561 //===----------------------------------------------------------------------===// 562 563 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 564 // addf(x, -0) -> x 565 if (matchPattern(getRhs(), m_NegZeroFloat())) 566 return getLhs(); 567 568 return constFoldBinaryOp<FloatAttr>( 569 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 570 } 571 572 //===----------------------------------------------------------------------===// 573 // SubFOp 574 //===----------------------------------------------------------------------===// 575 576 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 577 // subf(x, +0) -> x 578 if (matchPattern(getRhs(), m_PosZeroFloat())) 579 return getLhs(); 580 581 return constFoldBinaryOp<FloatAttr>( 582 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 583 } 584 585 //===----------------------------------------------------------------------===// 586 // MaxFOp 587 //===----------------------------------------------------------------------===// 588 589 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) { 590 assert(operands.size() == 2 && "maxf takes two operands"); 591 592 // maxf(x,x) -> x 593 if (getLhs() == getRhs()) 594 return getRhs(); 595 596 // maxf(x, -inf) -> x 597 if (matchPattern(getRhs(), m_NegInfFloat())) 598 return getLhs(); 599 600 return constFoldBinaryOp<FloatAttr>( 601 operands, 602 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); 603 } 604 605 //===----------------------------------------------------------------------===// 606 // MaxSIOp 607 //===----------------------------------------------------------------------===// 608 609 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 610 assert(operands.size() == 2 && "binary operation takes two operands"); 611 612 // maxsi(x,x) -> x 613 if (getLhs() == getRhs()) 614 return getRhs(); 615 616 APInt intValue; 617 // maxsi(x,MAX_INT) -> MAX_INT 618 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 619 intValue.isMaxSignedValue()) 620 return getRhs(); 621 622 // maxsi(x, MIN_INT) -> x 623 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 624 intValue.isMinSignedValue()) 625 return getLhs(); 626 627 return constFoldBinaryOp<IntegerAttr>(operands, 628 [](const APInt &a, const APInt &b) { 629 return llvm::APIntOps::smax(a, b); 630 }); 631 } 632 633 //===----------------------------------------------------------------------===// 634 // MaxUIOp 635 //===----------------------------------------------------------------------===// 636 637 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 638 assert(operands.size() == 2 && "binary operation takes two operands"); 639 640 // maxui(x,x) -> x 641 if (getLhs() == getRhs()) 642 return getRhs(); 643 644 APInt intValue; 645 // maxui(x,MAX_INT) -> MAX_INT 646 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 647 return getRhs(); 648 649 // maxui(x, MIN_INT) -> x 650 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 651 return getLhs(); 652 653 return constFoldBinaryOp<IntegerAttr>(operands, 654 [](const APInt &a, const APInt &b) { 655 return llvm::APIntOps::umax(a, b); 656 }); 657 } 658 659 //===----------------------------------------------------------------------===// 660 // MinFOp 661 //===----------------------------------------------------------------------===// 662 663 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) { 664 assert(operands.size() == 2 && "minf takes two operands"); 665 666 // minf(x,x) -> x 667 if (getLhs() == getRhs()) 668 return getRhs(); 669 670 // minf(x, +inf) -> x 671 if (matchPattern(getRhs(), m_PosInfFloat())) 672 return getLhs(); 673 674 return constFoldBinaryOp<FloatAttr>( 675 operands, 676 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); 677 } 678 679 //===----------------------------------------------------------------------===// 680 // MinSIOp 681 //===----------------------------------------------------------------------===// 682 683 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 684 assert(operands.size() == 2 && "binary operation takes two operands"); 685 686 // minsi(x,x) -> x 687 if (getLhs() == getRhs()) 688 return getRhs(); 689 690 APInt intValue; 691 // minsi(x,MIN_INT) -> MIN_INT 692 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 693 intValue.isMinSignedValue()) 694 return getRhs(); 695 696 // minsi(x, MAX_INT) -> x 697 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 698 intValue.isMaxSignedValue()) 699 return getLhs(); 700 701 return constFoldBinaryOp<IntegerAttr>(operands, 702 [](const APInt &a, const APInt &b) { 703 return llvm::APIntOps::smin(a, b); 704 }); 705 } 706 707 //===----------------------------------------------------------------------===// 708 // MinUIOp 709 //===----------------------------------------------------------------------===// 710 711 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 712 assert(operands.size() == 2 && "binary operation takes two operands"); 713 714 // minui(x,x) -> x 715 if (getLhs() == getRhs()) 716 return getRhs(); 717 718 APInt intValue; 719 // minui(x,MIN_INT) -> MIN_INT 720 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 721 return getRhs(); 722 723 // minui(x, MAX_INT) -> x 724 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 725 return getLhs(); 726 727 return constFoldBinaryOp<IntegerAttr>(operands, 728 [](const APInt &a, const APInt &b) { 729 return llvm::APIntOps::umin(a, b); 730 }); 731 } 732 733 //===----------------------------------------------------------------------===// 734 // MulFOp 735 //===----------------------------------------------------------------------===// 736 737 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 738 // mulf(x, 1) -> x 739 if (matchPattern(getRhs(), m_OneFloat())) 740 return getLhs(); 741 742 return constFoldBinaryOp<FloatAttr>( 743 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 744 } 745 746 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 747 MLIRContext *context) { 748 patterns.add<MulFOfNegF>(context); 749 } 750 751 //===----------------------------------------------------------------------===// 752 // DivFOp 753 //===----------------------------------------------------------------------===// 754 755 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 756 // divf(x, 1) -> x 757 if (matchPattern(getRhs(), m_OneFloat())) 758 return getLhs(); 759 760 return constFoldBinaryOp<FloatAttr>( 761 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 762 } 763 764 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 765 MLIRContext *context) { 766 patterns.add<DivFOfNegF>(context); 767 } 768 769 //===----------------------------------------------------------------------===// 770 // RemFOp 771 //===----------------------------------------------------------------------===// 772 773 OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) { 774 return constFoldBinaryOp<FloatAttr>(operands, 775 [](const APFloat &a, const APFloat &b) { 776 APFloat result(a); 777 (void)result.remainder(b); 778 return result; 779 }); 780 } 781 782 //===----------------------------------------------------------------------===// 783 // Utility functions for verifying cast ops 784 //===----------------------------------------------------------------------===// 785 786 template <typename... Types> 787 using type_list = std::tuple<Types...> *; 788 789 /// Returns a non-null type only if the provided type is one of the allowed 790 /// types or one of the allowed shaped types of the allowed types. Returns the 791 /// element type if a valid shaped type is provided. 792 template <typename... ShapedTypes, typename... ElementTypes> 793 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 794 type_list<ElementTypes...>) { 795 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 796 return {}; 797 798 auto underlyingType = getElementTypeOrSelf(type); 799 if (!underlyingType.isa<ElementTypes...>()) 800 return {}; 801 802 return underlyingType; 803 } 804 805 /// Get allowed underlying types for vectors and tensors. 806 template <typename... ElementTypes> 807 static Type getTypeIfLike(Type type) { 808 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 809 type_list<ElementTypes...>()); 810 } 811 812 /// Get allowed underlying types for vectors, tensors, and memrefs. 813 template <typename... ElementTypes> 814 static Type getTypeIfLikeOrMemRef(Type type) { 815 return getUnderlyingType(type, 816 type_list<VectorType, TensorType, MemRefType>(), 817 type_list<ElementTypes...>()); 818 } 819 820 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 821 return inputs.size() == 1 && outputs.size() == 1 && 822 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 823 } 824 825 //===----------------------------------------------------------------------===// 826 // Verifiers for integer and floating point extension/truncation ops 827 //===----------------------------------------------------------------------===// 828 829 // Extend ops can only extend to a wider type. 830 template <typename ValType, typename Op> 831 static LogicalResult verifyExtOp(Op op) { 832 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 833 Type dstType = getElementTypeOrSelf(op.getType()); 834 835 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 836 return op.emitError("result type ") 837 << dstType << " must be wider than operand type " << srcType; 838 839 return success(); 840 } 841 842 // Truncate ops can only truncate to a shorter type. 843 template <typename ValType, typename Op> 844 static LogicalResult verifyTruncateOp(Op op) { 845 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 846 Type dstType = getElementTypeOrSelf(op.getType()); 847 848 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 849 return op.emitError("result type ") 850 << dstType << " must be shorter than operand type " << srcType; 851 852 return success(); 853 } 854 855 /// Validate a cast that changes the width of a type. 856 template <template <typename> class WidthComparator, typename... ElementTypes> 857 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 858 if (!areValidCastInputsAndOutputs(inputs, outputs)) 859 return false; 860 861 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 862 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 863 if (!srcType || !dstType) 864 return false; 865 866 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 867 srcType.getIntOrFloatBitWidth()); 868 } 869 870 //===----------------------------------------------------------------------===// 871 // ExtUIOp 872 //===----------------------------------------------------------------------===// 873 874 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 875 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 876 getInMutable().assign(lhs.getIn()); 877 return getResult(); 878 } 879 Type resType = getType(); 880 unsigned bitWidth; 881 if (auto shapedType = resType.dyn_cast<ShapedType>()) 882 bitWidth = shapedType.getElementTypeBitWidth(); 883 else 884 bitWidth = resType.getIntOrFloatBitWidth(); 885 return constFoldCastOp<IntegerAttr, IntegerAttr>( 886 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 887 return a.zext(bitWidth); 888 }); 889 } 890 891 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 892 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 893 } 894 895 LogicalResult arith::ExtUIOp::verify() { 896 return verifyExtOp<IntegerType>(*this); 897 } 898 899 //===----------------------------------------------------------------------===// 900 // ExtSIOp 901 //===----------------------------------------------------------------------===// 902 903 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 904 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 905 getInMutable().assign(lhs.getIn()); 906 return getResult(); 907 } 908 Type resType = getType(); 909 unsigned bitWidth; 910 if (auto shapedType = resType.dyn_cast<ShapedType>()) 911 bitWidth = shapedType.getElementTypeBitWidth(); 912 else 913 bitWidth = resType.getIntOrFloatBitWidth(); 914 return constFoldCastOp<IntegerAttr, IntegerAttr>( 915 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 916 return a.sext(bitWidth); 917 }); 918 } 919 920 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 921 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 922 } 923 924 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 925 MLIRContext *context) { 926 patterns.add<ExtSIOfExtUI>(context); 927 } 928 929 LogicalResult arith::ExtSIOp::verify() { 930 return verifyExtOp<IntegerType>(*this); 931 } 932 933 //===----------------------------------------------------------------------===// 934 // ExtFOp 935 //===----------------------------------------------------------------------===// 936 937 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 938 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 939 } 940 941 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } 942 943 //===----------------------------------------------------------------------===// 944 // TruncIOp 945 //===----------------------------------------------------------------------===// 946 947 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 948 assert(operands.size() == 1 && "unary operation takes one operand"); 949 950 // trunci(zexti(a)) -> a 951 // trunci(sexti(a)) -> a 952 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 953 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 954 return getOperand().getDefiningOp()->getOperand(0); 955 956 // trunci(trunci(a)) -> trunci(a)) 957 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { 958 setOperand(getOperand().getDefiningOp()->getOperand(0)); 959 return getResult(); 960 } 961 962 Type resType = getType(); 963 unsigned bitWidth; 964 if (auto shapedType = resType.dyn_cast<ShapedType>()) 965 bitWidth = shapedType.getElementTypeBitWidth(); 966 else 967 bitWidth = resType.getIntOrFloatBitWidth(); 968 969 return constFoldCastOp<IntegerAttr, IntegerAttr>( 970 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 971 return a.trunc(bitWidth); 972 }); 973 } 974 975 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 976 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 977 } 978 979 LogicalResult arith::TruncIOp::verify() { 980 return verifyTruncateOp<IntegerType>(*this); 981 } 982 983 //===----------------------------------------------------------------------===// 984 // TruncFOp 985 //===----------------------------------------------------------------------===// 986 987 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 988 /// can be represented without precision loss or rounding. 989 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 990 assert(operands.size() == 1 && "unary operation takes one operand"); 991 992 auto constOperand = operands.front(); 993 if (!constOperand || !constOperand.isa<FloatAttr>()) 994 return {}; 995 996 // Convert to target type via 'double'. 997 double sourceValue = 998 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 999 auto targetAttr = FloatAttr::get(getType(), sourceValue); 1000 1001 // Propagate if constant's value does not change after truncation. 1002 if (sourceValue == targetAttr.getValue().convertToDouble()) 1003 return targetAttr; 1004 1005 return {}; 1006 } 1007 1008 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1009 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 1010 } 1011 1012 LogicalResult arith::TruncFOp::verify() { 1013 return verifyTruncateOp<FloatType>(*this); 1014 } 1015 1016 //===----------------------------------------------------------------------===// 1017 // AndIOp 1018 //===----------------------------------------------------------------------===// 1019 1020 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1021 MLIRContext *context) { 1022 patterns.add<AndOfExtUI, AndOfExtSI>(context); 1023 } 1024 1025 //===----------------------------------------------------------------------===// 1026 // OrIOp 1027 //===----------------------------------------------------------------------===// 1028 1029 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1030 MLIRContext *context) { 1031 patterns.add<OrOfExtUI, OrOfExtSI>(context); 1032 } 1033 1034 //===----------------------------------------------------------------------===// 1035 // Verifiers for casts between integers and floats. 1036 //===----------------------------------------------------------------------===// 1037 1038 template <typename From, typename To> 1039 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 1040 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1041 return false; 1042 1043 auto srcType = getTypeIfLike<From>(inputs.front()); 1044 auto dstType = getTypeIfLike<To>(outputs.back()); 1045 1046 return srcType && dstType; 1047 } 1048 1049 //===----------------------------------------------------------------------===// 1050 // UIToFPOp 1051 //===----------------------------------------------------------------------===// 1052 1053 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1054 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1055 } 1056 1057 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 1058 Type resType = getType(); 1059 Type resEleType; 1060 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1061 resEleType = shapedType.getElementType(); 1062 else 1063 resEleType = resType; 1064 return constFoldCastOp<IntegerAttr, FloatAttr>( 1065 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { 1066 FloatType floatTy = resEleType.cast<FloatType>(); 1067 APFloat apf(floatTy.getFloatSemantics(), 1068 APInt::getZero(floatTy.getWidth())); 1069 apf.convertFromAPInt(a, /*IsSigned=*/false, 1070 APFloat::rmNearestTiesToEven); 1071 return apf; 1072 }); 1073 } 1074 1075 //===----------------------------------------------------------------------===// 1076 // SIToFPOp 1077 //===----------------------------------------------------------------------===// 1078 1079 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1080 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1081 } 1082 1083 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 1084 Type resType = getType(); 1085 Type resEleType; 1086 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1087 resEleType = shapedType.getElementType(); 1088 else 1089 resEleType = resType; 1090 return constFoldCastOp<IntegerAttr, FloatAttr>( 1091 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { 1092 FloatType floatTy = resEleType.cast<FloatType>(); 1093 APFloat apf(floatTy.getFloatSemantics(), 1094 APInt::getZero(floatTy.getWidth())); 1095 apf.convertFromAPInt(a, /*IsSigned=*/true, 1096 APFloat::rmNearestTiesToEven); 1097 return apf; 1098 }); 1099 } 1100 //===----------------------------------------------------------------------===// 1101 // FPToUIOp 1102 //===----------------------------------------------------------------------===// 1103 1104 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1105 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1106 } 1107 1108 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 1109 Type resType = getType(); 1110 Type resEleType; 1111 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1112 resEleType = shapedType.getElementType(); 1113 else 1114 resEleType = resType; 1115 return constFoldCastOp<FloatAttr, IntegerAttr>( 1116 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { 1117 IntegerType intTy = resEleType.cast<IntegerType>(); 1118 bool ignored; 1119 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 1120 castStatus = APFloat::opInvalidOp != 1121 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1122 return api; 1123 }); 1124 } 1125 1126 //===----------------------------------------------------------------------===// 1127 // FPToSIOp 1128 //===----------------------------------------------------------------------===// 1129 1130 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1131 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1132 } 1133 1134 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 1135 Type resType = getType(); 1136 Type resEleType; 1137 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1138 resEleType = shapedType.getElementType(); 1139 else 1140 resEleType = resType; 1141 return constFoldCastOp<FloatAttr, IntegerAttr>( 1142 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { 1143 IntegerType intTy = resEleType.cast<IntegerType>(); 1144 bool ignored; 1145 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 1146 castStatus = APFloat::opInvalidOp != 1147 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1148 return api; 1149 }); 1150 } 1151 1152 //===----------------------------------------------------------------------===// 1153 // IndexCastOp 1154 //===----------------------------------------------------------------------===// 1155 1156 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1157 TypeRange outputs) { 1158 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1159 return false; 1160 1161 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1162 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1163 if (!srcType || !dstType) 1164 return false; 1165 1166 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1167 (srcType.isSignlessInteger() && dstType.isIndex()); 1168 } 1169 1170 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1171 // index_cast(constant) -> constant 1172 // A little hack because we go through int. Otherwise, the size of the 1173 // constant might need to change. 1174 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1175 return IntegerAttr::get(getType(), value.getInt()); 1176 1177 return {}; 1178 } 1179 1180 void arith::IndexCastOp::getCanonicalizationPatterns( 1181 RewritePatternSet &patterns, MLIRContext *context) { 1182 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1183 } 1184 1185 //===----------------------------------------------------------------------===// 1186 // BitcastOp 1187 //===----------------------------------------------------------------------===// 1188 1189 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1190 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1191 return false; 1192 1193 auto srcType = 1194 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1195 auto dstType = 1196 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1197 if (!srcType || !dstType) 1198 return false; 1199 1200 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1201 } 1202 1203 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1204 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1205 1206 auto resType = getType(); 1207 auto operand = operands[0]; 1208 if (!operand) 1209 return {}; 1210 1211 /// Bitcast dense elements. 1212 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1213 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1214 /// Other shaped types unhandled. 1215 if (resType.isa<ShapedType>()) 1216 return {}; 1217 1218 /// Bitcast integer or float to integer or float. 1219 APInt bits = operand.isa<FloatAttr>() 1220 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1221 : operand.cast<IntegerAttr>().getValue(); 1222 1223 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1224 return FloatAttr::get(resType, 1225 APFloat(resFloatType.getFloatSemantics(), bits)); 1226 return IntegerAttr::get(resType, bits); 1227 } 1228 1229 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1230 MLIRContext *context) { 1231 patterns.add<BitcastOfBitcast>(context); 1232 } 1233 1234 //===----------------------------------------------------------------------===// 1235 // Helpers for compare ops 1236 //===----------------------------------------------------------------------===// 1237 1238 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1239 static Type getI1SameShape(Type type) { 1240 auto i1Type = IntegerType::get(type.getContext(), 1); 1241 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1242 return RankedTensorType::get(tensorType.getShape(), i1Type); 1243 if (type.isa<UnrankedTensorType>()) 1244 return UnrankedTensorType::get(i1Type); 1245 if (auto vectorType = type.dyn_cast<VectorType>()) 1246 return VectorType::get(vectorType.getShape(), i1Type, 1247 vectorType.getNumScalableDims()); 1248 return i1Type; 1249 } 1250 1251 //===----------------------------------------------------------------------===// 1252 // CmpIOp 1253 //===----------------------------------------------------------------------===// 1254 1255 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1256 /// comparison predicates. 1257 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1258 const APInt &lhs, const APInt &rhs) { 1259 switch (predicate) { 1260 case arith::CmpIPredicate::eq: 1261 return lhs.eq(rhs); 1262 case arith::CmpIPredicate::ne: 1263 return lhs.ne(rhs); 1264 case arith::CmpIPredicate::slt: 1265 return lhs.slt(rhs); 1266 case arith::CmpIPredicate::sle: 1267 return lhs.sle(rhs); 1268 case arith::CmpIPredicate::sgt: 1269 return lhs.sgt(rhs); 1270 case arith::CmpIPredicate::sge: 1271 return lhs.sge(rhs); 1272 case arith::CmpIPredicate::ult: 1273 return lhs.ult(rhs); 1274 case arith::CmpIPredicate::ule: 1275 return lhs.ule(rhs); 1276 case arith::CmpIPredicate::ugt: 1277 return lhs.ugt(rhs); 1278 case arith::CmpIPredicate::uge: 1279 return lhs.uge(rhs); 1280 } 1281 llvm_unreachable("unknown cmpi predicate kind"); 1282 } 1283 1284 /// Returns true if the predicate is true for two equal operands. 1285 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1286 switch (predicate) { 1287 case arith::CmpIPredicate::eq: 1288 case arith::CmpIPredicate::sle: 1289 case arith::CmpIPredicate::sge: 1290 case arith::CmpIPredicate::ule: 1291 case arith::CmpIPredicate::uge: 1292 return true; 1293 case arith::CmpIPredicate::ne: 1294 case arith::CmpIPredicate::slt: 1295 case arith::CmpIPredicate::sgt: 1296 case arith::CmpIPredicate::ult: 1297 case arith::CmpIPredicate::ugt: 1298 return false; 1299 } 1300 llvm_unreachable("unknown cmpi predicate kind"); 1301 } 1302 1303 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1304 auto boolAttr = BoolAttr::get(ctx, value); 1305 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1306 if (!shapedType) 1307 return boolAttr; 1308 return DenseElementsAttr::get(shapedType, boolAttr); 1309 } 1310 1311 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1312 assert(operands.size() == 2 && "cmpi takes two operands"); 1313 1314 // cmpi(pred, x, x) 1315 if (getLhs() == getRhs()) { 1316 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1317 return getBoolAttribute(getType(), getContext(), val); 1318 } 1319 1320 if (matchPattern(getRhs(), m_Zero())) { 1321 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1322 // extsi(%x : i1 -> iN) != 0 -> %x 1323 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 && 1324 getPredicate() == arith::CmpIPredicate::ne) 1325 return extOp.getOperand(); 1326 } 1327 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1328 // extui(%x : i1 -> iN) != 0 -> %x 1329 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 && 1330 getPredicate() == arith::CmpIPredicate::ne) 1331 return extOp.getOperand(); 1332 } 1333 } 1334 1335 // Move constant to the right side. 1336 if (operands[0] && !operands[1]) { 1337 // Do not use invertPredicate, as it will change eq to ne and vice versa. 1338 using Pred = CmpIPredicate; 1339 const std::pair<Pred, Pred> invPreds[] = { 1340 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge}, 1341 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult}, 1342 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq}, 1343 {Pred::ne, Pred::ne}, 1344 }; 1345 Pred origPred = getPredicate(); 1346 for (auto pred : invPreds) { 1347 if (origPred == pred.first) { 1348 setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second)); 1349 Value lhs = getLhs(); 1350 Value rhs = getRhs(); 1351 getLhsMutable().assign(rhs); 1352 getRhsMutable().assign(lhs); 1353 return getResult(); 1354 } 1355 } 1356 llvm_unreachable("unknown cmpi predicate kind"); 1357 } 1358 1359 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1360 if (!lhs) 1361 return {}; 1362 1363 // We are moving constants to the right side; So if lhs is constant rhs is 1364 // guaranteed to be a constant. 1365 auto rhs = operands.back().cast<IntegerAttr>(); 1366 1367 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1368 return BoolAttr::get(getContext(), val); 1369 } 1370 1371 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1372 MLIRContext *context) { 1373 patterns.insert<CmpIExtSI, CmpIExtUI>(context); 1374 } 1375 1376 //===----------------------------------------------------------------------===// 1377 // CmpFOp 1378 //===----------------------------------------------------------------------===// 1379 1380 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1381 /// comparison predicates. 1382 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1383 const APFloat &lhs, const APFloat &rhs) { 1384 auto cmpResult = lhs.compare(rhs); 1385 switch (predicate) { 1386 case arith::CmpFPredicate::AlwaysFalse: 1387 return false; 1388 case arith::CmpFPredicate::OEQ: 1389 return cmpResult == APFloat::cmpEqual; 1390 case arith::CmpFPredicate::OGT: 1391 return cmpResult == APFloat::cmpGreaterThan; 1392 case arith::CmpFPredicate::OGE: 1393 return cmpResult == APFloat::cmpGreaterThan || 1394 cmpResult == APFloat::cmpEqual; 1395 case arith::CmpFPredicate::OLT: 1396 return cmpResult == APFloat::cmpLessThan; 1397 case arith::CmpFPredicate::OLE: 1398 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1399 case arith::CmpFPredicate::ONE: 1400 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1401 case arith::CmpFPredicate::ORD: 1402 return cmpResult != APFloat::cmpUnordered; 1403 case arith::CmpFPredicate::UEQ: 1404 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1405 case arith::CmpFPredicate::UGT: 1406 return cmpResult == APFloat::cmpUnordered || 1407 cmpResult == APFloat::cmpGreaterThan; 1408 case arith::CmpFPredicate::UGE: 1409 return cmpResult == APFloat::cmpUnordered || 1410 cmpResult == APFloat::cmpGreaterThan || 1411 cmpResult == APFloat::cmpEqual; 1412 case arith::CmpFPredicate::ULT: 1413 return cmpResult == APFloat::cmpUnordered || 1414 cmpResult == APFloat::cmpLessThan; 1415 case arith::CmpFPredicate::ULE: 1416 return cmpResult == APFloat::cmpUnordered || 1417 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1418 case arith::CmpFPredicate::UNE: 1419 return cmpResult != APFloat::cmpEqual; 1420 case arith::CmpFPredicate::UNO: 1421 return cmpResult == APFloat::cmpUnordered; 1422 case arith::CmpFPredicate::AlwaysTrue: 1423 return true; 1424 } 1425 llvm_unreachable("unknown cmpf predicate kind"); 1426 } 1427 1428 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1429 assert(operands.size() == 2 && "cmpf takes two operands"); 1430 1431 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1432 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1433 1434 // If one operand is NaN, making them both NaN does not change the result. 1435 if (lhs && lhs.getValue().isNaN()) 1436 rhs = lhs; 1437 if (rhs && rhs.getValue().isNaN()) 1438 lhs = rhs; 1439 1440 if (!lhs || !rhs) 1441 return {}; 1442 1443 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1444 return BoolAttr::get(getContext(), val); 1445 } 1446 1447 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> { 1448 public: 1449 using OpRewritePattern<CmpFOp>::OpRewritePattern; 1450 1451 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, 1452 bool isUnsigned) { 1453 using namespace arith; 1454 switch (pred) { 1455 case CmpFPredicate::UEQ: 1456 case CmpFPredicate::OEQ: 1457 return CmpIPredicate::eq; 1458 case CmpFPredicate::UGT: 1459 case CmpFPredicate::OGT: 1460 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; 1461 case CmpFPredicate::UGE: 1462 case CmpFPredicate::OGE: 1463 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; 1464 case CmpFPredicate::ULT: 1465 case CmpFPredicate::OLT: 1466 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; 1467 case CmpFPredicate::ULE: 1468 case CmpFPredicate::OLE: 1469 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; 1470 case CmpFPredicate::UNE: 1471 case CmpFPredicate::ONE: 1472 return CmpIPredicate::ne; 1473 default: 1474 llvm_unreachable("Unexpected predicate!"); 1475 } 1476 } 1477 1478 LogicalResult matchAndRewrite(CmpFOp op, 1479 PatternRewriter &rewriter) const override { 1480 FloatAttr flt; 1481 if (!matchPattern(op.getRhs(), m_Constant(&flt))) 1482 return failure(); 1483 1484 const APFloat &rhs = flt.getValue(); 1485 1486 // Don't attempt to fold a nan. 1487 if (rhs.isNaN()) 1488 return failure(); 1489 1490 // Get the width of the mantissa. We don't want to hack on conversions that 1491 // might lose information from the integer, e.g. "i64 -> float" 1492 FloatType floatTy = op.getRhs().getType().cast<FloatType>(); 1493 int mantissaWidth = floatTy.getFPMantissaWidth(); 1494 if (mantissaWidth <= 0) 1495 return failure(); 1496 1497 bool isUnsigned; 1498 Value intVal; 1499 1500 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) { 1501 isUnsigned = false; 1502 intVal = si.getIn(); 1503 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) { 1504 isUnsigned = true; 1505 intVal = ui.getIn(); 1506 } else { 1507 return failure(); 1508 } 1509 1510 // Check to see that the input is converted from an integer type that is 1511 // small enough that preserves all bits. 1512 auto intTy = intVal.getType().cast<IntegerType>(); 1513 auto intWidth = intTy.getWidth(); 1514 1515 // Number of bits representing values, as opposed to the sign 1516 auto valueBits = isUnsigned ? intWidth : (intWidth - 1); 1517 1518 // Following test does NOT adjust intWidth downwards for signed inputs, 1519 // because the most negative value still requires all the mantissa bits 1520 // to distinguish it from one less than that value. 1521 if ((int)intWidth > mantissaWidth) { 1522 // Conversion would lose accuracy. Check if loss can impact comparison. 1523 int exponent = ilogb(rhs); 1524 if (exponent == APFloat::IEK_Inf) { 1525 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics())); 1526 if (maxExponent < (int)valueBits) { 1527 // Conversion could create infinity. 1528 return failure(); 1529 } 1530 } else { 1531 // Note that if rhs is zero or NaN, then Exp is negative 1532 // and first condition is trivially false. 1533 if (mantissaWidth <= exponent && exponent <= (int)valueBits) { 1534 // Conversion could affect comparison. 1535 return failure(); 1536 } 1537 } 1538 } 1539 1540 // Convert to equivalent cmpi predicate 1541 CmpIPredicate pred; 1542 switch (op.getPredicate()) { 1543 case CmpFPredicate::ORD: 1544 // Int to fp conversion doesn't create a nan (ord checks neither is a nan) 1545 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1546 /*width=*/1); 1547 return success(); 1548 case CmpFPredicate::UNO: 1549 // Int to fp conversion doesn't create a nan (uno checks either is a nan) 1550 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1551 /*width=*/1); 1552 return success(); 1553 default: 1554 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned); 1555 break; 1556 } 1557 1558 if (!isUnsigned) { 1559 // If the rhs value is > SignedMax, fold the comparison. This handles 1560 // +INF and large values. 1561 APFloat signedMax(rhs.getSemantics()); 1562 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true, 1563 APFloat::rmNearestTiesToEven); 1564 if (signedMax < rhs) { // smax < 13123.0 1565 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt || 1566 pred == CmpIPredicate::sle) 1567 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1568 /*width=*/1); 1569 else 1570 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1571 /*width=*/1); 1572 return success(); 1573 } 1574 } else { 1575 // If the rhs value is > UnsignedMax, fold the comparison. This handles 1576 // +INF and large values. 1577 APFloat unsignedMax(rhs.getSemantics()); 1578 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false, 1579 APFloat::rmNearestTiesToEven); 1580 if (unsignedMax < rhs) { // umax < 13123.0 1581 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult || 1582 pred == CmpIPredicate::ule) 1583 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1584 /*width=*/1); 1585 else 1586 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1587 /*width=*/1); 1588 return success(); 1589 } 1590 } 1591 1592 if (!isUnsigned) { 1593 // See if the rhs value is < SignedMin. 1594 APFloat signedMin(rhs.getSemantics()); 1595 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true, 1596 APFloat::rmNearestTiesToEven); 1597 if (signedMin > rhs) { // smin > 12312.0 1598 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt || 1599 pred == CmpIPredicate::sge) 1600 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1601 /*width=*/1); 1602 else 1603 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1604 /*width=*/1); 1605 return success(); 1606 } 1607 } else { 1608 // See if the rhs value is < UnsignedMin. 1609 APFloat unsignedMin(rhs.getSemantics()); 1610 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false, 1611 APFloat::rmNearestTiesToEven); 1612 if (unsignedMin > rhs) { // umin > 12312.0 1613 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt || 1614 pred == CmpIPredicate::uge) 1615 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1616 /*width=*/1); 1617 else 1618 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1619 /*width=*/1); 1620 return success(); 1621 } 1622 } 1623 1624 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or 1625 // [0, UMAX], but it may still be fractional. See if it is fractional by 1626 // casting the FP value to the integer value and back, checking for 1627 // equality. Don't do this for zero, because -0.0 is not fractional. 1628 bool ignored; 1629 APSInt rhsInt(intWidth, isUnsigned); 1630 if (APFloat::opInvalidOp == 1631 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) { 1632 // Undefined behavior invoked - the destination type can't represent 1633 // the input constant. 1634 return failure(); 1635 } 1636 1637 if (!rhs.isZero()) { 1638 APFloat apf(floatTy.getFloatSemantics(), 1639 APInt::getZero(floatTy.getWidth())); 1640 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven); 1641 1642 bool equal = apf == rhs; 1643 if (!equal) { 1644 // If we had a comparison against a fractional value, we have to adjust 1645 // the compare predicate and sometimes the value. rhsInt is rounded 1646 // towards zero at this point. 1647 switch (pred) { 1648 case CmpIPredicate::ne: // (float)int != 4.4 --> true 1649 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1650 /*width=*/1); 1651 return success(); 1652 case CmpIPredicate::eq: // (float)int == 4.4 --> false 1653 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1654 /*width=*/1); 1655 return success(); 1656 case CmpIPredicate::ule: 1657 // (float)int <= 4.4 --> int <= 4 1658 // (float)int <= -4.4 --> false 1659 if (rhs.isNegative()) { 1660 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1661 /*width=*/1); 1662 return success(); 1663 } 1664 break; 1665 case CmpIPredicate::sle: 1666 // (float)int <= 4.4 --> int <= 4 1667 // (float)int <= -4.4 --> int < -4 1668 if (rhs.isNegative()) 1669 pred = CmpIPredicate::slt; 1670 break; 1671 case CmpIPredicate::ult: 1672 // (float)int < -4.4 --> false 1673 // (float)int < 4.4 --> int <= 4 1674 if (rhs.isNegative()) { 1675 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1676 /*width=*/1); 1677 return success(); 1678 } 1679 pred = CmpIPredicate::ule; 1680 break; 1681 case CmpIPredicate::slt: 1682 // (float)int < -4.4 --> int < -4 1683 // (float)int < 4.4 --> int <= 4 1684 if (!rhs.isNegative()) 1685 pred = CmpIPredicate::sle; 1686 break; 1687 case CmpIPredicate::ugt: 1688 // (float)int > 4.4 --> int > 4 1689 // (float)int > -4.4 --> true 1690 if (rhs.isNegative()) { 1691 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1692 /*width=*/1); 1693 return success(); 1694 } 1695 break; 1696 case CmpIPredicate::sgt: 1697 // (float)int > 4.4 --> int > 4 1698 // (float)int > -4.4 --> int >= -4 1699 if (rhs.isNegative()) 1700 pred = CmpIPredicate::sge; 1701 break; 1702 case CmpIPredicate::uge: 1703 // (float)int >= -4.4 --> true 1704 // (float)int >= 4.4 --> int > 4 1705 if (rhs.isNegative()) { 1706 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1707 /*width=*/1); 1708 return success(); 1709 } 1710 pred = CmpIPredicate::ugt; 1711 break; 1712 case CmpIPredicate::sge: 1713 // (float)int >= -4.4 --> int >= -4 1714 // (float)int >= 4.4 --> int > 4 1715 if (!rhs.isNegative()) 1716 pred = CmpIPredicate::sgt; 1717 break; 1718 } 1719 } 1720 } 1721 1722 // Lower this FP comparison into an appropriate integer version of the 1723 // comparison. 1724 rewriter.replaceOpWithNewOp<CmpIOp>( 1725 op, pred, intVal, 1726 rewriter.create<ConstantOp>( 1727 op.getLoc(), intVal.getType(), 1728 rewriter.getIntegerAttr(intVal.getType(), rhsInt))); 1729 return success(); 1730 } 1731 }; 1732 1733 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1734 MLIRContext *context) { 1735 patterns.insert<CmpFIntToFPConst>(context); 1736 } 1737 1738 //===----------------------------------------------------------------------===// 1739 // SelectOp 1740 //===----------------------------------------------------------------------===// 1741 1742 // Transforms a select of a boolean to arithmetic operations 1743 // 1744 // arith.select %arg, %x, %y : i1 1745 // 1746 // becomes 1747 // 1748 // and(%arg, %x) or and(!%arg, %y) 1749 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> { 1750 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1751 1752 LogicalResult matchAndRewrite(arith::SelectOp op, 1753 PatternRewriter &rewriter) const override { 1754 if (!op.getType().isInteger(1)) 1755 return failure(); 1756 1757 Value falseConstant = 1758 rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1); 1759 Value notCondition = rewriter.create<arith::XOrIOp>( 1760 op.getLoc(), op.getCondition(), falseConstant); 1761 1762 Value trueVal = rewriter.create<arith::AndIOp>( 1763 op.getLoc(), op.getCondition(), op.getTrueValue()); 1764 Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition, 1765 op.getFalseValue()); 1766 rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal); 1767 return success(); 1768 } 1769 }; 1770 1771 // select %arg, %c1, %c0 => extui %arg 1772 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { 1773 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1774 1775 LogicalResult matchAndRewrite(arith::SelectOp op, 1776 PatternRewriter &rewriter) const override { 1777 // Cannot extui i1 to i1, or i1 to f32 1778 if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1)) 1779 return failure(); 1780 1781 // select %x, c1, %c0 => extui %arg 1782 if (matchPattern(op.getTrueValue(), m_One()) && 1783 matchPattern(op.getFalseValue(), m_Zero())) { 1784 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(), 1785 op.getCondition()); 1786 return success(); 1787 } 1788 1789 // select %x, c0, %c1 => extui (xor %arg, true) 1790 if (matchPattern(op.getTrueValue(), m_Zero()) && 1791 matchPattern(op.getFalseValue(), m_One())) { 1792 rewriter.replaceOpWithNewOp<arith::ExtUIOp>( 1793 op, op.getType(), 1794 rewriter.create<arith::XOrIOp>( 1795 op.getLoc(), op.getCondition(), 1796 rewriter.create<arith::ConstantIntOp>( 1797 op.getLoc(), 1, op.getCondition().getType()))); 1798 return success(); 1799 } 1800 1801 return failure(); 1802 } 1803 }; 1804 1805 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, 1806 MLIRContext *context) { 1807 results.add<SelectI1Simplify, SelectToExtUI>(context); 1808 } 1809 1810 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) { 1811 Value trueVal = getTrueValue(); 1812 Value falseVal = getFalseValue(); 1813 if (trueVal == falseVal) 1814 return trueVal; 1815 1816 Value condition = getCondition(); 1817 1818 // select true, %0, %1 => %0 1819 if (matchPattern(condition, m_One())) 1820 return trueVal; 1821 1822 // select false, %0, %1 => %1 1823 if (matchPattern(condition, m_Zero())) 1824 return falseVal; 1825 1826 // select %x, true, false => %x 1827 if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) && 1828 matchPattern(getFalseValue(), m_Zero())) 1829 return condition; 1830 1831 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { 1832 auto pred = cmp.getPredicate(); 1833 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { 1834 auto cmpLhs = cmp.getLhs(); 1835 auto cmpRhs = cmp.getRhs(); 1836 1837 // %0 = arith.cmpi eq, %arg0, %arg1 1838 // %1 = arith.select %0, %arg0, %arg1 => %arg1 1839 1840 // %0 = arith.cmpi ne, %arg0, %arg1 1841 // %1 = arith.select %0, %arg0, %arg1 => %arg0 1842 1843 if ((cmpLhs == trueVal && cmpRhs == falseVal) || 1844 (cmpRhs == trueVal && cmpLhs == falseVal)) 1845 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; 1846 } 1847 } 1848 return nullptr; 1849 } 1850 1851 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { 1852 Type conditionType, resultType; 1853 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; 1854 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || 1855 parser.parseOptionalAttrDict(result.attributes) || 1856 parser.parseColonType(resultType)) 1857 return failure(); 1858 1859 // Check for the explicit condition type if this is a masked tensor or vector. 1860 if (succeeded(parser.parseOptionalComma())) { 1861 conditionType = resultType; 1862 if (parser.parseType(resultType)) 1863 return failure(); 1864 } else { 1865 conditionType = parser.getBuilder().getI1Type(); 1866 } 1867 1868 result.addTypes(resultType); 1869 return parser.resolveOperands(operands, 1870 {conditionType, resultType, resultType}, 1871 parser.getNameLoc(), result.operands); 1872 } 1873 1874 void arith::SelectOp::print(OpAsmPrinter &p) { 1875 p << " " << getOperands(); 1876 p.printOptionalAttrDict((*this)->getAttrs()); 1877 p << " : "; 1878 if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>()) 1879 p << condType << ", "; 1880 p << getType(); 1881 } 1882 1883 LogicalResult arith::SelectOp::verify() { 1884 Type conditionType = getCondition().getType(); 1885 if (conditionType.isSignlessInteger(1)) 1886 return success(); 1887 1888 // If the result type is a vector or tensor, the type can be a mask with the 1889 // same elements. 1890 Type resultType = getType(); 1891 if (!resultType.isa<TensorType, VectorType>()) 1892 return emitOpError() << "expected condition to be a signless i1, but got " 1893 << conditionType; 1894 Type shapedConditionType = getI1SameShape(resultType); 1895 if (conditionType != shapedConditionType) { 1896 return emitOpError() << "expected condition type to have the same shape " 1897 "as the result type, expected " 1898 << shapedConditionType << ", but got " 1899 << conditionType; 1900 } 1901 return success(); 1902 } 1903 //===----------------------------------------------------------------------===// 1904 // ShLIOp 1905 //===----------------------------------------------------------------------===// 1906 1907 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) { 1908 // Don't fold if shifting more than the bit width. 1909 bool bounded = false; 1910 auto result = constFoldBinaryOp<IntegerAttr>( 1911 operands, [&](const APInt &a, const APInt &b) { 1912 bounded = b.ule(b.getBitWidth()); 1913 return a.shl(b); 1914 }); 1915 return bounded ? result : Attribute(); 1916 } 1917 1918 //===----------------------------------------------------------------------===// 1919 // ShRUIOp 1920 //===----------------------------------------------------------------------===// 1921 1922 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) { 1923 // Don't fold if shifting more than the bit width. 1924 bool bounded = false; 1925 auto result = constFoldBinaryOp<IntegerAttr>( 1926 operands, [&](const APInt &a, const APInt &b) { 1927 bounded = b.ule(b.getBitWidth()); 1928 return a.lshr(b); 1929 }); 1930 return bounded ? result : Attribute(); 1931 } 1932 1933 //===----------------------------------------------------------------------===// 1934 // ShRSIOp 1935 //===----------------------------------------------------------------------===// 1936 1937 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) { 1938 // Don't fold if shifting more than the bit width. 1939 bool bounded = false; 1940 auto result = constFoldBinaryOp<IntegerAttr>( 1941 operands, [&](const APInt &a, const APInt &b) { 1942 bounded = b.ule(b.getBitWidth()); 1943 return a.ashr(b); 1944 }); 1945 return bounded ? result : Attribute(); 1946 } 1947 1948 //===----------------------------------------------------------------------===// 1949 // Atomic Enum 1950 //===----------------------------------------------------------------------===// 1951 1952 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1953 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1954 OpBuilder &builder, Location loc) { 1955 switch (kind) { 1956 case AtomicRMWKind::maxf: 1957 return builder.getFloatAttr( 1958 resultType, 1959 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1960 /*Negative=*/true)); 1961 case AtomicRMWKind::addf: 1962 case AtomicRMWKind::addi: 1963 case AtomicRMWKind::maxu: 1964 case AtomicRMWKind::ori: 1965 return builder.getZeroAttr(resultType); 1966 case AtomicRMWKind::andi: 1967 return builder.getIntegerAttr( 1968 resultType, 1969 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1970 case AtomicRMWKind::maxs: 1971 return builder.getIntegerAttr( 1972 resultType, 1973 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1974 case AtomicRMWKind::minf: 1975 return builder.getFloatAttr( 1976 resultType, 1977 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1978 /*Negative=*/false)); 1979 case AtomicRMWKind::mins: 1980 return builder.getIntegerAttr( 1981 resultType, 1982 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1983 case AtomicRMWKind::minu: 1984 return builder.getIntegerAttr( 1985 resultType, 1986 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1987 case AtomicRMWKind::muli: 1988 return builder.getIntegerAttr(resultType, 1); 1989 case AtomicRMWKind::mulf: 1990 return builder.getFloatAttr(resultType, 1); 1991 // TODO: Add remaining reduction operations. 1992 default: 1993 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1994 break; 1995 } 1996 return nullptr; 1997 } 1998 1999 /// Returns the identity value associated with an AtomicRMWKind op. 2000 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 2001 OpBuilder &builder, Location loc) { 2002 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 2003 return builder.create<arith::ConstantOp>(loc, attr); 2004 } 2005 2006 /// Return the value obtained by applying the reduction operation kind 2007 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 2008 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 2009 Location loc, Value lhs, Value rhs) { 2010 switch (op) { 2011 case AtomicRMWKind::addf: 2012 return builder.create<arith::AddFOp>(loc, lhs, rhs); 2013 case AtomicRMWKind::addi: 2014 return builder.create<arith::AddIOp>(loc, lhs, rhs); 2015 case AtomicRMWKind::mulf: 2016 return builder.create<arith::MulFOp>(loc, lhs, rhs); 2017 case AtomicRMWKind::muli: 2018 return builder.create<arith::MulIOp>(loc, lhs, rhs); 2019 case AtomicRMWKind::maxf: 2020 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 2021 case AtomicRMWKind::minf: 2022 return builder.create<arith::MinFOp>(loc, lhs, rhs); 2023 case AtomicRMWKind::maxs: 2024 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 2025 case AtomicRMWKind::mins: 2026 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 2027 case AtomicRMWKind::maxu: 2028 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 2029 case AtomicRMWKind::minu: 2030 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 2031 case AtomicRMWKind::ori: 2032 return builder.create<arith::OrIOp>(loc, lhs, rhs); 2033 case AtomicRMWKind::andi: 2034 return builder.create<arith::AndIOp>(loc, lhs, rhs); 2035 // TODO: Add remaining reduction operations. 2036 default: 2037 (void)emitOptionalError(loc, "Reduction operation type not supported"); 2038 break; 2039 } 2040 return nullptr; 2041 } 2042 2043 //===----------------------------------------------------------------------===// 2044 // TableGen'd op method definitions 2045 //===----------------------------------------------------------------------===// 2046 2047 #define GET_OP_CLASSES 2048 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 2049 2050 //===----------------------------------------------------------------------===// 2051 // TableGen'd enum attribute definitions 2052 //===----------------------------------------------------------------------===// 2053 2054 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 2055