1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/CommonFolders.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 using namespace mlir; 18 using namespace mlir::arith; 19 20 //===----------------------------------------------------------------------===// 21 // Pattern helpers 22 //===----------------------------------------------------------------------===// 23 24 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 25 Attribute lhs, Attribute rhs) { 26 return builder.getIntegerAttr(res.getType(), 27 lhs.cast<IntegerAttr>().getInt() + 28 rhs.cast<IntegerAttr>().getInt()); 29 } 30 31 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 32 Attribute lhs, Attribute rhs) { 33 return builder.getIntegerAttr(res.getType(), 34 lhs.cast<IntegerAttr>().getInt() - 35 rhs.cast<IntegerAttr>().getInt()); 36 } 37 38 /// Invert an integer comparison predicate. 39 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { 40 switch (pred) { 41 case arith::CmpIPredicate::eq: 42 return arith::CmpIPredicate::ne; 43 case arith::CmpIPredicate::ne: 44 return arith::CmpIPredicate::eq; 45 case arith::CmpIPredicate::slt: 46 return arith::CmpIPredicate::sge; 47 case arith::CmpIPredicate::sle: 48 return arith::CmpIPredicate::sgt; 49 case arith::CmpIPredicate::sgt: 50 return arith::CmpIPredicate::sle; 51 case arith::CmpIPredicate::sge: 52 return arith::CmpIPredicate::slt; 53 case arith::CmpIPredicate::ult: 54 return arith::CmpIPredicate::uge; 55 case arith::CmpIPredicate::ule: 56 return arith::CmpIPredicate::ugt; 57 case arith::CmpIPredicate::ugt: 58 return arith::CmpIPredicate::ule; 59 case arith::CmpIPredicate::uge: 60 return arith::CmpIPredicate::ult; 61 } 62 llvm_unreachable("unknown cmpi predicate kind"); 63 } 64 65 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 66 return arith::CmpIPredicateAttr::get(pred.getContext(), 67 invertPredicate(pred.getValue())); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // TableGen'd canonicalization patterns 72 //===----------------------------------------------------------------------===// 73 74 namespace { 75 #include "ArithmeticCanonicalization.inc" 76 } // end anonymous namespace 77 78 //===----------------------------------------------------------------------===// 79 // AddIOp 80 //===----------------------------------------------------------------------===// 81 82 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 83 // addi(x, 0) -> x 84 if (matchPattern(rhs(), m_Zero())) 85 return lhs(); 86 87 return constFoldBinaryOp<IntegerAttr>(operands, 88 [](APInt a, APInt b) { return a + b; }); 89 } 90 91 void arith::AddIOp::getCanonicalizationPatterns( 92 OwningRewritePatternList &patterns, MLIRContext *context) { 93 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 94 context); 95 } 96 97 //===----------------------------------------------------------------------===// 98 // SubIOp 99 //===----------------------------------------------------------------------===// 100 101 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 102 // subi(x,x) -> 0 103 if (getOperand(0) == getOperand(1)) 104 return Builder(getContext()).getZeroAttr(getType()); 105 // subi(x,0) -> x 106 if (matchPattern(rhs(), m_Zero())) 107 return lhs(); 108 109 return constFoldBinaryOp<IntegerAttr>(operands, 110 [](APInt a, APInt b) { return a - b; }); 111 } 112 113 void arith::SubIOp::getCanonicalizationPatterns( 114 OwningRewritePatternList &patterns, MLIRContext *context) { 115 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 116 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 117 SubILHSSubConstantLHS>(context); 118 } 119 120 //===----------------------------------------------------------------------===// 121 // MulIOp 122 //===----------------------------------------------------------------------===// 123 124 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 125 // muli(x, 0) -> 0 126 if (matchPattern(rhs(), m_Zero())) 127 return rhs(); 128 // muli(x, 1) -> x 129 if (matchPattern(rhs(), m_One())) 130 return getOperand(0); 131 // TODO: Handle the overflow case. 132 133 // default folder 134 return constFoldBinaryOp<IntegerAttr>(operands, 135 [](APInt a, APInt b) { return a * b; }); 136 } 137 138 //===----------------------------------------------------------------------===// 139 // DivUIOp 140 //===----------------------------------------------------------------------===// 141 142 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 143 // Don't fold if it would require a division by zero. 144 bool div0 = false; 145 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 146 if (div0 || !b) { 147 div0 = true; 148 return a; 149 } 150 return a.udiv(b); 151 }); 152 153 // Fold out division by one. Assumes all tensors of all ones are splats. 154 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 155 if (rhs.getValue() == 1) 156 return lhs(); 157 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 158 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 159 return lhs(); 160 } 161 162 return div0 ? Attribute() : result; 163 } 164 165 //===----------------------------------------------------------------------===// 166 // DivSIOp 167 //===----------------------------------------------------------------------===// 168 169 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 170 // Don't fold if it would overflow or if it requires a division by zero. 171 bool overflowOrDiv0 = false; 172 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 173 if (overflowOrDiv0 || !b) { 174 overflowOrDiv0 = true; 175 return a; 176 } 177 return a.sdiv_ov(b, overflowOrDiv0); 178 }); 179 180 // Fold out division by one. Assumes all tensors of all ones are splats. 181 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 182 if (rhs.getValue() == 1) 183 return lhs(); 184 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 185 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 186 return lhs(); 187 } 188 189 return overflowOrDiv0 ? Attribute() : result; 190 } 191 192 //===----------------------------------------------------------------------===// 193 // Ceil and floor division folding helpers 194 //===----------------------------------------------------------------------===// 195 196 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { 197 // Returns (a-1)/b + 1 198 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 199 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 200 return val.sadd_ov(one, overflow); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // CeilDivSIOp 205 //===----------------------------------------------------------------------===// 206 207 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 208 // Don't fold if it would overflow or if it requires a division by zero. 209 bool overflowOrDiv0 = false; 210 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 211 if (overflowOrDiv0 || !b) { 212 overflowOrDiv0 = true; 213 return a; 214 } 215 unsigned bits = a.getBitWidth(); 216 APInt zero = APInt::getZero(bits); 217 if (a.sgt(zero) && b.sgt(zero)) { 218 // Both positive, return ceil(a, b). 219 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 220 } 221 if (a.slt(zero) && b.slt(zero)) { 222 // Both negative, return ceil(-a, -b). 223 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 224 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 225 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 226 } 227 if (a.slt(zero) && b.sgt(zero)) { 228 // A is negative, b is positive, return - ( -a / b). 229 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 230 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 231 return zero.ssub_ov(div, overflowOrDiv0); 232 } 233 // A is positive (or zero), b is negative, return - (a / -b). 234 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 235 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 236 return zero.ssub_ov(div, overflowOrDiv0); 237 }); 238 239 // Fold out floor division by one. Assumes all tensors of all ones are 240 // splats. 241 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 242 if (rhs.getValue() == 1) 243 return lhs(); 244 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 245 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 246 return lhs(); 247 } 248 249 return overflowOrDiv0 ? Attribute() : result; 250 } 251 252 //===----------------------------------------------------------------------===// 253 // FloorDivSIOp 254 //===----------------------------------------------------------------------===// 255 256 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 257 // Don't fold if it would overflow or if it requires a division by zero. 258 bool overflowOrDiv0 = false; 259 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 260 if (overflowOrDiv0 || !b) { 261 overflowOrDiv0 = true; 262 return a; 263 } 264 unsigned bits = a.getBitWidth(); 265 APInt zero = APInt::getZero(bits); 266 if (a.sge(zero) && b.sgt(zero)) { 267 // Both positive (or a is zero), return a / b. 268 return a.sdiv_ov(b, overflowOrDiv0); 269 } 270 if (a.sle(zero) && b.slt(zero)) { 271 // Both negative (or a is zero), return -a / -b. 272 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 273 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 274 return posA.sdiv_ov(posB, overflowOrDiv0); 275 } 276 if (a.slt(zero) && b.sgt(zero)) { 277 // A is negative, b is positive, return - ceil(-a, b). 278 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 279 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 280 return zero.ssub_ov(ceil, overflowOrDiv0); 281 } 282 // A is positive, b is negative, return - ceil(a, -b). 283 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 284 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 285 return zero.ssub_ov(ceil, overflowOrDiv0); 286 }); 287 288 // Fold out floor division by one. Assumes all tensors of all ones are 289 // splats. 290 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 291 if (rhs.getValue() == 1) 292 return lhs(); 293 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 294 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 295 return lhs(); 296 } 297 298 return overflowOrDiv0 ? Attribute() : result; 299 } 300 301 //===----------------------------------------------------------------------===// 302 // RemUIOp 303 //===----------------------------------------------------------------------===// 304 305 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 306 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 307 if (!rhs) 308 return {}; 309 auto rhsValue = rhs.getValue(); 310 311 // x % 1 = 0 312 if (rhsValue.isOneValue()) 313 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 314 315 // Don't fold if it requires division by zero. 316 if (rhsValue.isNullValue()) 317 return {}; 318 319 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 320 if (!lhs) 321 return {}; 322 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 323 } 324 325 //===----------------------------------------------------------------------===// 326 // RemSIOp 327 //===----------------------------------------------------------------------===// 328 329 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 330 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 331 if (!rhs) 332 return {}; 333 auto rhsValue = rhs.getValue(); 334 335 // x % 1 = 0 336 if (rhsValue.isOneValue()) 337 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 338 339 // Don't fold if it requires division by zero. 340 if (rhsValue.isNullValue()) 341 return {}; 342 343 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 344 if (!lhs) 345 return {}; 346 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 347 } 348 349 //===----------------------------------------------------------------------===// 350 // AndIOp 351 //===----------------------------------------------------------------------===// 352 353 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 354 /// and(x, 0) -> 0 355 if (matchPattern(rhs(), m_Zero())) 356 return rhs(); 357 /// and(x, allOnes) -> x 358 APInt intValue; 359 if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 360 return lhs(); 361 /// and(x, x) -> x 362 if (lhs() == rhs()) 363 return rhs(); 364 365 return constFoldBinaryOp<IntegerAttr>(operands, 366 [](APInt a, APInt b) { return a & b; }); 367 } 368 369 //===----------------------------------------------------------------------===// 370 // OrIOp 371 //===----------------------------------------------------------------------===// 372 373 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 374 /// or(x, 0) -> x 375 if (matchPattern(rhs(), m_Zero())) 376 return lhs(); 377 /// or(x, x) -> x 378 if (lhs() == rhs()) 379 return rhs(); 380 381 return constFoldBinaryOp<IntegerAttr>(operands, 382 [](APInt a, APInt b) { return a | b; }); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // XOrIOp 387 //===----------------------------------------------------------------------===// 388 389 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 390 /// xor(x, 0) -> x 391 if (matchPattern(rhs(), m_Zero())) 392 return lhs(); 393 /// xor(x, x) -> 0 394 if (lhs() == rhs()) 395 return Builder(getContext()).getZeroAttr(getType()); 396 397 return constFoldBinaryOp<IntegerAttr>(operands, 398 [](APInt a, APInt b) { return a ^ b; }); 399 } 400 401 void arith::XOrIOp::getCanonicalizationPatterns( 402 OwningRewritePatternList &patterns, MLIRContext *context) { 403 patterns.insert<XOrINotCmpI>(context); 404 } 405 406 //===----------------------------------------------------------------------===// 407 // AddFOp 408 //===----------------------------------------------------------------------===// 409 410 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 411 return constFoldBinaryOp<FloatAttr>( 412 operands, [](APFloat a, APFloat b) { return a + b; }); 413 } 414 415 //===----------------------------------------------------------------------===// 416 // SubFOp 417 //===----------------------------------------------------------------------===// 418 419 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 420 return constFoldBinaryOp<FloatAttr>( 421 operands, [](APFloat a, APFloat b) { return a - b; }); 422 } 423 424 //===----------------------------------------------------------------------===// 425 // MulFOp 426 //===----------------------------------------------------------------------===// 427 428 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 429 return constFoldBinaryOp<FloatAttr>( 430 operands, [](APFloat a, APFloat b) { return a * b; }); 431 } 432 433 //===----------------------------------------------------------------------===// 434 // DivFOp 435 //===----------------------------------------------------------------------===// 436 437 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 438 return constFoldBinaryOp<FloatAttr>( 439 operands, [](APFloat a, APFloat b) { return a / b; }); 440 } 441 442 //===----------------------------------------------------------------------===// 443 // Verifiers for integer and floating point extension/truncation ops 444 //===----------------------------------------------------------------------===// 445 446 // Extend ops can only extend to a wider type. 447 template <typename ValType, typename Op> 448 static LogicalResult verifyExtOp(Op op) { 449 Type srcType = getElementTypeOrSelf(op.in().getType()); 450 Type dstType = getElementTypeOrSelf(op.getType()); 451 452 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 453 return op.emitError("result type ") 454 << dstType << " must be wider than operand type " << srcType; 455 456 return success(); 457 } 458 459 // Truncate ops can only truncate to a shorter type. 460 template <typename ValType, typename Op> 461 static LogicalResult verifyTruncateOp(Op op) { 462 Type srcType = getElementTypeOrSelf(op.in().getType()); 463 Type dstType = getElementTypeOrSelf(op.getType()); 464 465 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 466 return op.emitError("result type ") 467 << dstType << " must be shorter than operand type " << srcType; 468 469 return success(); 470 } 471 472 //===----------------------------------------------------------------------===// 473 // ExtUIOp 474 //===----------------------------------------------------------------------===// 475 476 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 477 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 478 return IntegerAttr::get( 479 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 480 481 return {}; 482 } 483 484 //===----------------------------------------------------------------------===// 485 // ExtSIOp 486 //===----------------------------------------------------------------------===// 487 488 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 489 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 490 return IntegerAttr::get( 491 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 492 493 return {}; 494 } 495 496 // TODO temporary fixes until second patch is in 497 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 498 return {}; 499 } 500 501 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 502 return true; 503 } 504 505 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 506 return {}; 507 } 508 509 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 510 return true; 511 } 512 513 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 514 return true; 515 } 516 517 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 518 return true; 519 } 520 521 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 522 return true; 523 } 524 525 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 526 return {}; 527 } 528 529 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 530 return true; 531 } 532 533 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 534 return true; 535 } 536 537 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 538 return true; 539 } 540 541 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 542 return true; 543 } 544 545 //===----------------------------------------------------------------------===// 546 // IndexCastOp 547 //===----------------------------------------------------------------------===// 548 549 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 550 TypeRange outputs) { 551 assert(inputs.size() == 1 && outputs.size() == 1 && 552 "index_cast op expects one result and one result"); 553 554 // Shape equivalence is guaranteed by op traits. 555 auto srcType = getElementTypeOrSelf(inputs.front()); 556 auto dstType = getElementTypeOrSelf(outputs.front()); 557 558 return (srcType.isIndex() && dstType.isSignlessInteger()) || 559 (srcType.isSignlessInteger() && dstType.isIndex()); 560 } 561 562 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 563 // index_cast(constant) -> constant 564 // A little hack because we go through int. Otherwise, the size of the 565 // constant might need to change. 566 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 567 return IntegerAttr::get(getType(), value.getInt()); 568 569 return {}; 570 } 571 572 void arith::IndexCastOp::getCanonicalizationPatterns( 573 OwningRewritePatternList &patterns, MLIRContext *context) { 574 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 575 } 576 577 //===----------------------------------------------------------------------===// 578 // BitcastOp 579 //===----------------------------------------------------------------------===// 580 581 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 582 assert(inputs.size() == 1 && outputs.size() == 1 && 583 "bitcast op expects one operand and one result"); 584 585 // Shape equivalence is guaranteed by op traits. 586 auto srcType = getElementTypeOrSelf(inputs.front()); 587 auto dstType = getElementTypeOrSelf(outputs.front()); 588 589 // Types are guarnateed to be integers or floats by constraints. 590 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 591 } 592 593 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 594 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 595 596 auto resType = getType(); 597 auto operand = operands[0]; 598 if (!operand) 599 return {}; 600 601 /// Bitcast dense elements. 602 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 603 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 604 /// Other shaped types unhandled. 605 if (resType.isa<ShapedType>()) 606 return {}; 607 608 /// Bitcast integer or float to integer or float. 609 APInt bits = operand.isa<FloatAttr>() 610 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 611 : operand.cast<IntegerAttr>().getValue(); 612 613 if (auto resFloatType = resType.dyn_cast<FloatType>()) 614 return FloatAttr::get(resType, 615 APFloat(resFloatType.getFloatSemantics(), bits)); 616 return IntegerAttr::get(resType, bits); 617 } 618 619 void arith::BitcastOp::getCanonicalizationPatterns( 620 OwningRewritePatternList &patterns, MLIRContext *context) { 621 patterns.insert<BitcastOfBitcast>(context); 622 } 623 624 //===----------------------------------------------------------------------===// 625 // Helpers for compare ops 626 //===----------------------------------------------------------------------===// 627 628 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 629 static Type getI1SameShape(Type type) { 630 auto i1Type = IntegerType::get(type.getContext(), 1); 631 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 632 return RankedTensorType::get(tensorType.getShape(), i1Type); 633 if (type.isa<UnrankedTensorType>()) 634 return UnrankedTensorType::get(i1Type); 635 if (auto vectorType = type.dyn_cast<VectorType>()) 636 return VectorType::get(vectorType.getShape(), i1Type); 637 return i1Type; 638 } 639 640 //===----------------------------------------------------------------------===// 641 // CmpIOp 642 //===----------------------------------------------------------------------===// 643 644 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 645 /// comparison predicates. 646 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 647 const APInt &lhs, const APInt &rhs) { 648 switch (predicate) { 649 case arith::CmpIPredicate::eq: 650 return lhs.eq(rhs); 651 case arith::CmpIPredicate::ne: 652 return lhs.ne(rhs); 653 case arith::CmpIPredicate::slt: 654 return lhs.slt(rhs); 655 case arith::CmpIPredicate::sle: 656 return lhs.sle(rhs); 657 case arith::CmpIPredicate::sgt: 658 return lhs.sgt(rhs); 659 case arith::CmpIPredicate::sge: 660 return lhs.sge(rhs); 661 case arith::CmpIPredicate::ult: 662 return lhs.ult(rhs); 663 case arith::CmpIPredicate::ule: 664 return lhs.ule(rhs); 665 case arith::CmpIPredicate::ugt: 666 return lhs.ugt(rhs); 667 case arith::CmpIPredicate::uge: 668 return lhs.uge(rhs); 669 } 670 llvm_unreachable("unknown cmpi predicate kind"); 671 } 672 673 /// Returns true if the predicate is true for two equal operands. 674 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 675 switch (predicate) { 676 case arith::CmpIPredicate::eq: 677 case arith::CmpIPredicate::sle: 678 case arith::CmpIPredicate::sge: 679 case arith::CmpIPredicate::ule: 680 case arith::CmpIPredicate::uge: 681 return true; 682 case arith::CmpIPredicate::ne: 683 case arith::CmpIPredicate::slt: 684 case arith::CmpIPredicate::sgt: 685 case arith::CmpIPredicate::ult: 686 case arith::CmpIPredicate::ugt: 687 return false; 688 } 689 llvm_unreachable("unknown cmpi predicate kind"); 690 } 691 692 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 693 assert(operands.size() == 2 && "cmpi takes two operands"); 694 695 // cmpi(pred, x, x) 696 if (lhs() == rhs()) { 697 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 698 return BoolAttr::get(getContext(), val); 699 } 700 701 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 702 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 703 if (!lhs || !rhs) 704 return {}; 705 706 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 707 return BoolAttr::get(getContext(), val); 708 } 709 710 //===----------------------------------------------------------------------===// 711 // CmpFOp 712 //===----------------------------------------------------------------------===// 713 714 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 715 /// comparison predicates. 716 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 717 const APFloat &lhs, const APFloat &rhs) { 718 auto cmpResult = lhs.compare(rhs); 719 switch (predicate) { 720 case arith::CmpFPredicate::AlwaysFalse: 721 return false; 722 case arith::CmpFPredicate::OEQ: 723 return cmpResult == APFloat::cmpEqual; 724 case arith::CmpFPredicate::OGT: 725 return cmpResult == APFloat::cmpGreaterThan; 726 case arith::CmpFPredicate::OGE: 727 return cmpResult == APFloat::cmpGreaterThan || 728 cmpResult == APFloat::cmpEqual; 729 case arith::CmpFPredicate::OLT: 730 return cmpResult == APFloat::cmpLessThan; 731 case arith::CmpFPredicate::OLE: 732 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 733 case arith::CmpFPredicate::ONE: 734 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 735 case arith::CmpFPredicate::ORD: 736 return cmpResult != APFloat::cmpUnordered; 737 case arith::CmpFPredicate::UEQ: 738 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 739 case arith::CmpFPredicate::UGT: 740 return cmpResult == APFloat::cmpUnordered || 741 cmpResult == APFloat::cmpGreaterThan; 742 case arith::CmpFPredicate::UGE: 743 return cmpResult == APFloat::cmpUnordered || 744 cmpResult == APFloat::cmpGreaterThan || 745 cmpResult == APFloat::cmpEqual; 746 case arith::CmpFPredicate::ULT: 747 return cmpResult == APFloat::cmpUnordered || 748 cmpResult == APFloat::cmpLessThan; 749 case arith::CmpFPredicate::ULE: 750 return cmpResult == APFloat::cmpUnordered || 751 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 752 case arith::CmpFPredicate::UNE: 753 return cmpResult != APFloat::cmpEqual; 754 case arith::CmpFPredicate::UNO: 755 return cmpResult == APFloat::cmpUnordered; 756 case arith::CmpFPredicate::AlwaysTrue: 757 return true; 758 } 759 llvm_unreachable("unknown cmpf predicate kind"); 760 } 761 762 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 763 assert(operands.size() == 2 && "cmpf takes two operands"); 764 765 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 766 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 767 768 if (!lhs || !rhs) 769 return {}; 770 771 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 772 return BoolAttr::get(getContext(), val); 773 } 774 775 //===----------------------------------------------------------------------===// 776 // TableGen'd op method definitions 777 //===----------------------------------------------------------------------===// 778 779 #define GET_OP_CLASSES 780 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 781 782 //===----------------------------------------------------------------------===// 783 // TableGen'd enum attribute definitions 784 //===----------------------------------------------------------------------===// 785 786 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 787