1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "llvm/ADT/SmallString.h" 19 20 #include "llvm/ADT/APSInt.h" 21 22 using namespace mlir; 23 using namespace mlir::arith; 24 25 //===----------------------------------------------------------------------===// 26 // Pattern helpers 27 //===----------------------------------------------------------------------===// 28 29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 30 Attribute lhs, Attribute rhs) { 31 return builder.getIntegerAttr(res.getType(), 32 lhs.cast<IntegerAttr>().getInt() + 33 rhs.cast<IntegerAttr>().getInt()); 34 } 35 36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 37 Attribute lhs, Attribute rhs) { 38 return builder.getIntegerAttr(res.getType(), 39 lhs.cast<IntegerAttr>().getInt() - 40 rhs.cast<IntegerAttr>().getInt()); 41 } 42 43 /// Invert an integer comparison predicate. 44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { 45 switch (pred) { 46 case arith::CmpIPredicate::eq: 47 return arith::CmpIPredicate::ne; 48 case arith::CmpIPredicate::ne: 49 return arith::CmpIPredicate::eq; 50 case arith::CmpIPredicate::slt: 51 return arith::CmpIPredicate::sge; 52 case arith::CmpIPredicate::sle: 53 return arith::CmpIPredicate::sgt; 54 case arith::CmpIPredicate::sgt: 55 return arith::CmpIPredicate::sle; 56 case arith::CmpIPredicate::sge: 57 return arith::CmpIPredicate::slt; 58 case arith::CmpIPredicate::ult: 59 return arith::CmpIPredicate::uge; 60 case arith::CmpIPredicate::ule: 61 return arith::CmpIPredicate::ugt; 62 case arith::CmpIPredicate::ugt: 63 return arith::CmpIPredicate::ule; 64 case arith::CmpIPredicate::uge: 65 return arith::CmpIPredicate::ult; 66 } 67 llvm_unreachable("unknown cmpi predicate kind"); 68 } 69 70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 71 return arith::CmpIPredicateAttr::get(pred.getContext(), 72 invertPredicate(pred.getValue())); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // TableGen'd canonicalization patterns 77 //===----------------------------------------------------------------------===// 78 79 namespace { 80 #include "ArithmeticCanonicalization.inc" 81 } // namespace 82 83 //===----------------------------------------------------------------------===// 84 // ConstantOp 85 //===----------------------------------------------------------------------===// 86 87 void arith::ConstantOp::getAsmResultNames( 88 function_ref<void(Value, StringRef)> setNameFn) { 89 auto type = getType(); 90 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 91 auto intType = type.dyn_cast<IntegerType>(); 92 93 // Sugar i1 constants with 'true' and 'false'. 94 if (intType && intType.getWidth() == 1) 95 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 96 97 // Otherwise, build a compex name with the value and type. 98 SmallString<32> specialNameBuffer; 99 llvm::raw_svector_ostream specialName(specialNameBuffer); 100 specialName << 'c' << intCst.getInt(); 101 if (intType) 102 specialName << '_' << type; 103 setNameFn(getResult(), specialName.str()); 104 } else { 105 setNameFn(getResult(), "cst"); 106 } 107 } 108 109 /// TODO: disallow arith.constant to return anything other than signless integer 110 /// or float like. 111 LogicalResult arith::ConstantOp::verify() { 112 auto type = getType(); 113 // The value's type must match the return type. 114 if (getValue().getType() != type) { 115 return emitOpError() << "value type " << getValue().getType() 116 << " must match return type: " << type; 117 } 118 // Integer values must be signless. 119 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 120 return emitOpError("integer return type must be signless"); 121 // Any float or elements attribute are acceptable. 122 if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 123 return emitOpError( 124 "value must be an integer, float, or elements attribute"); 125 } 126 return success(); 127 } 128 129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 130 // The value's type must be the same as the provided type. 131 if (value.getType() != type) 132 return false; 133 // Integer values must be signless. 134 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 135 return false; 136 // Integer, float, and element attributes are buildable. 137 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 138 } 139 140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 141 return getValue(); 142 } 143 144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 145 int64_t value, unsigned width) { 146 auto type = builder.getIntegerType(width); 147 arith::ConstantOp::build(builder, result, type, 148 builder.getIntegerAttr(type, value)); 149 } 150 151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 152 int64_t value, Type type) { 153 assert(type.isSignlessInteger() && 154 "ConstantIntOp can only have signless integer type values"); 155 arith::ConstantOp::build(builder, result, type, 156 builder.getIntegerAttr(type, value)); 157 } 158 159 bool arith::ConstantIntOp::classof(Operation *op) { 160 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 161 return constOp.getType().isSignlessInteger(); 162 return false; 163 } 164 165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 166 const APFloat &value, FloatType type) { 167 arith::ConstantOp::build(builder, result, type, 168 builder.getFloatAttr(type, value)); 169 } 170 171 bool arith::ConstantFloatOp::classof(Operation *op) { 172 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 173 return constOp.getType().isa<FloatType>(); 174 return false; 175 } 176 177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 178 int64_t value) { 179 arith::ConstantOp::build(builder, result, builder.getIndexType(), 180 builder.getIndexAttr(value)); 181 } 182 183 bool arith::ConstantIndexOp::classof(Operation *op) { 184 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 185 return constOp.getType().isIndex(); 186 return false; 187 } 188 189 //===----------------------------------------------------------------------===// 190 // AddIOp 191 //===----------------------------------------------------------------------===// 192 193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 194 // addi(x, 0) -> x 195 if (matchPattern(getRhs(), m_Zero())) 196 return getLhs(); 197 198 // addi(subi(a, b), b) -> a 199 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 200 if (getRhs() == sub.getRhs()) 201 return sub.getLhs(); 202 203 // addi(b, subi(a, b)) -> a 204 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 205 if (getLhs() == sub.getRhs()) 206 return sub.getLhs(); 207 208 return constFoldBinaryOp<IntegerAttr>( 209 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 210 } 211 212 void arith::AddIOp::getCanonicalizationPatterns( 213 RewritePatternSet &patterns, MLIRContext *context) { 214 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 215 context); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // SubIOp 220 //===----------------------------------------------------------------------===// 221 222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 223 // subi(x,x) -> 0 224 if (getOperand(0) == getOperand(1)) 225 return Builder(getContext()).getZeroAttr(getType()); 226 // subi(x,0) -> x 227 if (matchPattern(getRhs(), m_Zero())) 228 return getLhs(); 229 230 return constFoldBinaryOp<IntegerAttr>( 231 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 232 } 233 234 void arith::SubIOp::getCanonicalizationPatterns( 235 RewritePatternSet &patterns, MLIRContext *context) { 236 patterns 237 .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 238 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>( 239 context); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // MulIOp 244 //===----------------------------------------------------------------------===// 245 246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 247 // muli(x, 0) -> 0 248 if (matchPattern(getRhs(), m_Zero())) 249 return getRhs(); 250 // muli(x, 1) -> x 251 if (matchPattern(getRhs(), m_One())) 252 return getOperand(0); 253 // TODO: Handle the overflow case. 254 255 // default folder 256 return constFoldBinaryOp<IntegerAttr>( 257 operands, [](const APInt &a, const APInt &b) { return a * b; }); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // DivUIOp 262 //===----------------------------------------------------------------------===// 263 264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 265 // divui (x, 1) -> x. 266 if (matchPattern(getRhs(), m_One())) 267 return getLhs(); 268 269 // Don't fold if it would require a division by zero. 270 bool div0 = false; 271 auto result = 272 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 273 if (div0 || !b) { 274 div0 = true; 275 return a; 276 } 277 return a.udiv(b); 278 }); 279 280 return div0 ? Attribute() : result; 281 } 282 283 //===----------------------------------------------------------------------===// 284 // DivSIOp 285 //===----------------------------------------------------------------------===// 286 287 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 288 // divsi (x, 1) -> x. 289 if (matchPattern(getRhs(), m_One())) 290 return getLhs(); 291 292 // Don't fold if it would overflow or if it requires a division by zero. 293 bool overflowOrDiv0 = false; 294 auto result = 295 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 296 if (overflowOrDiv0 || !b) { 297 overflowOrDiv0 = true; 298 return a; 299 } 300 return a.sdiv_ov(b, overflowOrDiv0); 301 }); 302 303 return overflowOrDiv0 ? Attribute() : result; 304 } 305 306 //===----------------------------------------------------------------------===// 307 // Ceil and floor division folding helpers 308 //===----------------------------------------------------------------------===// 309 310 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 311 bool &overflow) { 312 // Returns (a-1)/b + 1 313 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 314 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 315 return val.sadd_ov(one, overflow); 316 } 317 318 //===----------------------------------------------------------------------===// 319 // CeilDivUIOp 320 //===----------------------------------------------------------------------===// 321 322 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 323 // ceildivui (x, 1) -> x. 324 if (matchPattern(getRhs(), m_One())) 325 return getLhs(); 326 327 bool overflowOrDiv0 = false; 328 auto result = 329 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 330 if (overflowOrDiv0 || !b) { 331 overflowOrDiv0 = true; 332 return a; 333 } 334 APInt quotient = a.udiv(b); 335 if (!a.urem(b)) 336 return quotient; 337 APInt one(a.getBitWidth(), 1, true); 338 return quotient.uadd_ov(one, overflowOrDiv0); 339 }); 340 341 return overflowOrDiv0 ? Attribute() : result; 342 } 343 344 //===----------------------------------------------------------------------===// 345 // CeilDivSIOp 346 //===----------------------------------------------------------------------===// 347 348 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 349 // ceildivsi (x, 1) -> x. 350 if (matchPattern(getRhs(), m_One())) 351 return getLhs(); 352 353 // Don't fold if it would overflow or if it requires a division by zero. 354 bool overflowOrDiv0 = false; 355 auto result = 356 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 357 if (overflowOrDiv0 || !b) { 358 overflowOrDiv0 = true; 359 return a; 360 } 361 if (!a) 362 return a; 363 // After this point we know that neither a or b are zero. 364 unsigned bits = a.getBitWidth(); 365 APInt zero = APInt::getZero(bits); 366 bool aGtZero = a.sgt(zero); 367 bool bGtZero = b.sgt(zero); 368 if (aGtZero && bGtZero) { 369 // Both positive, return ceil(a, b). 370 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 371 } 372 if (!aGtZero && !bGtZero) { 373 // Both negative, return ceil(-a, -b). 374 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 375 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 376 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 377 } 378 if (!aGtZero && bGtZero) { 379 // A is negative, b is positive, return - ( -a / b). 380 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 381 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 382 return zero.ssub_ov(div, overflowOrDiv0); 383 } 384 // A is positive, b is negative, return - (a / -b). 385 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 386 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 387 return zero.ssub_ov(div, overflowOrDiv0); 388 }); 389 390 return overflowOrDiv0 ? Attribute() : result; 391 } 392 393 //===----------------------------------------------------------------------===// 394 // FloorDivSIOp 395 //===----------------------------------------------------------------------===// 396 397 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 398 // floordivsi (x, 1) -> x. 399 if (matchPattern(getRhs(), m_One())) 400 return getLhs(); 401 402 // Don't fold if it would overflow or if it requires a division by zero. 403 bool overflowOrDiv0 = false; 404 auto result = 405 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 406 if (overflowOrDiv0 || !b) { 407 overflowOrDiv0 = true; 408 return a; 409 } 410 if (!a) 411 return a; 412 // After this point we know that neither a or b are zero. 413 unsigned bits = a.getBitWidth(); 414 APInt zero = APInt::getZero(bits); 415 bool aGtZero = a.sgt(zero); 416 bool bGtZero = b.sgt(zero); 417 if (aGtZero && bGtZero) { 418 // Both positive, return a / b. 419 return a.sdiv_ov(b, overflowOrDiv0); 420 } 421 if (!aGtZero && !bGtZero) { 422 // Both negative, return -a / -b. 423 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 424 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 425 return posA.sdiv_ov(posB, overflowOrDiv0); 426 } 427 if (!aGtZero && bGtZero) { 428 // A is negative, b is positive, return - ceil(-a, b). 429 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 430 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 431 return zero.ssub_ov(ceil, overflowOrDiv0); 432 } 433 // A is positive, b is negative, return - ceil(a, -b). 434 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 435 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 436 return zero.ssub_ov(ceil, overflowOrDiv0); 437 }); 438 439 return overflowOrDiv0 ? Attribute() : result; 440 } 441 442 //===----------------------------------------------------------------------===// 443 // RemUIOp 444 //===----------------------------------------------------------------------===// 445 446 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 447 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 448 if (!rhs) 449 return {}; 450 auto rhsValue = rhs.getValue(); 451 452 // x % 1 = 0 453 if (rhsValue.isOneValue()) 454 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 455 456 // Don't fold if it requires division by zero. 457 if (rhsValue.isNullValue()) 458 return {}; 459 460 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 461 if (!lhs) 462 return {}; 463 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 464 } 465 466 //===----------------------------------------------------------------------===// 467 // RemSIOp 468 //===----------------------------------------------------------------------===// 469 470 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 471 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 472 if (!rhs) 473 return {}; 474 auto rhsValue = rhs.getValue(); 475 476 // x % 1 = 0 477 if (rhsValue.isOneValue()) 478 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 479 480 // Don't fold if it requires division by zero. 481 if (rhsValue.isNullValue()) 482 return {}; 483 484 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 485 if (!lhs) 486 return {}; 487 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 488 } 489 490 //===----------------------------------------------------------------------===// 491 // AndIOp 492 //===----------------------------------------------------------------------===// 493 494 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 495 /// and(x, 0) -> 0 496 if (matchPattern(getRhs(), m_Zero())) 497 return getRhs(); 498 /// and(x, allOnes) -> x 499 APInt intValue; 500 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 501 return getLhs(); 502 503 return constFoldBinaryOp<IntegerAttr>( 504 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 505 } 506 507 //===----------------------------------------------------------------------===// 508 // OrIOp 509 //===----------------------------------------------------------------------===// 510 511 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 512 /// or(x, 0) -> x 513 if (matchPattern(getRhs(), m_Zero())) 514 return getLhs(); 515 /// or(x, <all ones>) -> <all ones> 516 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 517 if (rhsAttr.getValue().isAllOnes()) 518 return rhsAttr; 519 520 return constFoldBinaryOp<IntegerAttr>( 521 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 522 } 523 524 //===----------------------------------------------------------------------===// 525 // XOrIOp 526 //===----------------------------------------------------------------------===// 527 528 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 529 /// xor(x, 0) -> x 530 if (matchPattern(getRhs(), m_Zero())) 531 return getLhs(); 532 /// xor(x, x) -> 0 533 if (getLhs() == getRhs()) 534 return Builder(getContext()).getZeroAttr(getType()); 535 /// xor(xor(x, a), a) -> x 536 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) 537 if (prev.getRhs() == getRhs()) 538 return prev.getLhs(); 539 540 return constFoldBinaryOp<IntegerAttr>( 541 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 542 } 543 544 void arith::XOrIOp::getCanonicalizationPatterns( 545 RewritePatternSet &patterns, MLIRContext *context) { 546 patterns.add<XOrINotCmpI>(context); 547 } 548 549 //===----------------------------------------------------------------------===// 550 // NegFOp 551 //===----------------------------------------------------------------------===// 552 553 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) { 554 return constFoldUnaryOp<FloatAttr>(operands, 555 [](const APFloat &a) { return -a; }); 556 } 557 558 //===----------------------------------------------------------------------===// 559 // AddFOp 560 //===----------------------------------------------------------------------===// 561 562 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 563 // addf(x, -0) -> x 564 if (matchPattern(getRhs(), m_NegZeroFloat())) 565 return getLhs(); 566 567 return constFoldBinaryOp<FloatAttr>( 568 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 569 } 570 571 //===----------------------------------------------------------------------===// 572 // SubFOp 573 //===----------------------------------------------------------------------===// 574 575 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 576 // subf(x, +0) -> x 577 if (matchPattern(getRhs(), m_PosZeroFloat())) 578 return getLhs(); 579 580 return constFoldBinaryOp<FloatAttr>( 581 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 582 } 583 584 //===----------------------------------------------------------------------===// 585 // MaxFOp 586 //===----------------------------------------------------------------------===// 587 588 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) { 589 assert(operands.size() == 2 && "maxf takes two operands"); 590 591 // maxf(x,x) -> x 592 if (getLhs() == getRhs()) 593 return getRhs(); 594 595 // maxf(x, -inf) -> x 596 if (matchPattern(getRhs(), m_NegInfFloat())) 597 return getLhs(); 598 599 return constFoldBinaryOp<FloatAttr>( 600 operands, 601 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); 602 } 603 604 //===----------------------------------------------------------------------===// 605 // MaxSIOp 606 //===----------------------------------------------------------------------===// 607 608 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 609 assert(operands.size() == 2 && "binary operation takes two operands"); 610 611 // maxsi(x,x) -> x 612 if (getLhs() == getRhs()) 613 return getRhs(); 614 615 APInt intValue; 616 // maxsi(x,MAX_INT) -> MAX_INT 617 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 618 intValue.isMaxSignedValue()) 619 return getRhs(); 620 621 // maxsi(x, MIN_INT) -> x 622 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 623 intValue.isMinSignedValue()) 624 return getLhs(); 625 626 return constFoldBinaryOp<IntegerAttr>(operands, 627 [](const APInt &a, const APInt &b) { 628 return llvm::APIntOps::smax(a, b); 629 }); 630 } 631 632 //===----------------------------------------------------------------------===// 633 // MaxUIOp 634 //===----------------------------------------------------------------------===// 635 636 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 637 assert(operands.size() == 2 && "binary operation takes two operands"); 638 639 // maxui(x,x) -> x 640 if (getLhs() == getRhs()) 641 return getRhs(); 642 643 APInt intValue; 644 // maxui(x,MAX_INT) -> MAX_INT 645 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 646 return getRhs(); 647 648 // maxui(x, MIN_INT) -> x 649 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 650 return getLhs(); 651 652 return constFoldBinaryOp<IntegerAttr>(operands, 653 [](const APInt &a, const APInt &b) { 654 return llvm::APIntOps::umax(a, b); 655 }); 656 } 657 658 //===----------------------------------------------------------------------===// 659 // MinFOp 660 //===----------------------------------------------------------------------===// 661 662 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) { 663 assert(operands.size() == 2 && "minf takes two operands"); 664 665 // minf(x,x) -> x 666 if (getLhs() == getRhs()) 667 return getRhs(); 668 669 // minf(x, +inf) -> x 670 if (matchPattern(getRhs(), m_PosInfFloat())) 671 return getLhs(); 672 673 return constFoldBinaryOp<FloatAttr>( 674 operands, 675 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); 676 } 677 678 //===----------------------------------------------------------------------===// 679 // MinSIOp 680 //===----------------------------------------------------------------------===// 681 682 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 683 assert(operands.size() == 2 && "binary operation takes two operands"); 684 685 // minsi(x,x) -> x 686 if (getLhs() == getRhs()) 687 return getRhs(); 688 689 APInt intValue; 690 // minsi(x,MIN_INT) -> MIN_INT 691 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 692 intValue.isMinSignedValue()) 693 return getRhs(); 694 695 // minsi(x, MAX_INT) -> x 696 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 697 intValue.isMaxSignedValue()) 698 return getLhs(); 699 700 return constFoldBinaryOp<IntegerAttr>(operands, 701 [](const APInt &a, const APInt &b) { 702 return llvm::APIntOps::smin(a, b); 703 }); 704 } 705 706 //===----------------------------------------------------------------------===// 707 // MinUIOp 708 //===----------------------------------------------------------------------===// 709 710 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 711 assert(operands.size() == 2 && "binary operation takes two operands"); 712 713 // minui(x,x) -> x 714 if (getLhs() == getRhs()) 715 return getRhs(); 716 717 APInt intValue; 718 // minui(x,MIN_INT) -> MIN_INT 719 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 720 return getRhs(); 721 722 // minui(x, MAX_INT) -> x 723 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 724 return getLhs(); 725 726 return constFoldBinaryOp<IntegerAttr>(operands, 727 [](const APInt &a, const APInt &b) { 728 return llvm::APIntOps::umin(a, b); 729 }); 730 } 731 732 //===----------------------------------------------------------------------===// 733 // MulFOp 734 //===----------------------------------------------------------------------===// 735 736 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 737 // mulf(x, 1) -> x 738 if (matchPattern(getRhs(), m_OneFloat())) 739 return getLhs(); 740 741 return constFoldBinaryOp<FloatAttr>( 742 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 743 } 744 745 //===----------------------------------------------------------------------===// 746 // DivFOp 747 //===----------------------------------------------------------------------===// 748 749 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 750 // divf(x, 1) -> x 751 if (matchPattern(getRhs(), m_OneFloat())) 752 return getLhs(); 753 754 return constFoldBinaryOp<FloatAttr>( 755 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 756 } 757 758 //===----------------------------------------------------------------------===// 759 // Utility functions for verifying cast ops 760 //===----------------------------------------------------------------------===// 761 762 template <typename... Types> 763 using type_list = std::tuple<Types...> *; 764 765 /// Returns a non-null type only if the provided type is one of the allowed 766 /// types or one of the allowed shaped types of the allowed types. Returns the 767 /// element type if a valid shaped type is provided. 768 template <typename... ShapedTypes, typename... ElementTypes> 769 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 770 type_list<ElementTypes...>) { 771 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 772 return {}; 773 774 auto underlyingType = getElementTypeOrSelf(type); 775 if (!underlyingType.isa<ElementTypes...>()) 776 return {}; 777 778 return underlyingType; 779 } 780 781 /// Get allowed underlying types for vectors and tensors. 782 template <typename... ElementTypes> 783 static Type getTypeIfLike(Type type) { 784 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 785 type_list<ElementTypes...>()); 786 } 787 788 /// Get allowed underlying types for vectors, tensors, and memrefs. 789 template <typename... ElementTypes> 790 static Type getTypeIfLikeOrMemRef(Type type) { 791 return getUnderlyingType(type, 792 type_list<VectorType, TensorType, MemRefType>(), 793 type_list<ElementTypes...>()); 794 } 795 796 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 797 return inputs.size() == 1 && outputs.size() == 1 && 798 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 799 } 800 801 //===----------------------------------------------------------------------===// 802 // Verifiers for integer and floating point extension/truncation ops 803 //===----------------------------------------------------------------------===// 804 805 // Extend ops can only extend to a wider type. 806 template <typename ValType, typename Op> 807 static LogicalResult verifyExtOp(Op op) { 808 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 809 Type dstType = getElementTypeOrSelf(op.getType()); 810 811 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 812 return op.emitError("result type ") 813 << dstType << " must be wider than operand type " << srcType; 814 815 return success(); 816 } 817 818 // Truncate ops can only truncate to a shorter type. 819 template <typename ValType, typename Op> 820 static LogicalResult verifyTruncateOp(Op op) { 821 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 822 Type dstType = getElementTypeOrSelf(op.getType()); 823 824 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 825 return op.emitError("result type ") 826 << dstType << " must be shorter than operand type " << srcType; 827 828 return success(); 829 } 830 831 /// Validate a cast that changes the width of a type. 832 template <template <typename> class WidthComparator, typename... ElementTypes> 833 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 834 if (!areValidCastInputsAndOutputs(inputs, outputs)) 835 return false; 836 837 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 838 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 839 if (!srcType || !dstType) 840 return false; 841 842 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 843 srcType.getIntOrFloatBitWidth()); 844 } 845 846 //===----------------------------------------------------------------------===// 847 // ExtUIOp 848 //===----------------------------------------------------------------------===// 849 850 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 851 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 852 getInMutable().assign(lhs.getIn()); 853 return getResult(); 854 } 855 Type resType = getType(); 856 unsigned bitWidth; 857 if (auto shapedType = resType.dyn_cast<ShapedType>()) 858 bitWidth = shapedType.getElementTypeBitWidth(); 859 else 860 bitWidth = resType.getIntOrFloatBitWidth(); 861 return constFoldCastOp<IntegerAttr, IntegerAttr>( 862 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 863 return a.zext(bitWidth); 864 }); 865 } 866 867 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 868 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 869 } 870 871 LogicalResult arith::ExtUIOp::verify() { 872 return verifyExtOp<IntegerType>(*this); 873 } 874 875 //===----------------------------------------------------------------------===// 876 // ExtSIOp 877 //===----------------------------------------------------------------------===// 878 879 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 880 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 881 getInMutable().assign(lhs.getIn()); 882 return getResult(); 883 } 884 Type resType = getType(); 885 unsigned bitWidth; 886 if (auto shapedType = resType.dyn_cast<ShapedType>()) 887 bitWidth = shapedType.getElementTypeBitWidth(); 888 else 889 bitWidth = resType.getIntOrFloatBitWidth(); 890 return constFoldCastOp<IntegerAttr, IntegerAttr>( 891 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 892 return a.sext(bitWidth); 893 }); 894 } 895 896 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 897 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 898 } 899 900 void arith::ExtSIOp::getCanonicalizationPatterns( 901 RewritePatternSet &patterns, MLIRContext *context) { 902 patterns.add<ExtSIOfExtUI>(context); 903 } 904 905 LogicalResult arith::ExtSIOp::verify() { 906 return verifyExtOp<IntegerType>(*this); 907 } 908 909 //===----------------------------------------------------------------------===// 910 // ExtFOp 911 //===----------------------------------------------------------------------===// 912 913 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 914 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 915 } 916 917 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } 918 919 //===----------------------------------------------------------------------===// 920 // TruncIOp 921 //===----------------------------------------------------------------------===// 922 923 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 924 assert(operands.size() == 1 && "unary operation takes one operand"); 925 926 // trunci(zexti(a)) -> a 927 // trunci(sexti(a)) -> a 928 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 929 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 930 return getOperand().getDefiningOp()->getOperand(0); 931 932 // trunci(trunci(a)) -> trunci(a)) 933 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { 934 setOperand(getOperand().getDefiningOp()->getOperand(0)); 935 return getResult(); 936 } 937 938 Type resType = getType(); 939 unsigned bitWidth; 940 if (auto shapedType = resType.dyn_cast<ShapedType>()) 941 bitWidth = shapedType.getElementTypeBitWidth(); 942 else 943 bitWidth = resType.getIntOrFloatBitWidth(); 944 945 return constFoldCastOp<IntegerAttr, IntegerAttr>( 946 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 947 return a.trunc(bitWidth); 948 }); 949 } 950 951 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 952 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 953 } 954 955 LogicalResult arith::TruncIOp::verify() { 956 return verifyTruncateOp<IntegerType>(*this); 957 } 958 959 //===----------------------------------------------------------------------===// 960 // TruncFOp 961 //===----------------------------------------------------------------------===// 962 963 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 964 /// can be represented without precision loss or rounding. 965 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 966 assert(operands.size() == 1 && "unary operation takes one operand"); 967 968 auto constOperand = operands.front(); 969 if (!constOperand || !constOperand.isa<FloatAttr>()) 970 return {}; 971 972 // Convert to target type via 'double'. 973 double sourceValue = 974 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 975 auto targetAttr = FloatAttr::get(getType(), sourceValue); 976 977 // Propagate if constant's value does not change after truncation. 978 if (sourceValue == targetAttr.getValue().convertToDouble()) 979 return targetAttr; 980 981 return {}; 982 } 983 984 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 985 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 986 } 987 988 LogicalResult arith::TruncFOp::verify() { 989 return verifyTruncateOp<FloatType>(*this); 990 } 991 992 //===----------------------------------------------------------------------===// 993 // AndIOp 994 //===----------------------------------------------------------------------===// 995 996 void arith::AndIOp::getCanonicalizationPatterns( 997 RewritePatternSet &patterns, MLIRContext *context) { 998 patterns.add<AndOfExtUI, AndOfExtSI>(context); 999 } 1000 1001 //===----------------------------------------------------------------------===// 1002 // OrIOp 1003 //===----------------------------------------------------------------------===// 1004 1005 void arith::OrIOp::getCanonicalizationPatterns( 1006 RewritePatternSet &patterns, MLIRContext *context) { 1007 patterns.add<OrOfExtUI, OrOfExtSI>(context); 1008 } 1009 1010 //===----------------------------------------------------------------------===// 1011 // Verifiers for casts between integers and floats. 1012 //===----------------------------------------------------------------------===// 1013 1014 template <typename From, typename To> 1015 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 1016 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1017 return false; 1018 1019 auto srcType = getTypeIfLike<From>(inputs.front()); 1020 auto dstType = getTypeIfLike<To>(outputs.back()); 1021 1022 return srcType && dstType; 1023 } 1024 1025 //===----------------------------------------------------------------------===// 1026 // UIToFPOp 1027 //===----------------------------------------------------------------------===// 1028 1029 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1030 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1031 } 1032 1033 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 1034 Type resType = getType(); 1035 Type resEleType; 1036 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1037 resEleType = shapedType.getElementType(); 1038 else 1039 resEleType = resType; 1040 return constFoldCastOp<IntegerAttr, FloatAttr>( 1041 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { 1042 FloatType floatTy = resEleType.cast<FloatType>(); 1043 APFloat apf(floatTy.getFloatSemantics(), 1044 APInt::getZero(floatTy.getWidth())); 1045 apf.convertFromAPInt(a, /*IsSigned=*/false, 1046 APFloat::rmNearestTiesToEven); 1047 return apf; 1048 }); 1049 } 1050 1051 //===----------------------------------------------------------------------===// 1052 // SIToFPOp 1053 //===----------------------------------------------------------------------===// 1054 1055 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1056 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1057 } 1058 1059 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 1060 Type resType = getType(); 1061 Type resEleType; 1062 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1063 resEleType = shapedType.getElementType(); 1064 else 1065 resEleType = resType; 1066 return constFoldCastOp<IntegerAttr, FloatAttr>( 1067 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { 1068 FloatType floatTy = resEleType.cast<FloatType>(); 1069 APFloat apf(floatTy.getFloatSemantics(), 1070 APInt::getZero(floatTy.getWidth())); 1071 apf.convertFromAPInt(a, /*IsSigned=*/true, 1072 APFloat::rmNearestTiesToEven); 1073 return apf; 1074 }); 1075 } 1076 //===----------------------------------------------------------------------===// 1077 // FPToUIOp 1078 //===----------------------------------------------------------------------===// 1079 1080 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1081 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1082 } 1083 1084 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 1085 Type resType = getType(); 1086 Type resEleType; 1087 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1088 resEleType = shapedType.getElementType(); 1089 else 1090 resEleType = resType; 1091 return constFoldCastOp<FloatAttr, IntegerAttr>( 1092 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { 1093 IntegerType intTy = resEleType.cast<IntegerType>(); 1094 bool ignored; 1095 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 1096 castStatus = APFloat::opInvalidOp != 1097 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1098 return api; 1099 }); 1100 } 1101 1102 //===----------------------------------------------------------------------===// 1103 // FPToSIOp 1104 //===----------------------------------------------------------------------===// 1105 1106 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1107 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1108 } 1109 1110 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 1111 Type resType = getType(); 1112 Type resEleType; 1113 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1114 resEleType = shapedType.getElementType(); 1115 else 1116 resEleType = resType; 1117 return constFoldCastOp<FloatAttr, IntegerAttr>( 1118 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { 1119 IntegerType intTy = resEleType.cast<IntegerType>(); 1120 bool ignored; 1121 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 1122 castStatus = APFloat::opInvalidOp != 1123 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1124 return api; 1125 }); 1126 } 1127 1128 //===----------------------------------------------------------------------===// 1129 // IndexCastOp 1130 //===----------------------------------------------------------------------===// 1131 1132 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1133 TypeRange outputs) { 1134 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1135 return false; 1136 1137 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1138 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1139 if (!srcType || !dstType) 1140 return false; 1141 1142 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1143 (srcType.isSignlessInteger() && dstType.isIndex()); 1144 } 1145 1146 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1147 // index_cast(constant) -> constant 1148 // A little hack because we go through int. Otherwise, the size of the 1149 // constant might need to change. 1150 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1151 return IntegerAttr::get(getType(), value.getInt()); 1152 1153 return {}; 1154 } 1155 1156 void arith::IndexCastOp::getCanonicalizationPatterns( 1157 RewritePatternSet &patterns, MLIRContext *context) { 1158 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1159 } 1160 1161 //===----------------------------------------------------------------------===// 1162 // BitcastOp 1163 //===----------------------------------------------------------------------===// 1164 1165 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1166 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1167 return false; 1168 1169 auto srcType = 1170 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1171 auto dstType = 1172 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1173 if (!srcType || !dstType) 1174 return false; 1175 1176 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1177 } 1178 1179 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1180 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1181 1182 auto resType = getType(); 1183 auto operand = operands[0]; 1184 if (!operand) 1185 return {}; 1186 1187 /// Bitcast dense elements. 1188 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1189 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1190 /// Other shaped types unhandled. 1191 if (resType.isa<ShapedType>()) 1192 return {}; 1193 1194 /// Bitcast integer or float to integer or float. 1195 APInt bits = operand.isa<FloatAttr>() 1196 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1197 : operand.cast<IntegerAttr>().getValue(); 1198 1199 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1200 return FloatAttr::get(resType, 1201 APFloat(resFloatType.getFloatSemantics(), bits)); 1202 return IntegerAttr::get(resType, bits); 1203 } 1204 1205 void arith::BitcastOp::getCanonicalizationPatterns( 1206 RewritePatternSet &patterns, MLIRContext *context) { 1207 patterns.add<BitcastOfBitcast>(context); 1208 } 1209 1210 //===----------------------------------------------------------------------===// 1211 // Helpers for compare ops 1212 //===----------------------------------------------------------------------===// 1213 1214 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1215 static Type getI1SameShape(Type type) { 1216 auto i1Type = IntegerType::get(type.getContext(), 1); 1217 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1218 return RankedTensorType::get(tensorType.getShape(), i1Type); 1219 if (type.isa<UnrankedTensorType>()) 1220 return UnrankedTensorType::get(i1Type); 1221 if (auto vectorType = type.dyn_cast<VectorType>()) 1222 return VectorType::get(vectorType.getShape(), i1Type, 1223 vectorType.getNumScalableDims()); 1224 return i1Type; 1225 } 1226 1227 //===----------------------------------------------------------------------===// 1228 // CmpIOp 1229 //===----------------------------------------------------------------------===// 1230 1231 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1232 /// comparison predicates. 1233 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1234 const APInt &lhs, const APInt &rhs) { 1235 switch (predicate) { 1236 case arith::CmpIPredicate::eq: 1237 return lhs.eq(rhs); 1238 case arith::CmpIPredicate::ne: 1239 return lhs.ne(rhs); 1240 case arith::CmpIPredicate::slt: 1241 return lhs.slt(rhs); 1242 case arith::CmpIPredicate::sle: 1243 return lhs.sle(rhs); 1244 case arith::CmpIPredicate::sgt: 1245 return lhs.sgt(rhs); 1246 case arith::CmpIPredicate::sge: 1247 return lhs.sge(rhs); 1248 case arith::CmpIPredicate::ult: 1249 return lhs.ult(rhs); 1250 case arith::CmpIPredicate::ule: 1251 return lhs.ule(rhs); 1252 case arith::CmpIPredicate::ugt: 1253 return lhs.ugt(rhs); 1254 case arith::CmpIPredicate::uge: 1255 return lhs.uge(rhs); 1256 } 1257 llvm_unreachable("unknown cmpi predicate kind"); 1258 } 1259 1260 /// Returns true if the predicate is true for two equal operands. 1261 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1262 switch (predicate) { 1263 case arith::CmpIPredicate::eq: 1264 case arith::CmpIPredicate::sle: 1265 case arith::CmpIPredicate::sge: 1266 case arith::CmpIPredicate::ule: 1267 case arith::CmpIPredicate::uge: 1268 return true; 1269 case arith::CmpIPredicate::ne: 1270 case arith::CmpIPredicate::slt: 1271 case arith::CmpIPredicate::sgt: 1272 case arith::CmpIPredicate::ult: 1273 case arith::CmpIPredicate::ugt: 1274 return false; 1275 } 1276 llvm_unreachable("unknown cmpi predicate kind"); 1277 } 1278 1279 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1280 auto boolAttr = BoolAttr::get(ctx, value); 1281 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1282 if (!shapedType) 1283 return boolAttr; 1284 return DenseElementsAttr::get(shapedType, boolAttr); 1285 } 1286 1287 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1288 assert(operands.size() == 2 && "cmpi takes two operands"); 1289 1290 // cmpi(pred, x, x) 1291 if (getLhs() == getRhs()) { 1292 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1293 return getBoolAttribute(getType(), getContext(), val); 1294 } 1295 1296 if (matchPattern(getRhs(), m_Zero())) { 1297 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1298 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1299 // extsi(%x : i1 -> iN) != 0 -> %x 1300 if (getPredicate() == arith::CmpIPredicate::ne) { 1301 return extOp.getOperand(); 1302 } 1303 } 1304 } 1305 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1306 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1307 // extui(%x : i1 -> iN) != 0 -> %x 1308 if (getPredicate() == arith::CmpIPredicate::ne) { 1309 return extOp.getOperand(); 1310 } 1311 } 1312 } 1313 } 1314 1315 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1316 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1317 if (!lhs || !rhs) 1318 return {}; 1319 1320 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1321 return BoolAttr::get(getContext(), val); 1322 } 1323 1324 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1325 MLIRContext *context) { 1326 patterns.insert<CmpIExtSI, CmpIExtUI>(context); 1327 } 1328 1329 //===----------------------------------------------------------------------===// 1330 // CmpFOp 1331 //===----------------------------------------------------------------------===// 1332 1333 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1334 /// comparison predicates. 1335 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1336 const APFloat &lhs, const APFloat &rhs) { 1337 auto cmpResult = lhs.compare(rhs); 1338 switch (predicate) { 1339 case arith::CmpFPredicate::AlwaysFalse: 1340 return false; 1341 case arith::CmpFPredicate::OEQ: 1342 return cmpResult == APFloat::cmpEqual; 1343 case arith::CmpFPredicate::OGT: 1344 return cmpResult == APFloat::cmpGreaterThan; 1345 case arith::CmpFPredicate::OGE: 1346 return cmpResult == APFloat::cmpGreaterThan || 1347 cmpResult == APFloat::cmpEqual; 1348 case arith::CmpFPredicate::OLT: 1349 return cmpResult == APFloat::cmpLessThan; 1350 case arith::CmpFPredicate::OLE: 1351 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1352 case arith::CmpFPredicate::ONE: 1353 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1354 case arith::CmpFPredicate::ORD: 1355 return cmpResult != APFloat::cmpUnordered; 1356 case arith::CmpFPredicate::UEQ: 1357 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1358 case arith::CmpFPredicate::UGT: 1359 return cmpResult == APFloat::cmpUnordered || 1360 cmpResult == APFloat::cmpGreaterThan; 1361 case arith::CmpFPredicate::UGE: 1362 return cmpResult == APFloat::cmpUnordered || 1363 cmpResult == APFloat::cmpGreaterThan || 1364 cmpResult == APFloat::cmpEqual; 1365 case arith::CmpFPredicate::ULT: 1366 return cmpResult == APFloat::cmpUnordered || 1367 cmpResult == APFloat::cmpLessThan; 1368 case arith::CmpFPredicate::ULE: 1369 return cmpResult == APFloat::cmpUnordered || 1370 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1371 case arith::CmpFPredicate::UNE: 1372 return cmpResult != APFloat::cmpEqual; 1373 case arith::CmpFPredicate::UNO: 1374 return cmpResult == APFloat::cmpUnordered; 1375 case arith::CmpFPredicate::AlwaysTrue: 1376 return true; 1377 } 1378 llvm_unreachable("unknown cmpf predicate kind"); 1379 } 1380 1381 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1382 assert(operands.size() == 2 && "cmpf takes two operands"); 1383 1384 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1385 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1386 1387 // If one operand is NaN, making them both NaN does not change the result. 1388 if (lhs && lhs.getValue().isNaN()) 1389 rhs = lhs; 1390 if (rhs && rhs.getValue().isNaN()) 1391 lhs = rhs; 1392 1393 if (!lhs || !rhs) 1394 return {}; 1395 1396 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1397 return BoolAttr::get(getContext(), val); 1398 } 1399 1400 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> { 1401 public: 1402 using OpRewritePattern<CmpFOp>::OpRewritePattern; 1403 1404 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, 1405 bool isUnsigned) { 1406 using namespace arith; 1407 switch (pred) { 1408 case CmpFPredicate::UEQ: 1409 case CmpFPredicate::OEQ: 1410 return CmpIPredicate::eq; 1411 case CmpFPredicate::UGT: 1412 case CmpFPredicate::OGT: 1413 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; 1414 case CmpFPredicate::UGE: 1415 case CmpFPredicate::OGE: 1416 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; 1417 case CmpFPredicate::ULT: 1418 case CmpFPredicate::OLT: 1419 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; 1420 case CmpFPredicate::ULE: 1421 case CmpFPredicate::OLE: 1422 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; 1423 case CmpFPredicate::UNE: 1424 case CmpFPredicate::ONE: 1425 return CmpIPredicate::ne; 1426 default: 1427 llvm_unreachable("Unexpected predicate!"); 1428 } 1429 } 1430 1431 LogicalResult matchAndRewrite(CmpFOp op, 1432 PatternRewriter &rewriter) const override { 1433 FloatAttr flt; 1434 if (!matchPattern(op.getRhs(), m_Constant(&flt))) 1435 return failure(); 1436 1437 const APFloat &rhs = flt.getValue(); 1438 1439 // Don't attempt to fold a nan. 1440 if (rhs.isNaN()) 1441 return failure(); 1442 1443 // Get the width of the mantissa. We don't want to hack on conversions that 1444 // might lose information from the integer, e.g. "i64 -> float" 1445 FloatType floatTy = op.getRhs().getType().cast<FloatType>(); 1446 int mantissaWidth = floatTy.getFPMantissaWidth(); 1447 if (mantissaWidth <= 0) 1448 return failure(); 1449 1450 bool isUnsigned; 1451 Value intVal; 1452 1453 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) { 1454 isUnsigned = false; 1455 intVal = si.getIn(); 1456 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) { 1457 isUnsigned = true; 1458 intVal = ui.getIn(); 1459 } else { 1460 return failure(); 1461 } 1462 1463 // Check to see that the input is converted from an integer type that is 1464 // small enough that preserves all bits. 1465 auto intTy = intVal.getType().cast<IntegerType>(); 1466 auto intWidth = intTy.getWidth(); 1467 1468 // Number of bits representing values, as opposed to the sign 1469 auto valueBits = isUnsigned ? intWidth : (intWidth - 1); 1470 1471 // Following test does NOT adjust intWidth downwards for signed inputs, 1472 // because the most negative value still requires all the mantissa bits 1473 // to distinguish it from one less than that value. 1474 if ((int)intWidth > mantissaWidth) { 1475 // Conversion would lose accuracy. Check if loss can impact comparison. 1476 int exponent = ilogb(rhs); 1477 if (exponent == APFloat::IEK_Inf) { 1478 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics())); 1479 if (maxExponent < (int)valueBits) { 1480 // Conversion could create infinity. 1481 return failure(); 1482 } 1483 } else { 1484 // Note that if rhs is zero or NaN, then Exp is negative 1485 // and first condition is trivially false. 1486 if (mantissaWidth <= exponent && exponent <= (int)valueBits) { 1487 // Conversion could affect comparison. 1488 return failure(); 1489 } 1490 } 1491 } 1492 1493 // Convert to equivalent cmpi predicate 1494 CmpIPredicate pred; 1495 switch (op.getPredicate()) { 1496 case CmpFPredicate::ORD: 1497 // Int to fp conversion doesn't create a nan (ord checks neither is a nan) 1498 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1499 /*width=*/1); 1500 return success(); 1501 case CmpFPredicate::UNO: 1502 // Int to fp conversion doesn't create a nan (uno checks either is a nan) 1503 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1504 /*width=*/1); 1505 return success(); 1506 default: 1507 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned); 1508 break; 1509 } 1510 1511 if (!isUnsigned) { 1512 // If the rhs value is > SignedMax, fold the comparison. This handles 1513 // +INF and large values. 1514 APFloat signedMax(rhs.getSemantics()); 1515 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true, 1516 APFloat::rmNearestTiesToEven); 1517 if (signedMax < rhs) { // smax < 13123.0 1518 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt || 1519 pred == CmpIPredicate::sle) 1520 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1521 /*width=*/1); 1522 else 1523 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1524 /*width=*/1); 1525 return success(); 1526 } 1527 } else { 1528 // If the rhs value is > UnsignedMax, fold the comparison. This handles 1529 // +INF and large values. 1530 APFloat unsignedMax(rhs.getSemantics()); 1531 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false, 1532 APFloat::rmNearestTiesToEven); 1533 if (unsignedMax < rhs) { // umax < 13123.0 1534 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult || 1535 pred == CmpIPredicate::ule) 1536 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1537 /*width=*/1); 1538 else 1539 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1540 /*width=*/1); 1541 return success(); 1542 } 1543 } 1544 1545 if (!isUnsigned) { 1546 // See if the rhs value is < SignedMin. 1547 APFloat signedMin(rhs.getSemantics()); 1548 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true, 1549 APFloat::rmNearestTiesToEven); 1550 if (signedMin > rhs) { // smin > 12312.0 1551 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt || 1552 pred == CmpIPredicate::sge) 1553 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1554 /*width=*/1); 1555 else 1556 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1557 /*width=*/1); 1558 return success(); 1559 } 1560 } else { 1561 // See if the rhs value is < UnsignedMin. 1562 APFloat unsignedMin(rhs.getSemantics()); 1563 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false, 1564 APFloat::rmNearestTiesToEven); 1565 if (unsignedMin > rhs) { // umin > 12312.0 1566 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt || 1567 pred == CmpIPredicate::uge) 1568 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1569 /*width=*/1); 1570 else 1571 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1572 /*width=*/1); 1573 return success(); 1574 } 1575 } 1576 1577 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or 1578 // [0, UMAX], but it may still be fractional. See if it is fractional by 1579 // casting the FP value to the integer value and back, checking for 1580 // equality. Don't do this for zero, because -0.0 is not fractional. 1581 bool ignored; 1582 APSInt rhsInt(intWidth, isUnsigned); 1583 if (APFloat::opInvalidOp == 1584 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) { 1585 // Undefined behavior invoked - the destination type can't represent 1586 // the input constant. 1587 return failure(); 1588 } 1589 1590 if (!rhs.isZero()) { 1591 APFloat apf(floatTy.getFloatSemantics(), 1592 APInt::getZero(floatTy.getWidth())); 1593 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven); 1594 1595 bool equal = apf == rhs; 1596 if (!equal) { 1597 // If we had a comparison against a fractional value, we have to adjust 1598 // the compare predicate and sometimes the value. rhsInt is rounded 1599 // towards zero at this point. 1600 switch (pred) { 1601 case CmpIPredicate::ne: // (float)int != 4.4 --> true 1602 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1603 /*width=*/1); 1604 return success(); 1605 case CmpIPredicate::eq: // (float)int == 4.4 --> false 1606 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1607 /*width=*/1); 1608 return success(); 1609 case CmpIPredicate::ule: 1610 // (float)int <= 4.4 --> int <= 4 1611 // (float)int <= -4.4 --> false 1612 if (rhs.isNegative()) { 1613 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1614 /*width=*/1); 1615 return success(); 1616 } 1617 break; 1618 case CmpIPredicate::sle: 1619 // (float)int <= 4.4 --> int <= 4 1620 // (float)int <= -4.4 --> int < -4 1621 if (rhs.isNegative()) 1622 pred = CmpIPredicate::slt; 1623 break; 1624 case CmpIPredicate::ult: 1625 // (float)int < -4.4 --> false 1626 // (float)int < 4.4 --> int <= 4 1627 if (rhs.isNegative()) { 1628 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1629 /*width=*/1); 1630 return success(); 1631 } 1632 pred = CmpIPredicate::ule; 1633 break; 1634 case CmpIPredicate::slt: 1635 // (float)int < -4.4 --> int < -4 1636 // (float)int < 4.4 --> int <= 4 1637 if (!rhs.isNegative()) 1638 pred = CmpIPredicate::sle; 1639 break; 1640 case CmpIPredicate::ugt: 1641 // (float)int > 4.4 --> int > 4 1642 // (float)int > -4.4 --> true 1643 if (rhs.isNegative()) { 1644 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1645 /*width=*/1); 1646 return success(); 1647 } 1648 break; 1649 case CmpIPredicate::sgt: 1650 // (float)int > 4.4 --> int > 4 1651 // (float)int > -4.4 --> int >= -4 1652 if (rhs.isNegative()) 1653 pred = CmpIPredicate::sge; 1654 break; 1655 case CmpIPredicate::uge: 1656 // (float)int >= -4.4 --> true 1657 // (float)int >= 4.4 --> int > 4 1658 if (rhs.isNegative()) { 1659 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1660 /*width=*/1); 1661 return success(); 1662 } 1663 pred = CmpIPredicate::ugt; 1664 break; 1665 case CmpIPredicate::sge: 1666 // (float)int >= -4.4 --> int >= -4 1667 // (float)int >= 4.4 --> int > 4 1668 if (!rhs.isNegative()) 1669 pred = CmpIPredicate::sgt; 1670 break; 1671 } 1672 } 1673 } 1674 1675 // Lower this FP comparison into an appropriate integer version of the 1676 // comparison. 1677 rewriter.replaceOpWithNewOp<CmpIOp>( 1678 op, pred, intVal, 1679 rewriter.create<ConstantOp>( 1680 op.getLoc(), intVal.getType(), 1681 rewriter.getIntegerAttr(intVal.getType(), rhsInt))); 1682 return success(); 1683 } 1684 }; 1685 1686 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1687 MLIRContext *context) { 1688 patterns.insert<CmpFIntToFPConst>(context); 1689 } 1690 1691 //===----------------------------------------------------------------------===// 1692 // SelectOp 1693 //===----------------------------------------------------------------------===// 1694 1695 // Transforms a select of a boolean to arithmetic operations 1696 // 1697 // arith.select %arg, %x, %y : i1 1698 // 1699 // becomes 1700 // 1701 // and(%arg, %x) or and(!%arg, %y) 1702 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> { 1703 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1704 1705 LogicalResult matchAndRewrite(arith::SelectOp op, 1706 PatternRewriter &rewriter) const override { 1707 if (!op.getType().isInteger(1)) 1708 return failure(); 1709 1710 Value falseConstant = 1711 rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1); 1712 Value notCondition = rewriter.create<arith::XOrIOp>( 1713 op.getLoc(), op.getCondition(), falseConstant); 1714 1715 Value trueVal = rewriter.create<arith::AndIOp>( 1716 op.getLoc(), op.getCondition(), op.getTrueValue()); 1717 Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition, 1718 op.getFalseValue()); 1719 rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal); 1720 return success(); 1721 } 1722 }; 1723 1724 // select %arg, %c1, %c0 => extui %arg 1725 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { 1726 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1727 1728 LogicalResult matchAndRewrite(arith::SelectOp op, 1729 PatternRewriter &rewriter) const override { 1730 // Cannot extui i1 to i1, or i1 to f32 1731 if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1)) 1732 return failure(); 1733 1734 // select %x, c1, %c0 => extui %arg 1735 if (matchPattern(op.getTrueValue(), m_One())) 1736 if (matchPattern(op.getFalseValue(), m_Zero())) { 1737 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(), 1738 op.getCondition()); 1739 return success(); 1740 } 1741 1742 // select %x, c0, %c1 => extui (xor %arg, true) 1743 if (matchPattern(op.getTrueValue(), m_Zero())) 1744 if (matchPattern(op.getFalseValue(), m_One())) { 1745 rewriter.replaceOpWithNewOp<arith::ExtUIOp>( 1746 op, op.getType(), 1747 rewriter.create<arith::XOrIOp>( 1748 op.getLoc(), op.getCondition(), 1749 rewriter.create<arith::ConstantIntOp>( 1750 op.getLoc(), 1, op.getCondition().getType()))); 1751 return success(); 1752 } 1753 1754 return failure(); 1755 } 1756 }; 1757 1758 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, 1759 MLIRContext *context) { 1760 results.add<SelectI1Simplify, SelectToExtUI>(context); 1761 } 1762 1763 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) { 1764 Value trueVal = getTrueValue(); 1765 Value falseVal = getFalseValue(); 1766 if (trueVal == falseVal) 1767 return trueVal; 1768 1769 Value condition = getCondition(); 1770 1771 // select true, %0, %1 => %0 1772 if (matchPattern(condition, m_One())) 1773 return trueVal; 1774 1775 // select false, %0, %1 => %1 1776 if (matchPattern(condition, m_Zero())) 1777 return falseVal; 1778 1779 // select %x, true, false => %x 1780 if (getType().isInteger(1)) 1781 if (matchPattern(getTrueValue(), m_One())) 1782 if (matchPattern(getFalseValue(), m_Zero())) 1783 return condition; 1784 1785 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { 1786 auto pred = cmp.getPredicate(); 1787 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { 1788 auto cmpLhs = cmp.getLhs(); 1789 auto cmpRhs = cmp.getRhs(); 1790 1791 // %0 = arith.cmpi eq, %arg0, %arg1 1792 // %1 = arith.select %0, %arg0, %arg1 => %arg1 1793 1794 // %0 = arith.cmpi ne, %arg0, %arg1 1795 // %1 = arith.select %0, %arg0, %arg1 => %arg0 1796 1797 if ((cmpLhs == trueVal && cmpRhs == falseVal) || 1798 (cmpRhs == trueVal && cmpLhs == falseVal)) 1799 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; 1800 } 1801 } 1802 return nullptr; 1803 } 1804 1805 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { 1806 Type conditionType, resultType; 1807 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; 1808 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || 1809 parser.parseOptionalAttrDict(result.attributes) || 1810 parser.parseColonType(resultType)) 1811 return failure(); 1812 1813 // Check for the explicit condition type if this is a masked tensor or vector. 1814 if (succeeded(parser.parseOptionalComma())) { 1815 conditionType = resultType; 1816 if (parser.parseType(resultType)) 1817 return failure(); 1818 } else { 1819 conditionType = parser.getBuilder().getI1Type(); 1820 } 1821 1822 result.addTypes(resultType); 1823 return parser.resolveOperands(operands, 1824 {conditionType, resultType, resultType}, 1825 parser.getNameLoc(), result.operands); 1826 } 1827 1828 void arith::SelectOp::print(OpAsmPrinter &p) { 1829 p << " " << getOperands(); 1830 p.printOptionalAttrDict((*this)->getAttrs()); 1831 p << " : "; 1832 if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>()) 1833 p << condType << ", "; 1834 p << getType(); 1835 } 1836 1837 LogicalResult arith::SelectOp::verify() { 1838 Type conditionType = getCondition().getType(); 1839 if (conditionType.isSignlessInteger(1)) 1840 return success(); 1841 1842 // If the result type is a vector or tensor, the type can be a mask with the 1843 // same elements. 1844 Type resultType = getType(); 1845 if (!resultType.isa<TensorType, VectorType>()) 1846 return emitOpError() << "expected condition to be a signless i1, but got " 1847 << conditionType; 1848 Type shapedConditionType = getI1SameShape(resultType); 1849 if (conditionType != shapedConditionType) { 1850 return emitOpError() << "expected condition type to have the same shape " 1851 "as the result type, expected " 1852 << shapedConditionType << ", but got " 1853 << conditionType; 1854 } 1855 return success(); 1856 } 1857 //===----------------------------------------------------------------------===// 1858 // ShLIOp 1859 //===----------------------------------------------------------------------===// 1860 1861 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) { 1862 // Don't fold if shifting more than the bit width. 1863 bool bounded = false; 1864 auto result = constFoldBinaryOp<IntegerAttr>( 1865 operands, [&](const APInt &a, const APInt &b) { 1866 bounded = b.ule(b.getBitWidth()); 1867 return a.shl(b); 1868 }); 1869 return bounded ? result : Attribute(); 1870 } 1871 1872 //===----------------------------------------------------------------------===// 1873 // ShRUIOp 1874 //===----------------------------------------------------------------------===// 1875 1876 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) { 1877 // Don't fold if shifting more than the bit width. 1878 bool bounded = false; 1879 auto result = constFoldBinaryOp<IntegerAttr>( 1880 operands, [&](const APInt &a, const APInt &b) { 1881 bounded = b.ule(b.getBitWidth()); 1882 return a.lshr(b); 1883 }); 1884 return bounded ? result : Attribute(); 1885 } 1886 1887 //===----------------------------------------------------------------------===// 1888 // ShRSIOp 1889 //===----------------------------------------------------------------------===// 1890 1891 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) { 1892 // Don't fold if shifting more than the bit width. 1893 bool bounded = false; 1894 auto result = constFoldBinaryOp<IntegerAttr>( 1895 operands, [&](const APInt &a, const APInt &b) { 1896 bounded = b.ule(b.getBitWidth()); 1897 return a.ashr(b); 1898 }); 1899 return bounded ? result : Attribute(); 1900 } 1901 1902 //===----------------------------------------------------------------------===// 1903 // Atomic Enum 1904 //===----------------------------------------------------------------------===// 1905 1906 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1907 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1908 OpBuilder &builder, Location loc) { 1909 switch (kind) { 1910 case AtomicRMWKind::maxf: 1911 return builder.getFloatAttr( 1912 resultType, 1913 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1914 /*Negative=*/true)); 1915 case AtomicRMWKind::addf: 1916 case AtomicRMWKind::addi: 1917 case AtomicRMWKind::maxu: 1918 case AtomicRMWKind::ori: 1919 return builder.getZeroAttr(resultType); 1920 case AtomicRMWKind::andi: 1921 return builder.getIntegerAttr( 1922 resultType, 1923 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1924 case AtomicRMWKind::maxs: 1925 return builder.getIntegerAttr( 1926 resultType, 1927 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1928 case AtomicRMWKind::minf: 1929 return builder.getFloatAttr( 1930 resultType, 1931 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1932 /*Negative=*/false)); 1933 case AtomicRMWKind::mins: 1934 return builder.getIntegerAttr( 1935 resultType, 1936 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1937 case AtomicRMWKind::minu: 1938 return builder.getIntegerAttr( 1939 resultType, 1940 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1941 case AtomicRMWKind::muli: 1942 return builder.getIntegerAttr(resultType, 1); 1943 case AtomicRMWKind::mulf: 1944 return builder.getFloatAttr(resultType, 1); 1945 // TODO: Add remaining reduction operations. 1946 default: 1947 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1948 break; 1949 } 1950 return nullptr; 1951 } 1952 1953 /// Returns the identity value associated with an AtomicRMWKind op. 1954 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 1955 OpBuilder &builder, Location loc) { 1956 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 1957 return builder.create<arith::ConstantOp>(loc, attr); 1958 } 1959 1960 /// Return the value obtained by applying the reduction operation kind 1961 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 1962 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 1963 Location loc, Value lhs, Value rhs) { 1964 switch (op) { 1965 case AtomicRMWKind::addf: 1966 return builder.create<arith::AddFOp>(loc, lhs, rhs); 1967 case AtomicRMWKind::addi: 1968 return builder.create<arith::AddIOp>(loc, lhs, rhs); 1969 case AtomicRMWKind::mulf: 1970 return builder.create<arith::MulFOp>(loc, lhs, rhs); 1971 case AtomicRMWKind::muli: 1972 return builder.create<arith::MulIOp>(loc, lhs, rhs); 1973 case AtomicRMWKind::maxf: 1974 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 1975 case AtomicRMWKind::minf: 1976 return builder.create<arith::MinFOp>(loc, lhs, rhs); 1977 case AtomicRMWKind::maxs: 1978 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 1979 case AtomicRMWKind::mins: 1980 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 1981 case AtomicRMWKind::maxu: 1982 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 1983 case AtomicRMWKind::minu: 1984 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 1985 case AtomicRMWKind::ori: 1986 return builder.create<arith::OrIOp>(loc, lhs, rhs); 1987 case AtomicRMWKind::andi: 1988 return builder.create<arith::AndIOp>(loc, lhs, rhs); 1989 // TODO: Add remaining reduction operations. 1990 default: 1991 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1992 break; 1993 } 1994 return nullptr; 1995 } 1996 1997 //===----------------------------------------------------------------------===// 1998 // TableGen'd op method definitions 1999 //===----------------------------------------------------------------------===// 2000 2001 #define GET_OP_CLASSES 2002 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 2003 2004 //===----------------------------------------------------------------------===// 2005 // TableGen'd enum attribute definitions 2006 //===----------------------------------------------------------------------===// 2007 2008 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 2009