1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "llvm/ADT/SmallString.h" 20 21 #include "llvm/ADT/APSInt.h" 22 23 using namespace mlir; 24 using namespace mlir::arith; 25 26 //===----------------------------------------------------------------------===// 27 // Pattern helpers 28 //===----------------------------------------------------------------------===// 29 30 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 31 Attribute lhs, Attribute rhs) { 32 return builder.getIntegerAttr(res.getType(), 33 lhs.cast<IntegerAttr>().getInt() + 34 rhs.cast<IntegerAttr>().getInt()); 35 } 36 37 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 38 Attribute lhs, Attribute rhs) { 39 return builder.getIntegerAttr(res.getType(), 40 lhs.cast<IntegerAttr>().getInt() - 41 rhs.cast<IntegerAttr>().getInt()); 42 } 43 44 /// Invert an integer comparison predicate. 45 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { 46 switch (pred) { 47 case arith::CmpIPredicate::eq: 48 return arith::CmpIPredicate::ne; 49 case arith::CmpIPredicate::ne: 50 return arith::CmpIPredicate::eq; 51 case arith::CmpIPredicate::slt: 52 return arith::CmpIPredicate::sge; 53 case arith::CmpIPredicate::sle: 54 return arith::CmpIPredicate::sgt; 55 case arith::CmpIPredicate::sgt: 56 return arith::CmpIPredicate::sle; 57 case arith::CmpIPredicate::sge: 58 return arith::CmpIPredicate::slt; 59 case arith::CmpIPredicate::ult: 60 return arith::CmpIPredicate::uge; 61 case arith::CmpIPredicate::ule: 62 return arith::CmpIPredicate::ugt; 63 case arith::CmpIPredicate::ugt: 64 return arith::CmpIPredicate::ule; 65 case arith::CmpIPredicate::uge: 66 return arith::CmpIPredicate::ult; 67 } 68 llvm_unreachable("unknown cmpi predicate kind"); 69 } 70 71 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 72 return arith::CmpIPredicateAttr::get(pred.getContext(), 73 invertPredicate(pred.getValue())); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // TableGen'd canonicalization patterns 78 //===----------------------------------------------------------------------===// 79 80 namespace { 81 #include "ArithmeticCanonicalization.inc" 82 } // namespace 83 84 //===----------------------------------------------------------------------===// 85 // ConstantOp 86 //===----------------------------------------------------------------------===// 87 88 void arith::ConstantOp::getAsmResultNames( 89 function_ref<void(Value, StringRef)> setNameFn) { 90 auto type = getType(); 91 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 92 auto intType = type.dyn_cast<IntegerType>(); 93 94 // Sugar i1 constants with 'true' and 'false'. 95 if (intType && intType.getWidth() == 1) 96 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 97 98 // Otherwise, build a compex name with the value and type. 99 SmallString<32> specialNameBuffer; 100 llvm::raw_svector_ostream specialName(specialNameBuffer); 101 specialName << 'c' << intCst.getInt(); 102 if (intType) 103 specialName << '_' << type; 104 setNameFn(getResult(), specialName.str()); 105 } else { 106 setNameFn(getResult(), "cst"); 107 } 108 } 109 110 /// TODO: disallow arith.constant to return anything other than signless integer 111 /// or float like. 112 LogicalResult arith::ConstantOp::verify() { 113 auto type = getType(); 114 // The value's type must match the return type. 115 if (getValue().getType() != type) { 116 return emitOpError() << "value type " << getValue().getType() 117 << " must match return type: " << type; 118 } 119 // Integer values must be signless. 120 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 121 return emitOpError("integer return type must be signless"); 122 // Any float or elements attribute are acceptable. 123 if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 124 return emitOpError( 125 "value must be an integer, float, or elements attribute"); 126 } 127 return success(); 128 } 129 130 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 131 // The value's type must be the same as the provided type. 132 if (value.getType() != type) 133 return false; 134 // Integer values must be signless. 135 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 136 return false; 137 // Integer, float, and element attributes are buildable. 138 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 139 } 140 141 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 142 return getValue(); 143 } 144 145 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 146 int64_t value, unsigned width) { 147 auto type = builder.getIntegerType(width); 148 arith::ConstantOp::build(builder, result, type, 149 builder.getIntegerAttr(type, value)); 150 } 151 152 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 153 int64_t value, Type type) { 154 assert(type.isSignlessInteger() && 155 "ConstantIntOp can only have signless integer type values"); 156 arith::ConstantOp::build(builder, result, type, 157 builder.getIntegerAttr(type, value)); 158 } 159 160 bool arith::ConstantIntOp::classof(Operation *op) { 161 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 162 return constOp.getType().isSignlessInteger(); 163 return false; 164 } 165 166 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 167 const APFloat &value, FloatType type) { 168 arith::ConstantOp::build(builder, result, type, 169 builder.getFloatAttr(type, value)); 170 } 171 172 bool arith::ConstantFloatOp::classof(Operation *op) { 173 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 174 return constOp.getType().isa<FloatType>(); 175 return false; 176 } 177 178 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 179 int64_t value) { 180 arith::ConstantOp::build(builder, result, builder.getIndexType(), 181 builder.getIndexAttr(value)); 182 } 183 184 bool arith::ConstantIndexOp::classof(Operation *op) { 185 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 186 return constOp.getType().isIndex(); 187 return false; 188 } 189 190 //===----------------------------------------------------------------------===// 191 // AddIOp 192 //===----------------------------------------------------------------------===// 193 194 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 195 // addi(x, 0) -> x 196 if (matchPattern(getRhs(), m_Zero())) 197 return getLhs(); 198 199 // addi(subi(a, b), b) -> a 200 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 201 if (getRhs() == sub.getRhs()) 202 return sub.getLhs(); 203 204 // addi(b, subi(a, b)) -> a 205 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 206 if (getLhs() == sub.getRhs()) 207 return sub.getLhs(); 208 209 return constFoldBinaryOp<IntegerAttr>( 210 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 211 } 212 213 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 214 MLIRContext *context) { 215 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 216 context); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // SubIOp 221 //===----------------------------------------------------------------------===// 222 223 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 224 // subi(x,x) -> 0 225 if (getOperand(0) == getOperand(1)) 226 return Builder(getContext()).getZeroAttr(getType()); 227 // subi(x,0) -> x 228 if (matchPattern(getRhs(), m_Zero())) 229 return getLhs(); 230 231 return constFoldBinaryOp<IntegerAttr>( 232 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 233 } 234 235 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 236 MLIRContext *context) { 237 patterns 238 .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 239 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>( 240 context); 241 } 242 243 //===----------------------------------------------------------------------===// 244 // MulIOp 245 //===----------------------------------------------------------------------===// 246 247 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 248 // muli(x, 0) -> 0 249 if (matchPattern(getRhs(), m_Zero())) 250 return getRhs(); 251 // muli(x, 1) -> x 252 if (matchPattern(getRhs(), m_One())) 253 return getOperand(0); 254 // TODO: Handle the overflow case. 255 256 // default folder 257 return constFoldBinaryOp<IntegerAttr>( 258 operands, [](const APInt &a, const APInt &b) { return a * b; }); 259 } 260 261 //===----------------------------------------------------------------------===// 262 // DivUIOp 263 //===----------------------------------------------------------------------===// 264 265 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 266 // divui (x, 1) -> x. 267 if (matchPattern(getRhs(), m_One())) 268 return getLhs(); 269 270 // Don't fold if it would require a division by zero. 271 bool div0 = false; 272 auto result = 273 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 274 if (div0 || !b) { 275 div0 = true; 276 return a; 277 } 278 return a.udiv(b); 279 }); 280 281 return div0 ? Attribute() : result; 282 } 283 284 //===----------------------------------------------------------------------===// 285 // DivSIOp 286 //===----------------------------------------------------------------------===// 287 288 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 289 // divsi (x, 1) -> x. 290 if (matchPattern(getRhs(), m_One())) 291 return getLhs(); 292 293 // Don't fold if it would overflow or if it requires a division by zero. 294 bool overflowOrDiv0 = false; 295 auto result = 296 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 297 if (overflowOrDiv0 || !b) { 298 overflowOrDiv0 = true; 299 return a; 300 } 301 return a.sdiv_ov(b, overflowOrDiv0); 302 }); 303 304 return overflowOrDiv0 ? Attribute() : result; 305 } 306 307 //===----------------------------------------------------------------------===// 308 // Ceil and floor division folding helpers 309 //===----------------------------------------------------------------------===// 310 311 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 312 bool &overflow) { 313 // Returns (a-1)/b + 1 314 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 315 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 316 return val.sadd_ov(one, overflow); 317 } 318 319 //===----------------------------------------------------------------------===// 320 // CeilDivUIOp 321 //===----------------------------------------------------------------------===// 322 323 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 324 // ceildivui (x, 1) -> x. 325 if (matchPattern(getRhs(), m_One())) 326 return getLhs(); 327 328 bool overflowOrDiv0 = false; 329 auto result = 330 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 331 if (overflowOrDiv0 || !b) { 332 overflowOrDiv0 = true; 333 return a; 334 } 335 APInt quotient = a.udiv(b); 336 if (!a.urem(b)) 337 return quotient; 338 APInt one(a.getBitWidth(), 1, true); 339 return quotient.uadd_ov(one, overflowOrDiv0); 340 }); 341 342 return overflowOrDiv0 ? Attribute() : result; 343 } 344 345 //===----------------------------------------------------------------------===// 346 // CeilDivSIOp 347 //===----------------------------------------------------------------------===// 348 349 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 350 // ceildivsi (x, 1) -> x. 351 if (matchPattern(getRhs(), m_One())) 352 return getLhs(); 353 354 // Don't fold if it would overflow or if it requires a division by zero. 355 bool overflowOrDiv0 = false; 356 auto result = 357 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 358 if (overflowOrDiv0 || !b) { 359 overflowOrDiv0 = true; 360 return a; 361 } 362 if (!a) 363 return a; 364 // After this point we know that neither a or b are zero. 365 unsigned bits = a.getBitWidth(); 366 APInt zero = APInt::getZero(bits); 367 bool aGtZero = a.sgt(zero); 368 bool bGtZero = b.sgt(zero); 369 if (aGtZero && bGtZero) { 370 // Both positive, return ceil(a, b). 371 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 372 } 373 if (!aGtZero && !bGtZero) { 374 // Both negative, return ceil(-a, -b). 375 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 376 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 377 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 378 } 379 if (!aGtZero && bGtZero) { 380 // A is negative, b is positive, return - ( -a / b). 381 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 382 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 383 return zero.ssub_ov(div, overflowOrDiv0); 384 } 385 // A is positive, b is negative, return - (a / -b). 386 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 387 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 388 return zero.ssub_ov(div, overflowOrDiv0); 389 }); 390 391 return overflowOrDiv0 ? Attribute() : result; 392 } 393 394 //===----------------------------------------------------------------------===// 395 // FloorDivSIOp 396 //===----------------------------------------------------------------------===// 397 398 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 399 // floordivsi (x, 1) -> x. 400 if (matchPattern(getRhs(), m_One())) 401 return getLhs(); 402 403 // Don't fold if it would overflow or if it requires a division by zero. 404 bool overflowOrDiv0 = false; 405 auto result = 406 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 407 if (overflowOrDiv0 || !b) { 408 overflowOrDiv0 = true; 409 return a; 410 } 411 if (!a) 412 return a; 413 // After this point we know that neither a or b are zero. 414 unsigned bits = a.getBitWidth(); 415 APInt zero = APInt::getZero(bits); 416 bool aGtZero = a.sgt(zero); 417 bool bGtZero = b.sgt(zero); 418 if (aGtZero && bGtZero) { 419 // Both positive, return a / b. 420 return a.sdiv_ov(b, overflowOrDiv0); 421 } 422 if (!aGtZero && !bGtZero) { 423 // Both negative, return -a / -b. 424 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 425 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 426 return posA.sdiv_ov(posB, overflowOrDiv0); 427 } 428 if (!aGtZero && bGtZero) { 429 // A is negative, b is positive, return - ceil(-a, b). 430 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 431 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 432 return zero.ssub_ov(ceil, overflowOrDiv0); 433 } 434 // A is positive, b is negative, return - ceil(a, -b). 435 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 436 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 437 return zero.ssub_ov(ceil, overflowOrDiv0); 438 }); 439 440 return overflowOrDiv0 ? Attribute() : result; 441 } 442 443 //===----------------------------------------------------------------------===// 444 // RemUIOp 445 //===----------------------------------------------------------------------===// 446 447 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 448 // remui (x, 1) -> 0. 449 if (matchPattern(getRhs(), m_One())) 450 return Builder(getContext()).getZeroAttr(getType()); 451 452 // Don't fold if it would require a division by zero. 453 bool div0 = false; 454 auto result = 455 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 456 if (div0 || b.isNullValue()) { 457 div0 = true; 458 return a; 459 } 460 return a.urem(b); 461 }); 462 463 return div0 ? Attribute() : result; 464 } 465 466 //===----------------------------------------------------------------------===// 467 // RemSIOp 468 //===----------------------------------------------------------------------===// 469 470 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 471 // remsi (x, 1) -> 0. 472 if (matchPattern(getRhs(), m_One())) 473 return Builder(getContext()).getZeroAttr(getType()); 474 475 // Don't fold if it would require a division by zero. 476 bool div0 = false; 477 auto result = 478 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 479 if (div0 || b.isNullValue()) { 480 div0 = true; 481 return a; 482 } 483 return a.srem(b); 484 }); 485 486 return div0 ? Attribute() : result; 487 } 488 489 //===----------------------------------------------------------------------===// 490 // AndIOp 491 //===----------------------------------------------------------------------===// 492 493 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 494 /// and(x, 0) -> 0 495 if (matchPattern(getRhs(), m_Zero())) 496 return getRhs(); 497 /// and(x, allOnes) -> x 498 APInt intValue; 499 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 500 return getLhs(); 501 502 return constFoldBinaryOp<IntegerAttr>( 503 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // OrIOp 508 //===----------------------------------------------------------------------===// 509 510 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 511 /// or(x, 0) -> x 512 if (matchPattern(getRhs(), m_Zero())) 513 return getLhs(); 514 /// or(x, <all ones>) -> <all ones> 515 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 516 if (rhsAttr.getValue().isAllOnes()) 517 return rhsAttr; 518 519 return constFoldBinaryOp<IntegerAttr>( 520 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 521 } 522 523 //===----------------------------------------------------------------------===// 524 // XOrIOp 525 //===----------------------------------------------------------------------===// 526 527 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 528 /// xor(x, 0) -> x 529 if (matchPattern(getRhs(), m_Zero())) 530 return getLhs(); 531 /// xor(x, x) -> 0 532 if (getLhs() == getRhs()) 533 return Builder(getContext()).getZeroAttr(getType()); 534 /// xor(xor(x, a), a) -> x 535 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) 536 if (prev.getRhs() == getRhs()) 537 return prev.getLhs(); 538 539 return constFoldBinaryOp<IntegerAttr>( 540 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 541 } 542 543 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 544 MLIRContext *context) { 545 patterns.add<XOrINotCmpI>(context); 546 } 547 548 //===----------------------------------------------------------------------===// 549 // NegFOp 550 //===----------------------------------------------------------------------===// 551 552 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) { 553 /// negf(negf(x)) -> x 554 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>()) 555 return op.getOperand(); 556 return constFoldUnaryOp<FloatAttr>(operands, 557 [](const APFloat &a) { return -a; }); 558 } 559 560 //===----------------------------------------------------------------------===// 561 // AddFOp 562 //===----------------------------------------------------------------------===// 563 564 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 565 // addf(x, -0) -> x 566 if (matchPattern(getRhs(), m_NegZeroFloat())) 567 return getLhs(); 568 569 return constFoldBinaryOp<FloatAttr>( 570 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 571 } 572 573 //===----------------------------------------------------------------------===// 574 // SubFOp 575 //===----------------------------------------------------------------------===// 576 577 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 578 // subf(x, +0) -> x 579 if (matchPattern(getRhs(), m_PosZeroFloat())) 580 return getLhs(); 581 582 return constFoldBinaryOp<FloatAttr>( 583 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 584 } 585 586 //===----------------------------------------------------------------------===// 587 // MaxFOp 588 //===----------------------------------------------------------------------===// 589 590 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) { 591 assert(operands.size() == 2 && "maxf takes two operands"); 592 593 // maxf(x,x) -> x 594 if (getLhs() == getRhs()) 595 return getRhs(); 596 597 // maxf(x, -inf) -> x 598 if (matchPattern(getRhs(), m_NegInfFloat())) 599 return getLhs(); 600 601 return constFoldBinaryOp<FloatAttr>( 602 operands, 603 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); 604 } 605 606 //===----------------------------------------------------------------------===// 607 // MaxSIOp 608 //===----------------------------------------------------------------------===// 609 610 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 611 assert(operands.size() == 2 && "binary operation takes two operands"); 612 613 // maxsi(x,x) -> x 614 if (getLhs() == getRhs()) 615 return getRhs(); 616 617 APInt intValue; 618 // maxsi(x,MAX_INT) -> MAX_INT 619 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 620 intValue.isMaxSignedValue()) 621 return getRhs(); 622 623 // maxsi(x, MIN_INT) -> x 624 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 625 intValue.isMinSignedValue()) 626 return getLhs(); 627 628 return constFoldBinaryOp<IntegerAttr>(operands, 629 [](const APInt &a, const APInt &b) { 630 return llvm::APIntOps::smax(a, b); 631 }); 632 } 633 634 //===----------------------------------------------------------------------===// 635 // MaxUIOp 636 //===----------------------------------------------------------------------===// 637 638 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 639 assert(operands.size() == 2 && "binary operation takes two operands"); 640 641 // maxui(x,x) -> x 642 if (getLhs() == getRhs()) 643 return getRhs(); 644 645 APInt intValue; 646 // maxui(x,MAX_INT) -> MAX_INT 647 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 648 return getRhs(); 649 650 // maxui(x, MIN_INT) -> x 651 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 652 return getLhs(); 653 654 return constFoldBinaryOp<IntegerAttr>(operands, 655 [](const APInt &a, const APInt &b) { 656 return llvm::APIntOps::umax(a, b); 657 }); 658 } 659 660 //===----------------------------------------------------------------------===// 661 // MinFOp 662 //===----------------------------------------------------------------------===// 663 664 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) { 665 assert(operands.size() == 2 && "minf takes two operands"); 666 667 // minf(x,x) -> x 668 if (getLhs() == getRhs()) 669 return getRhs(); 670 671 // minf(x, +inf) -> x 672 if (matchPattern(getRhs(), m_PosInfFloat())) 673 return getLhs(); 674 675 return constFoldBinaryOp<FloatAttr>( 676 operands, 677 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); 678 } 679 680 //===----------------------------------------------------------------------===// 681 // MinSIOp 682 //===----------------------------------------------------------------------===// 683 684 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 685 assert(operands.size() == 2 && "binary operation takes two operands"); 686 687 // minsi(x,x) -> x 688 if (getLhs() == getRhs()) 689 return getRhs(); 690 691 APInt intValue; 692 // minsi(x,MIN_INT) -> MIN_INT 693 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 694 intValue.isMinSignedValue()) 695 return getRhs(); 696 697 // minsi(x, MAX_INT) -> x 698 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 699 intValue.isMaxSignedValue()) 700 return getLhs(); 701 702 return constFoldBinaryOp<IntegerAttr>(operands, 703 [](const APInt &a, const APInt &b) { 704 return llvm::APIntOps::smin(a, b); 705 }); 706 } 707 708 //===----------------------------------------------------------------------===// 709 // MinUIOp 710 //===----------------------------------------------------------------------===// 711 712 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 713 assert(operands.size() == 2 && "binary operation takes two operands"); 714 715 // minui(x,x) -> x 716 if (getLhs() == getRhs()) 717 return getRhs(); 718 719 APInt intValue; 720 // minui(x,MIN_INT) -> MIN_INT 721 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 722 return getRhs(); 723 724 // minui(x, MAX_INT) -> x 725 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 726 return getLhs(); 727 728 return constFoldBinaryOp<IntegerAttr>(operands, 729 [](const APInt &a, const APInt &b) { 730 return llvm::APIntOps::umin(a, b); 731 }); 732 } 733 734 //===----------------------------------------------------------------------===// 735 // MulFOp 736 //===----------------------------------------------------------------------===// 737 738 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 739 // mulf(x, 1) -> x 740 if (matchPattern(getRhs(), m_OneFloat())) 741 return getLhs(); 742 743 return constFoldBinaryOp<FloatAttr>( 744 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 745 } 746 747 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 748 MLIRContext *context) { 749 patterns.add<MulFOfNegF>(context); 750 } 751 752 //===----------------------------------------------------------------------===// 753 // DivFOp 754 //===----------------------------------------------------------------------===// 755 756 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 757 // divf(x, 1) -> x 758 if (matchPattern(getRhs(), m_OneFloat())) 759 return getLhs(); 760 761 return constFoldBinaryOp<FloatAttr>( 762 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 763 } 764 765 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 766 MLIRContext *context) { 767 patterns.add<DivFOfNegF>(context); 768 } 769 770 //===----------------------------------------------------------------------===// 771 // RemFOp 772 //===----------------------------------------------------------------------===// 773 774 OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) { 775 return constFoldBinaryOp<FloatAttr>(operands, 776 [](const APFloat &a, const APFloat &b) { 777 APFloat result(a); 778 (void)result.remainder(b); 779 return result; 780 }); 781 } 782 783 //===----------------------------------------------------------------------===// 784 // Utility functions for verifying cast ops 785 //===----------------------------------------------------------------------===// 786 787 template <typename... Types> 788 using type_list = std::tuple<Types...> *; 789 790 /// Returns a non-null type only if the provided type is one of the allowed 791 /// types or one of the allowed shaped types of the allowed types. Returns the 792 /// element type if a valid shaped type is provided. 793 template <typename... ShapedTypes, typename... ElementTypes> 794 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 795 type_list<ElementTypes...>) { 796 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 797 return {}; 798 799 auto underlyingType = getElementTypeOrSelf(type); 800 if (!underlyingType.isa<ElementTypes...>()) 801 return {}; 802 803 return underlyingType; 804 } 805 806 /// Get allowed underlying types for vectors and tensors. 807 template <typename... ElementTypes> 808 static Type getTypeIfLike(Type type) { 809 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 810 type_list<ElementTypes...>()); 811 } 812 813 /// Get allowed underlying types for vectors, tensors, and memrefs. 814 template <typename... ElementTypes> 815 static Type getTypeIfLikeOrMemRef(Type type) { 816 return getUnderlyingType(type, 817 type_list<VectorType, TensorType, MemRefType>(), 818 type_list<ElementTypes...>()); 819 } 820 821 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 822 return inputs.size() == 1 && outputs.size() == 1 && 823 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 824 } 825 826 //===----------------------------------------------------------------------===// 827 // Verifiers for integer and floating point extension/truncation ops 828 //===----------------------------------------------------------------------===// 829 830 // Extend ops can only extend to a wider type. 831 template <typename ValType, typename Op> 832 static LogicalResult verifyExtOp(Op op) { 833 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 834 Type dstType = getElementTypeOrSelf(op.getType()); 835 836 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 837 return op.emitError("result type ") 838 << dstType << " must be wider than operand type " << srcType; 839 840 return success(); 841 } 842 843 // Truncate ops can only truncate to a shorter type. 844 template <typename ValType, typename Op> 845 static LogicalResult verifyTruncateOp(Op op) { 846 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 847 Type dstType = getElementTypeOrSelf(op.getType()); 848 849 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 850 return op.emitError("result type ") 851 << dstType << " must be shorter than operand type " << srcType; 852 853 return success(); 854 } 855 856 /// Validate a cast that changes the width of a type. 857 template <template <typename> class WidthComparator, typename... ElementTypes> 858 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 859 if (!areValidCastInputsAndOutputs(inputs, outputs)) 860 return false; 861 862 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 863 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 864 if (!srcType || !dstType) 865 return false; 866 867 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 868 srcType.getIntOrFloatBitWidth()); 869 } 870 871 //===----------------------------------------------------------------------===// 872 // ExtUIOp 873 //===----------------------------------------------------------------------===// 874 875 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 876 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 877 getInMutable().assign(lhs.getIn()); 878 return getResult(); 879 } 880 Type resType = getType(); 881 unsigned bitWidth; 882 if (auto shapedType = resType.dyn_cast<ShapedType>()) 883 bitWidth = shapedType.getElementTypeBitWidth(); 884 else 885 bitWidth = resType.getIntOrFloatBitWidth(); 886 return constFoldCastOp<IntegerAttr, IntegerAttr>( 887 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 888 return a.zext(bitWidth); 889 }); 890 } 891 892 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 893 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 894 } 895 896 LogicalResult arith::ExtUIOp::verify() { 897 return verifyExtOp<IntegerType>(*this); 898 } 899 900 //===----------------------------------------------------------------------===// 901 // ExtSIOp 902 //===----------------------------------------------------------------------===// 903 904 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 905 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 906 getInMutable().assign(lhs.getIn()); 907 return getResult(); 908 } 909 Type resType = getType(); 910 unsigned bitWidth; 911 if (auto shapedType = resType.dyn_cast<ShapedType>()) 912 bitWidth = shapedType.getElementTypeBitWidth(); 913 else 914 bitWidth = resType.getIntOrFloatBitWidth(); 915 return constFoldCastOp<IntegerAttr, IntegerAttr>( 916 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 917 return a.sext(bitWidth); 918 }); 919 } 920 921 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 922 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 923 } 924 925 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 926 MLIRContext *context) { 927 patterns.add<ExtSIOfExtUI>(context); 928 } 929 930 LogicalResult arith::ExtSIOp::verify() { 931 return verifyExtOp<IntegerType>(*this); 932 } 933 934 //===----------------------------------------------------------------------===// 935 // ExtFOp 936 //===----------------------------------------------------------------------===// 937 938 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 939 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 940 } 941 942 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } 943 944 //===----------------------------------------------------------------------===// 945 // TruncIOp 946 //===----------------------------------------------------------------------===// 947 948 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 949 assert(operands.size() == 1 && "unary operation takes one operand"); 950 951 // trunci(zexti(a)) -> a 952 // trunci(sexti(a)) -> a 953 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 954 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 955 return getOperand().getDefiningOp()->getOperand(0); 956 957 // trunci(trunci(a)) -> trunci(a)) 958 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { 959 setOperand(getOperand().getDefiningOp()->getOperand(0)); 960 return getResult(); 961 } 962 963 Type resType = getType(); 964 unsigned bitWidth; 965 if (auto shapedType = resType.dyn_cast<ShapedType>()) 966 bitWidth = shapedType.getElementTypeBitWidth(); 967 else 968 bitWidth = resType.getIntOrFloatBitWidth(); 969 970 return constFoldCastOp<IntegerAttr, IntegerAttr>( 971 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 972 return a.trunc(bitWidth); 973 }); 974 } 975 976 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 977 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 978 } 979 980 LogicalResult arith::TruncIOp::verify() { 981 return verifyTruncateOp<IntegerType>(*this); 982 } 983 984 //===----------------------------------------------------------------------===// 985 // TruncFOp 986 //===----------------------------------------------------------------------===// 987 988 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 989 /// can be represented without precision loss or rounding. 990 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 991 assert(operands.size() == 1 && "unary operation takes one operand"); 992 993 auto constOperand = operands.front(); 994 if (!constOperand || !constOperand.isa<FloatAttr>()) 995 return {}; 996 997 // Convert to target type via 'double'. 998 double sourceValue = 999 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 1000 auto targetAttr = FloatAttr::get(getType(), sourceValue); 1001 1002 // Propagate if constant's value does not change after truncation. 1003 if (sourceValue == targetAttr.getValue().convertToDouble()) 1004 return targetAttr; 1005 1006 return {}; 1007 } 1008 1009 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1010 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 1011 } 1012 1013 LogicalResult arith::TruncFOp::verify() { 1014 return verifyTruncateOp<FloatType>(*this); 1015 } 1016 1017 //===----------------------------------------------------------------------===// 1018 // AndIOp 1019 //===----------------------------------------------------------------------===// 1020 1021 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1022 MLIRContext *context) { 1023 patterns.add<AndOfExtUI, AndOfExtSI>(context); 1024 } 1025 1026 //===----------------------------------------------------------------------===// 1027 // OrIOp 1028 //===----------------------------------------------------------------------===// 1029 1030 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1031 MLIRContext *context) { 1032 patterns.add<OrOfExtUI, OrOfExtSI>(context); 1033 } 1034 1035 //===----------------------------------------------------------------------===// 1036 // Verifiers for casts between integers and floats. 1037 //===----------------------------------------------------------------------===// 1038 1039 template <typename From, typename To> 1040 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 1041 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1042 return false; 1043 1044 auto srcType = getTypeIfLike<From>(inputs.front()); 1045 auto dstType = getTypeIfLike<To>(outputs.back()); 1046 1047 return srcType && dstType; 1048 } 1049 1050 //===----------------------------------------------------------------------===// 1051 // UIToFPOp 1052 //===----------------------------------------------------------------------===// 1053 1054 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1055 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1056 } 1057 1058 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 1059 Type resType = getType(); 1060 Type resEleType; 1061 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1062 resEleType = shapedType.getElementType(); 1063 else 1064 resEleType = resType; 1065 return constFoldCastOp<IntegerAttr, FloatAttr>( 1066 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { 1067 FloatType floatTy = resEleType.cast<FloatType>(); 1068 APFloat apf(floatTy.getFloatSemantics(), 1069 APInt::getZero(floatTy.getWidth())); 1070 apf.convertFromAPInt(a, /*IsSigned=*/false, 1071 APFloat::rmNearestTiesToEven); 1072 return apf; 1073 }); 1074 } 1075 1076 //===----------------------------------------------------------------------===// 1077 // SIToFPOp 1078 //===----------------------------------------------------------------------===// 1079 1080 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1081 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1082 } 1083 1084 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 1085 Type resType = getType(); 1086 Type resEleType; 1087 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1088 resEleType = shapedType.getElementType(); 1089 else 1090 resEleType = resType; 1091 return constFoldCastOp<IntegerAttr, FloatAttr>( 1092 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { 1093 FloatType floatTy = resEleType.cast<FloatType>(); 1094 APFloat apf(floatTy.getFloatSemantics(), 1095 APInt::getZero(floatTy.getWidth())); 1096 apf.convertFromAPInt(a, /*IsSigned=*/true, 1097 APFloat::rmNearestTiesToEven); 1098 return apf; 1099 }); 1100 } 1101 //===----------------------------------------------------------------------===// 1102 // FPToUIOp 1103 //===----------------------------------------------------------------------===// 1104 1105 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1106 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1107 } 1108 1109 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 1110 Type resType = getType(); 1111 Type resEleType; 1112 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1113 resEleType = shapedType.getElementType(); 1114 else 1115 resEleType = resType; 1116 return constFoldCastOp<FloatAttr, IntegerAttr>( 1117 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { 1118 IntegerType intTy = resEleType.cast<IntegerType>(); 1119 bool ignored; 1120 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 1121 castStatus = APFloat::opInvalidOp != 1122 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1123 return api; 1124 }); 1125 } 1126 1127 //===----------------------------------------------------------------------===// 1128 // FPToSIOp 1129 //===----------------------------------------------------------------------===// 1130 1131 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1132 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1133 } 1134 1135 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 1136 Type resType = getType(); 1137 Type resEleType; 1138 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1139 resEleType = shapedType.getElementType(); 1140 else 1141 resEleType = resType; 1142 return constFoldCastOp<FloatAttr, IntegerAttr>( 1143 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { 1144 IntegerType intTy = resEleType.cast<IntegerType>(); 1145 bool ignored; 1146 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 1147 castStatus = APFloat::opInvalidOp != 1148 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1149 return api; 1150 }); 1151 } 1152 1153 //===----------------------------------------------------------------------===// 1154 // IndexCastOp 1155 //===----------------------------------------------------------------------===// 1156 1157 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1158 TypeRange outputs) { 1159 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1160 return false; 1161 1162 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1163 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1164 if (!srcType || !dstType) 1165 return false; 1166 1167 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1168 (srcType.isSignlessInteger() && dstType.isIndex()); 1169 } 1170 1171 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1172 // index_cast(constant) -> constant 1173 // A little hack because we go through int. Otherwise, the size of the 1174 // constant might need to change. 1175 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1176 return IntegerAttr::get(getType(), value.getInt()); 1177 1178 return {}; 1179 } 1180 1181 void arith::IndexCastOp::getCanonicalizationPatterns( 1182 RewritePatternSet &patterns, MLIRContext *context) { 1183 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1184 } 1185 1186 //===----------------------------------------------------------------------===// 1187 // BitcastOp 1188 //===----------------------------------------------------------------------===// 1189 1190 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1191 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1192 return false; 1193 1194 auto srcType = 1195 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1196 auto dstType = 1197 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1198 if (!srcType || !dstType) 1199 return false; 1200 1201 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1202 } 1203 1204 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1205 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1206 1207 auto resType = getType(); 1208 auto operand = operands[0]; 1209 if (!operand) 1210 return {}; 1211 1212 /// Bitcast dense elements. 1213 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1214 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1215 /// Other shaped types unhandled. 1216 if (resType.isa<ShapedType>()) 1217 return {}; 1218 1219 /// Bitcast integer or float to integer or float. 1220 APInt bits = operand.isa<FloatAttr>() 1221 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1222 : operand.cast<IntegerAttr>().getValue(); 1223 1224 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1225 return FloatAttr::get(resType, 1226 APFloat(resFloatType.getFloatSemantics(), bits)); 1227 return IntegerAttr::get(resType, bits); 1228 } 1229 1230 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1231 MLIRContext *context) { 1232 patterns.add<BitcastOfBitcast>(context); 1233 } 1234 1235 //===----------------------------------------------------------------------===// 1236 // Helpers for compare ops 1237 //===----------------------------------------------------------------------===// 1238 1239 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1240 static Type getI1SameShape(Type type) { 1241 auto i1Type = IntegerType::get(type.getContext(), 1); 1242 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1243 return RankedTensorType::get(tensorType.getShape(), i1Type); 1244 if (type.isa<UnrankedTensorType>()) 1245 return UnrankedTensorType::get(i1Type); 1246 if (auto vectorType = type.dyn_cast<VectorType>()) 1247 return VectorType::get(vectorType.getShape(), i1Type, 1248 vectorType.getNumScalableDims()); 1249 return i1Type; 1250 } 1251 1252 //===----------------------------------------------------------------------===// 1253 // CmpIOp 1254 //===----------------------------------------------------------------------===// 1255 1256 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1257 /// comparison predicates. 1258 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1259 const APInt &lhs, const APInt &rhs) { 1260 switch (predicate) { 1261 case arith::CmpIPredicate::eq: 1262 return lhs.eq(rhs); 1263 case arith::CmpIPredicate::ne: 1264 return lhs.ne(rhs); 1265 case arith::CmpIPredicate::slt: 1266 return lhs.slt(rhs); 1267 case arith::CmpIPredicate::sle: 1268 return lhs.sle(rhs); 1269 case arith::CmpIPredicate::sgt: 1270 return lhs.sgt(rhs); 1271 case arith::CmpIPredicate::sge: 1272 return lhs.sge(rhs); 1273 case arith::CmpIPredicate::ult: 1274 return lhs.ult(rhs); 1275 case arith::CmpIPredicate::ule: 1276 return lhs.ule(rhs); 1277 case arith::CmpIPredicate::ugt: 1278 return lhs.ugt(rhs); 1279 case arith::CmpIPredicate::uge: 1280 return lhs.uge(rhs); 1281 } 1282 llvm_unreachable("unknown cmpi predicate kind"); 1283 } 1284 1285 /// Returns true if the predicate is true for two equal operands. 1286 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1287 switch (predicate) { 1288 case arith::CmpIPredicate::eq: 1289 case arith::CmpIPredicate::sle: 1290 case arith::CmpIPredicate::sge: 1291 case arith::CmpIPredicate::ule: 1292 case arith::CmpIPredicate::uge: 1293 return true; 1294 case arith::CmpIPredicate::ne: 1295 case arith::CmpIPredicate::slt: 1296 case arith::CmpIPredicate::sgt: 1297 case arith::CmpIPredicate::ult: 1298 case arith::CmpIPredicate::ugt: 1299 return false; 1300 } 1301 llvm_unreachable("unknown cmpi predicate kind"); 1302 } 1303 1304 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1305 auto boolAttr = BoolAttr::get(ctx, value); 1306 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1307 if (!shapedType) 1308 return boolAttr; 1309 return DenseElementsAttr::get(shapedType, boolAttr); 1310 } 1311 1312 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1313 assert(operands.size() == 2 && "cmpi takes two operands"); 1314 1315 // cmpi(pred, x, x) 1316 if (getLhs() == getRhs()) { 1317 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1318 return getBoolAttribute(getType(), getContext(), val); 1319 } 1320 1321 if (matchPattern(getRhs(), m_Zero())) { 1322 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1323 // extsi(%x : i1 -> iN) != 0 -> %x 1324 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 && 1325 getPredicate() == arith::CmpIPredicate::ne) 1326 return extOp.getOperand(); 1327 } 1328 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1329 // extui(%x : i1 -> iN) != 0 -> %x 1330 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 && 1331 getPredicate() == arith::CmpIPredicate::ne) 1332 return extOp.getOperand(); 1333 } 1334 } 1335 1336 // Move constant to the right side. 1337 if (operands[0] && !operands[1]) { 1338 // Do not use invertPredicate, as it will change eq to ne and vice versa. 1339 using Pred = CmpIPredicate; 1340 const std::pair<Pred, Pred> invPreds[] = { 1341 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge}, 1342 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult}, 1343 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq}, 1344 {Pred::ne, Pred::ne}, 1345 }; 1346 Pred origPred = getPredicate(); 1347 for (auto pred : invPreds) { 1348 if (origPred == pred.first) { 1349 setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second)); 1350 Value lhs = getLhs(); 1351 Value rhs = getRhs(); 1352 getLhsMutable().assign(rhs); 1353 getRhsMutable().assign(lhs); 1354 return getResult(); 1355 } 1356 } 1357 llvm_unreachable("unknown cmpi predicate kind"); 1358 } 1359 1360 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1361 if (!lhs) 1362 return {}; 1363 1364 // We are moving constants to the right side; So if lhs is constant rhs is 1365 // guaranteed to be a constant. 1366 auto rhs = operands.back().cast<IntegerAttr>(); 1367 1368 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1369 return BoolAttr::get(getContext(), val); 1370 } 1371 1372 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1373 MLIRContext *context) { 1374 patterns.insert<CmpIExtSI, CmpIExtUI>(context); 1375 } 1376 1377 //===----------------------------------------------------------------------===// 1378 // CmpFOp 1379 //===----------------------------------------------------------------------===// 1380 1381 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1382 /// comparison predicates. 1383 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1384 const APFloat &lhs, const APFloat &rhs) { 1385 auto cmpResult = lhs.compare(rhs); 1386 switch (predicate) { 1387 case arith::CmpFPredicate::AlwaysFalse: 1388 return false; 1389 case arith::CmpFPredicate::OEQ: 1390 return cmpResult == APFloat::cmpEqual; 1391 case arith::CmpFPredicate::OGT: 1392 return cmpResult == APFloat::cmpGreaterThan; 1393 case arith::CmpFPredicate::OGE: 1394 return cmpResult == APFloat::cmpGreaterThan || 1395 cmpResult == APFloat::cmpEqual; 1396 case arith::CmpFPredicate::OLT: 1397 return cmpResult == APFloat::cmpLessThan; 1398 case arith::CmpFPredicate::OLE: 1399 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1400 case arith::CmpFPredicate::ONE: 1401 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1402 case arith::CmpFPredicate::ORD: 1403 return cmpResult != APFloat::cmpUnordered; 1404 case arith::CmpFPredicate::UEQ: 1405 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1406 case arith::CmpFPredicate::UGT: 1407 return cmpResult == APFloat::cmpUnordered || 1408 cmpResult == APFloat::cmpGreaterThan; 1409 case arith::CmpFPredicate::UGE: 1410 return cmpResult == APFloat::cmpUnordered || 1411 cmpResult == APFloat::cmpGreaterThan || 1412 cmpResult == APFloat::cmpEqual; 1413 case arith::CmpFPredicate::ULT: 1414 return cmpResult == APFloat::cmpUnordered || 1415 cmpResult == APFloat::cmpLessThan; 1416 case arith::CmpFPredicate::ULE: 1417 return cmpResult == APFloat::cmpUnordered || 1418 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1419 case arith::CmpFPredicate::UNE: 1420 return cmpResult != APFloat::cmpEqual; 1421 case arith::CmpFPredicate::UNO: 1422 return cmpResult == APFloat::cmpUnordered; 1423 case arith::CmpFPredicate::AlwaysTrue: 1424 return true; 1425 } 1426 llvm_unreachable("unknown cmpf predicate kind"); 1427 } 1428 1429 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1430 assert(operands.size() == 2 && "cmpf takes two operands"); 1431 1432 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1433 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1434 1435 // If one operand is NaN, making them both NaN does not change the result. 1436 if (lhs && lhs.getValue().isNaN()) 1437 rhs = lhs; 1438 if (rhs && rhs.getValue().isNaN()) 1439 lhs = rhs; 1440 1441 if (!lhs || !rhs) 1442 return {}; 1443 1444 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1445 return BoolAttr::get(getContext(), val); 1446 } 1447 1448 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> { 1449 public: 1450 using OpRewritePattern<CmpFOp>::OpRewritePattern; 1451 1452 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, 1453 bool isUnsigned) { 1454 using namespace arith; 1455 switch (pred) { 1456 case CmpFPredicate::UEQ: 1457 case CmpFPredicate::OEQ: 1458 return CmpIPredicate::eq; 1459 case CmpFPredicate::UGT: 1460 case CmpFPredicate::OGT: 1461 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; 1462 case CmpFPredicate::UGE: 1463 case CmpFPredicate::OGE: 1464 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; 1465 case CmpFPredicate::ULT: 1466 case CmpFPredicate::OLT: 1467 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; 1468 case CmpFPredicate::ULE: 1469 case CmpFPredicate::OLE: 1470 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; 1471 case CmpFPredicate::UNE: 1472 case CmpFPredicate::ONE: 1473 return CmpIPredicate::ne; 1474 default: 1475 llvm_unreachable("Unexpected predicate!"); 1476 } 1477 } 1478 1479 LogicalResult matchAndRewrite(CmpFOp op, 1480 PatternRewriter &rewriter) const override { 1481 FloatAttr flt; 1482 if (!matchPattern(op.getRhs(), m_Constant(&flt))) 1483 return failure(); 1484 1485 const APFloat &rhs = flt.getValue(); 1486 1487 // Don't attempt to fold a nan. 1488 if (rhs.isNaN()) 1489 return failure(); 1490 1491 // Get the width of the mantissa. We don't want to hack on conversions that 1492 // might lose information from the integer, e.g. "i64 -> float" 1493 FloatType floatTy = op.getRhs().getType().cast<FloatType>(); 1494 int mantissaWidth = floatTy.getFPMantissaWidth(); 1495 if (mantissaWidth <= 0) 1496 return failure(); 1497 1498 bool isUnsigned; 1499 Value intVal; 1500 1501 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) { 1502 isUnsigned = false; 1503 intVal = si.getIn(); 1504 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) { 1505 isUnsigned = true; 1506 intVal = ui.getIn(); 1507 } else { 1508 return failure(); 1509 } 1510 1511 // Check to see that the input is converted from an integer type that is 1512 // small enough that preserves all bits. 1513 auto intTy = intVal.getType().cast<IntegerType>(); 1514 auto intWidth = intTy.getWidth(); 1515 1516 // Number of bits representing values, as opposed to the sign 1517 auto valueBits = isUnsigned ? intWidth : (intWidth - 1); 1518 1519 // Following test does NOT adjust intWidth downwards for signed inputs, 1520 // because the most negative value still requires all the mantissa bits 1521 // to distinguish it from one less than that value. 1522 if ((int)intWidth > mantissaWidth) { 1523 // Conversion would lose accuracy. Check if loss can impact comparison. 1524 int exponent = ilogb(rhs); 1525 if (exponent == APFloat::IEK_Inf) { 1526 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics())); 1527 if (maxExponent < (int)valueBits) { 1528 // Conversion could create infinity. 1529 return failure(); 1530 } 1531 } else { 1532 // Note that if rhs is zero or NaN, then Exp is negative 1533 // and first condition is trivially false. 1534 if (mantissaWidth <= exponent && exponent <= (int)valueBits) { 1535 // Conversion could affect comparison. 1536 return failure(); 1537 } 1538 } 1539 } 1540 1541 // Convert to equivalent cmpi predicate 1542 CmpIPredicate pred; 1543 switch (op.getPredicate()) { 1544 case CmpFPredicate::ORD: 1545 // Int to fp conversion doesn't create a nan (ord checks neither is a nan) 1546 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1547 /*width=*/1); 1548 return success(); 1549 case CmpFPredicate::UNO: 1550 // Int to fp conversion doesn't create a nan (uno checks either is a nan) 1551 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1552 /*width=*/1); 1553 return success(); 1554 default: 1555 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned); 1556 break; 1557 } 1558 1559 if (!isUnsigned) { 1560 // If the rhs value is > SignedMax, fold the comparison. This handles 1561 // +INF and large values. 1562 APFloat signedMax(rhs.getSemantics()); 1563 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true, 1564 APFloat::rmNearestTiesToEven); 1565 if (signedMax < rhs) { // smax < 13123.0 1566 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt || 1567 pred == CmpIPredicate::sle) 1568 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1569 /*width=*/1); 1570 else 1571 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1572 /*width=*/1); 1573 return success(); 1574 } 1575 } else { 1576 // If the rhs value is > UnsignedMax, fold the comparison. This handles 1577 // +INF and large values. 1578 APFloat unsignedMax(rhs.getSemantics()); 1579 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false, 1580 APFloat::rmNearestTiesToEven); 1581 if (unsignedMax < rhs) { // umax < 13123.0 1582 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult || 1583 pred == CmpIPredicate::ule) 1584 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1585 /*width=*/1); 1586 else 1587 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1588 /*width=*/1); 1589 return success(); 1590 } 1591 } 1592 1593 if (!isUnsigned) { 1594 // See if the rhs value is < SignedMin. 1595 APFloat signedMin(rhs.getSemantics()); 1596 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true, 1597 APFloat::rmNearestTiesToEven); 1598 if (signedMin > rhs) { // smin > 12312.0 1599 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt || 1600 pred == CmpIPredicate::sge) 1601 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1602 /*width=*/1); 1603 else 1604 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1605 /*width=*/1); 1606 return success(); 1607 } 1608 } else { 1609 // See if the rhs value is < UnsignedMin. 1610 APFloat unsignedMin(rhs.getSemantics()); 1611 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false, 1612 APFloat::rmNearestTiesToEven); 1613 if (unsignedMin > rhs) { // umin > 12312.0 1614 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt || 1615 pred == CmpIPredicate::uge) 1616 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1617 /*width=*/1); 1618 else 1619 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1620 /*width=*/1); 1621 return success(); 1622 } 1623 } 1624 1625 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or 1626 // [0, UMAX], but it may still be fractional. See if it is fractional by 1627 // casting the FP value to the integer value and back, checking for 1628 // equality. Don't do this for zero, because -0.0 is not fractional. 1629 bool ignored; 1630 APSInt rhsInt(intWidth, isUnsigned); 1631 if (APFloat::opInvalidOp == 1632 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) { 1633 // Undefined behavior invoked - the destination type can't represent 1634 // the input constant. 1635 return failure(); 1636 } 1637 1638 if (!rhs.isZero()) { 1639 APFloat apf(floatTy.getFloatSemantics(), 1640 APInt::getZero(floatTy.getWidth())); 1641 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven); 1642 1643 bool equal = apf == rhs; 1644 if (!equal) { 1645 // If we had a comparison against a fractional value, we have to adjust 1646 // the compare predicate and sometimes the value. rhsInt is rounded 1647 // towards zero at this point. 1648 switch (pred) { 1649 case CmpIPredicate::ne: // (float)int != 4.4 --> true 1650 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1651 /*width=*/1); 1652 return success(); 1653 case CmpIPredicate::eq: // (float)int == 4.4 --> false 1654 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1655 /*width=*/1); 1656 return success(); 1657 case CmpIPredicate::ule: 1658 // (float)int <= 4.4 --> int <= 4 1659 // (float)int <= -4.4 --> false 1660 if (rhs.isNegative()) { 1661 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1662 /*width=*/1); 1663 return success(); 1664 } 1665 break; 1666 case CmpIPredicate::sle: 1667 // (float)int <= 4.4 --> int <= 4 1668 // (float)int <= -4.4 --> int < -4 1669 if (rhs.isNegative()) 1670 pred = CmpIPredicate::slt; 1671 break; 1672 case CmpIPredicate::ult: 1673 // (float)int < -4.4 --> false 1674 // (float)int < 4.4 --> int <= 4 1675 if (rhs.isNegative()) { 1676 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1677 /*width=*/1); 1678 return success(); 1679 } 1680 pred = CmpIPredicate::ule; 1681 break; 1682 case CmpIPredicate::slt: 1683 // (float)int < -4.4 --> int < -4 1684 // (float)int < 4.4 --> int <= 4 1685 if (!rhs.isNegative()) 1686 pred = CmpIPredicate::sle; 1687 break; 1688 case CmpIPredicate::ugt: 1689 // (float)int > 4.4 --> int > 4 1690 // (float)int > -4.4 --> true 1691 if (rhs.isNegative()) { 1692 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1693 /*width=*/1); 1694 return success(); 1695 } 1696 break; 1697 case CmpIPredicate::sgt: 1698 // (float)int > 4.4 --> int > 4 1699 // (float)int > -4.4 --> int >= -4 1700 if (rhs.isNegative()) 1701 pred = CmpIPredicate::sge; 1702 break; 1703 case CmpIPredicate::uge: 1704 // (float)int >= -4.4 --> true 1705 // (float)int >= 4.4 --> int > 4 1706 if (rhs.isNegative()) { 1707 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1708 /*width=*/1); 1709 return success(); 1710 } 1711 pred = CmpIPredicate::ugt; 1712 break; 1713 case CmpIPredicate::sge: 1714 // (float)int >= -4.4 --> int >= -4 1715 // (float)int >= 4.4 --> int > 4 1716 if (!rhs.isNegative()) 1717 pred = CmpIPredicate::sgt; 1718 break; 1719 } 1720 } 1721 } 1722 1723 // Lower this FP comparison into an appropriate integer version of the 1724 // comparison. 1725 rewriter.replaceOpWithNewOp<CmpIOp>( 1726 op, pred, intVal, 1727 rewriter.create<ConstantOp>( 1728 op.getLoc(), intVal.getType(), 1729 rewriter.getIntegerAttr(intVal.getType(), rhsInt))); 1730 return success(); 1731 } 1732 }; 1733 1734 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1735 MLIRContext *context) { 1736 patterns.insert<CmpFIntToFPConst>(context); 1737 } 1738 1739 //===----------------------------------------------------------------------===// 1740 // SelectOp 1741 //===----------------------------------------------------------------------===// 1742 1743 // Transforms a select of a boolean to arithmetic operations 1744 // 1745 // arith.select %arg, %x, %y : i1 1746 // 1747 // becomes 1748 // 1749 // and(%arg, %x) or and(!%arg, %y) 1750 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> { 1751 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1752 1753 LogicalResult matchAndRewrite(arith::SelectOp op, 1754 PatternRewriter &rewriter) const override { 1755 if (!op.getType().isInteger(1)) 1756 return failure(); 1757 1758 Value falseConstant = 1759 rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1); 1760 Value notCondition = rewriter.create<arith::XOrIOp>( 1761 op.getLoc(), op.getCondition(), falseConstant); 1762 1763 Value trueVal = rewriter.create<arith::AndIOp>( 1764 op.getLoc(), op.getCondition(), op.getTrueValue()); 1765 Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition, 1766 op.getFalseValue()); 1767 rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal); 1768 return success(); 1769 } 1770 }; 1771 1772 // select %arg, %c1, %c0 => extui %arg 1773 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { 1774 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1775 1776 LogicalResult matchAndRewrite(arith::SelectOp op, 1777 PatternRewriter &rewriter) const override { 1778 // Cannot extui i1 to i1, or i1 to f32 1779 if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1)) 1780 return failure(); 1781 1782 // select %x, c1, %c0 => extui %arg 1783 if (matchPattern(op.getTrueValue(), m_One()) && 1784 matchPattern(op.getFalseValue(), m_Zero())) { 1785 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(), 1786 op.getCondition()); 1787 return success(); 1788 } 1789 1790 // select %x, c0, %c1 => extui (xor %arg, true) 1791 if (matchPattern(op.getTrueValue(), m_Zero()) && 1792 matchPattern(op.getFalseValue(), m_One())) { 1793 rewriter.replaceOpWithNewOp<arith::ExtUIOp>( 1794 op, op.getType(), 1795 rewriter.create<arith::XOrIOp>( 1796 op.getLoc(), op.getCondition(), 1797 rewriter.create<arith::ConstantIntOp>( 1798 op.getLoc(), 1, op.getCondition().getType()))); 1799 return success(); 1800 } 1801 1802 return failure(); 1803 } 1804 }; 1805 1806 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, 1807 MLIRContext *context) { 1808 results.add<SelectI1Simplify, SelectToExtUI>(context); 1809 } 1810 1811 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) { 1812 Value trueVal = getTrueValue(); 1813 Value falseVal = getFalseValue(); 1814 if (trueVal == falseVal) 1815 return trueVal; 1816 1817 Value condition = getCondition(); 1818 1819 // select true, %0, %1 => %0 1820 if (matchPattern(condition, m_One())) 1821 return trueVal; 1822 1823 // select false, %0, %1 => %1 1824 if (matchPattern(condition, m_Zero())) 1825 return falseVal; 1826 1827 // select %x, true, false => %x 1828 if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) && 1829 matchPattern(getFalseValue(), m_Zero())) 1830 return condition; 1831 1832 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { 1833 auto pred = cmp.getPredicate(); 1834 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { 1835 auto cmpLhs = cmp.getLhs(); 1836 auto cmpRhs = cmp.getRhs(); 1837 1838 // %0 = arith.cmpi eq, %arg0, %arg1 1839 // %1 = arith.select %0, %arg0, %arg1 => %arg1 1840 1841 // %0 = arith.cmpi ne, %arg0, %arg1 1842 // %1 = arith.select %0, %arg0, %arg1 => %arg0 1843 1844 if ((cmpLhs == trueVal && cmpRhs == falseVal) || 1845 (cmpRhs == trueVal && cmpLhs == falseVal)) 1846 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; 1847 } 1848 } 1849 return nullptr; 1850 } 1851 1852 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { 1853 Type conditionType, resultType; 1854 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; 1855 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || 1856 parser.parseOptionalAttrDict(result.attributes) || 1857 parser.parseColonType(resultType)) 1858 return failure(); 1859 1860 // Check for the explicit condition type if this is a masked tensor or vector. 1861 if (succeeded(parser.parseOptionalComma())) { 1862 conditionType = resultType; 1863 if (parser.parseType(resultType)) 1864 return failure(); 1865 } else { 1866 conditionType = parser.getBuilder().getI1Type(); 1867 } 1868 1869 result.addTypes(resultType); 1870 return parser.resolveOperands(operands, 1871 {conditionType, resultType, resultType}, 1872 parser.getNameLoc(), result.operands); 1873 } 1874 1875 void arith::SelectOp::print(OpAsmPrinter &p) { 1876 p << " " << getOperands(); 1877 p.printOptionalAttrDict((*this)->getAttrs()); 1878 p << " : "; 1879 if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>()) 1880 p << condType << ", "; 1881 p << getType(); 1882 } 1883 1884 LogicalResult arith::SelectOp::verify() { 1885 Type conditionType = getCondition().getType(); 1886 if (conditionType.isSignlessInteger(1)) 1887 return success(); 1888 1889 // If the result type is a vector or tensor, the type can be a mask with the 1890 // same elements. 1891 Type resultType = getType(); 1892 if (!resultType.isa<TensorType, VectorType>()) 1893 return emitOpError() << "expected condition to be a signless i1, but got " 1894 << conditionType; 1895 Type shapedConditionType = getI1SameShape(resultType); 1896 if (conditionType != shapedConditionType) { 1897 return emitOpError() << "expected condition type to have the same shape " 1898 "as the result type, expected " 1899 << shapedConditionType << ", but got " 1900 << conditionType; 1901 } 1902 return success(); 1903 } 1904 //===----------------------------------------------------------------------===// 1905 // ShLIOp 1906 //===----------------------------------------------------------------------===// 1907 1908 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) { 1909 // Don't fold if shifting more than the bit width. 1910 bool bounded = false; 1911 auto result = constFoldBinaryOp<IntegerAttr>( 1912 operands, [&](const APInt &a, const APInt &b) { 1913 bounded = b.ule(b.getBitWidth()); 1914 return a.shl(b); 1915 }); 1916 return bounded ? result : Attribute(); 1917 } 1918 1919 //===----------------------------------------------------------------------===// 1920 // ShRUIOp 1921 //===----------------------------------------------------------------------===// 1922 1923 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) { 1924 // Don't fold if shifting more than the bit width. 1925 bool bounded = false; 1926 auto result = constFoldBinaryOp<IntegerAttr>( 1927 operands, [&](const APInt &a, const APInt &b) { 1928 bounded = b.ule(b.getBitWidth()); 1929 return a.lshr(b); 1930 }); 1931 return bounded ? result : Attribute(); 1932 } 1933 1934 //===----------------------------------------------------------------------===// 1935 // ShRSIOp 1936 //===----------------------------------------------------------------------===// 1937 1938 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) { 1939 // Don't fold if shifting more than the bit width. 1940 bool bounded = false; 1941 auto result = constFoldBinaryOp<IntegerAttr>( 1942 operands, [&](const APInt &a, const APInt &b) { 1943 bounded = b.ule(b.getBitWidth()); 1944 return a.ashr(b); 1945 }); 1946 return bounded ? result : Attribute(); 1947 } 1948 1949 //===----------------------------------------------------------------------===// 1950 // Atomic Enum 1951 //===----------------------------------------------------------------------===// 1952 1953 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1954 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1955 OpBuilder &builder, Location loc) { 1956 switch (kind) { 1957 case AtomicRMWKind::maxf: 1958 return builder.getFloatAttr( 1959 resultType, 1960 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1961 /*Negative=*/true)); 1962 case AtomicRMWKind::addf: 1963 case AtomicRMWKind::addi: 1964 case AtomicRMWKind::maxu: 1965 case AtomicRMWKind::ori: 1966 return builder.getZeroAttr(resultType); 1967 case AtomicRMWKind::andi: 1968 return builder.getIntegerAttr( 1969 resultType, 1970 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1971 case AtomicRMWKind::maxs: 1972 return builder.getIntegerAttr( 1973 resultType, 1974 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1975 case AtomicRMWKind::minf: 1976 return builder.getFloatAttr( 1977 resultType, 1978 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1979 /*Negative=*/false)); 1980 case AtomicRMWKind::mins: 1981 return builder.getIntegerAttr( 1982 resultType, 1983 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1984 case AtomicRMWKind::minu: 1985 return builder.getIntegerAttr( 1986 resultType, 1987 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1988 case AtomicRMWKind::muli: 1989 return builder.getIntegerAttr(resultType, 1); 1990 case AtomicRMWKind::mulf: 1991 return builder.getFloatAttr(resultType, 1); 1992 // TODO: Add remaining reduction operations. 1993 default: 1994 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1995 break; 1996 } 1997 return nullptr; 1998 } 1999 2000 /// Returns the identity value associated with an AtomicRMWKind op. 2001 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 2002 OpBuilder &builder, Location loc) { 2003 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 2004 return builder.create<arith::ConstantOp>(loc, attr); 2005 } 2006 2007 /// Return the value obtained by applying the reduction operation kind 2008 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 2009 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 2010 Location loc, Value lhs, Value rhs) { 2011 switch (op) { 2012 case AtomicRMWKind::addf: 2013 return builder.create<arith::AddFOp>(loc, lhs, rhs); 2014 case AtomicRMWKind::addi: 2015 return builder.create<arith::AddIOp>(loc, lhs, rhs); 2016 case AtomicRMWKind::mulf: 2017 return builder.create<arith::MulFOp>(loc, lhs, rhs); 2018 case AtomicRMWKind::muli: 2019 return builder.create<arith::MulIOp>(loc, lhs, rhs); 2020 case AtomicRMWKind::maxf: 2021 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 2022 case AtomicRMWKind::minf: 2023 return builder.create<arith::MinFOp>(loc, lhs, rhs); 2024 case AtomicRMWKind::maxs: 2025 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 2026 case AtomicRMWKind::mins: 2027 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 2028 case AtomicRMWKind::maxu: 2029 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 2030 case AtomicRMWKind::minu: 2031 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 2032 case AtomicRMWKind::ori: 2033 return builder.create<arith::OrIOp>(loc, lhs, rhs); 2034 case AtomicRMWKind::andi: 2035 return builder.create<arith::AndIOp>(loc, lhs, rhs); 2036 // TODO: Add remaining reduction operations. 2037 default: 2038 (void)emitOptionalError(loc, "Reduction operation type not supported"); 2039 break; 2040 } 2041 return nullptr; 2042 } 2043 2044 //===----------------------------------------------------------------------===// 2045 // DelinearizeIndexOp 2046 //===----------------------------------------------------------------------===// 2047 2048 void arith::DelinearizeIndexOp::build(OpBuilder &builder, 2049 OperationState &result, 2050 Value linear_index, 2051 ArrayRef<OpFoldResult> basis) { 2052 result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType())); 2053 result.addOperands(linear_index); 2054 SmallVector<Value> basisValues = 2055 llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value { 2056 Optional<int64_t> staticDim = getConstantIntValue(ofr); 2057 if (staticDim.has_value()) 2058 return builder.create<arith::ConstantIndexOp>(result.location, 2059 *staticDim); 2060 return ofr.dyn_cast<Value>(); 2061 })); 2062 result.addOperands(basisValues); 2063 } 2064 2065 LogicalResult arith::DelinearizeIndexOp::verify() { 2066 if (getBasis().empty()) 2067 return emitOpError("basis should not be empty"); 2068 if (getNumResults() != getBasis().size()) 2069 return emitOpError("should return an index for each basis element"); 2070 return success(); 2071 } 2072 2073 //===----------------------------------------------------------------------===// 2074 // TableGen'd op method definitions 2075 //===----------------------------------------------------------------------===// 2076 2077 #define GET_OP_CLASSES 2078 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 2079 2080 //===----------------------------------------------------------------------===// 2081 // TableGen'd enum attribute definitions 2082 //===----------------------------------------------------------------------===// 2083 2084 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 2085