1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30                                    Attribute lhs, Attribute rhs) {
31   return builder.getIntegerAttr(res.getType(),
32                                 lhs.cast<IntegerAttr>().getInt() +
33                                     rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37                                    Attribute lhs, Attribute rhs) {
38   return builder.getIntegerAttr(res.getType(),
39                                 lhs.cast<IntegerAttr>().getInt() -
40                                     rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45   switch (pred) {
46   case arith::CmpIPredicate::eq:
47     return arith::CmpIPredicate::ne;
48   case arith::CmpIPredicate::ne:
49     return arith::CmpIPredicate::eq;
50   case arith::CmpIPredicate::slt:
51     return arith::CmpIPredicate::sge;
52   case arith::CmpIPredicate::sle:
53     return arith::CmpIPredicate::sgt;
54   case arith::CmpIPredicate::sgt:
55     return arith::CmpIPredicate::sle;
56   case arith::CmpIPredicate::sge:
57     return arith::CmpIPredicate::slt;
58   case arith::CmpIPredicate::ult:
59     return arith::CmpIPredicate::uge;
60   case arith::CmpIPredicate::ule:
61     return arith::CmpIPredicate::ugt;
62   case arith::CmpIPredicate::ugt:
63     return arith::CmpIPredicate::ule;
64   case arith::CmpIPredicate::uge:
65     return arith::CmpIPredicate::ult;
66   }
67   llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71   return arith::CmpIPredicateAttr::get(pred.getContext(),
72                                        invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88     function_ref<void(Value, StringRef)> setNameFn) {
89   auto type = getType();
90   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91     auto intType = type.dyn_cast<IntegerType>();
92 
93     // Sugar i1 constants with 'true' and 'false'.
94     if (intType && intType.getWidth() == 1)
95       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97     // Otherwise, build a compex name with the value and type.
98     SmallString<32> specialNameBuffer;
99     llvm::raw_svector_ostream specialName(specialNameBuffer);
100     specialName << 'c' << intCst.getInt();
101     if (intType)
102       specialName << '_' << type;
103     setNameFn(getResult(), specialName.str());
104   } else {
105     setNameFn(getResult(), "cst");
106   }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
111 LogicalResult arith::ConstantOp::verify() {
112   auto type = getType();
113   // The value's type must match the return type.
114   if (getValue().getType() != type) {
115     return emitOpError() << "value type " << getValue().getType()
116                          << " must match return type: " << type;
117   }
118   // Integer values must be signless.
119   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120     return emitOpError("integer return type must be signless");
121   // Any float or elements attribute are acceptable.
122   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123     return emitOpError(
124         "value must be an integer, float, or elements attribute");
125   }
126   return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130   // The value's type must be the same as the provided type.
131   if (value.getType() != type)
132     return false;
133   // Integer values must be signless.
134   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135     return false;
136   // Integer, float, and element attributes are buildable.
137   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139 
140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141   return getValue();
142 }
143 
144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145                                  int64_t value, unsigned width) {
146   auto type = builder.getIntegerType(width);
147   arith::ConstantOp::build(builder, result, type,
148                            builder.getIntegerAttr(type, value));
149 }
150 
151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152                                  int64_t value, Type type) {
153   assert(type.isSignlessInteger() &&
154          "ConstantIntOp can only have signless integer type values");
155   arith::ConstantOp::build(builder, result, type,
156                            builder.getIntegerAttr(type, value));
157 }
158 
159 bool arith::ConstantIntOp::classof(Operation *op) {
160   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161     return constOp.getType().isSignlessInteger();
162   return false;
163 }
164 
165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166                                    const APFloat &value, FloatType type) {
167   arith::ConstantOp::build(builder, result, type,
168                            builder.getFloatAttr(type, value));
169 }
170 
171 bool arith::ConstantFloatOp::classof(Operation *op) {
172   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173     return constOp.getType().isa<FloatType>();
174   return false;
175 }
176 
177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178                                    int64_t value) {
179   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180                            builder.getIndexAttr(value));
181 }
182 
183 bool arith::ConstantIndexOp::classof(Operation *op) {
184   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185     return constOp.getType().isIndex();
186   return false;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194   // addi(x, 0) -> x
195   if (matchPattern(getRhs(), m_Zero()))
196     return getLhs();
197 
198   // addi(subi(a, b), b) -> a
199   if (auto sub = getLhs().getDefiningOp<SubIOp>())
200     if (getRhs() == sub.getRhs())
201       return sub.getLhs();
202 
203   // addi(b, subi(a, b)) -> a
204   if (auto sub = getRhs().getDefiningOp<SubIOp>())
205     if (getLhs() == sub.getRhs())
206       return sub.getLhs();
207 
208   return constFoldBinaryOp<IntegerAttr>(
209       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211 
212 void arith::AddIOp::getCanonicalizationPatterns(
213     RewritePatternSet &patterns, MLIRContext *context) {
214   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215       context);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221 
222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223   // subi(x,x) -> 0
224   if (getOperand(0) == getOperand(1))
225     return Builder(getContext()).getZeroAttr(getType());
226   // subi(x,0) -> x
227   if (matchPattern(getRhs(), m_Zero()))
228     return getLhs();
229 
230   return constFoldBinaryOp<IntegerAttr>(
231       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233 
234 void arith::SubIOp::getCanonicalizationPatterns(
235     RewritePatternSet &patterns, MLIRContext *context) {
236   patterns
237       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247   // muli(x, 0) -> 0
248   if (matchPattern(getRhs(), m_Zero()))
249     return getRhs();
250   // muli(x, 1) -> x
251   if (matchPattern(getRhs(), m_One()))
252     return getOperand(0);
253   // TODO: Handle the overflow case.
254 
255   // default folder
256   return constFoldBinaryOp<IntegerAttr>(
257       operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265   // divui (x, 1) -> x.
266   if (matchPattern(getRhs(), m_One()))
267     return getLhs();
268 
269   // Don't fold if it would require a division by zero.
270   bool div0 = false;
271   auto result =
272       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
273         if (div0 || !b) {
274           div0 = true;
275           return a;
276         }
277         return a.udiv(b);
278       });
279 
280   return div0 ? Attribute() : result;
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // DivSIOp
285 //===----------------------------------------------------------------------===//
286 
287 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
288   // divsi (x, 1) -> x.
289   if (matchPattern(getRhs(), m_One()))
290     return getLhs();
291 
292   // Don't fold if it would overflow or if it requires a division by zero.
293   bool overflowOrDiv0 = false;
294   auto result =
295       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
296         if (overflowOrDiv0 || !b) {
297           overflowOrDiv0 = true;
298           return a;
299         }
300         return a.sdiv_ov(b, overflowOrDiv0);
301       });
302 
303   return overflowOrDiv0 ? Attribute() : result;
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Ceil and floor division folding helpers
308 //===----------------------------------------------------------------------===//
309 
310 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
311                                     bool &overflow) {
312   // Returns (a-1)/b + 1
313   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
314   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
315   return val.sadd_ov(one, overflow);
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // CeilDivUIOp
320 //===----------------------------------------------------------------------===//
321 
322 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
323   // ceildivui (x, 1) -> x.
324   if (matchPattern(getRhs(), m_One()))
325     return getLhs();
326 
327   bool overflowOrDiv0 = false;
328   auto result =
329       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
330         if (overflowOrDiv0 || !b) {
331           overflowOrDiv0 = true;
332           return a;
333         }
334         APInt quotient = a.udiv(b);
335         if (!a.urem(b))
336           return quotient;
337         APInt one(a.getBitWidth(), 1, true);
338         return quotient.uadd_ov(one, overflowOrDiv0);
339       });
340 
341   return overflowOrDiv0 ? Attribute() : result;
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // CeilDivSIOp
346 //===----------------------------------------------------------------------===//
347 
348 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
349   // ceildivsi (x, 1) -> x.
350   if (matchPattern(getRhs(), m_One()))
351     return getLhs();
352 
353   // Don't fold if it would overflow or if it requires a division by zero.
354   bool overflowOrDiv0 = false;
355   auto result =
356       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
357         if (overflowOrDiv0 || !b) {
358           overflowOrDiv0 = true;
359           return a;
360         }
361         if (!a)
362           return a;
363         // After this point we know that neither a or b are zero.
364         unsigned bits = a.getBitWidth();
365         APInt zero = APInt::getZero(bits);
366         bool aGtZero = a.sgt(zero);
367         bool bGtZero = b.sgt(zero);
368         if (aGtZero && bGtZero) {
369           // Both positive, return ceil(a, b).
370           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
371         }
372         if (!aGtZero && !bGtZero) {
373           // Both negative, return ceil(-a, -b).
374           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
375           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
376           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
377         }
378         if (!aGtZero && bGtZero) {
379           // A is negative, b is positive, return - ( -a / b).
380           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
381           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
382           return zero.ssub_ov(div, overflowOrDiv0);
383         }
384         // A is positive, b is negative, return - (a / -b).
385         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
386         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
387         return zero.ssub_ov(div, overflowOrDiv0);
388       });
389 
390   return overflowOrDiv0 ? Attribute() : result;
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // FloorDivSIOp
395 //===----------------------------------------------------------------------===//
396 
397 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
398   // floordivsi (x, 1) -> x.
399   if (matchPattern(getRhs(), m_One()))
400     return getLhs();
401 
402   // Don't fold if it would overflow or if it requires a division by zero.
403   bool overflowOrDiv0 = false;
404   auto result =
405       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
406         if (overflowOrDiv0 || !b) {
407           overflowOrDiv0 = true;
408           return a;
409         }
410         if (!a)
411           return a;
412         // After this point we know that neither a or b are zero.
413         unsigned bits = a.getBitWidth();
414         APInt zero = APInt::getZero(bits);
415         bool aGtZero = a.sgt(zero);
416         bool bGtZero = b.sgt(zero);
417         if (aGtZero && bGtZero) {
418           // Both positive, return a / b.
419           return a.sdiv_ov(b, overflowOrDiv0);
420         }
421         if (!aGtZero && !bGtZero) {
422           // Both negative, return -a / -b.
423           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
424           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
425           return posA.sdiv_ov(posB, overflowOrDiv0);
426         }
427         if (!aGtZero && bGtZero) {
428           // A is negative, b is positive, return - ceil(-a, b).
429           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
430           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
431           return zero.ssub_ov(ceil, overflowOrDiv0);
432         }
433         // A is positive, b is negative, return - ceil(a, -b).
434         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
435         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
436         return zero.ssub_ov(ceil, overflowOrDiv0);
437       });
438 
439   return overflowOrDiv0 ? Attribute() : result;
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // RemUIOp
444 //===----------------------------------------------------------------------===//
445 
446 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
447   // remui (x, 1) -> 0.
448   if (matchPattern(getRhs(), m_One()))
449     return Builder(getContext()).getZeroAttr(getType());
450 
451   // Don't fold if it would require a division by zero.
452   bool div0 = false;
453   auto result =
454       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
455         if (div0 || b.isNullValue()) {
456           div0 = true;
457           return a;
458         }
459         return a.urem(b);
460       });
461 
462   return div0 ? Attribute() : result;
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // RemSIOp
467 //===----------------------------------------------------------------------===//
468 
469 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
470   // remsi (x, 1) -> 0.
471   if (matchPattern(getRhs(), m_One()))
472     return Builder(getContext()).getZeroAttr(getType());
473 
474   // Don't fold if it would require a division by zero.
475   bool div0 = false;
476   auto result =
477       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
478         if (div0 || b.isNullValue()) {
479           div0 = true;
480           return a;
481         }
482         return a.srem(b);
483       });
484 
485   return div0 ? Attribute() : result;
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // AndIOp
490 //===----------------------------------------------------------------------===//
491 
492 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
493   /// and(x, 0) -> 0
494   if (matchPattern(getRhs(), m_Zero()))
495     return getRhs();
496   /// and(x, allOnes) -> x
497   APInt intValue;
498   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
499     return getLhs();
500 
501   return constFoldBinaryOp<IntegerAttr>(
502       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // OrIOp
507 //===----------------------------------------------------------------------===//
508 
509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
510   /// or(x, 0) -> x
511   if (matchPattern(getRhs(), m_Zero()))
512     return getLhs();
513   /// or(x, <all ones>) -> <all ones>
514   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
515     if (rhsAttr.getValue().isAllOnes())
516       return rhsAttr;
517 
518   return constFoldBinaryOp<IntegerAttr>(
519       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // XOrIOp
524 //===----------------------------------------------------------------------===//
525 
526 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
527   /// xor(x, 0) -> x
528   if (matchPattern(getRhs(), m_Zero()))
529     return getLhs();
530   /// xor(x, x) -> 0
531   if (getLhs() == getRhs())
532     return Builder(getContext()).getZeroAttr(getType());
533   /// xor(xor(x, a), a) -> x
534   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
535     if (prev.getRhs() == getRhs())
536       return prev.getLhs();
537 
538   return constFoldBinaryOp<IntegerAttr>(
539       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
540 }
541 
542 void arith::XOrIOp::getCanonicalizationPatterns(
543     RewritePatternSet &patterns, MLIRContext *context) {
544   patterns.add<XOrINotCmpI>(context);
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // NegFOp
549 //===----------------------------------------------------------------------===//
550 
551 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
552   /// negf(negf(x)) -> x
553   if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
554     return op.getOperand();
555   return constFoldUnaryOp<FloatAttr>(operands,
556                                      [](const APFloat &a) { return -a; });
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // AddFOp
561 //===----------------------------------------------------------------------===//
562 
563 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
564   // addf(x, -0) -> x
565   if (matchPattern(getRhs(), m_NegZeroFloat()))
566     return getLhs();
567 
568   return constFoldBinaryOp<FloatAttr>(
569       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // SubFOp
574 //===----------------------------------------------------------------------===//
575 
576 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
577   // subf(x, +0) -> x
578   if (matchPattern(getRhs(), m_PosZeroFloat()))
579     return getLhs();
580 
581   return constFoldBinaryOp<FloatAttr>(
582       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // MaxFOp
587 //===----------------------------------------------------------------------===//
588 
589 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
590   assert(operands.size() == 2 && "maxf takes two operands");
591 
592   // maxf(x,x) -> x
593   if (getLhs() == getRhs())
594     return getRhs();
595 
596   // maxf(x, -inf) -> x
597   if (matchPattern(getRhs(), m_NegInfFloat()))
598     return getLhs();
599 
600   return constFoldBinaryOp<FloatAttr>(
601       operands,
602       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // MaxSIOp
607 //===----------------------------------------------------------------------===//
608 
609 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
610   assert(operands.size() == 2 && "binary operation takes two operands");
611 
612   // maxsi(x,x) -> x
613   if (getLhs() == getRhs())
614     return getRhs();
615 
616   APInt intValue;
617   // maxsi(x,MAX_INT) -> MAX_INT
618   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
619       intValue.isMaxSignedValue())
620     return getRhs();
621 
622   // maxsi(x, MIN_INT) -> x
623   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
624       intValue.isMinSignedValue())
625     return getLhs();
626 
627   return constFoldBinaryOp<IntegerAttr>(operands,
628                                         [](const APInt &a, const APInt &b) {
629                                           return llvm::APIntOps::smax(a, b);
630                                         });
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // MaxUIOp
635 //===----------------------------------------------------------------------===//
636 
637 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
638   assert(operands.size() == 2 && "binary operation takes two operands");
639 
640   // maxui(x,x) -> x
641   if (getLhs() == getRhs())
642     return getRhs();
643 
644   APInt intValue;
645   // maxui(x,MAX_INT) -> MAX_INT
646   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
647     return getRhs();
648 
649   // maxui(x, MIN_INT) -> x
650   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
651     return getLhs();
652 
653   return constFoldBinaryOp<IntegerAttr>(operands,
654                                         [](const APInt &a, const APInt &b) {
655                                           return llvm::APIntOps::umax(a, b);
656                                         });
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // MinFOp
661 //===----------------------------------------------------------------------===//
662 
663 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
664   assert(operands.size() == 2 && "minf takes two operands");
665 
666   // minf(x,x) -> x
667   if (getLhs() == getRhs())
668     return getRhs();
669 
670   // minf(x, +inf) -> x
671   if (matchPattern(getRhs(), m_PosInfFloat()))
672     return getLhs();
673 
674   return constFoldBinaryOp<FloatAttr>(
675       operands,
676       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
677 }
678 
679 //===----------------------------------------------------------------------===//
680 // MinSIOp
681 //===----------------------------------------------------------------------===//
682 
683 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
684   assert(operands.size() == 2 && "binary operation takes two operands");
685 
686   // minsi(x,x) -> x
687   if (getLhs() == getRhs())
688     return getRhs();
689 
690   APInt intValue;
691   // minsi(x,MIN_INT) -> MIN_INT
692   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
693       intValue.isMinSignedValue())
694     return getRhs();
695 
696   // minsi(x, MAX_INT) -> x
697   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
698       intValue.isMaxSignedValue())
699     return getLhs();
700 
701   return constFoldBinaryOp<IntegerAttr>(operands,
702                                         [](const APInt &a, const APInt &b) {
703                                           return llvm::APIntOps::smin(a, b);
704                                         });
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // MinUIOp
709 //===----------------------------------------------------------------------===//
710 
711 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
712   assert(operands.size() == 2 && "binary operation takes two operands");
713 
714   // minui(x,x) -> x
715   if (getLhs() == getRhs())
716     return getRhs();
717 
718   APInt intValue;
719   // minui(x,MIN_INT) -> MIN_INT
720   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
721     return getRhs();
722 
723   // minui(x, MAX_INT) -> x
724   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
725     return getLhs();
726 
727   return constFoldBinaryOp<IntegerAttr>(operands,
728                                         [](const APInt &a, const APInt &b) {
729                                           return llvm::APIntOps::umin(a, b);
730                                         });
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // MulFOp
735 //===----------------------------------------------------------------------===//
736 
737 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
738   // mulf(x, 1) -> x
739   if (matchPattern(getRhs(), m_OneFloat()))
740     return getLhs();
741 
742   return constFoldBinaryOp<FloatAttr>(
743       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
744 }
745 
746 //===----------------------------------------------------------------------===//
747 // DivFOp
748 //===----------------------------------------------------------------------===//
749 
750 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
751   // divf(x, 1) -> x
752   if (matchPattern(getRhs(), m_OneFloat()))
753     return getLhs();
754 
755   return constFoldBinaryOp<FloatAttr>(
756       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // Utility functions for verifying cast ops
761 //===----------------------------------------------------------------------===//
762 
763 template <typename... Types>
764 using type_list = std::tuple<Types...> *;
765 
766 /// Returns a non-null type only if the provided type is one of the allowed
767 /// types or one of the allowed shaped types of the allowed types. Returns the
768 /// element type if a valid shaped type is provided.
769 template <typename... ShapedTypes, typename... ElementTypes>
770 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
771                               type_list<ElementTypes...>) {
772   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
773     return {};
774 
775   auto underlyingType = getElementTypeOrSelf(type);
776   if (!underlyingType.isa<ElementTypes...>())
777     return {};
778 
779   return underlyingType;
780 }
781 
782 /// Get allowed underlying types for vectors and tensors.
783 template <typename... ElementTypes>
784 static Type getTypeIfLike(Type type) {
785   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
786                            type_list<ElementTypes...>());
787 }
788 
789 /// Get allowed underlying types for vectors, tensors, and memrefs.
790 template <typename... ElementTypes>
791 static Type getTypeIfLikeOrMemRef(Type type) {
792   return getUnderlyingType(type,
793                            type_list<VectorType, TensorType, MemRefType>(),
794                            type_list<ElementTypes...>());
795 }
796 
797 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
798   return inputs.size() == 1 && outputs.size() == 1 &&
799          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
800 }
801 
802 //===----------------------------------------------------------------------===//
803 // Verifiers for integer and floating point extension/truncation ops
804 //===----------------------------------------------------------------------===//
805 
806 // Extend ops can only extend to a wider type.
807 template <typename ValType, typename Op>
808 static LogicalResult verifyExtOp(Op op) {
809   Type srcType = getElementTypeOrSelf(op.getIn().getType());
810   Type dstType = getElementTypeOrSelf(op.getType());
811 
812   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
813     return op.emitError("result type ")
814            << dstType << " must be wider than operand type " << srcType;
815 
816   return success();
817 }
818 
819 // Truncate ops can only truncate to a shorter type.
820 template <typename ValType, typename Op>
821 static LogicalResult verifyTruncateOp(Op op) {
822   Type srcType = getElementTypeOrSelf(op.getIn().getType());
823   Type dstType = getElementTypeOrSelf(op.getType());
824 
825   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
826     return op.emitError("result type ")
827            << dstType << " must be shorter than operand type " << srcType;
828 
829   return success();
830 }
831 
832 /// Validate a cast that changes the width of a type.
833 template <template <typename> class WidthComparator, typename... ElementTypes>
834 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
835   if (!areValidCastInputsAndOutputs(inputs, outputs))
836     return false;
837 
838   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
839   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
840   if (!srcType || !dstType)
841     return false;
842 
843   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
844                                      srcType.getIntOrFloatBitWidth());
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // ExtUIOp
849 //===----------------------------------------------------------------------===//
850 
851 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
852   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
853     getInMutable().assign(lhs.getIn());
854     return getResult();
855   }
856   Type resType = getType();
857   unsigned bitWidth;
858   if (auto shapedType = resType.dyn_cast<ShapedType>())
859     bitWidth = shapedType.getElementTypeBitWidth();
860   else
861     bitWidth = resType.getIntOrFloatBitWidth();
862   return constFoldCastOp<IntegerAttr, IntegerAttr>(
863       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
864         return a.zext(bitWidth);
865       });
866 }
867 
868 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
869   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
870 }
871 
872 LogicalResult arith::ExtUIOp::verify() {
873   return verifyExtOp<IntegerType>(*this);
874 }
875 
876 //===----------------------------------------------------------------------===//
877 // ExtSIOp
878 //===----------------------------------------------------------------------===//
879 
880 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
881   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
882     getInMutable().assign(lhs.getIn());
883     return getResult();
884   }
885   Type resType = getType();
886   unsigned bitWidth;
887   if (auto shapedType = resType.dyn_cast<ShapedType>())
888     bitWidth = shapedType.getElementTypeBitWidth();
889   else
890     bitWidth = resType.getIntOrFloatBitWidth();
891   return constFoldCastOp<IntegerAttr, IntegerAttr>(
892       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
893         return a.sext(bitWidth);
894       });
895 }
896 
897 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
898   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
899 }
900 
901 void arith::ExtSIOp::getCanonicalizationPatterns(
902     RewritePatternSet &patterns, MLIRContext *context) {
903   patterns.add<ExtSIOfExtUI>(context);
904 }
905 
906 LogicalResult arith::ExtSIOp::verify() {
907   return verifyExtOp<IntegerType>(*this);
908 }
909 
910 //===----------------------------------------------------------------------===//
911 // ExtFOp
912 //===----------------------------------------------------------------------===//
913 
914 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
915   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
916 }
917 
918 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
919 
920 //===----------------------------------------------------------------------===//
921 // TruncIOp
922 //===----------------------------------------------------------------------===//
923 
924 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
925   assert(operands.size() == 1 && "unary operation takes one operand");
926 
927   // trunci(zexti(a)) -> a
928   // trunci(sexti(a)) -> a
929   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
930       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
931     return getOperand().getDefiningOp()->getOperand(0);
932 
933   // trunci(trunci(a)) -> trunci(a))
934   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
935     setOperand(getOperand().getDefiningOp()->getOperand(0));
936     return getResult();
937   }
938 
939   Type resType = getType();
940   unsigned bitWidth;
941   if (auto shapedType = resType.dyn_cast<ShapedType>())
942     bitWidth = shapedType.getElementTypeBitWidth();
943   else
944     bitWidth = resType.getIntOrFloatBitWidth();
945 
946   return constFoldCastOp<IntegerAttr, IntegerAttr>(
947       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
948         return a.trunc(bitWidth);
949       });
950 }
951 
952 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
953   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
954 }
955 
956 LogicalResult arith::TruncIOp::verify() {
957   return verifyTruncateOp<IntegerType>(*this);
958 }
959 
960 //===----------------------------------------------------------------------===//
961 // TruncFOp
962 //===----------------------------------------------------------------------===//
963 
964 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
965 /// can be represented without precision loss or rounding.
966 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
967   assert(operands.size() == 1 && "unary operation takes one operand");
968 
969   auto constOperand = operands.front();
970   if (!constOperand || !constOperand.isa<FloatAttr>())
971     return {};
972 
973   // Convert to target type via 'double'.
974   double sourceValue =
975       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
976   auto targetAttr = FloatAttr::get(getType(), sourceValue);
977 
978   // Propagate if constant's value does not change after truncation.
979   if (sourceValue == targetAttr.getValue().convertToDouble())
980     return targetAttr;
981 
982   return {};
983 }
984 
985 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
986   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
987 }
988 
989 LogicalResult arith::TruncFOp::verify() {
990   return verifyTruncateOp<FloatType>(*this);
991 }
992 
993 //===----------------------------------------------------------------------===//
994 // AndIOp
995 //===----------------------------------------------------------------------===//
996 
997 void arith::AndIOp::getCanonicalizationPatterns(
998     RewritePatternSet &patterns, MLIRContext *context) {
999   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1000 }
1001 
1002 //===----------------------------------------------------------------------===//
1003 // OrIOp
1004 //===----------------------------------------------------------------------===//
1005 
1006 void arith::OrIOp::getCanonicalizationPatterns(
1007     RewritePatternSet &patterns, MLIRContext *context) {
1008   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1009 }
1010 
1011 //===----------------------------------------------------------------------===//
1012 // Verifiers for casts between integers and floats.
1013 //===----------------------------------------------------------------------===//
1014 
1015 template <typename From, typename To>
1016 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1017   if (!areValidCastInputsAndOutputs(inputs, outputs))
1018     return false;
1019 
1020   auto srcType = getTypeIfLike<From>(inputs.front());
1021   auto dstType = getTypeIfLike<To>(outputs.back());
1022 
1023   return srcType && dstType;
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // UIToFPOp
1028 //===----------------------------------------------------------------------===//
1029 
1030 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1031   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1032 }
1033 
1034 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1035   Type resType = getType();
1036   Type resEleType;
1037   if (auto shapedType = resType.dyn_cast<ShapedType>())
1038     resEleType = shapedType.getElementType();
1039   else
1040     resEleType = resType;
1041   return constFoldCastOp<IntegerAttr, FloatAttr>(
1042       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1043         FloatType floatTy = resEleType.cast<FloatType>();
1044         APFloat apf(floatTy.getFloatSemantics(),
1045                     APInt::getZero(floatTy.getWidth()));
1046         apf.convertFromAPInt(a, /*IsSigned=*/false,
1047                              APFloat::rmNearestTiesToEven);
1048         return apf;
1049       });
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // SIToFPOp
1054 //===----------------------------------------------------------------------===//
1055 
1056 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1057   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1058 }
1059 
1060 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1061   Type resType = getType();
1062   Type resEleType;
1063   if (auto shapedType = resType.dyn_cast<ShapedType>())
1064     resEleType = shapedType.getElementType();
1065   else
1066     resEleType = resType;
1067   return constFoldCastOp<IntegerAttr, FloatAttr>(
1068       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1069         FloatType floatTy = resEleType.cast<FloatType>();
1070         APFloat apf(floatTy.getFloatSemantics(),
1071                     APInt::getZero(floatTy.getWidth()));
1072         apf.convertFromAPInt(a, /*IsSigned=*/true,
1073                              APFloat::rmNearestTiesToEven);
1074         return apf;
1075       });
1076 }
1077 //===----------------------------------------------------------------------===//
1078 // FPToUIOp
1079 //===----------------------------------------------------------------------===//
1080 
1081 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1082   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1083 }
1084 
1085 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1086   Type resType = getType();
1087   Type resEleType;
1088   if (auto shapedType = resType.dyn_cast<ShapedType>())
1089     resEleType = shapedType.getElementType();
1090   else
1091     resEleType = resType;
1092   return constFoldCastOp<FloatAttr, IntegerAttr>(
1093       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1094         IntegerType intTy = resEleType.cast<IntegerType>();
1095         bool ignored;
1096         APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1097         castStatus = APFloat::opInvalidOp !=
1098                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1099         return api;
1100       });
1101 }
1102 
1103 //===----------------------------------------------------------------------===//
1104 // FPToSIOp
1105 //===----------------------------------------------------------------------===//
1106 
1107 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1108   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1109 }
1110 
1111 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1112   Type resType = getType();
1113   Type resEleType;
1114   if (auto shapedType = resType.dyn_cast<ShapedType>())
1115     resEleType = shapedType.getElementType();
1116   else
1117     resEleType = resType;
1118   return constFoldCastOp<FloatAttr, IntegerAttr>(
1119       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1120         IntegerType intTy = resEleType.cast<IntegerType>();
1121         bool ignored;
1122         APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1123         castStatus = APFloat::opInvalidOp !=
1124                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1125         return api;
1126       });
1127 }
1128 
1129 //===----------------------------------------------------------------------===//
1130 // IndexCastOp
1131 //===----------------------------------------------------------------------===//
1132 
1133 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1134                                            TypeRange outputs) {
1135   if (!areValidCastInputsAndOutputs(inputs, outputs))
1136     return false;
1137 
1138   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1139   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1140   if (!srcType || !dstType)
1141     return false;
1142 
1143   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1144          (srcType.isSignlessInteger() && dstType.isIndex());
1145 }
1146 
1147 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1148   // index_cast(constant) -> constant
1149   // A little hack because we go through int. Otherwise, the size of the
1150   // constant might need to change.
1151   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1152     return IntegerAttr::get(getType(), value.getInt());
1153 
1154   return {};
1155 }
1156 
1157 void arith::IndexCastOp::getCanonicalizationPatterns(
1158     RewritePatternSet &patterns, MLIRContext *context) {
1159   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1160 }
1161 
1162 //===----------------------------------------------------------------------===//
1163 // BitcastOp
1164 //===----------------------------------------------------------------------===//
1165 
1166 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1167   if (!areValidCastInputsAndOutputs(inputs, outputs))
1168     return false;
1169 
1170   auto srcType =
1171       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1172   auto dstType =
1173       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1174   if (!srcType || !dstType)
1175     return false;
1176 
1177   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1178 }
1179 
1180 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1181   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1182 
1183   auto resType = getType();
1184   auto operand = operands[0];
1185   if (!operand)
1186     return {};
1187 
1188   /// Bitcast dense elements.
1189   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1190     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1191   /// Other shaped types unhandled.
1192   if (resType.isa<ShapedType>())
1193     return {};
1194 
1195   /// Bitcast integer or float to integer or float.
1196   APInt bits = operand.isa<FloatAttr>()
1197                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1198                    : operand.cast<IntegerAttr>().getValue();
1199 
1200   if (auto resFloatType = resType.dyn_cast<FloatType>())
1201     return FloatAttr::get(resType,
1202                           APFloat(resFloatType.getFloatSemantics(), bits));
1203   return IntegerAttr::get(resType, bits);
1204 }
1205 
1206 void arith::BitcastOp::getCanonicalizationPatterns(
1207     RewritePatternSet &patterns, MLIRContext *context) {
1208   patterns.add<BitcastOfBitcast>(context);
1209 }
1210 
1211 //===----------------------------------------------------------------------===//
1212 // Helpers for compare ops
1213 //===----------------------------------------------------------------------===//
1214 
1215 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1216 static Type getI1SameShape(Type type) {
1217   auto i1Type = IntegerType::get(type.getContext(), 1);
1218   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1219     return RankedTensorType::get(tensorType.getShape(), i1Type);
1220   if (type.isa<UnrankedTensorType>())
1221     return UnrankedTensorType::get(i1Type);
1222   if (auto vectorType = type.dyn_cast<VectorType>())
1223     return VectorType::get(vectorType.getShape(), i1Type,
1224                            vectorType.getNumScalableDims());
1225   return i1Type;
1226 }
1227 
1228 //===----------------------------------------------------------------------===//
1229 // CmpIOp
1230 //===----------------------------------------------------------------------===//
1231 
1232 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1233 /// comparison predicates.
1234 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1235                                     const APInt &lhs, const APInt &rhs) {
1236   switch (predicate) {
1237   case arith::CmpIPredicate::eq:
1238     return lhs.eq(rhs);
1239   case arith::CmpIPredicate::ne:
1240     return lhs.ne(rhs);
1241   case arith::CmpIPredicate::slt:
1242     return lhs.slt(rhs);
1243   case arith::CmpIPredicate::sle:
1244     return lhs.sle(rhs);
1245   case arith::CmpIPredicate::sgt:
1246     return lhs.sgt(rhs);
1247   case arith::CmpIPredicate::sge:
1248     return lhs.sge(rhs);
1249   case arith::CmpIPredicate::ult:
1250     return lhs.ult(rhs);
1251   case arith::CmpIPredicate::ule:
1252     return lhs.ule(rhs);
1253   case arith::CmpIPredicate::ugt:
1254     return lhs.ugt(rhs);
1255   case arith::CmpIPredicate::uge:
1256     return lhs.uge(rhs);
1257   }
1258   llvm_unreachable("unknown cmpi predicate kind");
1259 }
1260 
1261 /// Returns true if the predicate is true for two equal operands.
1262 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1263   switch (predicate) {
1264   case arith::CmpIPredicate::eq:
1265   case arith::CmpIPredicate::sle:
1266   case arith::CmpIPredicate::sge:
1267   case arith::CmpIPredicate::ule:
1268   case arith::CmpIPredicate::uge:
1269     return true;
1270   case arith::CmpIPredicate::ne:
1271   case arith::CmpIPredicate::slt:
1272   case arith::CmpIPredicate::sgt:
1273   case arith::CmpIPredicate::ult:
1274   case arith::CmpIPredicate::ugt:
1275     return false;
1276   }
1277   llvm_unreachable("unknown cmpi predicate kind");
1278 }
1279 
1280 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1281   auto boolAttr = BoolAttr::get(ctx, value);
1282   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1283   if (!shapedType)
1284     return boolAttr;
1285   return DenseElementsAttr::get(shapedType, boolAttr);
1286 }
1287 
1288 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1289   assert(operands.size() == 2 && "cmpi takes two operands");
1290 
1291   // cmpi(pred, x, x)
1292   if (getLhs() == getRhs()) {
1293     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1294     return getBoolAttribute(getType(), getContext(), val);
1295   }
1296 
1297   if (matchPattern(getRhs(), m_Zero())) {
1298     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1299       // extsi(%x : i1 -> iN) != 0  ->  %x
1300       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1301           getPredicate() == arith::CmpIPredicate::ne)
1302         return extOp.getOperand();
1303     }
1304     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1305       // extui(%x : i1 -> iN) != 0  ->  %x
1306       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1307           getPredicate() == arith::CmpIPredicate::ne)
1308         return extOp.getOperand();
1309     }
1310   }
1311 
1312   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1313   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1314   if (!lhs || !rhs)
1315     return {};
1316 
1317   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1318   return BoolAttr::get(getContext(), val);
1319 }
1320 
1321 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1322                                                 MLIRContext *context) {
1323   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1324 }
1325 
1326 //===----------------------------------------------------------------------===//
1327 // CmpFOp
1328 //===----------------------------------------------------------------------===//
1329 
1330 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1331 /// comparison predicates.
1332 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1333                                     const APFloat &lhs, const APFloat &rhs) {
1334   auto cmpResult = lhs.compare(rhs);
1335   switch (predicate) {
1336   case arith::CmpFPredicate::AlwaysFalse:
1337     return false;
1338   case arith::CmpFPredicate::OEQ:
1339     return cmpResult == APFloat::cmpEqual;
1340   case arith::CmpFPredicate::OGT:
1341     return cmpResult == APFloat::cmpGreaterThan;
1342   case arith::CmpFPredicate::OGE:
1343     return cmpResult == APFloat::cmpGreaterThan ||
1344            cmpResult == APFloat::cmpEqual;
1345   case arith::CmpFPredicate::OLT:
1346     return cmpResult == APFloat::cmpLessThan;
1347   case arith::CmpFPredicate::OLE:
1348     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1349   case arith::CmpFPredicate::ONE:
1350     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1351   case arith::CmpFPredicate::ORD:
1352     return cmpResult != APFloat::cmpUnordered;
1353   case arith::CmpFPredicate::UEQ:
1354     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1355   case arith::CmpFPredicate::UGT:
1356     return cmpResult == APFloat::cmpUnordered ||
1357            cmpResult == APFloat::cmpGreaterThan;
1358   case arith::CmpFPredicate::UGE:
1359     return cmpResult == APFloat::cmpUnordered ||
1360            cmpResult == APFloat::cmpGreaterThan ||
1361            cmpResult == APFloat::cmpEqual;
1362   case arith::CmpFPredicate::ULT:
1363     return cmpResult == APFloat::cmpUnordered ||
1364            cmpResult == APFloat::cmpLessThan;
1365   case arith::CmpFPredicate::ULE:
1366     return cmpResult == APFloat::cmpUnordered ||
1367            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1368   case arith::CmpFPredicate::UNE:
1369     return cmpResult != APFloat::cmpEqual;
1370   case arith::CmpFPredicate::UNO:
1371     return cmpResult == APFloat::cmpUnordered;
1372   case arith::CmpFPredicate::AlwaysTrue:
1373     return true;
1374   }
1375   llvm_unreachable("unknown cmpf predicate kind");
1376 }
1377 
1378 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1379   assert(operands.size() == 2 && "cmpf takes two operands");
1380 
1381   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1382   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1383 
1384   // If one operand is NaN, making them both NaN does not change the result.
1385   if (lhs && lhs.getValue().isNaN())
1386     rhs = lhs;
1387   if (rhs && rhs.getValue().isNaN())
1388     lhs = rhs;
1389 
1390   if (!lhs || !rhs)
1391     return {};
1392 
1393   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1394   return BoolAttr::get(getContext(), val);
1395 }
1396 
1397 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1398 public:
1399   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1400 
1401   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1402                                                  bool isUnsigned) {
1403     using namespace arith;
1404     switch (pred) {
1405     case CmpFPredicate::UEQ:
1406     case CmpFPredicate::OEQ:
1407       return CmpIPredicate::eq;
1408     case CmpFPredicate::UGT:
1409     case CmpFPredicate::OGT:
1410       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1411     case CmpFPredicate::UGE:
1412     case CmpFPredicate::OGE:
1413       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1414     case CmpFPredicate::ULT:
1415     case CmpFPredicate::OLT:
1416       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1417     case CmpFPredicate::ULE:
1418     case CmpFPredicate::OLE:
1419       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1420     case CmpFPredicate::UNE:
1421     case CmpFPredicate::ONE:
1422       return CmpIPredicate::ne;
1423     default:
1424       llvm_unreachable("Unexpected predicate!");
1425     }
1426   }
1427 
1428   LogicalResult matchAndRewrite(CmpFOp op,
1429                                 PatternRewriter &rewriter) const override {
1430     FloatAttr flt;
1431     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1432       return failure();
1433 
1434     const APFloat &rhs = flt.getValue();
1435 
1436     // Don't attempt to fold a nan.
1437     if (rhs.isNaN())
1438       return failure();
1439 
1440     // Get the width of the mantissa.  We don't want to hack on conversions that
1441     // might lose information from the integer, e.g. "i64 -> float"
1442     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1443     int mantissaWidth = floatTy.getFPMantissaWidth();
1444     if (mantissaWidth <= 0)
1445       return failure();
1446 
1447     bool isUnsigned;
1448     Value intVal;
1449 
1450     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1451       isUnsigned = false;
1452       intVal = si.getIn();
1453     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1454       isUnsigned = true;
1455       intVal = ui.getIn();
1456     } else {
1457       return failure();
1458     }
1459 
1460     // Check to see that the input is converted from an integer type that is
1461     // small enough that preserves all bits.
1462     auto intTy = intVal.getType().cast<IntegerType>();
1463     auto intWidth = intTy.getWidth();
1464 
1465     // Number of bits representing values, as opposed to the sign
1466     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1467 
1468     // Following test does NOT adjust intWidth downwards for signed inputs,
1469     // because the most negative value still requires all the mantissa bits
1470     // to distinguish it from one less than that value.
1471     if ((int)intWidth > mantissaWidth) {
1472       // Conversion would lose accuracy. Check if loss can impact comparison.
1473       int exponent = ilogb(rhs);
1474       if (exponent == APFloat::IEK_Inf) {
1475         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1476         if (maxExponent < (int)valueBits) {
1477           // Conversion could create infinity.
1478           return failure();
1479         }
1480       } else {
1481         // Note that if rhs is zero or NaN, then Exp is negative
1482         // and first condition is trivially false.
1483         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1484           // Conversion could affect comparison.
1485           return failure();
1486         }
1487       }
1488     }
1489 
1490     // Convert to equivalent cmpi predicate
1491     CmpIPredicate pred;
1492     switch (op.getPredicate()) {
1493     case CmpFPredicate::ORD:
1494       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1495       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1496                                                  /*width=*/1);
1497       return success();
1498     case CmpFPredicate::UNO:
1499       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1500       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1501                                                  /*width=*/1);
1502       return success();
1503     default:
1504       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1505       break;
1506     }
1507 
1508     if (!isUnsigned) {
1509       // If the rhs value is > SignedMax, fold the comparison.  This handles
1510       // +INF and large values.
1511       APFloat signedMax(rhs.getSemantics());
1512       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1513                                  APFloat::rmNearestTiesToEven);
1514       if (signedMax < rhs) { // smax < 13123.0
1515         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1516             pred == CmpIPredicate::sle)
1517           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1518                                                      /*width=*/1);
1519         else
1520           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1521                                                      /*width=*/1);
1522         return success();
1523       }
1524     } else {
1525       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1526       // +INF and large values.
1527       APFloat unsignedMax(rhs.getSemantics());
1528       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1529                                    APFloat::rmNearestTiesToEven);
1530       if (unsignedMax < rhs) { // umax < 13123.0
1531         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1532             pred == CmpIPredicate::ule)
1533           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1534                                                      /*width=*/1);
1535         else
1536           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1537                                                      /*width=*/1);
1538         return success();
1539       }
1540     }
1541 
1542     if (!isUnsigned) {
1543       // See if the rhs value is < SignedMin.
1544       APFloat signedMin(rhs.getSemantics());
1545       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1546                                  APFloat::rmNearestTiesToEven);
1547       if (signedMin > rhs) { // smin > 12312.0
1548         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1549             pred == CmpIPredicate::sge)
1550           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1551                                                      /*width=*/1);
1552         else
1553           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1554                                                      /*width=*/1);
1555         return success();
1556       }
1557     } else {
1558       // See if the rhs value is < UnsignedMin.
1559       APFloat unsignedMin(rhs.getSemantics());
1560       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1561                                    APFloat::rmNearestTiesToEven);
1562       if (unsignedMin > rhs) { // umin > 12312.0
1563         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1564             pred == CmpIPredicate::uge)
1565           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1566                                                      /*width=*/1);
1567         else
1568           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1569                                                      /*width=*/1);
1570         return success();
1571       }
1572     }
1573 
1574     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1575     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1576     // casting the FP value to the integer value and back, checking for
1577     // equality. Don't do this for zero, because -0.0 is not fractional.
1578     bool ignored;
1579     APSInt rhsInt(intWidth, isUnsigned);
1580     if (APFloat::opInvalidOp ==
1581         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1582       // Undefined behavior invoked - the destination type can't represent
1583       // the input constant.
1584       return failure();
1585     }
1586 
1587     if (!rhs.isZero()) {
1588       APFloat apf(floatTy.getFloatSemantics(),
1589                   APInt::getZero(floatTy.getWidth()));
1590       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1591 
1592       bool equal = apf == rhs;
1593       if (!equal) {
1594         // If we had a comparison against a fractional value, we have to adjust
1595         // the compare predicate and sometimes the value.  rhsInt is rounded
1596         // towards zero at this point.
1597         switch (pred) {
1598         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1599           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1600                                                      /*width=*/1);
1601           return success();
1602         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1603           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1604                                                      /*width=*/1);
1605           return success();
1606         case CmpIPredicate::ule:
1607           // (float)int <= 4.4   --> int <= 4
1608           // (float)int <= -4.4  --> false
1609           if (rhs.isNegative()) {
1610             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1611                                                        /*width=*/1);
1612             return success();
1613           }
1614           break;
1615         case CmpIPredicate::sle:
1616           // (float)int <= 4.4   --> int <= 4
1617           // (float)int <= -4.4  --> int < -4
1618           if (rhs.isNegative())
1619             pred = CmpIPredicate::slt;
1620           break;
1621         case CmpIPredicate::ult:
1622           // (float)int < -4.4   --> false
1623           // (float)int < 4.4    --> int <= 4
1624           if (rhs.isNegative()) {
1625             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1626                                                        /*width=*/1);
1627             return success();
1628           }
1629           pred = CmpIPredicate::ule;
1630           break;
1631         case CmpIPredicate::slt:
1632           // (float)int < -4.4   --> int < -4
1633           // (float)int < 4.4    --> int <= 4
1634           if (!rhs.isNegative())
1635             pred = CmpIPredicate::sle;
1636           break;
1637         case CmpIPredicate::ugt:
1638           // (float)int > 4.4    --> int > 4
1639           // (float)int > -4.4   --> true
1640           if (rhs.isNegative()) {
1641             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1642                                                        /*width=*/1);
1643             return success();
1644           }
1645           break;
1646         case CmpIPredicate::sgt:
1647           // (float)int > 4.4    --> int > 4
1648           // (float)int > -4.4   --> int >= -4
1649           if (rhs.isNegative())
1650             pred = CmpIPredicate::sge;
1651           break;
1652         case CmpIPredicate::uge:
1653           // (float)int >= -4.4   --> true
1654           // (float)int >= 4.4    --> int > 4
1655           if (rhs.isNegative()) {
1656             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1657                                                        /*width=*/1);
1658             return success();
1659           }
1660           pred = CmpIPredicate::ugt;
1661           break;
1662         case CmpIPredicate::sge:
1663           // (float)int >= -4.4   --> int >= -4
1664           // (float)int >= 4.4    --> int > 4
1665           if (!rhs.isNegative())
1666             pred = CmpIPredicate::sgt;
1667           break;
1668         }
1669       }
1670     }
1671 
1672     // Lower this FP comparison into an appropriate integer version of the
1673     // comparison.
1674     rewriter.replaceOpWithNewOp<CmpIOp>(
1675         op, pred, intVal,
1676         rewriter.create<ConstantOp>(
1677             op.getLoc(), intVal.getType(),
1678             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1679     return success();
1680   }
1681 };
1682 
1683 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1684                                                 MLIRContext *context) {
1685   patterns.insert<CmpFIntToFPConst>(context);
1686 }
1687 
1688 //===----------------------------------------------------------------------===//
1689 // SelectOp
1690 //===----------------------------------------------------------------------===//
1691 
1692 // Transforms a select of a boolean to arithmetic operations
1693 //
1694 //  arith.select %arg, %x, %y : i1
1695 //
1696 //  becomes
1697 //
1698 //  and(%arg, %x) or and(!%arg, %y)
1699 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1700   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1701 
1702   LogicalResult matchAndRewrite(arith::SelectOp op,
1703                                 PatternRewriter &rewriter) const override {
1704     if (!op.getType().isInteger(1))
1705       return failure();
1706 
1707     Value falseConstant =
1708         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1709     Value notCondition = rewriter.create<arith::XOrIOp>(
1710         op.getLoc(), op.getCondition(), falseConstant);
1711 
1712     Value trueVal = rewriter.create<arith::AndIOp>(
1713         op.getLoc(), op.getCondition(), op.getTrueValue());
1714     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1715                                                     op.getFalseValue());
1716     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1717     return success();
1718   }
1719 };
1720 
1721 //  select %arg, %c1, %c0 => extui %arg
1722 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1723   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1724 
1725   LogicalResult matchAndRewrite(arith::SelectOp op,
1726                                 PatternRewriter &rewriter) const override {
1727     // Cannot extui i1 to i1, or i1 to f32
1728     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1729       return failure();
1730 
1731     // select %x, c1, %c0 => extui %arg
1732     if (matchPattern(op.getTrueValue(), m_One()) &&
1733         matchPattern(op.getFalseValue(), m_Zero())) {
1734       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1735                                                   op.getCondition());
1736       return success();
1737     }
1738 
1739     // select %x, c0, %c1 => extui (xor %arg, true)
1740     if (matchPattern(op.getTrueValue(), m_Zero()) &&
1741         matchPattern(op.getFalseValue(), m_One())) {
1742       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1743           op, op.getType(),
1744           rewriter.create<arith::XOrIOp>(
1745               op.getLoc(), op.getCondition(),
1746               rewriter.create<arith::ConstantIntOp>(
1747                   op.getLoc(), 1, op.getCondition().getType())));
1748       return success();
1749     }
1750 
1751     return failure();
1752   }
1753 };
1754 
1755 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1756                                                   MLIRContext *context) {
1757   results.add<SelectI1Simplify, SelectToExtUI>(context);
1758 }
1759 
1760 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1761   Value trueVal = getTrueValue();
1762   Value falseVal = getFalseValue();
1763   if (trueVal == falseVal)
1764     return trueVal;
1765 
1766   Value condition = getCondition();
1767 
1768   // select true, %0, %1 => %0
1769   if (matchPattern(condition, m_One()))
1770     return trueVal;
1771 
1772   // select false, %0, %1 => %1
1773   if (matchPattern(condition, m_Zero()))
1774     return falseVal;
1775 
1776   // select %x, true, false => %x
1777   if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
1778       matchPattern(getFalseValue(), m_Zero()))
1779     return condition;
1780 
1781   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1782     auto pred = cmp.getPredicate();
1783     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1784       auto cmpLhs = cmp.getLhs();
1785       auto cmpRhs = cmp.getRhs();
1786 
1787       // %0 = arith.cmpi eq, %arg0, %arg1
1788       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1789 
1790       // %0 = arith.cmpi ne, %arg0, %arg1
1791       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1792 
1793       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1794           (cmpRhs == trueVal && cmpLhs == falseVal))
1795         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1796     }
1797   }
1798   return nullptr;
1799 }
1800 
1801 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1802   Type conditionType, resultType;
1803   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1804   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1805       parser.parseOptionalAttrDict(result.attributes) ||
1806       parser.parseColonType(resultType))
1807     return failure();
1808 
1809   // Check for the explicit condition type if this is a masked tensor or vector.
1810   if (succeeded(parser.parseOptionalComma())) {
1811     conditionType = resultType;
1812     if (parser.parseType(resultType))
1813       return failure();
1814   } else {
1815     conditionType = parser.getBuilder().getI1Type();
1816   }
1817 
1818   result.addTypes(resultType);
1819   return parser.resolveOperands(operands,
1820                                 {conditionType, resultType, resultType},
1821                                 parser.getNameLoc(), result.operands);
1822 }
1823 
1824 void arith::SelectOp::print(OpAsmPrinter &p) {
1825   p << " " << getOperands();
1826   p.printOptionalAttrDict((*this)->getAttrs());
1827   p << " : ";
1828   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1829     p << condType << ", ";
1830   p << getType();
1831 }
1832 
1833 LogicalResult arith::SelectOp::verify() {
1834   Type conditionType = getCondition().getType();
1835   if (conditionType.isSignlessInteger(1))
1836     return success();
1837 
1838   // If the result type is a vector or tensor, the type can be a mask with the
1839   // same elements.
1840   Type resultType = getType();
1841   if (!resultType.isa<TensorType, VectorType>())
1842     return emitOpError() << "expected condition to be a signless i1, but got "
1843                          << conditionType;
1844   Type shapedConditionType = getI1SameShape(resultType);
1845   if (conditionType != shapedConditionType) {
1846     return emitOpError() << "expected condition type to have the same shape "
1847                             "as the result type, expected "
1848                          << shapedConditionType << ", but got "
1849                          << conditionType;
1850   }
1851   return success();
1852 }
1853 //===----------------------------------------------------------------------===//
1854 // ShLIOp
1855 //===----------------------------------------------------------------------===//
1856 
1857 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1858   // Don't fold if shifting more than the bit width.
1859   bool bounded = false;
1860   auto result = constFoldBinaryOp<IntegerAttr>(
1861       operands, [&](const APInt &a, const APInt &b) {
1862         bounded = b.ule(b.getBitWidth());
1863         return a.shl(b);
1864       });
1865   return bounded ? result : Attribute();
1866 }
1867 
1868 //===----------------------------------------------------------------------===//
1869 // ShRUIOp
1870 //===----------------------------------------------------------------------===//
1871 
1872 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1873   // Don't fold if shifting more than the bit width.
1874   bool bounded = false;
1875   auto result = constFoldBinaryOp<IntegerAttr>(
1876       operands, [&](const APInt &a, const APInt &b) {
1877         bounded = b.ule(b.getBitWidth());
1878         return a.lshr(b);
1879       });
1880   return bounded ? result : Attribute();
1881 }
1882 
1883 //===----------------------------------------------------------------------===//
1884 // ShRSIOp
1885 //===----------------------------------------------------------------------===//
1886 
1887 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1888   // Don't fold if shifting more than the bit width.
1889   bool bounded = false;
1890   auto result = constFoldBinaryOp<IntegerAttr>(
1891       operands, [&](const APInt &a, const APInt &b) {
1892         bounded = b.ule(b.getBitWidth());
1893         return a.ashr(b);
1894       });
1895   return bounded ? result : Attribute();
1896 }
1897 
1898 //===----------------------------------------------------------------------===//
1899 // Atomic Enum
1900 //===----------------------------------------------------------------------===//
1901 
1902 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1903 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1904                                             OpBuilder &builder, Location loc) {
1905   switch (kind) {
1906   case AtomicRMWKind::maxf:
1907     return builder.getFloatAttr(
1908         resultType,
1909         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1910                         /*Negative=*/true));
1911   case AtomicRMWKind::addf:
1912   case AtomicRMWKind::addi:
1913   case AtomicRMWKind::maxu:
1914   case AtomicRMWKind::ori:
1915     return builder.getZeroAttr(resultType);
1916   case AtomicRMWKind::andi:
1917     return builder.getIntegerAttr(
1918         resultType,
1919         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1920   case AtomicRMWKind::maxs:
1921     return builder.getIntegerAttr(
1922         resultType,
1923         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1924   case AtomicRMWKind::minf:
1925     return builder.getFloatAttr(
1926         resultType,
1927         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1928                         /*Negative=*/false));
1929   case AtomicRMWKind::mins:
1930     return builder.getIntegerAttr(
1931         resultType,
1932         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1933   case AtomicRMWKind::minu:
1934     return builder.getIntegerAttr(
1935         resultType,
1936         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1937   case AtomicRMWKind::muli:
1938     return builder.getIntegerAttr(resultType, 1);
1939   case AtomicRMWKind::mulf:
1940     return builder.getFloatAttr(resultType, 1);
1941   // TODO: Add remaining reduction operations.
1942   default:
1943     (void)emitOptionalError(loc, "Reduction operation type not supported");
1944     break;
1945   }
1946   return nullptr;
1947 }
1948 
1949 /// Returns the identity value associated with an AtomicRMWKind op.
1950 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1951                                     OpBuilder &builder, Location loc) {
1952   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1953   return builder.create<arith::ConstantOp>(loc, attr);
1954 }
1955 
1956 /// Return the value obtained by applying the reduction operation kind
1957 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1958 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1959                                   Location loc, Value lhs, Value rhs) {
1960   switch (op) {
1961   case AtomicRMWKind::addf:
1962     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1963   case AtomicRMWKind::addi:
1964     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1965   case AtomicRMWKind::mulf:
1966     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1967   case AtomicRMWKind::muli:
1968     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1969   case AtomicRMWKind::maxf:
1970     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1971   case AtomicRMWKind::minf:
1972     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1973   case AtomicRMWKind::maxs:
1974     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1975   case AtomicRMWKind::mins:
1976     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1977   case AtomicRMWKind::maxu:
1978     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1979   case AtomicRMWKind::minu:
1980     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1981   case AtomicRMWKind::ori:
1982     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1983   case AtomicRMWKind::andi:
1984     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1985   // TODO: Add remaining reduction operations.
1986   default:
1987     (void)emitOptionalError(loc, "Reduction operation type not supported");
1988     break;
1989   }
1990   return nullptr;
1991 }
1992 
1993 //===----------------------------------------------------------------------===//
1994 // TableGen'd op method definitions
1995 //===----------------------------------------------------------------------===//
1996 
1997 #define GET_OP_CLASSES
1998 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1999 
2000 //===----------------------------------------------------------------------===//
2001 // TableGen'd enum attribute definitions
2002 //===----------------------------------------------------------------------===//
2003 
2004 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2005