1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "llvm/ADT/SmallString.h" 19 20 #include "llvm/ADT/APSInt.h" 21 22 using namespace mlir; 23 using namespace mlir::arith; 24 25 //===----------------------------------------------------------------------===// 26 // Pattern helpers 27 //===----------------------------------------------------------------------===// 28 29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 30 Attribute lhs, Attribute rhs) { 31 return builder.getIntegerAttr(res.getType(), 32 lhs.cast<IntegerAttr>().getInt() + 33 rhs.cast<IntegerAttr>().getInt()); 34 } 35 36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 37 Attribute lhs, Attribute rhs) { 38 return builder.getIntegerAttr(res.getType(), 39 lhs.cast<IntegerAttr>().getInt() - 40 rhs.cast<IntegerAttr>().getInt()); 41 } 42 43 /// Invert an integer comparison predicate. 44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { 45 switch (pred) { 46 case arith::CmpIPredicate::eq: 47 return arith::CmpIPredicate::ne; 48 case arith::CmpIPredicate::ne: 49 return arith::CmpIPredicate::eq; 50 case arith::CmpIPredicate::slt: 51 return arith::CmpIPredicate::sge; 52 case arith::CmpIPredicate::sle: 53 return arith::CmpIPredicate::sgt; 54 case arith::CmpIPredicate::sgt: 55 return arith::CmpIPredicate::sle; 56 case arith::CmpIPredicate::sge: 57 return arith::CmpIPredicate::slt; 58 case arith::CmpIPredicate::ult: 59 return arith::CmpIPredicate::uge; 60 case arith::CmpIPredicate::ule: 61 return arith::CmpIPredicate::ugt; 62 case arith::CmpIPredicate::ugt: 63 return arith::CmpIPredicate::ule; 64 case arith::CmpIPredicate::uge: 65 return arith::CmpIPredicate::ult; 66 } 67 llvm_unreachable("unknown cmpi predicate kind"); 68 } 69 70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 71 return arith::CmpIPredicateAttr::get(pred.getContext(), 72 invertPredicate(pred.getValue())); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // TableGen'd canonicalization patterns 77 //===----------------------------------------------------------------------===// 78 79 namespace { 80 #include "ArithmeticCanonicalization.inc" 81 } // namespace 82 83 //===----------------------------------------------------------------------===// 84 // ConstantOp 85 //===----------------------------------------------------------------------===// 86 87 void arith::ConstantOp::getAsmResultNames( 88 function_ref<void(Value, StringRef)> setNameFn) { 89 auto type = getType(); 90 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 91 auto intType = type.dyn_cast<IntegerType>(); 92 93 // Sugar i1 constants with 'true' and 'false'. 94 if (intType && intType.getWidth() == 1) 95 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 96 97 // Otherwise, build a compex name with the value and type. 98 SmallString<32> specialNameBuffer; 99 llvm::raw_svector_ostream specialName(specialNameBuffer); 100 specialName << 'c' << intCst.getInt(); 101 if (intType) 102 specialName << '_' << type; 103 setNameFn(getResult(), specialName.str()); 104 } else { 105 setNameFn(getResult(), "cst"); 106 } 107 } 108 109 /// TODO: disallow arith.constant to return anything other than signless integer 110 /// or float like. 111 LogicalResult arith::ConstantOp::verify() { 112 auto type = getType(); 113 // The value's type must match the return type. 114 if (getValue().getType() != type) { 115 return emitOpError() << "value type " << getValue().getType() 116 << " must match return type: " << type; 117 } 118 // Integer values must be signless. 119 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 120 return emitOpError("integer return type must be signless"); 121 // Any float or elements attribute are acceptable. 122 if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 123 return emitOpError( 124 "value must be an integer, float, or elements attribute"); 125 } 126 return success(); 127 } 128 129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 130 // The value's type must be the same as the provided type. 131 if (value.getType() != type) 132 return false; 133 // Integer values must be signless. 134 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 135 return false; 136 // Integer, float, and element attributes are buildable. 137 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 138 } 139 140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 141 return getValue(); 142 } 143 144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 145 int64_t value, unsigned width) { 146 auto type = builder.getIntegerType(width); 147 arith::ConstantOp::build(builder, result, type, 148 builder.getIntegerAttr(type, value)); 149 } 150 151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 152 int64_t value, Type type) { 153 assert(type.isSignlessInteger() && 154 "ConstantIntOp can only have signless integer type values"); 155 arith::ConstantOp::build(builder, result, type, 156 builder.getIntegerAttr(type, value)); 157 } 158 159 bool arith::ConstantIntOp::classof(Operation *op) { 160 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 161 return constOp.getType().isSignlessInteger(); 162 return false; 163 } 164 165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 166 const APFloat &value, FloatType type) { 167 arith::ConstantOp::build(builder, result, type, 168 builder.getFloatAttr(type, value)); 169 } 170 171 bool arith::ConstantFloatOp::classof(Operation *op) { 172 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 173 return constOp.getType().isa<FloatType>(); 174 return false; 175 } 176 177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 178 int64_t value) { 179 arith::ConstantOp::build(builder, result, builder.getIndexType(), 180 builder.getIndexAttr(value)); 181 } 182 183 bool arith::ConstantIndexOp::classof(Operation *op) { 184 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 185 return constOp.getType().isIndex(); 186 return false; 187 } 188 189 //===----------------------------------------------------------------------===// 190 // AddIOp 191 //===----------------------------------------------------------------------===// 192 193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 194 // addi(x, 0) -> x 195 if (matchPattern(getRhs(), m_Zero())) 196 return getLhs(); 197 198 // addi(subi(a, b), b) -> a 199 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 200 if (getRhs() == sub.getRhs()) 201 return sub.getLhs(); 202 203 // addi(b, subi(a, b)) -> a 204 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 205 if (getLhs() == sub.getRhs()) 206 return sub.getLhs(); 207 208 return constFoldBinaryOp<IntegerAttr>( 209 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 210 } 211 212 void arith::AddIOp::getCanonicalizationPatterns( 213 RewritePatternSet &patterns, MLIRContext *context) { 214 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 215 context); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // SubIOp 220 //===----------------------------------------------------------------------===// 221 222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 223 // subi(x,x) -> 0 224 if (getOperand(0) == getOperand(1)) 225 return Builder(getContext()).getZeroAttr(getType()); 226 // subi(x,0) -> x 227 if (matchPattern(getRhs(), m_Zero())) 228 return getLhs(); 229 230 return constFoldBinaryOp<IntegerAttr>( 231 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 232 } 233 234 void arith::SubIOp::getCanonicalizationPatterns( 235 RewritePatternSet &patterns, MLIRContext *context) { 236 patterns 237 .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 238 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>( 239 context); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // MulIOp 244 //===----------------------------------------------------------------------===// 245 246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 247 // muli(x, 0) -> 0 248 if (matchPattern(getRhs(), m_Zero())) 249 return getRhs(); 250 // muli(x, 1) -> x 251 if (matchPattern(getRhs(), m_One())) 252 return getOperand(0); 253 // TODO: Handle the overflow case. 254 255 // default folder 256 return constFoldBinaryOp<IntegerAttr>( 257 operands, [](const APInt &a, const APInt &b) { return a * b; }); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // DivUIOp 262 //===----------------------------------------------------------------------===// 263 264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 265 // Don't fold if it would require a division by zero. 266 bool div0 = false; 267 auto result = 268 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 269 if (div0 || !b) { 270 div0 = true; 271 return a; 272 } 273 return a.udiv(b); 274 }); 275 276 // Fold out division by one. Assumes all tensors of all ones are splats. 277 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 278 if (rhs.getValue() == 1) 279 return getLhs(); 280 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 281 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 282 return getLhs(); 283 } 284 285 return div0 ? Attribute() : result; 286 } 287 288 //===----------------------------------------------------------------------===// 289 // DivSIOp 290 //===----------------------------------------------------------------------===// 291 292 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 293 // Don't fold if it would overflow or if it requires a division by zero. 294 bool overflowOrDiv0 = false; 295 auto result = 296 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 297 if (overflowOrDiv0 || !b) { 298 overflowOrDiv0 = true; 299 return a; 300 } 301 return a.sdiv_ov(b, overflowOrDiv0); 302 }); 303 304 // Fold out division by one. Assumes all tensors of all ones are splats. 305 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 306 if (rhs.getValue() == 1) 307 return getLhs(); 308 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 309 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 310 return getLhs(); 311 } 312 313 return overflowOrDiv0 ? Attribute() : result; 314 } 315 316 //===----------------------------------------------------------------------===// 317 // Ceil and floor division folding helpers 318 //===----------------------------------------------------------------------===// 319 320 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 321 bool &overflow) { 322 // Returns (a-1)/b + 1 323 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 324 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 325 return val.sadd_ov(one, overflow); 326 } 327 328 //===----------------------------------------------------------------------===// 329 // CeilDivUIOp 330 //===----------------------------------------------------------------------===// 331 332 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 333 bool overflowOrDiv0 = false; 334 auto result = 335 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 336 if (overflowOrDiv0 || !b) { 337 overflowOrDiv0 = true; 338 return a; 339 } 340 APInt quotient = a.udiv(b); 341 if (!a.urem(b)) 342 return quotient; 343 APInt one(a.getBitWidth(), 1, true); 344 return quotient.uadd_ov(one, overflowOrDiv0); 345 }); 346 // Fold out ceil division by one. Assumes all tensors of all ones are 347 // splats. 348 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 349 if (rhs.getValue() == 1) 350 return getLhs(); 351 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 352 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 353 return getLhs(); 354 } 355 356 return overflowOrDiv0 ? Attribute() : result; 357 } 358 359 //===----------------------------------------------------------------------===// 360 // CeilDivSIOp 361 //===----------------------------------------------------------------------===// 362 363 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 364 // Don't fold if it would overflow or if it requires a division by zero. 365 bool overflowOrDiv0 = false; 366 auto result = 367 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 368 if (overflowOrDiv0 || !b) { 369 overflowOrDiv0 = true; 370 return a; 371 } 372 if (!a) 373 return a; 374 // After this point we know that neither a or b are zero. 375 unsigned bits = a.getBitWidth(); 376 APInt zero = APInt::getZero(bits); 377 bool aGtZero = a.sgt(zero); 378 bool bGtZero = b.sgt(zero); 379 if (aGtZero && bGtZero) { 380 // Both positive, return ceil(a, b). 381 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 382 } 383 if (!aGtZero && !bGtZero) { 384 // Both negative, return ceil(-a, -b). 385 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 386 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 387 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 388 } 389 if (!aGtZero && bGtZero) { 390 // A is negative, b is positive, return - ( -a / b). 391 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 392 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 393 return zero.ssub_ov(div, overflowOrDiv0); 394 } 395 // A is positive, b is negative, return - (a / -b). 396 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 397 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 398 return zero.ssub_ov(div, overflowOrDiv0); 399 }); 400 401 // Fold out ceil division by one. Assumes all tensors of all ones are 402 // splats. 403 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 404 if (rhs.getValue() == 1) 405 return getLhs(); 406 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 407 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 408 return getLhs(); 409 } 410 411 return overflowOrDiv0 ? Attribute() : result; 412 } 413 414 //===----------------------------------------------------------------------===// 415 // FloorDivSIOp 416 //===----------------------------------------------------------------------===// 417 418 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 419 // Don't fold if it would overflow or if it requires a division by zero. 420 bool overflowOrDiv0 = false; 421 auto result = 422 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 423 if (overflowOrDiv0 || !b) { 424 overflowOrDiv0 = true; 425 return a; 426 } 427 if (!a) 428 return a; 429 // After this point we know that neither a or b are zero. 430 unsigned bits = a.getBitWidth(); 431 APInt zero = APInt::getZero(bits); 432 bool aGtZero = a.sgt(zero); 433 bool bGtZero = b.sgt(zero); 434 if (aGtZero && bGtZero) { 435 // Both positive, return a / b. 436 return a.sdiv_ov(b, overflowOrDiv0); 437 } 438 if (!aGtZero && !bGtZero) { 439 // Both negative, return -a / -b. 440 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 441 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 442 return posA.sdiv_ov(posB, overflowOrDiv0); 443 } 444 if (!aGtZero && bGtZero) { 445 // A is negative, b is positive, return - ceil(-a, b). 446 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 447 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 448 return zero.ssub_ov(ceil, overflowOrDiv0); 449 } 450 // A is positive, b is negative, return - ceil(a, -b). 451 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 452 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 453 return zero.ssub_ov(ceil, overflowOrDiv0); 454 }); 455 456 // Fold out floor division by one. Assumes all tensors of all ones are 457 // splats. 458 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 459 if (rhs.getValue() == 1) 460 return getLhs(); 461 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 462 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 463 return getLhs(); 464 } 465 466 return overflowOrDiv0 ? Attribute() : result; 467 } 468 469 //===----------------------------------------------------------------------===// 470 // RemUIOp 471 //===----------------------------------------------------------------------===// 472 473 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 474 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 475 if (!rhs) 476 return {}; 477 auto rhsValue = rhs.getValue(); 478 479 // x % 1 = 0 480 if (rhsValue.isOneValue()) 481 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 482 483 // Don't fold if it requires division by zero. 484 if (rhsValue.isNullValue()) 485 return {}; 486 487 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 488 if (!lhs) 489 return {}; 490 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 491 } 492 493 //===----------------------------------------------------------------------===// 494 // RemSIOp 495 //===----------------------------------------------------------------------===// 496 497 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 498 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 499 if (!rhs) 500 return {}; 501 auto rhsValue = rhs.getValue(); 502 503 // x % 1 = 0 504 if (rhsValue.isOneValue()) 505 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 506 507 // Don't fold if it requires division by zero. 508 if (rhsValue.isNullValue()) 509 return {}; 510 511 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 512 if (!lhs) 513 return {}; 514 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 515 } 516 517 //===----------------------------------------------------------------------===// 518 // AndIOp 519 //===----------------------------------------------------------------------===// 520 521 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 522 /// and(x, 0) -> 0 523 if (matchPattern(getRhs(), m_Zero())) 524 return getRhs(); 525 /// and(x, allOnes) -> x 526 APInt intValue; 527 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 528 return getLhs(); 529 530 return constFoldBinaryOp<IntegerAttr>( 531 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // OrIOp 536 //===----------------------------------------------------------------------===// 537 538 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 539 /// or(x, 0) -> x 540 if (matchPattern(getRhs(), m_Zero())) 541 return getLhs(); 542 /// or(x, <all ones>) -> <all ones> 543 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 544 if (rhsAttr.getValue().isAllOnes()) 545 return rhsAttr; 546 547 return constFoldBinaryOp<IntegerAttr>( 548 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 549 } 550 551 //===----------------------------------------------------------------------===// 552 // XOrIOp 553 //===----------------------------------------------------------------------===// 554 555 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 556 /// xor(x, 0) -> x 557 if (matchPattern(getRhs(), m_Zero())) 558 return getLhs(); 559 /// xor(x, x) -> 0 560 if (getLhs() == getRhs()) 561 return Builder(getContext()).getZeroAttr(getType()); 562 /// xor(xor(x, a), a) -> x 563 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) 564 if (prev.getRhs() == getRhs()) 565 return prev.getLhs(); 566 567 return constFoldBinaryOp<IntegerAttr>( 568 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 569 } 570 571 void arith::XOrIOp::getCanonicalizationPatterns( 572 RewritePatternSet &patterns, MLIRContext *context) { 573 patterns.add<XOrINotCmpI>(context); 574 } 575 576 //===----------------------------------------------------------------------===// 577 // NegFOp 578 //===----------------------------------------------------------------------===// 579 580 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) { 581 return constFoldUnaryOp<FloatAttr>(operands, 582 [](const APFloat &a) { return -a; }); 583 } 584 585 //===----------------------------------------------------------------------===// 586 // AddFOp 587 //===----------------------------------------------------------------------===// 588 589 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 590 // addf(x, -0) -> x 591 if (matchPattern(getRhs(), m_NegZeroFloat())) 592 return getLhs(); 593 594 return constFoldBinaryOp<FloatAttr>( 595 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 596 } 597 598 //===----------------------------------------------------------------------===// 599 // SubFOp 600 //===----------------------------------------------------------------------===// 601 602 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 603 // subf(x, +0) -> x 604 if (matchPattern(getRhs(), m_PosZeroFloat())) 605 return getLhs(); 606 607 return constFoldBinaryOp<FloatAttr>( 608 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 609 } 610 611 //===----------------------------------------------------------------------===// 612 // MaxFOp 613 //===----------------------------------------------------------------------===// 614 615 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) { 616 assert(operands.size() == 2 && "maxf takes two operands"); 617 618 // maxf(x,x) -> x 619 if (getLhs() == getRhs()) 620 return getRhs(); 621 622 // maxf(x, -inf) -> x 623 if (matchPattern(getRhs(), m_NegInfFloat())) 624 return getLhs(); 625 626 return constFoldBinaryOp<FloatAttr>( 627 operands, 628 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); 629 } 630 631 //===----------------------------------------------------------------------===// 632 // MaxSIOp 633 //===----------------------------------------------------------------------===// 634 635 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 636 assert(operands.size() == 2 && "binary operation takes two operands"); 637 638 // maxsi(x,x) -> x 639 if (getLhs() == getRhs()) 640 return getRhs(); 641 642 APInt intValue; 643 // maxsi(x,MAX_INT) -> MAX_INT 644 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 645 intValue.isMaxSignedValue()) 646 return getRhs(); 647 648 // maxsi(x, MIN_INT) -> x 649 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 650 intValue.isMinSignedValue()) 651 return getLhs(); 652 653 return constFoldBinaryOp<IntegerAttr>(operands, 654 [](const APInt &a, const APInt &b) { 655 return llvm::APIntOps::smax(a, b); 656 }); 657 } 658 659 //===----------------------------------------------------------------------===// 660 // MaxUIOp 661 //===----------------------------------------------------------------------===// 662 663 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 664 assert(operands.size() == 2 && "binary operation takes two operands"); 665 666 // maxui(x,x) -> x 667 if (getLhs() == getRhs()) 668 return getRhs(); 669 670 APInt intValue; 671 // maxui(x,MAX_INT) -> MAX_INT 672 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 673 return getRhs(); 674 675 // maxui(x, MIN_INT) -> x 676 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 677 return getLhs(); 678 679 return constFoldBinaryOp<IntegerAttr>(operands, 680 [](const APInt &a, const APInt &b) { 681 return llvm::APIntOps::umax(a, b); 682 }); 683 } 684 685 //===----------------------------------------------------------------------===// 686 // MinFOp 687 //===----------------------------------------------------------------------===// 688 689 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) { 690 assert(operands.size() == 2 && "minf takes two operands"); 691 692 // minf(x,x) -> x 693 if (getLhs() == getRhs()) 694 return getRhs(); 695 696 // minf(x, +inf) -> x 697 if (matchPattern(getRhs(), m_PosInfFloat())) 698 return getLhs(); 699 700 return constFoldBinaryOp<FloatAttr>( 701 operands, 702 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); 703 } 704 705 //===----------------------------------------------------------------------===// 706 // MinSIOp 707 //===----------------------------------------------------------------------===// 708 709 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 710 assert(operands.size() == 2 && "binary operation takes two operands"); 711 712 // minsi(x,x) -> x 713 if (getLhs() == getRhs()) 714 return getRhs(); 715 716 APInt intValue; 717 // minsi(x,MIN_INT) -> MIN_INT 718 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 719 intValue.isMinSignedValue()) 720 return getRhs(); 721 722 // minsi(x, MAX_INT) -> x 723 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 724 intValue.isMaxSignedValue()) 725 return getLhs(); 726 727 return constFoldBinaryOp<IntegerAttr>(operands, 728 [](const APInt &a, const APInt &b) { 729 return llvm::APIntOps::smin(a, b); 730 }); 731 } 732 733 //===----------------------------------------------------------------------===// 734 // MinUIOp 735 //===----------------------------------------------------------------------===// 736 737 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 738 assert(operands.size() == 2 && "binary operation takes two operands"); 739 740 // minui(x,x) -> x 741 if (getLhs() == getRhs()) 742 return getRhs(); 743 744 APInt intValue; 745 // minui(x,MIN_INT) -> MIN_INT 746 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 747 return getRhs(); 748 749 // minui(x, MAX_INT) -> x 750 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 751 return getLhs(); 752 753 return constFoldBinaryOp<IntegerAttr>(operands, 754 [](const APInt &a, const APInt &b) { 755 return llvm::APIntOps::umin(a, b); 756 }); 757 } 758 759 //===----------------------------------------------------------------------===// 760 // MulFOp 761 //===----------------------------------------------------------------------===// 762 763 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 764 // mulf(x, 1) -> x 765 if (matchPattern(getRhs(), m_OneFloat())) 766 return getLhs(); 767 768 return constFoldBinaryOp<FloatAttr>( 769 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 770 } 771 772 //===----------------------------------------------------------------------===// 773 // DivFOp 774 //===----------------------------------------------------------------------===// 775 776 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 777 // divf(x, 1) -> x 778 if (matchPattern(getRhs(), m_OneFloat())) 779 return getLhs(); 780 781 return constFoldBinaryOp<FloatAttr>( 782 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 783 } 784 785 //===----------------------------------------------------------------------===// 786 // Utility functions for verifying cast ops 787 //===----------------------------------------------------------------------===// 788 789 template <typename... Types> 790 using type_list = std::tuple<Types...> *; 791 792 /// Returns a non-null type only if the provided type is one of the allowed 793 /// types or one of the allowed shaped types of the allowed types. Returns the 794 /// element type if a valid shaped type is provided. 795 template <typename... ShapedTypes, typename... ElementTypes> 796 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 797 type_list<ElementTypes...>) { 798 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 799 return {}; 800 801 auto underlyingType = getElementTypeOrSelf(type); 802 if (!underlyingType.isa<ElementTypes...>()) 803 return {}; 804 805 return underlyingType; 806 } 807 808 /// Get allowed underlying types for vectors and tensors. 809 template <typename... ElementTypes> 810 static Type getTypeIfLike(Type type) { 811 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 812 type_list<ElementTypes...>()); 813 } 814 815 /// Get allowed underlying types for vectors, tensors, and memrefs. 816 template <typename... ElementTypes> 817 static Type getTypeIfLikeOrMemRef(Type type) { 818 return getUnderlyingType(type, 819 type_list<VectorType, TensorType, MemRefType>(), 820 type_list<ElementTypes...>()); 821 } 822 823 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 824 return inputs.size() == 1 && outputs.size() == 1 && 825 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 826 } 827 828 //===----------------------------------------------------------------------===// 829 // Verifiers for integer and floating point extension/truncation ops 830 //===----------------------------------------------------------------------===// 831 832 // Extend ops can only extend to a wider type. 833 template <typename ValType, typename Op> 834 static LogicalResult verifyExtOp(Op op) { 835 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 836 Type dstType = getElementTypeOrSelf(op.getType()); 837 838 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 839 return op.emitError("result type ") 840 << dstType << " must be wider than operand type " << srcType; 841 842 return success(); 843 } 844 845 // Truncate ops can only truncate to a shorter type. 846 template <typename ValType, typename Op> 847 static LogicalResult verifyTruncateOp(Op op) { 848 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 849 Type dstType = getElementTypeOrSelf(op.getType()); 850 851 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 852 return op.emitError("result type ") 853 << dstType << " must be shorter than operand type " << srcType; 854 855 return success(); 856 } 857 858 /// Validate a cast that changes the width of a type. 859 template <template <typename> class WidthComparator, typename... ElementTypes> 860 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 861 if (!areValidCastInputsAndOutputs(inputs, outputs)) 862 return false; 863 864 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 865 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 866 if (!srcType || !dstType) 867 return false; 868 869 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 870 srcType.getIntOrFloatBitWidth()); 871 } 872 873 //===----------------------------------------------------------------------===// 874 // ExtUIOp 875 //===----------------------------------------------------------------------===// 876 877 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 878 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 879 getInMutable().assign(lhs.getIn()); 880 return getResult(); 881 } 882 Type resType = getType(); 883 unsigned bitWidth; 884 if (auto shapedType = resType.dyn_cast<ShapedType>()) 885 bitWidth = shapedType.getElementTypeBitWidth(); 886 else 887 bitWidth = resType.getIntOrFloatBitWidth(); 888 return constFoldCastOp<IntegerAttr, IntegerAttr>( 889 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 890 return a.zext(bitWidth); 891 }); 892 } 893 894 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 895 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 896 } 897 898 LogicalResult arith::ExtUIOp::verify() { 899 return verifyExtOp<IntegerType>(*this); 900 } 901 902 //===----------------------------------------------------------------------===// 903 // ExtSIOp 904 //===----------------------------------------------------------------------===// 905 906 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 907 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 908 getInMutable().assign(lhs.getIn()); 909 return getResult(); 910 } 911 Type resType = getType(); 912 unsigned bitWidth; 913 if (auto shapedType = resType.dyn_cast<ShapedType>()) 914 bitWidth = shapedType.getElementTypeBitWidth(); 915 else 916 bitWidth = resType.getIntOrFloatBitWidth(); 917 return constFoldCastOp<IntegerAttr, IntegerAttr>( 918 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 919 return a.sext(bitWidth); 920 }); 921 } 922 923 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 924 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 925 } 926 927 void arith::ExtSIOp::getCanonicalizationPatterns( 928 RewritePatternSet &patterns, MLIRContext *context) { 929 patterns.add<ExtSIOfExtUI>(context); 930 } 931 932 LogicalResult arith::ExtSIOp::verify() { 933 return verifyExtOp<IntegerType>(*this); 934 } 935 936 //===----------------------------------------------------------------------===// 937 // ExtFOp 938 //===----------------------------------------------------------------------===// 939 940 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 941 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 942 } 943 944 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } 945 946 //===----------------------------------------------------------------------===// 947 // TruncIOp 948 //===----------------------------------------------------------------------===// 949 950 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 951 assert(operands.size() == 1 && "unary operation takes one operand"); 952 953 // trunci(zexti(a)) -> a 954 // trunci(sexti(a)) -> a 955 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 956 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 957 return getOperand().getDefiningOp()->getOperand(0); 958 959 // trunci(trunci(a)) -> trunci(a)) 960 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { 961 setOperand(getOperand().getDefiningOp()->getOperand(0)); 962 return getResult(); 963 } 964 965 Type resType = getType(); 966 unsigned bitWidth; 967 if (auto shapedType = resType.dyn_cast<ShapedType>()) 968 bitWidth = shapedType.getElementTypeBitWidth(); 969 else 970 bitWidth = resType.getIntOrFloatBitWidth(); 971 972 return constFoldCastOp<IntegerAttr, IntegerAttr>( 973 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { 974 return a.trunc(bitWidth); 975 }); 976 } 977 978 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 979 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 980 } 981 982 LogicalResult arith::TruncIOp::verify() { 983 return verifyTruncateOp<IntegerType>(*this); 984 } 985 986 //===----------------------------------------------------------------------===// 987 // TruncFOp 988 //===----------------------------------------------------------------------===// 989 990 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 991 /// can be represented without precision loss or rounding. 992 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 993 assert(operands.size() == 1 && "unary operation takes one operand"); 994 995 auto constOperand = operands.front(); 996 if (!constOperand || !constOperand.isa<FloatAttr>()) 997 return {}; 998 999 // Convert to target type via 'double'. 1000 double sourceValue = 1001 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 1002 auto targetAttr = FloatAttr::get(getType(), sourceValue); 1003 1004 // Propagate if constant's value does not change after truncation. 1005 if (sourceValue == targetAttr.getValue().convertToDouble()) 1006 return targetAttr; 1007 1008 return {}; 1009 } 1010 1011 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1012 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 1013 } 1014 1015 LogicalResult arith::TruncFOp::verify() { 1016 return verifyTruncateOp<FloatType>(*this); 1017 } 1018 1019 //===----------------------------------------------------------------------===// 1020 // AndIOp 1021 //===----------------------------------------------------------------------===// 1022 1023 void arith::AndIOp::getCanonicalizationPatterns( 1024 RewritePatternSet &patterns, MLIRContext *context) { 1025 patterns.add<AndOfExtUI, AndOfExtSI>(context); 1026 } 1027 1028 //===----------------------------------------------------------------------===// 1029 // OrIOp 1030 //===----------------------------------------------------------------------===// 1031 1032 void arith::OrIOp::getCanonicalizationPatterns( 1033 RewritePatternSet &patterns, MLIRContext *context) { 1034 patterns.add<OrOfExtUI, OrOfExtSI>(context); 1035 } 1036 1037 //===----------------------------------------------------------------------===// 1038 // Verifiers for casts between integers and floats. 1039 //===----------------------------------------------------------------------===// 1040 1041 template <typename From, typename To> 1042 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 1043 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1044 return false; 1045 1046 auto srcType = getTypeIfLike<From>(inputs.front()); 1047 auto dstType = getTypeIfLike<To>(outputs.back()); 1048 1049 return srcType && dstType; 1050 } 1051 1052 //===----------------------------------------------------------------------===// 1053 // UIToFPOp 1054 //===----------------------------------------------------------------------===// 1055 1056 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1057 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1058 } 1059 1060 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 1061 Type resType = getType(); 1062 Type resEleType; 1063 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1064 resEleType = shapedType.getElementType(); 1065 else 1066 resEleType = resType; 1067 return constFoldCastOp<IntegerAttr, FloatAttr>( 1068 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { 1069 FloatType floatTy = resEleType.cast<FloatType>(); 1070 APFloat apf(floatTy.getFloatSemantics(), 1071 APInt::getZero(floatTy.getWidth())); 1072 apf.convertFromAPInt(a, /*IsSigned=*/false, 1073 APFloat::rmNearestTiesToEven); 1074 return apf; 1075 }); 1076 } 1077 1078 //===----------------------------------------------------------------------===// 1079 // SIToFPOp 1080 //===----------------------------------------------------------------------===// 1081 1082 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1083 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1084 } 1085 1086 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 1087 Type resType = getType(); 1088 Type resEleType; 1089 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1090 resEleType = shapedType.getElementType(); 1091 else 1092 resEleType = resType; 1093 return constFoldCastOp<IntegerAttr, FloatAttr>( 1094 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { 1095 FloatType floatTy = resEleType.cast<FloatType>(); 1096 APFloat apf(floatTy.getFloatSemantics(), 1097 APInt::getZero(floatTy.getWidth())); 1098 apf.convertFromAPInt(a, /*IsSigned=*/true, 1099 APFloat::rmNearestTiesToEven); 1100 return apf; 1101 }); 1102 } 1103 //===----------------------------------------------------------------------===// 1104 // FPToUIOp 1105 //===----------------------------------------------------------------------===// 1106 1107 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1108 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1109 } 1110 1111 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 1112 Type resType = getType(); 1113 Type resEleType; 1114 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1115 resEleType = shapedType.getElementType(); 1116 else 1117 resEleType = resType; 1118 return constFoldCastOp<FloatAttr, IntegerAttr>( 1119 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { 1120 IntegerType intTy = resEleType.cast<IntegerType>(); 1121 bool ignored; 1122 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 1123 castStatus = APFloat::opInvalidOp != 1124 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1125 return api; 1126 }); 1127 } 1128 1129 //===----------------------------------------------------------------------===// 1130 // FPToSIOp 1131 //===----------------------------------------------------------------------===// 1132 1133 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1134 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1135 } 1136 1137 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 1138 Type resType = getType(); 1139 Type resEleType; 1140 if (auto shapedType = resType.dyn_cast<ShapedType>()) 1141 resEleType = shapedType.getElementType(); 1142 else 1143 resEleType = resType; 1144 return constFoldCastOp<FloatAttr, IntegerAttr>( 1145 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { 1146 IntegerType intTy = resEleType.cast<IntegerType>(); 1147 bool ignored; 1148 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 1149 castStatus = APFloat::opInvalidOp != 1150 a.convertToInteger(api, APFloat::rmTowardZero, &ignored); 1151 return api; 1152 }); 1153 } 1154 1155 //===----------------------------------------------------------------------===// 1156 // IndexCastOp 1157 //===----------------------------------------------------------------------===// 1158 1159 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1160 TypeRange outputs) { 1161 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1162 return false; 1163 1164 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1165 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1166 if (!srcType || !dstType) 1167 return false; 1168 1169 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1170 (srcType.isSignlessInteger() && dstType.isIndex()); 1171 } 1172 1173 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1174 // index_cast(constant) -> constant 1175 // A little hack because we go through int. Otherwise, the size of the 1176 // constant might need to change. 1177 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1178 return IntegerAttr::get(getType(), value.getInt()); 1179 1180 return {}; 1181 } 1182 1183 void arith::IndexCastOp::getCanonicalizationPatterns( 1184 RewritePatternSet &patterns, MLIRContext *context) { 1185 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1186 } 1187 1188 //===----------------------------------------------------------------------===// 1189 // BitcastOp 1190 //===----------------------------------------------------------------------===// 1191 1192 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1193 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1194 return false; 1195 1196 auto srcType = 1197 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1198 auto dstType = 1199 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1200 if (!srcType || !dstType) 1201 return false; 1202 1203 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1204 } 1205 1206 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1207 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1208 1209 auto resType = getType(); 1210 auto operand = operands[0]; 1211 if (!operand) 1212 return {}; 1213 1214 /// Bitcast dense elements. 1215 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1216 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1217 /// Other shaped types unhandled. 1218 if (resType.isa<ShapedType>()) 1219 return {}; 1220 1221 /// Bitcast integer or float to integer or float. 1222 APInt bits = operand.isa<FloatAttr>() 1223 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1224 : operand.cast<IntegerAttr>().getValue(); 1225 1226 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1227 return FloatAttr::get(resType, 1228 APFloat(resFloatType.getFloatSemantics(), bits)); 1229 return IntegerAttr::get(resType, bits); 1230 } 1231 1232 void arith::BitcastOp::getCanonicalizationPatterns( 1233 RewritePatternSet &patterns, MLIRContext *context) { 1234 patterns.add<BitcastOfBitcast>(context); 1235 } 1236 1237 //===----------------------------------------------------------------------===// 1238 // Helpers for compare ops 1239 //===----------------------------------------------------------------------===// 1240 1241 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1242 static Type getI1SameShape(Type type) { 1243 auto i1Type = IntegerType::get(type.getContext(), 1); 1244 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1245 return RankedTensorType::get(tensorType.getShape(), i1Type); 1246 if (type.isa<UnrankedTensorType>()) 1247 return UnrankedTensorType::get(i1Type); 1248 if (auto vectorType = type.dyn_cast<VectorType>()) 1249 return VectorType::get(vectorType.getShape(), i1Type, 1250 vectorType.getNumScalableDims()); 1251 return i1Type; 1252 } 1253 1254 //===----------------------------------------------------------------------===// 1255 // CmpIOp 1256 //===----------------------------------------------------------------------===// 1257 1258 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1259 /// comparison predicates. 1260 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1261 const APInt &lhs, const APInt &rhs) { 1262 switch (predicate) { 1263 case arith::CmpIPredicate::eq: 1264 return lhs.eq(rhs); 1265 case arith::CmpIPredicate::ne: 1266 return lhs.ne(rhs); 1267 case arith::CmpIPredicate::slt: 1268 return lhs.slt(rhs); 1269 case arith::CmpIPredicate::sle: 1270 return lhs.sle(rhs); 1271 case arith::CmpIPredicate::sgt: 1272 return lhs.sgt(rhs); 1273 case arith::CmpIPredicate::sge: 1274 return lhs.sge(rhs); 1275 case arith::CmpIPredicate::ult: 1276 return lhs.ult(rhs); 1277 case arith::CmpIPredicate::ule: 1278 return lhs.ule(rhs); 1279 case arith::CmpIPredicate::ugt: 1280 return lhs.ugt(rhs); 1281 case arith::CmpIPredicate::uge: 1282 return lhs.uge(rhs); 1283 } 1284 llvm_unreachable("unknown cmpi predicate kind"); 1285 } 1286 1287 /// Returns true if the predicate is true for two equal operands. 1288 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1289 switch (predicate) { 1290 case arith::CmpIPredicate::eq: 1291 case arith::CmpIPredicate::sle: 1292 case arith::CmpIPredicate::sge: 1293 case arith::CmpIPredicate::ule: 1294 case arith::CmpIPredicate::uge: 1295 return true; 1296 case arith::CmpIPredicate::ne: 1297 case arith::CmpIPredicate::slt: 1298 case arith::CmpIPredicate::sgt: 1299 case arith::CmpIPredicate::ult: 1300 case arith::CmpIPredicate::ugt: 1301 return false; 1302 } 1303 llvm_unreachable("unknown cmpi predicate kind"); 1304 } 1305 1306 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1307 auto boolAttr = BoolAttr::get(ctx, value); 1308 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1309 if (!shapedType) 1310 return boolAttr; 1311 return DenseElementsAttr::get(shapedType, boolAttr); 1312 } 1313 1314 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1315 assert(operands.size() == 2 && "cmpi takes two operands"); 1316 1317 // cmpi(pred, x, x) 1318 if (getLhs() == getRhs()) { 1319 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1320 return getBoolAttribute(getType(), getContext(), val); 1321 } 1322 1323 if (matchPattern(getRhs(), m_Zero())) { 1324 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1325 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1326 // extsi(%x : i1 -> iN) != 0 -> %x 1327 if (getPredicate() == arith::CmpIPredicate::ne) { 1328 return extOp.getOperand(); 1329 } 1330 } 1331 } 1332 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1333 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1334 // extui(%x : i1 -> iN) != 0 -> %x 1335 if (getPredicate() == arith::CmpIPredicate::ne) { 1336 return extOp.getOperand(); 1337 } 1338 } 1339 } 1340 } 1341 1342 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1343 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1344 if (!lhs || !rhs) 1345 return {}; 1346 1347 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1348 return BoolAttr::get(getContext(), val); 1349 } 1350 1351 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1352 MLIRContext *context) { 1353 patterns.insert<CmpIExtSI, CmpIExtUI>(context); 1354 } 1355 1356 //===----------------------------------------------------------------------===// 1357 // CmpFOp 1358 //===----------------------------------------------------------------------===// 1359 1360 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1361 /// comparison predicates. 1362 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1363 const APFloat &lhs, const APFloat &rhs) { 1364 auto cmpResult = lhs.compare(rhs); 1365 switch (predicate) { 1366 case arith::CmpFPredicate::AlwaysFalse: 1367 return false; 1368 case arith::CmpFPredicate::OEQ: 1369 return cmpResult == APFloat::cmpEqual; 1370 case arith::CmpFPredicate::OGT: 1371 return cmpResult == APFloat::cmpGreaterThan; 1372 case arith::CmpFPredicate::OGE: 1373 return cmpResult == APFloat::cmpGreaterThan || 1374 cmpResult == APFloat::cmpEqual; 1375 case arith::CmpFPredicate::OLT: 1376 return cmpResult == APFloat::cmpLessThan; 1377 case arith::CmpFPredicate::OLE: 1378 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1379 case arith::CmpFPredicate::ONE: 1380 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1381 case arith::CmpFPredicate::ORD: 1382 return cmpResult != APFloat::cmpUnordered; 1383 case arith::CmpFPredicate::UEQ: 1384 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1385 case arith::CmpFPredicate::UGT: 1386 return cmpResult == APFloat::cmpUnordered || 1387 cmpResult == APFloat::cmpGreaterThan; 1388 case arith::CmpFPredicate::UGE: 1389 return cmpResult == APFloat::cmpUnordered || 1390 cmpResult == APFloat::cmpGreaterThan || 1391 cmpResult == APFloat::cmpEqual; 1392 case arith::CmpFPredicate::ULT: 1393 return cmpResult == APFloat::cmpUnordered || 1394 cmpResult == APFloat::cmpLessThan; 1395 case arith::CmpFPredicate::ULE: 1396 return cmpResult == APFloat::cmpUnordered || 1397 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1398 case arith::CmpFPredicate::UNE: 1399 return cmpResult != APFloat::cmpEqual; 1400 case arith::CmpFPredicate::UNO: 1401 return cmpResult == APFloat::cmpUnordered; 1402 case arith::CmpFPredicate::AlwaysTrue: 1403 return true; 1404 } 1405 llvm_unreachable("unknown cmpf predicate kind"); 1406 } 1407 1408 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1409 assert(operands.size() == 2 && "cmpf takes two operands"); 1410 1411 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1412 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1413 1414 // If one operand is NaN, making them both NaN does not change the result. 1415 if (lhs && lhs.getValue().isNaN()) 1416 rhs = lhs; 1417 if (rhs && rhs.getValue().isNaN()) 1418 lhs = rhs; 1419 1420 if (!lhs || !rhs) 1421 return {}; 1422 1423 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1424 return BoolAttr::get(getContext(), val); 1425 } 1426 1427 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> { 1428 public: 1429 using OpRewritePattern<CmpFOp>::OpRewritePattern; 1430 1431 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, 1432 bool isUnsigned) { 1433 using namespace arith; 1434 switch (pred) { 1435 case CmpFPredicate::UEQ: 1436 case CmpFPredicate::OEQ: 1437 return CmpIPredicate::eq; 1438 case CmpFPredicate::UGT: 1439 case CmpFPredicate::OGT: 1440 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; 1441 case CmpFPredicate::UGE: 1442 case CmpFPredicate::OGE: 1443 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; 1444 case CmpFPredicate::ULT: 1445 case CmpFPredicate::OLT: 1446 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; 1447 case CmpFPredicate::ULE: 1448 case CmpFPredicate::OLE: 1449 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; 1450 case CmpFPredicate::UNE: 1451 case CmpFPredicate::ONE: 1452 return CmpIPredicate::ne; 1453 default: 1454 llvm_unreachable("Unexpected predicate!"); 1455 } 1456 } 1457 1458 LogicalResult matchAndRewrite(CmpFOp op, 1459 PatternRewriter &rewriter) const override { 1460 FloatAttr flt; 1461 if (!matchPattern(op.getRhs(), m_Constant(&flt))) 1462 return failure(); 1463 1464 const APFloat &rhs = flt.getValue(); 1465 1466 // Don't attempt to fold a nan. 1467 if (rhs.isNaN()) 1468 return failure(); 1469 1470 // Get the width of the mantissa. We don't want to hack on conversions that 1471 // might lose information from the integer, e.g. "i64 -> float" 1472 FloatType floatTy = op.getRhs().getType().cast<FloatType>(); 1473 int mantissaWidth = floatTy.getFPMantissaWidth(); 1474 if (mantissaWidth <= 0) 1475 return failure(); 1476 1477 bool isUnsigned; 1478 Value intVal; 1479 1480 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) { 1481 isUnsigned = false; 1482 intVal = si.getIn(); 1483 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) { 1484 isUnsigned = true; 1485 intVal = ui.getIn(); 1486 } else { 1487 return failure(); 1488 } 1489 1490 // Check to see that the input is converted from an integer type that is 1491 // small enough that preserves all bits. 1492 auto intTy = intVal.getType().cast<IntegerType>(); 1493 auto intWidth = intTy.getWidth(); 1494 1495 // Number of bits representing values, as opposed to the sign 1496 auto valueBits = isUnsigned ? intWidth : (intWidth - 1); 1497 1498 // Following test does NOT adjust intWidth downwards for signed inputs, 1499 // because the most negative value still requires all the mantissa bits 1500 // to distinguish it from one less than that value. 1501 if ((int)intWidth > mantissaWidth) { 1502 // Conversion would lose accuracy. Check if loss can impact comparison. 1503 int exponent = ilogb(rhs); 1504 if (exponent == APFloat::IEK_Inf) { 1505 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics())); 1506 if (maxExponent < (int)valueBits) { 1507 // Conversion could create infinity. 1508 return failure(); 1509 } 1510 } else { 1511 // Note that if rhs is zero or NaN, then Exp is negative 1512 // and first condition is trivially false. 1513 if (mantissaWidth <= exponent && exponent <= (int)valueBits) { 1514 // Conversion could affect comparison. 1515 return failure(); 1516 } 1517 } 1518 } 1519 1520 // Convert to equivalent cmpi predicate 1521 CmpIPredicate pred; 1522 switch (op.getPredicate()) { 1523 case CmpFPredicate::ORD: 1524 // Int to fp conversion doesn't create a nan (ord checks neither is a nan) 1525 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1526 /*width=*/1); 1527 return success(); 1528 case CmpFPredicate::UNO: 1529 // Int to fp conversion doesn't create a nan (uno checks either is a nan) 1530 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1531 /*width=*/1); 1532 return success(); 1533 default: 1534 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned); 1535 break; 1536 } 1537 1538 if (!isUnsigned) { 1539 // If the rhs value is > SignedMax, fold the comparison. This handles 1540 // +INF and large values. 1541 APFloat signedMax(rhs.getSemantics()); 1542 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true, 1543 APFloat::rmNearestTiesToEven); 1544 if (signedMax < rhs) { // smax < 13123.0 1545 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt || 1546 pred == CmpIPredicate::sle) 1547 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1548 /*width=*/1); 1549 else 1550 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1551 /*width=*/1); 1552 return success(); 1553 } 1554 } else { 1555 // If the rhs value is > UnsignedMax, fold the comparison. This handles 1556 // +INF and large values. 1557 APFloat unsignedMax(rhs.getSemantics()); 1558 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false, 1559 APFloat::rmNearestTiesToEven); 1560 if (unsignedMax < rhs) { // umax < 13123.0 1561 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult || 1562 pred == CmpIPredicate::ule) 1563 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1564 /*width=*/1); 1565 else 1566 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1567 /*width=*/1); 1568 return success(); 1569 } 1570 } 1571 1572 if (!isUnsigned) { 1573 // See if the rhs value is < SignedMin. 1574 APFloat signedMin(rhs.getSemantics()); 1575 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true, 1576 APFloat::rmNearestTiesToEven); 1577 if (signedMin > rhs) { // smin > 12312.0 1578 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt || 1579 pred == CmpIPredicate::sge) 1580 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1581 /*width=*/1); 1582 else 1583 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1584 /*width=*/1); 1585 return success(); 1586 } 1587 } else { 1588 // See if the rhs value is < UnsignedMin. 1589 APFloat unsignedMin(rhs.getSemantics()); 1590 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false, 1591 APFloat::rmNearestTiesToEven); 1592 if (unsignedMin > rhs) { // umin > 12312.0 1593 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt || 1594 pred == CmpIPredicate::uge) 1595 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1596 /*width=*/1); 1597 else 1598 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1599 /*width=*/1); 1600 return success(); 1601 } 1602 } 1603 1604 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or 1605 // [0, UMAX], but it may still be fractional. See if it is fractional by 1606 // casting the FP value to the integer value and back, checking for 1607 // equality. Don't do this for zero, because -0.0 is not fractional. 1608 bool ignored; 1609 APSInt rhsInt(intWidth, isUnsigned); 1610 if (APFloat::opInvalidOp == 1611 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) { 1612 // Undefined behavior invoked - the destination type can't represent 1613 // the input constant. 1614 return failure(); 1615 } 1616 1617 if (!rhs.isZero()) { 1618 APFloat apf(floatTy.getFloatSemantics(), 1619 APInt::getZero(floatTy.getWidth())); 1620 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven); 1621 1622 bool equal = apf == rhs; 1623 if (!equal) { 1624 // If we had a comparison against a fractional value, we have to adjust 1625 // the compare predicate and sometimes the value. rhsInt is rounded 1626 // towards zero at this point. 1627 switch (pred) { 1628 case CmpIPredicate::ne: // (float)int != 4.4 --> true 1629 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1630 /*width=*/1); 1631 return success(); 1632 case CmpIPredicate::eq: // (float)int == 4.4 --> false 1633 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1634 /*width=*/1); 1635 return success(); 1636 case CmpIPredicate::ule: 1637 // (float)int <= 4.4 --> int <= 4 1638 // (float)int <= -4.4 --> false 1639 if (rhs.isNegative()) { 1640 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1641 /*width=*/1); 1642 return success(); 1643 } 1644 break; 1645 case CmpIPredicate::sle: 1646 // (float)int <= 4.4 --> int <= 4 1647 // (float)int <= -4.4 --> int < -4 1648 if (rhs.isNegative()) 1649 pred = CmpIPredicate::slt; 1650 break; 1651 case CmpIPredicate::ult: 1652 // (float)int < -4.4 --> false 1653 // (float)int < 4.4 --> int <= 4 1654 if (rhs.isNegative()) { 1655 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, 1656 /*width=*/1); 1657 return success(); 1658 } 1659 pred = CmpIPredicate::ule; 1660 break; 1661 case CmpIPredicate::slt: 1662 // (float)int < -4.4 --> int < -4 1663 // (float)int < 4.4 --> int <= 4 1664 if (!rhs.isNegative()) 1665 pred = CmpIPredicate::sle; 1666 break; 1667 case CmpIPredicate::ugt: 1668 // (float)int > 4.4 --> int > 4 1669 // (float)int > -4.4 --> true 1670 if (rhs.isNegative()) { 1671 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1672 /*width=*/1); 1673 return success(); 1674 } 1675 break; 1676 case CmpIPredicate::sgt: 1677 // (float)int > 4.4 --> int > 4 1678 // (float)int > -4.4 --> int >= -4 1679 if (rhs.isNegative()) 1680 pred = CmpIPredicate::sge; 1681 break; 1682 case CmpIPredicate::uge: 1683 // (float)int >= -4.4 --> true 1684 // (float)int >= 4.4 --> int > 4 1685 if (rhs.isNegative()) { 1686 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, 1687 /*width=*/1); 1688 return success(); 1689 } 1690 pred = CmpIPredicate::ugt; 1691 break; 1692 case CmpIPredicate::sge: 1693 // (float)int >= -4.4 --> int >= -4 1694 // (float)int >= 4.4 --> int > 4 1695 if (!rhs.isNegative()) 1696 pred = CmpIPredicate::sgt; 1697 break; 1698 } 1699 } 1700 } 1701 1702 // Lower this FP comparison into an appropriate integer version of the 1703 // comparison. 1704 rewriter.replaceOpWithNewOp<CmpIOp>( 1705 op, pred, intVal, 1706 rewriter.create<ConstantOp>( 1707 op.getLoc(), intVal.getType(), 1708 rewriter.getIntegerAttr(intVal.getType(), rhsInt))); 1709 return success(); 1710 } 1711 }; 1712 1713 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1714 MLIRContext *context) { 1715 patterns.insert<CmpFIntToFPConst>(context); 1716 } 1717 1718 //===----------------------------------------------------------------------===// 1719 // SelectOp 1720 //===----------------------------------------------------------------------===// 1721 1722 // Transforms a select of a boolean to arithmetic operations 1723 // 1724 // arith.select %arg, %x, %y : i1 1725 // 1726 // becomes 1727 // 1728 // and(%arg, %x) or and(!%arg, %y) 1729 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> { 1730 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1731 1732 LogicalResult matchAndRewrite(arith::SelectOp op, 1733 PatternRewriter &rewriter) const override { 1734 if (!op.getType().isInteger(1)) 1735 return failure(); 1736 1737 Value falseConstant = 1738 rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1); 1739 Value notCondition = rewriter.create<arith::XOrIOp>( 1740 op.getLoc(), op.getCondition(), falseConstant); 1741 1742 Value trueVal = rewriter.create<arith::AndIOp>( 1743 op.getLoc(), op.getCondition(), op.getTrueValue()); 1744 Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition, 1745 op.getFalseValue()); 1746 rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal); 1747 return success(); 1748 } 1749 }; 1750 1751 // select %arg, %c1, %c0 => extui %arg 1752 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { 1753 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1754 1755 LogicalResult matchAndRewrite(arith::SelectOp op, 1756 PatternRewriter &rewriter) const override { 1757 // Cannot extui i1 to i1, or i1 to f32 1758 if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1)) 1759 return failure(); 1760 1761 // select %x, c1, %c0 => extui %arg 1762 if (matchPattern(op.getTrueValue(), m_One())) 1763 if (matchPattern(op.getFalseValue(), m_Zero())) { 1764 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(), 1765 op.getCondition()); 1766 return success(); 1767 } 1768 1769 // select %x, c0, %c1 => extui (xor %arg, true) 1770 if (matchPattern(op.getTrueValue(), m_Zero())) 1771 if (matchPattern(op.getFalseValue(), m_One())) { 1772 rewriter.replaceOpWithNewOp<arith::ExtUIOp>( 1773 op, op.getType(), 1774 rewriter.create<arith::XOrIOp>( 1775 op.getLoc(), op.getCondition(), 1776 rewriter.create<arith::ConstantIntOp>( 1777 op.getLoc(), 1, op.getCondition().getType()))); 1778 return success(); 1779 } 1780 1781 return failure(); 1782 } 1783 }; 1784 1785 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, 1786 MLIRContext *context) { 1787 results.add<SelectI1Simplify, SelectToExtUI>(context); 1788 } 1789 1790 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) { 1791 Value trueVal = getTrueValue(); 1792 Value falseVal = getFalseValue(); 1793 if (trueVal == falseVal) 1794 return trueVal; 1795 1796 Value condition = getCondition(); 1797 1798 // select true, %0, %1 => %0 1799 if (matchPattern(condition, m_One())) 1800 return trueVal; 1801 1802 // select false, %0, %1 => %1 1803 if (matchPattern(condition, m_Zero())) 1804 return falseVal; 1805 1806 // select %x, true, false => %x 1807 if (getType().isInteger(1)) 1808 if (matchPattern(getTrueValue(), m_One())) 1809 if (matchPattern(getFalseValue(), m_Zero())) 1810 return condition; 1811 1812 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { 1813 auto pred = cmp.getPredicate(); 1814 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { 1815 auto cmpLhs = cmp.getLhs(); 1816 auto cmpRhs = cmp.getRhs(); 1817 1818 // %0 = arith.cmpi eq, %arg0, %arg1 1819 // %1 = arith.select %0, %arg0, %arg1 => %arg1 1820 1821 // %0 = arith.cmpi ne, %arg0, %arg1 1822 // %1 = arith.select %0, %arg0, %arg1 => %arg0 1823 1824 if ((cmpLhs == trueVal && cmpRhs == falseVal) || 1825 (cmpRhs == trueVal && cmpLhs == falseVal)) 1826 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; 1827 } 1828 } 1829 return nullptr; 1830 } 1831 1832 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { 1833 Type conditionType, resultType; 1834 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; 1835 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || 1836 parser.parseOptionalAttrDict(result.attributes) || 1837 parser.parseColonType(resultType)) 1838 return failure(); 1839 1840 // Check for the explicit condition type if this is a masked tensor or vector. 1841 if (succeeded(parser.parseOptionalComma())) { 1842 conditionType = resultType; 1843 if (parser.parseType(resultType)) 1844 return failure(); 1845 } else { 1846 conditionType = parser.getBuilder().getI1Type(); 1847 } 1848 1849 result.addTypes(resultType); 1850 return parser.resolveOperands(operands, 1851 {conditionType, resultType, resultType}, 1852 parser.getNameLoc(), result.operands); 1853 } 1854 1855 void arith::SelectOp::print(OpAsmPrinter &p) { 1856 p << " " << getOperands(); 1857 p.printOptionalAttrDict((*this)->getAttrs()); 1858 p << " : "; 1859 if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>()) 1860 p << condType << ", "; 1861 p << getType(); 1862 } 1863 1864 LogicalResult arith::SelectOp::verify() { 1865 Type conditionType = getCondition().getType(); 1866 if (conditionType.isSignlessInteger(1)) 1867 return success(); 1868 1869 // If the result type is a vector or tensor, the type can be a mask with the 1870 // same elements. 1871 Type resultType = getType(); 1872 if (!resultType.isa<TensorType, VectorType>()) 1873 return emitOpError() << "expected condition to be a signless i1, but got " 1874 << conditionType; 1875 Type shapedConditionType = getI1SameShape(resultType); 1876 if (conditionType != shapedConditionType) { 1877 return emitOpError() << "expected condition type to have the same shape " 1878 "as the result type, expected " 1879 << shapedConditionType << ", but got " 1880 << conditionType; 1881 } 1882 return success(); 1883 } 1884 //===----------------------------------------------------------------------===// 1885 // ShLIOp 1886 //===----------------------------------------------------------------------===// 1887 1888 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) { 1889 // Don't fold if shifting more than the bit width. 1890 bool bounded = false; 1891 auto result = constFoldBinaryOp<IntegerAttr>( 1892 operands, [&](const APInt &a, const APInt &b) { 1893 bounded = b.ule(b.getBitWidth()); 1894 return a.shl(b); 1895 }); 1896 return bounded ? result : Attribute(); 1897 } 1898 1899 //===----------------------------------------------------------------------===// 1900 // ShRUIOp 1901 //===----------------------------------------------------------------------===// 1902 1903 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) { 1904 // Don't fold if shifting more than the bit width. 1905 bool bounded = false; 1906 auto result = constFoldBinaryOp<IntegerAttr>( 1907 operands, [&](const APInt &a, const APInt &b) { 1908 bounded = b.ule(b.getBitWidth()); 1909 return a.lshr(b); 1910 }); 1911 return bounded ? result : Attribute(); 1912 } 1913 1914 //===----------------------------------------------------------------------===// 1915 // ShRSIOp 1916 //===----------------------------------------------------------------------===// 1917 1918 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) { 1919 // Don't fold if shifting more than the bit width. 1920 bool bounded = false; 1921 auto result = constFoldBinaryOp<IntegerAttr>( 1922 operands, [&](const APInt &a, const APInt &b) { 1923 bounded = b.ule(b.getBitWidth()); 1924 return a.ashr(b); 1925 }); 1926 return bounded ? result : Attribute(); 1927 } 1928 1929 //===----------------------------------------------------------------------===// 1930 // Atomic Enum 1931 //===----------------------------------------------------------------------===// 1932 1933 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1934 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1935 OpBuilder &builder, Location loc) { 1936 switch (kind) { 1937 case AtomicRMWKind::maxf: 1938 return builder.getFloatAttr( 1939 resultType, 1940 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1941 /*Negative=*/true)); 1942 case AtomicRMWKind::addf: 1943 case AtomicRMWKind::addi: 1944 case AtomicRMWKind::maxu: 1945 case AtomicRMWKind::ori: 1946 return builder.getZeroAttr(resultType); 1947 case AtomicRMWKind::andi: 1948 return builder.getIntegerAttr( 1949 resultType, 1950 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1951 case AtomicRMWKind::maxs: 1952 return builder.getIntegerAttr( 1953 resultType, 1954 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1955 case AtomicRMWKind::minf: 1956 return builder.getFloatAttr( 1957 resultType, 1958 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1959 /*Negative=*/false)); 1960 case AtomicRMWKind::mins: 1961 return builder.getIntegerAttr( 1962 resultType, 1963 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1964 case AtomicRMWKind::minu: 1965 return builder.getIntegerAttr( 1966 resultType, 1967 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1968 case AtomicRMWKind::muli: 1969 return builder.getIntegerAttr(resultType, 1); 1970 case AtomicRMWKind::mulf: 1971 return builder.getFloatAttr(resultType, 1); 1972 // TODO: Add remaining reduction operations. 1973 default: 1974 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1975 break; 1976 } 1977 return nullptr; 1978 } 1979 1980 /// Returns the identity value associated with an AtomicRMWKind op. 1981 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 1982 OpBuilder &builder, Location loc) { 1983 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 1984 return builder.create<arith::ConstantOp>(loc, attr); 1985 } 1986 1987 /// Return the value obtained by applying the reduction operation kind 1988 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 1989 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 1990 Location loc, Value lhs, Value rhs) { 1991 switch (op) { 1992 case AtomicRMWKind::addf: 1993 return builder.create<arith::AddFOp>(loc, lhs, rhs); 1994 case AtomicRMWKind::addi: 1995 return builder.create<arith::AddIOp>(loc, lhs, rhs); 1996 case AtomicRMWKind::mulf: 1997 return builder.create<arith::MulFOp>(loc, lhs, rhs); 1998 case AtomicRMWKind::muli: 1999 return builder.create<arith::MulIOp>(loc, lhs, rhs); 2000 case AtomicRMWKind::maxf: 2001 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 2002 case AtomicRMWKind::minf: 2003 return builder.create<arith::MinFOp>(loc, lhs, rhs); 2004 case AtomicRMWKind::maxs: 2005 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 2006 case AtomicRMWKind::mins: 2007 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 2008 case AtomicRMWKind::maxu: 2009 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 2010 case AtomicRMWKind::minu: 2011 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 2012 case AtomicRMWKind::ori: 2013 return builder.create<arith::OrIOp>(loc, lhs, rhs); 2014 case AtomicRMWKind::andi: 2015 return builder.create<arith::AndIOp>(loc, lhs, rhs); 2016 // TODO: Add remaining reduction operations. 2017 default: 2018 (void)emitOptionalError(loc, "Reduction operation type not supported"); 2019 break; 2020 } 2021 return nullptr; 2022 } 2023 2024 //===----------------------------------------------------------------------===// 2025 // TableGen'd op method definitions 2026 //===----------------------------------------------------------------------===// 2027 2028 #define GET_OP_CLASSES 2029 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 2030 2031 //===----------------------------------------------------------------------===// 2032 // TableGen'd enum attribute definitions 2033 //===----------------------------------------------------------------------===// 2034 2035 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 2036