1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/CommonFolders.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 using namespace mlir; 18 using namespace mlir::arith; 19 20 //===----------------------------------------------------------------------===// 21 // Pattern helpers 22 //===----------------------------------------------------------------------===// 23 24 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 25 Attribute lhs, Attribute rhs) { 26 return builder.getIntegerAttr(res.getType(), 27 lhs.cast<IntegerAttr>().getInt() + 28 rhs.cast<IntegerAttr>().getInt()); 29 } 30 31 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 32 Attribute lhs, Attribute rhs) { 33 return builder.getIntegerAttr(res.getType(), 34 lhs.cast<IntegerAttr>().getInt() - 35 rhs.cast<IntegerAttr>().getInt()); 36 } 37 38 /// Invert an integer comparison predicate. 39 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { 40 switch (pred) { 41 case arith::CmpIPredicate::eq: 42 return arith::CmpIPredicate::ne; 43 case arith::CmpIPredicate::ne: 44 return arith::CmpIPredicate::eq; 45 case arith::CmpIPredicate::slt: 46 return arith::CmpIPredicate::sge; 47 case arith::CmpIPredicate::sle: 48 return arith::CmpIPredicate::sgt; 49 case arith::CmpIPredicate::sgt: 50 return arith::CmpIPredicate::sle; 51 case arith::CmpIPredicate::sge: 52 return arith::CmpIPredicate::slt; 53 case arith::CmpIPredicate::ult: 54 return arith::CmpIPredicate::uge; 55 case arith::CmpIPredicate::ule: 56 return arith::CmpIPredicate::ugt; 57 case arith::CmpIPredicate::ugt: 58 return arith::CmpIPredicate::ule; 59 case arith::CmpIPredicate::uge: 60 return arith::CmpIPredicate::ult; 61 } 62 llvm_unreachable("unknown cmpi predicate kind"); 63 } 64 65 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 66 return arith::CmpIPredicateAttr::get(pred.getContext(), 67 invertPredicate(pred.getValue())); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // TableGen'd canonicalization patterns 72 //===----------------------------------------------------------------------===// 73 74 namespace { 75 #include "ArithmeticCanonicalization.inc" 76 } // end anonymous namespace 77 78 //===----------------------------------------------------------------------===// 79 // ConstantOp 80 //===----------------------------------------------------------------------===// 81 82 void arith::ConstantOp::getAsmResultNames( 83 function_ref<void(Value, StringRef)> setNameFn) { 84 auto type = getType(); 85 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 86 auto intType = type.dyn_cast<IntegerType>(); 87 88 // Sugar i1 constants with 'true' and 'false'. 89 if (intType && intType.getWidth() == 1) 90 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 91 92 // Otherwise, build a compex name with the value and type. 93 SmallString<32> specialNameBuffer; 94 llvm::raw_svector_ostream specialName(specialNameBuffer); 95 specialName << 'c' << intCst.getInt(); 96 if (intType) 97 specialName << '_' << type; 98 setNameFn(getResult(), specialName.str()); 99 } else { 100 setNameFn(getResult(), "cst"); 101 } 102 } 103 104 /// TODO: disallow arith.constant to return anything other than signless integer 105 /// or float like. 106 static LogicalResult verify(arith::ConstantOp op) { 107 auto type = op.getType(); 108 // The value's type must match the return type. 109 if (op.getValue().getType() != type) { 110 return op.emitOpError() << "value type " << op.getValue().getType() 111 << " must match return type: " << type; 112 } 113 // Integer values must be signless. 114 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 115 return op.emitOpError("integer return type must be signless"); 116 // Any float or elements attribute are acceptable. 117 if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 118 return op.emitOpError( 119 "value must be an integer, float, or elements attribute"); 120 } 121 return success(); 122 } 123 124 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 125 // The value's type must be the same as the provided type. 126 if (value.getType() != type) 127 return false; 128 // Integer values must be signless. 129 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 130 return false; 131 // Integer, float, and element attributes are buildable. 132 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 133 } 134 135 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 136 return getValue(); 137 } 138 139 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 140 int64_t value, unsigned width) { 141 auto type = builder.getIntegerType(width); 142 arith::ConstantOp::build(builder, result, type, 143 builder.getIntegerAttr(type, value)); 144 } 145 146 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 147 int64_t value, Type type) { 148 assert(type.isSignlessInteger() && 149 "ConstantIntOp can only have signless integer type values"); 150 arith::ConstantOp::build(builder, result, type, 151 builder.getIntegerAttr(type, value)); 152 } 153 154 bool arith::ConstantIntOp::classof(Operation *op) { 155 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 156 return constOp.getType().isSignlessInteger(); 157 return false; 158 } 159 160 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 161 const APFloat &value, FloatType type) { 162 arith::ConstantOp::build(builder, result, type, 163 builder.getFloatAttr(type, value)); 164 } 165 166 bool arith::ConstantFloatOp::classof(Operation *op) { 167 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 168 return constOp.getType().isa<FloatType>(); 169 return false; 170 } 171 172 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 173 int64_t value) { 174 arith::ConstantOp::build(builder, result, builder.getIndexType(), 175 builder.getIndexAttr(value)); 176 } 177 178 bool arith::ConstantIndexOp::classof(Operation *op) { 179 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 180 return constOp.getType().isIndex(); 181 return false; 182 } 183 184 //===----------------------------------------------------------------------===// 185 // AddIOp 186 //===----------------------------------------------------------------------===// 187 188 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 189 // addi(x, 0) -> x 190 if (matchPattern(getRhs(), m_Zero())) 191 return getLhs(); 192 193 return constFoldBinaryOp<IntegerAttr>(operands, 194 [](APInt a, APInt b) { return a + b; }); 195 } 196 197 void arith::AddIOp::getCanonicalizationPatterns( 198 OwningRewritePatternList &patterns, MLIRContext *context) { 199 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 200 context); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // SubIOp 205 //===----------------------------------------------------------------------===// 206 207 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 208 // subi(x,x) -> 0 209 if (getOperand(0) == getOperand(1)) 210 return Builder(getContext()).getZeroAttr(getType()); 211 // subi(x,0) -> x 212 if (matchPattern(getRhs(), m_Zero())) 213 return getLhs(); 214 215 return constFoldBinaryOp<IntegerAttr>(operands, 216 [](APInt a, APInt b) { return a - b; }); 217 } 218 219 void arith::SubIOp::getCanonicalizationPatterns( 220 OwningRewritePatternList &patterns, MLIRContext *context) { 221 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 222 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 223 SubILHSSubConstantLHS>(context); 224 } 225 226 //===----------------------------------------------------------------------===// 227 // MulIOp 228 //===----------------------------------------------------------------------===// 229 230 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 231 // muli(x, 0) -> 0 232 if (matchPattern(getRhs(), m_Zero())) 233 return getRhs(); 234 // muli(x, 1) -> x 235 if (matchPattern(getRhs(), m_One())) 236 return getOperand(0); 237 // TODO: Handle the overflow case. 238 239 // default folder 240 return constFoldBinaryOp<IntegerAttr>(operands, 241 [](APInt a, APInt b) { return a * b; }); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // DivUIOp 246 //===----------------------------------------------------------------------===// 247 248 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 249 // Don't fold if it would require a division by zero. 250 bool div0 = false; 251 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 252 if (div0 || !b) { 253 div0 = true; 254 return a; 255 } 256 return a.udiv(b); 257 }); 258 259 // Fold out division by one. Assumes all tensors of all ones are splats. 260 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 261 if (rhs.getValue() == 1) 262 return getLhs(); 263 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 264 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 265 return getLhs(); 266 } 267 268 return div0 ? Attribute() : result; 269 } 270 271 //===----------------------------------------------------------------------===// 272 // DivSIOp 273 //===----------------------------------------------------------------------===// 274 275 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 276 // Don't fold if it would overflow or if it requires a division by zero. 277 bool overflowOrDiv0 = false; 278 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 279 if (overflowOrDiv0 || !b) { 280 overflowOrDiv0 = true; 281 return a; 282 } 283 return a.sdiv_ov(b, overflowOrDiv0); 284 }); 285 286 // Fold out division by one. Assumes all tensors of all ones are splats. 287 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 288 if (rhs.getValue() == 1) 289 return getLhs(); 290 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 291 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 292 return getLhs(); 293 } 294 295 return overflowOrDiv0 ? Attribute() : result; 296 } 297 298 //===----------------------------------------------------------------------===// 299 // Ceil and floor division folding helpers 300 //===----------------------------------------------------------------------===// 301 302 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { 303 // Returns (a-1)/b + 1 304 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 305 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 306 return val.sadd_ov(one, overflow); 307 } 308 309 //===----------------------------------------------------------------------===// 310 // CeilDivUIOp 311 //===----------------------------------------------------------------------===// 312 313 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) { 314 bool overflowOrDiv0 = false; 315 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 316 if (overflowOrDiv0 || !b) { 317 overflowOrDiv0 = true; 318 return a; 319 } 320 APInt quotient = a.udiv(b); 321 if (!a.urem(b)) 322 return quotient; 323 APInt one(a.getBitWidth(), 1, true); 324 return quotient.uadd_ov(one, overflowOrDiv0); 325 }); 326 // Fold out ceil division by one. Assumes all tensors of all ones are 327 // splats. 328 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 329 if (rhs.getValue() == 1) 330 return getLhs(); 331 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 332 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 333 return getLhs(); 334 } 335 336 return overflowOrDiv0 ? Attribute() : result; 337 } 338 339 //===----------------------------------------------------------------------===// 340 // CeilDivSIOp 341 //===----------------------------------------------------------------------===// 342 343 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 344 // Don't fold if it would overflow or if it requires a division by zero. 345 bool overflowOrDiv0 = false; 346 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 347 if (overflowOrDiv0 || !b) { 348 overflowOrDiv0 = true; 349 return a; 350 } 351 unsigned bits = a.getBitWidth(); 352 APInt zero = APInt::getZero(bits); 353 if (a.sgt(zero) && b.sgt(zero)) { 354 // Both positive, return ceil(a, b). 355 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 356 } 357 if (a.slt(zero) && b.slt(zero)) { 358 // Both negative, return ceil(-a, -b). 359 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 360 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 361 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 362 } 363 if (a.slt(zero) && b.sgt(zero)) { 364 // A is negative, b is positive, return - ( -a / b). 365 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 366 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 367 return zero.ssub_ov(div, overflowOrDiv0); 368 } 369 // A is positive (or zero), b is negative, return - (a / -b). 370 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 371 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 372 return zero.ssub_ov(div, overflowOrDiv0); 373 }); 374 375 // Fold out ceil division by one. Assumes all tensors of all ones are 376 // splats. 377 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 378 if (rhs.getValue() == 1) 379 return getLhs(); 380 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 381 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 382 return getLhs(); 383 } 384 385 return overflowOrDiv0 ? Attribute() : result; 386 } 387 388 //===----------------------------------------------------------------------===// 389 // FloorDivSIOp 390 //===----------------------------------------------------------------------===// 391 392 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 393 // Don't fold if it would overflow or if it requires a division by zero. 394 bool overflowOrDiv0 = false; 395 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 396 if (overflowOrDiv0 || !b) { 397 overflowOrDiv0 = true; 398 return a; 399 } 400 unsigned bits = a.getBitWidth(); 401 APInt zero = APInt::getZero(bits); 402 if (a.sge(zero) && b.sgt(zero)) { 403 // Both positive (or a is zero), return a / b. 404 return a.sdiv_ov(b, overflowOrDiv0); 405 } 406 if (a.sle(zero) && b.slt(zero)) { 407 // Both negative (or a is zero), return -a / -b. 408 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 409 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 410 return posA.sdiv_ov(posB, overflowOrDiv0); 411 } 412 if (a.slt(zero) && b.sgt(zero)) { 413 // A is negative, b is positive, return - ceil(-a, b). 414 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 415 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 416 return zero.ssub_ov(ceil, overflowOrDiv0); 417 } 418 // A is positive, b is negative, return - ceil(a, -b). 419 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 420 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 421 return zero.ssub_ov(ceil, overflowOrDiv0); 422 }); 423 424 // Fold out floor division by one. Assumes all tensors of all ones are 425 // splats. 426 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 427 if (rhs.getValue() == 1) 428 return getLhs(); 429 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 430 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 431 return getLhs(); 432 } 433 434 return overflowOrDiv0 ? Attribute() : result; 435 } 436 437 //===----------------------------------------------------------------------===// 438 // RemUIOp 439 //===----------------------------------------------------------------------===// 440 441 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 442 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 443 if (!rhs) 444 return {}; 445 auto rhsValue = rhs.getValue(); 446 447 // x % 1 = 0 448 if (rhsValue.isOneValue()) 449 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 450 451 // Don't fold if it requires division by zero. 452 if (rhsValue.isNullValue()) 453 return {}; 454 455 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 456 if (!lhs) 457 return {}; 458 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 459 } 460 461 //===----------------------------------------------------------------------===// 462 // RemSIOp 463 //===----------------------------------------------------------------------===// 464 465 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 466 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 467 if (!rhs) 468 return {}; 469 auto rhsValue = rhs.getValue(); 470 471 // x % 1 = 0 472 if (rhsValue.isOneValue()) 473 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 474 475 // Don't fold if it requires division by zero. 476 if (rhsValue.isNullValue()) 477 return {}; 478 479 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 480 if (!lhs) 481 return {}; 482 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 483 } 484 485 //===----------------------------------------------------------------------===// 486 // AndIOp 487 //===----------------------------------------------------------------------===// 488 489 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 490 /// and(x, 0) -> 0 491 if (matchPattern(getRhs(), m_Zero())) 492 return getRhs(); 493 /// and(x, allOnes) -> x 494 APInt intValue; 495 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 496 return getLhs(); 497 /// and(x, x) -> x 498 if (getLhs() == getRhs()) 499 return getRhs(); 500 501 return constFoldBinaryOp<IntegerAttr>(operands, 502 [](APInt a, APInt b) { return a & b; }); 503 } 504 505 //===----------------------------------------------------------------------===// 506 // OrIOp 507 //===----------------------------------------------------------------------===// 508 509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 510 /// or(x, 0) -> x 511 if (matchPattern(getRhs(), m_Zero())) 512 return getLhs(); 513 /// or(x, x) -> x 514 if (getLhs() == getRhs()) 515 return getRhs(); 516 /// or(x, <all ones>) -> <all ones> 517 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 518 if (rhsAttr.getValue().isAllOnes()) 519 return rhsAttr; 520 521 return constFoldBinaryOp<IntegerAttr>(operands, 522 [](APInt a, APInt b) { return a | b; }); 523 } 524 525 //===----------------------------------------------------------------------===// 526 // XOrIOp 527 //===----------------------------------------------------------------------===// 528 529 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 530 /// xor(x, 0) -> x 531 if (matchPattern(getRhs(), m_Zero())) 532 return getLhs(); 533 /// xor(x, x) -> 0 534 if (getLhs() == getRhs()) 535 return Builder(getContext()).getZeroAttr(getType()); 536 537 return constFoldBinaryOp<IntegerAttr>(operands, 538 [](APInt a, APInt b) { return a ^ b; }); 539 } 540 541 void arith::XOrIOp::getCanonicalizationPatterns( 542 OwningRewritePatternList &patterns, MLIRContext *context) { 543 patterns.insert<XOrINotCmpI>(context); 544 } 545 546 //===----------------------------------------------------------------------===// 547 // AddFOp 548 //===----------------------------------------------------------------------===// 549 550 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 551 return constFoldBinaryOp<FloatAttr>( 552 operands, [](APFloat a, APFloat b) { return a + b; }); 553 } 554 555 //===----------------------------------------------------------------------===// 556 // SubFOp 557 //===----------------------------------------------------------------------===// 558 559 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 560 return constFoldBinaryOp<FloatAttr>( 561 operands, [](APFloat a, APFloat b) { return a - b; }); 562 } 563 564 //===----------------------------------------------------------------------===// 565 // MulFOp 566 //===----------------------------------------------------------------------===// 567 568 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 569 return constFoldBinaryOp<FloatAttr>( 570 operands, [](APFloat a, APFloat b) { return a * b; }); 571 } 572 573 //===----------------------------------------------------------------------===// 574 // DivFOp 575 //===----------------------------------------------------------------------===// 576 577 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 578 return constFoldBinaryOp<FloatAttr>( 579 operands, [](APFloat a, APFloat b) { return a / b; }); 580 } 581 582 //===----------------------------------------------------------------------===// 583 // Utility functions for verifying cast ops 584 //===----------------------------------------------------------------------===// 585 586 template <typename... Types> 587 using type_list = std::tuple<Types...> *; 588 589 /// Returns a non-null type only if the provided type is one of the allowed 590 /// types or one of the allowed shaped types of the allowed types. Returns the 591 /// element type if a valid shaped type is provided. 592 template <typename... ShapedTypes, typename... ElementTypes> 593 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 594 type_list<ElementTypes...>) { 595 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 596 return {}; 597 598 auto underlyingType = getElementTypeOrSelf(type); 599 if (!underlyingType.isa<ElementTypes...>()) 600 return {}; 601 602 return underlyingType; 603 } 604 605 /// Get allowed underlying types for vectors and tensors. 606 template <typename... ElementTypes> 607 static Type getTypeIfLike(Type type) { 608 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 609 type_list<ElementTypes...>()); 610 } 611 612 /// Get allowed underlying types for vectors, tensors, and memrefs. 613 template <typename... ElementTypes> 614 static Type getTypeIfLikeOrMemRef(Type type) { 615 return getUnderlyingType(type, 616 type_list<VectorType, TensorType, MemRefType>(), 617 type_list<ElementTypes...>()); 618 } 619 620 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 621 return inputs.size() == 1 && outputs.size() == 1 && 622 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 623 } 624 625 //===----------------------------------------------------------------------===// 626 // Verifiers for integer and floating point extension/truncation ops 627 //===----------------------------------------------------------------------===// 628 629 // Extend ops can only extend to a wider type. 630 template <typename ValType, typename Op> 631 static LogicalResult verifyExtOp(Op op) { 632 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 633 Type dstType = getElementTypeOrSelf(op.getType()); 634 635 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 636 return op.emitError("result type ") 637 << dstType << " must be wider than operand type " << srcType; 638 639 return success(); 640 } 641 642 // Truncate ops can only truncate to a shorter type. 643 template <typename ValType, typename Op> 644 static LogicalResult verifyTruncateOp(Op op) { 645 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 646 Type dstType = getElementTypeOrSelf(op.getType()); 647 648 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 649 return op.emitError("result type ") 650 << dstType << " must be shorter than operand type " << srcType; 651 652 return success(); 653 } 654 655 /// Validate a cast that changes the width of a type. 656 template <template <typename> class WidthComparator, typename... ElementTypes> 657 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 658 if (!areValidCastInputsAndOutputs(inputs, outputs)) 659 return false; 660 661 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 662 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 663 if (!srcType || !dstType) 664 return false; 665 666 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 667 srcType.getIntOrFloatBitWidth()); 668 } 669 670 //===----------------------------------------------------------------------===// 671 // ExtUIOp 672 //===----------------------------------------------------------------------===// 673 674 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 675 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 676 return IntegerAttr::get( 677 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 678 679 return {}; 680 } 681 682 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 683 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 684 } 685 686 //===----------------------------------------------------------------------===// 687 // ExtSIOp 688 //===----------------------------------------------------------------------===// 689 690 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 691 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 692 return IntegerAttr::get( 693 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 694 695 return {}; 696 } 697 698 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 699 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 700 } 701 702 //===----------------------------------------------------------------------===// 703 // ExtFOp 704 //===----------------------------------------------------------------------===// 705 706 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 707 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 708 } 709 710 //===----------------------------------------------------------------------===// 711 // TruncIOp 712 //===----------------------------------------------------------------------===// 713 714 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 715 // trunci(zexti(a)) -> a 716 // trunci(sexti(a)) -> a 717 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 718 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 719 return getOperand().getDefiningOp()->getOperand(0); 720 721 assert(operands.size() == 1 && "unary operation takes one operand"); 722 723 if (!operands[0]) 724 return {}; 725 726 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 727 return IntegerAttr::get( 728 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 729 } 730 731 return {}; 732 } 733 734 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 735 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 736 } 737 738 //===----------------------------------------------------------------------===// 739 // TruncFOp 740 //===----------------------------------------------------------------------===// 741 742 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 743 /// can be represented without precision loss or rounding. 744 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 745 assert(operands.size() == 1 && "unary operation takes one operand"); 746 747 auto constOperand = operands.front(); 748 if (!constOperand || !constOperand.isa<FloatAttr>()) 749 return {}; 750 751 // Convert to target type via 'double'. 752 double sourceValue = 753 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 754 auto targetAttr = FloatAttr::get(getType(), sourceValue); 755 756 // Propagate if constant's value does not change after truncation. 757 if (sourceValue == targetAttr.getValue().convertToDouble()) 758 return targetAttr; 759 760 return {}; 761 } 762 763 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 764 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 765 } 766 767 //===----------------------------------------------------------------------===// 768 // Verifiers for casts between integers and floats. 769 //===----------------------------------------------------------------------===// 770 771 template <typename From, typename To> 772 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 773 if (!areValidCastInputsAndOutputs(inputs, outputs)) 774 return false; 775 776 auto srcType = getTypeIfLike<From>(inputs.front()); 777 auto dstType = getTypeIfLike<To>(outputs.back()); 778 779 return srcType && dstType; 780 } 781 782 //===----------------------------------------------------------------------===// 783 // UIToFPOp 784 //===----------------------------------------------------------------------===// 785 786 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 787 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 788 } 789 790 //===----------------------------------------------------------------------===// 791 // SIToFPOp 792 //===----------------------------------------------------------------------===// 793 794 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 795 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 796 } 797 798 //===----------------------------------------------------------------------===// 799 // FPToUIOp 800 //===----------------------------------------------------------------------===// 801 802 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 803 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 804 } 805 806 //===----------------------------------------------------------------------===// 807 // FPToSIOp 808 //===----------------------------------------------------------------------===// 809 810 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 811 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 812 } 813 814 //===----------------------------------------------------------------------===// 815 // IndexCastOp 816 //===----------------------------------------------------------------------===// 817 818 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 819 TypeRange outputs) { 820 if (!areValidCastInputsAndOutputs(inputs, outputs)) 821 return false; 822 823 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 824 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 825 if (!srcType || !dstType) 826 return false; 827 828 return (srcType.isIndex() && dstType.isSignlessInteger()) || 829 (srcType.isSignlessInteger() && dstType.isIndex()); 830 } 831 832 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 833 // index_cast(constant) -> constant 834 // A little hack because we go through int. Otherwise, the size of the 835 // constant might need to change. 836 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 837 return IntegerAttr::get(getType(), value.getInt()); 838 839 return {}; 840 } 841 842 void arith::IndexCastOp::getCanonicalizationPatterns( 843 OwningRewritePatternList &patterns, MLIRContext *context) { 844 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 845 } 846 847 //===----------------------------------------------------------------------===// 848 // BitcastOp 849 //===----------------------------------------------------------------------===// 850 851 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 852 if (!areValidCastInputsAndOutputs(inputs, outputs)) 853 return false; 854 855 auto srcType = 856 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 857 auto dstType = 858 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 859 if (!srcType || !dstType) 860 return false; 861 862 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 863 } 864 865 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 866 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 867 868 auto resType = getType(); 869 auto operand = operands[0]; 870 if (!operand) 871 return {}; 872 873 /// Bitcast dense elements. 874 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 875 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 876 /// Other shaped types unhandled. 877 if (resType.isa<ShapedType>()) 878 return {}; 879 880 /// Bitcast integer or float to integer or float. 881 APInt bits = operand.isa<FloatAttr>() 882 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 883 : operand.cast<IntegerAttr>().getValue(); 884 885 if (auto resFloatType = resType.dyn_cast<FloatType>()) 886 return FloatAttr::get(resType, 887 APFloat(resFloatType.getFloatSemantics(), bits)); 888 return IntegerAttr::get(resType, bits); 889 } 890 891 void arith::BitcastOp::getCanonicalizationPatterns( 892 OwningRewritePatternList &patterns, MLIRContext *context) { 893 patterns.insert<BitcastOfBitcast>(context); 894 } 895 896 //===----------------------------------------------------------------------===// 897 // Helpers for compare ops 898 //===----------------------------------------------------------------------===// 899 900 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 901 static Type getI1SameShape(Type type) { 902 auto i1Type = IntegerType::get(type.getContext(), 1); 903 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 904 return RankedTensorType::get(tensorType.getShape(), i1Type); 905 if (type.isa<UnrankedTensorType>()) 906 return UnrankedTensorType::get(i1Type); 907 if (auto vectorType = type.dyn_cast<VectorType>()) 908 return VectorType::get(vectorType.getShape(), i1Type); 909 return i1Type; 910 } 911 912 //===----------------------------------------------------------------------===// 913 // CmpIOp 914 //===----------------------------------------------------------------------===// 915 916 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 917 /// comparison predicates. 918 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 919 const APInt &lhs, const APInt &rhs) { 920 switch (predicate) { 921 case arith::CmpIPredicate::eq: 922 return lhs.eq(rhs); 923 case arith::CmpIPredicate::ne: 924 return lhs.ne(rhs); 925 case arith::CmpIPredicate::slt: 926 return lhs.slt(rhs); 927 case arith::CmpIPredicate::sle: 928 return lhs.sle(rhs); 929 case arith::CmpIPredicate::sgt: 930 return lhs.sgt(rhs); 931 case arith::CmpIPredicate::sge: 932 return lhs.sge(rhs); 933 case arith::CmpIPredicate::ult: 934 return lhs.ult(rhs); 935 case arith::CmpIPredicate::ule: 936 return lhs.ule(rhs); 937 case arith::CmpIPredicate::ugt: 938 return lhs.ugt(rhs); 939 case arith::CmpIPredicate::uge: 940 return lhs.uge(rhs); 941 } 942 llvm_unreachable("unknown cmpi predicate kind"); 943 } 944 945 /// Returns true if the predicate is true for two equal operands. 946 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 947 switch (predicate) { 948 case arith::CmpIPredicate::eq: 949 case arith::CmpIPredicate::sle: 950 case arith::CmpIPredicate::sge: 951 case arith::CmpIPredicate::ule: 952 case arith::CmpIPredicate::uge: 953 return true; 954 case arith::CmpIPredicate::ne: 955 case arith::CmpIPredicate::slt: 956 case arith::CmpIPredicate::sgt: 957 case arith::CmpIPredicate::ult: 958 case arith::CmpIPredicate::ugt: 959 return false; 960 } 961 llvm_unreachable("unknown cmpi predicate kind"); 962 } 963 964 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 965 assert(operands.size() == 2 && "cmpi takes two operands"); 966 967 // cmpi(pred, x, x) 968 if (getLhs() == getRhs()) { 969 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 970 return BoolAttr::get(getContext(), val); 971 } 972 973 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 974 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 975 if (!lhs || !rhs) 976 return {}; 977 978 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 979 return BoolAttr::get(getContext(), val); 980 } 981 982 //===----------------------------------------------------------------------===// 983 // CmpFOp 984 //===----------------------------------------------------------------------===// 985 986 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 987 /// comparison predicates. 988 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 989 const APFloat &lhs, const APFloat &rhs) { 990 auto cmpResult = lhs.compare(rhs); 991 switch (predicate) { 992 case arith::CmpFPredicate::AlwaysFalse: 993 return false; 994 case arith::CmpFPredicate::OEQ: 995 return cmpResult == APFloat::cmpEqual; 996 case arith::CmpFPredicate::OGT: 997 return cmpResult == APFloat::cmpGreaterThan; 998 case arith::CmpFPredicate::OGE: 999 return cmpResult == APFloat::cmpGreaterThan || 1000 cmpResult == APFloat::cmpEqual; 1001 case arith::CmpFPredicate::OLT: 1002 return cmpResult == APFloat::cmpLessThan; 1003 case arith::CmpFPredicate::OLE: 1004 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1005 case arith::CmpFPredicate::ONE: 1006 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 1007 case arith::CmpFPredicate::ORD: 1008 return cmpResult != APFloat::cmpUnordered; 1009 case arith::CmpFPredicate::UEQ: 1010 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 1011 case arith::CmpFPredicate::UGT: 1012 return cmpResult == APFloat::cmpUnordered || 1013 cmpResult == APFloat::cmpGreaterThan; 1014 case arith::CmpFPredicate::UGE: 1015 return cmpResult == APFloat::cmpUnordered || 1016 cmpResult == APFloat::cmpGreaterThan || 1017 cmpResult == APFloat::cmpEqual; 1018 case arith::CmpFPredicate::ULT: 1019 return cmpResult == APFloat::cmpUnordered || 1020 cmpResult == APFloat::cmpLessThan; 1021 case arith::CmpFPredicate::ULE: 1022 return cmpResult == APFloat::cmpUnordered || 1023 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 1024 case arith::CmpFPredicate::UNE: 1025 return cmpResult != APFloat::cmpEqual; 1026 case arith::CmpFPredicate::UNO: 1027 return cmpResult == APFloat::cmpUnordered; 1028 case arith::CmpFPredicate::AlwaysTrue: 1029 return true; 1030 } 1031 llvm_unreachable("unknown cmpf predicate kind"); 1032 } 1033 1034 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1035 assert(operands.size() == 2 && "cmpf takes two operands"); 1036 1037 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1038 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1039 1040 if (!lhs || !rhs) 1041 return {}; 1042 1043 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1044 return BoolAttr::get(getContext(), val); 1045 } 1046 1047 //===----------------------------------------------------------------------===// 1048 // TableGen'd op method definitions 1049 //===----------------------------------------------------------------------===// 1050 1051 #define GET_OP_CLASSES 1052 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1053 1054 //===----------------------------------------------------------------------===// 1055 // TableGen'd enum attribute definitions 1056 //===----------------------------------------------------------------------===// 1057 1058 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1059