1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/CommonFolders.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 using namespace mlir; 18 using namespace mlir::arith; 19 20 //===----------------------------------------------------------------------===// 21 // Pattern helpers 22 //===----------------------------------------------------------------------===// 23 24 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 25 Attribute lhs, Attribute rhs) { 26 return builder.getIntegerAttr(res.getType(), 27 lhs.cast<IntegerAttr>().getInt() + 28 rhs.cast<IntegerAttr>().getInt()); 29 } 30 31 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 32 Attribute lhs, Attribute rhs) { 33 return builder.getIntegerAttr(res.getType(), 34 lhs.cast<IntegerAttr>().getInt() - 35 rhs.cast<IntegerAttr>().getInt()); 36 } 37 38 /// Invert an integer comparison predicate. 39 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { 40 switch (pred) { 41 case arith::CmpIPredicate::eq: 42 return arith::CmpIPredicate::ne; 43 case arith::CmpIPredicate::ne: 44 return arith::CmpIPredicate::eq; 45 case arith::CmpIPredicate::slt: 46 return arith::CmpIPredicate::sge; 47 case arith::CmpIPredicate::sle: 48 return arith::CmpIPredicate::sgt; 49 case arith::CmpIPredicate::sgt: 50 return arith::CmpIPredicate::sle; 51 case arith::CmpIPredicate::sge: 52 return arith::CmpIPredicate::slt; 53 case arith::CmpIPredicate::ult: 54 return arith::CmpIPredicate::uge; 55 case arith::CmpIPredicate::ule: 56 return arith::CmpIPredicate::ugt; 57 case arith::CmpIPredicate::ugt: 58 return arith::CmpIPredicate::ule; 59 case arith::CmpIPredicate::uge: 60 return arith::CmpIPredicate::ult; 61 } 62 llvm_unreachable("unknown cmpi predicate kind"); 63 } 64 65 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 66 return arith::CmpIPredicateAttr::get(pred.getContext(), 67 invertPredicate(pred.getValue())); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // TableGen'd canonicalization patterns 72 //===----------------------------------------------------------------------===// 73 74 namespace { 75 #include "ArithmeticCanonicalization.inc" 76 } // end anonymous namespace 77 78 //===----------------------------------------------------------------------===// 79 // ConstantOp 80 //===----------------------------------------------------------------------===// 81 82 void arith::ConstantOp::getAsmResultNames( 83 function_ref<void(Value, StringRef)> setNameFn) { 84 auto type = getType(); 85 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 86 auto intType = type.dyn_cast<IntegerType>(); 87 88 // Sugar i1 constants with 'true' and 'false'. 89 if (intType && intType.getWidth() == 1) 90 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 91 92 // Otherwise, build a compex name with the value and type. 93 SmallString<32> specialNameBuffer; 94 llvm::raw_svector_ostream specialName(specialNameBuffer); 95 specialName << 'c' << intCst.getInt(); 96 if (intType) 97 specialName << '_' << type; 98 setNameFn(getResult(), specialName.str()); 99 } else { 100 setNameFn(getResult(), "cst"); 101 } 102 } 103 104 /// TODO: disallow arith.constant to return anything other than signless integer 105 /// or float like. 106 static LogicalResult verify(arith::ConstantOp op) { 107 auto type = op.getType(); 108 // The value's type must match the return type. 109 if (op.getValue().getType() != type) { 110 return op.emitOpError() << "value type " << op.getValue().getType() 111 << " must match return type: " << type; 112 } 113 // Integer values must be signless. 114 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 115 return op.emitOpError("integer return type must be signless"); 116 // Any float or elements attribute are acceptable. 117 if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 118 return op.emitOpError( 119 "value must be an integer, float, or elements attribute"); 120 } 121 return success(); 122 } 123 124 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 125 // The value's type must be the same as the provided type. 126 if (value.getType() != type) 127 return false; 128 // Integer values must be signless. 129 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 130 return false; 131 // Integer, float, and element attributes are buildable. 132 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 133 } 134 135 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 136 return getValue(); 137 } 138 139 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 140 int64_t value, unsigned width) { 141 auto type = builder.getIntegerType(width); 142 arith::ConstantOp::build(builder, result, type, 143 builder.getIntegerAttr(type, value)); 144 } 145 146 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 147 int64_t value, Type type) { 148 assert(type.isSignlessInteger() && 149 "ConstantIntOp can only have signless integer type values"); 150 arith::ConstantOp::build(builder, result, type, 151 builder.getIntegerAttr(type, value)); 152 } 153 154 bool arith::ConstantIntOp::classof(Operation *op) { 155 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 156 return constOp.getType().isSignlessInteger(); 157 return false; 158 } 159 160 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 161 const APFloat &value, FloatType type) { 162 arith::ConstantOp::build(builder, result, type, 163 builder.getFloatAttr(type, value)); 164 } 165 166 bool arith::ConstantFloatOp::classof(Operation *op) { 167 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 168 return constOp.getType().isa<FloatType>(); 169 return false; 170 } 171 172 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 173 int64_t value) { 174 arith::ConstantOp::build(builder, result, builder.getIndexType(), 175 builder.getIndexAttr(value)); 176 } 177 178 bool arith::ConstantIndexOp::classof(Operation *op) { 179 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 180 return constOp.getType().isIndex(); 181 return false; 182 } 183 184 //===----------------------------------------------------------------------===// 185 // AddIOp 186 //===----------------------------------------------------------------------===// 187 188 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 189 // addi(x, 0) -> x 190 if (matchPattern(getRhs(), m_Zero())) 191 return getLhs(); 192 193 return constFoldBinaryOp<IntegerAttr>(operands, 194 [](APInt a, APInt b) { return a + b; }); 195 } 196 197 void arith::AddIOp::getCanonicalizationPatterns( 198 OwningRewritePatternList &patterns, MLIRContext *context) { 199 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 200 context); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // SubIOp 205 //===----------------------------------------------------------------------===// 206 207 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 208 // subi(x,x) -> 0 209 if (getOperand(0) == getOperand(1)) 210 return Builder(getContext()).getZeroAttr(getType()); 211 // subi(x,0) -> x 212 if (matchPattern(getRhs(), m_Zero())) 213 return getLhs(); 214 215 return constFoldBinaryOp<IntegerAttr>(operands, 216 [](APInt a, APInt b) { return a - b; }); 217 } 218 219 void arith::SubIOp::getCanonicalizationPatterns( 220 OwningRewritePatternList &patterns, MLIRContext *context) { 221 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 222 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 223 SubILHSSubConstantLHS>(context); 224 } 225 226 //===----------------------------------------------------------------------===// 227 // MulIOp 228 //===----------------------------------------------------------------------===// 229 230 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 231 // muli(x, 0) -> 0 232 if (matchPattern(getRhs(), m_Zero())) 233 return getRhs(); 234 // muli(x, 1) -> x 235 if (matchPattern(getRhs(), m_One())) 236 return getOperand(0); 237 // TODO: Handle the overflow case. 238 239 // default folder 240 return constFoldBinaryOp<IntegerAttr>(operands, 241 [](APInt a, APInt b) { return a * b; }); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // DivUIOp 246 //===----------------------------------------------------------------------===// 247 248 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 249 // Don't fold if it would require a division by zero. 250 bool div0 = false; 251 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 252 if (div0 || !b) { 253 div0 = true; 254 return a; 255 } 256 return a.udiv(b); 257 }); 258 259 // Fold out division by one. Assumes all tensors of all ones are splats. 260 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 261 if (rhs.getValue() == 1) 262 return getLhs(); 263 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 264 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 265 return getLhs(); 266 } 267 268 return div0 ? Attribute() : result; 269 } 270 271 //===----------------------------------------------------------------------===// 272 // DivSIOp 273 //===----------------------------------------------------------------------===// 274 275 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 276 // Don't fold if it would overflow or if it requires a division by zero. 277 bool overflowOrDiv0 = false; 278 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 279 if (overflowOrDiv0 || !b) { 280 overflowOrDiv0 = true; 281 return a; 282 } 283 return a.sdiv_ov(b, overflowOrDiv0); 284 }); 285 286 // Fold out division by one. Assumes all tensors of all ones are splats. 287 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 288 if (rhs.getValue() == 1) 289 return getLhs(); 290 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 291 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 292 return getLhs(); 293 } 294 295 return overflowOrDiv0 ? Attribute() : result; 296 } 297 298 //===----------------------------------------------------------------------===// 299 // Ceil and floor division folding helpers 300 //===----------------------------------------------------------------------===// 301 302 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { 303 // Returns (a-1)/b + 1 304 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 305 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 306 return val.sadd_ov(one, overflow); 307 } 308 309 //===----------------------------------------------------------------------===// 310 // CeilDivUIOp 311 //===----------------------------------------------------------------------===// 312 313 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 314 bool overflowOrDiv0 = false; 315 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 316 if (overflowOrDiv0 || !b) { 317 overflowOrDiv0 = true; 318 return a; 319 } 320 APInt quotient = a.udiv(b); 321 if (!a.urem(b)) 322 return quotient; 323 APInt one(a.getBitWidth(), 1, true); 324 return quotient.uadd_ov(one, overflowOrDiv0); 325 }); 326 // Fold out ceil division by one. Assumes all tensors of all ones are 327 // splats. 328 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 329 if (rhs.getValue() == 1) 330 return getLhs(); 331 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 332 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 333 return getLhs(); 334 } 335 336 return overflowOrDiv0 ? Attribute() : result; 337 } 338 339 //===----------------------------------------------------------------------===// 340 // CeilDivSIOp 341 //===----------------------------------------------------------------------===// 342 343 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 344 // Don't fold if it would overflow or if it requires a division by zero. 345 bool overflowOrDiv0 = false; 346 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 347 if (overflowOrDiv0 || !b) { 348 overflowOrDiv0 = true; 349 return a; 350 } 351 unsigned bits = a.getBitWidth(); 352 APInt zero = APInt::getZero(bits); 353 if (a.sgt(zero) && b.sgt(zero)) { 354 // Both positive, return ceil(a, b). 355 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 356 } 357 if (a.slt(zero) && b.slt(zero)) { 358 // Both negative, return ceil(-a, -b). 359 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 360 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 361 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 362 } 363 if (a.slt(zero) && b.sgt(zero)) { 364 // A is negative, b is positive, return - ( -a / b). 365 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 366 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 367 return zero.ssub_ov(div, overflowOrDiv0); 368 } 369 // A is positive (or zero), b is negative, return - (a / -b). 370 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 371 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 372 return zero.ssub_ov(div, overflowOrDiv0); 373 }); 374 375 // Fold out ceil division by one. Assumes all tensors of all ones are 376 // splats. 377 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 378 if (rhs.getValue() == 1) 379 return getLhs(); 380 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 381 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 382 return getLhs(); 383 } 384 385 return overflowOrDiv0 ? Attribute() : result; 386 } 387 388 //===----------------------------------------------------------------------===// 389 // FloorDivSIOp 390 //===----------------------------------------------------------------------===// 391 392 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 393 // Don't fold if it would overflow or if it requires a division by zero. 394 bool overflowOrDiv0 = false; 395 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 396 if (overflowOrDiv0 || !b) { 397 overflowOrDiv0 = true; 398 return a; 399 } 400 unsigned bits = a.getBitWidth(); 401 APInt zero = APInt::getZero(bits); 402 if (a.sge(zero) && b.sgt(zero)) { 403 // Both positive (or a is zero), return a / b. 404 return a.sdiv_ov(b, overflowOrDiv0); 405 } 406 if (a.sle(zero) && b.slt(zero)) { 407 // Both negative (or a is zero), return -a / -b. 408 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 409 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 410 return posA.sdiv_ov(posB, overflowOrDiv0); 411 } 412 if (a.slt(zero) && b.sgt(zero)) { 413 // A is negative, b is positive, return - ceil(-a, b). 414 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 415 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 416 return zero.ssub_ov(ceil, overflowOrDiv0); 417 } 418 // A is positive, b is negative, return - ceil(a, -b). 419 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 420 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 421 return zero.ssub_ov(ceil, overflowOrDiv0); 422 }); 423 424 // Fold out floor division by one. Assumes all tensors of all ones are 425 // splats. 426 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 427 if (rhs.getValue() == 1) 428 return getLhs(); 429 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 430 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 431 return getLhs(); 432 } 433 434 return overflowOrDiv0 ? Attribute() : result; 435 } 436 437 //===----------------------------------------------------------------------===// 438 // RemUIOp 439 //===----------------------------------------------------------------------===// 440 441 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 442 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 443 if (!rhs) 444 return {}; 445 auto rhsValue = rhs.getValue(); 446 447 // x % 1 = 0 448 if (rhsValue.isOneValue()) 449 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 450 451 // Don't fold if it requires division by zero. 452 if (rhsValue.isNullValue()) 453 return {}; 454 455 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 456 if (!lhs) 457 return {}; 458 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 459 } 460 461 //===----------------------------------------------------------------------===// 462 // RemSIOp 463 //===----------------------------------------------------------------------===// 464 465 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 466 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 467 if (!rhs) 468 return {}; 469 auto rhsValue = rhs.getValue(); 470 471 // x % 1 = 0 472 if (rhsValue.isOneValue()) 473 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 474 475 // Don't fold if it requires division by zero. 476 if (rhsValue.isNullValue()) 477 return {}; 478 479 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 480 if (!lhs) 481 return {}; 482 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 483 } 484 485 //===----------------------------------------------------------------------===// 486 // AndIOp 487 //===----------------------------------------------------------------------===// 488 489 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 490 /// and(x, 0) -> 0 491 if (matchPattern(getRhs(), m_Zero())) 492 return getRhs(); 493 /// and(x, allOnes) -> x 494 APInt intValue; 495 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 496 return getLhs(); 497 498 return constFoldBinaryOp<IntegerAttr>(operands, 499 [](APInt a, APInt b) { return a & b; }); 500 } 501 502 //===----------------------------------------------------------------------===// 503 // OrIOp 504 //===----------------------------------------------------------------------===// 505 506 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 507 /// or(x, 0) -> x 508 if (matchPattern(getRhs(), m_Zero())) 509 return getLhs(); 510 /// or(x, <all ones>) -> <all ones> 511 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 512 if (rhsAttr.getValue().isAllOnes()) 513 return rhsAttr; 514 515 return constFoldBinaryOp<IntegerAttr>(operands, 516 [](APInt a, APInt b) { return a | b; }); 517 } 518 519 //===----------------------------------------------------------------------===// 520 // XOrIOp 521 //===----------------------------------------------------------------------===// 522 523 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 524 /// xor(x, 0) -> x 525 if (matchPattern(getRhs(), m_Zero())) 526 return getLhs(); 527 /// xor(x, x) -> 0 528 if (getLhs() == getRhs()) 529 return Builder(getContext()).getZeroAttr(getType()); 530 531 return constFoldBinaryOp<IntegerAttr>(operands, 532 [](APInt a, APInt b) { return a ^ b; }); 533 } 534 535 void arith::XOrIOp::getCanonicalizationPatterns( 536 OwningRewritePatternList &patterns, MLIRContext *context) { 537 patterns.insert<XOrINotCmpI>(context); 538 } 539 540 //===----------------------------------------------------------------------===// 541 // AddFOp 542 //===----------------------------------------------------------------------===// 543 544 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 545 return constFoldBinaryOp<FloatAttr>( 546 operands, [](APFloat a, APFloat b) { return a + b; }); 547 } 548 549 //===----------------------------------------------------------------------===// 550 // SubFOp 551 //===----------------------------------------------------------------------===// 552 553 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 554 return constFoldBinaryOp<FloatAttr>( 555 operands, [](APFloat a, APFloat b) { return a - b; }); 556 } 557 558 //===----------------------------------------------------------------------===// 559 // MaxSIOp 560 //===----------------------------------------------------------------------===// 561 562 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) { 563 assert(operands.size() == 2 && "binary operation takes two operands"); 564 565 // maxsi(x,x) -> x 566 if (getLhs() == getRhs()) 567 return getRhs(); 568 569 APInt intValue; 570 // maxsi(x,MAX_INT) -> MAX_INT 571 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 572 intValue.isMaxSignedValue()) 573 return getRhs(); 574 575 // maxsi(x, MIN_INT) -> x 576 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 577 intValue.isMinSignedValue()) 578 return getLhs(); 579 580 return constFoldBinaryOp<IntegerAttr>( 581 operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); }); 582 } 583 584 //===----------------------------------------------------------------------===// 585 // MaxUIOp 586 //===----------------------------------------------------------------------===// 587 588 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) { 589 assert(operands.size() == 2 && "binary operation takes two operands"); 590 591 // maxui(x,x) -> x 592 if (getLhs() == getRhs()) 593 return getRhs(); 594 595 APInt intValue; 596 // maxui(x,MAX_INT) -> MAX_INT 597 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 598 return getRhs(); 599 600 // maxui(x, MIN_INT) -> x 601 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 602 return getLhs(); 603 604 return constFoldBinaryOp<IntegerAttr>( 605 operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); }); 606 } 607 608 //===----------------------------------------------------------------------===// 609 // MinSIOp 610 //===----------------------------------------------------------------------===// 611 612 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) { 613 assert(operands.size() == 2 && "binary operation takes two operands"); 614 615 // minsi(x,x) -> x 616 if (getLhs() == getRhs()) 617 return getRhs(); 618 619 APInt intValue; 620 // minsi(x,MIN_INT) -> MIN_INT 621 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 622 intValue.isMinSignedValue()) 623 return getRhs(); 624 625 // minsi(x, MAX_INT) -> x 626 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && 627 intValue.isMaxSignedValue()) 628 return getLhs(); 629 630 return constFoldBinaryOp<IntegerAttr>( 631 operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); }); 632 } 633 634 //===----------------------------------------------------------------------===// 635 // MinUIOp 636 //===----------------------------------------------------------------------===// 637 638 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) { 639 assert(operands.size() == 2 && "binary operation takes two operands"); 640 641 // minui(x,x) -> x 642 if (getLhs() == getRhs()) 643 return getRhs(); 644 645 APInt intValue; 646 // minui(x,MIN_INT) -> MIN_INT 647 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) 648 return getRhs(); 649 650 // minui(x, MAX_INT) -> x 651 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) 652 return getLhs(); 653 654 return constFoldBinaryOp<IntegerAttr>( 655 operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); }); 656 } 657 658 //===----------------------------------------------------------------------===// 659 // MulFOp 660 //===----------------------------------------------------------------------===// 661 662 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 663 return constFoldBinaryOp<FloatAttr>( 664 operands, [](APFloat a, APFloat b) { return a * b; }); 665 } 666 667 //===----------------------------------------------------------------------===// 668 // DivFOp 669 //===----------------------------------------------------------------------===// 670 671 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 672 return constFoldBinaryOp<FloatAttr>( 673 operands, [](APFloat a, APFloat b) { return a / b; }); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // Utility functions for verifying cast ops 678 //===----------------------------------------------------------------------===// 679 680 template <typename... Types> 681 using type_list = std::tuple<Types...> *; 682 683 /// Returns a non-null type only if the provided type is one of the allowed 684 /// types or one of the allowed shaped types of the allowed types. Returns the 685 /// element type if a valid shaped type is provided. 686 template <typename... ShapedTypes, typename... ElementTypes> 687 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 688 type_list<ElementTypes...>) { 689 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 690 return {}; 691 692 auto underlyingType = getElementTypeOrSelf(type); 693 if (!underlyingType.isa<ElementTypes...>()) 694 return {}; 695 696 return underlyingType; 697 } 698 699 /// Get allowed underlying types for vectors and tensors. 700 template <typename... ElementTypes> 701 static Type getTypeIfLike(Type type) { 702 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 703 type_list<ElementTypes...>()); 704 } 705 706 /// Get allowed underlying types for vectors, tensors, and memrefs. 707 template <typename... ElementTypes> 708 static Type getTypeIfLikeOrMemRef(Type type) { 709 return getUnderlyingType(type, 710 type_list<VectorType, TensorType, MemRefType>(), 711 type_list<ElementTypes...>()); 712 } 713 714 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 715 return inputs.size() == 1 && outputs.size() == 1 && 716 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 717 } 718 719 //===----------------------------------------------------------------------===// 720 // Verifiers for integer and floating point extension/truncation ops 721 //===----------------------------------------------------------------------===// 722 723 // Extend ops can only extend to a wider type. 724 template <typename ValType, typename Op> 725 static LogicalResult verifyExtOp(Op op) { 726 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 727 Type dstType = getElementTypeOrSelf(op.getType()); 728 729 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 730 return op.emitError("result type ") 731 << dstType << " must be wider than operand type " << srcType; 732 733 return success(); 734 } 735 736 // Truncate ops can only truncate to a shorter type. 737 template <typename ValType, typename Op> 738 static LogicalResult verifyTruncateOp(Op op) { 739 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 740 Type dstType = getElementTypeOrSelf(op.getType()); 741 742 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 743 return op.emitError("result type ") 744 << dstType << " must be shorter than operand type " << srcType; 745 746 return success(); 747 } 748 749 /// Validate a cast that changes the width of a type. 750 template <template <typename> class WidthComparator, typename... ElementTypes> 751 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 752 if (!areValidCastInputsAndOutputs(inputs, outputs)) 753 return false; 754 755 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 756 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 757 if (!srcType || !dstType) 758 return false; 759 760 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 761 srcType.getIntOrFloatBitWidth()); 762 } 763 764 //===----------------------------------------------------------------------===// 765 // ExtUIOp 766 //===----------------------------------------------------------------------===// 767 768 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 769 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 770 return IntegerAttr::get( 771 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 772 773 return {}; 774 } 775 776 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 777 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 778 } 779 780 //===----------------------------------------------------------------------===// 781 // ExtSIOp 782 //===----------------------------------------------------------------------===// 783 784 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 785 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 786 return IntegerAttr::get( 787 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 788 789 return {}; 790 } 791 792 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 793 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 794 } 795 796 //===----------------------------------------------------------------------===// 797 // ExtFOp 798 //===----------------------------------------------------------------------===// 799 800 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 801 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 802 } 803 804 //===----------------------------------------------------------------------===// 805 // TruncIOp 806 //===----------------------------------------------------------------------===// 807 808 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 809 // trunci(zexti(a)) -> a 810 // trunci(sexti(a)) -> a 811 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 812 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 813 return getOperand().getDefiningOp()->getOperand(0); 814 815 assert(operands.size() == 1 && "unary operation takes one operand"); 816 817 if (!operands[0]) 818 return {}; 819 820 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 821 return IntegerAttr::get( 822 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 823 } 824 825 return {}; 826 } 827 828 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 829 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 830 } 831 832 //===----------------------------------------------------------------------===// 833 // TruncFOp 834 //===----------------------------------------------------------------------===// 835 836 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 837 /// can be represented without precision loss or rounding. 838 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 839 assert(operands.size() == 1 && "unary operation takes one operand"); 840 841 auto constOperand = operands.front(); 842 if (!constOperand || !constOperand.isa<FloatAttr>()) 843 return {}; 844 845 // Convert to target type via 'double'. 846 double sourceValue = 847 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 848 auto targetAttr = FloatAttr::get(getType(), sourceValue); 849 850 // Propagate if constant's value does not change after truncation. 851 if (sourceValue == targetAttr.getValue().convertToDouble()) 852 return targetAttr; 853 854 return {}; 855 } 856 857 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 858 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 859 } 860 861 //===----------------------------------------------------------------------===// 862 // Verifiers for casts between integers and floats. 863 //===----------------------------------------------------------------------===// 864 865 template <typename From, typename To> 866 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 867 if (!areValidCastInputsAndOutputs(inputs, outputs)) 868 return false; 869 870 auto srcType = getTypeIfLike<From>(inputs.front()); 871 auto dstType = getTypeIfLike<To>(outputs.back()); 872 873 return srcType && dstType; 874 } 875 876 //===----------------------------------------------------------------------===// 877 // UIToFPOp 878 //===----------------------------------------------------------------------===// 879 880 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 881 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 882 } 883 884 //===----------------------------------------------------------------------===// 885 // SIToFPOp 886 //===----------------------------------------------------------------------===// 887 888 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 889 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 890 } 891 892 //===----------------------------------------------------------------------===// 893 // FPToUIOp 894 //===----------------------------------------------------------------------===// 895 896 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 897 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 898 } 899 900 //===----------------------------------------------------------------------===// 901 // FPToSIOp 902 //===----------------------------------------------------------------------===// 903 904 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 905 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 906 } 907 908 //===----------------------------------------------------------------------===// 909 // IndexCastOp 910 //===----------------------------------------------------------------------===// 911 912 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 913 TypeRange outputs) { 914 if (!areValidCastInputsAndOutputs(inputs, outputs)) 915 return false; 916 917 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 918 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 919 if (!srcType || !dstType) 920 return false; 921 922 return (srcType.isIndex() && dstType.isSignlessInteger()) || 923 (srcType.isSignlessInteger() && dstType.isIndex()); 924 } 925 926 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 927 // index_cast(constant) -> constant 928 // A little hack because we go through int. Otherwise, the size of the 929 // constant might need to change. 930 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 931 return IntegerAttr::get(getType(), value.getInt()); 932 933 return {}; 934 } 935 936 void arith::IndexCastOp::getCanonicalizationPatterns( 937 OwningRewritePatternList &patterns, MLIRContext *context) { 938 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 939 } 940 941 //===----------------------------------------------------------------------===// 942 // BitcastOp 943 //===----------------------------------------------------------------------===// 944 945 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 946 if (!areValidCastInputsAndOutputs(inputs, outputs)) 947 return false; 948 949 auto srcType = 950 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 951 auto dstType = 952 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 953 if (!srcType || !dstType) 954 return false; 955 956 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 957 } 958 959 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 960 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 961 962 auto resType = getType(); 963 auto operand = operands[0]; 964 if (!operand) 965 return {}; 966 967 /// Bitcast dense elements. 968 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 969 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 970 /// Other shaped types unhandled. 971 if (resType.isa<ShapedType>()) 972 return {}; 973 974 /// Bitcast integer or float to integer or float. 975 APInt bits = operand.isa<FloatAttr>() 976 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 977 : operand.cast<IntegerAttr>().getValue(); 978 979 if (auto resFloatType = resType.dyn_cast<FloatType>()) 980 return FloatAttr::get(resType, 981 APFloat(resFloatType.getFloatSemantics(), bits)); 982 return IntegerAttr::get(resType, bits); 983 } 984 985 void arith::BitcastOp::getCanonicalizationPatterns( 986 OwningRewritePatternList &patterns, MLIRContext *context) { 987 patterns.insert<BitcastOfBitcast>(context); 988 } 989 990 //===----------------------------------------------------------------------===// 991 // Helpers for compare ops 992 //===----------------------------------------------------------------------===// 993 994 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 995 static Type getI1SameShape(Type type) { 996 auto i1Type = IntegerType::get(type.getContext(), 1); 997 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 998 return RankedTensorType::get(tensorType.getShape(), i1Type); 999 if (type.isa<UnrankedTensorType>()) 1000 return UnrankedTensorType::get(i1Type); 1001 if (auto vectorType = type.dyn_cast<VectorType>()) 1002 return VectorType::get(vectorType.getShape(), i1Type); 1003 return i1Type; 1004 } 1005 1006 //===----------------------------------------------------------------------===// 1007 // CmpIOp 1008 //===----------------------------------------------------------------------===// 1009 1010 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 1011 /// comparison predicates. 1012 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 1013 const APInt &lhs, const APInt &rhs) { 1014 switch (predicate) { 1015 case arith::CmpIPredicate::eq: 1016 return lhs.eq(rhs); 1017 case arith::CmpIPredicate::ne: 1018 return lhs.ne(rhs); 1019 case arith::CmpIPredicate::slt: 1020 return lhs.slt(rhs); 1021 case arith::CmpIPredicate::sle: 1022 return lhs.sle(rhs); 1023 case arith::CmpIPredicate::sgt: 1024 return lhs.sgt(rhs); 1025 case arith::CmpIPredicate::sge: 1026 return lhs.sge(rhs); 1027 case arith::CmpIPredicate::ult: 1028 return lhs.ult(rhs); 1029 case arith::CmpIPredicate::ule: 1030 return lhs.ule(rhs); 1031 case arith::CmpIPredicate::ugt: 1032 return lhs.ugt(rhs); 1033 case arith::CmpIPredicate::uge: 1034 return lhs.uge(rhs); 1035 } 1036 llvm_unreachable("unknown cmpi predicate kind"); 1037 } 1038 1039 /// Returns true if the predicate is true for two equal operands. 1040 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 1041 switch (predicate) { 1042 case arith::CmpIPredicate::eq: 1043 case arith::CmpIPredicate::sle: 1044 case arith::CmpIPredicate::sge: 1045 case arith::CmpIPredicate::ule: 1046 case arith::CmpIPredicate::uge: 1047 return true; 1048 case arith::CmpIPredicate::ne: 1049 case arith::CmpIPredicate::slt: 1050 case arith::CmpIPredicate::sgt: 1051 case arith::CmpIPredicate::ult: 1052 case arith::CmpIPredicate::ugt: 1053 return false; 1054 } 1055 llvm_unreachable("unknown cmpi predicate kind"); 1056 } 1057 1058 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 1059 assert(operands.size() == 2 && "cmpi takes two operands"); 1060 1061 // cmpi(pred, x, x) 1062 if (getLhs() == getRhs()) { 1063 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 1064 return BoolAttr::get(getContext(), val); 1065 } 1066 1067 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1068 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1069 if (!lhs || !rhs) 1070 return {}; 1071 1072 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1073 return BoolAttr::get(getContext(), val); 1074 } 1075 1076 //===----------------------------------------------------------------------===// 1077 // CmpFOp 1078 //===----------------------------------------------------------------------===// 1079 1080 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 1081 /// comparison predicates. 1082 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 1083 const APFloat &lhs, const APFloat &rhs) { 1084 auto cmpResult = lhs.compare(rhs); 1085 switch (predicate) { 1086 case arith::CmpFPredicate::AlwaysFalse: 1087 return false; 1088 case arith::CmpFPredicate::OEQ: 1089 return cmpResult == APFloat::cmpEqual; 1090 case arith::CmpFPredicate::OGT: 1091 return cmpResult == APFloat::cmpGreaterThan; 1092 case arith::CmpFPredicate::OGE: 1093 return cmpResult == APFloat::cmpGreaterThan || 1094 cmpResult == APFloat::cmpEqual; 1095 case arith::CmpFPredicate::OLT: 1096 return cmpResult == APFloat::cmpLessThan; 1097 case arith::CmpFPredicate::OLE: 1098 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1099 case arith::CmpFPredicate::ONE: 1100 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1101 case arith::CmpFPredicate::ORD: 1102 return cmpResult != APFloat::cmpUnordered; 1103 case arith::CmpFPredicate::UEQ: 1104 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1105 case arith::CmpFPredicate::UGT: 1106 return cmpResult == APFloat::cmpUnordered || 1107 cmpResult == APFloat::cmpGreaterThan; 1108 case arith::CmpFPredicate::UGE: 1109 return cmpResult == APFloat::cmpUnordered || 1110 cmpResult == APFloat::cmpGreaterThan || 1111 cmpResult == APFloat::cmpEqual; 1112 case arith::CmpFPredicate::ULT: 1113 return cmpResult == APFloat::cmpUnordered || 1114 cmpResult == APFloat::cmpLessThan; 1115 case arith::CmpFPredicate::ULE: 1116 return cmpResult == APFloat::cmpUnordered || 1117 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1118 case arith::CmpFPredicate::UNE: 1119 return cmpResult != APFloat::cmpEqual; 1120 case arith::CmpFPredicate::UNO: 1121 return cmpResult == APFloat::cmpUnordered; 1122 case arith::CmpFPredicate::AlwaysTrue: 1123 return true; 1124 } 1125 llvm_unreachable("unknown cmpf predicate kind"); 1126 } 1127 1128 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1129 assert(operands.size() == 2 && "cmpf takes two operands"); 1130 1131 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1132 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1133 1134 if (!lhs || !rhs) 1135 return {}; 1136 1137 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1138 return BoolAttr::get(getContext(), val); 1139 } 1140 1141 //===----------------------------------------------------------------------===// 1142 // TableGen'd op method definitions 1143 //===----------------------------------------------------------------------===// 1144 1145 #define GET_OP_CLASSES 1146 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1147 1148 //===----------------------------------------------------------------------===// 1149 // TableGen'd enum attribute definitions 1150 //===----------------------------------------------------------------------===// 1151 1152 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1153