1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 #include "llvm/ADT/APSInt.h" 20 21 using namespace mlir; 22 using namespace mlir::arith; 23 24 //===----------------------------------------------------------------------===// 25 // Pattern helpers 26 //===----------------------------------------------------------------------===// 27 28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 29 Attribute lhs, Attribute rhs) { 30 return builder.getIntegerAttr(res.getType(), 31 lhs.cast<IntegerAttr>().getInt() + 32 rhs.cast<IntegerAttr>().getInt()); 33 } 34 35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 36 Attribute lhs, Attribute rhs) { 37 return builder.getIntegerAttr(res.getType(), 38 lhs.cast<IntegerAttr>().getInt() - 39 rhs.cast<IntegerAttr>().getInt()); 40 } 41 42 /// Invert an integer comparison predicate. 43 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { 44 switch (pred) { 45 case arith::CmpIPredicate::eq: 46 return arith::CmpIPredicate::ne; 47 case arith::CmpIPredicate::ne: 48 return arith::CmpIPredicate::eq; 49 case arith::CmpIPredicate::slt: 50 return arith::CmpIPredicate::sge; 51 case arith::CmpIPredicate::sle: 52 return arith::CmpIPredicate::sgt; 53 case arith::CmpIPredicate::sgt: 54 return arith::CmpIPredicate::sle; 55 case arith::CmpIPredicate::sge: 56 return arith::CmpIPredicate::slt; 57 case arith::CmpIPredicate::ult: 58 return arith::CmpIPredicate::uge; 59 case arith::CmpIPredicate::ule: 60 return arith::CmpIPredicate::ugt; 61 case arith::CmpIPredicate::ugt: 62 return arith::CmpIPredicate::ule; 63 case arith::CmpIPredicate::uge: 64 return arith::CmpIPredicate::ult; 65 } 66 llvm_unreachable("unknown cmpi predicate kind"); 67 } 68 69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 70 return arith::CmpIPredicateAttr::get(pred.getContext(), 71 invertPredicate(pred.getValue())); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // TableGen'd canonicalization patterns 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 #include "ArithmeticCanonicalization.inc" 80 } // namespace 81 82 //===----------------------------------------------------------------------===// 83 // ConstantOp 84 //===----------------------------------------------------------------------===// 85 86 void arith::ConstantOp::getAsmResultNames( 87 function_ref<void(Value, StringRef)> setNameFn) { 88 auto type = getType(); 89 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 90 auto intType = type.dyn_cast<IntegerType>(); 91 92 // Sugar i1 constants with 'true' and 'false'. 93 if (intType && intType.getWidth() == 1) 94 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 95 96 // Otherwise, build a compex name with the value and type. 97 SmallString<32> specialNameBuffer; 98 llvm::raw_svector_ostream specialName(specialNameBuffer); 99 specialName << 'c' << intCst.getInt(); 100 if (intType) 101 specialName << '_' << type; 102 setNameFn(getResult(), specialName.str()); 103 } else { 104 setNameFn(getResult(), "cst"); 105 } 106 } 107 108 /// TODO: disallow arith.constant to return anything other than signless integer 109 /// or float like. 110 static LogicalResult verify(arith::ConstantOp op) { 111 auto type = op.getType(); 112 // The value's type must match the return type. 113 if (op.getValue().getType() != type) { 114 return op.emitOpError() << "value type " << op.getValue().getType() 115 << " must match return type: " << type; 116 } 117 // Integer values must be signless. 118 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 119 return op.emitOpError("integer return type must be signless"); 120 // Any float or elements attribute are acceptable. 121 if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 122 return op.emitOpError( 123 "value must be an integer, float, or elements attribute"); 124 } 125 return success(); 126 } 127 128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 129 // The value's type must be the same as the provided type. 130 if (value.getType() != type) 131 return false; 132 // Integer values must be signless. 133 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 134 return false; 135 // Integer, float, and element attributes are buildable. 136 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 137 } 138 139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 140 return getValue(); 141 } 142 143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 144 int64_t value, unsigned width) { 145 auto type = builder.getIntegerType(width); 146 arith::ConstantOp::build(builder, result, type, 147 builder.getIntegerAttr(type, value)); 148 } 149 150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 151 int64_t value, Type type) { 152 assert(type.isSignlessInteger() && 153 "ConstantIntOp can only have signless integer type values"); 154 arith::ConstantOp::build(builder, result, type, 155 builder.getIntegerAttr(type, value)); 156 } 157 158 bool arith::ConstantIntOp::classof(Operation *op) { 159 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 160 return constOp.getType().isSignlessInteger(); 161 return false; 162 } 163 164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 165 const APFloat &value, FloatType type) { 166 arith::ConstantOp::build(builder, result, type, 167 builder.getFloatAttr(type, value)); 168 } 169 170 bool arith::ConstantFloatOp::classof(Operation *op) { 171 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 172 return constOp.getType().isa<FloatType>(); 173 return false; 174 } 175 176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 177 int64_t value) { 178 arith::ConstantOp::build(builder, result, builder.getIndexType(), 179 builder.getIndexAttr(value)); 180 } 181 182 bool arith::ConstantIndexOp::classof(Operation *op) { 183 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 184 return constOp.getType().isIndex(); 185 return false; 186 } 187 188 //===----------------------------------------------------------------------===// 189 // AddIOp 190 //===----------------------------------------------------------------------===// 191 192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 193 // addi(x, 0) -> x 194 if (matchPattern(getRhs(), m_Zero())) 195 return getLhs(); 196 197 // addi(subi(a, b), b) -> a 198 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 199 if (getRhs() == sub.getRhs()) 200 return sub.getLhs(); 201 202 // addi(b, subi(a, b)) -> a 203 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 204 if (getLhs() == sub.getRhs()) 205 return sub.getLhs(); 206 207 return constFoldBinaryOp<IntegerAttr>( 208 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 209 } 210 211 void arith::AddIOp::getCanonicalizationPatterns( 212 RewritePatternSet &patterns, MLIRContext *context) { 213 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 214 context); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // SubIOp 219 //===----------------------------------------------------------------------===// 220 221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 222 // subi(x,x) -> 0 223 if (getOperand(0) == getOperand(1)) 224 return Builder(getContext()).getZeroAttr(getType()); 225 // subi(x,0) -> x 226 if (matchPattern(getRhs(), m_Zero())) 227 return getLhs(); 228 229 return constFoldBinaryOp<IntegerAttr>( 230 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 231 } 232 233 void arith::SubIOp::getCanonicalizationPatterns( 234 RewritePatternSet &patterns, MLIRContext *context) { 235 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 236 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 237 SubILHSSubConstantLHS>(context); 238 } 239 240 //===----------------------------------------------------------------------===// 241 // MulIOp 242 //===----------------------------------------------------------------------===// 243 244 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 245 // muli(x, 0) -> 0 246 if (matchPattern(getRhs(), m_Zero())) 247 return getRhs(); 248 // muli(x, 1) -> x 249 if (matchPattern(getRhs(), m_One())) 250 return getOperand(0); 251 // TODO: Handle the overflow case. 252 253 // default folder 254 return constFoldBinaryOp<IntegerAttr>( 255 operands, [](const APInt &a, const APInt &b) { return a * b; }); 256 } 257 258 //===----------------------------------------------------------------------===// 259 // DivUIOp 260 //===----------------------------------------------------------------------===// 261 262 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 263 // Don't fold if it would require a division by zero. 264 bool div0 = false; 265 auto result = 266 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 267 if (div0 || !b) { 268 div0 = true; 269 return a; 270 } 271 return a.udiv(b); 272 }); 273 274 // Fold out division by one. Assumes all tensors of all ones are splats. 275 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 276 if (rhs.getValue() == 1) 277 return getLhs(); 278 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 279 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 280 return getLhs(); 281 } 282 283 return div0 ? Attribute() : result; 284 } 285 286 //===----------------------------------------------------------------------===// 287 // DivSIOp 288 //===----------------------------------------------------------------------===// 289 290 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 291 // Don't fold if it would overflow or if it requires a division by zero. 292 bool overflowOrDiv0 = false; 293 auto result = 294 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 295 if (overflowOrDiv0 || !b) { 296 overflowOrDiv0 = true; 297 return a; 298 } 299 return a.sdiv_ov(b, overflowOrDiv0); 300 }); 301 302 // Fold out division by one. Assumes all tensors of all ones are splats. 303 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 304 if (rhs.getValue() == 1) 305 return getLhs(); 306 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 307 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 308 return getLhs(); 309 } 310 311 return overflowOrDiv0 ? Attribute() : result; 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Ceil and floor division folding helpers 316 //===----------------------------------------------------------------------===// 317 318 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 319 bool &overflow) { 320 // Returns (a-1)/b + 1 321 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 322 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 323 return val.sadd_ov(one, overflow); 324 } 325 326 //===----------------------------------------------------------------------===// 327 // CeilDivUIOp 328 //===----------------------------------------------------------------------===// 329 330 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 331 bool overflowOrDiv0 = false; 332 auto result = 333 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 334 if (overflowOrDiv0 || !b) { 335 overflowOrDiv0 = true; 336 return a; 337 } 338 APInt quotient = a.udiv(b); 339 if (!a.urem(b)) 340 return quotient; 341 APInt one(a.getBitWidth(), 1, true); 342 return quotient.uadd_ov(one, overflowOrDiv0); 343 }); 344 // Fold out ceil division by one. Assumes all tensors of all ones are 345 // splats. 346 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 347 if (rhs.getValue() == 1) 348 return getLhs(); 349 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 350 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 351 return getLhs(); 352 } 353 354 return overflowOrDiv0 ? Attribute() : result; 355 } 356 357 //===----------------------------------------------------------------------===// 358 // CeilDivSIOp 359 //===----------------------------------------------------------------------===// 360 361 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 362 // Don't fold if it would overflow or if it requires a division by zero. 363 bool overflowOrDiv0 = false; 364 auto result = 365 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 366 if (overflowOrDiv0 || !b) { 367 overflowOrDiv0 = true; 368 return a; 369 } 370 if (!a) 371 return a; 372 // After this point we know that neither a or b are zero. 373 unsigned bits = a.getBitWidth(); 374 APInt zero = APInt::getZero(bits); 375 bool aGtZero = a.sgt(zero); 376 bool bGtZero = b.sgt(zero); 377 if (aGtZero && bGtZero) { 378 // Both positive, return ceil(a, b). 379 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 380 } 381 if (!aGtZero && !bGtZero) { 382 // Both negative, return ceil(-a, -b). 383 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 384 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 385 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 386 } 387 if (!aGtZero && bGtZero) { 388 // A is negative, b is positive, return - ( -a / b). 389 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 390 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 391 return zero.ssub_ov(div, overflowOrDiv0); 392 } 393 // A is positive, b is negative, return - (a / -b). 394 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 395 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 396 return zero.ssub_ov(div, overflowOrDiv0); 397 }); 398 399 // Fold out ceil division by one. Assumes all tensors of all ones are 400 // splats. 401 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 402 if (rhs.getValue() == 1) 403 return getLhs(); 404 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 405 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 406 return getLhs(); 407 } 408 409 return overflowOrDiv0 ? Attribute() : result; 410 } 411 412 //===----------------------------------------------------------------------===// 413 // FloorDivSIOp 414 //===----------------------------------------------------------------------===// 415 416 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 417 // Don't fold if it would overflow or if it requires a division by zero. 418 bool overflowOrDiv0 = false; 419 auto result = 420 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 421 if (overflowOrDiv0 || !b) { 422 overflowOrDiv0 = true; 423 return a; 424 } 425 if (!a) 426 return a; 427 // After this point we know that neither a or b are zero. 428 unsigned bits = a.getBitWidth(); 429 APInt zero = APInt::getZero(bits); 430 bool aGtZero = a.sgt(zero); 431 bool bGtZero = b.sgt(zero); 432 if (aGtZero && bGtZero) { 433 // Both positive, return a / b. 434 return a.sdiv_ov(b, overflowOrDiv0); 435 } 436 if (!aGtZero && !bGtZero) { 437 // Both negative, return -a / -b. 438 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 439 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 440 return posA.sdiv_ov(posB, overflowOrDiv0); 441 } 442 if (!aGtZero && bGtZero) { 443 // A is negative, b is positive, return - ceil(-a, b). 444 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 445 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 446 return zero.ssub_ov(ceil, overflowOrDiv0); 447 } 448 // A is positive, b is negative, return - ceil(a, -b). 449 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 450 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 451 return zero.ssub_ov(ceil, overflowOrDiv0); 452 }); 453 454 // Fold out floor division by one. Assumes all tensors of all ones are 455 // splats. 456 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 457 if (rhs.getValue() == 1) 458 return getLhs(); 459 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 460 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 461 return getLhs(); 462 } 463 464 return overflowOrDiv0 ? Attribute() : result; 465 } 466 467 //===----------------------------------------------------------------------===// 468 // RemUIOp 469 //===----------------------------------------------------------------------===// 470 471 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 472 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 473 if (!rhs) 474 return {}; 475 auto rhsValue = rhs.getValue(); 476 477 // x % 1 = 0 478 if (rhsValue.isOneValue()) 479 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 480 481 // Don't fold if it requires division by zero. 482 if (rhsValue.isNullValue()) 483 return {}; 484 485 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 486 if (!lhs) 487 return {}; 488 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 489 } 490 491 //===----------------------------------------------------------------------===// 492 // RemSIOp 493 //===----------------------------------------------------------------------===// 494 495 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 496 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 497 if (!rhs) 498 return {}; 499 auto rhsValue = rhs.getValue(); 500 501 // x % 1 = 0 502 if (rhsValue.isOneValue()) 503 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 504 505 // Don't fold if it requires division by zero. 506 if (rhsValue.isNullValue()) 507 return {}; 508 509 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 510 if (!lhs) 511 return {}; 512 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 513 } 514 515 //===----------------------------------------------------------------------===// 516 // AndIOp 517 //===----------------------------------------------------------------------===// 518 519 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 520 /// and(x, 0) -> 0 521 if (matchPattern(getRhs(), m_Zero())) 522 return getRhs(); 523 /// and(x, allOnes) -> x 524 APInt intValue; 525 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 526 return getLhs(); 527 528 return constFoldBinaryOp<IntegerAttr>( 529 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 530 } 531 532 //===----------------------------------------------------------------------===// 533 // OrIOp 534 //===----------------------------------------------------------------------===// 535 536 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 537 /// or(x, 0) -> x 538 if (matchPattern(getRhs(), m_Zero())) 539 return getLhs(); 540 /// or(x, <all ones>) -> <all ones> 541 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 542 if (rhsAttr.getValue().isAllOnes()) 543 return rhsAttr; 544 545 return constFoldBinaryOp<IntegerAttr>( 546 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 547 } 548 549 //===----------------------------------------------------------------------===// 550 // XOrIOp 551 //===----------------------------------------------------------------------===// 552 553 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 554 /// xor(x, 0) -> x 555 if (matchPattern(getRhs(), m_Zero())) 556 return getLhs(); 557 /// xor(x, x) -> 0 558 if (getLhs() == getRhs()) 559 return Builder(getContext()).getZeroAttr(getType()); 560 /// xor(xor(x, a), a) -> x 561 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) 562 if (prev.getRhs() == getRhs()) 563 return prev.getLhs(); 564 565 return constFoldBinaryOp<IntegerAttr>( 566 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 567 } 568 569 void arith::XOrIOp::getCanonicalizationPatterns( 570 RewritePatternSet &patterns, MLIRContext *context) { 571 patterns.insert<XOrINotCmpI>(context); 572 } 573 574 //===----------------------------------------------------------------------===// 575 // AddFOp 576 //===----------------------------------------------------------------------===// 577 578 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 579 // addf(x, -0) -> x 580 if (matchPattern(getRhs(), m_NegZeroFloat())) 581 return getLhs(); 582 583 return constFoldBinaryOp<FloatAttr>( 584 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 585 } 586 587 //===----------------------------------------------------------------------===// 588 // SubFOp 589 //===----------------------------------------------------------------------===// 590 591 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 592 // subf(x, +0) -> x 593 if (matchPattern(getRhs(), m_PosZeroFloat())) 594 return getLhs(); 595 596 return constFoldBinaryOp<FloatAttr>( 597 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 598 } 599 600 //===----------------------------------------------------------------------===// 601 // MaxFOp 602 //===----------------------------------------------------------------------===// 603 604 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) { 605 assert(operands.size() == 2 && "maxf takes two operands"); 606 607 // maxf(x,x) -> x 608 if (getLhs() == getRhs()) 609 return getRhs(); 610 611 // maxf(x, -inf) -> x 612 if (matchPattern(getRhs(), m_NegInfFloat())) 613 return getLhs(); 614 615 return constFoldBinaryOp<FloatAttr>( 616 operands, 617 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // MaxSIOp 622 //===----------------------------------------------------------------------===// 623 624 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 625 assert(operands.size() == 2 && "binary operation takes two operands"); 626 627 // maxsi(x,x) -> x 628 if (getLhs() == getRhs()) 629 return getRhs(); 630 631 APInt intValue; 632 // maxsi(x,MAX_INT) -> MAX_INT 633 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 634 intValue.isMaxSignedValue()) 635 return getRhs(); 636 637 // maxsi(x, MIN_INT) -> x 638 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 639 intValue.isMinSignedValue()) 640 return getLhs(); 641 642 return constFoldBinaryOp<IntegerAttr>(operands, 643 [](const APInt &a, const APInt &b) { 644 return llvm::APIntOps::smax(a, b); 645 }); 646 } 647 648 //===----------------------------------------------------------------------===// 649 // MaxUIOp 650 //===----------------------------------------------------------------------===// 651 652 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 653 assert(operands.size() == 2 && "binary operation takes two operands"); 654 655 // maxui(x,x) -> x 656 if (getLhs() == getRhs()) 657 return getRhs(); 658 659 APInt intValue; 660 // maxui(x,MAX_INT) -> MAX_INT 661 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 662 return getRhs(); 663 664 // maxui(x, MIN_INT) -> x 665 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 666 return getLhs(); 667 668 return constFoldBinaryOp<IntegerAttr>(operands, 669 [](const APInt &a, const APInt &b) { 670 return llvm::APIntOps::umax(a, b); 671 }); 672 } 673 674 //===----------------------------------------------------------------------===// 675 // MinFOp 676 //===----------------------------------------------------------------------===// 677 678 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) { 679 assert(operands.size() == 2 && "minf takes two operands"); 680 681 // minf(x,x) -> x 682 if (getLhs() == getRhs()) 683 return getRhs(); 684 685 // minf(x, +inf) -> x 686 if (matchPattern(getRhs(), m_PosInfFloat())) 687 return getLhs(); 688 689 return constFoldBinaryOp<FloatAttr>( 690 operands, 691 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); 692 } 693 694 //===----------------------------------------------------------------------===// 695 // MinSIOp 696 //===----------------------------------------------------------------------===// 697 698 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 699 assert(operands.size() == 2 && "binary operation takes two operands"); 700 701 // minsi(x,x) -> x 702 if (getLhs() == getRhs()) 703 return getRhs(); 704 705 APInt intValue; 706 // minsi(x,MIN_INT) -> MIN_INT 707 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 708 intValue.isMinSignedValue()) 709 return getRhs(); 710 711 // minsi(x, MAX_INT) -> x 712 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 713 intValue.isMaxSignedValue()) 714 return getLhs(); 715 716 return constFoldBinaryOp<IntegerAttr>(operands, 717 [](const APInt &a, const APInt &b) { 718 return llvm::APIntOps::smin(a, b); 719 }); 720 } 721 722 //===----------------------------------------------------------------------===// 723 // MinUIOp 724 //===----------------------------------------------------------------------===// 725 726 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 727 assert(operands.size() == 2 && "binary operation takes two operands"); 728 729 // minui(x,x) -> x 730 if (getLhs() == getRhs()) 731 return getRhs(); 732 733 APInt intValue; 734 // minui(x,MIN_INT) -> MIN_INT 735 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 736 return getRhs(); 737 738 // minui(x, MAX_INT) -> x 739 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 740 return getLhs(); 741 742 return constFoldBinaryOp<IntegerAttr>(operands, 743 [](const APInt &a, const APInt &b) { 744 return llvm::APIntOps::umin(a, b); 745 }); 746 } 747 748 //===----------------------------------------------------------------------===// 749 // MulFOp 750 //===----------------------------------------------------------------------===// 751 752 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 753 APFloat floatValue(0.0f), inverseValue(0.0f); 754 // mulf(x, 1) -> x 755 if (matchPattern(getRhs(), m_OneFloat())) 756 return getLhs(); 757 758 // mulf(1, x) -> x 759 if (matchPattern(getLhs(), m_OneFloat())) 760 return getRhs(); 761 762 return constFoldBinaryOp<FloatAttr>( 763 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 764 } 765 766 //===----------------------------------------------------------------------===// 767 // DivFOp 768 //===----------------------------------------------------------------------===// 769 770 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 771 APFloat floatValue(0.0f), inverseValue(0.0f); 772 // divf(x, 1) -> x 773 if (matchPattern(getRhs(), m_OneFloat())) 774 return getLhs(); 775 776 return constFoldBinaryOp<FloatAttr>( 777 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 778 } 779 780 //===----------------------------------------------------------------------===// 781 // Utility functions for verifying cast ops 782 //===----------------------------------------------------------------------===// 783 784 template <typename... Types> 785 using type_list = std::tuple<Types...> *; 786 787 /// Returns a non-null type only if the provided type is one of the allowed 788 /// types or one of the allowed shaped types of the allowed types. Returns the 789 /// element type if a valid shaped type is provided. 790 template <typename... ShapedTypes, typename... ElementTypes> 791 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 792 type_list<ElementTypes...>) { 793 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 794 return {}; 795 796 auto underlyingType = getElementTypeOrSelf(type); 797 if (!underlyingType.isa<ElementTypes...>()) 798 return {}; 799 800 return underlyingType; 801 } 802 803 /// Get allowed underlying types for vectors and tensors. 804 template <typename... ElementTypes> 805 static Type getTypeIfLike(Type type) { 806 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 807 type_list<ElementTypes...>()); 808 } 809 810 /// Get allowed underlying types for vectors, tensors, and memrefs. 811 template <typename... ElementTypes> 812 static Type getTypeIfLikeOrMemRef(Type type) { 813 return getUnderlyingType(type, 814 type_list<VectorType, TensorType, MemRefType>(), 815 type_list<ElementTypes...>()); 816 } 817 818 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 819 return inputs.size() == 1 && outputs.size() == 1 && 820 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 821 } 822 823 //===----------------------------------------------------------------------===// 824 // Verifiers for integer and floating point extension/truncation ops 825 //===----------------------------------------------------------------------===// 826 827 // Extend ops can only extend to a wider type. 828 template <typename ValType, typename Op> 829 static LogicalResult verifyExtOp(Op op) { 830 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 831 Type dstType = getElementTypeOrSelf(op.getType()); 832 833 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 834 return op.emitError("result type ") 835 << dstType << " must be wider than operand type " << srcType; 836 837 return success(); 838 } 839 840 // Truncate ops can only truncate to a shorter type. 841 template <typename ValType, typename Op> 842 static LogicalResult verifyTruncateOp(Op op) { 843 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 844 Type dstType = getElementTypeOrSelf(op.getType()); 845 846 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 847 return op.emitError("result type ") 848 << dstType << " must be shorter than operand type " << srcType; 849 850 return success(); 851 } 852 853 /// Validate a cast that changes the width of a type. 854 template <template <typename> class WidthComparator, typename... ElementTypes> 855 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 856 if (!areValidCastInputsAndOutputs(inputs, outputs)) 857 return false; 858 859 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 860 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 861 if (!srcType || !dstType) 862 return false; 863 864 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 865 srcType.getIntOrFloatBitWidth()); 866 } 867 868 //===----------------------------------------------------------------------===// 869 // ExtUIOp 870 //===----------------------------------------------------------------------===// 871 872 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 873 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 874 return IntegerAttr::get( 875 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 876 877 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 878 getInMutable().assign(lhs.getIn()); 879 return getResult(); 880 } 881 882 return {}; 883 } 884 885 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 886 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 887 } 888 889 //===----------------------------------------------------------------------===// 890 // ExtSIOp 891 //===----------------------------------------------------------------------===// 892 893 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 894 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 895 return IntegerAttr::get( 896 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 897 898 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 899 getInMutable().assign(lhs.getIn()); 900 return getResult(); 901 } 902 903 return {}; 904 } 905 906 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 907 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 908 } 909 910 void arith::ExtSIOp::getCanonicalizationPatterns( 911 RewritePatternSet &patterns, MLIRContext *context) { 912 patterns.insert<ExtSIOfExtUI>(context); 913 } 914 915 //===----------------------------------------------------------------------===// 916 // ExtFOp 917 //===----------------------------------------------------------------------===// 918 919 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 920 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 921 } 922 923 //===----------------------------------------------------------------------===// 924 // TruncIOp 925 //===----------------------------------------------------------------------===// 926 927 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 928 assert(operands.size() == 1 && "unary operation takes one operand"); 929 930 // trunci(zexti(a)) -> a 931 // trunci(sexti(a)) -> a 932 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 933 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 934 return getOperand().getDefiningOp()->getOperand(0); 935 936 // trunci(trunci(a)) -> trunci(a)) 937 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { 938 setOperand(getOperand().getDefiningOp()->getOperand(0)); 939 return getResult(); 940 } 941 942 if (!operands[0]) 943 return {}; 944 945 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 946 return IntegerAttr::get( 947 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 948 } 949 950 return {}; 951 } 952 953 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 954 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 955 } 956 957 //===----------------------------------------------------------------------===// 958 // TruncFOp 959 //===----------------------------------------------------------------------===// 960 961 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 962 /// can be represented without precision loss or rounding. 963 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 964 assert(operands.size() == 1 && "unary operation takes one operand"); 965 966 auto constOperand = operands.front(); 967 if (!constOperand || !constOperand.isa<FloatAttr>()) 968 return {}; 969 970 // Convert to target type via 'double'. 971 double sourceValue = 972 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 973 auto targetAttr = FloatAttr::get(getType(), sourceValue); 974 975 // Propagate if constant's value does not change after truncation. 976 if (sourceValue == targetAttr.getValue().convertToDouble()) 977 return targetAttr; 978 979 return {}; 980 } 981 982 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 983 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 984 } 985 986 //===----------------------------------------------------------------------===// 987 // AndIOp 988 //===----------------------------------------------------------------------===// 989 990 void arith::AndIOp::getCanonicalizationPatterns( 991 RewritePatternSet &patterns, MLIRContext *context) { 992 patterns.insert<AndOfExtUI, AndOfExtSI>(context); 993 } 994 995 //===----------------------------------------------------------------------===// 996 // OrIOp 997 //===----------------------------------------------------------------------===// 998 999 void arith::OrIOp::getCanonicalizationPatterns( 1000 RewritePatternSet &patterns, MLIRContext *context) { 1001 patterns.insert<OrOfExtUI, OrOfExtSI>(context); 1002 } 1003 1004 //===----------------------------------------------------------------------===// 1005 // Verifiers for casts between integers and floats. 1006 //===----------------------------------------------------------------------===// 1007 1008 template <typename From, typename To> 1009 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 1010 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1011 return false; 1012 1013 auto srcType = getTypeIfLike<From>(inputs.front()); 1014 auto dstType = getTypeIfLike<To>(outputs.back()); 1015 1016 return srcType && dstType; 1017 } 1018 1019 //===----------------------------------------------------------------------===// 1020 // UIToFPOp 1021 //===----------------------------------------------------------------------===// 1022 1023 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1024 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1025 } 1026 1027 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 1028 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 1029 const APInt &api = lhs.getValue(); 1030 FloatType floatTy = getType().cast<FloatType>(); 1031 APFloat apf(floatTy.getFloatSemantics(), 1032 APInt::getZero(floatTy.getWidth())); 1033 apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven); 1034 return FloatAttr::get(floatTy, apf); 1035 } 1036 return {}; 1037 } 1038 1039 //===----------------------------------------------------------------------===// 1040 // SIToFPOp 1041 //===----------------------------------------------------------------------===// 1042 1043 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1044 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1045 } 1046 1047 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 1048 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 1049 const APInt &api = lhs.getValue(); 1050 FloatType floatTy = getType().cast<FloatType>(); 1051 APFloat apf(floatTy.getFloatSemantics(), 1052 APInt::getZero(floatTy.getWidth())); 1053 apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven); 1054 return FloatAttr::get(floatTy, apf); 1055 } 1056 return {}; 1057 } 1058 //===----------------------------------------------------------------------===// 1059 // FPToUIOp 1060 //===----------------------------------------------------------------------===// 1061 1062 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1063 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1064 } 1065 1066 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 1067 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 1068 const APFloat &apf = lhs.getValue(); 1069 IntegerType intTy = getType().cast<IntegerType>(); 1070 bool ignored; 1071 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 1072 if (APFloat::opInvalidOp == 1073 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1074 // Undefined behavior invoked - the destination type can't represent 1075 // the input constant. 1076 return {}; 1077 } 1078 return IntegerAttr::get(getType(), api); 1079 } 1080 1081 return {}; 1082 } 1083 1084 //===----------------------------------------------------------------------===// 1085 // FPToSIOp 1086 //===----------------------------------------------------------------------===// 1087 1088 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1089 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1090 } 1091 1092 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 1093 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 1094 const APFloat &apf = lhs.getValue(); 1095 IntegerType intTy = getType().cast<IntegerType>(); 1096 bool ignored; 1097 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 1098 if (APFloat::opInvalidOp == 1099 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1100 // Undefined behavior invoked - the destination type can't represent 1101 // the input constant. 1102 return {}; 1103 } 1104 return IntegerAttr::get(getType(), api); 1105 } 1106 1107 return {}; 1108 } 1109 1110 //===----------------------------------------------------------------------===// 1111 // IndexCastOp 1112 //===----------------------------------------------------------------------===// 1113 1114 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1115 TypeRange outputs) { 1116 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1117 return false; 1118 1119 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1120 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1121 if (!srcType || !dstType) 1122 return false; 1123 1124 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1125 (srcType.isSignlessInteger() && dstType.isIndex()); 1126 } 1127 1128 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1129 // index_cast(constant) -> constant 1130 // A little hack because we go through int. Otherwise, the size of the 1131 // constant might need to change. 1132 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1133 return IntegerAttr::get(getType(), value.getInt()); 1134 1135 return {}; 1136 } 1137 1138 void arith::IndexCastOp::getCanonicalizationPatterns( 1139 RewritePatternSet &patterns, MLIRContext *context) { 1140 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1141 } 1142 1143 //===----------------------------------------------------------------------===// 1144 // BitcastOp 1145 //===----------------------------------------------------------------------===// 1146 1147 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1148 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1149 return false; 1150 1151 auto srcType = 1152 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1153 auto dstType = 1154 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1155 if (!srcType || !dstType) 1156 return false; 1157 1158 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1159 } 1160 1161 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1162 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1163 1164 auto resType = getType(); 1165 auto operand = operands[0]; 1166 if (!operand) 1167 return {}; 1168 1169 /// Bitcast dense elements. 1170 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1171 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1172 /// Other shaped types unhandled. 1173 if (resType.isa<ShapedType>()) 1174 return {}; 1175 1176 /// Bitcast integer or float to integer or float. 1177 APInt bits = operand.isa<FloatAttr>() 1178 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1179 : operand.cast<IntegerAttr>().getValue(); 1180 1181 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1182 return FloatAttr::get(resType, 1183 APFloat(resFloatType.getFloatSemantics(), bits)); 1184 return IntegerAttr::get(resType, bits); 1185 } 1186 1187 void arith::BitcastOp::getCanonicalizationPatterns( 1188 RewritePatternSet &patterns, MLIRContext *context) { 1189 patterns.insert<BitcastOfBitcast>(context); 1190 } 1191 1192 //===----------------------------------------------------------------------===// 1193 // Helpers for compare ops 1194 //===----------------------------------------------------------------------===// 1195 1196 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1197 static Type getI1SameShape(Type type) { 1198 auto i1Type = IntegerType::get(type.getContext(), 1); 1199 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1200 return RankedTensorType::get(tensorType.getShape(), i1Type); 1201 if (type.isa<UnrankedTensorType>()) 1202 return UnrankedTensorType::get(i1Type); 1203 if (auto vectorType = type.dyn_cast<VectorType>()) 1204 return VectorType::get(vectorType.getShape(), i1Type, 1205 vectorType.getNumScalableDims()); 1206 return i1Type; 1207 } 1208 1209 //===----------------------------------------------------------------------===// 1210 // CmpIOp 1211 //===----------------------------------------------------------------------===// 1212 1213 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1214 /// comparison predicates. 1215 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1216 const APInt &lhs, const APInt &rhs) { 1217 switch (predicate) { 1218 case arith::CmpIPredicate::eq: 1219 return lhs.eq(rhs); 1220 case arith::CmpIPredicate::ne: 1221 return lhs.ne(rhs); 1222 case arith::CmpIPredicate::slt: 1223 return lhs.slt(rhs); 1224 case arith::CmpIPredicate::sle: 1225 return lhs.sle(rhs); 1226 case arith::CmpIPredicate::sgt: 1227 return lhs.sgt(rhs); 1228 case arith::CmpIPredicate::sge: 1229 return lhs.sge(rhs); 1230 case arith::CmpIPredicate::ult: 1231 return lhs.ult(rhs); 1232 case arith::CmpIPredicate::ule: 1233 return lhs.ule(rhs); 1234 case arith::CmpIPredicate::ugt: 1235 return lhs.ugt(rhs); 1236 case arith::CmpIPredicate::uge: 1237 return lhs.uge(rhs); 1238 } 1239 llvm_unreachable("unknown cmpi predicate kind"); 1240 } 1241 1242 /// Returns true if the predicate is true for two equal operands. 1243 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1244 switch (predicate) { 1245 case arith::CmpIPredicate::eq: 1246 case arith::CmpIPredicate::sle: 1247 case arith::CmpIPredicate::sge: 1248 case arith::CmpIPredicate::ule: 1249 case arith::CmpIPredicate::uge: 1250 return true; 1251 case arith::CmpIPredicate::ne: 1252 case arith::CmpIPredicate::slt: 1253 case arith::CmpIPredicate::sgt: 1254 case arith::CmpIPredicate::ult: 1255 case arith::CmpIPredicate::ugt: 1256 return false; 1257 } 1258 llvm_unreachable("unknown cmpi predicate kind"); 1259 } 1260 1261 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1262 auto boolAttr = BoolAttr::get(ctx, value); 1263 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1264 if (!shapedType) 1265 return boolAttr; 1266 return DenseElementsAttr::get(shapedType, boolAttr); 1267 } 1268 1269 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1270 assert(operands.size() == 2 && "cmpi takes two operands"); 1271 1272 // cmpi(pred, x, x) 1273 if (getLhs() == getRhs()) { 1274 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1275 return getBoolAttribute(getType(), getContext(), val); 1276 } 1277 1278 if (matchPattern(getRhs(), m_Zero())) { 1279 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1280 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1281 // extsi(%x : i1 -> iN) != 0 -> %x 1282 if (getPredicate() == arith::CmpIPredicate::ne) { 1283 return extOp.getOperand(); 1284 } 1285 } 1286 } 1287 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1288 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1289 // extui(%x : i1 -> iN) != 0 -> %x 1290 if (getPredicate() == arith::CmpIPredicate::ne) { 1291 return extOp.getOperand(); 1292 } 1293 } 1294 } 1295 } 1296 1297 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1298 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1299 if (!lhs || !rhs) 1300 return {}; 1301 1302 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1303 return BoolAttr::get(getContext(), val); 1304 } 1305 1306 //===----------------------------------------------------------------------===// 1307 // CmpFOp 1308 //===----------------------------------------------------------------------===// 1309 1310 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1311 /// comparison predicates. 1312 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1313 const APFloat &lhs, const APFloat &rhs) { 1314 auto cmpResult = lhs.compare(rhs); 1315 switch (predicate) { 1316 case arith::CmpFPredicate::AlwaysFalse: 1317 return false; 1318 case arith::CmpFPredicate::OEQ: 1319 return cmpResult == APFloat::cmpEqual; 1320 case arith::CmpFPredicate::OGT: 1321 return cmpResult == APFloat::cmpGreaterThan; 1322 case arith::CmpFPredicate::OGE: 1323 return cmpResult == APFloat::cmpGreaterThan || 1324 cmpResult == APFloat::cmpEqual; 1325 case arith::CmpFPredicate::OLT: 1326 return cmpResult == APFloat::cmpLessThan; 1327 case arith::CmpFPredicate::OLE: 1328 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1329 case arith::CmpFPredicate::ONE: 1330 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1331 case arith::CmpFPredicate::ORD: 1332 return cmpResult != APFloat::cmpUnordered; 1333 case arith::CmpFPredicate::UEQ: 1334 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1335 case arith::CmpFPredicate::UGT: 1336 return cmpResult == APFloat::cmpUnordered || 1337 cmpResult == APFloat::cmpGreaterThan; 1338 case arith::CmpFPredicate::UGE: 1339 return cmpResult == APFloat::cmpUnordered || 1340 cmpResult == APFloat::cmpGreaterThan || 1341 cmpResult == APFloat::cmpEqual; 1342 case arith::CmpFPredicate::ULT: 1343 return cmpResult == APFloat::cmpUnordered || 1344 cmpResult == APFloat::cmpLessThan; 1345 case arith::CmpFPredicate::ULE: 1346 return cmpResult == APFloat::cmpUnordered || 1347 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1348 case arith::CmpFPredicate::UNE: 1349 return cmpResult != APFloat::cmpEqual; 1350 case arith::CmpFPredicate::UNO: 1351 return cmpResult == APFloat::cmpUnordered; 1352 case arith::CmpFPredicate::AlwaysTrue: 1353 return true; 1354 } 1355 llvm_unreachable("unknown cmpf predicate kind"); 1356 } 1357 1358 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1359 assert(operands.size() == 2 && "cmpf takes two operands"); 1360 1361 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1362 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1363 1364 // If one operand is NaN, making them both NaN does not change the result. 1365 if (lhs && lhs.getValue().isNaN()) 1366 rhs = lhs; 1367 if (rhs && rhs.getValue().isNaN()) 1368 lhs = rhs; 1369 1370 if (!lhs || !rhs) 1371 return {}; 1372 1373 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1374 return BoolAttr::get(getContext(), val); 1375 } 1376 1377 //===----------------------------------------------------------------------===// 1378 // Atomic Enum 1379 //===----------------------------------------------------------------------===// 1380 1381 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1382 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1383 OpBuilder &builder, Location loc) { 1384 switch (kind) { 1385 case AtomicRMWKind::maxf: 1386 return builder.getFloatAttr( 1387 resultType, 1388 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1389 /*Negative=*/true)); 1390 case AtomicRMWKind::addf: 1391 case AtomicRMWKind::addi: 1392 case AtomicRMWKind::maxu: 1393 case AtomicRMWKind::ori: 1394 return builder.getZeroAttr(resultType); 1395 case AtomicRMWKind::andi: 1396 return builder.getIntegerAttr( 1397 resultType, 1398 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1399 case AtomicRMWKind::maxs: 1400 return builder.getIntegerAttr( 1401 resultType, 1402 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1403 case AtomicRMWKind::minf: 1404 return builder.getFloatAttr( 1405 resultType, 1406 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1407 /*Negative=*/false)); 1408 case AtomicRMWKind::mins: 1409 return builder.getIntegerAttr( 1410 resultType, 1411 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1412 case AtomicRMWKind::minu: 1413 return builder.getIntegerAttr( 1414 resultType, 1415 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1416 case AtomicRMWKind::muli: 1417 return builder.getIntegerAttr(resultType, 1); 1418 case AtomicRMWKind::mulf: 1419 return builder.getFloatAttr(resultType, 1); 1420 // TODO: Add remaining reduction operations. 1421 default: 1422 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1423 break; 1424 } 1425 return nullptr; 1426 } 1427 1428 /// Returns the identity value associated with an AtomicRMWKind op. 1429 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 1430 OpBuilder &builder, Location loc) { 1431 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 1432 return builder.create<arith::ConstantOp>(loc, attr); 1433 } 1434 1435 /// Return the value obtained by applying the reduction operation kind 1436 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 1437 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 1438 Location loc, Value lhs, Value rhs) { 1439 switch (op) { 1440 case AtomicRMWKind::addf: 1441 return builder.create<arith::AddFOp>(loc, lhs, rhs); 1442 case AtomicRMWKind::addi: 1443 return builder.create<arith::AddIOp>(loc, lhs, rhs); 1444 case AtomicRMWKind::mulf: 1445 return builder.create<arith::MulFOp>(loc, lhs, rhs); 1446 case AtomicRMWKind::muli: 1447 return builder.create<arith::MulIOp>(loc, lhs, rhs); 1448 case AtomicRMWKind::maxf: 1449 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 1450 case AtomicRMWKind::minf: 1451 return builder.create<arith::MinFOp>(loc, lhs, rhs); 1452 case AtomicRMWKind::maxs: 1453 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 1454 case AtomicRMWKind::mins: 1455 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 1456 case AtomicRMWKind::maxu: 1457 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 1458 case AtomicRMWKind::minu: 1459 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 1460 case AtomicRMWKind::ori: 1461 return builder.create<arith::OrIOp>(loc, lhs, rhs); 1462 case AtomicRMWKind::andi: 1463 return builder.create<arith::AndIOp>(loc, lhs, rhs); 1464 // TODO: Add remaining reduction operations. 1465 default: 1466 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1467 break; 1468 } 1469 return nullptr; 1470 } 1471 1472 //===----------------------------------------------------------------------===// 1473 // TableGen'd op method definitions 1474 //===----------------------------------------------------------------------===// 1475 1476 #define GET_OP_CLASSES 1477 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1478 1479 //===----------------------------------------------------------------------===// 1480 // TableGen'd enum attribute definitions 1481 //===----------------------------------------------------------------------===// 1482 1483 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1484