1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 #include "llvm/ADT/APSInt.h" 20 21 using namespace mlir; 22 using namespace mlir::arith; 23 24 //===----------------------------------------------------------------------===// 25 // Pattern helpers 26 //===----------------------------------------------------------------------===// 27 28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 29 Attribute lhs, Attribute rhs) { 30 return builder.getIntegerAttr(res.getType(), 31 lhs.cast<IntegerAttr>().getInt() + 32 rhs.cast<IntegerAttr>().getInt()); 33 } 34 35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 36 Attribute lhs, Attribute rhs) { 37 return builder.getIntegerAttr(res.getType(), 38 lhs.cast<IntegerAttr>().getInt() - 39 rhs.cast<IntegerAttr>().getInt()); 40 } 41 42 /// Invert an integer comparison predicate. 43 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { 44 switch (pred) { 45 case arith::CmpIPredicate::eq: 46 return arith::CmpIPredicate::ne; 47 case arith::CmpIPredicate::ne: 48 return arith::CmpIPredicate::eq; 49 case arith::CmpIPredicate::slt: 50 return arith::CmpIPredicate::sge; 51 case arith::CmpIPredicate::sle: 52 return arith::CmpIPredicate::sgt; 53 case arith::CmpIPredicate::sgt: 54 return arith::CmpIPredicate::sle; 55 case arith::CmpIPredicate::sge: 56 return arith::CmpIPredicate::slt; 57 case arith::CmpIPredicate::ult: 58 return arith::CmpIPredicate::uge; 59 case arith::CmpIPredicate::ule: 60 return arith::CmpIPredicate::ugt; 61 case arith::CmpIPredicate::ugt: 62 return arith::CmpIPredicate::ule; 63 case arith::CmpIPredicate::uge: 64 return arith::CmpIPredicate::ult; 65 } 66 llvm_unreachable("unknown cmpi predicate kind"); 67 } 68 69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 70 return arith::CmpIPredicateAttr::get(pred.getContext(), 71 invertPredicate(pred.getValue())); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // TableGen'd canonicalization patterns 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 #include "ArithmeticCanonicalization.inc" 80 } // namespace 81 82 //===----------------------------------------------------------------------===// 83 // ConstantOp 84 //===----------------------------------------------------------------------===// 85 86 void arith::ConstantOp::getAsmResultNames( 87 function_ref<void(Value, StringRef)> setNameFn) { 88 auto type = getType(); 89 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 90 auto intType = type.dyn_cast<IntegerType>(); 91 92 // Sugar i1 constants with 'true' and 'false'. 93 if (intType && intType.getWidth() == 1) 94 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 95 96 // Otherwise, build a compex name with the value and type. 97 SmallString<32> specialNameBuffer; 98 llvm::raw_svector_ostream specialName(specialNameBuffer); 99 specialName << 'c' << intCst.getInt(); 100 if (intType) 101 specialName << '_' << type; 102 setNameFn(getResult(), specialName.str()); 103 } else { 104 setNameFn(getResult(), "cst"); 105 } 106 } 107 108 /// TODO: disallow arith.constant to return anything other than signless integer 109 /// or float like. 110 LogicalResult arith::ConstantOp::verify() { 111 auto type = getType(); 112 // The value's type must match the return type. 113 if (getValue().getType() != type) { 114 return emitOpError() << "value type " << getValue().getType() 115 << " must match return type: " << type; 116 } 117 // Integer values must be signless. 118 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 119 return emitOpError("integer return type must be signless"); 120 // Any float or elements attribute are acceptable. 121 if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 122 return emitOpError( 123 "value must be an integer, float, or elements attribute"); 124 } 125 return success(); 126 } 127 128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 129 // The value's type must be the same as the provided type. 130 if (value.getType() != type) 131 return false; 132 // Integer values must be signless. 133 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 134 return false; 135 // Integer, float, and element attributes are buildable. 136 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 137 } 138 139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 140 return getValue(); 141 } 142 143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 144 int64_t value, unsigned width) { 145 auto type = builder.getIntegerType(width); 146 arith::ConstantOp::build(builder, result, type, 147 builder.getIntegerAttr(type, value)); 148 } 149 150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 151 int64_t value, Type type) { 152 assert(type.isSignlessInteger() && 153 "ConstantIntOp can only have signless integer type values"); 154 arith::ConstantOp::build(builder, result, type, 155 builder.getIntegerAttr(type, value)); 156 } 157 158 bool arith::ConstantIntOp::classof(Operation *op) { 159 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 160 return constOp.getType().isSignlessInteger(); 161 return false; 162 } 163 164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 165 const APFloat &value, FloatType type) { 166 arith::ConstantOp::build(builder, result, type, 167 builder.getFloatAttr(type, value)); 168 } 169 170 bool arith::ConstantFloatOp::classof(Operation *op) { 171 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 172 return constOp.getType().isa<FloatType>(); 173 return false; 174 } 175 176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 177 int64_t value) { 178 arith::ConstantOp::build(builder, result, builder.getIndexType(), 179 builder.getIndexAttr(value)); 180 } 181 182 bool arith::ConstantIndexOp::classof(Operation *op) { 183 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 184 return constOp.getType().isIndex(); 185 return false; 186 } 187 188 //===----------------------------------------------------------------------===// 189 // AddIOp 190 //===----------------------------------------------------------------------===// 191 192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 193 // addi(x, 0) -> x 194 if (matchPattern(getRhs(), m_Zero())) 195 return getLhs(); 196 197 // addi(subi(a, b), b) -> a 198 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 199 if (getRhs() == sub.getRhs()) 200 return sub.getLhs(); 201 202 // addi(b, subi(a, b)) -> a 203 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 204 if (getLhs() == sub.getRhs()) 205 return sub.getLhs(); 206 207 return constFoldBinaryOp<IntegerAttr>( 208 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 209 } 210 211 void arith::AddIOp::getCanonicalizationPatterns( 212 RewritePatternSet &patterns, MLIRContext *context) { 213 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 214 context); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // SubIOp 219 //===----------------------------------------------------------------------===// 220 221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 222 // subi(x,x) -> 0 223 if (getOperand(0) == getOperand(1)) 224 return Builder(getContext()).getZeroAttr(getType()); 225 // subi(x,0) -> x 226 if (matchPattern(getRhs(), m_Zero())) 227 return getLhs(); 228 229 return constFoldBinaryOp<IntegerAttr>( 230 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 231 } 232 233 void arith::SubIOp::getCanonicalizationPatterns( 234 RewritePatternSet &patterns, MLIRContext *context) { 235 patterns 236 .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 237 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>( 238 context); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // MulIOp 243 //===----------------------------------------------------------------------===// 244 245 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 246 // muli(x, 0) -> 0 247 if (matchPattern(getRhs(), m_Zero())) 248 return getRhs(); 249 // muli(x, 1) -> x 250 if (matchPattern(getRhs(), m_One())) 251 return getOperand(0); 252 // TODO: Handle the overflow case. 253 254 // default folder 255 return constFoldBinaryOp<IntegerAttr>( 256 operands, [](const APInt &a, const APInt &b) { return a * b; }); 257 } 258 259 //===----------------------------------------------------------------------===// 260 // DivUIOp 261 //===----------------------------------------------------------------------===// 262 263 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 264 // Don't fold if it would require a division by zero. 265 bool div0 = false; 266 auto result = 267 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 268 if (div0 || !b) { 269 div0 = true; 270 return a; 271 } 272 return a.udiv(b); 273 }); 274 275 // Fold out division by one. Assumes all tensors of all ones are splats. 276 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 277 if (rhs.getValue() == 1) 278 return getLhs(); 279 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 280 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 281 return getLhs(); 282 } 283 284 return div0 ? Attribute() : result; 285 } 286 287 //===----------------------------------------------------------------------===// 288 // DivSIOp 289 //===----------------------------------------------------------------------===// 290 291 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 292 // Don't fold if it would overflow or if it requires a division by zero. 293 bool overflowOrDiv0 = false; 294 auto result = 295 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 296 if (overflowOrDiv0 || !b) { 297 overflowOrDiv0 = true; 298 return a; 299 } 300 return a.sdiv_ov(b, overflowOrDiv0); 301 }); 302 303 // Fold out division by one. Assumes all tensors of all ones are splats. 304 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 305 if (rhs.getValue() == 1) 306 return getLhs(); 307 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 308 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 309 return getLhs(); 310 } 311 312 return overflowOrDiv0 ? Attribute() : result; 313 } 314 315 //===----------------------------------------------------------------------===// 316 // Ceil and floor division folding helpers 317 //===----------------------------------------------------------------------===// 318 319 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 320 bool &overflow) { 321 // Returns (a-1)/b + 1 322 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 323 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 324 return val.sadd_ov(one, overflow); 325 } 326 327 //===----------------------------------------------------------------------===// 328 // CeilDivUIOp 329 //===----------------------------------------------------------------------===// 330 331 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 332 bool overflowOrDiv0 = false; 333 auto result = 334 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 335 if (overflowOrDiv0 || !b) { 336 overflowOrDiv0 = true; 337 return a; 338 } 339 APInt quotient = a.udiv(b); 340 if (!a.urem(b)) 341 return quotient; 342 APInt one(a.getBitWidth(), 1, true); 343 return quotient.uadd_ov(one, overflowOrDiv0); 344 }); 345 // Fold out ceil division by one. Assumes all tensors of all ones are 346 // splats. 347 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 348 if (rhs.getValue() == 1) 349 return getLhs(); 350 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 351 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 352 return getLhs(); 353 } 354 355 return overflowOrDiv0 ? Attribute() : result; 356 } 357 358 //===----------------------------------------------------------------------===// 359 // CeilDivSIOp 360 //===----------------------------------------------------------------------===// 361 362 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 363 // Don't fold if it would overflow or if it requires a division by zero. 364 bool overflowOrDiv0 = false; 365 auto result = 366 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 367 if (overflowOrDiv0 || !b) { 368 overflowOrDiv0 = true; 369 return a; 370 } 371 if (!a) 372 return a; 373 // After this point we know that neither a or b are zero. 374 unsigned bits = a.getBitWidth(); 375 APInt zero = APInt::getZero(bits); 376 bool aGtZero = a.sgt(zero); 377 bool bGtZero = b.sgt(zero); 378 if (aGtZero && bGtZero) { 379 // Both positive, return ceil(a, b). 380 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 381 } 382 if (!aGtZero && !bGtZero) { 383 // Both negative, return ceil(-a, -b). 384 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 385 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 386 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 387 } 388 if (!aGtZero && bGtZero) { 389 // A is negative, b is positive, return - ( -a / b). 390 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 391 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 392 return zero.ssub_ov(div, overflowOrDiv0); 393 } 394 // A is positive, b is negative, return - (a / -b). 395 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 396 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 397 return zero.ssub_ov(div, overflowOrDiv0); 398 }); 399 400 // Fold out ceil division by one. Assumes all tensors of all ones are 401 // splats. 402 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 403 if (rhs.getValue() == 1) 404 return getLhs(); 405 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 406 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 407 return getLhs(); 408 } 409 410 return overflowOrDiv0 ? Attribute() : result; 411 } 412 413 //===----------------------------------------------------------------------===// 414 // FloorDivSIOp 415 //===----------------------------------------------------------------------===// 416 417 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 418 // Don't fold if it would overflow or if it requires a division by zero. 419 bool overflowOrDiv0 = false; 420 auto result = 421 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 422 if (overflowOrDiv0 || !b) { 423 overflowOrDiv0 = true; 424 return a; 425 } 426 if (!a) 427 return a; 428 // After this point we know that neither a or b are zero. 429 unsigned bits = a.getBitWidth(); 430 APInt zero = APInt::getZero(bits); 431 bool aGtZero = a.sgt(zero); 432 bool bGtZero = b.sgt(zero); 433 if (aGtZero && bGtZero) { 434 // Both positive, return a / b. 435 return a.sdiv_ov(b, overflowOrDiv0); 436 } 437 if (!aGtZero && !bGtZero) { 438 // Both negative, return -a / -b. 439 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 440 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 441 return posA.sdiv_ov(posB, overflowOrDiv0); 442 } 443 if (!aGtZero && bGtZero) { 444 // A is negative, b is positive, return - ceil(-a, b). 445 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 446 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 447 return zero.ssub_ov(ceil, overflowOrDiv0); 448 } 449 // A is positive, b is negative, return - ceil(a, -b). 450 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 451 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 452 return zero.ssub_ov(ceil, overflowOrDiv0); 453 }); 454 455 // Fold out floor division by one. Assumes all tensors of all ones are 456 // splats. 457 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 458 if (rhs.getValue() == 1) 459 return getLhs(); 460 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 461 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 462 return getLhs(); 463 } 464 465 return overflowOrDiv0 ? Attribute() : result; 466 } 467 468 //===----------------------------------------------------------------------===// 469 // RemUIOp 470 //===----------------------------------------------------------------------===// 471 472 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 473 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 474 if (!rhs) 475 return {}; 476 auto rhsValue = rhs.getValue(); 477 478 // x % 1 = 0 479 if (rhsValue.isOneValue()) 480 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 481 482 // Don't fold if it requires division by zero. 483 if (rhsValue.isNullValue()) 484 return {}; 485 486 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 487 if (!lhs) 488 return {}; 489 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 490 } 491 492 //===----------------------------------------------------------------------===// 493 // RemSIOp 494 //===----------------------------------------------------------------------===// 495 496 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 497 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 498 if (!rhs) 499 return {}; 500 auto rhsValue = rhs.getValue(); 501 502 // x % 1 = 0 503 if (rhsValue.isOneValue()) 504 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 505 506 // Don't fold if it requires division by zero. 507 if (rhsValue.isNullValue()) 508 return {}; 509 510 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 511 if (!lhs) 512 return {}; 513 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 514 } 515 516 //===----------------------------------------------------------------------===// 517 // AndIOp 518 //===----------------------------------------------------------------------===// 519 520 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 521 /// and(x, 0) -> 0 522 if (matchPattern(getRhs(), m_Zero())) 523 return getRhs(); 524 /// and(x, allOnes) -> x 525 APInt intValue; 526 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 527 return getLhs(); 528 529 return constFoldBinaryOp<IntegerAttr>( 530 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 531 } 532 533 //===----------------------------------------------------------------------===// 534 // OrIOp 535 //===----------------------------------------------------------------------===// 536 537 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 538 /// or(x, 0) -> x 539 if (matchPattern(getRhs(), m_Zero())) 540 return getLhs(); 541 /// or(x, <all ones>) -> <all ones> 542 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 543 if (rhsAttr.getValue().isAllOnes()) 544 return rhsAttr; 545 546 return constFoldBinaryOp<IntegerAttr>( 547 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 548 } 549 550 //===----------------------------------------------------------------------===// 551 // XOrIOp 552 //===----------------------------------------------------------------------===// 553 554 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 555 /// xor(x, 0) -> x 556 if (matchPattern(getRhs(), m_Zero())) 557 return getLhs(); 558 /// xor(x, x) -> 0 559 if (getLhs() == getRhs()) 560 return Builder(getContext()).getZeroAttr(getType()); 561 /// xor(xor(x, a), a) -> x 562 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) 563 if (prev.getRhs() == getRhs()) 564 return prev.getLhs(); 565 566 return constFoldBinaryOp<IntegerAttr>( 567 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 568 } 569 570 void arith::XOrIOp::getCanonicalizationPatterns( 571 RewritePatternSet &patterns, MLIRContext *context) { 572 patterns.add<XOrINotCmpI>(context); 573 } 574 575 //===----------------------------------------------------------------------===// 576 // AddFOp 577 //===----------------------------------------------------------------------===// 578 579 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 580 // addf(x, -0) -> x 581 if (matchPattern(getRhs(), m_NegZeroFloat())) 582 return getLhs(); 583 584 return constFoldBinaryOp<FloatAttr>( 585 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 586 } 587 588 //===----------------------------------------------------------------------===// 589 // SubFOp 590 //===----------------------------------------------------------------------===// 591 592 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 593 // subf(x, +0) -> x 594 if (matchPattern(getRhs(), m_PosZeroFloat())) 595 return getLhs(); 596 597 return constFoldBinaryOp<FloatAttr>( 598 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 599 } 600 601 //===----------------------------------------------------------------------===// 602 // MaxFOp 603 //===----------------------------------------------------------------------===// 604 605 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) { 606 assert(operands.size() == 2 && "maxf takes two operands"); 607 608 // maxf(x,x) -> x 609 if (getLhs() == getRhs()) 610 return getRhs(); 611 612 // maxf(x, -inf) -> x 613 if (matchPattern(getRhs(), m_NegInfFloat())) 614 return getLhs(); 615 616 return constFoldBinaryOp<FloatAttr>( 617 operands, 618 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); 619 } 620 621 //===----------------------------------------------------------------------===// 622 // MaxSIOp 623 //===----------------------------------------------------------------------===// 624 625 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 626 assert(operands.size() == 2 && "binary operation takes two operands"); 627 628 // maxsi(x,x) -> x 629 if (getLhs() == getRhs()) 630 return getRhs(); 631 632 APInt intValue; 633 // maxsi(x,MAX_INT) -> MAX_INT 634 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 635 intValue.isMaxSignedValue()) 636 return getRhs(); 637 638 // maxsi(x, MIN_INT) -> x 639 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 640 intValue.isMinSignedValue()) 641 return getLhs(); 642 643 return constFoldBinaryOp<IntegerAttr>(operands, 644 [](const APInt &a, const APInt &b) { 645 return llvm::APIntOps::smax(a, b); 646 }); 647 } 648 649 //===----------------------------------------------------------------------===// 650 // MaxUIOp 651 //===----------------------------------------------------------------------===// 652 653 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 654 assert(operands.size() == 2 && "binary operation takes two operands"); 655 656 // maxui(x,x) -> x 657 if (getLhs() == getRhs()) 658 return getRhs(); 659 660 APInt intValue; 661 // maxui(x,MAX_INT) -> MAX_INT 662 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 663 return getRhs(); 664 665 // maxui(x, MIN_INT) -> x 666 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 667 return getLhs(); 668 669 return constFoldBinaryOp<IntegerAttr>(operands, 670 [](const APInt &a, const APInt &b) { 671 return llvm::APIntOps::umax(a, b); 672 }); 673 } 674 675 //===----------------------------------------------------------------------===// 676 // MinFOp 677 //===----------------------------------------------------------------------===// 678 679 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) { 680 assert(operands.size() == 2 && "minf takes two operands"); 681 682 // minf(x,x) -> x 683 if (getLhs() == getRhs()) 684 return getRhs(); 685 686 // minf(x, +inf) -> x 687 if (matchPattern(getRhs(), m_PosInfFloat())) 688 return getLhs(); 689 690 return constFoldBinaryOp<FloatAttr>( 691 operands, 692 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); 693 } 694 695 //===----------------------------------------------------------------------===// 696 // MinSIOp 697 //===----------------------------------------------------------------------===// 698 699 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 700 assert(operands.size() == 2 && "binary operation takes two operands"); 701 702 // minsi(x,x) -> x 703 if (getLhs() == getRhs()) 704 return getRhs(); 705 706 APInt intValue; 707 // minsi(x,MIN_INT) -> MIN_INT 708 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 709 intValue.isMinSignedValue()) 710 return getRhs(); 711 712 // minsi(x, MAX_INT) -> x 713 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 714 intValue.isMaxSignedValue()) 715 return getLhs(); 716 717 return constFoldBinaryOp<IntegerAttr>(operands, 718 [](const APInt &a, const APInt &b) { 719 return llvm::APIntOps::smin(a, b); 720 }); 721 } 722 723 //===----------------------------------------------------------------------===// 724 // MinUIOp 725 //===----------------------------------------------------------------------===// 726 727 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 728 assert(operands.size() == 2 && "binary operation takes two operands"); 729 730 // minui(x,x) -> x 731 if (getLhs() == getRhs()) 732 return getRhs(); 733 734 APInt intValue; 735 // minui(x,MIN_INT) -> MIN_INT 736 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 737 return getRhs(); 738 739 // minui(x, MAX_INT) -> x 740 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 741 return getLhs(); 742 743 return constFoldBinaryOp<IntegerAttr>(operands, 744 [](const APInt &a, const APInt &b) { 745 return llvm::APIntOps::umin(a, b); 746 }); 747 } 748 749 //===----------------------------------------------------------------------===// 750 // MulFOp 751 //===----------------------------------------------------------------------===// 752 753 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 754 APFloat floatValue(0.0f), inverseValue(0.0f); 755 // mulf(x, 1) -> x 756 if (matchPattern(getRhs(), m_OneFloat())) 757 return getLhs(); 758 759 // mulf(1, x) -> x 760 if (matchPattern(getLhs(), m_OneFloat())) 761 return getRhs(); 762 763 return constFoldBinaryOp<FloatAttr>( 764 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 765 } 766 767 //===----------------------------------------------------------------------===// 768 // DivFOp 769 //===----------------------------------------------------------------------===// 770 771 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 772 APFloat floatValue(0.0f), inverseValue(0.0f); 773 // divf(x, 1) -> x 774 if (matchPattern(getRhs(), m_OneFloat())) 775 return getLhs(); 776 777 return constFoldBinaryOp<FloatAttr>( 778 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 779 } 780 781 //===----------------------------------------------------------------------===// 782 // Utility functions for verifying cast ops 783 //===----------------------------------------------------------------------===// 784 785 template <typename... Types> 786 using type_list = std::tuple<Types...> *; 787 788 /// Returns a non-null type only if the provided type is one of the allowed 789 /// types or one of the allowed shaped types of the allowed types. Returns the 790 /// element type if a valid shaped type is provided. 791 template <typename... ShapedTypes, typename... ElementTypes> 792 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 793 type_list<ElementTypes...>) { 794 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 795 return {}; 796 797 auto underlyingType = getElementTypeOrSelf(type); 798 if (!underlyingType.isa<ElementTypes...>()) 799 return {}; 800 801 return underlyingType; 802 } 803 804 /// Get allowed underlying types for vectors and tensors. 805 template <typename... ElementTypes> 806 static Type getTypeIfLike(Type type) { 807 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 808 type_list<ElementTypes...>()); 809 } 810 811 /// Get allowed underlying types for vectors, tensors, and memrefs. 812 template <typename... ElementTypes> 813 static Type getTypeIfLikeOrMemRef(Type type) { 814 return getUnderlyingType(type, 815 type_list<VectorType, TensorType, MemRefType>(), 816 type_list<ElementTypes...>()); 817 } 818 819 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 820 return inputs.size() == 1 && outputs.size() == 1 && 821 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 822 } 823 824 //===----------------------------------------------------------------------===// 825 // Verifiers for integer and floating point extension/truncation ops 826 //===----------------------------------------------------------------------===// 827 828 // Extend ops can only extend to a wider type. 829 template <typename ValType, typename Op> 830 static LogicalResult verifyExtOp(Op op) { 831 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 832 Type dstType = getElementTypeOrSelf(op.getType()); 833 834 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 835 return op.emitError("result type ") 836 << dstType << " must be wider than operand type " << srcType; 837 838 return success(); 839 } 840 841 // Truncate ops can only truncate to a shorter type. 842 template <typename ValType, typename Op> 843 static LogicalResult verifyTruncateOp(Op op) { 844 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 845 Type dstType = getElementTypeOrSelf(op.getType()); 846 847 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 848 return op.emitError("result type ") 849 << dstType << " must be shorter than operand type " << srcType; 850 851 return success(); 852 } 853 854 /// Validate a cast that changes the width of a type. 855 template <template <typename> class WidthComparator, typename... ElementTypes> 856 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 857 if (!areValidCastInputsAndOutputs(inputs, outputs)) 858 return false; 859 860 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 861 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 862 if (!srcType || !dstType) 863 return false; 864 865 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 866 srcType.getIntOrFloatBitWidth()); 867 } 868 869 //===----------------------------------------------------------------------===// 870 // ExtUIOp 871 //===----------------------------------------------------------------------===// 872 873 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 874 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 875 return IntegerAttr::get( 876 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 877 878 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 879 getInMutable().assign(lhs.getIn()); 880 return getResult(); 881 } 882 883 return {}; 884 } 885 886 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 887 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 888 } 889 890 LogicalResult arith::ExtUIOp::verify() { 891 return verifyExtOp<IntegerType>(*this); 892 } 893 894 //===----------------------------------------------------------------------===// 895 // ExtSIOp 896 //===----------------------------------------------------------------------===// 897 898 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 899 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 900 return IntegerAttr::get( 901 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 902 903 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 904 getInMutable().assign(lhs.getIn()); 905 return getResult(); 906 } 907 908 return {}; 909 } 910 911 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 912 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 913 } 914 915 void arith::ExtSIOp::getCanonicalizationPatterns( 916 RewritePatternSet &patterns, MLIRContext *context) { 917 patterns.add<ExtSIOfExtUI>(context); 918 } 919 920 LogicalResult arith::ExtSIOp::verify() { 921 return verifyExtOp<IntegerType>(*this); 922 } 923 924 //===----------------------------------------------------------------------===// 925 // ExtFOp 926 //===----------------------------------------------------------------------===// 927 928 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 929 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 930 } 931 932 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } 933 934 //===----------------------------------------------------------------------===// 935 // TruncIOp 936 //===----------------------------------------------------------------------===// 937 938 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 939 assert(operands.size() == 1 && "unary operation takes one operand"); 940 941 // trunci(zexti(a)) -> a 942 // trunci(sexti(a)) -> a 943 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 944 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 945 return getOperand().getDefiningOp()->getOperand(0); 946 947 // trunci(trunci(a)) -> trunci(a)) 948 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { 949 setOperand(getOperand().getDefiningOp()->getOperand(0)); 950 return getResult(); 951 } 952 953 if (!operands[0]) 954 return {}; 955 956 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 957 return IntegerAttr::get( 958 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 959 } 960 961 return {}; 962 } 963 964 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 965 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 966 } 967 968 LogicalResult arith::TruncIOp::verify() { 969 return verifyTruncateOp<IntegerType>(*this); 970 } 971 972 //===----------------------------------------------------------------------===// 973 // TruncFOp 974 //===----------------------------------------------------------------------===// 975 976 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 977 /// can be represented without precision loss or rounding. 978 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 979 assert(operands.size() == 1 && "unary operation takes one operand"); 980 981 auto constOperand = operands.front(); 982 if (!constOperand || !constOperand.isa<FloatAttr>()) 983 return {}; 984 985 // Convert to target type via 'double'. 986 double sourceValue = 987 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 988 auto targetAttr = FloatAttr::get(getType(), sourceValue); 989 990 // Propagate if constant's value does not change after truncation. 991 if (sourceValue == targetAttr.getValue().convertToDouble()) 992 return targetAttr; 993 994 return {}; 995 } 996 997 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 998 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 999 } 1000 1001 LogicalResult arith::TruncFOp::verify() { 1002 return verifyTruncateOp<FloatType>(*this); 1003 } 1004 1005 //===----------------------------------------------------------------------===// 1006 // AndIOp 1007 //===----------------------------------------------------------------------===// 1008 1009 void arith::AndIOp::getCanonicalizationPatterns( 1010 RewritePatternSet &patterns, MLIRContext *context) { 1011 patterns.add<AndOfExtUI, AndOfExtSI>(context); 1012 } 1013 1014 //===----------------------------------------------------------------------===// 1015 // OrIOp 1016 //===----------------------------------------------------------------------===// 1017 1018 void arith::OrIOp::getCanonicalizationPatterns( 1019 RewritePatternSet &patterns, MLIRContext *context) { 1020 patterns.add<OrOfExtUI, OrOfExtSI>(context); 1021 } 1022 1023 //===----------------------------------------------------------------------===// 1024 // Verifiers for casts between integers and floats. 1025 //===----------------------------------------------------------------------===// 1026 1027 template <typename From, typename To> 1028 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 1029 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1030 return false; 1031 1032 auto srcType = getTypeIfLike<From>(inputs.front()); 1033 auto dstType = getTypeIfLike<To>(outputs.back()); 1034 1035 return srcType && dstType; 1036 } 1037 1038 //===----------------------------------------------------------------------===// 1039 // UIToFPOp 1040 //===----------------------------------------------------------------------===// 1041 1042 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1043 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1044 } 1045 1046 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 1047 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 1048 const APInt &api = lhs.getValue(); 1049 FloatType floatTy = getType().cast<FloatType>(); 1050 APFloat apf(floatTy.getFloatSemantics(), 1051 APInt::getZero(floatTy.getWidth())); 1052 apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven); 1053 return FloatAttr::get(floatTy, apf); 1054 } 1055 return {}; 1056 } 1057 1058 //===----------------------------------------------------------------------===// 1059 // SIToFPOp 1060 //===----------------------------------------------------------------------===// 1061 1062 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1063 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1064 } 1065 1066 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 1067 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 1068 const APInt &api = lhs.getValue(); 1069 FloatType floatTy = getType().cast<FloatType>(); 1070 APFloat apf(floatTy.getFloatSemantics(), 1071 APInt::getZero(floatTy.getWidth())); 1072 apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven); 1073 return FloatAttr::get(floatTy, apf); 1074 } 1075 return {}; 1076 } 1077 //===----------------------------------------------------------------------===// 1078 // FPToUIOp 1079 //===----------------------------------------------------------------------===// 1080 1081 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1082 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1083 } 1084 1085 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 1086 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 1087 const APFloat &apf = lhs.getValue(); 1088 IntegerType intTy = getType().cast<IntegerType>(); 1089 bool ignored; 1090 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 1091 if (APFloat::opInvalidOp == 1092 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1093 // Undefined behavior invoked - the destination type can't represent 1094 // the input constant. 1095 return {}; 1096 } 1097 return IntegerAttr::get(getType(), api); 1098 } 1099 1100 return {}; 1101 } 1102 1103 //===----------------------------------------------------------------------===// 1104 // FPToSIOp 1105 //===----------------------------------------------------------------------===// 1106 1107 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1108 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1109 } 1110 1111 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 1112 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 1113 const APFloat &apf = lhs.getValue(); 1114 IntegerType intTy = getType().cast<IntegerType>(); 1115 bool ignored; 1116 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 1117 if (APFloat::opInvalidOp == 1118 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1119 // Undefined behavior invoked - the destination type can't represent 1120 // the input constant. 1121 return {}; 1122 } 1123 return IntegerAttr::get(getType(), api); 1124 } 1125 1126 return {}; 1127 } 1128 1129 //===----------------------------------------------------------------------===// 1130 // IndexCastOp 1131 //===----------------------------------------------------------------------===// 1132 1133 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1134 TypeRange outputs) { 1135 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1136 return false; 1137 1138 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1139 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1140 if (!srcType || !dstType) 1141 return false; 1142 1143 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1144 (srcType.isSignlessInteger() && dstType.isIndex()); 1145 } 1146 1147 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1148 // index_cast(constant) -> constant 1149 // A little hack because we go through int. Otherwise, the size of the 1150 // constant might need to change. 1151 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1152 return IntegerAttr::get(getType(), value.getInt()); 1153 1154 return {}; 1155 } 1156 1157 void arith::IndexCastOp::getCanonicalizationPatterns( 1158 RewritePatternSet &patterns, MLIRContext *context) { 1159 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1160 } 1161 1162 //===----------------------------------------------------------------------===// 1163 // BitcastOp 1164 //===----------------------------------------------------------------------===// 1165 1166 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1167 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1168 return false; 1169 1170 auto srcType = 1171 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1172 auto dstType = 1173 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1174 if (!srcType || !dstType) 1175 return false; 1176 1177 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1178 } 1179 1180 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1181 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1182 1183 auto resType = getType(); 1184 auto operand = operands[0]; 1185 if (!operand) 1186 return {}; 1187 1188 /// Bitcast dense elements. 1189 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1190 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1191 /// Other shaped types unhandled. 1192 if (resType.isa<ShapedType>()) 1193 return {}; 1194 1195 /// Bitcast integer or float to integer or float. 1196 APInt bits = operand.isa<FloatAttr>() 1197 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1198 : operand.cast<IntegerAttr>().getValue(); 1199 1200 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1201 return FloatAttr::get(resType, 1202 APFloat(resFloatType.getFloatSemantics(), bits)); 1203 return IntegerAttr::get(resType, bits); 1204 } 1205 1206 void arith::BitcastOp::getCanonicalizationPatterns( 1207 RewritePatternSet &patterns, MLIRContext *context) { 1208 patterns.add<BitcastOfBitcast>(context); 1209 } 1210 1211 //===----------------------------------------------------------------------===// 1212 // Helpers for compare ops 1213 //===----------------------------------------------------------------------===// 1214 1215 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1216 static Type getI1SameShape(Type type) { 1217 auto i1Type = IntegerType::get(type.getContext(), 1); 1218 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1219 return RankedTensorType::get(tensorType.getShape(), i1Type); 1220 if (type.isa<UnrankedTensorType>()) 1221 return UnrankedTensorType::get(i1Type); 1222 if (auto vectorType = type.dyn_cast<VectorType>()) 1223 return VectorType::get(vectorType.getShape(), i1Type, 1224 vectorType.getNumScalableDims()); 1225 return i1Type; 1226 } 1227 1228 //===----------------------------------------------------------------------===// 1229 // CmpIOp 1230 //===----------------------------------------------------------------------===// 1231 1232 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1233 /// comparison predicates. 1234 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1235 const APInt &lhs, const APInt &rhs) { 1236 switch (predicate) { 1237 case arith::CmpIPredicate::eq: 1238 return lhs.eq(rhs); 1239 case arith::CmpIPredicate::ne: 1240 return lhs.ne(rhs); 1241 case arith::CmpIPredicate::slt: 1242 return lhs.slt(rhs); 1243 case arith::CmpIPredicate::sle: 1244 return lhs.sle(rhs); 1245 case arith::CmpIPredicate::sgt: 1246 return lhs.sgt(rhs); 1247 case arith::CmpIPredicate::sge: 1248 return lhs.sge(rhs); 1249 case arith::CmpIPredicate::ult: 1250 return lhs.ult(rhs); 1251 case arith::CmpIPredicate::ule: 1252 return lhs.ule(rhs); 1253 case arith::CmpIPredicate::ugt: 1254 return lhs.ugt(rhs); 1255 case arith::CmpIPredicate::uge: 1256 return lhs.uge(rhs); 1257 } 1258 llvm_unreachable("unknown cmpi predicate kind"); 1259 } 1260 1261 /// Returns true if the predicate is true for two equal operands. 1262 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1263 switch (predicate) { 1264 case arith::CmpIPredicate::eq: 1265 case arith::CmpIPredicate::sle: 1266 case arith::CmpIPredicate::sge: 1267 case arith::CmpIPredicate::ule: 1268 case arith::CmpIPredicate::uge: 1269 return true; 1270 case arith::CmpIPredicate::ne: 1271 case arith::CmpIPredicate::slt: 1272 case arith::CmpIPredicate::sgt: 1273 case arith::CmpIPredicate::ult: 1274 case arith::CmpIPredicate::ugt: 1275 return false; 1276 } 1277 llvm_unreachable("unknown cmpi predicate kind"); 1278 } 1279 1280 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1281 auto boolAttr = BoolAttr::get(ctx, value); 1282 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1283 if (!shapedType) 1284 return boolAttr; 1285 return DenseElementsAttr::get(shapedType, boolAttr); 1286 } 1287 1288 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1289 assert(operands.size() == 2 && "cmpi takes two operands"); 1290 1291 // cmpi(pred, x, x) 1292 if (getLhs() == getRhs()) { 1293 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1294 return getBoolAttribute(getType(), getContext(), val); 1295 } 1296 1297 if (matchPattern(getRhs(), m_Zero())) { 1298 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1299 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1300 // extsi(%x : i1 -> iN) != 0 -> %x 1301 if (getPredicate() == arith::CmpIPredicate::ne) { 1302 return extOp.getOperand(); 1303 } 1304 } 1305 } 1306 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1307 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1308 // extui(%x : i1 -> iN) != 0 -> %x 1309 if (getPredicate() == arith::CmpIPredicate::ne) { 1310 return extOp.getOperand(); 1311 } 1312 } 1313 } 1314 } 1315 1316 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1317 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1318 if (!lhs || !rhs) 1319 return {}; 1320 1321 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1322 return BoolAttr::get(getContext(), val); 1323 } 1324 1325 //===----------------------------------------------------------------------===// 1326 // CmpFOp 1327 //===----------------------------------------------------------------------===// 1328 1329 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1330 /// comparison predicates. 1331 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1332 const APFloat &lhs, const APFloat &rhs) { 1333 auto cmpResult = lhs.compare(rhs); 1334 switch (predicate) { 1335 case arith::CmpFPredicate::AlwaysFalse: 1336 return false; 1337 case arith::CmpFPredicate::OEQ: 1338 return cmpResult == APFloat::cmpEqual; 1339 case arith::CmpFPredicate::OGT: 1340 return cmpResult == APFloat::cmpGreaterThan; 1341 case arith::CmpFPredicate::OGE: 1342 return cmpResult == APFloat::cmpGreaterThan || 1343 cmpResult == APFloat::cmpEqual; 1344 case arith::CmpFPredicate::OLT: 1345 return cmpResult == APFloat::cmpLessThan; 1346 case arith::CmpFPredicate::OLE: 1347 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1348 case arith::CmpFPredicate::ONE: 1349 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1350 case arith::CmpFPredicate::ORD: 1351 return cmpResult != APFloat::cmpUnordered; 1352 case arith::CmpFPredicate::UEQ: 1353 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1354 case arith::CmpFPredicate::UGT: 1355 return cmpResult == APFloat::cmpUnordered || 1356 cmpResult == APFloat::cmpGreaterThan; 1357 case arith::CmpFPredicate::UGE: 1358 return cmpResult == APFloat::cmpUnordered || 1359 cmpResult == APFloat::cmpGreaterThan || 1360 cmpResult == APFloat::cmpEqual; 1361 case arith::CmpFPredicate::ULT: 1362 return cmpResult == APFloat::cmpUnordered || 1363 cmpResult == APFloat::cmpLessThan; 1364 case arith::CmpFPredicate::ULE: 1365 return cmpResult == APFloat::cmpUnordered || 1366 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1367 case arith::CmpFPredicate::UNE: 1368 return cmpResult != APFloat::cmpEqual; 1369 case arith::CmpFPredicate::UNO: 1370 return cmpResult == APFloat::cmpUnordered; 1371 case arith::CmpFPredicate::AlwaysTrue: 1372 return true; 1373 } 1374 llvm_unreachable("unknown cmpf predicate kind"); 1375 } 1376 1377 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1378 assert(operands.size() == 2 && "cmpf takes two operands"); 1379 1380 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1381 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1382 1383 // If one operand is NaN, making them both NaN does not change the result. 1384 if (lhs && lhs.getValue().isNaN()) 1385 rhs = lhs; 1386 if (rhs && rhs.getValue().isNaN()) 1387 lhs = rhs; 1388 1389 if (!lhs || !rhs) 1390 return {}; 1391 1392 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1393 return BoolAttr::get(getContext(), val); 1394 } 1395 1396 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> { 1397 public: 1398 using OpRewritePattern<CmpFOp>::OpRewritePattern; 1399 1400 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, 1401 bool isUnsigned) { 1402 using namespace arith; 1403 switch (pred) { 1404 case CmpFPredicate::UEQ: 1405 case CmpFPredicate::OEQ: 1406 return CmpIPredicate::eq; 1407 case CmpFPredicate::UGT: 1408 case CmpFPredicate::OGT: 1409 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; 1410 case CmpFPredicate::UGE: 1411 case CmpFPredicate::OGE: 1412 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; 1413 case CmpFPredicate::ULT: 1414 case CmpFPredicate::OLT: 1415 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; 1416 case CmpFPredicate::ULE: 1417 case CmpFPredicate::OLE: 1418 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; 1419 case CmpFPredicate::UNE: 1420 case CmpFPredicate::ONE: 1421 return CmpIPredicate::ne; 1422 default: 1423 llvm_unreachable("Unexpected predicate!"); 1424 } 1425 } 1426 1427 LogicalResult matchAndRewrite(CmpFOp op, 1428 PatternRewriter &rewriter) const override { 1429 FloatAttr flt; 1430 if (!matchPattern(op.getRhs(), m_Constant(&flt))) 1431 return failure(); 1432 1433 const APFloat &rhs = flt.getValue(); 1434 1435 // Don't attempt to fold a nan. 1436 if (rhs.isNaN()) 1437 return failure(); 1438 1439 // Get the width of the mantissa. We don't want to hack on conversions that 1440 // might lose information from the integer, e.g. "i64 -> float" 1441 FloatType floatTy = op.getRhs().getType().cast<FloatType>(); 1442 int mantissaWidth = floatTy.getFPMantissaWidth(); 1443 if (mantissaWidth <= 0) 1444 return failure(); 1445 1446 bool isUnsigned; 1447 Value intVal; 1448 1449 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) { 1450 isUnsigned = false; 1451 intVal = si.getIn(); 1452 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) { 1453 isUnsigned = true; 1454 intVal = ui.getIn(); 1455 } else { 1456 return failure(); 1457 } 1458 1459 // Check to see that the input is converted from an integer type that is 1460 // small enough that preserves all bits. 1461 auto intTy = intVal.getType().cast<IntegerType>(); 1462 auto intWidth = intTy.getWidth(); 1463 1464 // Number of bits representing values, as opposed to the sign 1465 auto valueBits = isUnsigned ? intWidth : (intWidth - 1); 1466 1467 // Following test does NOT adjust intWidth downwards for signed inputs, 1468 // because the most negative value still requires all the mantissa bits 1469 // to distinguish it from one less than that value. 1470 if ((int)intWidth > mantissaWidth) { 1471 // Conversion would lose accuracy. Check if loss can impact comparison. 1472 int exponent = ilogb(rhs); 1473 if (exponent == APFloat::IEK_Inf) { 1474 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics())); 1475 if (maxExponent < (int)valueBits) { 1476 // Conversion could create infinity. 1477 return failure(); 1478 } 1479 } else { 1480 // Note that if rhs is zero or NaN, then Exp is negative 1481 // and first condition is trivially false. 1482 if (mantissaWidth <= exponent && exponent <= (int)valueBits) { 1483 // Conversion could affect comparison. 1484 return failure(); 1485 } 1486 } 1487 } 1488 1489 // Convert to equivalent cmpi predicate 1490 CmpIPredicate pred; 1491 switch (op.getPredicate()) { 1492 case CmpFPredicate::ORD: 1493 // Int to fp conversion doesn't create a nan (ord checks neither is a nan) 1494 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1495 /*width=*/1); 1496 return success(); 1497 case CmpFPredicate::UNO: 1498 // Int to fp conversion doesn't create a nan (uno checks either is a nan) 1499 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1500 /*width=*/1); 1501 return success(); 1502 default: 1503 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned); 1504 break; 1505 } 1506 1507 if (!isUnsigned) { 1508 // If the rhs value is > SignedMax, fold the comparison. This handles 1509 // +INF and large values. 1510 APFloat signedMax(rhs.getSemantics()); 1511 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true, 1512 APFloat::rmNearestTiesToEven); 1513 if (signedMax < rhs) { // smax < 13123.0 1514 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt || 1515 pred == CmpIPredicate::sle) 1516 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1517 /*width=*/1); 1518 else 1519 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1520 /*width=*/1); 1521 return success(); 1522 } 1523 } else { 1524 // If the rhs value is > UnsignedMax, fold the comparison. This handles 1525 // +INF and large values. 1526 APFloat unsignedMax(rhs.getSemantics()); 1527 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false, 1528 APFloat::rmNearestTiesToEven); 1529 if (unsignedMax < rhs) { // umax < 13123.0 1530 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult || 1531 pred == CmpIPredicate::ule) 1532 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1533 /*width=*/1); 1534 else 1535 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1536 /*width=*/1); 1537 return success(); 1538 } 1539 } 1540 1541 if (!isUnsigned) { 1542 // See if the rhs value is < SignedMin. 1543 APFloat signedMin(rhs.getSemantics()); 1544 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true, 1545 APFloat::rmNearestTiesToEven); 1546 if (signedMin > rhs) { // smin > 12312.0 1547 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt || 1548 pred == CmpIPredicate::sge) 1549 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1550 /*width=*/1); 1551 else 1552 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1553 /*width=*/1); 1554 return success(); 1555 } 1556 } else { 1557 // See if the rhs value is < UnsignedMin. 1558 APFloat unsignedMin(rhs.getSemantics()); 1559 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false, 1560 APFloat::rmNearestTiesToEven); 1561 if (unsignedMin > rhs) { // umin > 12312.0 1562 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt || 1563 pred == CmpIPredicate::uge) 1564 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1565 /*width=*/1); 1566 else 1567 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1568 /*width=*/1); 1569 return success(); 1570 } 1571 } 1572 1573 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or 1574 // [0, UMAX], but it may still be fractional. See if it is fractional by 1575 // casting the FP value to the integer value and back, checking for 1576 // equality. Don't do this for zero, because -0.0 is not fractional. 1577 bool ignored; 1578 APSInt rhsInt(intWidth, isUnsigned); 1579 if (APFloat::opInvalidOp == 1580 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) { 1581 // Undefined behavior invoked - the destination type can't represent 1582 // the input constant. 1583 return failure(); 1584 } 1585 1586 if (!rhs.isZero()) { 1587 APFloat apf(floatTy.getFloatSemantics(), 1588 APInt::getZero(floatTy.getWidth())); 1589 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven); 1590 1591 bool equal = apf == rhs; 1592 if (!equal) { 1593 // If we had a comparison against a fractional value, we have to adjust 1594 // the compare predicate and sometimes the value. rhsInt is rounded 1595 // towards zero at this point. 1596 switch (pred) { 1597 case CmpIPredicate::ne: // (float)int != 4.4 --> true 1598 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1599 /*width=*/1); 1600 return success(); 1601 case CmpIPredicate::eq: // (float)int == 4.4 --> false 1602 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1603 /*width=*/1); 1604 return success(); 1605 case CmpIPredicate::ule: 1606 // (float)int <= 4.4 --> int <= 4 1607 // (float)int <= -4.4 --> false 1608 if (rhs.isNegative()) { 1609 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1610 /*width=*/1); 1611 return success(); 1612 } 1613 break; 1614 case CmpIPredicate::sle: 1615 // (float)int <= 4.4 --> int <= 4 1616 // (float)int <= -4.4 --> int < -4 1617 if (rhs.isNegative()) 1618 pred = CmpIPredicate::slt; 1619 break; 1620 case CmpIPredicate::ult: 1621 // (float)int < -4.4 --> false 1622 // (float)int < 4.4 --> int <= 4 1623 if (rhs.isNegative()) { 1624 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1625 /*width=*/1); 1626 return success(); 1627 } 1628 pred = CmpIPredicate::ule; 1629 break; 1630 case CmpIPredicate::slt: 1631 // (float)int < -4.4 --> int < -4 1632 // (float)int < 4.4 --> int <= 4 1633 if (!rhs.isNegative()) 1634 pred = CmpIPredicate::sle; 1635 break; 1636 case CmpIPredicate::ugt: 1637 // (float)int > 4.4 --> int > 4 1638 // (float)int > -4.4 --> true 1639 if (rhs.isNegative()) { 1640 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1641 /*width=*/1); 1642 return success(); 1643 } 1644 break; 1645 case CmpIPredicate::sgt: 1646 // (float)int > 4.4 --> int > 4 1647 // (float)int > -4.4 --> int >= -4 1648 if (rhs.isNegative()) 1649 pred = CmpIPredicate::sge; 1650 break; 1651 case CmpIPredicate::uge: 1652 // (float)int >= -4.4 --> true 1653 // (float)int >= 4.4 --> int > 4 1654 if (rhs.isNegative()) { 1655 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1656 /*width=*/1); 1657 return success(); 1658 } 1659 pred = CmpIPredicate::ugt; 1660 break; 1661 case CmpIPredicate::sge: 1662 // (float)int >= -4.4 --> int >= -4 1663 // (float)int >= 4.4 --> int > 4 1664 if (!rhs.isNegative()) 1665 pred = CmpIPredicate::sgt; 1666 break; 1667 } 1668 } 1669 } 1670 1671 // Lower this FP comparison into an appropriate integer version of the 1672 // comparison. 1673 rewriter.replaceOpWithNewOp<CmpIOp>( 1674 op, pred, intVal, 1675 rewriter.create<ConstantOp>( 1676 op.getLoc(), intVal.getType(), 1677 rewriter.getIntegerAttr(intVal.getType(), rhsInt))); 1678 return success(); 1679 } 1680 }; 1681 1682 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1683 MLIRContext *context) { 1684 patterns.insert<CmpFIntToFPConst>(context); 1685 } 1686 1687 //===----------------------------------------------------------------------===// 1688 // SelectOp 1689 //===----------------------------------------------------------------------===// 1690 1691 // Transforms a select of a boolean to arithmetic operations 1692 // 1693 // arith.select %arg, %x, %y : i1 1694 // 1695 // becomes 1696 // 1697 // and(%arg, %x) or and(!%arg, %y) 1698 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> { 1699 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1700 1701 LogicalResult matchAndRewrite(arith::SelectOp op, 1702 PatternRewriter &rewriter) const override { 1703 if (!op.getType().isInteger(1)) 1704 return failure(); 1705 1706 Value falseConstant = 1707 rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1); 1708 Value notCondition = rewriter.create<arith::XOrIOp>( 1709 op.getLoc(), op.getCondition(), falseConstant); 1710 1711 Value trueVal = rewriter.create<arith::AndIOp>( 1712 op.getLoc(), op.getCondition(), op.getTrueValue()); 1713 Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition, 1714 op.getFalseValue()); 1715 rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal); 1716 return success(); 1717 } 1718 }; 1719 1720 // select %arg, %c1, %c0 => extui %arg 1721 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { 1722 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1723 1724 LogicalResult matchAndRewrite(arith::SelectOp op, 1725 PatternRewriter &rewriter) const override { 1726 // Cannot extui i1 to i1, or i1 to f32 1727 if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1)) 1728 return failure(); 1729 1730 // select %x, c1, %c0 => extui %arg 1731 if (matchPattern(op.getTrueValue(), m_One())) 1732 if (matchPattern(op.getFalseValue(), m_Zero())) { 1733 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(), 1734 op.getCondition()); 1735 return success(); 1736 } 1737 1738 // select %x, c0, %c1 => extui (xor %arg, true) 1739 if (matchPattern(op.getTrueValue(), m_Zero())) 1740 if (matchPattern(op.getFalseValue(), m_One())) { 1741 rewriter.replaceOpWithNewOp<arith::ExtUIOp>( 1742 op, op.getType(), 1743 rewriter.create<arith::XOrIOp>( 1744 op.getLoc(), op.getCondition(), 1745 rewriter.create<arith::ConstantIntOp>( 1746 op.getLoc(), 1, op.getCondition().getType()))); 1747 return success(); 1748 } 1749 1750 return failure(); 1751 } 1752 }; 1753 1754 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, 1755 MLIRContext *context) { 1756 results.add<SelectI1Simplify, SelectToExtUI>(context); 1757 } 1758 1759 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) { 1760 Value trueVal = getTrueValue(); 1761 Value falseVal = getFalseValue(); 1762 if (trueVal == falseVal) 1763 return trueVal; 1764 1765 Value condition = getCondition(); 1766 1767 // select true, %0, %1 => %0 1768 if (matchPattern(condition, m_One())) 1769 return trueVal; 1770 1771 // select false, %0, %1 => %1 1772 if (matchPattern(condition, m_Zero())) 1773 return falseVal; 1774 1775 // select %x, true, false => %x 1776 if (getType().isInteger(1)) 1777 if (matchPattern(getTrueValue(), m_One())) 1778 if (matchPattern(getFalseValue(), m_Zero())) 1779 return condition; 1780 1781 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { 1782 auto pred = cmp.getPredicate(); 1783 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { 1784 auto cmpLhs = cmp.getLhs(); 1785 auto cmpRhs = cmp.getRhs(); 1786 1787 // %0 = arith.cmpi eq, %arg0, %arg1 1788 // %1 = arith.select %0, %arg0, %arg1 => %arg1 1789 1790 // %0 = arith.cmpi ne, %arg0, %arg1 1791 // %1 = arith.select %0, %arg0, %arg1 => %arg0 1792 1793 if ((cmpLhs == trueVal && cmpRhs == falseVal) || 1794 (cmpRhs == trueVal && cmpLhs == falseVal)) 1795 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; 1796 } 1797 } 1798 return nullptr; 1799 } 1800 1801 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { 1802 Type conditionType, resultType; 1803 SmallVector<OpAsmParser::OperandType, 3> operands; 1804 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || 1805 parser.parseOptionalAttrDict(result.attributes) || 1806 parser.parseColonType(resultType)) 1807 return failure(); 1808 1809 // Check for the explicit condition type if this is a masked tensor or vector. 1810 if (succeeded(parser.parseOptionalComma())) { 1811 conditionType = resultType; 1812 if (parser.parseType(resultType)) 1813 return failure(); 1814 } else { 1815 conditionType = parser.getBuilder().getI1Type(); 1816 } 1817 1818 result.addTypes(resultType); 1819 return parser.resolveOperands(operands, 1820 {conditionType, resultType, resultType}, 1821 parser.getNameLoc(), result.operands); 1822 } 1823 1824 void arith::SelectOp::print(OpAsmPrinter &p) { 1825 p << " " << getOperands(); 1826 p.printOptionalAttrDict((*this)->getAttrs()); 1827 p << " : "; 1828 if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>()) 1829 p << condType << ", "; 1830 p << getType(); 1831 } 1832 1833 LogicalResult arith::SelectOp::verify() { 1834 Type conditionType = getCondition().getType(); 1835 if (conditionType.isSignlessInteger(1)) 1836 return success(); 1837 1838 // If the result type is a vector or tensor, the type can be a mask with the 1839 // same elements. 1840 Type resultType = getType(); 1841 if (!resultType.isa<TensorType, VectorType>()) 1842 return emitOpError() << "expected condition to be a signless i1, but got " 1843 << conditionType; 1844 Type shapedConditionType = getI1SameShape(resultType); 1845 if (conditionType != shapedConditionType) { 1846 return emitOpError() << "expected condition type to have the same shape " 1847 "as the result type, expected " 1848 << shapedConditionType << ", but got " 1849 << conditionType; 1850 } 1851 return success(); 1852 } 1853 1854 //===----------------------------------------------------------------------===// 1855 // Atomic Enum 1856 //===----------------------------------------------------------------------===// 1857 1858 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1859 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1860 OpBuilder &builder, Location loc) { 1861 switch (kind) { 1862 case AtomicRMWKind::maxf: 1863 return builder.getFloatAttr( 1864 resultType, 1865 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1866 /*Negative=*/true)); 1867 case AtomicRMWKind::addf: 1868 case AtomicRMWKind::addi: 1869 case AtomicRMWKind::maxu: 1870 case AtomicRMWKind::ori: 1871 return builder.getZeroAttr(resultType); 1872 case AtomicRMWKind::andi: 1873 return builder.getIntegerAttr( 1874 resultType, 1875 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1876 case AtomicRMWKind::maxs: 1877 return builder.getIntegerAttr( 1878 resultType, 1879 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1880 case AtomicRMWKind::minf: 1881 return builder.getFloatAttr( 1882 resultType, 1883 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1884 /*Negative=*/false)); 1885 case AtomicRMWKind::mins: 1886 return builder.getIntegerAttr( 1887 resultType, 1888 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1889 case AtomicRMWKind::minu: 1890 return builder.getIntegerAttr( 1891 resultType, 1892 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1893 case AtomicRMWKind::muli: 1894 return builder.getIntegerAttr(resultType, 1); 1895 case AtomicRMWKind::mulf: 1896 return builder.getFloatAttr(resultType, 1); 1897 // TODO: Add remaining reduction operations. 1898 default: 1899 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1900 break; 1901 } 1902 return nullptr; 1903 } 1904 1905 /// Returns the identity value associated with an AtomicRMWKind op. 1906 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 1907 OpBuilder &builder, Location loc) { 1908 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 1909 return builder.create<arith::ConstantOp>(loc, attr); 1910 } 1911 1912 /// Return the value obtained by applying the reduction operation kind 1913 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 1914 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 1915 Location loc, Value lhs, Value rhs) { 1916 switch (op) { 1917 case AtomicRMWKind::addf: 1918 return builder.create<arith::AddFOp>(loc, lhs, rhs); 1919 case AtomicRMWKind::addi: 1920 return builder.create<arith::AddIOp>(loc, lhs, rhs); 1921 case AtomicRMWKind::mulf: 1922 return builder.create<arith::MulFOp>(loc, lhs, rhs); 1923 case AtomicRMWKind::muli: 1924 return builder.create<arith::MulIOp>(loc, lhs, rhs); 1925 case AtomicRMWKind::maxf: 1926 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 1927 case AtomicRMWKind::minf: 1928 return builder.create<arith::MinFOp>(loc, lhs, rhs); 1929 case AtomicRMWKind::maxs: 1930 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 1931 case AtomicRMWKind::mins: 1932 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 1933 case AtomicRMWKind::maxu: 1934 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 1935 case AtomicRMWKind::minu: 1936 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 1937 case AtomicRMWKind::ori: 1938 return builder.create<arith::OrIOp>(loc, lhs, rhs); 1939 case AtomicRMWKind::andi: 1940 return builder.create<arith::AndIOp>(loc, lhs, rhs); 1941 // TODO: Add remaining reduction operations. 1942 default: 1943 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1944 break; 1945 } 1946 return nullptr; 1947 } 1948 1949 //===----------------------------------------------------------------------===// 1950 // TableGen'd op method definitions 1951 //===----------------------------------------------------------------------===// 1952 1953 #define GET_OP_CLASSES 1954 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1955 1956 //===----------------------------------------------------------------------===// 1957 // TableGen'd enum attribute definitions 1958 //===----------------------------------------------------------------------===// 1959 1960 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1961