1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 #include "llvm/ADT/APSInt.h" 20 21 using namespace mlir; 22 using namespace mlir::arith; 23 24 //===----------------------------------------------------------------------===// 25 // Pattern helpers 26 //===----------------------------------------------------------------------===// 27 28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 29 Attribute lhs, Attribute rhs) { 30 return builder.getIntegerAttr(res.getType(), 31 lhs.cast<IntegerAttr>().getInt() + 32 rhs.cast<IntegerAttr>().getInt()); 33 } 34 35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 36 Attribute lhs, Attribute rhs) { 37 return builder.getIntegerAttr(res.getType(), 38 lhs.cast<IntegerAttr>().getInt() - 39 rhs.cast<IntegerAttr>().getInt()); 40 } 41 42 /// Invert an integer comparison predicate. 43 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { 44 switch (pred) { 45 case arith::CmpIPredicate::eq: 46 return arith::CmpIPredicate::ne; 47 case arith::CmpIPredicate::ne: 48 return arith::CmpIPredicate::eq; 49 case arith::CmpIPredicate::slt: 50 return arith::CmpIPredicate::sge; 51 case arith::CmpIPredicate::sle: 52 return arith::CmpIPredicate::sgt; 53 case arith::CmpIPredicate::sgt: 54 return arith::CmpIPredicate::sle; 55 case arith::CmpIPredicate::sge: 56 return arith::CmpIPredicate::slt; 57 case arith::CmpIPredicate::ult: 58 return arith::CmpIPredicate::uge; 59 case arith::CmpIPredicate::ule: 60 return arith::CmpIPredicate::ugt; 61 case arith::CmpIPredicate::ugt: 62 return arith::CmpIPredicate::ule; 63 case arith::CmpIPredicate::uge: 64 return arith::CmpIPredicate::ult; 65 } 66 llvm_unreachable("unknown cmpi predicate kind"); 67 } 68 69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 70 return arith::CmpIPredicateAttr::get(pred.getContext(), 71 invertPredicate(pred.getValue())); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // TableGen'd canonicalization patterns 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 #include "ArithmeticCanonicalization.inc" 80 } // namespace 81 82 //===----------------------------------------------------------------------===// 83 // ConstantOp 84 //===----------------------------------------------------------------------===// 85 86 void arith::ConstantOp::getAsmResultNames( 87 function_ref<void(Value, StringRef)> setNameFn) { 88 auto type = getType(); 89 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 90 auto intType = type.dyn_cast<IntegerType>(); 91 92 // Sugar i1 constants with 'true' and 'false'. 93 if (intType && intType.getWidth() == 1) 94 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 95 96 // Otherwise, build a compex name with the value and type. 97 SmallString<32> specialNameBuffer; 98 llvm::raw_svector_ostream specialName(specialNameBuffer); 99 specialName << 'c' << intCst.getInt(); 100 if (intType) 101 specialName << '_' << type; 102 setNameFn(getResult(), specialName.str()); 103 } else { 104 setNameFn(getResult(), "cst"); 105 } 106 } 107 108 /// TODO: disallow arith.constant to return anything other than signless integer 109 /// or float like. 110 LogicalResult arith::ConstantOp::verify() { 111 auto type = getType(); 112 // The value's type must match the return type. 113 if (getValue().getType() != type) { 114 return emitOpError() << "value type " << getValue().getType() 115 << " must match return type: " << type; 116 } 117 // Integer values must be signless. 118 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 119 return emitOpError("integer return type must be signless"); 120 // Any float or elements attribute are acceptable. 121 if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 122 return emitOpError( 123 "value must be an integer, float, or elements attribute"); 124 } 125 return success(); 126 } 127 128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 129 // The value's type must be the same as the provided type. 130 if (value.getType() != type) 131 return false; 132 // Integer values must be signless. 133 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 134 return false; 135 // Integer, float, and element attributes are buildable. 136 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 137 } 138 139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 140 return getValue(); 141 } 142 143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 144 int64_t value, unsigned width) { 145 auto type = builder.getIntegerType(width); 146 arith::ConstantOp::build(builder, result, type, 147 builder.getIntegerAttr(type, value)); 148 } 149 150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 151 int64_t value, Type type) { 152 assert(type.isSignlessInteger() && 153 "ConstantIntOp can only have signless integer type values"); 154 arith::ConstantOp::build(builder, result, type, 155 builder.getIntegerAttr(type, value)); 156 } 157 158 bool arith::ConstantIntOp::classof(Operation *op) { 159 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 160 return constOp.getType().isSignlessInteger(); 161 return false; 162 } 163 164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 165 const APFloat &value, FloatType type) { 166 arith::ConstantOp::build(builder, result, type, 167 builder.getFloatAttr(type, value)); 168 } 169 170 bool arith::ConstantFloatOp::classof(Operation *op) { 171 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 172 return constOp.getType().isa<FloatType>(); 173 return false; 174 } 175 176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 177 int64_t value) { 178 arith::ConstantOp::build(builder, result, builder.getIndexType(), 179 builder.getIndexAttr(value)); 180 } 181 182 bool arith::ConstantIndexOp::classof(Operation *op) { 183 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 184 return constOp.getType().isIndex(); 185 return false; 186 } 187 188 //===----------------------------------------------------------------------===// 189 // AddIOp 190 //===----------------------------------------------------------------------===// 191 192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 193 // addi(x, 0) -> x 194 if (matchPattern(getRhs(), m_Zero())) 195 return getLhs(); 196 197 // addi(subi(a, b), b) -> a 198 if (auto sub = getLhs().getDefiningOp<SubIOp>()) 199 if (getRhs() == sub.getRhs()) 200 return sub.getLhs(); 201 202 // addi(b, subi(a, b)) -> a 203 if (auto sub = getRhs().getDefiningOp<SubIOp>()) 204 if (getLhs() == sub.getRhs()) 205 return sub.getLhs(); 206 207 return constFoldBinaryOp<IntegerAttr>( 208 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 209 } 210 211 void arith::AddIOp::getCanonicalizationPatterns( 212 RewritePatternSet &patterns, MLIRContext *context) { 213 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 214 context); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // SubIOp 219 //===----------------------------------------------------------------------===// 220 221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 222 // subi(x,x) -> 0 223 if (getOperand(0) == getOperand(1)) 224 return Builder(getContext()).getZeroAttr(getType()); 225 // subi(x,0) -> x 226 if (matchPattern(getRhs(), m_Zero())) 227 return getLhs(); 228 229 return constFoldBinaryOp<IntegerAttr>( 230 operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); 231 } 232 233 void arith::SubIOp::getCanonicalizationPatterns( 234 RewritePatternSet &patterns, MLIRContext *context) { 235 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 236 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 237 SubILHSSubConstantLHS>(context); 238 } 239 240 //===----------------------------------------------------------------------===// 241 // MulIOp 242 //===----------------------------------------------------------------------===// 243 244 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 245 // muli(x, 0) -> 0 246 if (matchPattern(getRhs(), m_Zero())) 247 return getRhs(); 248 // muli(x, 1) -> x 249 if (matchPattern(getRhs(), m_One())) 250 return getOperand(0); 251 // TODO: Handle the overflow case. 252 253 // default folder 254 return constFoldBinaryOp<IntegerAttr>( 255 operands, [](const APInt &a, const APInt &b) { return a * b; }); 256 } 257 258 //===----------------------------------------------------------------------===// 259 // DivUIOp 260 //===----------------------------------------------------------------------===// 261 262 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 263 // Don't fold if it would require a division by zero. 264 bool div0 = false; 265 auto result = 266 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 267 if (div0 || !b) { 268 div0 = true; 269 return a; 270 } 271 return a.udiv(b); 272 }); 273 274 // Fold out division by one. Assumes all tensors of all ones are splats. 275 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 276 if (rhs.getValue() == 1) 277 return getLhs(); 278 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 279 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 280 return getLhs(); 281 } 282 283 return div0 ? Attribute() : result; 284 } 285 286 //===----------------------------------------------------------------------===// 287 // DivSIOp 288 //===----------------------------------------------------------------------===// 289 290 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 291 // Don't fold if it would overflow or if it requires a division by zero. 292 bool overflowOrDiv0 = false; 293 auto result = 294 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 295 if (overflowOrDiv0 || !b) { 296 overflowOrDiv0 = true; 297 return a; 298 } 299 return a.sdiv_ov(b, overflowOrDiv0); 300 }); 301 302 // Fold out division by one. Assumes all tensors of all ones are splats. 303 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 304 if (rhs.getValue() == 1) 305 return getLhs(); 306 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 307 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 308 return getLhs(); 309 } 310 311 return overflowOrDiv0 ? Attribute() : result; 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Ceil and floor division folding helpers 316 //===----------------------------------------------------------------------===// 317 318 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, 319 bool &overflow) { 320 // Returns (a-1)/b + 1 321 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 322 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 323 return val.sadd_ov(one, overflow); 324 } 325 326 //===----------------------------------------------------------------------===// 327 // CeilDivUIOp 328 //===----------------------------------------------------------------------===// 329 330 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 331 bool overflowOrDiv0 = false; 332 auto result = 333 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 334 if (overflowOrDiv0 || !b) { 335 overflowOrDiv0 = true; 336 return a; 337 } 338 APInt quotient = a.udiv(b); 339 if (!a.urem(b)) 340 return quotient; 341 APInt one(a.getBitWidth(), 1, true); 342 return quotient.uadd_ov(one, overflowOrDiv0); 343 }); 344 // Fold out ceil division by one. Assumes all tensors of all ones are 345 // splats. 346 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 347 if (rhs.getValue() == 1) 348 return getLhs(); 349 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 350 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 351 return getLhs(); 352 } 353 354 return overflowOrDiv0 ? Attribute() : result; 355 } 356 357 //===----------------------------------------------------------------------===// 358 // CeilDivSIOp 359 //===----------------------------------------------------------------------===// 360 361 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 362 // Don't fold if it would overflow or if it requires a division by zero. 363 bool overflowOrDiv0 = false; 364 auto result = 365 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 366 if (overflowOrDiv0 || !b) { 367 overflowOrDiv0 = true; 368 return a; 369 } 370 if (!a) 371 return a; 372 // After this point we know that neither a or b are zero. 373 unsigned bits = a.getBitWidth(); 374 APInt zero = APInt::getZero(bits); 375 bool aGtZero = a.sgt(zero); 376 bool bGtZero = b.sgt(zero); 377 if (aGtZero && bGtZero) { 378 // Both positive, return ceil(a, b). 379 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 380 } 381 if (!aGtZero && !bGtZero) { 382 // Both negative, return ceil(-a, -b). 383 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 384 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 385 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 386 } 387 if (!aGtZero && bGtZero) { 388 // A is negative, b is positive, return - ( -a / b). 389 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 390 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 391 return zero.ssub_ov(div, overflowOrDiv0); 392 } 393 // A is positive, b is negative, return - (a / -b). 394 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 395 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 396 return zero.ssub_ov(div, overflowOrDiv0); 397 }); 398 399 // Fold out ceil division by one. Assumes all tensors of all ones are 400 // splats. 401 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 402 if (rhs.getValue() == 1) 403 return getLhs(); 404 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 405 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 406 return getLhs(); 407 } 408 409 return overflowOrDiv0 ? Attribute() : result; 410 } 411 412 //===----------------------------------------------------------------------===// 413 // FloorDivSIOp 414 //===----------------------------------------------------------------------===// 415 416 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 417 // Don't fold if it would overflow or if it requires a division by zero. 418 bool overflowOrDiv0 = false; 419 auto result = 420 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) { 421 if (overflowOrDiv0 || !b) { 422 overflowOrDiv0 = true; 423 return a; 424 } 425 if (!a) 426 return a; 427 // After this point we know that neither a or b are zero. 428 unsigned bits = a.getBitWidth(); 429 APInt zero = APInt::getZero(bits); 430 bool aGtZero = a.sgt(zero); 431 bool bGtZero = b.sgt(zero); 432 if (aGtZero && bGtZero) { 433 // Both positive, return a / b. 434 return a.sdiv_ov(b, overflowOrDiv0); 435 } 436 if (!aGtZero && !bGtZero) { 437 // Both negative, return -a / -b. 438 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 439 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 440 return posA.sdiv_ov(posB, overflowOrDiv0); 441 } 442 if (!aGtZero && bGtZero) { 443 // A is negative, b is positive, return - ceil(-a, b). 444 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 445 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 446 return zero.ssub_ov(ceil, overflowOrDiv0); 447 } 448 // A is positive, b is negative, return - ceil(a, -b). 449 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 450 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 451 return zero.ssub_ov(ceil, overflowOrDiv0); 452 }); 453 454 // Fold out floor division by one. Assumes all tensors of all ones are 455 // splats. 456 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 457 if (rhs.getValue() == 1) 458 return getLhs(); 459 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 460 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 461 return getLhs(); 462 } 463 464 return overflowOrDiv0 ? Attribute() : result; 465 } 466 467 //===----------------------------------------------------------------------===// 468 // RemUIOp 469 //===----------------------------------------------------------------------===// 470 471 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 472 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 473 if (!rhs) 474 return {}; 475 auto rhsValue = rhs.getValue(); 476 477 // x % 1 = 0 478 if (rhsValue.isOneValue()) 479 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 480 481 // Don't fold if it requires division by zero. 482 if (rhsValue.isNullValue()) 483 return {}; 484 485 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 486 if (!lhs) 487 return {}; 488 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 489 } 490 491 //===----------------------------------------------------------------------===// 492 // RemSIOp 493 //===----------------------------------------------------------------------===// 494 495 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 496 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 497 if (!rhs) 498 return {}; 499 auto rhsValue = rhs.getValue(); 500 501 // x % 1 = 0 502 if (rhsValue.isOneValue()) 503 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 504 505 // Don't fold if it requires division by zero. 506 if (rhsValue.isNullValue()) 507 return {}; 508 509 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 510 if (!lhs) 511 return {}; 512 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 513 } 514 515 //===----------------------------------------------------------------------===// 516 // AndIOp 517 //===----------------------------------------------------------------------===// 518 519 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 520 /// and(x, 0) -> 0 521 if (matchPattern(getRhs(), m_Zero())) 522 return getRhs(); 523 /// and(x, allOnes) -> x 524 APInt intValue; 525 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 526 return getLhs(); 527 528 return constFoldBinaryOp<IntegerAttr>( 529 operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); 530 } 531 532 //===----------------------------------------------------------------------===// 533 // OrIOp 534 //===----------------------------------------------------------------------===// 535 536 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 537 /// or(x, 0) -> x 538 if (matchPattern(getRhs(), m_Zero())) 539 return getLhs(); 540 /// or(x, <all ones>) -> <all ones> 541 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 542 if (rhsAttr.getValue().isAllOnes()) 543 return rhsAttr; 544 545 return constFoldBinaryOp<IntegerAttr>( 546 operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); 547 } 548 549 //===----------------------------------------------------------------------===// 550 // XOrIOp 551 //===----------------------------------------------------------------------===// 552 553 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 554 /// xor(x, 0) -> x 555 if (matchPattern(getRhs(), m_Zero())) 556 return getLhs(); 557 /// xor(x, x) -> 0 558 if (getLhs() == getRhs()) 559 return Builder(getContext()).getZeroAttr(getType()); 560 /// xor(xor(x, a), a) -> x 561 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) 562 if (prev.getRhs() == getRhs()) 563 return prev.getLhs(); 564 565 return constFoldBinaryOp<IntegerAttr>( 566 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); 567 } 568 569 void arith::XOrIOp::getCanonicalizationPatterns( 570 RewritePatternSet &patterns, MLIRContext *context) { 571 patterns.insert<XOrINotCmpI>(context); 572 } 573 574 //===----------------------------------------------------------------------===// 575 // AddFOp 576 //===----------------------------------------------------------------------===// 577 578 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 579 // addf(x, -0) -> x 580 if (matchPattern(getRhs(), m_NegZeroFloat())) 581 return getLhs(); 582 583 return constFoldBinaryOp<FloatAttr>( 584 operands, [](const APFloat &a, const APFloat &b) { return a + b; }); 585 } 586 587 //===----------------------------------------------------------------------===// 588 // SubFOp 589 //===----------------------------------------------------------------------===// 590 591 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 592 // subf(x, +0) -> x 593 if (matchPattern(getRhs(), m_PosZeroFloat())) 594 return getLhs(); 595 596 return constFoldBinaryOp<FloatAttr>( 597 operands, [](const APFloat &a, const APFloat &b) { return a - b; }); 598 } 599 600 //===----------------------------------------------------------------------===// 601 // MaxFOp 602 //===----------------------------------------------------------------------===// 603 604 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) { 605 assert(operands.size() == 2 && "maxf takes two operands"); 606 607 // maxf(x,x) -> x 608 if (getLhs() == getRhs()) 609 return getRhs(); 610 611 // maxf(x, -inf) -> x 612 if (matchPattern(getRhs(), m_NegInfFloat())) 613 return getLhs(); 614 615 return constFoldBinaryOp<FloatAttr>( 616 operands, 617 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // MaxSIOp 622 //===----------------------------------------------------------------------===// 623 624 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 625 assert(operands.size() == 2 && "binary operation takes two operands"); 626 627 // maxsi(x,x) -> x 628 if (getLhs() == getRhs()) 629 return getRhs(); 630 631 APInt intValue; 632 // maxsi(x,MAX_INT) -> MAX_INT 633 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 634 intValue.isMaxSignedValue()) 635 return getRhs(); 636 637 // maxsi(x, MIN_INT) -> x 638 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 639 intValue.isMinSignedValue()) 640 return getLhs(); 641 642 return constFoldBinaryOp<IntegerAttr>(operands, 643 [](const APInt &a, const APInt &b) { 644 return llvm::APIntOps::smax(a, b); 645 }); 646 } 647 648 //===----------------------------------------------------------------------===// 649 // MaxUIOp 650 //===----------------------------------------------------------------------===// 651 652 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 653 assert(operands.size() == 2 && "binary operation takes two operands"); 654 655 // maxui(x,x) -> x 656 if (getLhs() == getRhs()) 657 return getRhs(); 658 659 APInt intValue; 660 // maxui(x,MAX_INT) -> MAX_INT 661 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 662 return getRhs(); 663 664 // maxui(x, MIN_INT) -> x 665 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 666 return getLhs(); 667 668 return constFoldBinaryOp<IntegerAttr>(operands, 669 [](const APInt &a, const APInt &b) { 670 return llvm::APIntOps::umax(a, b); 671 }); 672 } 673 674 //===----------------------------------------------------------------------===// 675 // MinFOp 676 //===----------------------------------------------------------------------===// 677 678 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) { 679 assert(operands.size() == 2 && "minf takes two operands"); 680 681 // minf(x,x) -> x 682 if (getLhs() == getRhs()) 683 return getRhs(); 684 685 // minf(x, +inf) -> x 686 if (matchPattern(getRhs(), m_PosInfFloat())) 687 return getLhs(); 688 689 return constFoldBinaryOp<FloatAttr>( 690 operands, 691 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); 692 } 693 694 //===----------------------------------------------------------------------===// 695 // MinSIOp 696 //===----------------------------------------------------------------------===// 697 698 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 699 assert(operands.size() == 2 && "binary operation takes two operands"); 700 701 // minsi(x,x) -> x 702 if (getLhs() == getRhs()) 703 return getRhs(); 704 705 APInt intValue; 706 // minsi(x,MIN_INT) -> MIN_INT 707 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 708 intValue.isMinSignedValue()) 709 return getRhs(); 710 711 // minsi(x, MAX_INT) -> x 712 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 713 intValue.isMaxSignedValue()) 714 return getLhs(); 715 716 return constFoldBinaryOp<IntegerAttr>(operands, 717 [](const APInt &a, const APInt &b) { 718 return llvm::APIntOps::smin(a, b); 719 }); 720 } 721 722 //===----------------------------------------------------------------------===// 723 // MinUIOp 724 //===----------------------------------------------------------------------===// 725 726 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 727 assert(operands.size() == 2 && "binary operation takes two operands"); 728 729 // minui(x,x) -> x 730 if (getLhs() == getRhs()) 731 return getRhs(); 732 733 APInt intValue; 734 // minui(x,MIN_INT) -> MIN_INT 735 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 736 return getRhs(); 737 738 // minui(x, MAX_INT) -> x 739 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 740 return getLhs(); 741 742 return constFoldBinaryOp<IntegerAttr>(operands, 743 [](const APInt &a, const APInt &b) { 744 return llvm::APIntOps::umin(a, b); 745 }); 746 } 747 748 //===----------------------------------------------------------------------===// 749 // MulFOp 750 //===----------------------------------------------------------------------===// 751 752 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 753 APFloat floatValue(0.0f), inverseValue(0.0f); 754 // mulf(x, 1) -> x 755 if (matchPattern(getRhs(), m_OneFloat())) 756 return getLhs(); 757 758 // mulf(1, x) -> x 759 if (matchPattern(getLhs(), m_OneFloat())) 760 return getRhs(); 761 762 return constFoldBinaryOp<FloatAttr>( 763 operands, [](const APFloat &a, const APFloat &b) { return a * b; }); 764 } 765 766 //===----------------------------------------------------------------------===// 767 // DivFOp 768 //===----------------------------------------------------------------------===// 769 770 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 771 APFloat floatValue(0.0f), inverseValue(0.0f); 772 // divf(x, 1) -> x 773 if (matchPattern(getRhs(), m_OneFloat())) 774 return getLhs(); 775 776 return constFoldBinaryOp<FloatAttr>( 777 operands, [](const APFloat &a, const APFloat &b) { return a / b; }); 778 } 779 780 //===----------------------------------------------------------------------===// 781 // Utility functions for verifying cast ops 782 //===----------------------------------------------------------------------===// 783 784 template <typename... Types> 785 using type_list = std::tuple<Types...> *; 786 787 /// Returns a non-null type only if the provided type is one of the allowed 788 /// types or one of the allowed shaped types of the allowed types. Returns the 789 /// element type if a valid shaped type is provided. 790 template <typename... ShapedTypes, typename... ElementTypes> 791 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 792 type_list<ElementTypes...>) { 793 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 794 return {}; 795 796 auto underlyingType = getElementTypeOrSelf(type); 797 if (!underlyingType.isa<ElementTypes...>()) 798 return {}; 799 800 return underlyingType; 801 } 802 803 /// Get allowed underlying types for vectors and tensors. 804 template <typename... ElementTypes> 805 static Type getTypeIfLike(Type type) { 806 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 807 type_list<ElementTypes...>()); 808 } 809 810 /// Get allowed underlying types for vectors, tensors, and memrefs. 811 template <typename... ElementTypes> 812 static Type getTypeIfLikeOrMemRef(Type type) { 813 return getUnderlyingType(type, 814 type_list<VectorType, TensorType, MemRefType>(), 815 type_list<ElementTypes...>()); 816 } 817 818 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 819 return inputs.size() == 1 && outputs.size() == 1 && 820 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 821 } 822 823 //===----------------------------------------------------------------------===// 824 // Verifiers for integer and floating point extension/truncation ops 825 //===----------------------------------------------------------------------===// 826 827 // Extend ops can only extend to a wider type. 828 template <typename ValType, typename Op> 829 static LogicalResult verifyExtOp(Op op) { 830 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 831 Type dstType = getElementTypeOrSelf(op.getType()); 832 833 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 834 return op.emitError("result type ") 835 << dstType << " must be wider than operand type " << srcType; 836 837 return success(); 838 } 839 840 // Truncate ops can only truncate to a shorter type. 841 template <typename ValType, typename Op> 842 static LogicalResult verifyTruncateOp(Op op) { 843 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 844 Type dstType = getElementTypeOrSelf(op.getType()); 845 846 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 847 return op.emitError("result type ") 848 << dstType << " must be shorter than operand type " << srcType; 849 850 return success(); 851 } 852 853 /// Validate a cast that changes the width of a type. 854 template <template <typename> class WidthComparator, typename... ElementTypes> 855 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 856 if (!areValidCastInputsAndOutputs(inputs, outputs)) 857 return false; 858 859 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 860 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 861 if (!srcType || !dstType) 862 return false; 863 864 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 865 srcType.getIntOrFloatBitWidth()); 866 } 867 868 //===----------------------------------------------------------------------===// 869 // ExtUIOp 870 //===----------------------------------------------------------------------===// 871 872 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 873 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 874 return IntegerAttr::get( 875 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 876 877 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { 878 getInMutable().assign(lhs.getIn()); 879 return getResult(); 880 } 881 882 return {}; 883 } 884 885 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 886 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 887 } 888 889 LogicalResult arith::ExtUIOp::verify() { 890 return verifyExtOp<IntegerType>(*this); 891 } 892 893 //===----------------------------------------------------------------------===// 894 // ExtSIOp 895 //===----------------------------------------------------------------------===// 896 897 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 898 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 899 return IntegerAttr::get( 900 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 901 902 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { 903 getInMutable().assign(lhs.getIn()); 904 return getResult(); 905 } 906 907 return {}; 908 } 909 910 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 911 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 912 } 913 914 void arith::ExtSIOp::getCanonicalizationPatterns( 915 RewritePatternSet &patterns, MLIRContext *context) { 916 patterns.insert<ExtSIOfExtUI>(context); 917 } 918 919 LogicalResult arith::ExtSIOp::verify() { 920 return verifyExtOp<IntegerType>(*this); 921 } 922 923 //===----------------------------------------------------------------------===// 924 // ExtFOp 925 //===----------------------------------------------------------------------===// 926 927 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 928 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 929 } 930 931 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } 932 933 //===----------------------------------------------------------------------===// 934 // TruncIOp 935 //===----------------------------------------------------------------------===// 936 937 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 938 assert(operands.size() == 1 && "unary operation takes one operand"); 939 940 // trunci(zexti(a)) -> a 941 // trunci(sexti(a)) -> a 942 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 943 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 944 return getOperand().getDefiningOp()->getOperand(0); 945 946 // trunci(trunci(a)) -> trunci(a)) 947 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { 948 setOperand(getOperand().getDefiningOp()->getOperand(0)); 949 return getResult(); 950 } 951 952 if (!operands[0]) 953 return {}; 954 955 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 956 return IntegerAttr::get( 957 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 958 } 959 960 return {}; 961 } 962 963 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 964 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 965 } 966 967 LogicalResult arith::TruncIOp::verify() { 968 return verifyTruncateOp<IntegerType>(*this); 969 } 970 971 //===----------------------------------------------------------------------===// 972 // TruncFOp 973 //===----------------------------------------------------------------------===// 974 975 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 976 /// can be represented without precision loss or rounding. 977 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 978 assert(operands.size() == 1 && "unary operation takes one operand"); 979 980 auto constOperand = operands.front(); 981 if (!constOperand || !constOperand.isa<FloatAttr>()) 982 return {}; 983 984 // Convert to target type via 'double'. 985 double sourceValue = 986 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 987 auto targetAttr = FloatAttr::get(getType(), sourceValue); 988 989 // Propagate if constant's value does not change after truncation. 990 if (sourceValue == targetAttr.getValue().convertToDouble()) 991 return targetAttr; 992 993 return {}; 994 } 995 996 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 997 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 998 } 999 1000 LogicalResult arith::TruncFOp::verify() { 1001 return verifyTruncateOp<FloatType>(*this); 1002 } 1003 1004 //===----------------------------------------------------------------------===// 1005 // AndIOp 1006 //===----------------------------------------------------------------------===// 1007 1008 void arith::AndIOp::getCanonicalizationPatterns( 1009 RewritePatternSet &patterns, MLIRContext *context) { 1010 patterns.insert<AndOfExtUI, AndOfExtSI>(context); 1011 } 1012 1013 //===----------------------------------------------------------------------===// 1014 // OrIOp 1015 //===----------------------------------------------------------------------===// 1016 1017 void arith::OrIOp::getCanonicalizationPatterns( 1018 RewritePatternSet &patterns, MLIRContext *context) { 1019 patterns.insert<OrOfExtUI, OrOfExtSI>(context); 1020 } 1021 1022 //===----------------------------------------------------------------------===// 1023 // Verifiers for casts between integers and floats. 1024 //===----------------------------------------------------------------------===// 1025 1026 template <typename From, typename To> 1027 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 1028 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1029 return false; 1030 1031 auto srcType = getTypeIfLike<From>(inputs.front()); 1032 auto dstType = getTypeIfLike<To>(outputs.back()); 1033 1034 return srcType && dstType; 1035 } 1036 1037 //===----------------------------------------------------------------------===// 1038 // UIToFPOp 1039 //===----------------------------------------------------------------------===// 1040 1041 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1042 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1043 } 1044 1045 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 1046 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 1047 const APInt &api = lhs.getValue(); 1048 FloatType floatTy = getType().cast<FloatType>(); 1049 APFloat apf(floatTy.getFloatSemantics(), 1050 APInt::getZero(floatTy.getWidth())); 1051 apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven); 1052 return FloatAttr::get(floatTy, apf); 1053 } 1054 return {}; 1055 } 1056 1057 //===----------------------------------------------------------------------===// 1058 // SIToFPOp 1059 //===----------------------------------------------------------------------===// 1060 1061 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1062 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 1063 } 1064 1065 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 1066 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 1067 const APInt &api = lhs.getValue(); 1068 FloatType floatTy = getType().cast<FloatType>(); 1069 APFloat apf(floatTy.getFloatSemantics(), 1070 APInt::getZero(floatTy.getWidth())); 1071 apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven); 1072 return FloatAttr::get(floatTy, apf); 1073 } 1074 return {}; 1075 } 1076 //===----------------------------------------------------------------------===// 1077 // FPToUIOp 1078 //===----------------------------------------------------------------------===// 1079 1080 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1081 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1082 } 1083 1084 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 1085 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 1086 const APFloat &apf = lhs.getValue(); 1087 IntegerType intTy = getType().cast<IntegerType>(); 1088 bool ignored; 1089 APSInt api(intTy.getWidth(), /*isUnsigned=*/true); 1090 if (APFloat::opInvalidOp == 1091 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1092 // Undefined behavior invoked - the destination type can't represent 1093 // the input constant. 1094 return {}; 1095 } 1096 return IntegerAttr::get(getType(), api); 1097 } 1098 1099 return {}; 1100 } 1101 1102 //===----------------------------------------------------------------------===// 1103 // FPToSIOp 1104 //===----------------------------------------------------------------------===// 1105 1106 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1107 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 1108 } 1109 1110 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 1111 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 1112 const APFloat &apf = lhs.getValue(); 1113 IntegerType intTy = getType().cast<IntegerType>(); 1114 bool ignored; 1115 APSInt api(intTy.getWidth(), /*isUnsigned=*/false); 1116 if (APFloat::opInvalidOp == 1117 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 1118 // Undefined behavior invoked - the destination type can't represent 1119 // the input constant. 1120 return {}; 1121 } 1122 return IntegerAttr::get(getType(), api); 1123 } 1124 1125 return {}; 1126 } 1127 1128 //===----------------------------------------------------------------------===// 1129 // IndexCastOp 1130 //===----------------------------------------------------------------------===// 1131 1132 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 1133 TypeRange outputs) { 1134 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1135 return false; 1136 1137 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 1138 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 1139 if (!srcType || !dstType) 1140 return false; 1141 1142 return (srcType.isIndex() && dstType.isSignlessInteger()) || 1143 (srcType.isSignlessInteger() && dstType.isIndex()); 1144 } 1145 1146 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 1147 // index_cast(constant) -> constant 1148 // A little hack because we go through int. Otherwise, the size of the 1149 // constant might need to change. 1150 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 1151 return IntegerAttr::get(getType(), value.getInt()); 1152 1153 return {}; 1154 } 1155 1156 void arith::IndexCastOp::getCanonicalizationPatterns( 1157 RewritePatternSet &patterns, MLIRContext *context) { 1158 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1159 } 1160 1161 //===----------------------------------------------------------------------===// 1162 // BitcastOp 1163 //===----------------------------------------------------------------------===// 1164 1165 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1166 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1167 return false; 1168 1169 auto srcType = 1170 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1171 auto dstType = 1172 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1173 if (!srcType || !dstType) 1174 return false; 1175 1176 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1177 } 1178 1179 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1180 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1181 1182 auto resType = getType(); 1183 auto operand = operands[0]; 1184 if (!operand) 1185 return {}; 1186 1187 /// Bitcast dense elements. 1188 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1189 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1190 /// Other shaped types unhandled. 1191 if (resType.isa<ShapedType>()) 1192 return {}; 1193 1194 /// Bitcast integer or float to integer or float. 1195 APInt bits = operand.isa<FloatAttr>() 1196 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1197 : operand.cast<IntegerAttr>().getValue(); 1198 1199 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1200 return FloatAttr::get(resType, 1201 APFloat(resFloatType.getFloatSemantics(), bits)); 1202 return IntegerAttr::get(resType, bits); 1203 } 1204 1205 void arith::BitcastOp::getCanonicalizationPatterns( 1206 RewritePatternSet &patterns, MLIRContext *context) { 1207 patterns.insert<BitcastOfBitcast>(context); 1208 } 1209 1210 //===----------------------------------------------------------------------===// 1211 // Helpers for compare ops 1212 //===----------------------------------------------------------------------===// 1213 1214 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1215 static Type getI1SameShape(Type type) { 1216 auto i1Type = IntegerType::get(type.getContext(), 1); 1217 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1218 return RankedTensorType::get(tensorType.getShape(), i1Type); 1219 if (type.isa<UnrankedTensorType>()) 1220 return UnrankedTensorType::get(i1Type); 1221 if (auto vectorType = type.dyn_cast<VectorType>()) 1222 return VectorType::get(vectorType.getShape(), i1Type, 1223 vectorType.getNumScalableDims()); 1224 return i1Type; 1225 } 1226 1227 //===----------------------------------------------------------------------===// 1228 // CmpIOp 1229 //===----------------------------------------------------------------------===// 1230 1231 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1232 /// comparison predicates. 1233 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1234 const APInt &lhs, const APInt &rhs) { 1235 switch (predicate) { 1236 case arith::CmpIPredicate::eq: 1237 return lhs.eq(rhs); 1238 case arith::CmpIPredicate::ne: 1239 return lhs.ne(rhs); 1240 case arith::CmpIPredicate::slt: 1241 return lhs.slt(rhs); 1242 case arith::CmpIPredicate::sle: 1243 return lhs.sle(rhs); 1244 case arith::CmpIPredicate::sgt: 1245 return lhs.sgt(rhs); 1246 case arith::CmpIPredicate::sge: 1247 return lhs.sge(rhs); 1248 case arith::CmpIPredicate::ult: 1249 return lhs.ult(rhs); 1250 case arith::CmpIPredicate::ule: 1251 return lhs.ule(rhs); 1252 case arith::CmpIPredicate::ugt: 1253 return lhs.ugt(rhs); 1254 case arith::CmpIPredicate::uge: 1255 return lhs.uge(rhs); 1256 } 1257 llvm_unreachable("unknown cmpi predicate kind"); 1258 } 1259 1260 /// Returns true if the predicate is true for two equal operands. 1261 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1262 switch (predicate) { 1263 case arith::CmpIPredicate::eq: 1264 case arith::CmpIPredicate::sle: 1265 case arith::CmpIPredicate::sge: 1266 case arith::CmpIPredicate::ule: 1267 case arith::CmpIPredicate::uge: 1268 return true; 1269 case arith::CmpIPredicate::ne: 1270 case arith::CmpIPredicate::slt: 1271 case arith::CmpIPredicate::sgt: 1272 case arith::CmpIPredicate::ult: 1273 case arith::CmpIPredicate::ugt: 1274 return false; 1275 } 1276 llvm_unreachable("unknown cmpi predicate kind"); 1277 } 1278 1279 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1280 auto boolAttr = BoolAttr::get(ctx, value); 1281 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1282 if (!shapedType) 1283 return boolAttr; 1284 return DenseElementsAttr::get(shapedType, boolAttr); 1285 } 1286 1287 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1288 assert(operands.size() == 2 && "cmpi takes two operands"); 1289 1290 // cmpi(pred, x, x) 1291 if (getLhs() == getRhs()) { 1292 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1293 return getBoolAttribute(getType(), getContext(), val); 1294 } 1295 1296 if (matchPattern(getRhs(), m_Zero())) { 1297 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { 1298 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1299 // extsi(%x : i1 -> iN) != 0 -> %x 1300 if (getPredicate() == arith::CmpIPredicate::ne) { 1301 return extOp.getOperand(); 1302 } 1303 } 1304 } 1305 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { 1306 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) { 1307 // extui(%x : i1 -> iN) != 0 -> %x 1308 if (getPredicate() == arith::CmpIPredicate::ne) { 1309 return extOp.getOperand(); 1310 } 1311 } 1312 } 1313 } 1314 1315 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1316 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1317 if (!lhs || !rhs) 1318 return {}; 1319 1320 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1321 return BoolAttr::get(getContext(), val); 1322 } 1323 1324 //===----------------------------------------------------------------------===// 1325 // CmpFOp 1326 //===----------------------------------------------------------------------===// 1327 1328 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1329 /// comparison predicates. 1330 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1331 const APFloat &lhs, const APFloat &rhs) { 1332 auto cmpResult = lhs.compare(rhs); 1333 switch (predicate) { 1334 case arith::CmpFPredicate::AlwaysFalse: 1335 return false; 1336 case arith::CmpFPredicate::OEQ: 1337 return cmpResult == APFloat::cmpEqual; 1338 case arith::CmpFPredicate::OGT: 1339 return cmpResult == APFloat::cmpGreaterThan; 1340 case arith::CmpFPredicate::OGE: 1341 return cmpResult == APFloat::cmpGreaterThan || 1342 cmpResult == APFloat::cmpEqual; 1343 case arith::CmpFPredicate::OLT: 1344 return cmpResult == APFloat::cmpLessThan; 1345 case arith::CmpFPredicate::OLE: 1346 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1347 case arith::CmpFPredicate::ONE: 1348 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1349 case arith::CmpFPredicate::ORD: 1350 return cmpResult != APFloat::cmpUnordered; 1351 case arith::CmpFPredicate::UEQ: 1352 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1353 case arith::CmpFPredicate::UGT: 1354 return cmpResult == APFloat::cmpUnordered || 1355 cmpResult == APFloat::cmpGreaterThan; 1356 case arith::CmpFPredicate::UGE: 1357 return cmpResult == APFloat::cmpUnordered || 1358 cmpResult == APFloat::cmpGreaterThan || 1359 cmpResult == APFloat::cmpEqual; 1360 case arith::CmpFPredicate::ULT: 1361 return cmpResult == APFloat::cmpUnordered || 1362 cmpResult == APFloat::cmpLessThan; 1363 case arith::CmpFPredicate::ULE: 1364 return cmpResult == APFloat::cmpUnordered || 1365 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1366 case arith::CmpFPredicate::UNE: 1367 return cmpResult != APFloat::cmpEqual; 1368 case arith::CmpFPredicate::UNO: 1369 return cmpResult == APFloat::cmpUnordered; 1370 case arith::CmpFPredicate::AlwaysTrue: 1371 return true; 1372 } 1373 llvm_unreachable("unknown cmpf predicate kind"); 1374 } 1375 1376 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1377 assert(operands.size() == 2 && "cmpf takes two operands"); 1378 1379 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1380 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1381 1382 // If one operand is NaN, making them both NaN does not change the result. 1383 if (lhs && lhs.getValue().isNaN()) 1384 rhs = lhs; 1385 if (rhs && rhs.getValue().isNaN()) 1386 lhs = rhs; 1387 1388 if (!lhs || !rhs) 1389 return {}; 1390 1391 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1392 return BoolAttr::get(getContext(), val); 1393 } 1394 1395 //===----------------------------------------------------------------------===// 1396 // SelectOp 1397 //===----------------------------------------------------------------------===// 1398 1399 // Transforms a select of a boolean to arithmetic operations 1400 // 1401 // arith.select %arg, %x, %y : i1 1402 // 1403 // becomes 1404 // 1405 // and(%arg, %x) or and(!%arg, %y) 1406 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> { 1407 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1408 1409 LogicalResult matchAndRewrite(arith::SelectOp op, 1410 PatternRewriter &rewriter) const override { 1411 if (!op.getType().isInteger(1)) 1412 return failure(); 1413 1414 Value falseConstant = 1415 rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1); 1416 Value notCondition = rewriter.create<arith::XOrIOp>( 1417 op.getLoc(), op.getCondition(), falseConstant); 1418 1419 Value trueVal = rewriter.create<arith::AndIOp>( 1420 op.getLoc(), op.getCondition(), op.getTrueValue()); 1421 Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition, 1422 op.getFalseValue()); 1423 rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal); 1424 return success(); 1425 } 1426 }; 1427 1428 // select %arg, %c1, %c0 => extui %arg 1429 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { 1430 using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 1431 1432 LogicalResult matchAndRewrite(arith::SelectOp op, 1433 PatternRewriter &rewriter) const override { 1434 // Cannot extui i1 to i1, or i1 to f32 1435 if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1)) 1436 return failure(); 1437 1438 // select %x, c1, %c0 => extui %arg 1439 if (matchPattern(op.getTrueValue(), m_One())) 1440 if (matchPattern(op.getFalseValue(), m_Zero())) { 1441 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(), 1442 op.getCondition()); 1443 return success(); 1444 } 1445 1446 // select %x, c0, %c1 => extui (xor %arg, true) 1447 if (matchPattern(op.getTrueValue(), m_Zero())) 1448 if (matchPattern(op.getFalseValue(), m_One())) { 1449 rewriter.replaceOpWithNewOp<arith::ExtUIOp>( 1450 op, op.getType(), 1451 rewriter.create<arith::XOrIOp>( 1452 op.getLoc(), op.getCondition(), 1453 rewriter.create<arith::ConstantIntOp>( 1454 op.getLoc(), 1, op.getCondition().getType()))); 1455 return success(); 1456 } 1457 1458 return failure(); 1459 } 1460 }; 1461 1462 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, 1463 MLIRContext *context) { 1464 results.insert<SelectI1Simplify, SelectToExtUI>(context); 1465 } 1466 1467 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) { 1468 Value trueVal = getTrueValue(); 1469 Value falseVal = getFalseValue(); 1470 if (trueVal == falseVal) 1471 return trueVal; 1472 1473 Value condition = getCondition(); 1474 1475 // select true, %0, %1 => %0 1476 if (matchPattern(condition, m_One())) 1477 return trueVal; 1478 1479 // select false, %0, %1 => %1 1480 if (matchPattern(condition, m_Zero())) 1481 return falseVal; 1482 1483 // select %x, true, false => %x 1484 if (getType().isInteger(1)) 1485 if (matchPattern(getTrueValue(), m_One())) 1486 if (matchPattern(getFalseValue(), m_Zero())) 1487 return condition; 1488 1489 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { 1490 auto pred = cmp.getPredicate(); 1491 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { 1492 auto cmpLhs = cmp.getLhs(); 1493 auto cmpRhs = cmp.getRhs(); 1494 1495 // %0 = arith.cmpi eq, %arg0, %arg1 1496 // %1 = arith.select %0, %arg0, %arg1 => %arg1 1497 1498 // %0 = arith.cmpi ne, %arg0, %arg1 1499 // %1 = arith.select %0, %arg0, %arg1 => %arg0 1500 1501 if ((cmpLhs == trueVal && cmpRhs == falseVal) || 1502 (cmpRhs == trueVal && cmpLhs == falseVal)) 1503 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; 1504 } 1505 } 1506 return nullptr; 1507 } 1508 1509 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { 1510 Type conditionType, resultType; 1511 SmallVector<OpAsmParser::OperandType, 3> operands; 1512 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || 1513 parser.parseOptionalAttrDict(result.attributes) || 1514 parser.parseColonType(resultType)) 1515 return failure(); 1516 1517 // Check for the explicit condition type if this is a masked tensor or vector. 1518 if (succeeded(parser.parseOptionalComma())) { 1519 conditionType = resultType; 1520 if (parser.parseType(resultType)) 1521 return failure(); 1522 } else { 1523 conditionType = parser.getBuilder().getI1Type(); 1524 } 1525 1526 result.addTypes(resultType); 1527 return parser.resolveOperands(operands, 1528 {conditionType, resultType, resultType}, 1529 parser.getNameLoc(), result.operands); 1530 } 1531 1532 void arith::SelectOp::print(OpAsmPrinter &p) { 1533 p << " " << getOperands(); 1534 p.printOptionalAttrDict((*this)->getAttrs()); 1535 p << " : "; 1536 if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>()) 1537 p << condType << ", "; 1538 p << getType(); 1539 } 1540 1541 LogicalResult arith::SelectOp::verify() { 1542 Type conditionType = getCondition().getType(); 1543 if (conditionType.isSignlessInteger(1)) 1544 return success(); 1545 1546 // If the result type is a vector or tensor, the type can be a mask with the 1547 // same elements. 1548 Type resultType = getType(); 1549 if (!resultType.isa<TensorType, VectorType>()) 1550 return emitOpError() << "expected condition to be a signless i1, but got " 1551 << conditionType; 1552 Type shapedConditionType = getI1SameShape(resultType); 1553 if (conditionType != shapedConditionType) { 1554 return emitOpError() << "expected condition type to have the same shape " 1555 "as the result type, expected " 1556 << shapedConditionType << ", but got " 1557 << conditionType; 1558 } 1559 return success(); 1560 } 1561 1562 //===----------------------------------------------------------------------===// 1563 // Atomic Enum 1564 //===----------------------------------------------------------------------===// 1565 1566 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1567 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1568 OpBuilder &builder, Location loc) { 1569 switch (kind) { 1570 case AtomicRMWKind::maxf: 1571 return builder.getFloatAttr( 1572 resultType, 1573 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1574 /*Negative=*/true)); 1575 case AtomicRMWKind::addf: 1576 case AtomicRMWKind::addi: 1577 case AtomicRMWKind::maxu: 1578 case AtomicRMWKind::ori: 1579 return builder.getZeroAttr(resultType); 1580 case AtomicRMWKind::andi: 1581 return builder.getIntegerAttr( 1582 resultType, 1583 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1584 case AtomicRMWKind::maxs: 1585 return builder.getIntegerAttr( 1586 resultType, 1587 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1588 case AtomicRMWKind::minf: 1589 return builder.getFloatAttr( 1590 resultType, 1591 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1592 /*Negative=*/false)); 1593 case AtomicRMWKind::mins: 1594 return builder.getIntegerAttr( 1595 resultType, 1596 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1597 case AtomicRMWKind::minu: 1598 return builder.getIntegerAttr( 1599 resultType, 1600 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1601 case AtomicRMWKind::muli: 1602 return builder.getIntegerAttr(resultType, 1); 1603 case AtomicRMWKind::mulf: 1604 return builder.getFloatAttr(resultType, 1); 1605 // TODO: Add remaining reduction operations. 1606 default: 1607 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1608 break; 1609 } 1610 return nullptr; 1611 } 1612 1613 /// Returns the identity value associated with an AtomicRMWKind op. 1614 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 1615 OpBuilder &builder, Location loc) { 1616 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 1617 return builder.create<arith::ConstantOp>(loc, attr); 1618 } 1619 1620 /// Return the value obtained by applying the reduction operation kind 1621 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 1622 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 1623 Location loc, Value lhs, Value rhs) { 1624 switch (op) { 1625 case AtomicRMWKind::addf: 1626 return builder.create<arith::AddFOp>(loc, lhs, rhs); 1627 case AtomicRMWKind::addi: 1628 return builder.create<arith::AddIOp>(loc, lhs, rhs); 1629 case AtomicRMWKind::mulf: 1630 return builder.create<arith::MulFOp>(loc, lhs, rhs); 1631 case AtomicRMWKind::muli: 1632 return builder.create<arith::MulIOp>(loc, lhs, rhs); 1633 case AtomicRMWKind::maxf: 1634 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 1635 case AtomicRMWKind::minf: 1636 return builder.create<arith::MinFOp>(loc, lhs, rhs); 1637 case AtomicRMWKind::maxs: 1638 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 1639 case AtomicRMWKind::mins: 1640 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 1641 case AtomicRMWKind::maxu: 1642 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 1643 case AtomicRMWKind::minu: 1644 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 1645 case AtomicRMWKind::ori: 1646 return builder.create<arith::OrIOp>(loc, lhs, rhs); 1647 case AtomicRMWKind::andi: 1648 return builder.create<arith::AndIOp>(loc, lhs, rhs); 1649 // TODO: Add remaining reduction operations. 1650 default: 1651 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1652 break; 1653 } 1654 return nullptr; 1655 } 1656 1657 //===----------------------------------------------------------------------===// 1658 // TableGen'd op method definitions 1659 //===----------------------------------------------------------------------===// 1660 1661 #define GET_OP_CLASSES 1662 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1663 1664 //===----------------------------------------------------------------------===// 1665 // TableGen'd enum attribute definitions 1666 //===----------------------------------------------------------------------===// 1667 1668 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1669