1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 #include "llvm/ADT/APSInt.h" 20 21 using namespace mlir; 22 using namespace mlir::arith; 23 24 //===----------------------------------------------------------------------===// 25 // Pattern helpers 26 //===----------------------------------------------------------------------===// 27 28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 29 Attribute lhs, Attribute rhs) { 30 return builder.getIntegerAttr(res.getType(), 31 lhs.cast<IntegerAttr>().getInt() + 32 rhs.cast<IntegerAttr>().getInt()); 33 } 34 35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 36 Attribute lhs, Attribute rhs) { 37 return builder.getIntegerAttr(res.getType(), 38 lhs.cast<IntegerAttr>().getInt() - 39 rhs.cast<IntegerAttr>().getInt()); 40 } 41 42 /// Invert an integer comparison predicate. 43 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { 44 switch (pred) { 45 case arith::CmpIPredicate::eq: 46 return arith::CmpIPredicate::ne; 47 case arith::CmpIPredicate::ne: 48 return arith::CmpIPredicate::eq; 49 case arith::CmpIPredicate::slt: 50 return arith::CmpIPredicate::sge; 51 case arith::CmpIPredicate::sle: 52 return arith::CmpIPredicate::sgt; 53 case arith::CmpIPredicate::sgt: 54 return arith::CmpIPredicate::sle; 55 case arith::CmpIPredicate::sge: 56 return arith::CmpIPredicate::slt; 57 case arith::CmpIPredicate::ult: 58 return arith::CmpIPredicate::uge; 59 case arith::CmpIPredicate::ule: 60 return arith::CmpIPredicate::ugt; 61 case arith::CmpIPredicate::ugt: 62 return arith::CmpIPredicate::ule; 63 case arith::CmpIPredicate::uge: 64 return arith::CmpIPredicate::ult; 65 } 66 llvm_unreachable("unknown cmpi predicate kind"); 67 } 68 69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 70 return arith::CmpIPredicateAttr::get(pred.getContext(), 71 invertPredicate(pred.getValue())); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // TableGen'd canonicalization patterns 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 #include "ArithmeticCanonicalization.inc" 80 } // namespace 81 82 //===----------------------------------------------------------------------===// 83 // ConstantOp 84 //===----------------------------------------------------------------------===// 85 86 void arith::ConstantOp::getAsmResultNames( 87 function_ref<void(Value, StringRef)> setNameFn) { 88 auto type = getType(); 89 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 90 auto intType = type.dyn_cast<IntegerType>(); 91 92 // Sugar i1 constants with 'true' and 'false'. 93 if (intType && intType.getWidth() == 1) 94 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 95 96 // Otherwise, build a compex name with the value and type. 97 SmallString<32> specialNameBuffer; 98 llvm::raw_svector_ostream specialName(specialNameBuffer); 99 specialName << 'c' << intCst.getInt(); 100 if (intType) 101 specialName << '_' << type; 102 setNameFn(getResult(), specialName.str()); 103 } else { 104 setNameFn(getResult(), "cst"); 105 } 106 } 107 108 /// TODO: disallow arith.constant to return anything other than signless integer 109 /// or float like. 110 static LogicalResult verify(arith::ConstantOp op) { 111 auto type = op.getType(); 112 // The value's type must match the return type. 113 if (op.getValue().getType() != type) { 114 return op.emitOpError() << "value type " << op.getValue().getType() 115 << " must match return type: " << type; 116 } 117 // Integer values must be signless. 118 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 119 return op.emitOpError("integer return type must be signless"); 120 // Any float or elements attribute are acceptable. 121 if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 122 return op.emitOpError( 123 "value must be an integer, float, or elements attribute"); 124 } 125 return success(); 126 } 127 128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 129 // The value's type must be the same as the provided type. 130 if (value.getType() != type) 131 return false; 132 // Integer values must be signless. 133 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 134 return false; 135 // Integer, float, and element attributes are buildable. 136 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 137 } 138 139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 140 return getValue(); 141 } 142 143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 144 int64_t value, unsigned width) { 145 auto type = builder.getIntegerType(width); 146 arith::ConstantOp::build(builder, result, type, 147 builder.getIntegerAttr(type, value)); 148 } 149 150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 151 int64_t value, Type type) { 152 assert(type.isSignlessInteger() && 153 "ConstantIntOp can only have signless integer type values"); 154 arith::ConstantOp::build(builder, result, type, 155 builder.getIntegerAttr(type, value)); 156 } 157 158 bool arith::ConstantIntOp::classof(Operation *op) { 159 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 160 return constOp.getType().isSignlessInteger(); 161 return false; 162 } 163 164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 165 const APFloat &value, FloatType type) { 166 arith::ConstantOp::build(builder, result, type, 167 builder.getFloatAttr(type, value)); 168 } 169 170 bool arith::ConstantFloatOp::classof(Operation *op) { 171 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 172 return constOp.getType().isa<FloatType>(); 173 return false; 174 } 175 176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 177 int64_t value) { 178 arith::ConstantOp::build(builder, result, builder.getIndexType(), 179 builder.getIndexAttr(value)); 180 } 181 182 bool arith::ConstantIndexOp::classof(Operation *op) { 183 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 184 return constOp.getType().isIndex(); 185 return false; 186 } 187 188 //===----------------------------------------------------------------------===// 189 // AddIOp 190 //===----------------------------------------------------------------------===// 191 192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 193 // addi(x, 0) -> x 194 if (matchPattern(getRhs(), m_Zero())) 195 return getLhs(); 196 197 // add(sub(a, b), b) -> a 198 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 199 if (getRhs() == sub.getRhs()) 200 return sub.getLhs(); 201 202 // add(b, sub(a, b)) -> a 203 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 204 if (getLhs() == sub.getRhs()) 205 return sub.getLhs(); 206 207 return constFoldBinaryOp<IntegerAttr>( 208 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 209 } 210 211 void arith::AddIOp::getCanonicalizationPatterns( 212 OwningRewritePatternList &patterns, MLIRContext *context) { 213 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 214 context); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // SubIOp 219 //===----------------------------------------------------------------------===// 220 221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 222 // subi(x,x) -> 0 223 if (getOperand(0) == getOperand(1)) 224 return Builder(getContext()).getZeroAttr(getType()); 225 // subi(x,0) -> x 226 if (matchPattern(getRhs(), m_Zero())) 227 return getLhs(); 228 229 return constFoldBinaryOp<IntegerAttr>( 230 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 231 } 232 233 void arith::SubIOp::getCanonicalizationPatterns( 234 OwningRewritePatternList &patterns, MLIRContext *context) { 235 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 236 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 237 SubILHSSubConstantLHS>(context); 238 } 239 240 //===----------------------------------------------------------------------===// 241 // MulIOp 242 //===----------------------------------------------------------------------===// 243 244 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 245 // muli(x, 0) -> 0 246 if (matchPattern(getRhs(), m_Zero())) 247 return getRhs(); 248 // muli(x, 1) -> x 249 if (matchPattern(getRhs(), m_One())) 250 return getOperand(0); 251 // TODO: Handle the overflow case. 252 253 // default folder 254 return constFoldBinaryOp<IntegerAttr>( 255 operands, [](const APInt &a, const APInt &b) { return a * b; }); 256 } 257 258 //===----------------------------------------------------------------------===// 259 // DivUIOp 260 //===----------------------------------------------------------------------===// 261 262 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 263 // Don't fold if it would require a division by zero. 264 bool div0 = false; 265 auto result = 266 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 267 if (div0 || !b) { 268 div0 = true; 269 return a; 270 } 271 return a.udiv(b); 272 }); 273 274 // Fold out division by one. Assumes all tensors of all ones are splats. 275 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 276 if (rhs.getValue() == 1) 277 return getLhs(); 278 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 279 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 280 return getLhs(); 281 } 282 283 return div0 ? Attribute() : result; 284 } 285 286 //===----------------------------------------------------------------------===// 287 // DivSIOp 288 //===----------------------------------------------------------------------===// 289 290 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 291 // Don't fold if it would overflow or if it requires a division by zero. 292 bool overflowOrDiv0 = false; 293 auto result = 294 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 295 if (overflowOrDiv0 || !b) { 296 overflowOrDiv0 = true; 297 return a; 298 } 299 return a.sdiv_ov(b, overflowOrDiv0); 300 }); 301 302 // Fold out division by one. Assumes all tensors of all ones are splats. 303 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 304 if (rhs.getValue() == 1) 305 return getLhs(); 306 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 307 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 308 return getLhs(); 309 } 310 311 return overflowOrDiv0 ? Attribute() : result; 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Ceil and floor division folding helpers 316 //===----------------------------------------------------------------------===// 317 318 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 319 bool &overflow) { 320 // Returns (a-1)/b + 1 321 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 322 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 323 return val.sadd_ov(one, overflow); 324 } 325 326 //===----------------------------------------------------------------------===// 327 // CeilDivUIOp 328 //===----------------------------------------------------------------------===// 329 330 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 331 bool overflowOrDiv0 = false; 332 auto result = 333 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 334 if (overflowOrDiv0 || !b) { 335 overflowOrDiv0 = true; 336 return a; 337 } 338 APInt quotient = a.udiv(b); 339 if (!a.urem(b)) 340 return quotient; 341 APInt one(a.getBitWidth(), 1, true); 342 return quotient.uadd_ov(one, overflowOrDiv0); 343 }); 344 // Fold out ceil division by one. Assumes all tensors of all ones are 345 // splats. 346 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 347 if (rhs.getValue() == 1) 348 return getLhs(); 349 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 350 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 351 return getLhs(); 352 } 353 354 return overflowOrDiv0 ? Attribute() : result; 355 } 356 357 //===----------------------------------------------------------------------===// 358 // CeilDivSIOp 359 //===----------------------------------------------------------------------===// 360 361 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 362 // Don't fold if it would overflow or if it requires a division by zero. 363 bool overflowOrDiv0 = false; 364 auto result = 365 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 366 if (overflowOrDiv0 || !b) { 367 overflowOrDiv0 = true; 368 return a; 369 } 370 if (!a) 371 return a; 372 // After this point we know that neither a or b are zero. 373 unsigned bits = a.getBitWidth(); 374 APInt zero = APInt::getZero(bits); 375 bool aGtZero = a.sgt(zero); 376 bool bGtZero = b.sgt(zero); 377 if (aGtZero && bGtZero) { 378 // Both positive, return ceil(a, b). 379 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 380 } 381 if (!aGtZero && !bGtZero) { 382 // Both negative, return ceil(-a, -b). 383 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 384 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 385 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 386 } 387 if (!aGtZero && bGtZero) { 388 // A is negative, b is positive, return - ( -a / b). 389 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 390 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 391 return zero.ssub_ov(div, overflowOrDiv0); 392 } 393 // A is positive, b is negative, return - (a / -b). 394 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 395 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 396 return zero.ssub_ov(div, overflowOrDiv0); 397 }); 398 399 // Fold out ceil division by one. Assumes all tensors of all ones are 400 // splats. 401 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 402 if (rhs.getValue() == 1) 403 return getLhs(); 404 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 405 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 406 return getLhs(); 407 } 408 409 return overflowOrDiv0 ? Attribute() : result; 410 } 411 412 //===----------------------------------------------------------------------===// 413 // FloorDivSIOp 414 //===----------------------------------------------------------------------===// 415 416 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 417 // Don't fold if it would overflow or if it requires a division by zero. 418 bool overflowOrDiv0 = false; 419 auto result = 420 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 421 if (overflowOrDiv0 || !b) { 422 overflowOrDiv0 = true; 423 return a; 424 } 425 if (!a) 426 return a; 427 // After this point we know that neither a or b are zero. 428 unsigned bits = a.getBitWidth(); 429 APInt zero = APInt::getZero(bits); 430 bool aGtZero = a.sgt(zero); 431 bool bGtZero = b.sgt(zero); 432 if (aGtZero && bGtZero) { 433 // Both positive, return a / b. 434 return a.sdiv_ov(b, overflowOrDiv0); 435 } 436 if (!aGtZero && !bGtZero) { 437 // Both negative, return -a / -b. 438 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 439 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 440 return posA.sdiv_ov(posB, overflowOrDiv0); 441 } 442 if (!aGtZero && bGtZero) { 443 // A is negative, b is positive, return - ceil(-a, b). 444 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 445 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 446 return zero.ssub_ov(ceil, overflowOrDiv0); 447 } 448 // A is positive, b is negative, return - ceil(a, -b). 449 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 450 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 451 return zero.ssub_ov(ceil, overflowOrDiv0); 452 }); 453 454 // Fold out floor division by one. Assumes all tensors of all ones are 455 // splats. 456 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 457 if (rhs.getValue() == 1) 458 return getLhs(); 459 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 460 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 461 return getLhs(); 462 } 463 464 return overflowOrDiv0 ? Attribute() : result; 465 } 466 467 //===----------------------------------------------------------------------===// 468 // RemUIOp 469 //===----------------------------------------------------------------------===// 470 471 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 472 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 473 if (!rhs) 474 return {}; 475 auto rhsValue = rhs.getValue(); 476 477 // x % 1 = 0 478 if (rhsValue.isOneValue()) 479 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 480 481 // Don't fold if it requires division by zero. 482 if (rhsValue.isNullValue()) 483 return {}; 484 485 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 486 if (!lhs) 487 return {}; 488 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 489 } 490 491 //===----------------------------------------------------------------------===// 492 // RemSIOp 493 //===----------------------------------------------------------------------===// 494 495 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 496 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 497 if (!rhs) 498 return {}; 499 auto rhsValue = rhs.getValue(); 500 501 // x % 1 = 0 502 if (rhsValue.isOneValue()) 503 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 504 505 // Don't fold if it requires division by zero. 506 if (rhsValue.isNullValue()) 507 return {}; 508 509 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 510 if (!lhs) 511 return {}; 512 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 513 } 514 515 //===----------------------------------------------------------------------===// 516 // AndIOp 517 //===----------------------------------------------------------------------===// 518 519 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 520 /// and(x, 0) -> 0 521 if (matchPattern(getRhs(), m_Zero())) 522 return getRhs(); 523 /// and(x, allOnes) -> x 524 APInt intValue; 525 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 526 return getLhs(); 527 528 return constFoldBinaryOp<IntegerAttr>( 529 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 530 } 531 532 //===----------------------------------------------------------------------===// 533 // OrIOp 534 //===----------------------------------------------------------------------===// 535 536 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 537 /// or(x, 0) -> x 538 if (matchPattern(getRhs(), m_Zero())) 539 return getLhs(); 540 /// or(x, <all ones>) -> <all ones> 541 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 542 if (rhsAttr.getValue().isAllOnes()) 543 return rhsAttr; 544 545 return constFoldBinaryOp<IntegerAttr>( 546 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 547 } 548 549 //===----------------------------------------------------------------------===// 550 // XOrIOp 551 //===----------------------------------------------------------------------===// 552 553 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 554 /// xor(x, 0) -> x 555 if (matchPattern(getRhs(), m_Zero())) 556 return getLhs(); 557 /// xor(x, x) -> 0 558 if (getLhs() == getRhs()) 559 return Builder(getContext()).getZeroAttr(getType()); 560 /// xor(xor(x, a), a) -> x 561 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) 562 if (prev.getRhs() == getRhs()) 563 return prev.getLhs(); 564 565 return constFoldBinaryOp<IntegerAttr>( 566 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 567 } 568 569 void arith::XOrIOp::getCanonicalizationPatterns( 570 OwningRewritePatternList &patterns, MLIRContext *context) { 571 patterns.insert<XOrINotCmpI>(context); 572 } 573 574 //===----------------------------------------------------------------------===// 575 // AddFOp 576 //===----------------------------------------------------------------------===// 577 578 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 579 return constFoldBinaryOp<FloatAttr>( 580 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 581 } 582 583 //===----------------------------------------------------------------------===// 584 // SubFOp 585 //===----------------------------------------------------------------------===// 586 587 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 588 return constFoldBinaryOp<FloatAttr>( 589 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 590 } 591 592 //===----------------------------------------------------------------------===// 593 // MaxSIOp 594 //===----------------------------------------------------------------------===// 595 596 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 597 assert(operands.size() == 2 && "binary operation takes two operands"); 598 599 // maxsi(x,x) -> x 600 if (getLhs() == getRhs()) 601 return getRhs(); 602 603 APInt intValue; 604 // maxsi(x,MAX_INT) -> MAX_INT 605 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 606 intValue.isMaxSignedValue()) 607 return getRhs(); 608 609 // maxsi(x, MIN_INT) -> x 610 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 611 intValue.isMinSignedValue()) 612 return getLhs(); 613 614 return constFoldBinaryOp<IntegerAttr>(operands, 615 [](const APInt &a, const APInt &b) { 616 return llvm::APIntOps::smax(a, b); 617 }); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // MaxUIOp 622 //===----------------------------------------------------------------------===// 623 624 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 625 assert(operands.size() == 2 && "binary operation takes two operands"); 626 627 // maxui(x,x) -> x 628 if (getLhs() == getRhs()) 629 return getRhs(); 630 631 APInt intValue; 632 // maxui(x,MAX_INT) -> MAX_INT 633 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 634 return getRhs(); 635 636 // maxui(x, MIN_INT) -> x 637 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 638 return getLhs(); 639 640 return constFoldBinaryOp<IntegerAttr>(operands, 641 [](const APInt &a, const APInt &b) { 642 return llvm::APIntOps::umax(a, b); 643 }); 644 } 645 646 //===----------------------------------------------------------------------===// 647 // MinSIOp 648 //===----------------------------------------------------------------------===// 649 650 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 651 assert(operands.size() == 2 && "binary operation takes two operands"); 652 653 // minsi(x,x) -> x 654 if (getLhs() == getRhs()) 655 return getRhs(); 656 657 APInt intValue; 658 // minsi(x,MIN_INT) -> MIN_INT 659 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 660 intValue.isMinSignedValue()) 661 return getRhs(); 662 663 // minsi(x, MAX_INT) -> x 664 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 665 intValue.isMaxSignedValue()) 666 return getLhs(); 667 668 return constFoldBinaryOp<IntegerAttr>(operands, 669 [](const APInt &a, const APInt &b) { 670 return llvm::APIntOps::smin(a, b); 671 }); 672 } 673 674 //===----------------------------------------------------------------------===// 675 // MinUIOp 676 //===----------------------------------------------------------------------===// 677 678 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 679 assert(operands.size() == 2 && "binary operation takes two operands"); 680 681 // minui(x,x) -> x 682 if (getLhs() == getRhs()) 683 return getRhs(); 684 685 APInt intValue; 686 // minui(x,MIN_INT) -> MIN_INT 687 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 688 return getRhs(); 689 690 // minui(x, MAX_INT) -> x 691 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 692 return getLhs(); 693 694 return constFoldBinaryOp<IntegerAttr>(operands, 695 [](const APInt &a, const APInt &b) { 696 return llvm::APIntOps::umin(a, b); 697 }); 698 } 699 700 //===----------------------------------------------------------------------===// 701 // MulFOp 702 //===----------------------------------------------------------------------===// 703 704 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 705 return constFoldBinaryOp<FloatAttr>( 706 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 707 } 708 709 //===----------------------------------------------------------------------===// 710 // DivFOp 711 //===----------------------------------------------------------------------===// 712 713 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 714 return constFoldBinaryOp<FloatAttr>( 715 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 716 } 717 718 //===----------------------------------------------------------------------===// 719 // Utility functions for verifying cast ops 720 //===----------------------------------------------------------------------===// 721 722 template <typename... Types> 723 using type_list = std::tuple<Types...> *; 724 725 /// Returns a non-null type only if the provided type is one of the allowed 726 /// types or one of the allowed shaped types of the allowed types. Returns the 727 /// element type if a valid shaped type is provided. 728 template <typename... ShapedTypes, typename... ElementTypes> 729 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 730 type_list<ElementTypes...>) { 731 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 732 return {}; 733 734 auto underlyingType = getElementTypeOrSelf(type); 735 if (!underlyingType.isa<ElementTypes...>()) 736 return {}; 737 738 return underlyingType; 739 } 740 741 /// Get allowed underlying types for vectors and tensors. 742 template <typename... ElementTypes> 743 static Type getTypeIfLike(Type type) { 744 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 745 type_list<ElementTypes...>()); 746 } 747 748 /// Get allowed underlying types for vectors, tensors, and memrefs. 749 template <typename... ElementTypes> 750 static Type getTypeIfLikeOrMemRef(Type type) { 751 return getUnderlyingType(type, 752 type_list<VectorType, TensorType, MemRefType>(), 753 type_list<ElementTypes...>()); 754 } 755 756 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 757 return inputs.size() == 1 && outputs.size() == 1 && 758 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 759 } 760 761 //===----------------------------------------------------------------------===// 762 // Verifiers for integer and floating point extension/truncation ops 763 //===----------------------------------------------------------------------===// 764 765 // Extend ops can only extend to a wider type. 766 template <typename ValType, typename Op> 767 static LogicalResult verifyExtOp(Op op) { 768 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 769 Type dstType = getElementTypeOrSelf(op.getType()); 770 771 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 772 return op.emitError("result type ") 773 << dstType << " must be wider than operand type " << srcType; 774 775 return success(); 776 } 777 778 // Truncate ops can only truncate to a shorter type. 779 template <typename ValType, typename Op> 780 static LogicalResult verifyTruncateOp(Op op) { 781 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 782 Type dstType = getElementTypeOrSelf(op.getType()); 783 784 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 785 return op.emitError("result type ") 786 << dstType << " must be shorter than operand type " << srcType; 787 788 return success(); 789 } 790 791 /// Validate a cast that changes the width of a type. 792 template <template <typename> class WidthComparator, typename... ElementTypes> 793 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 794 if (!areValidCastInputsAndOutputs(inputs, outputs)) 795 return false; 796 797 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 798 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 799 if (!srcType || !dstType) 800 return false; 801 802 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 803 srcType.getIntOrFloatBitWidth()); 804 } 805 806 //===----------------------------------------------------------------------===// 807 // ExtUIOp 808 //===----------------------------------------------------------------------===// 809 810 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 811 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 812 return IntegerAttr::get( 813 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 814 815 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 816 getInMutable().assign(lhs.getIn()); 817 return getResult(); 818 } 819 820 return {}; 821 } 822 823 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 824 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 825 } 826 827 //===----------------------------------------------------------------------===// 828 // ExtSIOp 829 //===----------------------------------------------------------------------===// 830 831 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 832 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 833 return IntegerAttr::get( 834 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 835 836 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 837 getInMutable().assign(lhs.getIn()); 838 return getResult(); 839 } 840 841 return {}; 842 } 843 844 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 845 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 846 } 847 848 void arith::ExtSIOp::getCanonicalizationPatterns( 849 OwningRewritePatternList &patterns, MLIRContext *context) { 850 patterns.insert<ExtSIOfExtUI>(context); 851 } 852 853 //===----------------------------------------------------------------------===// 854 // ExtFOp 855 //===----------------------------------------------------------------------===// 856 857 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 858 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 859 } 860 861 //===----------------------------------------------------------------------===// 862 // TruncIOp 863 //===----------------------------------------------------------------------===// 864 865 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 866 assert(operands.size() == 1 && "unary operation takes one operand"); 867 868 // trunci(zexti(a)) -> a 869 // trunci(sexti(a)) -> a 870 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 871 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 872 return getOperand().getDefiningOp()->getOperand(0); 873 874 // trunci(trunci(a)) -> trunci(a)) 875 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { 876 setOperand(getOperand().getDefiningOp()->getOperand(0)); 877 return getResult(); 878 } 879 880 if (!operands[0]) 881 return {}; 882 883 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 884 return IntegerAttr::get( 885 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 886 } 887 888 return {}; 889 } 890 891 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 892 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 893 } 894 895 //===----------------------------------------------------------------------===// 896 // TruncFOp 897 //===----------------------------------------------------------------------===// 898 899 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 900 /// can be represented without precision loss or rounding. 901 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 902 assert(operands.size() == 1 && "unary operation takes one operand"); 903 904 auto constOperand = operands.front(); 905 if (!constOperand || !constOperand.isa<FloatAttr>()) 906 return {}; 907 908 // Convert to target type via 'double'. 909 double sourceValue = 910 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 911 auto targetAttr = FloatAttr::get(getType(), sourceValue); 912 913 // Propagate if constant's value does not change after truncation. 914 if (sourceValue == targetAttr.getValue().convertToDouble()) 915 return targetAttr; 916 917 return {}; 918 } 919 920 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 921 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 922 } 923 924 //===----------------------------------------------------------------------===// 925 // AndIOp 926 //===----------------------------------------------------------------------===// 927 928 void arith::AndIOp::getCanonicalizationPatterns( 929 OwningRewritePatternList &patterns, MLIRContext *context) { 930 patterns.insert<AndOfExtUI, AndOfExtSI>(context); 931 } 932 933 //===----------------------------------------------------------------------===// 934 // OrIOp 935 //===----------------------------------------------------------------------===// 936 937 void arith::OrIOp::getCanonicalizationPatterns( 938 OwningRewritePatternList &patterns, MLIRContext *context) { 939 patterns.insert<OrOfExtUI, OrOfExtSI>(context); 940 } 941 942 //===----------------------------------------------------------------------===// 943 // Verifiers for casts between integers and floats. 944 //===----------------------------------------------------------------------===// 945 946 template <typename From, typename To> 947 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 948 if (!areValidCastInputsAndOutputs(inputs, outputs)) 949 return false; 950 951 auto srcType = getTypeIfLike<From>(inputs.front()); 952 auto dstType = getTypeIfLike<To>(outputs.back()); 953 954 return srcType && dstType; 955 } 956 957 //===----------------------------------------------------------------------===// 958 // UIToFPOp 959 //===----------------------------------------------------------------------===// 960 961 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 962 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 963 } 964 965 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 966 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 967 const APInt &api = lhs.getValue(); 968 FloatType floatTy = getType().cast<FloatType>(); 969 APFloat apf(floatTy.getFloatSemantics(), 970 APInt::getZero(floatTy.getWidth())); 971 apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven); 972 return FloatAttr::get(floatTy, apf); 973 } 974 return {}; 975 } 976 977 //===----------------------------------------------------------------------===// 978 // SIToFPOp 979 //===----------------------------------------------------------------------===// 980 981 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 982 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 983 } 984 985 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 986 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 987 const APInt &api = lhs.getValue(); 988 FloatType floatTy = getType().cast<FloatType>(); 989 APFloat apf(floatTy.getFloatSemantics(), 990 APInt::getZero(floatTy.getWidth())); 991 apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven); 992 return FloatAttr::get(floatTy, apf); 993 } 994 return {}; 995 } 996 //===----------------------------------------------------------------------===// 997 // FPToUIOp 998 //===----------------------------------------------------------------------===// 999 1000 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1001 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1002 } 1003 1004 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 1005 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 1006 const APFloat &apf = lhs.getValue(); 1007 IntegerType intTy = getType().cast<IntegerType>(); 1008 bool ignored; 1009 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 1010 if (APFloat::opInvalidOp == 1011 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1012 // Undefined behavior invoked - the destination type can't represent 1013 // the input constant. 1014 return {}; 1015 } 1016 return IntegerAttr::get(getType(), api); 1017 } 1018 1019 return {}; 1020 } 1021 1022 //===----------------------------------------------------------------------===// 1023 // FPToSIOp 1024 //===----------------------------------------------------------------------===// 1025 1026 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1027 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1028 } 1029 1030 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 1031 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 1032 const APFloat &apf = lhs.getValue(); 1033 IntegerType intTy = getType().cast<IntegerType>(); 1034 bool ignored; 1035 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 1036 if (APFloat::opInvalidOp == 1037 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1038 // Undefined behavior invoked - the destination type can't represent 1039 // the input constant. 1040 return {}; 1041 } 1042 return IntegerAttr::get(getType(), api); 1043 } 1044 1045 return {}; 1046 } 1047 1048 //===----------------------------------------------------------------------===// 1049 // IndexCastOp 1050 //===----------------------------------------------------------------------===// 1051 1052 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1053 TypeRange outputs) { 1054 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1055 return false; 1056 1057 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1058 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1059 if (!srcType || !dstType) 1060 return false; 1061 1062 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1063 (srcType.isSignlessInteger() && dstType.isIndex()); 1064 } 1065 1066 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1067 // index_cast(constant) -> constant 1068 // A little hack because we go through int. Otherwise, the size of the 1069 // constant might need to change. 1070 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1071 return IntegerAttr::get(getType(), value.getInt()); 1072 1073 return {}; 1074 } 1075 1076 void arith::IndexCastOp::getCanonicalizationPatterns( 1077 OwningRewritePatternList &patterns, MLIRContext *context) { 1078 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1079 } 1080 1081 //===----------------------------------------------------------------------===// 1082 // BitcastOp 1083 //===----------------------------------------------------------------------===// 1084 1085 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1086 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1087 return false; 1088 1089 auto srcType = 1090 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1091 auto dstType = 1092 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1093 if (!srcType || !dstType) 1094 return false; 1095 1096 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1097 } 1098 1099 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1100 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1101 1102 auto resType = getType(); 1103 auto operand = operands[0]; 1104 if (!operand) 1105 return {}; 1106 1107 /// Bitcast dense elements. 1108 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1109 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1110 /// Other shaped types unhandled. 1111 if (resType.isa<ShapedType>()) 1112 return {}; 1113 1114 /// Bitcast integer or float to integer or float. 1115 APInt bits = operand.isa<FloatAttr>() 1116 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1117 : operand.cast<IntegerAttr>().getValue(); 1118 1119 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1120 return FloatAttr::get(resType, 1121 APFloat(resFloatType.getFloatSemantics(), bits)); 1122 return IntegerAttr::get(resType, bits); 1123 } 1124 1125 void arith::BitcastOp::getCanonicalizationPatterns( 1126 OwningRewritePatternList &patterns, MLIRContext *context) { 1127 patterns.insert<BitcastOfBitcast>(context); 1128 } 1129 1130 //===----------------------------------------------------------------------===// 1131 // Helpers for compare ops 1132 //===----------------------------------------------------------------------===// 1133 1134 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1135 static Type getI1SameShape(Type type) { 1136 auto i1Type = IntegerType::get(type.getContext(), 1); 1137 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1138 return RankedTensorType::get(tensorType.getShape(), i1Type); 1139 if (type.isa<UnrankedTensorType>()) 1140 return UnrankedTensorType::get(i1Type); 1141 if (auto vectorType = type.dyn_cast<VectorType>()) 1142 return VectorType::get(vectorType.getShape(), i1Type, 1143 vectorType.getNumScalableDims()); 1144 return i1Type; 1145 } 1146 1147 //===----------------------------------------------------------------------===// 1148 // CmpIOp 1149 //===----------------------------------------------------------------------===// 1150 1151 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1152 /// comparison predicates. 1153 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1154 const APInt &lhs, const APInt &rhs) { 1155 switch (predicate) { 1156 case arith::CmpIPredicate::eq: 1157 return lhs.eq(rhs); 1158 case arith::CmpIPredicate::ne: 1159 return lhs.ne(rhs); 1160 case arith::CmpIPredicate::slt: 1161 return lhs.slt(rhs); 1162 case arith::CmpIPredicate::sle: 1163 return lhs.sle(rhs); 1164 case arith::CmpIPredicate::sgt: 1165 return lhs.sgt(rhs); 1166 case arith::CmpIPredicate::sge: 1167 return lhs.sge(rhs); 1168 case arith::CmpIPredicate::ult: 1169 return lhs.ult(rhs); 1170 case arith::CmpIPredicate::ule: 1171 return lhs.ule(rhs); 1172 case arith::CmpIPredicate::ugt: 1173 return lhs.ugt(rhs); 1174 case arith::CmpIPredicate::uge: 1175 return lhs.uge(rhs); 1176 } 1177 llvm_unreachable("unknown cmpi predicate kind"); 1178 } 1179 1180 /// Returns true if the predicate is true for two equal operands. 1181 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1182 switch (predicate) { 1183 case arith::CmpIPredicate::eq: 1184 case arith::CmpIPredicate::sle: 1185 case arith::CmpIPredicate::sge: 1186 case arith::CmpIPredicate::ule: 1187 case arith::CmpIPredicate::uge: 1188 return true; 1189 case arith::CmpIPredicate::ne: 1190 case arith::CmpIPredicate::slt: 1191 case arith::CmpIPredicate::sgt: 1192 case arith::CmpIPredicate::ult: 1193 case arith::CmpIPredicate::ugt: 1194 return false; 1195 } 1196 llvm_unreachable("unknown cmpi predicate kind"); 1197 } 1198 1199 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1200 auto boolAttr = BoolAttr::get(ctx, value); 1201 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1202 if (!shapedType) 1203 return boolAttr; 1204 return DenseElementsAttr::get(shapedType, boolAttr); 1205 } 1206 1207 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1208 assert(operands.size() == 2 && "cmpi takes two operands"); 1209 1210 // cmpi(pred, x, x) 1211 if (getLhs() == getRhs()) { 1212 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1213 return getBoolAttribute(getType(), getContext(), val); 1214 } 1215 1216 if (matchPattern(getRhs(), m_Zero())) { 1217 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1218 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1219 // extsi(%x : i1 -> iN) != 0 -> %x 1220 if (getPredicate() == arith::CmpIPredicate::ne) { 1221 return extOp.getOperand(); 1222 } 1223 } 1224 } 1225 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1226 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1227 // extui(%x : i1 -> iN) != 0 -> %x 1228 if (getPredicate() == arith::CmpIPredicate::ne) { 1229 return extOp.getOperand(); 1230 } 1231 } 1232 } 1233 } 1234 1235 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1236 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1237 if (!lhs || !rhs) 1238 return {}; 1239 1240 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1241 return BoolAttr::get(getContext(), val); 1242 } 1243 1244 //===----------------------------------------------------------------------===// 1245 // CmpFOp 1246 //===----------------------------------------------------------------------===// 1247 1248 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1249 /// comparison predicates. 1250 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1251 const APFloat &lhs, const APFloat &rhs) { 1252 auto cmpResult = lhs.compare(rhs); 1253 switch (predicate) { 1254 case arith::CmpFPredicate::AlwaysFalse: 1255 return false; 1256 case arith::CmpFPredicate::OEQ: 1257 return cmpResult == APFloat::cmpEqual; 1258 case arith::CmpFPredicate::OGT: 1259 return cmpResult == APFloat::cmpGreaterThan; 1260 case arith::CmpFPredicate::OGE: 1261 return cmpResult == APFloat::cmpGreaterThan || 1262 cmpResult == APFloat::cmpEqual; 1263 case arith::CmpFPredicate::OLT: 1264 return cmpResult == APFloat::cmpLessThan; 1265 case arith::CmpFPredicate::OLE: 1266 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1267 case arith::CmpFPredicate::ONE: 1268 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1269 case arith::CmpFPredicate::ORD: 1270 return cmpResult != APFloat::cmpUnordered; 1271 case arith::CmpFPredicate::UEQ: 1272 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1273 case arith::CmpFPredicate::UGT: 1274 return cmpResult == APFloat::cmpUnordered || 1275 cmpResult == APFloat::cmpGreaterThan; 1276 case arith::CmpFPredicate::UGE: 1277 return cmpResult == APFloat::cmpUnordered || 1278 cmpResult == APFloat::cmpGreaterThan || 1279 cmpResult == APFloat::cmpEqual; 1280 case arith::CmpFPredicate::ULT: 1281 return cmpResult == APFloat::cmpUnordered || 1282 cmpResult == APFloat::cmpLessThan; 1283 case arith::CmpFPredicate::ULE: 1284 return cmpResult == APFloat::cmpUnordered || 1285 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1286 case arith::CmpFPredicate::UNE: 1287 return cmpResult != APFloat::cmpEqual; 1288 case arith::CmpFPredicate::UNO: 1289 return cmpResult == APFloat::cmpUnordered; 1290 case arith::CmpFPredicate::AlwaysTrue: 1291 return true; 1292 } 1293 llvm_unreachable("unknown cmpf predicate kind"); 1294 } 1295 1296 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1297 assert(operands.size() == 2 && "cmpf takes two operands"); 1298 1299 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1300 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1301 1302 // If one operand is NaN, making them both NaN does not change the result. 1303 if (lhs && lhs.getValue().isNaN()) 1304 rhs = lhs; 1305 if (rhs && rhs.getValue().isNaN()) 1306 lhs = rhs; 1307 1308 if (!lhs || !rhs) 1309 return {}; 1310 1311 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1312 return BoolAttr::get(getContext(), val); 1313 } 1314 1315 //===----------------------------------------------------------------------===// 1316 // Atomic Enum 1317 //===----------------------------------------------------------------------===// 1318 1319 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1320 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1321 OpBuilder &builder, Location loc) { 1322 switch (kind) { 1323 case AtomicRMWKind::maxf: 1324 return builder.getFloatAttr( 1325 resultType, 1326 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1327 /*Negative=*/true)); 1328 case AtomicRMWKind::addf: 1329 case AtomicRMWKind::addi: 1330 case AtomicRMWKind::maxu: 1331 case AtomicRMWKind::ori: 1332 return builder.getZeroAttr(resultType); 1333 case AtomicRMWKind::andi: 1334 return builder.getIntegerAttr( 1335 resultType, 1336 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1337 case AtomicRMWKind::maxs: 1338 return builder.getIntegerAttr( 1339 resultType, 1340 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1341 case AtomicRMWKind::minf: 1342 return builder.getFloatAttr( 1343 resultType, 1344 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1345 /*Negative=*/false)); 1346 case AtomicRMWKind::mins: 1347 return builder.getIntegerAttr( 1348 resultType, 1349 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1350 case AtomicRMWKind::minu: 1351 return builder.getIntegerAttr( 1352 resultType, 1353 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1354 case AtomicRMWKind::muli: 1355 return builder.getIntegerAttr(resultType, 1); 1356 case AtomicRMWKind::mulf: 1357 return builder.getFloatAttr(resultType, 1); 1358 // TODO: Add remaining reduction operations. 1359 default: 1360 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1361 break; 1362 } 1363 return nullptr; 1364 } 1365 1366 /// Returns the identity value associated with an AtomicRMWKind op. 1367 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 1368 OpBuilder &builder, Location loc) { 1369 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 1370 return builder.create<arith::ConstantOp>(loc, attr); 1371 } 1372 1373 /// Return the value obtained by applying the reduction operation kind 1374 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 1375 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 1376 Location loc, Value lhs, Value rhs) { 1377 switch (op) { 1378 case AtomicRMWKind::addf: 1379 return builder.create<arith::AddFOp>(loc, lhs, rhs); 1380 case AtomicRMWKind::addi: 1381 return builder.create<arith::AddIOp>(loc, lhs, rhs); 1382 case AtomicRMWKind::mulf: 1383 return builder.create<arith::MulFOp>(loc, lhs, rhs); 1384 case AtomicRMWKind::muli: 1385 return builder.create<arith::MulIOp>(loc, lhs, rhs); 1386 case AtomicRMWKind::maxf: 1387 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 1388 case AtomicRMWKind::minf: 1389 return builder.create<arith::MinFOp>(loc, lhs, rhs); 1390 case AtomicRMWKind::maxs: 1391 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 1392 case AtomicRMWKind::mins: 1393 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 1394 case AtomicRMWKind::maxu: 1395 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 1396 case AtomicRMWKind::minu: 1397 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 1398 case AtomicRMWKind::ori: 1399 return builder.create<arith::OrIOp>(loc, lhs, rhs); 1400 case AtomicRMWKind::andi: 1401 return builder.create<arith::AndIOp>(loc, lhs, rhs); 1402 // TODO: Add remaining reduction operations. 1403 default: 1404 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1405 break; 1406 } 1407 return nullptr; 1408 } 1409 1410 //===----------------------------------------------------------------------===// 1411 // TableGen'd op method definitions 1412 //===----------------------------------------------------------------------===// 1413 1414 #define GET_OP_CLASSES 1415 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1416 1417 //===----------------------------------------------------------------------===// 1418 // TableGen'd enum attribute definitions 1419 //===----------------------------------------------------------------------===// 1420 1421 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1422