1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/CommonFolders.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 using namespace mlir; 18 using namespace mlir::arith; 19 20 //===----------------------------------------------------------------------===// 21 // Pattern helpers 22 //===----------------------------------------------------------------------===// 23 24 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 25 Attribute lhs, Attribute rhs) { 26 return builder.getIntegerAttr(res.getType(), 27 lhs.cast<IntegerAttr>().getInt() + 28 rhs.cast<IntegerAttr>().getInt()); 29 } 30 31 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 32 Attribute lhs, Attribute rhs) { 33 return builder.getIntegerAttr(res.getType(), 34 lhs.cast<IntegerAttr>().getInt() - 35 rhs.cast<IntegerAttr>().getInt()); 36 } 37 38 /// Invert an integer comparison predicate. 39 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { 40 switch (pred) { 41 case arith::CmpIPredicate::eq: 42 return arith::CmpIPredicate::ne; 43 case arith::CmpIPredicate::ne: 44 return arith::CmpIPredicate::eq; 45 case arith::CmpIPredicate::slt: 46 return arith::CmpIPredicate::sge; 47 case arith::CmpIPredicate::sle: 48 return arith::CmpIPredicate::sgt; 49 case arith::CmpIPredicate::sgt: 50 return arith::CmpIPredicate::sle; 51 case arith::CmpIPredicate::sge: 52 return arith::CmpIPredicate::slt; 53 case arith::CmpIPredicate::ult: 54 return arith::CmpIPredicate::uge; 55 case arith::CmpIPredicate::ule: 56 return arith::CmpIPredicate::ugt; 57 case arith::CmpIPredicate::ugt: 58 return arith::CmpIPredicate::ule; 59 case arith::CmpIPredicate::uge: 60 return arith::CmpIPredicate::ult; 61 } 62 llvm_unreachable("unknown cmpi predicate kind"); 63 } 64 65 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 66 return arith::CmpIPredicateAttr::get(pred.getContext(), 67 invertPredicate(pred.getValue())); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // TableGen'd canonicalization patterns 72 //===----------------------------------------------------------------------===// 73 74 namespace { 75 #include "ArithmeticCanonicalization.inc" 76 } // end anonymous namespace 77 78 //===----------------------------------------------------------------------===// 79 // ConstantOp 80 //===----------------------------------------------------------------------===// 81 82 void arith::ConstantOp::getAsmResultNames( 83 function_ref<void(Value, StringRef)> setNameFn) { 84 auto type = getType(); 85 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 86 auto intType = type.dyn_cast<IntegerType>(); 87 88 // Sugar i1 constants with 'true' and 'false'. 89 if (intType && intType.getWidth() == 1) 90 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 91 92 // Otherwise, build a compex name with the value and type. 93 SmallString<32> specialNameBuffer; 94 llvm::raw_svector_ostream specialName(specialNameBuffer); 95 specialName << 'c' << intCst.getInt(); 96 if (intType) 97 specialName << '_' << type; 98 setNameFn(getResult(), specialName.str()); 99 } else { 100 setNameFn(getResult(), "cst"); 101 } 102 } 103 104 /// TODO: disallow arith.constant to return anything other than signless integer 105 /// or float like. 106 static LogicalResult verify(arith::ConstantOp op) { 107 auto type = op.getType(); 108 // The value's type must match the return type. 109 if (op.getValue().getType() != type) { 110 return op.emitOpError() << "value type " << op.getValue().getType() 111 << " must match return type: " << type; 112 } 113 // Integer values must be signless. 114 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 115 return op.emitOpError("integer return type must be signless"); 116 // Any float or elements attribute are acceptable. 117 if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 118 return op.emitOpError( 119 "value must be an integer, float, or elements attribute"); 120 } 121 return success(); 122 } 123 124 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 125 // The value's type must be the same as the provided type. 126 if (value.getType() != type) 127 return false; 128 // Integer values must be signless. 129 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 130 return false; 131 // Integer, float, and element attributes are buildable. 132 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 133 } 134 135 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 136 return getValue(); 137 } 138 139 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 140 int64_t value, unsigned width) { 141 auto type = builder.getIntegerType(width); 142 arith::ConstantOp::build(builder, result, type, 143 builder.getIntegerAttr(type, value)); 144 } 145 146 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 147 int64_t value, Type type) { 148 assert(type.isSignlessInteger() && 149 "ConstantIntOp can only have signless integer type values"); 150 arith::ConstantOp::build(builder, result, type, 151 builder.getIntegerAttr(type, value)); 152 } 153 154 bool arith::ConstantIntOp::classof(Operation *op) { 155 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 156 return constOp.getType().isSignlessInteger(); 157 return false; 158 } 159 160 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 161 const APFloat &value, FloatType type) { 162 arith::ConstantOp::build(builder, result, type, 163 builder.getFloatAttr(type, value)); 164 } 165 166 bool arith::ConstantFloatOp::classof(Operation *op) { 167 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 168 return constOp.getType().isa<FloatType>(); 169 return false; 170 } 171 172 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 173 int64_t value) { 174 arith::ConstantOp::build(builder, result, builder.getIndexType(), 175 builder.getIndexAttr(value)); 176 } 177 178 bool arith::ConstantIndexOp::classof(Operation *op) { 179 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 180 return constOp.getType().isIndex(); 181 return false; 182 } 183 184 //===----------------------------------------------------------------------===// 185 // AddIOp 186 //===----------------------------------------------------------------------===// 187 188 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 189 // addi(x, 0) -> x 190 if (matchPattern(getRhs(), m_Zero())) 191 return getLhs(); 192 193 return constFoldBinaryOp<IntegerAttr>(operands, 194 [](APInt a, APInt b) { return a + b; }); 195 } 196 197 void arith::AddIOp::getCanonicalizationPatterns( 198 OwningRewritePatternList &patterns, MLIRContext *context) { 199 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 200 context); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // SubIOp 205 //===----------------------------------------------------------------------===// 206 207 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 208 // subi(x,x) -> 0 209 if (getOperand(0) == getOperand(1)) 210 return Builder(getContext()).getZeroAttr(getType()); 211 // subi(x,0) -> x 212 if (matchPattern(getRhs(), m_Zero())) 213 return getLhs(); 214 215 return constFoldBinaryOp<IntegerAttr>(operands, 216 [](APInt a, APInt b) { return a - b; }); 217 } 218 219 void arith::SubIOp::getCanonicalizationPatterns( 220 OwningRewritePatternList &patterns, MLIRContext *context) { 221 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 222 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 223 SubILHSSubConstantLHS>(context); 224 } 225 226 //===----------------------------------------------------------------------===// 227 // MulIOp 228 //===----------------------------------------------------------------------===// 229 230 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 231 // muli(x, 0) -> 0 232 if (matchPattern(getRhs(), m_Zero())) 233 return getRhs(); 234 // muli(x, 1) -> x 235 if (matchPattern(getRhs(), m_One())) 236 return getOperand(0); 237 // TODO: Handle the overflow case. 238 239 // default folder 240 return constFoldBinaryOp<IntegerAttr>(operands, 241 [](APInt a, APInt b) { return a * b; }); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // DivUIOp 246 //===----------------------------------------------------------------------===// 247 248 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 249 // Don't fold if it would require a division by zero. 250 bool div0 = false; 251 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 252 if (div0 || !b) { 253 div0 = true; 254 return a; 255 } 256 return a.udiv(b); 257 }); 258 259 // Fold out division by one. Assumes all tensors of all ones are splats. 260 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 261 if (rhs.getValue() == 1) 262 return getLhs(); 263 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 264 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 265 return getLhs(); 266 } 267 268 return div0 ? Attribute() : result; 269 } 270 271 //===----------------------------------------------------------------------===// 272 // DivSIOp 273 //===----------------------------------------------------------------------===// 274 275 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 276 // Don't fold if it would overflow or if it requires a division by zero. 277 bool overflowOrDiv0 = false; 278 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 279 if (overflowOrDiv0 || !b) { 280 overflowOrDiv0 = true; 281 return a; 282 } 283 return a.sdiv_ov(b, overflowOrDiv0); 284 }); 285 286 // Fold out division by one. Assumes all tensors of all ones are splats. 287 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 288 if (rhs.getValue() == 1) 289 return getLhs(); 290 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 291 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 292 return getLhs(); 293 } 294 295 return overflowOrDiv0 ? Attribute() : result; 296 } 297 298 //===----------------------------------------------------------------------===// 299 // Ceil and floor division folding helpers 300 //===----------------------------------------------------------------------===// 301 302 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { 303 // Returns (a-1)/b + 1 304 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 305 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 306 return val.sadd_ov(one, overflow); 307 } 308 309 //===----------------------------------------------------------------------===// 310 // CeilDivUIOp 311 //===----------------------------------------------------------------------===// 312 313 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 314 bool overflowOrDiv0 = false; 315 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 316 if (overflowOrDiv0 || !b) { 317 overflowOrDiv0 = true; 318 return a; 319 } 320 APInt quotient = a.udiv(b); 321 if (!a.urem(b)) 322 return quotient; 323 APInt one(a.getBitWidth(), 1, true); 324 return quotient.uadd_ov(one, overflowOrDiv0); 325 }); 326 // Fold out ceil division by one. Assumes all tensors of all ones are 327 // splats. 328 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 329 if (rhs.getValue() == 1) 330 return getLhs(); 331 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 332 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 333 return getLhs(); 334 } 335 336 return overflowOrDiv0 ? Attribute() : result; 337 } 338 339 //===----------------------------------------------------------------------===// 340 // CeilDivSIOp 341 //===----------------------------------------------------------------------===// 342 343 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 344 // Don't fold if it would overflow or if it requires a division by zero. 345 bool overflowOrDiv0 = false; 346 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 347 if (overflowOrDiv0 || !b) { 348 overflowOrDiv0 = true; 349 return a; 350 } 351 unsigned bits = a.getBitWidth(); 352 APInt zero = APInt::getZero(bits); 353 if (a.sgt(zero) && b.sgt(zero)) { 354 // Both positive, return ceil(a, b). 355 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 356 } 357 if (a.slt(zero) && b.slt(zero)) { 358 // Both negative, return ceil(-a, -b). 359 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 360 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 361 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 362 } 363 if (a.slt(zero) && b.sgt(zero)) { 364 // A is negative, b is positive, return - ( -a / b). 365 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 366 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 367 return zero.ssub_ov(div, overflowOrDiv0); 368 } 369 // A is positive (or zero), b is negative, return - (a / -b). 370 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 371 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 372 return zero.ssub_ov(div, overflowOrDiv0); 373 }); 374 375 // Fold out ceil division by one. Assumes all tensors of all ones are 376 // splats. 377 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 378 if (rhs.getValue() == 1) 379 return getLhs(); 380 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 381 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 382 return getLhs(); 383 } 384 385 return overflowOrDiv0 ? Attribute() : result; 386 } 387 388 //===----------------------------------------------------------------------===// 389 // FloorDivSIOp 390 //===----------------------------------------------------------------------===// 391 392 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 393 // Don't fold if it would overflow or if it requires a division by zero. 394 bool overflowOrDiv0 = false; 395 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 396 if (overflowOrDiv0 || !b) { 397 overflowOrDiv0 = true; 398 return a; 399 } 400 unsigned bits = a.getBitWidth(); 401 APInt zero = APInt::getZero(bits); 402 if (a.sge(zero) && b.sgt(zero)) { 403 // Both positive (or a is zero), return a / b. 404 return a.sdiv_ov(b, overflowOrDiv0); 405 } 406 if (a.sle(zero) && b.slt(zero)) { 407 // Both negative (or a is zero), return -a / -b. 408 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 409 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 410 return posA.sdiv_ov(posB, overflowOrDiv0); 411 } 412 if (a.slt(zero) && b.sgt(zero)) { 413 // A is negative, b is positive, return - ceil(-a, b). 414 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 415 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 416 return zero.ssub_ov(ceil, overflowOrDiv0); 417 } 418 // A is positive, b is negative, return - ceil(a, -b). 419 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 420 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 421 return zero.ssub_ov(ceil, overflowOrDiv0); 422 }); 423 424 // Fold out floor division by one. Assumes all tensors of all ones are 425 // splats. 426 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 427 if (rhs.getValue() == 1) 428 return getLhs(); 429 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 430 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 431 return getLhs(); 432 } 433 434 return overflowOrDiv0 ? Attribute() : result; 435 } 436 437 //===----------------------------------------------------------------------===// 438 // RemUIOp 439 //===----------------------------------------------------------------------===// 440 441 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 442 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 443 if (!rhs) 444 return {}; 445 auto rhsValue = rhs.getValue(); 446 447 // x % 1 = 0 448 if (rhsValue.isOneValue()) 449 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 450 451 // Don't fold if it requires division by zero. 452 if (rhsValue.isNullValue()) 453 return {}; 454 455 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 456 if (!lhs) 457 return {}; 458 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 459 } 460 461 //===----------------------------------------------------------------------===// 462 // RemSIOp 463 //===----------------------------------------------------------------------===// 464 465 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 466 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 467 if (!rhs) 468 return {}; 469 auto rhsValue = rhs.getValue(); 470 471 // x % 1 = 0 472 if (rhsValue.isOneValue()) 473 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 474 475 // Don't fold if it requires division by zero. 476 if (rhsValue.isNullValue()) 477 return {}; 478 479 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 480 if (!lhs) 481 return {}; 482 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 483 } 484 485 //===----------------------------------------------------------------------===// 486 // AndIOp 487 //===----------------------------------------------------------------------===// 488 489 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 490 /// and(x, 0) -> 0 491 if (matchPattern(getRhs(), m_Zero())) 492 return getRhs(); 493 /// and(x, allOnes) -> x 494 APInt intValue; 495 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 496 return getLhs(); 497 /// and(x, x) -> x 498 if (getLhs() == getRhs()) 499 return getRhs(); 500 501 return constFoldBinaryOp<IntegerAttr>(operands, 502 [](APInt a, APInt b) { return a & b; }); 503 } 504 505 //===----------------------------------------------------------------------===// 506 // OrIOp 507 //===----------------------------------------------------------------------===// 508 509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 510 /// or(x, 0) -> x 511 if (matchPattern(getRhs(), m_Zero())) 512 return getLhs(); 513 /// or(x, x) -> x 514 if (getLhs() == getRhs()) 515 return getRhs(); 516 /// or(x, <all ones>) -> <all ones> 517 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 518 if (rhsAttr.getValue().isAllOnes()) 519 return rhsAttr; 520 521 return constFoldBinaryOp<IntegerAttr>(operands, 522 [](APInt a, APInt b) { return a | b; }); 523 } 524 525 //===----------------------------------------------------------------------===// 526 // XOrIOp 527 //===----------------------------------------------------------------------===// 528 529 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 530 /// xor(x, 0) -> x 531 if (matchPattern(getRhs(), m_Zero())) 532 return getLhs(); 533 /// xor(x, x) -> 0 534 if (getLhs() == getRhs()) 535 return Builder(getContext()).getZeroAttr(getType()); 536 537 return constFoldBinaryOp<IntegerAttr>(operands, 538 [](APInt a, APInt b) { return a ^ b; }); 539 } 540 541 void arith::XOrIOp::getCanonicalizationPatterns( 542 OwningRewritePatternList &patterns, MLIRContext *context) { 543 patterns.insert<XOrINotCmpI>(context); 544 } 545 546 //===----------------------------------------------------------------------===// 547 // AddFOp 548 //===----------------------------------------------------------------------===// 549 550 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 551 return constFoldBinaryOp<FloatAttr>( 552 operands, [](APFloat a, APFloat b) { return a + b; }); 553 } 554 555 //===----------------------------------------------------------------------===// 556 // SubFOp 557 //===----------------------------------------------------------------------===// 558 559 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 560 return constFoldBinaryOp<FloatAttr>( 561 operands, [](APFloat a, APFloat b) { return a - b; }); 562 } 563 564 //===----------------------------------------------------------------------===// 565 // MaxSIOp 566 //===----------------------------------------------------------------------===// 567 568 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 569 assert(operands.size() == 2 && "binary operation takes two operands"); 570 571 // maxsi(x,x) -> x 572 if (getLhs() == getRhs()) 573 return getRhs(); 574 575 APInt intValue; 576 // maxsi(x,MAX_INT) -> MAX_INT 577 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 578 intValue.isMaxSignedValue()) 579 return getRhs(); 580 581 // maxsi(x, MIN_INT) -> x 582 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 583 intValue.isMinSignedValue()) 584 return getLhs(); 585 586 return constFoldBinaryOp<IntegerAttr>( 587 operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); }); 588 } 589 590 //===----------------------------------------------------------------------===// 591 // MaxUIOp 592 //===----------------------------------------------------------------------===// 593 594 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 595 assert(operands.size() == 2 && "binary operation takes two operands"); 596 597 // maxui(x,x) -> x 598 if (getLhs() == getRhs()) 599 return getRhs(); 600 601 APInt intValue; 602 // maxui(x,MAX_INT) -> MAX_INT 603 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 604 return getRhs(); 605 606 // maxui(x, MIN_INT) -> x 607 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 608 return getLhs(); 609 610 return constFoldBinaryOp<IntegerAttr>( 611 operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); }); 612 } 613 614 //===----------------------------------------------------------------------===// 615 // MinSIOp 616 //===----------------------------------------------------------------------===// 617 618 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 619 assert(operands.size() == 2 && "binary operation takes two operands"); 620 621 // minsi(x,x) -> x 622 if (getLhs() == getRhs()) 623 return getRhs(); 624 625 APInt intValue; 626 // minsi(x,MIN_INT) -> MIN_INT 627 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 628 intValue.isMinSignedValue()) 629 return getRhs(); 630 631 // minsi(x, MAX_INT) -> x 632 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 633 intValue.isMaxSignedValue()) 634 return getLhs(); 635 636 return constFoldBinaryOp<IntegerAttr>( 637 operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); }); 638 } 639 640 //===----------------------------------------------------------------------===// 641 // MinUIOp 642 //===----------------------------------------------------------------------===// 643 644 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 645 assert(operands.size() == 2 && "binary operation takes two operands"); 646 647 // minui(x,x) -> x 648 if (getLhs() == getRhs()) 649 return getRhs(); 650 651 APInt intValue; 652 // minui(x,MIN_INT) -> MIN_INT 653 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 654 return getRhs(); 655 656 // minui(x, MAX_INT) -> x 657 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 658 return getLhs(); 659 660 return constFoldBinaryOp<IntegerAttr>( 661 operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); }); 662 } 663 664 //===----------------------------------------------------------------------===// 665 // MulFOp 666 //===----------------------------------------------------------------------===// 667 668 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 669 return constFoldBinaryOp<FloatAttr>( 670 operands, [](APFloat a, APFloat b) { return a * b; }); 671 } 672 673 //===----------------------------------------------------------------------===// 674 // DivFOp 675 //===----------------------------------------------------------------------===// 676 677 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 678 return constFoldBinaryOp<FloatAttr>( 679 operands, [](APFloat a, APFloat b) { return a / b; }); 680 } 681 682 //===----------------------------------------------------------------------===// 683 // Utility functions for verifying cast ops 684 //===----------------------------------------------------------------------===// 685 686 template <typename... Types> 687 using type_list = std::tuple<Types...> *; 688 689 /// Returns a non-null type only if the provided type is one of the allowed 690 /// types or one of the allowed shaped types of the allowed types. Returns the 691 /// element type if a valid shaped type is provided. 692 template <typename... ShapedTypes, typename... ElementTypes> 693 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 694 type_list<ElementTypes...>) { 695 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 696 return {}; 697 698 auto underlyingType = getElementTypeOrSelf(type); 699 if (!underlyingType.isa<ElementTypes...>()) 700 return {}; 701 702 return underlyingType; 703 } 704 705 /// Get allowed underlying types for vectors and tensors. 706 template <typename... ElementTypes> 707 static Type getTypeIfLike(Type type) { 708 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 709 type_list<ElementTypes...>()); 710 } 711 712 /// Get allowed underlying types for vectors, tensors, and memrefs. 713 template <typename... ElementTypes> 714 static Type getTypeIfLikeOrMemRef(Type type) { 715 return getUnderlyingType(type, 716 type_list<VectorType, TensorType, MemRefType>(), 717 type_list<ElementTypes...>()); 718 } 719 720 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 721 return inputs.size() == 1 && outputs.size() == 1 && 722 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 723 } 724 725 //===----------------------------------------------------------------------===// 726 // Verifiers for integer and floating point extension/truncation ops 727 //===----------------------------------------------------------------------===// 728 729 // Extend ops can only extend to a wider type. 730 template <typename ValType, typename Op> 731 static LogicalResult verifyExtOp(Op op) { 732 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 733 Type dstType = getElementTypeOrSelf(op.getType()); 734 735 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 736 return op.emitError("result type ") 737 << dstType << " must be wider than operand type " << srcType; 738 739 return success(); 740 } 741 742 // Truncate ops can only truncate to a shorter type. 743 template <typename ValType, typename Op> 744 static LogicalResult verifyTruncateOp(Op op) { 745 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 746 Type dstType = getElementTypeOrSelf(op.getType()); 747 748 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 749 return op.emitError("result type ") 750 << dstType << " must be shorter than operand type " << srcType; 751 752 return success(); 753 } 754 755 /// Validate a cast that changes the width of a type. 756 template <template <typename> class WidthComparator, typename... ElementTypes> 757 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 758 if (!areValidCastInputsAndOutputs(inputs, outputs)) 759 return false; 760 761 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 762 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 763 if (!srcType || !dstType) 764 return false; 765 766 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 767 srcType.getIntOrFloatBitWidth()); 768 } 769 770 //===----------------------------------------------------------------------===// 771 // ExtUIOp 772 //===----------------------------------------------------------------------===// 773 774 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 775 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 776 return IntegerAttr::get( 777 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 778 779 return {}; 780 } 781 782 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 783 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 784 } 785 786 //===----------------------------------------------------------------------===// 787 // ExtSIOp 788 //===----------------------------------------------------------------------===// 789 790 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 791 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 792 return IntegerAttr::get( 793 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 794 795 return {}; 796 } 797 798 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 799 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 800 } 801 802 //===----------------------------------------------------------------------===// 803 // ExtFOp 804 //===----------------------------------------------------------------------===// 805 806 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 807 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 808 } 809 810 //===----------------------------------------------------------------------===// 811 // TruncIOp 812 //===----------------------------------------------------------------------===// 813 814 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 815 // trunci(zexti(a)) -> a 816 // trunci(sexti(a)) -> a 817 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 818 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 819 return getOperand().getDefiningOp()->getOperand(0); 820 821 assert(operands.size() == 1 && "unary operation takes one operand"); 822 823 if (!operands[0]) 824 return {}; 825 826 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 827 return IntegerAttr::get( 828 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 829 } 830 831 return {}; 832 } 833 834 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 835 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 836 } 837 838 //===----------------------------------------------------------------------===// 839 // TruncFOp 840 //===----------------------------------------------------------------------===// 841 842 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 843 /// can be represented without precision loss or rounding. 844 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 845 assert(operands.size() == 1 && "unary operation takes one operand"); 846 847 auto constOperand = operands.front(); 848 if (!constOperand || !constOperand.isa<FloatAttr>()) 849 return {}; 850 851 // Convert to target type via 'double'. 852 double sourceValue = 853 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 854 auto targetAttr = FloatAttr::get(getType(), sourceValue); 855 856 // Propagate if constant's value does not change after truncation. 857 if (sourceValue == targetAttr.getValue().convertToDouble()) 858 return targetAttr; 859 860 return {}; 861 } 862 863 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 864 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 865 } 866 867 //===----------------------------------------------------------------------===// 868 // Verifiers for casts between integers and floats. 869 //===----------------------------------------------------------------------===// 870 871 template <typename From, typename To> 872 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 873 if (!areValidCastInputsAndOutputs(inputs, outputs)) 874 return false; 875 876 auto srcType = getTypeIfLike<From>(inputs.front()); 877 auto dstType = getTypeIfLike<To>(outputs.back()); 878 879 return srcType && dstType; 880 } 881 882 //===----------------------------------------------------------------------===// 883 // UIToFPOp 884 //===----------------------------------------------------------------------===// 885 886 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 887 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 888 } 889 890 //===----------------------------------------------------------------------===// 891 // SIToFPOp 892 //===----------------------------------------------------------------------===// 893 894 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 895 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 896 } 897 898 //===----------------------------------------------------------------------===// 899 // FPToUIOp 900 //===----------------------------------------------------------------------===// 901 902 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 903 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 904 } 905 906 //===----------------------------------------------------------------------===// 907 // FPToSIOp 908 //===----------------------------------------------------------------------===// 909 910 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 911 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 912 } 913 914 //===----------------------------------------------------------------------===// 915 // IndexCastOp 916 //===----------------------------------------------------------------------===// 917 918 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 919 TypeRange outputs) { 920 if (!areValidCastInputsAndOutputs(inputs, outputs)) 921 return false; 922 923 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 924 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 925 if (!srcType || !dstType) 926 return false; 927 928 return (srcType.isIndex() && dstType.isSignlessInteger()) || 929 (srcType.isSignlessInteger() && dstType.isIndex()); 930 } 931 932 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 933 // index_cast(constant) -> constant 934 // A little hack because we go through int. Otherwise, the size of the 935 // constant might need to change. 936 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 937 return IntegerAttr::get(getType(), value.getInt()); 938 939 return {}; 940 } 941 942 void arith::IndexCastOp::getCanonicalizationPatterns( 943 OwningRewritePatternList &patterns, MLIRContext *context) { 944 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 945 } 946 947 //===----------------------------------------------------------------------===// 948 // BitcastOp 949 //===----------------------------------------------------------------------===// 950 951 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 952 if (!areValidCastInputsAndOutputs(inputs, outputs)) 953 return false; 954 955 auto srcType = 956 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 957 auto dstType = 958 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 959 if (!srcType || !dstType) 960 return false; 961 962 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 963 } 964 965 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 966 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 967 968 auto resType = getType(); 969 auto operand = operands[0]; 970 if (!operand) 971 return {}; 972 973 /// Bitcast dense elements. 974 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 975 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 976 /// Other shaped types unhandled. 977 if (resType.isa<ShapedType>()) 978 return {}; 979 980 /// Bitcast integer or float to integer or float. 981 APInt bits = operand.isa<FloatAttr>() 982 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 983 : operand.cast<IntegerAttr>().getValue(); 984 985 if (auto resFloatType = resType.dyn_cast<FloatType>()) 986 return FloatAttr::get(resType, 987 APFloat(resFloatType.getFloatSemantics(), bits)); 988 return IntegerAttr::get(resType, bits); 989 } 990 991 void arith::BitcastOp::getCanonicalizationPatterns( 992 OwningRewritePatternList &patterns, MLIRContext *context) { 993 patterns.insert<BitcastOfBitcast>(context); 994 } 995 996 //===----------------------------------------------------------------------===// 997 // Helpers for compare ops 998 //===----------------------------------------------------------------------===// 999 1000 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 1001 static Type getI1SameShape(Type type) { 1002 auto i1Type = IntegerType::get(type.getContext(), 1); 1003 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1004 return RankedTensorType::get(tensorType.getShape(), i1Type); 1005 if (type.isa<UnrankedTensorType>()) 1006 return UnrankedTensorType::get(i1Type); 1007 if (auto vectorType = type.dyn_cast<VectorType>()) 1008 return VectorType::get(vectorType.getShape(), i1Type); 1009 return i1Type; 1010 } 1011 1012 //===----------------------------------------------------------------------===// 1013 // CmpIOp 1014 //===----------------------------------------------------------------------===// 1015 1016 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1017 /// comparison predicates. 1018 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1019 const APInt &lhs, const APInt &rhs) { 1020 switch (predicate) { 1021 case arith::CmpIPredicate::eq: 1022 return lhs.eq(rhs); 1023 case arith::CmpIPredicate::ne: 1024 return lhs.ne(rhs); 1025 case arith::CmpIPredicate::slt: 1026 return lhs.slt(rhs); 1027 case arith::CmpIPredicate::sle: 1028 return lhs.sle(rhs); 1029 case arith::CmpIPredicate::sgt: 1030 return lhs.sgt(rhs); 1031 case arith::CmpIPredicate::sge: 1032 return lhs.sge(rhs); 1033 case arith::CmpIPredicate::ult: 1034 return lhs.ult(rhs); 1035 case arith::CmpIPredicate::ule: 1036 return lhs.ule(rhs); 1037 case arith::CmpIPredicate::ugt: 1038 return lhs.ugt(rhs); 1039 case arith::CmpIPredicate::uge: 1040 return lhs.uge(rhs); 1041 } 1042 llvm_unreachable("unknown cmpi predicate kind"); 1043 } 1044 1045 /// Returns true if the predicate is true for two equal operands. 1046 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1047 switch (predicate) { 1048 case arith::CmpIPredicate::eq: 1049 case arith::CmpIPredicate::sle: 1050 case arith::CmpIPredicate::sge: 1051 case arith::CmpIPredicate::ule: 1052 case arith::CmpIPredicate::uge: 1053 return true; 1054 case arith::CmpIPredicate::ne: 1055 case arith::CmpIPredicate::slt: 1056 case arith::CmpIPredicate::sgt: 1057 case arith::CmpIPredicate::ult: 1058 case arith::CmpIPredicate::ugt: 1059 return false; 1060 } 1061 llvm_unreachable("unknown cmpi predicate kind"); 1062 } 1063 1064 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1065 assert(operands.size() == 2 && "cmpi takes two operands"); 1066 1067 // cmpi(pred, x, x) 1068 if (getLhs() == getRhs()) { 1069 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1070 return BoolAttr::get(getContext(), val); 1071 } 1072 1073 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1074 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1075 if (!lhs || !rhs) 1076 return {}; 1077 1078 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1079 return BoolAttr::get(getContext(), val); 1080 } 1081 1082 //===----------------------------------------------------------------------===// 1083 // CmpFOp 1084 //===----------------------------------------------------------------------===// 1085 1086 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1087 /// comparison predicates. 1088 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1089 const APFloat &lhs, const APFloat &rhs) { 1090 auto cmpResult = lhs.compare(rhs); 1091 switch (predicate) { 1092 case arith::CmpFPredicate::AlwaysFalse: 1093 return false; 1094 case arith::CmpFPredicate::OEQ: 1095 return cmpResult == APFloat::cmpEqual; 1096 case arith::CmpFPredicate::OGT: 1097 return cmpResult == APFloat::cmpGreaterThan; 1098 case arith::CmpFPredicate::OGE: 1099 return cmpResult == APFloat::cmpGreaterThan || 1100 cmpResult == APFloat::cmpEqual; 1101 case arith::CmpFPredicate::OLT: 1102 return cmpResult == APFloat::cmpLessThan; 1103 case arith::CmpFPredicate::OLE: 1104 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1105 case arith::CmpFPredicate::ONE: 1106 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1107 case arith::CmpFPredicate::ORD: 1108 return cmpResult != APFloat::cmpUnordered; 1109 case arith::CmpFPredicate::UEQ: 1110 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1111 case arith::CmpFPredicate::UGT: 1112 return cmpResult == APFloat::cmpUnordered || 1113 cmpResult == APFloat::cmpGreaterThan; 1114 case arith::CmpFPredicate::UGE: 1115 return cmpResult == APFloat::cmpUnordered || 1116 cmpResult == APFloat::cmpGreaterThan || 1117 cmpResult == APFloat::cmpEqual; 1118 case arith::CmpFPredicate::ULT: 1119 return cmpResult == APFloat::cmpUnordered || 1120 cmpResult == APFloat::cmpLessThan; 1121 case arith::CmpFPredicate::ULE: 1122 return cmpResult == APFloat::cmpUnordered || 1123 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1124 case arith::CmpFPredicate::UNE: 1125 return cmpResult != APFloat::cmpEqual; 1126 case arith::CmpFPredicate::UNO: 1127 return cmpResult == APFloat::cmpUnordered; 1128 case arith::CmpFPredicate::AlwaysTrue: 1129 return true; 1130 } 1131 llvm_unreachable("unknown cmpf predicate kind"); 1132 } 1133 1134 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1135 assert(operands.size() == 2 && "cmpf takes two operands"); 1136 1137 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1138 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1139 1140 if (!lhs || !rhs) 1141 return {}; 1142 1143 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1144 return BoolAttr::get(getContext(), val); 1145 } 1146 1147 //===----------------------------------------------------------------------===// 1148 // TableGen'd op method definitions 1149 //===----------------------------------------------------------------------===// 1150 1151 #define GET_OP_CLASSES 1152 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1153 1154 //===----------------------------------------------------------------------===// 1155 // TableGen'd enum attribute definitions 1156 //===----------------------------------------------------------------------===// 1157 1158 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1159