1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30                                    Attribute lhs, Attribute rhs) {
31   return builder.getIntegerAttr(res.getType(),
32                                 lhs.cast<IntegerAttr>().getInt() +
33                                     rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37                                    Attribute lhs, Attribute rhs) {
38   return builder.getIntegerAttr(res.getType(),
39                                 lhs.cast<IntegerAttr>().getInt() -
40                                     rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45   switch (pred) {
46   case arith::CmpIPredicate::eq:
47     return arith::CmpIPredicate::ne;
48   case arith::CmpIPredicate::ne:
49     return arith::CmpIPredicate::eq;
50   case arith::CmpIPredicate::slt:
51     return arith::CmpIPredicate::sge;
52   case arith::CmpIPredicate::sle:
53     return arith::CmpIPredicate::sgt;
54   case arith::CmpIPredicate::sgt:
55     return arith::CmpIPredicate::sle;
56   case arith::CmpIPredicate::sge:
57     return arith::CmpIPredicate::slt;
58   case arith::CmpIPredicate::ult:
59     return arith::CmpIPredicate::uge;
60   case arith::CmpIPredicate::ule:
61     return arith::CmpIPredicate::ugt;
62   case arith::CmpIPredicate::ugt:
63     return arith::CmpIPredicate::ule;
64   case arith::CmpIPredicate::uge:
65     return arith::CmpIPredicate::ult;
66   }
67   llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71   return arith::CmpIPredicateAttr::get(pred.getContext(),
72                                        invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88     function_ref<void(Value, StringRef)> setNameFn) {
89   auto type = getType();
90   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91     auto intType = type.dyn_cast<IntegerType>();
92 
93     // Sugar i1 constants with 'true' and 'false'.
94     if (intType && intType.getWidth() == 1)
95       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97     // Otherwise, build a compex name with the value and type.
98     SmallString<32> specialNameBuffer;
99     llvm::raw_svector_ostream specialName(specialNameBuffer);
100     specialName << 'c' << intCst.getInt();
101     if (intType)
102       specialName << '_' << type;
103     setNameFn(getResult(), specialName.str());
104   } else {
105     setNameFn(getResult(), "cst");
106   }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
111 LogicalResult arith::ConstantOp::verify() {
112   auto type = getType();
113   // The value's type must match the return type.
114   if (getValue().getType() != type) {
115     return emitOpError() << "value type " << getValue().getType()
116                          << " must match return type: " << type;
117   }
118   // Integer values must be signless.
119   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120     return emitOpError("integer return type must be signless");
121   // Any float or elements attribute are acceptable.
122   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123     return emitOpError(
124         "value must be an integer, float, or elements attribute");
125   }
126   return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130   // The value's type must be the same as the provided type.
131   if (value.getType() != type)
132     return false;
133   // Integer values must be signless.
134   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135     return false;
136   // Integer, float, and element attributes are buildable.
137   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139 
140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141   return getValue();
142 }
143 
144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145                                  int64_t value, unsigned width) {
146   auto type = builder.getIntegerType(width);
147   arith::ConstantOp::build(builder, result, type,
148                            builder.getIntegerAttr(type, value));
149 }
150 
151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152                                  int64_t value, Type type) {
153   assert(type.isSignlessInteger() &&
154          "ConstantIntOp can only have signless integer type values");
155   arith::ConstantOp::build(builder, result, type,
156                            builder.getIntegerAttr(type, value));
157 }
158 
159 bool arith::ConstantIntOp::classof(Operation *op) {
160   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161     return constOp.getType().isSignlessInteger();
162   return false;
163 }
164 
165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166                                    const APFloat &value, FloatType type) {
167   arith::ConstantOp::build(builder, result, type,
168                            builder.getFloatAttr(type, value));
169 }
170 
171 bool arith::ConstantFloatOp::classof(Operation *op) {
172   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173     return constOp.getType().isa<FloatType>();
174   return false;
175 }
176 
177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178                                    int64_t value) {
179   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180                            builder.getIndexAttr(value));
181 }
182 
183 bool arith::ConstantIndexOp::classof(Operation *op) {
184   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185     return constOp.getType().isIndex();
186   return false;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194   // addi(x, 0) -> x
195   if (matchPattern(getRhs(), m_Zero()))
196     return getLhs();
197 
198   // addi(subi(a, b), b) -> a
199   if (auto sub = getLhs().getDefiningOp<SubIOp>())
200     if (getRhs() == sub.getRhs())
201       return sub.getLhs();
202 
203   // addi(b, subi(a, b)) -> a
204   if (auto sub = getRhs().getDefiningOp<SubIOp>())
205     if (getLhs() == sub.getRhs())
206       return sub.getLhs();
207 
208   return constFoldBinaryOp<IntegerAttr>(
209       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211 
212 void arith::AddIOp::getCanonicalizationPatterns(
213     RewritePatternSet &patterns, MLIRContext *context) {
214   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215       context);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221 
222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223   // subi(x,x) -> 0
224   if (getOperand(0) == getOperand(1))
225     return Builder(getContext()).getZeroAttr(getType());
226   // subi(x,0) -> x
227   if (matchPattern(getRhs(), m_Zero()))
228     return getLhs();
229 
230   return constFoldBinaryOp<IntegerAttr>(
231       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233 
234 void arith::SubIOp::getCanonicalizationPatterns(
235     RewritePatternSet &patterns, MLIRContext *context) {
236   patterns
237       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247   // muli(x, 0) -> 0
248   if (matchPattern(getRhs(), m_Zero()))
249     return getRhs();
250   // muli(x, 1) -> x
251   if (matchPattern(getRhs(), m_One()))
252     return getOperand(0);
253   // TODO: Handle the overflow case.
254 
255   // default folder
256   return constFoldBinaryOp<IntegerAttr>(
257       operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265   // divui (x, 1) -> x.
266   if (matchPattern(getRhs(), m_One()))
267     return getLhs();
268 
269   // Don't fold if it would require a division by zero.
270   bool div0 = false;
271   auto result =
272       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
273         if (div0 || !b) {
274           div0 = true;
275           return a;
276         }
277         return a.udiv(b);
278       });
279 
280   return div0 ? Attribute() : result;
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // DivSIOp
285 //===----------------------------------------------------------------------===//
286 
287 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
288   // divsi (x, 1) -> x.
289   if (matchPattern(getRhs(), m_One()))
290     return getLhs();
291 
292   // Don't fold if it would overflow or if it requires a division by zero.
293   bool overflowOrDiv0 = false;
294   auto result =
295       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
296         if (overflowOrDiv0 || !b) {
297           overflowOrDiv0 = true;
298           return a;
299         }
300         return a.sdiv_ov(b, overflowOrDiv0);
301       });
302 
303   return overflowOrDiv0 ? Attribute() : result;
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Ceil and floor division folding helpers
308 //===----------------------------------------------------------------------===//
309 
310 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
311                                     bool &overflow) {
312   // Returns (a-1)/b + 1
313   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
314   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
315   return val.sadd_ov(one, overflow);
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // CeilDivUIOp
320 //===----------------------------------------------------------------------===//
321 
322 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
323   // ceildivui (x, 1) -> x.
324   if (matchPattern(getRhs(), m_One()))
325     return getLhs();
326 
327   bool overflowOrDiv0 = false;
328   auto result =
329       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
330         if (overflowOrDiv0 || !b) {
331           overflowOrDiv0 = true;
332           return a;
333         }
334         APInt quotient = a.udiv(b);
335         if (!a.urem(b))
336           return quotient;
337         APInt one(a.getBitWidth(), 1, true);
338         return quotient.uadd_ov(one, overflowOrDiv0);
339       });
340 
341   return overflowOrDiv0 ? Attribute() : result;
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // CeilDivSIOp
346 //===----------------------------------------------------------------------===//
347 
348 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
349   // ceildivsi (x, 1) -> x.
350   if (matchPattern(getRhs(), m_One()))
351     return getLhs();
352 
353   // Don't fold if it would overflow or if it requires a division by zero.
354   bool overflowOrDiv0 = false;
355   auto result =
356       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
357         if (overflowOrDiv0 || !b) {
358           overflowOrDiv0 = true;
359           return a;
360         }
361         if (!a)
362           return a;
363         // After this point we know that neither a or b are zero.
364         unsigned bits = a.getBitWidth();
365         APInt zero = APInt::getZero(bits);
366         bool aGtZero = a.sgt(zero);
367         bool bGtZero = b.sgt(zero);
368         if (aGtZero && bGtZero) {
369           // Both positive, return ceil(a, b).
370           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
371         }
372         if (!aGtZero && !bGtZero) {
373           // Both negative, return ceil(-a, -b).
374           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
375           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
376           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
377         }
378         if (!aGtZero && bGtZero) {
379           // A is negative, b is positive, return - ( -a / b).
380           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
381           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
382           return zero.ssub_ov(div, overflowOrDiv0);
383         }
384         // A is positive, b is negative, return - (a / -b).
385         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
386         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
387         return zero.ssub_ov(div, overflowOrDiv0);
388       });
389 
390   return overflowOrDiv0 ? Attribute() : result;
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // FloorDivSIOp
395 //===----------------------------------------------------------------------===//
396 
397 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
398   // floordivsi (x, 1) -> x.
399   if (matchPattern(getRhs(), m_One()))
400     return getLhs();
401 
402   // Don't fold if it would overflow or if it requires a division by zero.
403   bool overflowOrDiv0 = false;
404   auto result =
405       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
406         if (overflowOrDiv0 || !b) {
407           overflowOrDiv0 = true;
408           return a;
409         }
410         if (!a)
411           return a;
412         // After this point we know that neither a or b are zero.
413         unsigned bits = a.getBitWidth();
414         APInt zero = APInt::getZero(bits);
415         bool aGtZero = a.sgt(zero);
416         bool bGtZero = b.sgt(zero);
417         if (aGtZero && bGtZero) {
418           // Both positive, return a / b.
419           return a.sdiv_ov(b, overflowOrDiv0);
420         }
421         if (!aGtZero && !bGtZero) {
422           // Both negative, return -a / -b.
423           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
424           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
425           return posA.sdiv_ov(posB, overflowOrDiv0);
426         }
427         if (!aGtZero && bGtZero) {
428           // A is negative, b is positive, return - ceil(-a, b).
429           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
430           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
431           return zero.ssub_ov(ceil, overflowOrDiv0);
432         }
433         // A is positive, b is negative, return - ceil(a, -b).
434         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
435         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
436         return zero.ssub_ov(ceil, overflowOrDiv0);
437       });
438 
439   return overflowOrDiv0 ? Attribute() : result;
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // RemUIOp
444 //===----------------------------------------------------------------------===//
445 
446 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
447   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
448   if (!rhs)
449     return {};
450   auto rhsValue = rhs.getValue();
451 
452   // x % 1 = 0
453   if (rhsValue.isOneValue())
454     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
455 
456   // Don't fold if it requires division by zero.
457   if (rhsValue.isNullValue())
458     return {};
459 
460   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
461   if (!lhs)
462     return {};
463   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
464 }
465 
466 //===----------------------------------------------------------------------===//
467 // RemSIOp
468 //===----------------------------------------------------------------------===//
469 
470 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
471   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
472   if (!rhs)
473     return {};
474   auto rhsValue = rhs.getValue();
475 
476   // x % 1 = 0
477   if (rhsValue.isOneValue())
478     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
479 
480   // Don't fold if it requires division by zero.
481   if (rhsValue.isNullValue())
482     return {};
483 
484   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
485   if (!lhs)
486     return {};
487   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // AndIOp
492 //===----------------------------------------------------------------------===//
493 
494 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
495   /// and(x, 0) -> 0
496   if (matchPattern(getRhs(), m_Zero()))
497     return getRhs();
498   /// and(x, allOnes) -> x
499   APInt intValue;
500   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
501     return getLhs();
502 
503   return constFoldBinaryOp<IntegerAttr>(
504       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
505 }
506 
507 //===----------------------------------------------------------------------===//
508 // OrIOp
509 //===----------------------------------------------------------------------===//
510 
511 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
512   /// or(x, 0) -> x
513   if (matchPattern(getRhs(), m_Zero()))
514     return getLhs();
515   /// or(x, <all ones>) -> <all ones>
516   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
517     if (rhsAttr.getValue().isAllOnes())
518       return rhsAttr;
519 
520   return constFoldBinaryOp<IntegerAttr>(
521       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // XOrIOp
526 //===----------------------------------------------------------------------===//
527 
528 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
529   /// xor(x, 0) -> x
530   if (matchPattern(getRhs(), m_Zero()))
531     return getLhs();
532   /// xor(x, x) -> 0
533   if (getLhs() == getRhs())
534     return Builder(getContext()).getZeroAttr(getType());
535   /// xor(xor(x, a), a) -> x
536   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
537     if (prev.getRhs() == getRhs())
538       return prev.getLhs();
539 
540   return constFoldBinaryOp<IntegerAttr>(
541       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
542 }
543 
544 void arith::XOrIOp::getCanonicalizationPatterns(
545     RewritePatternSet &patterns, MLIRContext *context) {
546   patterns.add<XOrINotCmpI>(context);
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // NegFOp
551 //===----------------------------------------------------------------------===//
552 
553 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
554   return constFoldUnaryOp<FloatAttr>(operands,
555                                      [](const APFloat &a) { return -a; });
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // AddFOp
560 //===----------------------------------------------------------------------===//
561 
562 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
563   // addf(x, -0) -> x
564   if (matchPattern(getRhs(), m_NegZeroFloat()))
565     return getLhs();
566 
567   return constFoldBinaryOp<FloatAttr>(
568       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
569 }
570 
571 //===----------------------------------------------------------------------===//
572 // SubFOp
573 //===----------------------------------------------------------------------===//
574 
575 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
576   // subf(x, +0) -> x
577   if (matchPattern(getRhs(), m_PosZeroFloat()))
578     return getLhs();
579 
580   return constFoldBinaryOp<FloatAttr>(
581       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // MaxFOp
586 //===----------------------------------------------------------------------===//
587 
588 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
589   assert(operands.size() == 2 && "maxf takes two operands");
590 
591   // maxf(x,x) -> x
592   if (getLhs() == getRhs())
593     return getRhs();
594 
595   // maxf(x, -inf) -> x
596   if (matchPattern(getRhs(), m_NegInfFloat()))
597     return getLhs();
598 
599   return constFoldBinaryOp<FloatAttr>(
600       operands,
601       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // MaxSIOp
606 //===----------------------------------------------------------------------===//
607 
608 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
609   assert(operands.size() == 2 && "binary operation takes two operands");
610 
611   // maxsi(x,x) -> x
612   if (getLhs() == getRhs())
613     return getRhs();
614 
615   APInt intValue;
616   // maxsi(x,MAX_INT) -> MAX_INT
617   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
618       intValue.isMaxSignedValue())
619     return getRhs();
620 
621   // maxsi(x, MIN_INT) -> x
622   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
623       intValue.isMinSignedValue())
624     return getLhs();
625 
626   return constFoldBinaryOp<IntegerAttr>(operands,
627                                         [](const APInt &a, const APInt &b) {
628                                           return llvm::APIntOps::smax(a, b);
629                                         });
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // MaxUIOp
634 //===----------------------------------------------------------------------===//
635 
636 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
637   assert(operands.size() == 2 && "binary operation takes two operands");
638 
639   // maxui(x,x) -> x
640   if (getLhs() == getRhs())
641     return getRhs();
642 
643   APInt intValue;
644   // maxui(x,MAX_INT) -> MAX_INT
645   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
646     return getRhs();
647 
648   // maxui(x, MIN_INT) -> x
649   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
650     return getLhs();
651 
652   return constFoldBinaryOp<IntegerAttr>(operands,
653                                         [](const APInt &a, const APInt &b) {
654                                           return llvm::APIntOps::umax(a, b);
655                                         });
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // MinFOp
660 //===----------------------------------------------------------------------===//
661 
662 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
663   assert(operands.size() == 2 && "minf takes two operands");
664 
665   // minf(x,x) -> x
666   if (getLhs() == getRhs())
667     return getRhs();
668 
669   // minf(x, +inf) -> x
670   if (matchPattern(getRhs(), m_PosInfFloat()))
671     return getLhs();
672 
673   return constFoldBinaryOp<FloatAttr>(
674       operands,
675       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // MinSIOp
680 //===----------------------------------------------------------------------===//
681 
682 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
683   assert(operands.size() == 2 && "binary operation takes two operands");
684 
685   // minsi(x,x) -> x
686   if (getLhs() == getRhs())
687     return getRhs();
688 
689   APInt intValue;
690   // minsi(x,MIN_INT) -> MIN_INT
691   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
692       intValue.isMinSignedValue())
693     return getRhs();
694 
695   // minsi(x, MAX_INT) -> x
696   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
697       intValue.isMaxSignedValue())
698     return getLhs();
699 
700   return constFoldBinaryOp<IntegerAttr>(operands,
701                                         [](const APInt &a, const APInt &b) {
702                                           return llvm::APIntOps::smin(a, b);
703                                         });
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // MinUIOp
708 //===----------------------------------------------------------------------===//
709 
710 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
711   assert(operands.size() == 2 && "binary operation takes two operands");
712 
713   // minui(x,x) -> x
714   if (getLhs() == getRhs())
715     return getRhs();
716 
717   APInt intValue;
718   // minui(x,MIN_INT) -> MIN_INT
719   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
720     return getRhs();
721 
722   // minui(x, MAX_INT) -> x
723   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
724     return getLhs();
725 
726   return constFoldBinaryOp<IntegerAttr>(operands,
727                                         [](const APInt &a, const APInt &b) {
728                                           return llvm::APIntOps::umin(a, b);
729                                         });
730 }
731 
732 //===----------------------------------------------------------------------===//
733 // MulFOp
734 //===----------------------------------------------------------------------===//
735 
736 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
737   // mulf(x, 1) -> x
738   if (matchPattern(getRhs(), m_OneFloat()))
739     return getLhs();
740 
741   return constFoldBinaryOp<FloatAttr>(
742       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
743 }
744 
745 //===----------------------------------------------------------------------===//
746 // DivFOp
747 //===----------------------------------------------------------------------===//
748 
749 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
750   // divf(x, 1) -> x
751   if (matchPattern(getRhs(), m_OneFloat()))
752     return getLhs();
753 
754   return constFoldBinaryOp<FloatAttr>(
755       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
756 }
757 
758 //===----------------------------------------------------------------------===//
759 // Utility functions for verifying cast ops
760 //===----------------------------------------------------------------------===//
761 
762 template <typename... Types>
763 using type_list = std::tuple<Types...> *;
764 
765 /// Returns a non-null type only if the provided type is one of the allowed
766 /// types or one of the allowed shaped types of the allowed types. Returns the
767 /// element type if a valid shaped type is provided.
768 template <typename... ShapedTypes, typename... ElementTypes>
769 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
770                               type_list<ElementTypes...>) {
771   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
772     return {};
773 
774   auto underlyingType = getElementTypeOrSelf(type);
775   if (!underlyingType.isa<ElementTypes...>())
776     return {};
777 
778   return underlyingType;
779 }
780 
781 /// Get allowed underlying types for vectors and tensors.
782 template <typename... ElementTypes>
783 static Type getTypeIfLike(Type type) {
784   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
785                            type_list<ElementTypes...>());
786 }
787 
788 /// Get allowed underlying types for vectors, tensors, and memrefs.
789 template <typename... ElementTypes>
790 static Type getTypeIfLikeOrMemRef(Type type) {
791   return getUnderlyingType(type,
792                            type_list<VectorType, TensorType, MemRefType>(),
793                            type_list<ElementTypes...>());
794 }
795 
796 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
797   return inputs.size() == 1 && outputs.size() == 1 &&
798          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
799 }
800 
801 //===----------------------------------------------------------------------===//
802 // Verifiers for integer and floating point extension/truncation ops
803 //===----------------------------------------------------------------------===//
804 
805 // Extend ops can only extend to a wider type.
806 template <typename ValType, typename Op>
807 static LogicalResult verifyExtOp(Op op) {
808   Type srcType = getElementTypeOrSelf(op.getIn().getType());
809   Type dstType = getElementTypeOrSelf(op.getType());
810 
811   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
812     return op.emitError("result type ")
813            << dstType << " must be wider than operand type " << srcType;
814 
815   return success();
816 }
817 
818 // Truncate ops can only truncate to a shorter type.
819 template <typename ValType, typename Op>
820 static LogicalResult verifyTruncateOp(Op op) {
821   Type srcType = getElementTypeOrSelf(op.getIn().getType());
822   Type dstType = getElementTypeOrSelf(op.getType());
823 
824   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
825     return op.emitError("result type ")
826            << dstType << " must be shorter than operand type " << srcType;
827 
828   return success();
829 }
830 
831 /// Validate a cast that changes the width of a type.
832 template <template <typename> class WidthComparator, typename... ElementTypes>
833 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
834   if (!areValidCastInputsAndOutputs(inputs, outputs))
835     return false;
836 
837   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
838   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
839   if (!srcType || !dstType)
840     return false;
841 
842   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
843                                      srcType.getIntOrFloatBitWidth());
844 }
845 
846 //===----------------------------------------------------------------------===//
847 // ExtUIOp
848 //===----------------------------------------------------------------------===//
849 
850 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
851   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
852     getInMutable().assign(lhs.getIn());
853     return getResult();
854   }
855   Type resType = getType();
856   unsigned bitWidth;
857   if (auto shapedType = resType.dyn_cast<ShapedType>())
858     bitWidth = shapedType.getElementTypeBitWidth();
859   else
860     bitWidth = resType.getIntOrFloatBitWidth();
861   return constFoldCastOp<IntegerAttr, IntegerAttr>(
862       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
863         return a.zext(bitWidth);
864       });
865 }
866 
867 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
868   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
869 }
870 
871 LogicalResult arith::ExtUIOp::verify() {
872   return verifyExtOp<IntegerType>(*this);
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // ExtSIOp
877 //===----------------------------------------------------------------------===//
878 
879 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
880   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
881     getInMutable().assign(lhs.getIn());
882     return getResult();
883   }
884   Type resType = getType();
885   unsigned bitWidth;
886   if (auto shapedType = resType.dyn_cast<ShapedType>())
887     bitWidth = shapedType.getElementTypeBitWidth();
888   else
889     bitWidth = resType.getIntOrFloatBitWidth();
890   return constFoldCastOp<IntegerAttr, IntegerAttr>(
891       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
892         return a.sext(bitWidth);
893       });
894 }
895 
896 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
897   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
898 }
899 
900 void arith::ExtSIOp::getCanonicalizationPatterns(
901     RewritePatternSet &patterns, MLIRContext *context) {
902   patterns.add<ExtSIOfExtUI>(context);
903 }
904 
905 LogicalResult arith::ExtSIOp::verify() {
906   return verifyExtOp<IntegerType>(*this);
907 }
908 
909 //===----------------------------------------------------------------------===//
910 // ExtFOp
911 //===----------------------------------------------------------------------===//
912 
913 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
914   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
915 }
916 
917 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
918 
919 //===----------------------------------------------------------------------===//
920 // TruncIOp
921 //===----------------------------------------------------------------------===//
922 
923 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
924   assert(operands.size() == 1 && "unary operation takes one operand");
925 
926   // trunci(zexti(a)) -> a
927   // trunci(sexti(a)) -> a
928   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
929       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
930     return getOperand().getDefiningOp()->getOperand(0);
931 
932   // trunci(trunci(a)) -> trunci(a))
933   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
934     setOperand(getOperand().getDefiningOp()->getOperand(0));
935     return getResult();
936   }
937 
938   Type resType = getType();
939   unsigned bitWidth;
940   if (auto shapedType = resType.dyn_cast<ShapedType>())
941     bitWidth = shapedType.getElementTypeBitWidth();
942   else
943     bitWidth = resType.getIntOrFloatBitWidth();
944 
945   return constFoldCastOp<IntegerAttr, IntegerAttr>(
946       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
947         return a.trunc(bitWidth);
948       });
949 }
950 
951 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
952   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
953 }
954 
955 LogicalResult arith::TruncIOp::verify() {
956   return verifyTruncateOp<IntegerType>(*this);
957 }
958 
959 //===----------------------------------------------------------------------===//
960 // TruncFOp
961 //===----------------------------------------------------------------------===//
962 
963 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
964 /// can be represented without precision loss or rounding.
965 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
966   assert(operands.size() == 1 && "unary operation takes one operand");
967 
968   auto constOperand = operands.front();
969   if (!constOperand || !constOperand.isa<FloatAttr>())
970     return {};
971 
972   // Convert to target type via 'double'.
973   double sourceValue =
974       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
975   auto targetAttr = FloatAttr::get(getType(), sourceValue);
976 
977   // Propagate if constant's value does not change after truncation.
978   if (sourceValue == targetAttr.getValue().convertToDouble())
979     return targetAttr;
980 
981   return {};
982 }
983 
984 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
985   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
986 }
987 
988 LogicalResult arith::TruncFOp::verify() {
989   return verifyTruncateOp<FloatType>(*this);
990 }
991 
992 //===----------------------------------------------------------------------===//
993 // AndIOp
994 //===----------------------------------------------------------------------===//
995 
996 void arith::AndIOp::getCanonicalizationPatterns(
997     RewritePatternSet &patterns, MLIRContext *context) {
998   patterns.add<AndOfExtUI, AndOfExtSI>(context);
999 }
1000 
1001 //===----------------------------------------------------------------------===//
1002 // OrIOp
1003 //===----------------------------------------------------------------------===//
1004 
1005 void arith::OrIOp::getCanonicalizationPatterns(
1006     RewritePatternSet &patterns, MLIRContext *context) {
1007   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1008 }
1009 
1010 //===----------------------------------------------------------------------===//
1011 // Verifiers for casts between integers and floats.
1012 //===----------------------------------------------------------------------===//
1013 
1014 template <typename From, typename To>
1015 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1016   if (!areValidCastInputsAndOutputs(inputs, outputs))
1017     return false;
1018 
1019   auto srcType = getTypeIfLike<From>(inputs.front());
1020   auto dstType = getTypeIfLike<To>(outputs.back());
1021 
1022   return srcType && dstType;
1023 }
1024 
1025 //===----------------------------------------------------------------------===//
1026 // UIToFPOp
1027 //===----------------------------------------------------------------------===//
1028 
1029 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1030   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1031 }
1032 
1033 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1034   Type resType = getType();
1035   Type resEleType;
1036   if (auto shapedType = resType.dyn_cast<ShapedType>())
1037     resEleType = shapedType.getElementType();
1038   else
1039     resEleType = resType;
1040   return constFoldCastOp<IntegerAttr, FloatAttr>(
1041       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1042         FloatType floatTy = resEleType.cast<FloatType>();
1043         APFloat apf(floatTy.getFloatSemantics(),
1044                     APInt::getZero(floatTy.getWidth()));
1045         apf.convertFromAPInt(a, /*IsSigned=*/false,
1046                              APFloat::rmNearestTiesToEven);
1047         return apf;
1048       });
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // SIToFPOp
1053 //===----------------------------------------------------------------------===//
1054 
1055 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1056   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1057 }
1058 
1059 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1060   Type resType = getType();
1061   Type resEleType;
1062   if (auto shapedType = resType.dyn_cast<ShapedType>())
1063     resEleType = shapedType.getElementType();
1064   else
1065     resEleType = resType;
1066   return constFoldCastOp<IntegerAttr, FloatAttr>(
1067       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1068         FloatType floatTy = resEleType.cast<FloatType>();
1069         APFloat apf(floatTy.getFloatSemantics(),
1070                     APInt::getZero(floatTy.getWidth()));
1071         apf.convertFromAPInt(a, /*IsSigned=*/true,
1072                              APFloat::rmNearestTiesToEven);
1073         return apf;
1074       });
1075 }
1076 //===----------------------------------------------------------------------===//
1077 // FPToUIOp
1078 //===----------------------------------------------------------------------===//
1079 
1080 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1081   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1082 }
1083 
1084 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1085   Type resType = getType();
1086   Type resEleType;
1087   if (auto shapedType = resType.dyn_cast<ShapedType>())
1088     resEleType = shapedType.getElementType();
1089   else
1090     resEleType = resType;
1091   return constFoldCastOp<FloatAttr, IntegerAttr>(
1092       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1093         IntegerType intTy = resEleType.cast<IntegerType>();
1094         bool ignored;
1095         APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1096         castStatus = APFloat::opInvalidOp !=
1097                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1098         return api;
1099       });
1100 }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // FPToSIOp
1104 //===----------------------------------------------------------------------===//
1105 
1106 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1107   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1108 }
1109 
1110 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1111   Type resType = getType();
1112   Type resEleType;
1113   if (auto shapedType = resType.dyn_cast<ShapedType>())
1114     resEleType = shapedType.getElementType();
1115   else
1116     resEleType = resType;
1117   return constFoldCastOp<FloatAttr, IntegerAttr>(
1118       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1119         IntegerType intTy = resEleType.cast<IntegerType>();
1120         bool ignored;
1121         APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1122         castStatus = APFloat::opInvalidOp !=
1123                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1124         return api;
1125       });
1126 }
1127 
1128 //===----------------------------------------------------------------------===//
1129 // IndexCastOp
1130 //===----------------------------------------------------------------------===//
1131 
1132 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1133                                            TypeRange outputs) {
1134   if (!areValidCastInputsAndOutputs(inputs, outputs))
1135     return false;
1136 
1137   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1138   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1139   if (!srcType || !dstType)
1140     return false;
1141 
1142   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1143          (srcType.isSignlessInteger() && dstType.isIndex());
1144 }
1145 
1146 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1147   // index_cast(constant) -> constant
1148   // A little hack because we go through int. Otherwise, the size of the
1149   // constant might need to change.
1150   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1151     return IntegerAttr::get(getType(), value.getInt());
1152 
1153   return {};
1154 }
1155 
1156 void arith::IndexCastOp::getCanonicalizationPatterns(
1157     RewritePatternSet &patterns, MLIRContext *context) {
1158   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // BitcastOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1166   if (!areValidCastInputsAndOutputs(inputs, outputs))
1167     return false;
1168 
1169   auto srcType =
1170       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1171   auto dstType =
1172       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1173   if (!srcType || !dstType)
1174     return false;
1175 
1176   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1177 }
1178 
1179 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1180   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1181 
1182   auto resType = getType();
1183   auto operand = operands[0];
1184   if (!operand)
1185     return {};
1186 
1187   /// Bitcast dense elements.
1188   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1189     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1190   /// Other shaped types unhandled.
1191   if (resType.isa<ShapedType>())
1192     return {};
1193 
1194   /// Bitcast integer or float to integer or float.
1195   APInt bits = operand.isa<FloatAttr>()
1196                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1197                    : operand.cast<IntegerAttr>().getValue();
1198 
1199   if (auto resFloatType = resType.dyn_cast<FloatType>())
1200     return FloatAttr::get(resType,
1201                           APFloat(resFloatType.getFloatSemantics(), bits));
1202   return IntegerAttr::get(resType, bits);
1203 }
1204 
1205 void arith::BitcastOp::getCanonicalizationPatterns(
1206     RewritePatternSet &patterns, MLIRContext *context) {
1207   patterns.add<BitcastOfBitcast>(context);
1208 }
1209 
1210 //===----------------------------------------------------------------------===//
1211 // Helpers for compare ops
1212 //===----------------------------------------------------------------------===//
1213 
1214 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1215 static Type getI1SameShape(Type type) {
1216   auto i1Type = IntegerType::get(type.getContext(), 1);
1217   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1218     return RankedTensorType::get(tensorType.getShape(), i1Type);
1219   if (type.isa<UnrankedTensorType>())
1220     return UnrankedTensorType::get(i1Type);
1221   if (auto vectorType = type.dyn_cast<VectorType>())
1222     return VectorType::get(vectorType.getShape(), i1Type,
1223                            vectorType.getNumScalableDims());
1224   return i1Type;
1225 }
1226 
1227 //===----------------------------------------------------------------------===//
1228 // CmpIOp
1229 //===----------------------------------------------------------------------===//
1230 
1231 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1232 /// comparison predicates.
1233 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1234                                     const APInt &lhs, const APInt &rhs) {
1235   switch (predicate) {
1236   case arith::CmpIPredicate::eq:
1237     return lhs.eq(rhs);
1238   case arith::CmpIPredicate::ne:
1239     return lhs.ne(rhs);
1240   case arith::CmpIPredicate::slt:
1241     return lhs.slt(rhs);
1242   case arith::CmpIPredicate::sle:
1243     return lhs.sle(rhs);
1244   case arith::CmpIPredicate::sgt:
1245     return lhs.sgt(rhs);
1246   case arith::CmpIPredicate::sge:
1247     return lhs.sge(rhs);
1248   case arith::CmpIPredicate::ult:
1249     return lhs.ult(rhs);
1250   case arith::CmpIPredicate::ule:
1251     return lhs.ule(rhs);
1252   case arith::CmpIPredicate::ugt:
1253     return lhs.ugt(rhs);
1254   case arith::CmpIPredicate::uge:
1255     return lhs.uge(rhs);
1256   }
1257   llvm_unreachable("unknown cmpi predicate kind");
1258 }
1259 
1260 /// Returns true if the predicate is true for two equal operands.
1261 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1262   switch (predicate) {
1263   case arith::CmpIPredicate::eq:
1264   case arith::CmpIPredicate::sle:
1265   case arith::CmpIPredicate::sge:
1266   case arith::CmpIPredicate::ule:
1267   case arith::CmpIPredicate::uge:
1268     return true;
1269   case arith::CmpIPredicate::ne:
1270   case arith::CmpIPredicate::slt:
1271   case arith::CmpIPredicate::sgt:
1272   case arith::CmpIPredicate::ult:
1273   case arith::CmpIPredicate::ugt:
1274     return false;
1275   }
1276   llvm_unreachable("unknown cmpi predicate kind");
1277 }
1278 
1279 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1280   auto boolAttr = BoolAttr::get(ctx, value);
1281   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1282   if (!shapedType)
1283     return boolAttr;
1284   return DenseElementsAttr::get(shapedType, boolAttr);
1285 }
1286 
1287 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1288   assert(operands.size() == 2 && "cmpi takes two operands");
1289 
1290   // cmpi(pred, x, x)
1291   if (getLhs() == getRhs()) {
1292     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1293     return getBoolAttribute(getType(), getContext(), val);
1294   }
1295 
1296   if (matchPattern(getRhs(), m_Zero())) {
1297     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1298       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1299         // extsi(%x : i1 -> iN) != 0  ->  %x
1300         if (getPredicate() == arith::CmpIPredicate::ne) {
1301           return extOp.getOperand();
1302         }
1303       }
1304     }
1305     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1306       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1307         // extui(%x : i1 -> iN) != 0  ->  %x
1308         if (getPredicate() == arith::CmpIPredicate::ne) {
1309           return extOp.getOperand();
1310         }
1311       }
1312     }
1313   }
1314 
1315   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1316   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1317   if (!lhs || !rhs)
1318     return {};
1319 
1320   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1321   return BoolAttr::get(getContext(), val);
1322 }
1323 
1324 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1325                                                 MLIRContext *context) {
1326   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1327 }
1328 
1329 //===----------------------------------------------------------------------===//
1330 // CmpFOp
1331 //===----------------------------------------------------------------------===//
1332 
1333 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1334 /// comparison predicates.
1335 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1336                                     const APFloat &lhs, const APFloat &rhs) {
1337   auto cmpResult = lhs.compare(rhs);
1338   switch (predicate) {
1339   case arith::CmpFPredicate::AlwaysFalse:
1340     return false;
1341   case arith::CmpFPredicate::OEQ:
1342     return cmpResult == APFloat::cmpEqual;
1343   case arith::CmpFPredicate::OGT:
1344     return cmpResult == APFloat::cmpGreaterThan;
1345   case arith::CmpFPredicate::OGE:
1346     return cmpResult == APFloat::cmpGreaterThan ||
1347            cmpResult == APFloat::cmpEqual;
1348   case arith::CmpFPredicate::OLT:
1349     return cmpResult == APFloat::cmpLessThan;
1350   case arith::CmpFPredicate::OLE:
1351     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1352   case arith::CmpFPredicate::ONE:
1353     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1354   case arith::CmpFPredicate::ORD:
1355     return cmpResult != APFloat::cmpUnordered;
1356   case arith::CmpFPredicate::UEQ:
1357     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1358   case arith::CmpFPredicate::UGT:
1359     return cmpResult == APFloat::cmpUnordered ||
1360            cmpResult == APFloat::cmpGreaterThan;
1361   case arith::CmpFPredicate::UGE:
1362     return cmpResult == APFloat::cmpUnordered ||
1363            cmpResult == APFloat::cmpGreaterThan ||
1364            cmpResult == APFloat::cmpEqual;
1365   case arith::CmpFPredicate::ULT:
1366     return cmpResult == APFloat::cmpUnordered ||
1367            cmpResult == APFloat::cmpLessThan;
1368   case arith::CmpFPredicate::ULE:
1369     return cmpResult == APFloat::cmpUnordered ||
1370            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1371   case arith::CmpFPredicate::UNE:
1372     return cmpResult != APFloat::cmpEqual;
1373   case arith::CmpFPredicate::UNO:
1374     return cmpResult == APFloat::cmpUnordered;
1375   case arith::CmpFPredicate::AlwaysTrue:
1376     return true;
1377   }
1378   llvm_unreachable("unknown cmpf predicate kind");
1379 }
1380 
1381 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1382   assert(operands.size() == 2 && "cmpf takes two operands");
1383 
1384   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1385   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1386 
1387   // If one operand is NaN, making them both NaN does not change the result.
1388   if (lhs && lhs.getValue().isNaN())
1389     rhs = lhs;
1390   if (rhs && rhs.getValue().isNaN())
1391     lhs = rhs;
1392 
1393   if (!lhs || !rhs)
1394     return {};
1395 
1396   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1397   return BoolAttr::get(getContext(), val);
1398 }
1399 
1400 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1401 public:
1402   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1403 
1404   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1405                                                  bool isUnsigned) {
1406     using namespace arith;
1407     switch (pred) {
1408     case CmpFPredicate::UEQ:
1409     case CmpFPredicate::OEQ:
1410       return CmpIPredicate::eq;
1411     case CmpFPredicate::UGT:
1412     case CmpFPredicate::OGT:
1413       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1414     case CmpFPredicate::UGE:
1415     case CmpFPredicate::OGE:
1416       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1417     case CmpFPredicate::ULT:
1418     case CmpFPredicate::OLT:
1419       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1420     case CmpFPredicate::ULE:
1421     case CmpFPredicate::OLE:
1422       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1423     case CmpFPredicate::UNE:
1424     case CmpFPredicate::ONE:
1425       return CmpIPredicate::ne;
1426     default:
1427       llvm_unreachable("Unexpected predicate!");
1428     }
1429   }
1430 
1431   LogicalResult matchAndRewrite(CmpFOp op,
1432                                 PatternRewriter &rewriter) const override {
1433     FloatAttr flt;
1434     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1435       return failure();
1436 
1437     const APFloat &rhs = flt.getValue();
1438 
1439     // Don't attempt to fold a nan.
1440     if (rhs.isNaN())
1441       return failure();
1442 
1443     // Get the width of the mantissa.  We don't want to hack on conversions that
1444     // might lose information from the integer, e.g. "i64 -> float"
1445     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1446     int mantissaWidth = floatTy.getFPMantissaWidth();
1447     if (mantissaWidth <= 0)
1448       return failure();
1449 
1450     bool isUnsigned;
1451     Value intVal;
1452 
1453     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1454       isUnsigned = false;
1455       intVal = si.getIn();
1456     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1457       isUnsigned = true;
1458       intVal = ui.getIn();
1459     } else {
1460       return failure();
1461     }
1462 
1463     // Check to see that the input is converted from an integer type that is
1464     // small enough that preserves all bits.
1465     auto intTy = intVal.getType().cast<IntegerType>();
1466     auto intWidth = intTy.getWidth();
1467 
1468     // Number of bits representing values, as opposed to the sign
1469     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1470 
1471     // Following test does NOT adjust intWidth downwards for signed inputs,
1472     // because the most negative value still requires all the mantissa bits
1473     // to distinguish it from one less than that value.
1474     if ((int)intWidth > mantissaWidth) {
1475       // Conversion would lose accuracy. Check if loss can impact comparison.
1476       int exponent = ilogb(rhs);
1477       if (exponent == APFloat::IEK_Inf) {
1478         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1479         if (maxExponent < (int)valueBits) {
1480           // Conversion could create infinity.
1481           return failure();
1482         }
1483       } else {
1484         // Note that if rhs is zero or NaN, then Exp is negative
1485         // and first condition is trivially false.
1486         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1487           // Conversion could affect comparison.
1488           return failure();
1489         }
1490       }
1491     }
1492 
1493     // Convert to equivalent cmpi predicate
1494     CmpIPredicate pred;
1495     switch (op.getPredicate()) {
1496     case CmpFPredicate::ORD:
1497       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1498       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1499                                                  /*width=*/1);
1500       return success();
1501     case CmpFPredicate::UNO:
1502       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1503       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1504                                                  /*width=*/1);
1505       return success();
1506     default:
1507       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1508       break;
1509     }
1510 
1511     if (!isUnsigned) {
1512       // If the rhs value is > SignedMax, fold the comparison.  This handles
1513       // +INF and large values.
1514       APFloat signedMax(rhs.getSemantics());
1515       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1516                                  APFloat::rmNearestTiesToEven);
1517       if (signedMax < rhs) { // smax < 13123.0
1518         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1519             pred == CmpIPredicate::sle)
1520           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1521                                                      /*width=*/1);
1522         else
1523           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1524                                                      /*width=*/1);
1525         return success();
1526       }
1527     } else {
1528       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1529       // +INF and large values.
1530       APFloat unsignedMax(rhs.getSemantics());
1531       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1532                                    APFloat::rmNearestTiesToEven);
1533       if (unsignedMax < rhs) { // umax < 13123.0
1534         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1535             pred == CmpIPredicate::ule)
1536           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1537                                                      /*width=*/1);
1538         else
1539           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1540                                                      /*width=*/1);
1541         return success();
1542       }
1543     }
1544 
1545     if (!isUnsigned) {
1546       // See if the rhs value is < SignedMin.
1547       APFloat signedMin(rhs.getSemantics());
1548       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1549                                  APFloat::rmNearestTiesToEven);
1550       if (signedMin > rhs) { // smin > 12312.0
1551         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1552             pred == CmpIPredicate::sge)
1553           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1554                                                      /*width=*/1);
1555         else
1556           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1557                                                      /*width=*/1);
1558         return success();
1559       }
1560     } else {
1561       // See if the rhs value is < UnsignedMin.
1562       APFloat unsignedMin(rhs.getSemantics());
1563       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1564                                    APFloat::rmNearestTiesToEven);
1565       if (unsignedMin > rhs) { // umin > 12312.0
1566         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1567             pred == CmpIPredicate::uge)
1568           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1569                                                      /*width=*/1);
1570         else
1571           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1572                                                      /*width=*/1);
1573         return success();
1574       }
1575     }
1576 
1577     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1578     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1579     // casting the FP value to the integer value and back, checking for
1580     // equality. Don't do this for zero, because -0.0 is not fractional.
1581     bool ignored;
1582     APSInt rhsInt(intWidth, isUnsigned);
1583     if (APFloat::opInvalidOp ==
1584         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1585       // Undefined behavior invoked - the destination type can't represent
1586       // the input constant.
1587       return failure();
1588     }
1589 
1590     if (!rhs.isZero()) {
1591       APFloat apf(floatTy.getFloatSemantics(),
1592                   APInt::getZero(floatTy.getWidth()));
1593       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1594 
1595       bool equal = apf == rhs;
1596       if (!equal) {
1597         // If we had a comparison against a fractional value, we have to adjust
1598         // the compare predicate and sometimes the value.  rhsInt is rounded
1599         // towards zero at this point.
1600         switch (pred) {
1601         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1602           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1603                                                      /*width=*/1);
1604           return success();
1605         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1606           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1607                                                      /*width=*/1);
1608           return success();
1609         case CmpIPredicate::ule:
1610           // (float)int <= 4.4   --> int <= 4
1611           // (float)int <= -4.4  --> false
1612           if (rhs.isNegative()) {
1613             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1614                                                        /*width=*/1);
1615             return success();
1616           }
1617           break;
1618         case CmpIPredicate::sle:
1619           // (float)int <= 4.4   --> int <= 4
1620           // (float)int <= -4.4  --> int < -4
1621           if (rhs.isNegative())
1622             pred = CmpIPredicate::slt;
1623           break;
1624         case CmpIPredicate::ult:
1625           // (float)int < -4.4   --> false
1626           // (float)int < 4.4    --> int <= 4
1627           if (rhs.isNegative()) {
1628             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1629                                                        /*width=*/1);
1630             return success();
1631           }
1632           pred = CmpIPredicate::ule;
1633           break;
1634         case CmpIPredicate::slt:
1635           // (float)int < -4.4   --> int < -4
1636           // (float)int < 4.4    --> int <= 4
1637           if (!rhs.isNegative())
1638             pred = CmpIPredicate::sle;
1639           break;
1640         case CmpIPredicate::ugt:
1641           // (float)int > 4.4    --> int > 4
1642           // (float)int > -4.4   --> true
1643           if (rhs.isNegative()) {
1644             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1645                                                        /*width=*/1);
1646             return success();
1647           }
1648           break;
1649         case CmpIPredicate::sgt:
1650           // (float)int > 4.4    --> int > 4
1651           // (float)int > -4.4   --> int >= -4
1652           if (rhs.isNegative())
1653             pred = CmpIPredicate::sge;
1654           break;
1655         case CmpIPredicate::uge:
1656           // (float)int >= -4.4   --> true
1657           // (float)int >= 4.4    --> int > 4
1658           if (rhs.isNegative()) {
1659             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1660                                                        /*width=*/1);
1661             return success();
1662           }
1663           pred = CmpIPredicate::ugt;
1664           break;
1665         case CmpIPredicate::sge:
1666           // (float)int >= -4.4   --> int >= -4
1667           // (float)int >= 4.4    --> int > 4
1668           if (!rhs.isNegative())
1669             pred = CmpIPredicate::sgt;
1670           break;
1671         }
1672       }
1673     }
1674 
1675     // Lower this FP comparison into an appropriate integer version of the
1676     // comparison.
1677     rewriter.replaceOpWithNewOp<CmpIOp>(
1678         op, pred, intVal,
1679         rewriter.create<ConstantOp>(
1680             op.getLoc(), intVal.getType(),
1681             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1682     return success();
1683   }
1684 };
1685 
1686 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1687                                                 MLIRContext *context) {
1688   patterns.insert<CmpFIntToFPConst>(context);
1689 }
1690 
1691 //===----------------------------------------------------------------------===//
1692 // SelectOp
1693 //===----------------------------------------------------------------------===//
1694 
1695 // Transforms a select of a boolean to arithmetic operations
1696 //
1697 //  arith.select %arg, %x, %y : i1
1698 //
1699 //  becomes
1700 //
1701 //  and(%arg, %x) or and(!%arg, %y)
1702 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1703   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1704 
1705   LogicalResult matchAndRewrite(arith::SelectOp op,
1706                                 PatternRewriter &rewriter) const override {
1707     if (!op.getType().isInteger(1))
1708       return failure();
1709 
1710     Value falseConstant =
1711         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1712     Value notCondition = rewriter.create<arith::XOrIOp>(
1713         op.getLoc(), op.getCondition(), falseConstant);
1714 
1715     Value trueVal = rewriter.create<arith::AndIOp>(
1716         op.getLoc(), op.getCondition(), op.getTrueValue());
1717     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1718                                                     op.getFalseValue());
1719     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1720     return success();
1721   }
1722 };
1723 
1724 //  select %arg, %c1, %c0 => extui %arg
1725 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1726   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1727 
1728   LogicalResult matchAndRewrite(arith::SelectOp op,
1729                                 PatternRewriter &rewriter) const override {
1730     // Cannot extui i1 to i1, or i1 to f32
1731     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1732       return failure();
1733 
1734     // select %x, c1, %c0 => extui %arg
1735     if (matchPattern(op.getTrueValue(), m_One()))
1736       if (matchPattern(op.getFalseValue(), m_Zero())) {
1737         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1738                                                     op.getCondition());
1739         return success();
1740       }
1741 
1742     // select %x, c0, %c1 => extui (xor %arg, true)
1743     if (matchPattern(op.getTrueValue(), m_Zero()))
1744       if (matchPattern(op.getFalseValue(), m_One())) {
1745         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1746             op, op.getType(),
1747             rewriter.create<arith::XOrIOp>(
1748                 op.getLoc(), op.getCondition(),
1749                 rewriter.create<arith::ConstantIntOp>(
1750                     op.getLoc(), 1, op.getCondition().getType())));
1751         return success();
1752       }
1753 
1754     return failure();
1755   }
1756 };
1757 
1758 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1759                                                   MLIRContext *context) {
1760   results.add<SelectI1Simplify, SelectToExtUI>(context);
1761 }
1762 
1763 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1764   Value trueVal = getTrueValue();
1765   Value falseVal = getFalseValue();
1766   if (trueVal == falseVal)
1767     return trueVal;
1768 
1769   Value condition = getCondition();
1770 
1771   // select true, %0, %1 => %0
1772   if (matchPattern(condition, m_One()))
1773     return trueVal;
1774 
1775   // select false, %0, %1 => %1
1776   if (matchPattern(condition, m_Zero()))
1777     return falseVal;
1778 
1779   // select %x, true, false => %x
1780   if (getType().isInteger(1))
1781     if (matchPattern(getTrueValue(), m_One()))
1782       if (matchPattern(getFalseValue(), m_Zero()))
1783         return condition;
1784 
1785   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1786     auto pred = cmp.getPredicate();
1787     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1788       auto cmpLhs = cmp.getLhs();
1789       auto cmpRhs = cmp.getRhs();
1790 
1791       // %0 = arith.cmpi eq, %arg0, %arg1
1792       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1793 
1794       // %0 = arith.cmpi ne, %arg0, %arg1
1795       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1796 
1797       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1798           (cmpRhs == trueVal && cmpLhs == falseVal))
1799         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1800     }
1801   }
1802   return nullptr;
1803 }
1804 
1805 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1806   Type conditionType, resultType;
1807   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1808   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1809       parser.parseOptionalAttrDict(result.attributes) ||
1810       parser.parseColonType(resultType))
1811     return failure();
1812 
1813   // Check for the explicit condition type if this is a masked tensor or vector.
1814   if (succeeded(parser.parseOptionalComma())) {
1815     conditionType = resultType;
1816     if (parser.parseType(resultType))
1817       return failure();
1818   } else {
1819     conditionType = parser.getBuilder().getI1Type();
1820   }
1821 
1822   result.addTypes(resultType);
1823   return parser.resolveOperands(operands,
1824                                 {conditionType, resultType, resultType},
1825                                 parser.getNameLoc(), result.operands);
1826 }
1827 
1828 void arith::SelectOp::print(OpAsmPrinter &p) {
1829   p << " " << getOperands();
1830   p.printOptionalAttrDict((*this)->getAttrs());
1831   p << " : ";
1832   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1833     p << condType << ", ";
1834   p << getType();
1835 }
1836 
1837 LogicalResult arith::SelectOp::verify() {
1838   Type conditionType = getCondition().getType();
1839   if (conditionType.isSignlessInteger(1))
1840     return success();
1841 
1842   // If the result type is a vector or tensor, the type can be a mask with the
1843   // same elements.
1844   Type resultType = getType();
1845   if (!resultType.isa<TensorType, VectorType>())
1846     return emitOpError() << "expected condition to be a signless i1, but got "
1847                          << conditionType;
1848   Type shapedConditionType = getI1SameShape(resultType);
1849   if (conditionType != shapedConditionType) {
1850     return emitOpError() << "expected condition type to have the same shape "
1851                             "as the result type, expected "
1852                          << shapedConditionType << ", but got "
1853                          << conditionType;
1854   }
1855   return success();
1856 }
1857 //===----------------------------------------------------------------------===//
1858 // ShLIOp
1859 //===----------------------------------------------------------------------===//
1860 
1861 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1862   // Don't fold if shifting more than the bit width.
1863   bool bounded = false;
1864   auto result = constFoldBinaryOp<IntegerAttr>(
1865       operands, [&](const APInt &a, const APInt &b) {
1866         bounded = b.ule(b.getBitWidth());
1867         return a.shl(b);
1868       });
1869   return bounded ? result : Attribute();
1870 }
1871 
1872 //===----------------------------------------------------------------------===//
1873 // ShRUIOp
1874 //===----------------------------------------------------------------------===//
1875 
1876 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1877   // Don't fold if shifting more than the bit width.
1878   bool bounded = false;
1879   auto result = constFoldBinaryOp<IntegerAttr>(
1880       operands, [&](const APInt &a, const APInt &b) {
1881         bounded = b.ule(b.getBitWidth());
1882         return a.lshr(b);
1883       });
1884   return bounded ? result : Attribute();
1885 }
1886 
1887 //===----------------------------------------------------------------------===//
1888 // ShRSIOp
1889 //===----------------------------------------------------------------------===//
1890 
1891 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1892   // Don't fold if shifting more than the bit width.
1893   bool bounded = false;
1894   auto result = constFoldBinaryOp<IntegerAttr>(
1895       operands, [&](const APInt &a, const APInt &b) {
1896         bounded = b.ule(b.getBitWidth());
1897         return a.ashr(b);
1898       });
1899   return bounded ? result : Attribute();
1900 }
1901 
1902 //===----------------------------------------------------------------------===//
1903 // Atomic Enum
1904 //===----------------------------------------------------------------------===//
1905 
1906 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1907 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1908                                             OpBuilder &builder, Location loc) {
1909   switch (kind) {
1910   case AtomicRMWKind::maxf:
1911     return builder.getFloatAttr(
1912         resultType,
1913         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1914                         /*Negative=*/true));
1915   case AtomicRMWKind::addf:
1916   case AtomicRMWKind::addi:
1917   case AtomicRMWKind::maxu:
1918   case AtomicRMWKind::ori:
1919     return builder.getZeroAttr(resultType);
1920   case AtomicRMWKind::andi:
1921     return builder.getIntegerAttr(
1922         resultType,
1923         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1924   case AtomicRMWKind::maxs:
1925     return builder.getIntegerAttr(
1926         resultType,
1927         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1928   case AtomicRMWKind::minf:
1929     return builder.getFloatAttr(
1930         resultType,
1931         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1932                         /*Negative=*/false));
1933   case AtomicRMWKind::mins:
1934     return builder.getIntegerAttr(
1935         resultType,
1936         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1937   case AtomicRMWKind::minu:
1938     return builder.getIntegerAttr(
1939         resultType,
1940         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1941   case AtomicRMWKind::muli:
1942     return builder.getIntegerAttr(resultType, 1);
1943   case AtomicRMWKind::mulf:
1944     return builder.getFloatAttr(resultType, 1);
1945   // TODO: Add remaining reduction operations.
1946   default:
1947     (void)emitOptionalError(loc, "Reduction operation type not supported");
1948     break;
1949   }
1950   return nullptr;
1951 }
1952 
1953 /// Returns the identity value associated with an AtomicRMWKind op.
1954 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1955                                     OpBuilder &builder, Location loc) {
1956   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1957   return builder.create<arith::ConstantOp>(loc, attr);
1958 }
1959 
1960 /// Return the value obtained by applying the reduction operation kind
1961 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1962 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1963                                   Location loc, Value lhs, Value rhs) {
1964   switch (op) {
1965   case AtomicRMWKind::addf:
1966     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1967   case AtomicRMWKind::addi:
1968     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1969   case AtomicRMWKind::mulf:
1970     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1971   case AtomicRMWKind::muli:
1972     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1973   case AtomicRMWKind::maxf:
1974     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1975   case AtomicRMWKind::minf:
1976     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1977   case AtomicRMWKind::maxs:
1978     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1979   case AtomicRMWKind::mins:
1980     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1981   case AtomicRMWKind::maxu:
1982     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1983   case AtomicRMWKind::minu:
1984     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1985   case AtomicRMWKind::ori:
1986     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1987   case AtomicRMWKind::andi:
1988     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1989   // TODO: Add remaining reduction operations.
1990   default:
1991     (void)emitOptionalError(loc, "Reduction operation type not supported");
1992     break;
1993   }
1994   return nullptr;
1995 }
1996 
1997 //===----------------------------------------------------------------------===//
1998 // TableGen'd op method definitions
1999 //===----------------------------------------------------------------------===//
2000 
2001 #define GET_OP_CLASSES
2002 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2003 
2004 //===----------------------------------------------------------------------===//
2005 // TableGen'd enum attribute definitions
2006 //===----------------------------------------------------------------------===//
2007 
2008 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2009