1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/CommonFolders.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 using namespace mlir; 18 using namespace mlir::arith; 19 20 //===----------------------------------------------------------------------===// 21 // Pattern helpers 22 //===----------------------------------------------------------------------===// 23 24 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, 25 Attribute lhs, Attribute rhs) { 26 return builder.getIntegerAttr(res.getType(), 27 lhs.cast<IntegerAttr>().getInt() + 28 rhs.cast<IntegerAttr>().getInt()); 29 } 30 31 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, 32 Attribute lhs, Attribute rhs) { 33 return builder.getIntegerAttr(res.getType(), 34 lhs.cast<IntegerAttr>().getInt() - 35 rhs.cast<IntegerAttr>().getInt()); 36 } 37 38 /// Invert an integer comparison predicate. 39 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { 40 switch (pred) { 41 case arith::CmpIPredicate::eq: 42 return arith::CmpIPredicate::ne; 43 case arith::CmpIPredicate::ne: 44 return arith::CmpIPredicate::eq; 45 case arith::CmpIPredicate::slt: 46 return arith::CmpIPredicate::sge; 47 case arith::CmpIPredicate::sle: 48 return arith::CmpIPredicate::sgt; 49 case arith::CmpIPredicate::sgt: 50 return arith::CmpIPredicate::sle; 51 case arith::CmpIPredicate::sge: 52 return arith::CmpIPredicate::slt; 53 case arith::CmpIPredicate::ult: 54 return arith::CmpIPredicate::uge; 55 case arith::CmpIPredicate::ule: 56 return arith::CmpIPredicate::ugt; 57 case arith::CmpIPredicate::ugt: 58 return arith::CmpIPredicate::ule; 59 case arith::CmpIPredicate::uge: 60 return arith::CmpIPredicate::ult; 61 } 62 llvm_unreachable("unknown cmpi predicate kind"); 63 } 64 65 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { 66 return arith::CmpIPredicateAttr::get(pred.getContext(), 67 invertPredicate(pred.getValue())); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // TableGen'd canonicalization patterns 72 //===----------------------------------------------------------------------===// 73 74 namespace { 75 #include "ArithmeticCanonicalization.inc" 76 } // end anonymous namespace 77 78 //===----------------------------------------------------------------------===// 79 // ConstantOp 80 //===----------------------------------------------------------------------===// 81 82 void arith::ConstantOp::getAsmResultNames( 83 function_ref<void(Value, StringRef)> setNameFn) { 84 auto type = getType(); 85 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { 86 auto intType = type.dyn_cast<IntegerType>(); 87 88 // Sugar i1 constants with 'true' and 'false'. 89 if (intType && intType.getWidth() == 1) 90 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); 91 92 // Otherwise, build a compex name with the value and type. 93 SmallString<32> specialNameBuffer; 94 llvm::raw_svector_ostream specialName(specialNameBuffer); 95 specialName << 'c' << intCst.getInt(); 96 if (intType) 97 specialName << '_' << type; 98 setNameFn(getResult(), specialName.str()); 99 } else { 100 setNameFn(getResult(), "cst"); 101 } 102 } 103 104 /// TODO: disallow arith.constant to return anything other than signless integer 105 /// or float like. 106 static LogicalResult verify(arith::ConstantOp op) { 107 auto type = op.getType(); 108 // The value's type must match the return type. 109 if (op.getValue().getType() != type) { 110 return op.emitOpError() << "value type " << op.getValue().getType() 111 << " must match return type: " << type; 112 } 113 // Integer values must be signless. 114 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 115 return op.emitOpError("integer return type must be signless"); 116 // Any float or elements attribute are acceptable. 117 if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) { 118 return op.emitOpError( 119 "value must be an integer, float, or elements attribute"); 120 } 121 return success(); 122 } 123 124 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { 125 // The value's type must be the same as the provided type. 126 if (value.getType() != type) 127 return false; 128 // Integer values must be signless. 129 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless()) 130 return false; 131 // Integer, float, and element attributes are buildable. 132 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>(); 133 } 134 135 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) { 136 return getValue(); 137 } 138 139 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 140 int64_t value, unsigned width) { 141 auto type = builder.getIntegerType(width); 142 arith::ConstantOp::build(builder, result, type, 143 builder.getIntegerAttr(type, value)); 144 } 145 146 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, 147 int64_t value, Type type) { 148 assert(type.isSignlessInteger() && 149 "ConstantIntOp can only have signless integer type values"); 150 arith::ConstantOp::build(builder, result, type, 151 builder.getIntegerAttr(type, value)); 152 } 153 154 bool arith::ConstantIntOp::classof(Operation *op) { 155 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 156 return constOp.getType().isSignlessInteger(); 157 return false; 158 } 159 160 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, 161 const APFloat &value, FloatType type) { 162 arith::ConstantOp::build(builder, result, type, 163 builder.getFloatAttr(type, value)); 164 } 165 166 bool arith::ConstantFloatOp::classof(Operation *op) { 167 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 168 return constOp.getType().isa<FloatType>(); 169 return false; 170 } 171 172 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, 173 int64_t value) { 174 arith::ConstantOp::build(builder, result, builder.getIndexType(), 175 builder.getIndexAttr(value)); 176 } 177 178 bool arith::ConstantIndexOp::classof(Operation *op) { 179 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) 180 return constOp.getType().isIndex(); 181 return false; 182 } 183 184 //===----------------------------------------------------------------------===// 185 // AddIOp 186 //===----------------------------------------------------------------------===// 187 188 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) { 189 // addi(x, 0) -> x 190 if (matchPattern(getRhs(), m_Zero())) 191 return getLhs(); 192 193 return constFoldBinaryOp<IntegerAttr>(operands, 194 [](APInt a, APInt b) { return a + b; }); 195 } 196 197 void arith::AddIOp::getCanonicalizationPatterns( 198 OwningRewritePatternList &patterns, MLIRContext *context) { 199 patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>( 200 context); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // SubIOp 205 //===----------------------------------------------------------------------===// 206 207 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) { 208 // subi(x,x) -> 0 209 if (getOperand(0) == getOperand(1)) 210 return Builder(getContext()).getZeroAttr(getType()); 211 // subi(x,0) -> x 212 if (matchPattern(getRhs(), m_Zero())) 213 return getLhs(); 214 215 return constFoldBinaryOp<IntegerAttr>(operands, 216 [](APInt a, APInt b) { return a - b; }); 217 } 218 219 void arith::SubIOp::getCanonicalizationPatterns( 220 OwningRewritePatternList &patterns, MLIRContext *context) { 221 patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, 222 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, 223 SubILHSSubConstantLHS>(context); 224 } 225 226 //===----------------------------------------------------------------------===// 227 // MulIOp 228 //===----------------------------------------------------------------------===// 229 230 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) { 231 // muli(x, 0) -> 0 232 if (matchPattern(getRhs(), m_Zero())) 233 return getRhs(); 234 // muli(x, 1) -> x 235 if (matchPattern(getRhs(), m_One())) 236 return getOperand(0); 237 // TODO: Handle the overflow case. 238 239 // default folder 240 return constFoldBinaryOp<IntegerAttr>(operands, 241 [](APInt a, APInt b) { return a * b; }); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // DivUIOp 246 //===----------------------------------------------------------------------===// 247 248 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) { 249 // Don't fold if it would require a division by zero. 250 bool div0 = false; 251 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 252 if (div0 || !b) { 253 div0 = true; 254 return a; 255 } 256 return a.udiv(b); 257 }); 258 259 // Fold out division by one. Assumes all tensors of all ones are splats. 260 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 261 if (rhs.getValue() == 1) 262 return getLhs(); 263 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 264 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 265 return getLhs(); 266 } 267 268 return div0 ? Attribute() : result; 269 } 270 271 //===----------------------------------------------------------------------===// 272 // DivSIOp 273 //===----------------------------------------------------------------------===// 274 275 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) { 276 // Don't fold if it would overflow or if it requires a division by zero. 277 bool overflowOrDiv0 = false; 278 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 279 if (overflowOrDiv0 || !b) { 280 overflowOrDiv0 = true; 281 return a; 282 } 283 return a.sdiv_ov(b, overflowOrDiv0); 284 }); 285 286 // Fold out division by one. Assumes all tensors of all ones are splats. 287 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 288 if (rhs.getValue() == 1) 289 return getLhs(); 290 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 291 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 292 return getLhs(); 293 } 294 295 return overflowOrDiv0 ? Attribute() : result; 296 } 297 298 //===----------------------------------------------------------------------===// 299 // Ceil and floor division folding helpers 300 //===----------------------------------------------------------------------===// 301 302 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { 303 // Returns (a-1)/b + 1 304 APInt one(a.getBitWidth(), 1, true); // Signed value 1. 305 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); 306 return val.sadd_ov(one, overflow); 307 } 308 309 //===----------------------------------------------------------------------===// 310 // CeilDivSIOp 311 //===----------------------------------------------------------------------===// 312 313 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) { 314 // Don't fold if it would overflow or if it requires a division by zero. 315 bool overflowOrDiv0 = false; 316 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 317 if (overflowOrDiv0 || !b) { 318 overflowOrDiv0 = true; 319 return a; 320 } 321 unsigned bits = a.getBitWidth(); 322 APInt zero = APInt::getZero(bits); 323 if (a.sgt(zero) && b.sgt(zero)) { 324 // Both positive, return ceil(a, b). 325 return signedCeilNonnegInputs(a, b, overflowOrDiv0); 326 } 327 if (a.slt(zero) && b.slt(zero)) { 328 // Both negative, return ceil(-a, -b). 329 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 330 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 331 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); 332 } 333 if (a.slt(zero) && b.sgt(zero)) { 334 // A is negative, b is positive, return - ( -a / b). 335 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 336 APInt div = posA.sdiv_ov(b, overflowOrDiv0); 337 return zero.ssub_ov(div, overflowOrDiv0); 338 } 339 // A is positive (or zero), b is negative, return - (a / -b). 340 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 341 APInt div = a.sdiv_ov(posB, overflowOrDiv0); 342 return zero.ssub_ov(div, overflowOrDiv0); 343 }); 344 345 // Fold out floor division by one. Assumes all tensors of all ones are 346 // splats. 347 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 348 if (rhs.getValue() == 1) 349 return getLhs(); 350 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 351 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 352 return getLhs(); 353 } 354 355 return overflowOrDiv0 ? Attribute() : result; 356 } 357 358 //===----------------------------------------------------------------------===// 359 // FloorDivSIOp 360 //===----------------------------------------------------------------------===// 361 362 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) { 363 // Don't fold if it would overflow or if it requires a division by zero. 364 bool overflowOrDiv0 = false; 365 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { 366 if (overflowOrDiv0 || !b) { 367 overflowOrDiv0 = true; 368 return a; 369 } 370 unsigned bits = a.getBitWidth(); 371 APInt zero = APInt::getZero(bits); 372 if (a.sge(zero) && b.sgt(zero)) { 373 // Both positive (or a is zero), return a / b. 374 return a.sdiv_ov(b, overflowOrDiv0); 375 } 376 if (a.sle(zero) && b.slt(zero)) { 377 // Both negative (or a is zero), return -a / -b. 378 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 379 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 380 return posA.sdiv_ov(posB, overflowOrDiv0); 381 } 382 if (a.slt(zero) && b.sgt(zero)) { 383 // A is negative, b is positive, return - ceil(-a, b). 384 APInt posA = zero.ssub_ov(a, overflowOrDiv0); 385 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); 386 return zero.ssub_ov(ceil, overflowOrDiv0); 387 } 388 // A is positive, b is negative, return - ceil(a, -b). 389 APInt posB = zero.ssub_ov(b, overflowOrDiv0); 390 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); 391 return zero.ssub_ov(ceil, overflowOrDiv0); 392 }); 393 394 // Fold out floor division by one. Assumes all tensors of all ones are 395 // splats. 396 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { 397 if (rhs.getValue() == 1) 398 return getLhs(); 399 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { 400 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) 401 return getLhs(); 402 } 403 404 return overflowOrDiv0 ? Attribute() : result; 405 } 406 407 //===----------------------------------------------------------------------===// 408 // RemUIOp 409 //===----------------------------------------------------------------------===// 410 411 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) { 412 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 413 if (!rhs) 414 return {}; 415 auto rhsValue = rhs.getValue(); 416 417 // x % 1 = 0 418 if (rhsValue.isOneValue()) 419 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 420 421 // Don't fold if it requires division by zero. 422 if (rhsValue.isNullValue()) 423 return {}; 424 425 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 426 if (!lhs) 427 return {}; 428 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 429 } 430 431 //===----------------------------------------------------------------------===// 432 // RemSIOp 433 //===----------------------------------------------------------------------===// 434 435 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) { 436 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 437 if (!rhs) 438 return {}; 439 auto rhsValue = rhs.getValue(); 440 441 // x % 1 = 0 442 if (rhsValue.isOneValue()) 443 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 444 445 // Don't fold if it requires division by zero. 446 if (rhsValue.isNullValue()) 447 return {}; 448 449 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 450 if (!lhs) 451 return {}; 452 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 453 } 454 455 //===----------------------------------------------------------------------===// 456 // AndIOp 457 //===----------------------------------------------------------------------===// 458 459 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) { 460 /// and(x, 0) -> 0 461 if (matchPattern(getRhs(), m_Zero())) 462 return getRhs(); 463 /// and(x, allOnes) -> x 464 APInt intValue; 465 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) 466 return getLhs(); 467 /// and(x, x) -> x 468 if (getLhs() == getRhs()) 469 return getRhs(); 470 471 return constFoldBinaryOp<IntegerAttr>(operands, 472 [](APInt a, APInt b) { return a & b; }); 473 } 474 475 //===----------------------------------------------------------------------===// 476 // OrIOp 477 //===----------------------------------------------------------------------===// 478 479 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) { 480 /// or(x, 0) -> x 481 if (matchPattern(getRhs(), m_Zero())) 482 return getLhs(); 483 /// or(x, x) -> x 484 if (getLhs() == getRhs()) 485 return getRhs(); 486 /// or(x, <all ones>) -> <all ones> 487 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>()) 488 if (rhsAttr.getValue().isAllOnes()) 489 return rhsAttr; 490 491 return constFoldBinaryOp<IntegerAttr>(operands, 492 [](APInt a, APInt b) { return a | b; }); 493 } 494 495 //===----------------------------------------------------------------------===// 496 // XOrIOp 497 //===----------------------------------------------------------------------===// 498 499 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) { 500 /// xor(x, 0) -> x 501 if (matchPattern(getRhs(), m_Zero())) 502 return getLhs(); 503 /// xor(x, x) -> 0 504 if (getLhs() == getRhs()) 505 return Builder(getContext()).getZeroAttr(getType()); 506 507 return constFoldBinaryOp<IntegerAttr>(operands, 508 [](APInt a, APInt b) { return a ^ b; }); 509 } 510 511 void arith::XOrIOp::getCanonicalizationPatterns( 512 OwningRewritePatternList &patterns, MLIRContext *context) { 513 patterns.insert<XOrINotCmpI>(context); 514 } 515 516 //===----------------------------------------------------------------------===// 517 // AddFOp 518 //===----------------------------------------------------------------------===// 519 520 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { 521 return constFoldBinaryOp<FloatAttr>( 522 operands, [](APFloat a, APFloat b) { return a + b; }); 523 } 524 525 //===----------------------------------------------------------------------===// 526 // SubFOp 527 //===----------------------------------------------------------------------===// 528 529 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { 530 return constFoldBinaryOp<FloatAttr>( 531 operands, [](APFloat a, APFloat b) { return a - b; }); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // MulFOp 536 //===----------------------------------------------------------------------===// 537 538 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { 539 return constFoldBinaryOp<FloatAttr>( 540 operands, [](APFloat a, APFloat b) { return a * b; }); 541 } 542 543 //===----------------------------------------------------------------------===// 544 // DivFOp 545 //===----------------------------------------------------------------------===// 546 547 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { 548 return constFoldBinaryOp<FloatAttr>( 549 operands, [](APFloat a, APFloat b) { return a / b; }); 550 } 551 552 //===----------------------------------------------------------------------===// 553 // Utility functions for verifying cast ops 554 //===----------------------------------------------------------------------===// 555 556 template <typename... Types> 557 using type_list = std::tuple<Types...> *; 558 559 /// Returns a non-null type only if the provided type is one of the allowed 560 /// types or one of the allowed shaped types of the allowed types. Returns the 561 /// element type if a valid shaped type is provided. 562 template <typename... ShapedTypes, typename... ElementTypes> 563 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, 564 type_list<ElementTypes...>) { 565 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>()) 566 return {}; 567 568 auto underlyingType = getElementTypeOrSelf(type); 569 if (!underlyingType.isa<ElementTypes...>()) 570 return {}; 571 572 return underlyingType; 573 } 574 575 /// Get allowed underlying types for vectors and tensors. 576 template <typename... ElementTypes> 577 static Type getTypeIfLike(Type type) { 578 return getUnderlyingType(type, type_list<VectorType, TensorType>(), 579 type_list<ElementTypes...>()); 580 } 581 582 /// Get allowed underlying types for vectors, tensors, and memrefs. 583 template <typename... ElementTypes> 584 static Type getTypeIfLikeOrMemRef(Type type) { 585 return getUnderlyingType(type, 586 type_list<VectorType, TensorType, MemRefType>(), 587 type_list<ElementTypes...>()); 588 } 589 590 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { 591 return inputs.size() == 1 && outputs.size() == 1 && 592 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); 593 } 594 595 //===----------------------------------------------------------------------===// 596 // Verifiers for integer and floating point extension/truncation ops 597 //===----------------------------------------------------------------------===// 598 599 // Extend ops can only extend to a wider type. 600 template <typename ValType, typename Op> 601 static LogicalResult verifyExtOp(Op op) { 602 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 603 Type dstType = getElementTypeOrSelf(op.getType()); 604 605 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth()) 606 return op.emitError("result type ") 607 << dstType << " must be wider than operand type " << srcType; 608 609 return success(); 610 } 611 612 // Truncate ops can only truncate to a shorter type. 613 template <typename ValType, typename Op> 614 static LogicalResult verifyTruncateOp(Op op) { 615 Type srcType = getElementTypeOrSelf(op.getIn().getType()); 616 Type dstType = getElementTypeOrSelf(op.getType()); 617 618 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth()) 619 return op.emitError("result type ") 620 << dstType << " must be shorter than operand type " << srcType; 621 622 return success(); 623 } 624 625 /// Validate a cast that changes the width of a type. 626 template <template <typename> class WidthComparator, typename... ElementTypes> 627 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { 628 if (!areValidCastInputsAndOutputs(inputs, outputs)) 629 return false; 630 631 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); 632 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); 633 if (!srcType || !dstType) 634 return false; 635 636 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), 637 srcType.getIntOrFloatBitWidth()); 638 } 639 640 //===----------------------------------------------------------------------===// 641 // ExtUIOp 642 //===----------------------------------------------------------------------===// 643 644 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { 645 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 646 return IntegerAttr::get( 647 getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); 648 649 return {}; 650 } 651 652 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 653 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 654 } 655 656 //===----------------------------------------------------------------------===// 657 // ExtSIOp 658 //===----------------------------------------------------------------------===// 659 660 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { 661 if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) 662 return IntegerAttr::get( 663 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); 664 665 return {}; 666 } 667 668 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 669 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); 670 } 671 672 //===----------------------------------------------------------------------===// 673 // ExtFOp 674 //===----------------------------------------------------------------------===// 675 676 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 677 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); 678 } 679 680 //===----------------------------------------------------------------------===// 681 // TruncIOp 682 //===----------------------------------------------------------------------===// 683 684 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) { 685 // trunci(zexti(a)) -> a 686 // trunci(sexti(a)) -> a 687 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || 688 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) 689 return getOperand().getDefiningOp()->getOperand(0); 690 691 assert(operands.size() == 1 && "unary operation takes one operand"); 692 693 if (!operands[0]) 694 return {}; 695 696 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { 697 return IntegerAttr::get( 698 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); 699 } 700 701 return {}; 702 } 703 704 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 705 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); 706 } 707 708 //===----------------------------------------------------------------------===// 709 // TruncFOp 710 //===----------------------------------------------------------------------===// 711 712 /// Perform safe const propagation for truncf, i.e. only propagate if FP value 713 /// can be represented without precision loss or rounding. 714 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) { 715 assert(operands.size() == 1 && "unary operation takes one operand"); 716 717 auto constOperand = operands.front(); 718 if (!constOperand || !constOperand.isa<FloatAttr>()) 719 return {}; 720 721 // Convert to target type via 'double'. 722 double sourceValue = 723 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble(); 724 auto targetAttr = FloatAttr::get(getType(), sourceValue); 725 726 // Propagate if constant's value does not change after truncation. 727 if (sourceValue == targetAttr.getValue().convertToDouble()) 728 return targetAttr; 729 730 return {}; 731 } 732 733 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 734 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); 735 } 736 737 //===----------------------------------------------------------------------===// 738 // Verifiers for casts between integers and floats. 739 //===----------------------------------------------------------------------===// 740 741 template <typename From, typename To> 742 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { 743 if (!areValidCastInputsAndOutputs(inputs, outputs)) 744 return false; 745 746 auto srcType = getTypeIfLike<From>(inputs.front()); 747 auto dstType = getTypeIfLike<To>(outputs.back()); 748 749 return srcType && dstType; 750 } 751 752 //===----------------------------------------------------------------------===// 753 // UIToFPOp 754 //===----------------------------------------------------------------------===// 755 756 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 757 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 758 } 759 760 //===----------------------------------------------------------------------===// 761 // SIToFPOp 762 //===----------------------------------------------------------------------===// 763 764 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 765 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); 766 } 767 768 //===----------------------------------------------------------------------===// 769 // FPToUIOp 770 //===----------------------------------------------------------------------===// 771 772 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 773 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 774 } 775 776 //===----------------------------------------------------------------------===// 777 // FPToSIOp 778 //===----------------------------------------------------------------------===// 779 780 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 781 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); 782 } 783 784 //===----------------------------------------------------------------------===// 785 // IndexCastOp 786 //===----------------------------------------------------------------------===// 787 788 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, 789 TypeRange outputs) { 790 if (!areValidCastInputsAndOutputs(inputs, outputs)) 791 return false; 792 793 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front()); 794 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front()); 795 if (!srcType || !dstType) 796 return false; 797 798 return (srcType.isIndex() && dstType.isSignlessInteger()) || 799 (srcType.isSignlessInteger() && dstType.isIndex()); 800 } 801 802 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) { 803 // index_cast(constant) -> constant 804 // A little hack because we go through int. Otherwise, the size of the 805 // constant might need to change. 806 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>()) 807 return IntegerAttr::get(getType(), value.getInt()); 808 809 return {}; 810 } 811 812 void arith::IndexCastOp::getCanonicalizationPatterns( 813 OwningRewritePatternList &patterns, MLIRContext *context) { 814 patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context); 815 } 816 817 //===----------------------------------------------------------------------===// 818 // BitcastOp 819 //===----------------------------------------------------------------------===// 820 821 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 822 if (!areValidCastInputsAndOutputs(inputs, outputs)) 823 return false; 824 825 auto srcType = 826 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); 827 auto dstType = 828 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); 829 if (!srcType || !dstType) 830 return false; 831 832 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); 833 } 834 835 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) { 836 assert(operands.size() == 1 && "bitcast op expects 1 operand"); 837 838 auto resType = getType(); 839 auto operand = operands[0]; 840 if (!operand) 841 return {}; 842 843 /// Bitcast dense elements. 844 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>()) 845 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType()); 846 /// Other shaped types unhandled. 847 if (resType.isa<ShapedType>()) 848 return {}; 849 850 /// Bitcast integer or float to integer or float. 851 APInt bits = operand.isa<FloatAttr>() 852 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt() 853 : operand.cast<IntegerAttr>().getValue(); 854 855 if (auto resFloatType = resType.dyn_cast<FloatType>()) 856 return FloatAttr::get(resType, 857 APFloat(resFloatType.getFloatSemantics(), bits)); 858 return IntegerAttr::get(resType, bits); 859 } 860 861 void arith::BitcastOp::getCanonicalizationPatterns( 862 OwningRewritePatternList &patterns, MLIRContext *context) { 863 patterns.insert<BitcastOfBitcast>(context); 864 } 865 866 //===----------------------------------------------------------------------===// 867 // Helpers for compare ops 868 //===----------------------------------------------------------------------===// 869 870 /// Return the type of the same shape (scalar, vector or tensor) containing i1. 871 static Type getI1SameShape(Type type) { 872 auto i1Type = IntegerType::get(type.getContext(), 1); 873 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 874 return RankedTensorType::get(tensorType.getShape(), i1Type); 875 if (type.isa<UnrankedTensorType>()) 876 return UnrankedTensorType::get(i1Type); 877 if (auto vectorType = type.dyn_cast<VectorType>()) 878 return VectorType::get(vectorType.getShape(), i1Type); 879 return i1Type; 880 } 881 882 //===----------------------------------------------------------------------===// 883 // CmpIOp 884 //===----------------------------------------------------------------------===// 885 886 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 887 /// comparison predicates. 888 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, 889 const APInt &lhs, const APInt &rhs) { 890 switch (predicate) { 891 case arith::CmpIPredicate::eq: 892 return lhs.eq(rhs); 893 case arith::CmpIPredicate::ne: 894 return lhs.ne(rhs); 895 case arith::CmpIPredicate::slt: 896 return lhs.slt(rhs); 897 case arith::CmpIPredicate::sle: 898 return lhs.sle(rhs); 899 case arith::CmpIPredicate::sgt: 900 return lhs.sgt(rhs); 901 case arith::CmpIPredicate::sge: 902 return lhs.sge(rhs); 903 case arith::CmpIPredicate::ult: 904 return lhs.ult(rhs); 905 case arith::CmpIPredicate::ule: 906 return lhs.ule(rhs); 907 case arith::CmpIPredicate::ugt: 908 return lhs.ugt(rhs); 909 case arith::CmpIPredicate::uge: 910 return lhs.uge(rhs); 911 } 912 llvm_unreachable("unknown cmpi predicate kind"); 913 } 914 915 /// Returns true if the predicate is true for two equal operands. 916 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { 917 switch (predicate) { 918 case arith::CmpIPredicate::eq: 919 case arith::CmpIPredicate::sle: 920 case arith::CmpIPredicate::sge: 921 case arith::CmpIPredicate::ule: 922 case arith::CmpIPredicate::uge: 923 return true; 924 case arith::CmpIPredicate::ne: 925 case arith::CmpIPredicate::slt: 926 case arith::CmpIPredicate::sgt: 927 case arith::CmpIPredicate::ult: 928 case arith::CmpIPredicate::ugt: 929 return false; 930 } 931 llvm_unreachable("unknown cmpi predicate kind"); 932 } 933 934 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { 935 assert(operands.size() == 2 && "cmpi takes two operands"); 936 937 // cmpi(pred, x, x) 938 if (getLhs() == getRhs()) { 939 auto val = applyCmpPredicateToEqualOperands(getPredicate()); 940 return BoolAttr::get(getContext(), val); 941 } 942 943 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 944 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 945 if (!lhs || !rhs) 946 return {}; 947 948 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 949 return BoolAttr::get(getContext(), val); 950 } 951 952 //===----------------------------------------------------------------------===// 953 // CmpFOp 954 //===----------------------------------------------------------------------===// 955 956 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 957 /// comparison predicates. 958 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, 959 const APFloat &lhs, const APFloat &rhs) { 960 auto cmpResult = lhs.compare(rhs); 961 switch (predicate) { 962 case arith::CmpFPredicate::AlwaysFalse: 963 return false; 964 case arith::CmpFPredicate::OEQ: 965 return cmpResult == APFloat::cmpEqual; 966 case arith::CmpFPredicate::OGT: 967 return cmpResult == APFloat::cmpGreaterThan; 968 case arith::CmpFPredicate::OGE: 969 return cmpResult == APFloat::cmpGreaterThan || 970 cmpResult == APFloat::cmpEqual; 971 case arith::CmpFPredicate::OLT: 972 return cmpResult == APFloat::cmpLessThan; 973 case arith::CmpFPredicate::OLE: 974 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 975 case arith::CmpFPredicate::ONE: 976 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 977 case arith::CmpFPredicate::ORD: 978 return cmpResult != APFloat::cmpUnordered; 979 case arith::CmpFPredicate::UEQ: 980 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 981 case arith::CmpFPredicate::UGT: 982 return cmpResult == APFloat::cmpUnordered || 983 cmpResult == APFloat::cmpGreaterThan; 984 case arith::CmpFPredicate::UGE: 985 return cmpResult == APFloat::cmpUnordered || 986 cmpResult == APFloat::cmpGreaterThan || 987 cmpResult == APFloat::cmpEqual; 988 case arith::CmpFPredicate::ULT: 989 return cmpResult == APFloat::cmpUnordered || 990 cmpResult == APFloat::cmpLessThan; 991 case arith::CmpFPredicate::ULE: 992 return cmpResult == APFloat::cmpUnordered || 993 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 994 case arith::CmpFPredicate::UNE: 995 return cmpResult != APFloat::cmpEqual; 996 case arith::CmpFPredicate::UNO: 997 return cmpResult == APFloat::cmpUnordered; 998 case arith::CmpFPredicate::AlwaysTrue: 999 return true; 1000 } 1001 llvm_unreachable("unknown cmpf predicate kind"); 1002 } 1003 1004 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) { 1005 assert(operands.size() == 2 && "cmpf takes two operands"); 1006 1007 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 1008 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 1009 1010 if (!lhs || !rhs) 1011 return {}; 1012 1013 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 1014 return BoolAttr::get(getContext(), val); 1015 } 1016 1017 //===----------------------------------------------------------------------===// 1018 // TableGen'd op method definitions 1019 //===----------------------------------------------------------------------===// 1020 1021 #define GET_OP_CLASSES 1022 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" 1023 1024 //===----------------------------------------------------------------------===// 1025 // TableGen'd enum attribute definitions 1026 //===----------------------------------------------------------------------===// 1027 1028 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" 1029