1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 #include "llvm/ADT/APSInt.h" 20 21 using namespace mlir; 22 using namespace mlir::arith; 23 24 //===----------------------------------------------------------------------===// 25 // Pattern helpers 26 //===----------------------------------------------------------------------===// 27 28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 29 Attribute lhs, Attribute rhs) { 30 return builder.getIntegerAttr(res.getType(), 31 lhs.cast<IntegerAttr>().getInt() + 32 rhs.cast<IntegerAttr>().getInt()); 33 } 34 35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 36 Attribute lhs, Attribute rhs) { 37 return builder.getIntegerAttr(res.getType(), 38 lhs.cast<IntegerAttr>().getInt() - 39 rhs.cast<IntegerAttr>().getInt()); 40 } 41 42 /// Invert an integer comparison predicate. 43 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { 44 switch (pred) { 45 case arith::CmpIPredicate::eq: 46 return arith::CmpIPredicate::ne; 47 case arith::CmpIPredicate::ne: 48 return arith::CmpIPredicate::eq; 49 case arith::CmpIPredicate::slt: 50 return arith::CmpIPredicate::sge; 51 case arith::CmpIPredicate::sle: 52 return arith::CmpIPredicate::sgt; 53 case arith::CmpIPredicate::sgt: 54 return arith::CmpIPredicate::sle; 55 case arith::CmpIPredicate::sge: 56 return arith::CmpIPredicate::slt; 57 case arith::CmpIPredicate::ult: 58 return arith::CmpIPredicate::uge; 59 case arith::CmpIPredicate::ule: 60 return arith::CmpIPredicate::ugt; 61 case arith::CmpIPredicate::ugt: 62 return arith::CmpIPredicate::ule; 63 case arith::CmpIPredicate::uge: 64 return arith::CmpIPredicate::ult; 65 } 66 llvm_unreachable("unknown cmpi predicate kind"); 67 } 68 69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 70 return arith::CmpIPredicateAttr::get(pred.getContext(), 71 invertPredicate(pred.getValue())); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // TableGen'd canonicalization patterns 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 #include "ArithmeticCanonicalization.inc" 80 } // namespace 81 82 //===----------------------------------------------------------------------===// 83 // ConstantOp 84 //===----------------------------------------------------------------------===// 85 86 void arith::ConstantOp::getAsmResultNames( 87 function_ref<void(Value, StringRef)> setNameFn) { 88 auto type = getType(); 89 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 90 auto intType = type.dyn_cast<IntegerType>(); 91 92 // Sugar i1 constants with 'true' and 'false'. 93 if (intType && intType.getWidth() == 1) 94 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 95 96 // Otherwise, build a compex name with the value and type. 97 SmallString<32> specialNameBuffer; 98 llvm::raw_svector_ostream specialName(specialNameBuffer); 99 specialName << 'c' << intCst.getInt(); 100 if (intType) 101 specialName << '_' << type; 102 setNameFn(getResult(), specialName.str()); 103 } else { 104 setNameFn(getResult(), "cst"); 105 } 106 } 107 108 /// TODO: disallow arith.constant to return anything other than signless integer 109 /// or float like. 110 static LogicalResult verify(arith::ConstantOp op) { 111 auto type = op.getType(); 112 // The value's type must match the return type. 113 if (op.getValue().getType() != type) { 114 return op.emitOpError() << "value type " << op.getValue().getType() 115 << " must match return type: " << type; 116 } 117 // Integer values must be signless. 118 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 119 return op.emitOpError("integer return type must be signless"); 120 // Any float or elements attribute are acceptable. 121 if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 122 return op.emitOpError( 123 "value must be an integer, float, or elements attribute"); 124 } 125 return success(); 126 } 127 128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 129 // The value's type must be the same as the provided type. 130 if (value.getType() != type) 131 return false; 132 // Integer values must be signless. 133 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 134 return false; 135 // Integer, float, and element attributes are buildable. 136 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 137 } 138 139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 140 return getValue(); 141 } 142 143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 144 int64_t value, unsigned width) { 145 auto type = builder.getIntegerType(width); 146 arith::ConstantOp::build(builder, result, type, 147 builder.getIntegerAttr(type, value)); 148 } 149 150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 151 int64_t value, Type type) { 152 assert(type.isSignlessInteger() && 153 "ConstantIntOp can only have signless integer type values"); 154 arith::ConstantOp::build(builder, result, type, 155 builder.getIntegerAttr(type, value)); 156 } 157 158 bool arith::ConstantIntOp::classof(Operation *op) { 159 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 160 return constOp.getType().isSignlessInteger(); 161 return false; 162 } 163 164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 165 const APFloat &value, FloatType type) { 166 arith::ConstantOp::build(builder, result, type, 167 builder.getFloatAttr(type, value)); 168 } 169 170 bool arith::ConstantFloatOp::classof(Operation *op) { 171 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 172 return constOp.getType().isa<FloatType>(); 173 return false; 174 } 175 176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 177 int64_t value) { 178 arith::ConstantOp::build(builder, result, builder.getIndexType(), 179 builder.getIndexAttr(value)); 180 } 181 182 bool arith::ConstantIndexOp::classof(Operation *op) { 183 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 184 return constOp.getType().isIndex(); 185 return false; 186 } 187 188 //===----------------------------------------------------------------------===// 189 // AddIOp 190 //===----------------------------------------------------------------------===// 191 192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 193 // addi(x, 0) -> x 194 if (matchPattern(getRhs(), m_Zero())) 195 return getLhs(); 196 197 // add(sub(a, b), b) -> a 198 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 199 if (getRhs() == sub.getRhs()) 200 return sub.getLhs(); 201 202 // add(b, sub(a, b)) -> a 203 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 204 if (getLhs() == sub.getRhs()) 205 return sub.getLhs(); 206 207 return constFoldBinaryOp<IntegerAttr>( 208 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 209 } 210 211 void arith::AddIOp::getCanonicalizationPatterns( 212 OwningRewritePatternList &patterns, MLIRContext *context) { 213 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 214 context); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // SubIOp 219 //===----------------------------------------------------------------------===// 220 221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 222 // subi(x,x) -> 0 223 if (getOperand(0) == getOperand(1)) 224 return Builder(getContext()).getZeroAttr(getType()); 225 // subi(x,0) -> x 226 if (matchPattern(getRhs(), m_Zero())) 227 return getLhs(); 228 229 return constFoldBinaryOp<IntegerAttr>( 230 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 231 } 232 233 void arith::SubIOp::getCanonicalizationPatterns( 234 OwningRewritePatternList &patterns, MLIRContext *context) { 235 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 236 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 237 SubILHSSubConstantLHS>(context); 238 } 239 240 //===----------------------------------------------------------------------===// 241 // MulIOp 242 //===----------------------------------------------------------------------===// 243 244 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 245 // muli(x, 0) -> 0 246 if (matchPattern(getRhs(), m_Zero())) 247 return getRhs(); 248 // muli(x, 1) -> x 249 if (matchPattern(getRhs(), m_One())) 250 return getOperand(0); 251 // TODO: Handle the overflow case. 252 253 // default folder 254 return constFoldBinaryOp<IntegerAttr>( 255 operands, [](const APInt &a, const APInt &b) { return a * b; }); 256 } 257 258 //===----------------------------------------------------------------------===// 259 // DivUIOp 260 //===----------------------------------------------------------------------===// 261 262 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 263 // Don't fold if it would require a division by zero. 264 bool div0 = false; 265 auto result = 266 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 267 if (div0 || !b) { 268 div0 = true; 269 return a; 270 } 271 return a.udiv(b); 272 }); 273 274 // Fold out division by one. Assumes all tensors of all ones are splats. 275 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 276 if (rhs.getValue() == 1) 277 return getLhs(); 278 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 279 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 280 return getLhs(); 281 } 282 283 return div0 ? Attribute() : result; 284 } 285 286 //===----------------------------------------------------------------------===// 287 // DivSIOp 288 //===----------------------------------------------------------------------===// 289 290 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 291 // Don't fold if it would overflow or if it requires a division by zero. 292 bool overflowOrDiv0 = false; 293 auto result = 294 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 295 if (overflowOrDiv0 || !b) { 296 overflowOrDiv0 = true; 297 return a; 298 } 299 return a.sdiv_ov(b, overflowOrDiv0); 300 }); 301 302 // Fold out division by one. Assumes all tensors of all ones are splats. 303 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 304 if (rhs.getValue() == 1) 305 return getLhs(); 306 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 307 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 308 return getLhs(); 309 } 310 311 return overflowOrDiv0 ? Attribute() : result; 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Ceil and floor division folding helpers 316 //===----------------------------------------------------------------------===// 317 318 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 319 bool &overflow) { 320 // Returns (a-1)/b + 1 321 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 322 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 323 return val.sadd_ov(one, overflow); 324 } 325 326 //===----------------------------------------------------------------------===// 327 // CeilDivUIOp 328 //===----------------------------------------------------------------------===// 329 330 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 331 bool overflowOrDiv0 = false; 332 auto result = 333 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 334 if (overflowOrDiv0 || !b) { 335 overflowOrDiv0 = true; 336 return a; 337 } 338 APInt quotient = a.udiv(b); 339 if (!a.urem(b)) 340 return quotient; 341 APInt one(a.getBitWidth(), 1, true); 342 return quotient.uadd_ov(one, overflowOrDiv0); 343 }); 344 // Fold out ceil division by one. Assumes all tensors of all ones are 345 // splats. 346 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 347 if (rhs.getValue() == 1) 348 return getLhs(); 349 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 350 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 351 return getLhs(); 352 } 353 354 return overflowOrDiv0 ? Attribute() : result; 355 } 356 357 //===----------------------------------------------------------------------===// 358 // CeilDivSIOp 359 //===----------------------------------------------------------------------===// 360 361 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 362 // Don't fold if it would overflow or if it requires a division by zero. 363 bool overflowOrDiv0 = false; 364 auto result = 365 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 366 if (overflowOrDiv0 || !b) { 367 overflowOrDiv0 = true; 368 return a; 369 } 370 if (!a) 371 return a; 372 // After this point we know that neither a or b are zero. 373 unsigned bits = a.getBitWidth(); 374 APInt zero = APInt::getZero(bits); 375 bool aGtZero = a.sgt(zero); 376 bool bGtZero = b.sgt(zero); 377 if (aGtZero && bGtZero) { 378 // Both positive, return ceil(a, b). 379 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 380 } 381 if (!aGtZero && !bGtZero) { 382 // Both negative, return ceil(-a, -b). 383 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 384 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 385 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 386 } 387 if (!aGtZero && bGtZero) { 388 // A is negative, b is positive, return - ( -a / b). 389 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 390 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 391 return zero.ssub_ov(div, overflowOrDiv0); 392 } 393 // A is positive, b is negative, return - (a / -b). 394 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 395 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 396 return zero.ssub_ov(div, overflowOrDiv0); 397 }); 398 399 // Fold out ceil division by one. Assumes all tensors of all ones are 400 // splats. 401 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 402 if (rhs.getValue() == 1) 403 return getLhs(); 404 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 405 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 406 return getLhs(); 407 } 408 409 return overflowOrDiv0 ? Attribute() : result; 410 } 411 412 //===----------------------------------------------------------------------===// 413 // FloorDivSIOp 414 //===----------------------------------------------------------------------===// 415 416 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 417 // Don't fold if it would overflow or if it requires a division by zero. 418 bool overflowOrDiv0 = false; 419 auto result = 420 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 421 if (overflowOrDiv0 || !b) { 422 overflowOrDiv0 = true; 423 return a; 424 } 425 if (!a) 426 return a; 427 // After this point we know that neither a or b are zero. 428 unsigned bits = a.getBitWidth(); 429 APInt zero = APInt::getZero(bits); 430 bool aGtZero = a.sgt(zero); 431 bool bGtZero = b.sgt(zero); 432 if (aGtZero && bGtZero) { 433 // Both positive, return a / b. 434 return a.sdiv_ov(b, overflowOrDiv0); 435 } 436 if (!aGtZero && !bGtZero) { 437 // Both negative, return -a / -b. 438 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 439 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 440 return posA.sdiv_ov(posB, overflowOrDiv0); 441 } 442 if (!aGtZero && bGtZero) { 443 // A is negative, b is positive, return - ceil(-a, b). 444 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 445 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 446 return zero.ssub_ov(ceil, overflowOrDiv0); 447 } 448 // A is positive, b is negative, return - ceil(a, -b). 449 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 450 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 451 return zero.ssub_ov(ceil, overflowOrDiv0); 452 }); 453 454 // Fold out floor division by one. Assumes all tensors of all ones are 455 // splats. 456 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 457 if (rhs.getValue() == 1) 458 return getLhs(); 459 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 460 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 461 return getLhs(); 462 } 463 464 return overflowOrDiv0 ? Attribute() : result; 465 } 466 467 //===----------------------------------------------------------------------===// 468 // RemUIOp 469 //===----------------------------------------------------------------------===// 470 471 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 472 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 473 if (!rhs) 474 return {}; 475 auto rhsValue = rhs.getValue(); 476 477 // x % 1 = 0 478 if (rhsValue.isOneValue()) 479 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 480 481 // Don't fold if it requires division by zero. 482 if (rhsValue.isNullValue()) 483 return {}; 484 485 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 486 if (!lhs) 487 return {}; 488 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 489 } 490 491 //===----------------------------------------------------------------------===// 492 // RemSIOp 493 //===----------------------------------------------------------------------===// 494 495 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 496 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 497 if (!rhs) 498 return {}; 499 auto rhsValue = rhs.getValue(); 500 501 // x % 1 = 0 502 if (rhsValue.isOneValue()) 503 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 504 505 // Don't fold if it requires division by zero. 506 if (rhsValue.isNullValue()) 507 return {}; 508 509 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 510 if (!lhs) 511 return {}; 512 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 513 } 514 515 //===----------------------------------------------------------------------===// 516 // AndIOp 517 //===----------------------------------------------------------------------===// 518 519 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 520 /// and(x, 0) -> 0 521 if (matchPattern(getRhs(), m_Zero())) 522 return getRhs(); 523 /// and(x, allOnes) -> x 524 APInt intValue; 525 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 526 return getLhs(); 527 528 return constFoldBinaryOp<IntegerAttr>( 529 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 530 } 531 532 //===----------------------------------------------------------------------===// 533 // OrIOp 534 //===----------------------------------------------------------------------===// 535 536 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 537 /// or(x, 0) -> x 538 if (matchPattern(getRhs(), m_Zero())) 539 return getLhs(); 540 /// or(x, <all ones>) -> <all ones> 541 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 542 if (rhsAttr.getValue().isAllOnes()) 543 return rhsAttr; 544 545 return constFoldBinaryOp<IntegerAttr>( 546 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 547 } 548 549 //===----------------------------------------------------------------------===// 550 // XOrIOp 551 //===----------------------------------------------------------------------===// 552 553 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 554 /// xor(x, 0) -> x 555 if (matchPattern(getRhs(), m_Zero())) 556 return getLhs(); 557 /// xor(x, x) -> 0 558 if (getLhs() == getRhs()) 559 return Builder(getContext()).getZeroAttr(getType()); 560 561 return constFoldBinaryOp<IntegerAttr>( 562 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 563 } 564 565 void arith::XOrIOp::getCanonicalizationPatterns( 566 OwningRewritePatternList &patterns, MLIRContext *context) { 567 patterns.insert<XOrINotCmpI>(context); 568 } 569 570 //===----------------------------------------------------------------------===// 571 // AddFOp 572 //===----------------------------------------------------------------------===// 573 574 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 575 return constFoldBinaryOp<FloatAttr>( 576 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 577 } 578 579 //===----------------------------------------------------------------------===// 580 // SubFOp 581 //===----------------------------------------------------------------------===// 582 583 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 584 return constFoldBinaryOp<FloatAttr>( 585 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 586 } 587 588 //===----------------------------------------------------------------------===// 589 // MaxSIOp 590 //===----------------------------------------------------------------------===// 591 592 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 593 assert(operands.size() == 2 && "binary operation takes two operands"); 594 595 // maxsi(x,x) -> x 596 if (getLhs() == getRhs()) 597 return getRhs(); 598 599 APInt intValue; 600 // maxsi(x,MAX_INT) -> MAX_INT 601 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 602 intValue.isMaxSignedValue()) 603 return getRhs(); 604 605 // maxsi(x, MIN_INT) -> x 606 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 607 intValue.isMinSignedValue()) 608 return getLhs(); 609 610 return constFoldBinaryOp<IntegerAttr>(operands, 611 [](const APInt &a, const APInt &b) { 612 return llvm::APIntOps::smax(a, b); 613 }); 614 } 615 616 //===----------------------------------------------------------------------===// 617 // MaxUIOp 618 //===----------------------------------------------------------------------===// 619 620 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 621 assert(operands.size() == 2 && "binary operation takes two operands"); 622 623 // maxui(x,x) -> x 624 if (getLhs() == getRhs()) 625 return getRhs(); 626 627 APInt intValue; 628 // maxui(x,MAX_INT) -> MAX_INT 629 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 630 return getRhs(); 631 632 // maxui(x, MIN_INT) -> x 633 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 634 return getLhs(); 635 636 return constFoldBinaryOp<IntegerAttr>(operands, 637 [](const APInt &a, const APInt &b) { 638 return llvm::APIntOps::umax(a, b); 639 }); 640 } 641 642 //===----------------------------------------------------------------------===// 643 // MinSIOp 644 //===----------------------------------------------------------------------===// 645 646 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 647 assert(operands.size() == 2 && "binary operation takes two operands"); 648 649 // minsi(x,x) -> x 650 if (getLhs() == getRhs()) 651 return getRhs(); 652 653 APInt intValue; 654 // minsi(x,MIN_INT) -> MIN_INT 655 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 656 intValue.isMinSignedValue()) 657 return getRhs(); 658 659 // minsi(x, MAX_INT) -> x 660 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 661 intValue.isMaxSignedValue()) 662 return getLhs(); 663 664 return constFoldBinaryOp<IntegerAttr>(operands, 665 [](const APInt &a, const APInt &b) { 666 return llvm::APIntOps::smin(a, b); 667 }); 668 } 669 670 //===----------------------------------------------------------------------===// 671 // MinUIOp 672 //===----------------------------------------------------------------------===// 673 674 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 675 assert(operands.size() == 2 && "binary operation takes two operands"); 676 677 // minui(x,x) -> x 678 if (getLhs() == getRhs()) 679 return getRhs(); 680 681 APInt intValue; 682 // minui(x,MIN_INT) -> MIN_INT 683 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 684 return getRhs(); 685 686 // minui(x, MAX_INT) -> x 687 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 688 return getLhs(); 689 690 return constFoldBinaryOp<IntegerAttr>(operands, 691 [](const APInt &a, const APInt &b) { 692 return llvm::APIntOps::umin(a, b); 693 }); 694 } 695 696 //===----------------------------------------------------------------------===// 697 // MulFOp 698 //===----------------------------------------------------------------------===// 699 700 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 701 return constFoldBinaryOp<FloatAttr>( 702 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 703 } 704 705 //===----------------------------------------------------------------------===// 706 // DivFOp 707 //===----------------------------------------------------------------------===// 708 709 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 710 return constFoldBinaryOp<FloatAttr>( 711 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 712 } 713 714 //===----------------------------------------------------------------------===// 715 // Utility functions for verifying cast ops 716 //===----------------------------------------------------------------------===// 717 718 template <typename... Types> 719 using type_list = std::tuple<Types...> *; 720 721 /// Returns a non-null type only if the provided type is one of the allowed 722 /// types or one of the allowed shaped types of the allowed types. Returns the 723 /// element type if a valid shaped type is provided. 724 template <typename... ShapedTypes, typename... ElementTypes> 725 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 726 type_list<ElementTypes...>) { 727 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 728 return {}; 729 730 auto underlyingType = getElementTypeOrSelf(type); 731 if (!underlyingType.isa<ElementTypes...>()) 732 return {}; 733 734 return underlyingType; 735 } 736 737 /// Get allowed underlying types for vectors and tensors. 738 template <typename... ElementTypes> 739 static Type getTypeIfLike(Type type) { 740 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 741 type_list<ElementTypes...>()); 742 } 743 744 /// Get allowed underlying types for vectors, tensors, and memrefs. 745 template <typename... ElementTypes> 746 static Type getTypeIfLikeOrMemRef(Type type) { 747 return getUnderlyingType(type, 748 type_list<VectorType, TensorType, MemRefType>(), 749 type_list<ElementTypes...>()); 750 } 751 752 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 753 return inputs.size() == 1 && outputs.size() == 1 && 754 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 755 } 756 757 //===----------------------------------------------------------------------===// 758 // Verifiers for integer and floating point extension/truncation ops 759 //===----------------------------------------------------------------------===// 760 761 // Extend ops can only extend to a wider type. 762 template <typename ValType, typename Op> 763 static LogicalResult verifyExtOp(Op op) { 764 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 765 Type dstType = getElementTypeOrSelf(op.getType()); 766 767 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 768 return op.emitError("result type ") 769 << dstType << " must be wider than operand type " << srcType; 770 771 return success(); 772 } 773 774 // Truncate ops can only truncate to a shorter type. 775 template <typename ValType, typename Op> 776 static LogicalResult verifyTruncateOp(Op op) { 777 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 778 Type dstType = getElementTypeOrSelf(op.getType()); 779 780 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 781 return op.emitError("result type ") 782 << dstType << " must be shorter than operand type " << srcType; 783 784 return success(); 785 } 786 787 /// Validate a cast that changes the width of a type. 788 template <template <typename> class WidthComparator, typename... ElementTypes> 789 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 790 if (!areValidCastInputsAndOutputs(inputs, outputs)) 791 return false; 792 793 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 794 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 795 if (!srcType || !dstType) 796 return false; 797 798 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 799 srcType.getIntOrFloatBitWidth()); 800 } 801 802 //===----------------------------------------------------------------------===// 803 // ExtUIOp 804 //===----------------------------------------------------------------------===// 805 806 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 807 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 808 return IntegerAttr::get( 809 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 810 811 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 812 getInMutable().assign(lhs.getIn()); 813 return getResult(); 814 } 815 816 return {}; 817 } 818 819 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 820 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 821 } 822 823 //===----------------------------------------------------------------------===// 824 // ExtSIOp 825 //===----------------------------------------------------------------------===// 826 827 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 828 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 829 return IntegerAttr::get( 830 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 831 832 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 833 getInMutable().assign(lhs.getIn()); 834 return getResult(); 835 } 836 837 return {}; 838 } 839 840 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 841 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 842 } 843 844 void arith::ExtSIOp::getCanonicalizationPatterns( 845 OwningRewritePatternList &patterns, MLIRContext *context) { 846 patterns.insert<ExtSIOfExtUI>(context); 847 } 848 849 //===----------------------------------------------------------------------===// 850 // ExtFOp 851 //===----------------------------------------------------------------------===// 852 853 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 854 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 855 } 856 857 //===----------------------------------------------------------------------===// 858 // TruncIOp 859 //===----------------------------------------------------------------------===// 860 861 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 862 // trunci(zexti(a)) -> a 863 // trunci(sexti(a)) -> a 864 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 865 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 866 return getOperand().getDefiningOp()->getOperand(0); 867 868 assert(operands.size() == 1 && "unary operation takes one operand"); 869 870 if (!operands[0]) 871 return {}; 872 873 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 874 return IntegerAttr::get( 875 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 876 } 877 878 return {}; 879 } 880 881 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 882 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 883 } 884 885 //===----------------------------------------------------------------------===// 886 // TruncFOp 887 //===----------------------------------------------------------------------===// 888 889 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 890 /// can be represented without precision loss or rounding. 891 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 892 assert(operands.size() == 1 && "unary operation takes one operand"); 893 894 auto constOperand = operands.front(); 895 if (!constOperand || !constOperand.isa<FloatAttr>()) 896 return {}; 897 898 // Convert to target type via 'double'. 899 double sourceValue = 900 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 901 auto targetAttr = FloatAttr::get(getType(), sourceValue); 902 903 // Propagate if constant's value does not change after truncation. 904 if (sourceValue == targetAttr.getValue().convertToDouble()) 905 return targetAttr; 906 907 return {}; 908 } 909 910 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 911 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 912 } 913 914 //===----------------------------------------------------------------------===// 915 // AndIOp 916 //===----------------------------------------------------------------------===// 917 918 void arith::AndIOp::getCanonicalizationPatterns( 919 OwningRewritePatternList &patterns, MLIRContext *context) { 920 patterns.insert<AndOfExtUI, AndOfExtSI>(context); 921 } 922 923 //===----------------------------------------------------------------------===// 924 // OrIOp 925 //===----------------------------------------------------------------------===// 926 927 void arith::OrIOp::getCanonicalizationPatterns( 928 OwningRewritePatternList &patterns, MLIRContext *context) { 929 patterns.insert<OrOfExtUI, OrOfExtSI>(context); 930 } 931 932 //===----------------------------------------------------------------------===// 933 // Verifiers for casts between integers and floats. 934 //===----------------------------------------------------------------------===// 935 936 template <typename From, typename To> 937 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 938 if (!areValidCastInputsAndOutputs(inputs, outputs)) 939 return false; 940 941 auto srcType = getTypeIfLike<From>(inputs.front()); 942 auto dstType = getTypeIfLike<To>(outputs.back()); 943 944 return srcType && dstType; 945 } 946 947 //===----------------------------------------------------------------------===// 948 // UIToFPOp 949 //===----------------------------------------------------------------------===// 950 951 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 952 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 953 } 954 955 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 956 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 957 const APInt &api = lhs.getValue(); 958 FloatType floatTy = getType().cast<FloatType>(); 959 APFloat apf(floatTy.getFloatSemantics(), 960 APInt::getZero(floatTy.getWidth())); 961 apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven); 962 return FloatAttr::get(floatTy, apf); 963 } 964 return {}; 965 } 966 967 //===----------------------------------------------------------------------===// 968 // SIToFPOp 969 //===----------------------------------------------------------------------===// 970 971 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 972 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 973 } 974 975 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 976 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 977 const APInt &api = lhs.getValue(); 978 FloatType floatTy = getType().cast<FloatType>(); 979 APFloat apf(floatTy.getFloatSemantics(), 980 APInt::getZero(floatTy.getWidth())); 981 apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven); 982 return FloatAttr::get(floatTy, apf); 983 } 984 return {}; 985 } 986 //===----------------------------------------------------------------------===// 987 // FPToUIOp 988 //===----------------------------------------------------------------------===// 989 990 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 991 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 992 } 993 994 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 995 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 996 const APFloat &apf = lhs.getValue(); 997 IntegerType intTy = getType().cast<IntegerType>(); 998 bool ignored; 999 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 1000 if (APFloat::opInvalidOp == 1001 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1002 // Undefined behavior invoked - the destination type can't represent 1003 // the input constant. 1004 return {}; 1005 } 1006 return IntegerAttr::get(getType(), api); 1007 } 1008 1009 return {}; 1010 } 1011 1012 //===----------------------------------------------------------------------===// 1013 // FPToSIOp 1014 //===----------------------------------------------------------------------===// 1015 1016 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1017 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1018 } 1019 1020 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 1021 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 1022 const APFloat &apf = lhs.getValue(); 1023 IntegerType intTy = getType().cast<IntegerType>(); 1024 bool ignored; 1025 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 1026 if (APFloat::opInvalidOp == 1027 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1028 // Undefined behavior invoked - the destination type can't represent 1029 // the input constant. 1030 return {}; 1031 } 1032 return IntegerAttr::get(getType(), api); 1033 } 1034 1035 return {}; 1036 } 1037 1038 //===----------------------------------------------------------------------===// 1039 // IndexCastOp 1040 //===----------------------------------------------------------------------===// 1041 1042 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1043 TypeRange outputs) { 1044 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1045 return false; 1046 1047 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1048 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1049 if (!srcType || !dstType) 1050 return false; 1051 1052 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1053 (srcType.isSignlessInteger() && dstType.isIndex()); 1054 } 1055 1056 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1057 // index_cast(constant) -> constant 1058 // A little hack because we go through int. Otherwise, the size of the 1059 // constant might need to change. 1060 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1061 return IntegerAttr::get(getType(), value.getInt()); 1062 1063 return {}; 1064 } 1065 1066 void arith::IndexCastOp::getCanonicalizationPatterns( 1067 OwningRewritePatternList &patterns, MLIRContext *context) { 1068 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1069 } 1070 1071 //===----------------------------------------------------------------------===// 1072 // BitcastOp 1073 //===----------------------------------------------------------------------===// 1074 1075 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1076 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1077 return false; 1078 1079 auto srcType = 1080 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1081 auto dstType = 1082 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1083 if (!srcType || !dstType) 1084 return false; 1085 1086 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1087 } 1088 1089 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1090 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1091 1092 auto resType = getType(); 1093 auto operand = operands[0]; 1094 if (!operand) 1095 return {}; 1096 1097 /// Bitcast dense elements. 1098 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1099 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1100 /// Other shaped types unhandled. 1101 if (resType.isa<ShapedType>()) 1102 return {}; 1103 1104 /// Bitcast integer or float to integer or float. 1105 APInt bits = operand.isa<FloatAttr>() 1106 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1107 : operand.cast<IntegerAttr>().getValue(); 1108 1109 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1110 return FloatAttr::get(resType, 1111 APFloat(resFloatType.getFloatSemantics(), bits)); 1112 return IntegerAttr::get(resType, bits); 1113 } 1114 1115 void arith::BitcastOp::getCanonicalizationPatterns( 1116 OwningRewritePatternList &patterns, MLIRContext *context) { 1117 patterns.insert<BitcastOfBitcast>(context); 1118 } 1119 1120 //===----------------------------------------------------------------------===// 1121 // Helpers for compare ops 1122 //===----------------------------------------------------------------------===// 1123 1124 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1125 static Type getI1SameShape(Type type) { 1126 auto i1Type = IntegerType::get(type.getContext(), 1); 1127 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1128 return RankedTensorType::get(tensorType.getShape(), i1Type); 1129 if (type.isa<UnrankedTensorType>()) 1130 return UnrankedTensorType::get(i1Type); 1131 if (auto vectorType = type.dyn_cast<VectorType>()) 1132 return VectorType::get(vectorType.getShape(), i1Type, 1133 vectorType.getNumScalableDims()); 1134 return i1Type; 1135 } 1136 1137 //===----------------------------------------------------------------------===// 1138 // CmpIOp 1139 //===----------------------------------------------------------------------===// 1140 1141 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1142 /// comparison predicates. 1143 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1144 const APInt &lhs, const APInt &rhs) { 1145 switch (predicate) { 1146 case arith::CmpIPredicate::eq: 1147 return lhs.eq(rhs); 1148 case arith::CmpIPredicate::ne: 1149 return lhs.ne(rhs); 1150 case arith::CmpIPredicate::slt: 1151 return lhs.slt(rhs); 1152 case arith::CmpIPredicate::sle: 1153 return lhs.sle(rhs); 1154 case arith::CmpIPredicate::sgt: 1155 return lhs.sgt(rhs); 1156 case arith::CmpIPredicate::sge: 1157 return lhs.sge(rhs); 1158 case arith::CmpIPredicate::ult: 1159 return lhs.ult(rhs); 1160 case arith::CmpIPredicate::ule: 1161 return lhs.ule(rhs); 1162 case arith::CmpIPredicate::ugt: 1163 return lhs.ugt(rhs); 1164 case arith::CmpIPredicate::uge: 1165 return lhs.uge(rhs); 1166 } 1167 llvm_unreachable("unknown cmpi predicate kind"); 1168 } 1169 1170 /// Returns true if the predicate is true for two equal operands. 1171 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1172 switch (predicate) { 1173 case arith::CmpIPredicate::eq: 1174 case arith::CmpIPredicate::sle: 1175 case arith::CmpIPredicate::sge: 1176 case arith::CmpIPredicate::ule: 1177 case arith::CmpIPredicate::uge: 1178 return true; 1179 case arith::CmpIPredicate::ne: 1180 case arith::CmpIPredicate::slt: 1181 case arith::CmpIPredicate::sgt: 1182 case arith::CmpIPredicate::ult: 1183 case arith::CmpIPredicate::ugt: 1184 return false; 1185 } 1186 llvm_unreachable("unknown cmpi predicate kind"); 1187 } 1188 1189 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1190 auto boolAttr = BoolAttr::get(ctx, value); 1191 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1192 if (!shapedType) 1193 return boolAttr; 1194 return DenseElementsAttr::get(shapedType, boolAttr); 1195 } 1196 1197 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1198 assert(operands.size() == 2 && "cmpi takes two operands"); 1199 1200 // cmpi(pred, x, x) 1201 if (getLhs() == getRhs()) { 1202 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1203 return getBoolAttribute(getType(), getContext(), val); 1204 } 1205 1206 if (matchPattern(getRhs(), m_Zero())) { 1207 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1208 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1209 // extsi(%x : i1 -> iN) != 0 -> %x 1210 if (getPredicate() == arith::CmpIPredicate::ne) { 1211 return extOp.getOperand(); 1212 } 1213 } 1214 } 1215 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1216 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1217 // extui(%x : i1 -> iN) != 0 -> %x 1218 if (getPredicate() == arith::CmpIPredicate::ne) { 1219 return extOp.getOperand(); 1220 } 1221 } 1222 } 1223 } 1224 1225 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1226 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1227 if (!lhs || !rhs) 1228 return {}; 1229 1230 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1231 return BoolAttr::get(getContext(), val); 1232 } 1233 1234 //===----------------------------------------------------------------------===// 1235 // CmpFOp 1236 //===----------------------------------------------------------------------===// 1237 1238 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1239 /// comparison predicates. 1240 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1241 const APFloat &lhs, const APFloat &rhs) { 1242 auto cmpResult = lhs.compare(rhs); 1243 switch (predicate) { 1244 case arith::CmpFPredicate::AlwaysFalse: 1245 return false; 1246 case arith::CmpFPredicate::OEQ: 1247 return cmpResult == APFloat::cmpEqual; 1248 case arith::CmpFPredicate::OGT: 1249 return cmpResult == APFloat::cmpGreaterThan; 1250 case arith::CmpFPredicate::OGE: 1251 return cmpResult == APFloat::cmpGreaterThan || 1252 cmpResult == APFloat::cmpEqual; 1253 case arith::CmpFPredicate::OLT: 1254 return cmpResult == APFloat::cmpLessThan; 1255 case arith::CmpFPredicate::OLE: 1256 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1257 case arith::CmpFPredicate::ONE: 1258 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1259 case arith::CmpFPredicate::ORD: 1260 return cmpResult != APFloat::cmpUnordered; 1261 case arith::CmpFPredicate::UEQ: 1262 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1263 case arith::CmpFPredicate::UGT: 1264 return cmpResult == APFloat::cmpUnordered || 1265 cmpResult == APFloat::cmpGreaterThan; 1266 case arith::CmpFPredicate::UGE: 1267 return cmpResult == APFloat::cmpUnordered || 1268 cmpResult == APFloat::cmpGreaterThan || 1269 cmpResult == APFloat::cmpEqual; 1270 case arith::CmpFPredicate::ULT: 1271 return cmpResult == APFloat::cmpUnordered || 1272 cmpResult == APFloat::cmpLessThan; 1273 case arith::CmpFPredicate::ULE: 1274 return cmpResult == APFloat::cmpUnordered || 1275 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1276 case arith::CmpFPredicate::UNE: 1277 return cmpResult != APFloat::cmpEqual; 1278 case arith::CmpFPredicate::UNO: 1279 return cmpResult == APFloat::cmpUnordered; 1280 case arith::CmpFPredicate::AlwaysTrue: 1281 return true; 1282 } 1283 llvm_unreachable("unknown cmpf predicate kind"); 1284 } 1285 1286 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1287 assert(operands.size() == 2 && "cmpf takes two operands"); 1288 1289 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1290 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1291 1292 if (!lhs || !rhs) 1293 return {}; 1294 1295 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1296 return BoolAttr::get(getContext(), val); 1297 } 1298 1299 //===----------------------------------------------------------------------===// 1300 // Atomic Enum 1301 //===----------------------------------------------------------------------===// 1302 1303 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1304 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1305 OpBuilder &builder, Location loc) { 1306 switch (kind) { 1307 case AtomicRMWKind::maxf: 1308 return builder.getFloatAttr( 1309 resultType, 1310 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1311 /*Negative=*/true)); 1312 case AtomicRMWKind::addf: 1313 case AtomicRMWKind::addi: 1314 case AtomicRMWKind::maxu: 1315 case AtomicRMWKind::ori: 1316 return builder.getZeroAttr(resultType); 1317 case AtomicRMWKind::andi: 1318 return builder.getIntegerAttr( 1319 resultType, 1320 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1321 case AtomicRMWKind::maxs: 1322 return builder.getIntegerAttr( 1323 resultType, 1324 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1325 case AtomicRMWKind::minf: 1326 return builder.getFloatAttr( 1327 resultType, 1328 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1329 /*Negative=*/false)); 1330 case AtomicRMWKind::mins: 1331 return builder.getIntegerAttr( 1332 resultType, 1333 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1334 case AtomicRMWKind::minu: 1335 return builder.getIntegerAttr( 1336 resultType, 1337 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1338 case AtomicRMWKind::muli: 1339 return builder.getIntegerAttr(resultType, 1); 1340 case AtomicRMWKind::mulf: 1341 return builder.getFloatAttr(resultType, 1); 1342 // TODO: Add remaining reduction operations. 1343 default: 1344 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1345 break; 1346 } 1347 return nullptr; 1348 } 1349 1350 /// Returns the identity value associated with an AtomicRMWKind op. 1351 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 1352 OpBuilder &builder, Location loc) { 1353 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 1354 return builder.create<arith::ConstantOp>(loc, attr); 1355 } 1356 1357 /// Return the value obtained by applying the reduction operation kind 1358 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 1359 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 1360 Location loc, Value lhs, Value rhs) { 1361 switch (op) { 1362 case AtomicRMWKind::addf: 1363 return builder.create<arith::AddFOp>(loc, lhs, rhs); 1364 case AtomicRMWKind::addi: 1365 return builder.create<arith::AddIOp>(loc, lhs, rhs); 1366 case AtomicRMWKind::mulf: 1367 return builder.create<arith::MulFOp>(loc, lhs, rhs); 1368 case AtomicRMWKind::muli: 1369 return builder.create<arith::MulIOp>(loc, lhs, rhs); 1370 case AtomicRMWKind::maxf: 1371 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 1372 case AtomicRMWKind::minf: 1373 return builder.create<arith::MinFOp>(loc, lhs, rhs); 1374 case AtomicRMWKind::maxs: 1375 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 1376 case AtomicRMWKind::mins: 1377 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 1378 case AtomicRMWKind::maxu: 1379 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 1380 case AtomicRMWKind::minu: 1381 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 1382 case AtomicRMWKind::ori: 1383 return builder.create<arith::OrIOp>(loc, lhs, rhs); 1384 case AtomicRMWKind::andi: 1385 return builder.create<arith::AndIOp>(loc, lhs, rhs); 1386 // TODO: Add remaining reduction operations. 1387 default: 1388 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1389 break; 1390 } 1391 return nullptr; 1392 } 1393 1394 //===----------------------------------------------------------------------===// 1395 // TableGen'd op method definitions 1396 //===----------------------------------------------------------------------===// 1397 1398 #define GET_OP_CLASSES 1399 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1400 1401 //===----------------------------------------------------------------------===// 1402 // TableGen'd enum attribute definitions 1403 //===----------------------------------------------------------------------===// 1404 1405 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1406