1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/CommonFolders.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 #include "llvm/ADT/APSInt.h" 18 19 using namespace mlir; 20 using namespace mlir::arith; 21 22 //===----------------------------------------------------------------------===// 23 // Pattern helpers 24 //===----------------------------------------------------------------------===// 25 26 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 27 Attribute lhs, Attribute rhs) { 28 return builder.getIntegerAttr(res.getType(), 29 lhs.cast<IntegerAttr>().getInt() + 30 rhs.cast<IntegerAttr>().getInt()); 31 } 32 33 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 34 Attribute lhs, Attribute rhs) { 35 return builder.getIntegerAttr(res.getType(), 36 lhs.cast<IntegerAttr>().getInt() - 37 rhs.cast<IntegerAttr>().getInt()); 38 } 39 40 /// Invert an integer comparison predicate. 41 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { 42 switch (pred) { 43 case arith::CmpIPredicate::eq: 44 return arith::CmpIPredicate::ne; 45 case arith::CmpIPredicate::ne: 46 return arith::CmpIPredicate::eq; 47 case arith::CmpIPredicate::slt: 48 return arith::CmpIPredicate::sge; 49 case arith::CmpIPredicate::sle: 50 return arith::CmpIPredicate::sgt; 51 case arith::CmpIPredicate::sgt: 52 return arith::CmpIPredicate::sle; 53 case arith::CmpIPredicate::sge: 54 return arith::CmpIPredicate::slt; 55 case arith::CmpIPredicate::ult: 56 return arith::CmpIPredicate::uge; 57 case arith::CmpIPredicate::ule: 58 return arith::CmpIPredicate::ugt; 59 case arith::CmpIPredicate::ugt: 60 return arith::CmpIPredicate::ule; 61 case arith::CmpIPredicate::uge: 62 return arith::CmpIPredicate::ult; 63 } 64 llvm_unreachable("unknown cmpi predicate kind"); 65 } 66 67 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 68 return arith::CmpIPredicateAttr::get(pred.getContext(), 69 invertPredicate(pred.getValue())); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // TableGen'd canonicalization patterns 74 //===----------------------------------------------------------------------===// 75 76 namespace { 77 #include "ArithmeticCanonicalization.inc" 78 } // namespace 79 80 //===----------------------------------------------------------------------===// 81 // ConstantOp 82 //===----------------------------------------------------------------------===// 83 84 void arith::ConstantOp::getAsmResultNames( 85 function_ref<void(Value, StringRef)> setNameFn) { 86 auto type = getType(); 87 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 88 auto intType = type.dyn_cast<IntegerType>(); 89 90 // Sugar i1 constants with 'true' and 'false'. 91 if (intType && intType.getWidth() == 1) 92 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 93 94 // Otherwise, build a compex name with the value and type. 95 SmallString<32> specialNameBuffer; 96 llvm::raw_svector_ostream specialName(specialNameBuffer); 97 specialName << 'c' << intCst.getInt(); 98 if (intType) 99 specialName << '_' << type; 100 setNameFn(getResult(), specialName.str()); 101 } else { 102 setNameFn(getResult(), "cst"); 103 } 104 } 105 106 /// TODO: disallow arith.constant to return anything other than signless integer 107 /// or float like. 108 static LogicalResult verify(arith::ConstantOp op) { 109 auto type = op.getType(); 110 // The value's type must match the return type. 111 if (op.getValue().getType() != type) { 112 return op.emitOpError() << "value type " << op.getValue().getType() 113 << " must match return type: " << type; 114 } 115 // Integer values must be signless. 116 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 117 return op.emitOpError("integer return type must be signless"); 118 // Any float or elements attribute are acceptable. 119 if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 120 return op.emitOpError( 121 "value must be an integer, float, or elements attribute"); 122 } 123 return success(); 124 } 125 126 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 127 // The value's type must be the same as the provided type. 128 if (value.getType() != type) 129 return false; 130 // Integer values must be signless. 131 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 132 return false; 133 // Integer, float, and element attributes are buildable. 134 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 135 } 136 137 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 138 return getValue(); 139 } 140 141 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 142 int64_t value, unsigned width) { 143 auto type = builder.getIntegerType(width); 144 arith::ConstantOp::build(builder, result, type, 145 builder.getIntegerAttr(type, value)); 146 } 147 148 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 149 int64_t value, Type type) { 150 assert(type.isSignlessInteger() && 151 "ConstantIntOp can only have signless integer type values"); 152 arith::ConstantOp::build(builder, result, type, 153 builder.getIntegerAttr(type, value)); 154 } 155 156 bool arith::ConstantIntOp::classof(Operation *op) { 157 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 158 return constOp.getType().isSignlessInteger(); 159 return false; 160 } 161 162 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 163 const APFloat &value, FloatType type) { 164 arith::ConstantOp::build(builder, result, type, 165 builder.getFloatAttr(type, value)); 166 } 167 168 bool arith::ConstantFloatOp::classof(Operation *op) { 169 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 170 return constOp.getType().isa<FloatType>(); 171 return false; 172 } 173 174 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 175 int64_t value) { 176 arith::ConstantOp::build(builder, result, builder.getIndexType(), 177 builder.getIndexAttr(value)); 178 } 179 180 bool arith::ConstantIndexOp::classof(Operation *op) { 181 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 182 return constOp.getType().isIndex(); 183 return false; 184 } 185 186 //===----------------------------------------------------------------------===// 187 // AddIOp 188 //===----------------------------------------------------------------------===// 189 190 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 191 // addi(x, 0) -> x 192 if (matchPattern(getRhs(), m_Zero())) 193 return getLhs(); 194 195 return constFoldBinaryOp<IntegerAttr>(operands, 196 [](APInt a, APInt b) { return a + b; }); 197 } 198 199 void arith::AddIOp::getCanonicalizationPatterns( 200 OwningRewritePatternList &patterns, MLIRContext *context) { 201 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 202 context); 203 } 204 205 //===----------------------------------------------------------------------===// 206 // SubIOp 207 //===----------------------------------------------------------------------===// 208 209 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 210 // subi(x,x) -> 0 211 if (getOperand(0) == getOperand(1)) 212 return Builder(getContext()).getZeroAttr(getType()); 213 // subi(x,0) -> x 214 if (matchPattern(getRhs(), m_Zero())) 215 return getLhs(); 216 217 return constFoldBinaryOp<IntegerAttr>(operands, 218 [](APInt a, APInt b) { return a - b; }); 219 } 220 221 void arith::SubIOp::getCanonicalizationPatterns( 222 OwningRewritePatternList &patterns, MLIRContext *context) { 223 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 224 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 225 SubILHSSubConstantLHS>(context); 226 } 227 228 //===----------------------------------------------------------------------===// 229 // MulIOp 230 //===----------------------------------------------------------------------===// 231 232 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 233 // muli(x, 0) -> 0 234 if (matchPattern(getRhs(), m_Zero())) 235 return getRhs(); 236 // muli(x, 1) -> x 237 if (matchPattern(getRhs(), m_One())) 238 return getOperand(0); 239 // TODO: Handle the overflow case. 240 241 // default folder 242 return constFoldBinaryOp<IntegerAttr>(operands, 243 [](APInt a, APInt b) { return a * b; }); 244 } 245 246 //===----------------------------------------------------------------------===// 247 // DivUIOp 248 //===----------------------------------------------------------------------===// 249 250 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 251 // Don't fold if it would require a division by zero. 252 bool div0 = false; 253 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 254 if (div0 || !b) { 255 div0 = true; 256 return a; 257 } 258 return a.udiv(b); 259 }); 260 261 // Fold out division by one. Assumes all tensors of all ones are splats. 262 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 263 if (rhs.getValue() == 1) 264 return getLhs(); 265 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 266 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 267 return getLhs(); 268 } 269 270 return div0 ? Attribute() : result; 271 } 272 273 //===----------------------------------------------------------------------===// 274 // DivSIOp 275 //===----------------------------------------------------------------------===// 276 277 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 278 // Don't fold if it would overflow or if it requires a division by zero. 279 bool overflowOrDiv0 = false; 280 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 281 if (overflowOrDiv0 || !b) { 282 overflowOrDiv0 = true; 283 return a; 284 } 285 return a.sdiv_ov(b, overflowOrDiv0); 286 }); 287 288 // Fold out division by one. Assumes all tensors of all ones are splats. 289 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 290 if (rhs.getValue() == 1) 291 return getLhs(); 292 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 293 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 294 return getLhs(); 295 } 296 297 return overflowOrDiv0 ? Attribute() : result; 298 } 299 300 //===----------------------------------------------------------------------===// 301 // Ceil and floor division folding helpers 302 //===----------------------------------------------------------------------===// 303 304 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { 305 // Returns (a-1)/b + 1 306 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 307 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 308 return val.sadd_ov(one, overflow); 309 } 310 311 //===----------------------------------------------------------------------===// 312 // CeilDivUIOp 313 //===----------------------------------------------------------------------===// 314 315 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 316 bool overflowOrDiv0 = false; 317 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 318 if (overflowOrDiv0 || !b) { 319 overflowOrDiv0 = true; 320 return a; 321 } 322 APInt quotient = a.udiv(b); 323 if (!a.urem(b)) 324 return quotient; 325 APInt one(a.getBitWidth(), 1, true); 326 return quotient.uadd_ov(one, overflowOrDiv0); 327 }); 328 // Fold out ceil division by one. Assumes all tensors of all ones are 329 // splats. 330 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 331 if (rhs.getValue() == 1) 332 return getLhs(); 333 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 334 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 335 return getLhs(); 336 } 337 338 return overflowOrDiv0 ? Attribute() : result; 339 } 340 341 //===----------------------------------------------------------------------===// 342 // CeilDivSIOp 343 //===----------------------------------------------------------------------===// 344 345 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 346 // Don't fold if it would overflow or if it requires a division by zero. 347 bool overflowOrDiv0 = false; 348 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 349 if (overflowOrDiv0 || !b) { 350 overflowOrDiv0 = true; 351 return a; 352 } 353 unsigned bits = a.getBitWidth(); 354 APInt zero = APInt::getZero(bits); 355 if (a.sgt(zero) && b.sgt(zero)) { 356 // Both positive, return ceil(a, b). 357 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 358 } 359 if (a.slt(zero) && b.slt(zero)) { 360 // Both negative, return ceil(-a, -b). 361 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 362 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 363 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 364 } 365 if (a.slt(zero) && b.sgt(zero)) { 366 // A is negative, b is positive, return - ( -a / b). 367 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 368 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 369 return zero.ssub_ov(div, overflowOrDiv0); 370 } 371 // A is positive (or zero), b is negative, return - (a / -b). 372 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 373 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 374 return zero.ssub_ov(div, overflowOrDiv0); 375 }); 376 377 // Fold out ceil division by one. Assumes all tensors of all ones are 378 // splats. 379 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 380 if (rhs.getValue() == 1) 381 return getLhs(); 382 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 383 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 384 return getLhs(); 385 } 386 387 return overflowOrDiv0 ? Attribute() : result; 388 } 389 390 //===----------------------------------------------------------------------===// 391 // FloorDivSIOp 392 //===----------------------------------------------------------------------===// 393 394 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 395 // Don't fold if it would overflow or if it requires a division by zero. 396 bool overflowOrDiv0 = false; 397 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 398 if (overflowOrDiv0 || !b) { 399 overflowOrDiv0 = true; 400 return a; 401 } 402 unsigned bits = a.getBitWidth(); 403 APInt zero = APInt::getZero(bits); 404 if (a.sge(zero) && b.sgt(zero)) { 405 // Both positive (or a is zero), return a / b. 406 return a.sdiv_ov(b, overflowOrDiv0); 407 } 408 if (a.sle(zero) && b.slt(zero)) { 409 // Both negative (or a is zero), return -a / -b. 410 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 411 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 412 return posA.sdiv_ov(posB, overflowOrDiv0); 413 } 414 if (a.slt(zero) && b.sgt(zero)) { 415 // A is negative, b is positive, return - ceil(-a, b). 416 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 417 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 418 return zero.ssub_ov(ceil, overflowOrDiv0); 419 } 420 // A is positive, b is negative, return - ceil(a, -b). 421 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 422 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 423 return zero.ssub_ov(ceil, overflowOrDiv0); 424 }); 425 426 // Fold out floor division by one. Assumes all tensors of all ones are 427 // splats. 428 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 429 if (rhs.getValue() == 1) 430 return getLhs(); 431 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 432 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 433 return getLhs(); 434 } 435 436 return overflowOrDiv0 ? Attribute() : result; 437 } 438 439 //===----------------------------------------------------------------------===// 440 // RemUIOp 441 //===----------------------------------------------------------------------===// 442 443 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 444 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 445 if (!rhs) 446 return {}; 447 auto rhsValue = rhs.getValue(); 448 449 // x % 1 = 0 450 if (rhsValue.isOneValue()) 451 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 452 453 // Don't fold if it requires division by zero. 454 if (rhsValue.isNullValue()) 455 return {}; 456 457 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 458 if (!lhs) 459 return {}; 460 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 461 } 462 463 //===----------------------------------------------------------------------===// 464 // RemSIOp 465 //===----------------------------------------------------------------------===// 466 467 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 468 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 469 if (!rhs) 470 return {}; 471 auto rhsValue = rhs.getValue(); 472 473 // x % 1 = 0 474 if (rhsValue.isOneValue()) 475 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 476 477 // Don't fold if it requires division by zero. 478 if (rhsValue.isNullValue()) 479 return {}; 480 481 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 482 if (!lhs) 483 return {}; 484 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 485 } 486 487 //===----------------------------------------------------------------------===// 488 // AndIOp 489 //===----------------------------------------------------------------------===// 490 491 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 492 /// and(x, 0) -> 0 493 if (matchPattern(getRhs(), m_Zero())) 494 return getRhs(); 495 /// and(x, allOnes) -> x 496 APInt intValue; 497 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 498 return getLhs(); 499 500 return constFoldBinaryOp<IntegerAttr>(operands, 501 [](APInt a, APInt b) { return a & b; }); 502 } 503 504 //===----------------------------------------------------------------------===// 505 // OrIOp 506 //===----------------------------------------------------------------------===// 507 508 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 509 /// or(x, 0) -> x 510 if (matchPattern(getRhs(), m_Zero())) 511 return getLhs(); 512 /// or(x, <all ones>) -> <all ones> 513 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 514 if (rhsAttr.getValue().isAllOnes()) 515 return rhsAttr; 516 517 return constFoldBinaryOp<IntegerAttr>(operands, 518 [](APInt a, APInt b) { return a | b; }); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // XOrIOp 523 //===----------------------------------------------------------------------===// 524 525 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 526 /// xor(x, 0) -> x 527 if (matchPattern(getRhs(), m_Zero())) 528 return getLhs(); 529 /// xor(x, x) -> 0 530 if (getLhs() == getRhs()) 531 return Builder(getContext()).getZeroAttr(getType()); 532 533 return constFoldBinaryOp<IntegerAttr>(operands, 534 [](APInt a, APInt b) { return a ^ b; }); 535 } 536 537 void arith::XOrIOp::getCanonicalizationPatterns( 538 OwningRewritePatternList &patterns, MLIRContext *context) { 539 patterns.insert<XOrINotCmpI>(context); 540 } 541 542 //===----------------------------------------------------------------------===// 543 // AddFOp 544 //===----------------------------------------------------------------------===// 545 546 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 547 return constFoldBinaryOp<FloatAttr>( 548 operands, [](APFloat a, APFloat b) { return a + b; }); 549 } 550 551 //===----------------------------------------------------------------------===// 552 // SubFOp 553 //===----------------------------------------------------------------------===// 554 555 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 556 return constFoldBinaryOp<FloatAttr>( 557 operands, [](APFloat a, APFloat b) { return a - b; }); 558 } 559 560 //===----------------------------------------------------------------------===// 561 // MaxSIOp 562 //===----------------------------------------------------------------------===// 563 564 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 565 assert(operands.size() == 2 && "binary operation takes two operands"); 566 567 // maxsi(x,x) -> x 568 if (getLhs() == getRhs()) 569 return getRhs(); 570 571 APInt intValue; 572 // maxsi(x,MAX_INT) -> MAX_INT 573 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 574 intValue.isMaxSignedValue()) 575 return getRhs(); 576 577 // maxsi(x, MIN_INT) -> x 578 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 579 intValue.isMinSignedValue()) 580 return getLhs(); 581 582 return constFoldBinaryOp<IntegerAttr>( 583 operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); }); 584 } 585 586 //===----------------------------------------------------------------------===// 587 // MaxUIOp 588 //===----------------------------------------------------------------------===// 589 590 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 591 assert(operands.size() == 2 && "binary operation takes two operands"); 592 593 // maxui(x,x) -> x 594 if (getLhs() == getRhs()) 595 return getRhs(); 596 597 APInt intValue; 598 // maxui(x,MAX_INT) -> MAX_INT 599 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 600 return getRhs(); 601 602 // maxui(x, MIN_INT) -> x 603 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 604 return getLhs(); 605 606 return constFoldBinaryOp<IntegerAttr>( 607 operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); }); 608 } 609 610 //===----------------------------------------------------------------------===// 611 // MinSIOp 612 //===----------------------------------------------------------------------===// 613 614 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 615 assert(operands.size() == 2 && "binary operation takes two operands"); 616 617 // minsi(x,x) -> x 618 if (getLhs() == getRhs()) 619 return getRhs(); 620 621 APInt intValue; 622 // minsi(x,MIN_INT) -> MIN_INT 623 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 624 intValue.isMinSignedValue()) 625 return getRhs(); 626 627 // minsi(x, MAX_INT) -> x 628 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 629 intValue.isMaxSignedValue()) 630 return getLhs(); 631 632 return constFoldBinaryOp<IntegerAttr>( 633 operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); }); 634 } 635 636 //===----------------------------------------------------------------------===// 637 // MinUIOp 638 //===----------------------------------------------------------------------===// 639 640 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 641 assert(operands.size() == 2 && "binary operation takes two operands"); 642 643 // minui(x,x) -> x 644 if (getLhs() == getRhs()) 645 return getRhs(); 646 647 APInt intValue; 648 // minui(x,MIN_INT) -> MIN_INT 649 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 650 return getRhs(); 651 652 // minui(x, MAX_INT) -> x 653 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 654 return getLhs(); 655 656 return constFoldBinaryOp<IntegerAttr>( 657 operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); }); 658 } 659 660 //===----------------------------------------------------------------------===// 661 // MulFOp 662 //===----------------------------------------------------------------------===// 663 664 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 665 return constFoldBinaryOp<FloatAttr>( 666 operands, [](APFloat a, APFloat b) { return a * b; }); 667 } 668 669 //===----------------------------------------------------------------------===// 670 // DivFOp 671 //===----------------------------------------------------------------------===// 672 673 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 674 return constFoldBinaryOp<FloatAttr>( 675 operands, [](APFloat a, APFloat b) { return a / b; }); 676 } 677 678 //===----------------------------------------------------------------------===// 679 // Utility functions for verifying cast ops 680 //===----------------------------------------------------------------------===// 681 682 template <typename... Types> 683 using type_list = std::tuple<Types...> *; 684 685 /// Returns a non-null type only if the provided type is one of the allowed 686 /// types or one of the allowed shaped types of the allowed types. Returns the 687 /// element type if a valid shaped type is provided. 688 template <typename... ShapedTypes, typename... ElementTypes> 689 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 690 type_list<ElementTypes...>) { 691 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 692 return {}; 693 694 auto underlyingType = getElementTypeOrSelf(type); 695 if (!underlyingType.isa<ElementTypes...>()) 696 return {}; 697 698 return underlyingType; 699 } 700 701 /// Get allowed underlying types for vectors and tensors. 702 template <typename... ElementTypes> 703 static Type getTypeIfLike(Type type) { 704 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 705 type_list<ElementTypes...>()); 706 } 707 708 /// Get allowed underlying types for vectors, tensors, and memrefs. 709 template <typename... ElementTypes> 710 static Type getTypeIfLikeOrMemRef(Type type) { 711 return getUnderlyingType(type, 712 type_list<VectorType, TensorType, MemRefType>(), 713 type_list<ElementTypes...>()); 714 } 715 716 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 717 return inputs.size() == 1 && outputs.size() == 1 && 718 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 719 } 720 721 //===----------------------------------------------------------------------===// 722 // Verifiers for integer and floating point extension/truncation ops 723 //===----------------------------------------------------------------------===// 724 725 // Extend ops can only extend to a wider type. 726 template <typename ValType, typename Op> 727 static LogicalResult verifyExtOp(Op op) { 728 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 729 Type dstType = getElementTypeOrSelf(op.getType()); 730 731 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 732 return op.emitError("result type ") 733 << dstType << " must be wider than operand type " << srcType; 734 735 return success(); 736 } 737 738 // Truncate ops can only truncate to a shorter type. 739 template <typename ValType, typename Op> 740 static LogicalResult verifyTruncateOp(Op op) { 741 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 742 Type dstType = getElementTypeOrSelf(op.getType()); 743 744 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 745 return op.emitError("result type ") 746 << dstType << " must be shorter than operand type " << srcType; 747 748 return success(); 749 } 750 751 /// Validate a cast that changes the width of a type. 752 template <template <typename> class WidthComparator, typename... ElementTypes> 753 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 754 if (!areValidCastInputsAndOutputs(inputs, outputs)) 755 return false; 756 757 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 758 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 759 if (!srcType || !dstType) 760 return false; 761 762 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 763 srcType.getIntOrFloatBitWidth()); 764 } 765 766 //===----------------------------------------------------------------------===// 767 // ExtUIOp 768 //===----------------------------------------------------------------------===// 769 770 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 771 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 772 return IntegerAttr::get( 773 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 774 775 return {}; 776 } 777 778 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 779 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 780 } 781 782 //===----------------------------------------------------------------------===// 783 // ExtSIOp 784 //===----------------------------------------------------------------------===// 785 786 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 787 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 788 return IntegerAttr::get( 789 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 790 791 return {}; 792 } 793 794 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 795 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 796 } 797 798 //===----------------------------------------------------------------------===// 799 // ExtFOp 800 //===----------------------------------------------------------------------===// 801 802 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 803 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 804 } 805 806 //===----------------------------------------------------------------------===// 807 // TruncIOp 808 //===----------------------------------------------------------------------===// 809 810 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 811 // trunci(zexti(a)) -> a 812 // trunci(sexti(a)) -> a 813 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 814 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 815 return getOperand().getDefiningOp()->getOperand(0); 816 817 assert(operands.size() == 1 && "unary operation takes one operand"); 818 819 if (!operands[0]) 820 return {}; 821 822 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 823 return IntegerAttr::get( 824 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 825 } 826 827 return {}; 828 } 829 830 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 831 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 832 } 833 834 //===----------------------------------------------------------------------===// 835 // TruncFOp 836 //===----------------------------------------------------------------------===// 837 838 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 839 /// can be represented without precision loss or rounding. 840 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 841 assert(operands.size() == 1 && "unary operation takes one operand"); 842 843 auto constOperand = operands.front(); 844 if (!constOperand || !constOperand.isa<FloatAttr>()) 845 return {}; 846 847 // Convert to target type via 'double'. 848 double sourceValue = 849 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 850 auto targetAttr = FloatAttr::get(getType(), sourceValue); 851 852 // Propagate if constant's value does not change after truncation. 853 if (sourceValue == targetAttr.getValue().convertToDouble()) 854 return targetAttr; 855 856 return {}; 857 } 858 859 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 860 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 861 } 862 863 //===----------------------------------------------------------------------===// 864 // Verifiers for casts between integers and floats. 865 //===----------------------------------------------------------------------===// 866 867 template <typename From, typename To> 868 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 869 if (!areValidCastInputsAndOutputs(inputs, outputs)) 870 return false; 871 872 auto srcType = getTypeIfLike<From>(inputs.front()); 873 auto dstType = getTypeIfLike<To>(outputs.back()); 874 875 return srcType && dstType; 876 } 877 878 //===----------------------------------------------------------------------===// 879 // UIToFPOp 880 //===----------------------------------------------------------------------===// 881 882 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 883 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 884 } 885 886 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { 887 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 888 const APInt &api = lhs.getValue(); 889 FloatType floatTy = getType().cast<FloatType>(); 890 APFloat apf(floatTy.getFloatSemantics(), 891 APInt::getZero(floatTy.getWidth())); 892 apf.convertFromAPInt(api, /*signed=*/false, APFloat::rmNearestTiesToEven); 893 return FloatAttr::get(floatTy, apf); 894 } 895 return {}; 896 } 897 898 //===----------------------------------------------------------------------===// 899 // SIToFPOp 900 //===----------------------------------------------------------------------===// 901 902 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 903 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 904 } 905 906 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { 907 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { 908 const APInt &api = lhs.getValue(); 909 FloatType floatTy = getType().cast<FloatType>(); 910 APFloat apf(floatTy.getFloatSemantics(), 911 APInt::getZero(floatTy.getWidth())); 912 apf.convertFromAPInt(api, /*signed=*/true, APFloat::rmNearestTiesToEven); 913 return FloatAttr::get(floatTy, apf); 914 } 915 return {}; 916 } 917 //===----------------------------------------------------------------------===// 918 // FPToUIOp 919 //===----------------------------------------------------------------------===// 920 921 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 922 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 923 } 924 925 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { 926 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 927 const APFloat &apf = lhs.getValue(); 928 IntegerType intTy = getType().cast<IntegerType>(); 929 bool ignored; 930 APSInt api(intTy.getWidth(), /*unsigned=*/true); 931 if (APFloat::opInvalidOp == 932 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 933 // Undefined behavior invoked - the destination type can't represent 934 // the input constant. 935 return {}; 936 } 937 return IntegerAttr::get(getType(), api); 938 } 939 940 return {}; 941 } 942 943 //===----------------------------------------------------------------------===// 944 // FPToSIOp 945 //===----------------------------------------------------------------------===// 946 947 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 948 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 949 } 950 951 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { 952 if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { 953 const APFloat &apf = lhs.getValue(); 954 IntegerType intTy = getType().cast<IntegerType>(); 955 bool ignored; 956 APSInt api(intTy.getWidth(), /*unsigned=*/false); 957 if (APFloat::opInvalidOp == 958 apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { 959 // Undefined behavior invoked - the destination type can't represent 960 // the input constant. 961 return {}; 962 } 963 return IntegerAttr::get(getType(), api); 964 } 965 966 return {}; 967 } 968 969 //===----------------------------------------------------------------------===// 970 // IndexCastOp 971 //===----------------------------------------------------------------------===// 972 973 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 974 TypeRange outputs) { 975 if (!areValidCastInputsAndOutputs(inputs, outputs)) 976 return false; 977 978 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 979 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 980 if (!srcType || !dstType) 981 return false; 982 983 return (srcType.isIndex() && dstType.isSignlessInteger()) || 984 (srcType.isSignlessInteger() && dstType.isIndex()); 985 } 986 987 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 988 // index_cast(constant) -> constant 989 // A little hack because we go through int. Otherwise, the size of the 990 // constant might need to change. 991 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 992 return IntegerAttr::get(getType(), value.getInt()); 993 994 return {}; 995 } 996 997 void arith::IndexCastOp::getCanonicalizationPatterns( 998 OwningRewritePatternList &patterns, MLIRContext *context) { 999 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 1000 } 1001 1002 //===----------------------------------------------------------------------===// 1003 // BitcastOp 1004 //===----------------------------------------------------------------------===// 1005 1006 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1007 if (!areValidCastInputsAndOutputs(inputs, outputs)) 1008 return false; 1009 1010 auto srcType = 1011 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 1012 auto dstType = 1013 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 1014 if (!srcType || !dstType) 1015 return false; 1016 1017 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 1018 } 1019 1020 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 1021 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 1022 1023 auto resType = getType(); 1024 auto operand = operands[0]; 1025 if (!operand) 1026 return {}; 1027 1028 /// Bitcast dense elements. 1029 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 1030 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 1031 /// Other shaped types unhandled. 1032 if (resType.isa<ShapedType>()) 1033 return {}; 1034 1035 /// Bitcast integer or float to integer or float. 1036 APInt bits = operand.isa<FloatAttr>() 1037 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 1038 : operand.cast<IntegerAttr>().getValue(); 1039 1040 if (auto resFloatType = resType.dyn_cast<FloatType>()) 1041 return FloatAttr::get(resType, 1042 APFloat(resFloatType.getFloatSemantics(), bits)); 1043 return IntegerAttr::get(resType, bits); 1044 } 1045 1046 void arith::BitcastOp::getCanonicalizationPatterns( 1047 OwningRewritePatternList &patterns, MLIRContext *context) { 1048 patterns.insert<BitcastOfBitcast>(context); 1049 } 1050 1051 //===----------------------------------------------------------------------===// 1052 // Helpers for compare ops 1053 //===----------------------------------------------------------------------===// 1054 1055 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1056 static Type getI1SameShape(Type type) { 1057 auto i1Type = IntegerType::get(type.getContext(), 1); 1058 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1059 return RankedTensorType::get(tensorType.getShape(), i1Type); 1060 if (type.isa<UnrankedTensorType>()) 1061 return UnrankedTensorType::get(i1Type); 1062 if (auto vectorType = type.dyn_cast<VectorType>()) 1063 return VectorType::get(vectorType.getShape(), i1Type, 1064 vectorType.getNumScalableDims()); 1065 return i1Type; 1066 } 1067 1068 //===----------------------------------------------------------------------===// 1069 // CmpIOp 1070 //===----------------------------------------------------------------------===// 1071 1072 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1073 /// comparison predicates. 1074 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1075 const APInt &lhs, const APInt &rhs) { 1076 switch (predicate) { 1077 case arith::CmpIPredicate::eq: 1078 return lhs.eq(rhs); 1079 case arith::CmpIPredicate::ne: 1080 return lhs.ne(rhs); 1081 case arith::CmpIPredicate::slt: 1082 return lhs.slt(rhs); 1083 case arith::CmpIPredicate::sle: 1084 return lhs.sle(rhs); 1085 case arith::CmpIPredicate::sgt: 1086 return lhs.sgt(rhs); 1087 case arith::CmpIPredicate::sge: 1088 return lhs.sge(rhs); 1089 case arith::CmpIPredicate::ult: 1090 return lhs.ult(rhs); 1091 case arith::CmpIPredicate::ule: 1092 return lhs.ule(rhs); 1093 case arith::CmpIPredicate::ugt: 1094 return lhs.ugt(rhs); 1095 case arith::CmpIPredicate::uge: 1096 return lhs.uge(rhs); 1097 } 1098 llvm_unreachable("unknown cmpi predicate kind"); 1099 } 1100 1101 /// Returns true if the predicate is true for two equal operands. 1102 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1103 switch (predicate) { 1104 case arith::CmpIPredicate::eq: 1105 case arith::CmpIPredicate::sle: 1106 case arith::CmpIPredicate::sge: 1107 case arith::CmpIPredicate::ule: 1108 case arith::CmpIPredicate::uge: 1109 return true; 1110 case arith::CmpIPredicate::ne: 1111 case arith::CmpIPredicate::slt: 1112 case arith::CmpIPredicate::sgt: 1113 case arith::CmpIPredicate::ult: 1114 case arith::CmpIPredicate::ugt: 1115 return false; 1116 } 1117 llvm_unreachable("unknown cmpi predicate kind"); 1118 } 1119 1120 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { 1121 auto boolAttr = BoolAttr::get(ctx, value); 1122 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); 1123 if (!shapedType) 1124 return boolAttr; 1125 return DenseElementsAttr::get(shapedType, boolAttr); 1126 } 1127 1128 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1129 assert(operands.size() == 2 && "cmpi takes two operands"); 1130 1131 // cmpi(pred, x, x) 1132 if (getLhs() == getRhs()) { 1133 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1134 return getBoolAttribute(getType(), getContext(), val); 1135 } 1136 1137 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1138 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1139 if (!lhs || !rhs) 1140 return {}; 1141 1142 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1143 return BoolAttr::get(getContext(), val); 1144 } 1145 1146 //===----------------------------------------------------------------------===// 1147 // CmpFOp 1148 //===----------------------------------------------------------------------===// 1149 1150 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1151 /// comparison predicates. 1152 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1153 const APFloat &lhs, const APFloat &rhs) { 1154 auto cmpResult = lhs.compare(rhs); 1155 switch (predicate) { 1156 case arith::CmpFPredicate::AlwaysFalse: 1157 return false; 1158 case arith::CmpFPredicate::OEQ: 1159 return cmpResult == APFloat::cmpEqual; 1160 case arith::CmpFPredicate::OGT: 1161 return cmpResult == APFloat::cmpGreaterThan; 1162 case arith::CmpFPredicate::OGE: 1163 return cmpResult == APFloat::cmpGreaterThan || 1164 cmpResult == APFloat::cmpEqual; 1165 case arith::CmpFPredicate::OLT: 1166 return cmpResult == APFloat::cmpLessThan; 1167 case arith::CmpFPredicate::OLE: 1168 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1169 case arith::CmpFPredicate::ONE: 1170 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1171 case arith::CmpFPredicate::ORD: 1172 return cmpResult != APFloat::cmpUnordered; 1173 case arith::CmpFPredicate::UEQ: 1174 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1175 case arith::CmpFPredicate::UGT: 1176 return cmpResult == APFloat::cmpUnordered || 1177 cmpResult == APFloat::cmpGreaterThan; 1178 case arith::CmpFPredicate::UGE: 1179 return cmpResult == APFloat::cmpUnordered || 1180 cmpResult == APFloat::cmpGreaterThan || 1181 cmpResult == APFloat::cmpEqual; 1182 case arith::CmpFPredicate::ULT: 1183 return cmpResult == APFloat::cmpUnordered || 1184 cmpResult == APFloat::cmpLessThan; 1185 case arith::CmpFPredicate::ULE: 1186 return cmpResult == APFloat::cmpUnordered || 1187 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1188 case arith::CmpFPredicate::UNE: 1189 return cmpResult != APFloat::cmpEqual; 1190 case arith::CmpFPredicate::UNO: 1191 return cmpResult == APFloat::cmpUnordered; 1192 case arith::CmpFPredicate::AlwaysTrue: 1193 return true; 1194 } 1195 llvm_unreachable("unknown cmpf predicate kind"); 1196 } 1197 1198 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1199 assert(operands.size() == 2 && "cmpf takes two operands"); 1200 1201 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1202 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1203 1204 if (!lhs || !rhs) 1205 return {}; 1206 1207 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1208 return BoolAttr::get(getContext(), val); 1209 } 1210 1211 //===----------------------------------------------------------------------===// 1212 // Atomic Enum 1213 //===----------------------------------------------------------------------===// 1214 1215 /// Returns the identity value attribute associated with an AtomicRMWKind op. 1216 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, 1217 OpBuilder &builder, Location loc) { 1218 switch (kind) { 1219 case AtomicRMWKind::maxf: 1220 return builder.getFloatAttr( 1221 resultType, 1222 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1223 /*Negative=*/true)); 1224 case AtomicRMWKind::addf: 1225 case AtomicRMWKind::addi: 1226 case AtomicRMWKind::maxu: 1227 case AtomicRMWKind::ori: 1228 return builder.getZeroAttr(resultType); 1229 case AtomicRMWKind::andi: 1230 return builder.getIntegerAttr( 1231 resultType, 1232 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth())); 1233 case AtomicRMWKind::maxs: 1234 return builder.getIntegerAttr( 1235 resultType, 1236 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth())); 1237 case AtomicRMWKind::minf: 1238 return builder.getFloatAttr( 1239 resultType, 1240 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(), 1241 /*Negative=*/false)); 1242 case AtomicRMWKind::mins: 1243 return builder.getIntegerAttr( 1244 resultType, 1245 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth())); 1246 case AtomicRMWKind::minu: 1247 return builder.getIntegerAttr( 1248 resultType, 1249 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth())); 1250 case AtomicRMWKind::muli: 1251 return builder.getIntegerAttr(resultType, 1); 1252 case AtomicRMWKind::mulf: 1253 return builder.getFloatAttr(resultType, 1); 1254 // TODO: Add remaining reduction operations. 1255 default: 1256 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1257 break; 1258 } 1259 return nullptr; 1260 } 1261 1262 /// Returns the identity value associated with an AtomicRMWKind op. 1263 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, 1264 OpBuilder &builder, Location loc) { 1265 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); 1266 return builder.create<arith::ConstantOp>(loc, attr); 1267 } 1268 1269 /// Return the value obtained by applying the reduction operation kind 1270 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. 1271 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, 1272 Location loc, Value lhs, Value rhs) { 1273 switch (op) { 1274 case AtomicRMWKind::addf: 1275 return builder.create<arith::AddFOp>(loc, lhs, rhs); 1276 case AtomicRMWKind::addi: 1277 return builder.create<arith::AddIOp>(loc, lhs, rhs); 1278 case AtomicRMWKind::mulf: 1279 return builder.create<arith::MulFOp>(loc, lhs, rhs); 1280 case AtomicRMWKind::muli: 1281 return builder.create<arith::MulIOp>(loc, lhs, rhs); 1282 case AtomicRMWKind::maxf: 1283 return builder.create<arith::MaxFOp>(loc, lhs, rhs); 1284 case AtomicRMWKind::minf: 1285 return builder.create<arith::MinFOp>(loc, lhs, rhs); 1286 case AtomicRMWKind::maxs: 1287 return builder.create<arith::MaxSIOp>(loc, lhs, rhs); 1288 case AtomicRMWKind::mins: 1289 return builder.create<arith::MinSIOp>(loc, lhs, rhs); 1290 case AtomicRMWKind::maxu: 1291 return builder.create<arith::MaxUIOp>(loc, lhs, rhs); 1292 case AtomicRMWKind::minu: 1293 return builder.create<arith::MinUIOp>(loc, lhs, rhs); 1294 case AtomicRMWKind::ori: 1295 return builder.create<arith::OrIOp>(loc, lhs, rhs); 1296 case AtomicRMWKind::andi: 1297 return builder.create<arith::AndIOp>(loc, lhs, rhs); 1298 // TODO: Add remaining reduction operations. 1299 default: 1300 (void)emitOptionalError(loc, "Reduction operation type not supported"); 1301 break; 1302 } 1303 return nullptr; 1304 } 1305 1306 //===----------------------------------------------------------------------===// 1307 // TableGen'd op method definitions 1308 //===----------------------------------------------------------------------===// 1309 1310 #define GET_OP_CLASSES 1311 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1312 1313 //===----------------------------------------------------------------------===// 1314 // TableGen'd enum attribute definitions 1315 //===----------------------------------------------------------------------===// 1316 1317 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1318