1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30                                    Attribute lhs, Attribute rhs) {
31   return builder.getIntegerAttr(res.getType(),
32                                 lhs.cast<IntegerAttr>().getInt() +
33                                     rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37                                    Attribute lhs, Attribute rhs) {
38   return builder.getIntegerAttr(res.getType(),
39                                 lhs.cast<IntegerAttr>().getInt() -
40                                     rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45   switch (pred) {
46   case arith::CmpIPredicate::eq:
47     return arith::CmpIPredicate::ne;
48   case arith::CmpIPredicate::ne:
49     return arith::CmpIPredicate::eq;
50   case arith::CmpIPredicate::slt:
51     return arith::CmpIPredicate::sge;
52   case arith::CmpIPredicate::sle:
53     return arith::CmpIPredicate::sgt;
54   case arith::CmpIPredicate::sgt:
55     return arith::CmpIPredicate::sle;
56   case arith::CmpIPredicate::sge:
57     return arith::CmpIPredicate::slt;
58   case arith::CmpIPredicate::ult:
59     return arith::CmpIPredicate::uge;
60   case arith::CmpIPredicate::ule:
61     return arith::CmpIPredicate::ugt;
62   case arith::CmpIPredicate::ugt:
63     return arith::CmpIPredicate::ule;
64   case arith::CmpIPredicate::uge:
65     return arith::CmpIPredicate::ult;
66   }
67   llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71   return arith::CmpIPredicateAttr::get(pred.getContext(),
72                                        invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88     function_ref<void(Value, StringRef)> setNameFn) {
89   auto type = getType();
90   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91     auto intType = type.dyn_cast<IntegerType>();
92 
93     // Sugar i1 constants with 'true' and 'false'.
94     if (intType && intType.getWidth() == 1)
95       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97     // Otherwise, build a compex name with the value and type.
98     SmallString<32> specialNameBuffer;
99     llvm::raw_svector_ostream specialName(specialNameBuffer);
100     specialName << 'c' << intCst.getInt();
101     if (intType)
102       specialName << '_' << type;
103     setNameFn(getResult(), specialName.str());
104   } else {
105     setNameFn(getResult(), "cst");
106   }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
111 LogicalResult arith::ConstantOp::verify() {
112   auto type = getType();
113   // The value's type must match the return type.
114   if (getValue().getType() != type) {
115     return emitOpError() << "value type " << getValue().getType()
116                          << " must match return type: " << type;
117   }
118   // Integer values must be signless.
119   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120     return emitOpError("integer return type must be signless");
121   // Any float or elements attribute are acceptable.
122   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123     return emitOpError(
124         "value must be an integer, float, or elements attribute");
125   }
126   return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130   // The value's type must be the same as the provided type.
131   if (value.getType() != type)
132     return false;
133   // Integer values must be signless.
134   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135     return false;
136   // Integer, float, and element attributes are buildable.
137   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139 
140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141   return getValue();
142 }
143 
144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145                                  int64_t value, unsigned width) {
146   auto type = builder.getIntegerType(width);
147   arith::ConstantOp::build(builder, result, type,
148                            builder.getIntegerAttr(type, value));
149 }
150 
151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152                                  int64_t value, Type type) {
153   assert(type.isSignlessInteger() &&
154          "ConstantIntOp can only have signless integer type values");
155   arith::ConstantOp::build(builder, result, type,
156                            builder.getIntegerAttr(type, value));
157 }
158 
159 bool arith::ConstantIntOp::classof(Operation *op) {
160   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161     return constOp.getType().isSignlessInteger();
162   return false;
163 }
164 
165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166                                    const APFloat &value, FloatType type) {
167   arith::ConstantOp::build(builder, result, type,
168                            builder.getFloatAttr(type, value));
169 }
170 
171 bool arith::ConstantFloatOp::classof(Operation *op) {
172   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173     return constOp.getType().isa<FloatType>();
174   return false;
175 }
176 
177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178                                    int64_t value) {
179   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180                            builder.getIndexAttr(value));
181 }
182 
183 bool arith::ConstantIndexOp::classof(Operation *op) {
184   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185     return constOp.getType().isIndex();
186   return false;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194   // addi(x, 0) -> x
195   if (matchPattern(getRhs(), m_Zero()))
196     return getLhs();
197 
198   // addi(subi(a, b), b) -> a
199   if (auto sub = getLhs().getDefiningOp<SubIOp>())
200     if (getRhs() == sub.getRhs())
201       return sub.getLhs();
202 
203   // addi(b, subi(a, b)) -> a
204   if (auto sub = getRhs().getDefiningOp<SubIOp>())
205     if (getLhs() == sub.getRhs())
206       return sub.getLhs();
207 
208   return constFoldBinaryOp<IntegerAttr>(
209       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211 
212 void arith::AddIOp::getCanonicalizationPatterns(
213     RewritePatternSet &patterns, MLIRContext *context) {
214   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215       context);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221 
222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223   // subi(x,x) -> 0
224   if (getOperand(0) == getOperand(1))
225     return Builder(getContext()).getZeroAttr(getType());
226   // subi(x,0) -> x
227   if (matchPattern(getRhs(), m_Zero()))
228     return getLhs();
229 
230   return constFoldBinaryOp<IntegerAttr>(
231       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233 
234 void arith::SubIOp::getCanonicalizationPatterns(
235     RewritePatternSet &patterns, MLIRContext *context) {
236   patterns
237       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247   // muli(x, 0) -> 0
248   if (matchPattern(getRhs(), m_Zero()))
249     return getRhs();
250   // muli(x, 1) -> x
251   if (matchPattern(getRhs(), m_One()))
252     return getOperand(0);
253   // TODO: Handle the overflow case.
254 
255   // default folder
256   return constFoldBinaryOp<IntegerAttr>(
257       operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265   // divui (x, 1) -> x.
266   if (matchPattern(getRhs(), m_One()))
267     return getLhs();
268 
269   // Don't fold if it would require a division by zero.
270   bool div0 = false;
271   auto result =
272       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
273         if (div0 || !b) {
274           div0 = true;
275           return a;
276         }
277         return a.udiv(b);
278       });
279 
280   return div0 ? Attribute() : result;
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // DivSIOp
285 //===----------------------------------------------------------------------===//
286 
287 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
288   // divsi (x, 1) -> x.
289   if (matchPattern(getRhs(), m_One()))
290     return getLhs();
291 
292   // Don't fold if it would overflow or if it requires a division by zero.
293   bool overflowOrDiv0 = false;
294   auto result =
295       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
296         if (overflowOrDiv0 || !b) {
297           overflowOrDiv0 = true;
298           return a;
299         }
300         return a.sdiv_ov(b, overflowOrDiv0);
301       });
302 
303   return overflowOrDiv0 ? Attribute() : result;
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Ceil and floor division folding helpers
308 //===----------------------------------------------------------------------===//
309 
310 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
311                                     bool &overflow) {
312   // Returns (a-1)/b + 1
313   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
314   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
315   return val.sadd_ov(one, overflow);
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // CeilDivUIOp
320 //===----------------------------------------------------------------------===//
321 
322 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
323   // ceildivui (x, 1) -> x.
324   if (matchPattern(getRhs(), m_One()))
325     return getLhs();
326 
327   bool overflowOrDiv0 = false;
328   auto result =
329       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
330         if (overflowOrDiv0 || !b) {
331           overflowOrDiv0 = true;
332           return a;
333         }
334         APInt quotient = a.udiv(b);
335         if (!a.urem(b))
336           return quotient;
337         APInt one(a.getBitWidth(), 1, true);
338         return quotient.uadd_ov(one, overflowOrDiv0);
339       });
340 
341   return overflowOrDiv0 ? Attribute() : result;
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // CeilDivSIOp
346 //===----------------------------------------------------------------------===//
347 
348 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
349   // ceildivsi (x, 1) -> x.
350   if (matchPattern(getRhs(), m_One()))
351     return getLhs();
352 
353   // Don't fold if it would overflow or if it requires a division by zero.
354   bool overflowOrDiv0 = false;
355   auto result =
356       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
357         if (overflowOrDiv0 || !b) {
358           overflowOrDiv0 = true;
359           return a;
360         }
361         if (!a)
362           return a;
363         // After this point we know that neither a or b are zero.
364         unsigned bits = a.getBitWidth();
365         APInt zero = APInt::getZero(bits);
366         bool aGtZero = a.sgt(zero);
367         bool bGtZero = b.sgt(zero);
368         if (aGtZero && bGtZero) {
369           // Both positive, return ceil(a, b).
370           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
371         }
372         if (!aGtZero && !bGtZero) {
373           // Both negative, return ceil(-a, -b).
374           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
375           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
376           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
377         }
378         if (!aGtZero && bGtZero) {
379           // A is negative, b is positive, return - ( -a / b).
380           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
381           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
382           return zero.ssub_ov(div, overflowOrDiv0);
383         }
384         // A is positive, b is negative, return - (a / -b).
385         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
386         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
387         return zero.ssub_ov(div, overflowOrDiv0);
388       });
389 
390   return overflowOrDiv0 ? Attribute() : result;
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // FloorDivSIOp
395 //===----------------------------------------------------------------------===//
396 
397 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
398   // floordivsi (x, 1) -> x.
399   if (matchPattern(getRhs(), m_One()))
400     return getLhs();
401 
402   // Don't fold if it would overflow or if it requires a division by zero.
403   bool overflowOrDiv0 = false;
404   auto result =
405       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
406         if (overflowOrDiv0 || !b) {
407           overflowOrDiv0 = true;
408           return a;
409         }
410         if (!a)
411           return a;
412         // After this point we know that neither a or b are zero.
413         unsigned bits = a.getBitWidth();
414         APInt zero = APInt::getZero(bits);
415         bool aGtZero = a.sgt(zero);
416         bool bGtZero = b.sgt(zero);
417         if (aGtZero && bGtZero) {
418           // Both positive, return a / b.
419           return a.sdiv_ov(b, overflowOrDiv0);
420         }
421         if (!aGtZero && !bGtZero) {
422           // Both negative, return -a / -b.
423           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
424           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
425           return posA.sdiv_ov(posB, overflowOrDiv0);
426         }
427         if (!aGtZero && bGtZero) {
428           // A is negative, b is positive, return - ceil(-a, b).
429           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
430           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
431           return zero.ssub_ov(ceil, overflowOrDiv0);
432         }
433         // A is positive, b is negative, return - ceil(a, -b).
434         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
435         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
436         return zero.ssub_ov(ceil, overflowOrDiv0);
437       });
438 
439   return overflowOrDiv0 ? Attribute() : result;
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // RemUIOp
444 //===----------------------------------------------------------------------===//
445 
446 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
447   // remui (x, 1) -> 0.
448   if (matchPattern(getRhs(), m_One()))
449     return Builder(getContext()).getZeroAttr(getType());
450 
451   // Don't fold if it would require a division by zero.
452   bool div0 = false;
453   auto result =
454       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
455         if (div0 || b.isNullValue()) {
456           div0 = true;
457           return a;
458         }
459         return a.urem(b);
460       });
461 
462   return div0 ? Attribute() : result;
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // RemSIOp
467 //===----------------------------------------------------------------------===//
468 
469 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
470   // remsi (x, 1) -> 0.
471   if (matchPattern(getRhs(), m_One()))
472     return Builder(getContext()).getZeroAttr(getType());
473 
474   // Don't fold if it would require a division by zero.
475   bool div0 = false;
476   auto result =
477       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
478         if (div0 || b.isNullValue()) {
479           div0 = true;
480           return a;
481         }
482         return a.srem(b);
483       });
484 
485   return div0 ? Attribute() : result;
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // AndIOp
490 //===----------------------------------------------------------------------===//
491 
492 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
493   /// and(x, 0) -> 0
494   if (matchPattern(getRhs(), m_Zero()))
495     return getRhs();
496   /// and(x, allOnes) -> x
497   APInt intValue;
498   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
499     return getLhs();
500 
501   return constFoldBinaryOp<IntegerAttr>(
502       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // OrIOp
507 //===----------------------------------------------------------------------===//
508 
509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
510   /// or(x, 0) -> x
511   if (matchPattern(getRhs(), m_Zero()))
512     return getLhs();
513   /// or(x, <all ones>) -> <all ones>
514   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
515     if (rhsAttr.getValue().isAllOnes())
516       return rhsAttr;
517 
518   return constFoldBinaryOp<IntegerAttr>(
519       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // XOrIOp
524 //===----------------------------------------------------------------------===//
525 
526 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
527   /// xor(x, 0) -> x
528   if (matchPattern(getRhs(), m_Zero()))
529     return getLhs();
530   /// xor(x, x) -> 0
531   if (getLhs() == getRhs())
532     return Builder(getContext()).getZeroAttr(getType());
533   /// xor(xor(x, a), a) -> x
534   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
535     if (prev.getRhs() == getRhs())
536       return prev.getLhs();
537 
538   return constFoldBinaryOp<IntegerAttr>(
539       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
540 }
541 
542 void arith::XOrIOp::getCanonicalizationPatterns(
543     RewritePatternSet &patterns, MLIRContext *context) {
544   patterns.add<XOrINotCmpI>(context);
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // NegFOp
549 //===----------------------------------------------------------------------===//
550 
551 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
552   /// negf(negf(x)) -> x
553   if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
554     return op.getOperand();
555   return constFoldUnaryOp<FloatAttr>(operands,
556                                      [](const APFloat &a) { return -a; });
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // AddFOp
561 //===----------------------------------------------------------------------===//
562 
563 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
564   // addf(x, -0) -> x
565   if (matchPattern(getRhs(), m_NegZeroFloat()))
566     return getLhs();
567 
568   return constFoldBinaryOp<FloatAttr>(
569       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // SubFOp
574 //===----------------------------------------------------------------------===//
575 
576 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
577   // subf(x, +0) -> x
578   if (matchPattern(getRhs(), m_PosZeroFloat()))
579     return getLhs();
580 
581   return constFoldBinaryOp<FloatAttr>(
582       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // MaxFOp
587 //===----------------------------------------------------------------------===//
588 
589 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
590   assert(operands.size() == 2 && "maxf takes two operands");
591 
592   // maxf(x,x) -> x
593   if (getLhs() == getRhs())
594     return getRhs();
595 
596   // maxf(x, -inf) -> x
597   if (matchPattern(getRhs(), m_NegInfFloat()))
598     return getLhs();
599 
600   return constFoldBinaryOp<FloatAttr>(
601       operands,
602       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // MaxSIOp
607 //===----------------------------------------------------------------------===//
608 
609 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
610   assert(operands.size() == 2 && "binary operation takes two operands");
611 
612   // maxsi(x,x) -> x
613   if (getLhs() == getRhs())
614     return getRhs();
615 
616   APInt intValue;
617   // maxsi(x,MAX_INT) -> MAX_INT
618   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
619       intValue.isMaxSignedValue())
620     return getRhs();
621 
622   // maxsi(x, MIN_INT) -> x
623   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
624       intValue.isMinSignedValue())
625     return getLhs();
626 
627   return constFoldBinaryOp<IntegerAttr>(operands,
628                                         [](const APInt &a, const APInt &b) {
629                                           return llvm::APIntOps::smax(a, b);
630                                         });
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // MaxUIOp
635 //===----------------------------------------------------------------------===//
636 
637 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
638   assert(operands.size() == 2 && "binary operation takes two operands");
639 
640   // maxui(x,x) -> x
641   if (getLhs() == getRhs())
642     return getRhs();
643 
644   APInt intValue;
645   // maxui(x,MAX_INT) -> MAX_INT
646   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
647     return getRhs();
648 
649   // maxui(x, MIN_INT) -> x
650   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
651     return getLhs();
652 
653   return constFoldBinaryOp<IntegerAttr>(operands,
654                                         [](const APInt &a, const APInt &b) {
655                                           return llvm::APIntOps::umax(a, b);
656                                         });
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // MinFOp
661 //===----------------------------------------------------------------------===//
662 
663 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
664   assert(operands.size() == 2 && "minf takes two operands");
665 
666   // minf(x,x) -> x
667   if (getLhs() == getRhs())
668     return getRhs();
669 
670   // minf(x, +inf) -> x
671   if (matchPattern(getRhs(), m_PosInfFloat()))
672     return getLhs();
673 
674   return constFoldBinaryOp<FloatAttr>(
675       operands,
676       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
677 }
678 
679 //===----------------------------------------------------------------------===//
680 // MinSIOp
681 //===----------------------------------------------------------------------===//
682 
683 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
684   assert(operands.size() == 2 && "binary operation takes two operands");
685 
686   // minsi(x,x) -> x
687   if (getLhs() == getRhs())
688     return getRhs();
689 
690   APInt intValue;
691   // minsi(x,MIN_INT) -> MIN_INT
692   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
693       intValue.isMinSignedValue())
694     return getRhs();
695 
696   // minsi(x, MAX_INT) -> x
697   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
698       intValue.isMaxSignedValue())
699     return getLhs();
700 
701   return constFoldBinaryOp<IntegerAttr>(operands,
702                                         [](const APInt &a, const APInt &b) {
703                                           return llvm::APIntOps::smin(a, b);
704                                         });
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // MinUIOp
709 //===----------------------------------------------------------------------===//
710 
711 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
712   assert(operands.size() == 2 && "binary operation takes two operands");
713 
714   // minui(x,x) -> x
715   if (getLhs() == getRhs())
716     return getRhs();
717 
718   APInt intValue;
719   // minui(x,MIN_INT) -> MIN_INT
720   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
721     return getRhs();
722 
723   // minui(x, MAX_INT) -> x
724   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
725     return getLhs();
726 
727   return constFoldBinaryOp<IntegerAttr>(operands,
728                                         [](const APInt &a, const APInt &b) {
729                                           return llvm::APIntOps::umin(a, b);
730                                         });
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // MulFOp
735 //===----------------------------------------------------------------------===//
736 
737 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
738   // mulf(x, 1) -> x
739   if (matchPattern(getRhs(), m_OneFloat()))
740     return getLhs();
741 
742   return constFoldBinaryOp<FloatAttr>(
743       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
744 }
745 
746 //===----------------------------------------------------------------------===//
747 // DivFOp
748 //===----------------------------------------------------------------------===//
749 
750 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
751   // divf(x, 1) -> x
752   if (matchPattern(getRhs(), m_OneFloat()))
753     return getLhs();
754 
755   return constFoldBinaryOp<FloatAttr>(
756       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // RemFOp
761 //===----------------------------------------------------------------------===//
762 
763 OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) {
764   return constFoldBinaryOp<FloatAttr>(operands,
765                                       [](const APFloat &a, const APFloat &b) {
766                                         APFloat Result(a);
767                                         (void)Result.remainder(b);
768                                         return Result;
769                                       });
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // Utility functions for verifying cast ops
774 //===----------------------------------------------------------------------===//
775 
776 template <typename... Types>
777 using type_list = std::tuple<Types...> *;
778 
779 /// Returns a non-null type only if the provided type is one of the allowed
780 /// types or one of the allowed shaped types of the allowed types. Returns the
781 /// element type if a valid shaped type is provided.
782 template <typename... ShapedTypes, typename... ElementTypes>
783 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
784                               type_list<ElementTypes...>) {
785   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
786     return {};
787 
788   auto underlyingType = getElementTypeOrSelf(type);
789   if (!underlyingType.isa<ElementTypes...>())
790     return {};
791 
792   return underlyingType;
793 }
794 
795 /// Get allowed underlying types for vectors and tensors.
796 template <typename... ElementTypes>
797 static Type getTypeIfLike(Type type) {
798   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
799                            type_list<ElementTypes...>());
800 }
801 
802 /// Get allowed underlying types for vectors, tensors, and memrefs.
803 template <typename... ElementTypes>
804 static Type getTypeIfLikeOrMemRef(Type type) {
805   return getUnderlyingType(type,
806                            type_list<VectorType, TensorType, MemRefType>(),
807                            type_list<ElementTypes...>());
808 }
809 
810 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
811   return inputs.size() == 1 && outputs.size() == 1 &&
812          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
813 }
814 
815 //===----------------------------------------------------------------------===//
816 // Verifiers for integer and floating point extension/truncation ops
817 //===----------------------------------------------------------------------===//
818 
819 // Extend ops can only extend to a wider type.
820 template <typename ValType, typename Op>
821 static LogicalResult verifyExtOp(Op op) {
822   Type srcType = getElementTypeOrSelf(op.getIn().getType());
823   Type dstType = getElementTypeOrSelf(op.getType());
824 
825   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
826     return op.emitError("result type ")
827            << dstType << " must be wider than operand type " << srcType;
828 
829   return success();
830 }
831 
832 // Truncate ops can only truncate to a shorter type.
833 template <typename ValType, typename Op>
834 static LogicalResult verifyTruncateOp(Op op) {
835   Type srcType = getElementTypeOrSelf(op.getIn().getType());
836   Type dstType = getElementTypeOrSelf(op.getType());
837 
838   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
839     return op.emitError("result type ")
840            << dstType << " must be shorter than operand type " << srcType;
841 
842   return success();
843 }
844 
845 /// Validate a cast that changes the width of a type.
846 template <template <typename> class WidthComparator, typename... ElementTypes>
847 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
848   if (!areValidCastInputsAndOutputs(inputs, outputs))
849     return false;
850 
851   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
852   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
853   if (!srcType || !dstType)
854     return false;
855 
856   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
857                                      srcType.getIntOrFloatBitWidth());
858 }
859 
860 //===----------------------------------------------------------------------===//
861 // ExtUIOp
862 //===----------------------------------------------------------------------===//
863 
864 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
865   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
866     getInMutable().assign(lhs.getIn());
867     return getResult();
868   }
869   Type resType = getType();
870   unsigned bitWidth;
871   if (auto shapedType = resType.dyn_cast<ShapedType>())
872     bitWidth = shapedType.getElementTypeBitWidth();
873   else
874     bitWidth = resType.getIntOrFloatBitWidth();
875   return constFoldCastOp<IntegerAttr, IntegerAttr>(
876       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
877         return a.zext(bitWidth);
878       });
879 }
880 
881 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
882   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
883 }
884 
885 LogicalResult arith::ExtUIOp::verify() {
886   return verifyExtOp<IntegerType>(*this);
887 }
888 
889 //===----------------------------------------------------------------------===//
890 // ExtSIOp
891 //===----------------------------------------------------------------------===//
892 
893 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
894   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
895     getInMutable().assign(lhs.getIn());
896     return getResult();
897   }
898   Type resType = getType();
899   unsigned bitWidth;
900   if (auto shapedType = resType.dyn_cast<ShapedType>())
901     bitWidth = shapedType.getElementTypeBitWidth();
902   else
903     bitWidth = resType.getIntOrFloatBitWidth();
904   return constFoldCastOp<IntegerAttr, IntegerAttr>(
905       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
906         return a.sext(bitWidth);
907       });
908 }
909 
910 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
911   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
912 }
913 
914 void arith::ExtSIOp::getCanonicalizationPatterns(
915     RewritePatternSet &patterns, MLIRContext *context) {
916   patterns.add<ExtSIOfExtUI>(context);
917 }
918 
919 LogicalResult arith::ExtSIOp::verify() {
920   return verifyExtOp<IntegerType>(*this);
921 }
922 
923 //===----------------------------------------------------------------------===//
924 // ExtFOp
925 //===----------------------------------------------------------------------===//
926 
927 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
928   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
929 }
930 
931 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
932 
933 //===----------------------------------------------------------------------===//
934 // TruncIOp
935 //===----------------------------------------------------------------------===//
936 
937 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
938   assert(operands.size() == 1 && "unary operation takes one operand");
939 
940   // trunci(zexti(a)) -> a
941   // trunci(sexti(a)) -> a
942   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
943       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
944     return getOperand().getDefiningOp()->getOperand(0);
945 
946   // trunci(trunci(a)) -> trunci(a))
947   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
948     setOperand(getOperand().getDefiningOp()->getOperand(0));
949     return getResult();
950   }
951 
952   Type resType = getType();
953   unsigned bitWidth;
954   if (auto shapedType = resType.dyn_cast<ShapedType>())
955     bitWidth = shapedType.getElementTypeBitWidth();
956   else
957     bitWidth = resType.getIntOrFloatBitWidth();
958 
959   return constFoldCastOp<IntegerAttr, IntegerAttr>(
960       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
961         return a.trunc(bitWidth);
962       });
963 }
964 
965 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
966   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
967 }
968 
969 LogicalResult arith::TruncIOp::verify() {
970   return verifyTruncateOp<IntegerType>(*this);
971 }
972 
973 //===----------------------------------------------------------------------===//
974 // TruncFOp
975 //===----------------------------------------------------------------------===//
976 
977 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
978 /// can be represented without precision loss or rounding.
979 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
980   assert(operands.size() == 1 && "unary operation takes one operand");
981 
982   auto constOperand = operands.front();
983   if (!constOperand || !constOperand.isa<FloatAttr>())
984     return {};
985 
986   // Convert to target type via 'double'.
987   double sourceValue =
988       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
989   auto targetAttr = FloatAttr::get(getType(), sourceValue);
990 
991   // Propagate if constant's value does not change after truncation.
992   if (sourceValue == targetAttr.getValue().convertToDouble())
993     return targetAttr;
994 
995   return {};
996 }
997 
998 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
999   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1000 }
1001 
1002 LogicalResult arith::TruncFOp::verify() {
1003   return verifyTruncateOp<FloatType>(*this);
1004 }
1005 
1006 //===----------------------------------------------------------------------===//
1007 // AndIOp
1008 //===----------------------------------------------------------------------===//
1009 
1010 void arith::AndIOp::getCanonicalizationPatterns(
1011     RewritePatternSet &patterns, MLIRContext *context) {
1012   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1013 }
1014 
1015 //===----------------------------------------------------------------------===//
1016 // OrIOp
1017 //===----------------------------------------------------------------------===//
1018 
1019 void arith::OrIOp::getCanonicalizationPatterns(
1020     RewritePatternSet &patterns, MLIRContext *context) {
1021   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // Verifiers for casts between integers and floats.
1026 //===----------------------------------------------------------------------===//
1027 
1028 template <typename From, typename To>
1029 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1030   if (!areValidCastInputsAndOutputs(inputs, outputs))
1031     return false;
1032 
1033   auto srcType = getTypeIfLike<From>(inputs.front());
1034   auto dstType = getTypeIfLike<To>(outputs.back());
1035 
1036   return srcType && dstType;
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // UIToFPOp
1041 //===----------------------------------------------------------------------===//
1042 
1043 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1044   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1045 }
1046 
1047 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1048   Type resType = getType();
1049   Type resEleType;
1050   if (auto shapedType = resType.dyn_cast<ShapedType>())
1051     resEleType = shapedType.getElementType();
1052   else
1053     resEleType = resType;
1054   return constFoldCastOp<IntegerAttr, FloatAttr>(
1055       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1056         FloatType floatTy = resEleType.cast<FloatType>();
1057         APFloat apf(floatTy.getFloatSemantics(),
1058                     APInt::getZero(floatTy.getWidth()));
1059         apf.convertFromAPInt(a, /*IsSigned=*/false,
1060                              APFloat::rmNearestTiesToEven);
1061         return apf;
1062       });
1063 }
1064 
1065 //===----------------------------------------------------------------------===//
1066 // SIToFPOp
1067 //===----------------------------------------------------------------------===//
1068 
1069 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1070   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1071 }
1072 
1073 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1074   Type resType = getType();
1075   Type resEleType;
1076   if (auto shapedType = resType.dyn_cast<ShapedType>())
1077     resEleType = shapedType.getElementType();
1078   else
1079     resEleType = resType;
1080   return constFoldCastOp<IntegerAttr, FloatAttr>(
1081       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1082         FloatType floatTy = resEleType.cast<FloatType>();
1083         APFloat apf(floatTy.getFloatSemantics(),
1084                     APInt::getZero(floatTy.getWidth()));
1085         apf.convertFromAPInt(a, /*IsSigned=*/true,
1086                              APFloat::rmNearestTiesToEven);
1087         return apf;
1088       });
1089 }
1090 //===----------------------------------------------------------------------===//
1091 // FPToUIOp
1092 //===----------------------------------------------------------------------===//
1093 
1094 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1095   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1096 }
1097 
1098 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1099   Type resType = getType();
1100   Type resEleType;
1101   if (auto shapedType = resType.dyn_cast<ShapedType>())
1102     resEleType = shapedType.getElementType();
1103   else
1104     resEleType = resType;
1105   return constFoldCastOp<FloatAttr, IntegerAttr>(
1106       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1107         IntegerType intTy = resEleType.cast<IntegerType>();
1108         bool ignored;
1109         APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1110         castStatus = APFloat::opInvalidOp !=
1111                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1112         return api;
1113       });
1114 }
1115 
1116 //===----------------------------------------------------------------------===//
1117 // FPToSIOp
1118 //===----------------------------------------------------------------------===//
1119 
1120 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1121   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1122 }
1123 
1124 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1125   Type resType = getType();
1126   Type resEleType;
1127   if (auto shapedType = resType.dyn_cast<ShapedType>())
1128     resEleType = shapedType.getElementType();
1129   else
1130     resEleType = resType;
1131   return constFoldCastOp<FloatAttr, IntegerAttr>(
1132       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1133         IntegerType intTy = resEleType.cast<IntegerType>();
1134         bool ignored;
1135         APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1136         castStatus = APFloat::opInvalidOp !=
1137                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1138         return api;
1139       });
1140 }
1141 
1142 //===----------------------------------------------------------------------===//
1143 // IndexCastOp
1144 //===----------------------------------------------------------------------===//
1145 
1146 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1147                                            TypeRange outputs) {
1148   if (!areValidCastInputsAndOutputs(inputs, outputs))
1149     return false;
1150 
1151   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1152   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1153   if (!srcType || !dstType)
1154     return false;
1155 
1156   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1157          (srcType.isSignlessInteger() && dstType.isIndex());
1158 }
1159 
1160 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1161   // index_cast(constant) -> constant
1162   // A little hack because we go through int. Otherwise, the size of the
1163   // constant might need to change.
1164   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1165     return IntegerAttr::get(getType(), value.getInt());
1166 
1167   return {};
1168 }
1169 
1170 void arith::IndexCastOp::getCanonicalizationPatterns(
1171     RewritePatternSet &patterns, MLIRContext *context) {
1172   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1173 }
1174 
1175 //===----------------------------------------------------------------------===//
1176 // BitcastOp
1177 //===----------------------------------------------------------------------===//
1178 
1179 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1180   if (!areValidCastInputsAndOutputs(inputs, outputs))
1181     return false;
1182 
1183   auto srcType =
1184       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1185   auto dstType =
1186       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1187   if (!srcType || !dstType)
1188     return false;
1189 
1190   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1191 }
1192 
1193 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1194   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1195 
1196   auto resType = getType();
1197   auto operand = operands[0];
1198   if (!operand)
1199     return {};
1200 
1201   /// Bitcast dense elements.
1202   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1203     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1204   /// Other shaped types unhandled.
1205   if (resType.isa<ShapedType>())
1206     return {};
1207 
1208   /// Bitcast integer or float to integer or float.
1209   APInt bits = operand.isa<FloatAttr>()
1210                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1211                    : operand.cast<IntegerAttr>().getValue();
1212 
1213   if (auto resFloatType = resType.dyn_cast<FloatType>())
1214     return FloatAttr::get(resType,
1215                           APFloat(resFloatType.getFloatSemantics(), bits));
1216   return IntegerAttr::get(resType, bits);
1217 }
1218 
1219 void arith::BitcastOp::getCanonicalizationPatterns(
1220     RewritePatternSet &patterns, MLIRContext *context) {
1221   patterns.add<BitcastOfBitcast>(context);
1222 }
1223 
1224 //===----------------------------------------------------------------------===//
1225 // Helpers for compare ops
1226 //===----------------------------------------------------------------------===//
1227 
1228 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1229 static Type getI1SameShape(Type type) {
1230   auto i1Type = IntegerType::get(type.getContext(), 1);
1231   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1232     return RankedTensorType::get(tensorType.getShape(), i1Type);
1233   if (type.isa<UnrankedTensorType>())
1234     return UnrankedTensorType::get(i1Type);
1235   if (auto vectorType = type.dyn_cast<VectorType>())
1236     return VectorType::get(vectorType.getShape(), i1Type,
1237                            vectorType.getNumScalableDims());
1238   return i1Type;
1239 }
1240 
1241 //===----------------------------------------------------------------------===//
1242 // CmpIOp
1243 //===----------------------------------------------------------------------===//
1244 
1245 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1246 /// comparison predicates.
1247 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1248                                     const APInt &lhs, const APInt &rhs) {
1249   switch (predicate) {
1250   case arith::CmpIPredicate::eq:
1251     return lhs.eq(rhs);
1252   case arith::CmpIPredicate::ne:
1253     return lhs.ne(rhs);
1254   case arith::CmpIPredicate::slt:
1255     return lhs.slt(rhs);
1256   case arith::CmpIPredicate::sle:
1257     return lhs.sle(rhs);
1258   case arith::CmpIPredicate::sgt:
1259     return lhs.sgt(rhs);
1260   case arith::CmpIPredicate::sge:
1261     return lhs.sge(rhs);
1262   case arith::CmpIPredicate::ult:
1263     return lhs.ult(rhs);
1264   case arith::CmpIPredicate::ule:
1265     return lhs.ule(rhs);
1266   case arith::CmpIPredicate::ugt:
1267     return lhs.ugt(rhs);
1268   case arith::CmpIPredicate::uge:
1269     return lhs.uge(rhs);
1270   }
1271   llvm_unreachable("unknown cmpi predicate kind");
1272 }
1273 
1274 /// Returns true if the predicate is true for two equal operands.
1275 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1276   switch (predicate) {
1277   case arith::CmpIPredicate::eq:
1278   case arith::CmpIPredicate::sle:
1279   case arith::CmpIPredicate::sge:
1280   case arith::CmpIPredicate::ule:
1281   case arith::CmpIPredicate::uge:
1282     return true;
1283   case arith::CmpIPredicate::ne:
1284   case arith::CmpIPredicate::slt:
1285   case arith::CmpIPredicate::sgt:
1286   case arith::CmpIPredicate::ult:
1287   case arith::CmpIPredicate::ugt:
1288     return false;
1289   }
1290   llvm_unreachable("unknown cmpi predicate kind");
1291 }
1292 
1293 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1294   auto boolAttr = BoolAttr::get(ctx, value);
1295   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1296   if (!shapedType)
1297     return boolAttr;
1298   return DenseElementsAttr::get(shapedType, boolAttr);
1299 }
1300 
1301 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1302   assert(operands.size() == 2 && "cmpi takes two operands");
1303 
1304   // cmpi(pred, x, x)
1305   if (getLhs() == getRhs()) {
1306     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1307     return getBoolAttribute(getType(), getContext(), val);
1308   }
1309 
1310   if (matchPattern(getRhs(), m_Zero())) {
1311     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1312       // extsi(%x : i1 -> iN) != 0  ->  %x
1313       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1314           getPredicate() == arith::CmpIPredicate::ne)
1315         return extOp.getOperand();
1316     }
1317     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1318       // extui(%x : i1 -> iN) != 0  ->  %x
1319       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1320           getPredicate() == arith::CmpIPredicate::ne)
1321         return extOp.getOperand();
1322     }
1323   }
1324 
1325   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1326   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1327   if (!lhs || !rhs)
1328     return {};
1329 
1330   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1331   return BoolAttr::get(getContext(), val);
1332 }
1333 
1334 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1335                                                 MLIRContext *context) {
1336   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1337 }
1338 
1339 //===----------------------------------------------------------------------===//
1340 // CmpFOp
1341 //===----------------------------------------------------------------------===//
1342 
1343 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1344 /// comparison predicates.
1345 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1346                                     const APFloat &lhs, const APFloat &rhs) {
1347   auto cmpResult = lhs.compare(rhs);
1348   switch (predicate) {
1349   case arith::CmpFPredicate::AlwaysFalse:
1350     return false;
1351   case arith::CmpFPredicate::OEQ:
1352     return cmpResult == APFloat::cmpEqual;
1353   case arith::CmpFPredicate::OGT:
1354     return cmpResult == APFloat::cmpGreaterThan;
1355   case arith::CmpFPredicate::OGE:
1356     return cmpResult == APFloat::cmpGreaterThan ||
1357            cmpResult == APFloat::cmpEqual;
1358   case arith::CmpFPredicate::OLT:
1359     return cmpResult == APFloat::cmpLessThan;
1360   case arith::CmpFPredicate::OLE:
1361     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1362   case arith::CmpFPredicate::ONE:
1363     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1364   case arith::CmpFPredicate::ORD:
1365     return cmpResult != APFloat::cmpUnordered;
1366   case arith::CmpFPredicate::UEQ:
1367     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1368   case arith::CmpFPredicate::UGT:
1369     return cmpResult == APFloat::cmpUnordered ||
1370            cmpResult == APFloat::cmpGreaterThan;
1371   case arith::CmpFPredicate::UGE:
1372     return cmpResult == APFloat::cmpUnordered ||
1373            cmpResult == APFloat::cmpGreaterThan ||
1374            cmpResult == APFloat::cmpEqual;
1375   case arith::CmpFPredicate::ULT:
1376     return cmpResult == APFloat::cmpUnordered ||
1377            cmpResult == APFloat::cmpLessThan;
1378   case arith::CmpFPredicate::ULE:
1379     return cmpResult == APFloat::cmpUnordered ||
1380            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1381   case arith::CmpFPredicate::UNE:
1382     return cmpResult != APFloat::cmpEqual;
1383   case arith::CmpFPredicate::UNO:
1384     return cmpResult == APFloat::cmpUnordered;
1385   case arith::CmpFPredicate::AlwaysTrue:
1386     return true;
1387   }
1388   llvm_unreachable("unknown cmpf predicate kind");
1389 }
1390 
1391 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1392   assert(operands.size() == 2 && "cmpf takes two operands");
1393 
1394   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1395   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1396 
1397   // If one operand is NaN, making them both NaN does not change the result.
1398   if (lhs && lhs.getValue().isNaN())
1399     rhs = lhs;
1400   if (rhs && rhs.getValue().isNaN())
1401     lhs = rhs;
1402 
1403   if (!lhs || !rhs)
1404     return {};
1405 
1406   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1407   return BoolAttr::get(getContext(), val);
1408 }
1409 
1410 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1411 public:
1412   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1413 
1414   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1415                                                  bool isUnsigned) {
1416     using namespace arith;
1417     switch (pred) {
1418     case CmpFPredicate::UEQ:
1419     case CmpFPredicate::OEQ:
1420       return CmpIPredicate::eq;
1421     case CmpFPredicate::UGT:
1422     case CmpFPredicate::OGT:
1423       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1424     case CmpFPredicate::UGE:
1425     case CmpFPredicate::OGE:
1426       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1427     case CmpFPredicate::ULT:
1428     case CmpFPredicate::OLT:
1429       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1430     case CmpFPredicate::ULE:
1431     case CmpFPredicate::OLE:
1432       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1433     case CmpFPredicate::UNE:
1434     case CmpFPredicate::ONE:
1435       return CmpIPredicate::ne;
1436     default:
1437       llvm_unreachable("Unexpected predicate!");
1438     }
1439   }
1440 
1441   LogicalResult matchAndRewrite(CmpFOp op,
1442                                 PatternRewriter &rewriter) const override {
1443     FloatAttr flt;
1444     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1445       return failure();
1446 
1447     const APFloat &rhs = flt.getValue();
1448 
1449     // Don't attempt to fold a nan.
1450     if (rhs.isNaN())
1451       return failure();
1452 
1453     // Get the width of the mantissa.  We don't want to hack on conversions that
1454     // might lose information from the integer, e.g. "i64 -> float"
1455     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1456     int mantissaWidth = floatTy.getFPMantissaWidth();
1457     if (mantissaWidth <= 0)
1458       return failure();
1459 
1460     bool isUnsigned;
1461     Value intVal;
1462 
1463     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1464       isUnsigned = false;
1465       intVal = si.getIn();
1466     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1467       isUnsigned = true;
1468       intVal = ui.getIn();
1469     } else {
1470       return failure();
1471     }
1472 
1473     // Check to see that the input is converted from an integer type that is
1474     // small enough that preserves all bits.
1475     auto intTy = intVal.getType().cast<IntegerType>();
1476     auto intWidth = intTy.getWidth();
1477 
1478     // Number of bits representing values, as opposed to the sign
1479     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1480 
1481     // Following test does NOT adjust intWidth downwards for signed inputs,
1482     // because the most negative value still requires all the mantissa bits
1483     // to distinguish it from one less than that value.
1484     if ((int)intWidth > mantissaWidth) {
1485       // Conversion would lose accuracy. Check if loss can impact comparison.
1486       int exponent = ilogb(rhs);
1487       if (exponent == APFloat::IEK_Inf) {
1488         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1489         if (maxExponent < (int)valueBits) {
1490           // Conversion could create infinity.
1491           return failure();
1492         }
1493       } else {
1494         // Note that if rhs is zero or NaN, then Exp is negative
1495         // and first condition is trivially false.
1496         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1497           // Conversion could affect comparison.
1498           return failure();
1499         }
1500       }
1501     }
1502 
1503     // Convert to equivalent cmpi predicate
1504     CmpIPredicate pred;
1505     switch (op.getPredicate()) {
1506     case CmpFPredicate::ORD:
1507       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1508       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1509                                                  /*width=*/1);
1510       return success();
1511     case CmpFPredicate::UNO:
1512       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1513       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1514                                                  /*width=*/1);
1515       return success();
1516     default:
1517       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1518       break;
1519     }
1520 
1521     if (!isUnsigned) {
1522       // If the rhs value is > SignedMax, fold the comparison.  This handles
1523       // +INF and large values.
1524       APFloat signedMax(rhs.getSemantics());
1525       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1526                                  APFloat::rmNearestTiesToEven);
1527       if (signedMax < rhs) { // smax < 13123.0
1528         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1529             pred == CmpIPredicate::sle)
1530           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1531                                                      /*width=*/1);
1532         else
1533           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1534                                                      /*width=*/1);
1535         return success();
1536       }
1537     } else {
1538       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1539       // +INF and large values.
1540       APFloat unsignedMax(rhs.getSemantics());
1541       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1542                                    APFloat::rmNearestTiesToEven);
1543       if (unsignedMax < rhs) { // umax < 13123.0
1544         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1545             pred == CmpIPredicate::ule)
1546           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1547                                                      /*width=*/1);
1548         else
1549           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1550                                                      /*width=*/1);
1551         return success();
1552       }
1553     }
1554 
1555     if (!isUnsigned) {
1556       // See if the rhs value is < SignedMin.
1557       APFloat signedMin(rhs.getSemantics());
1558       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1559                                  APFloat::rmNearestTiesToEven);
1560       if (signedMin > rhs) { // smin > 12312.0
1561         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1562             pred == CmpIPredicate::sge)
1563           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1564                                                      /*width=*/1);
1565         else
1566           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1567                                                      /*width=*/1);
1568         return success();
1569       }
1570     } else {
1571       // See if the rhs value is < UnsignedMin.
1572       APFloat unsignedMin(rhs.getSemantics());
1573       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1574                                    APFloat::rmNearestTiesToEven);
1575       if (unsignedMin > rhs) { // umin > 12312.0
1576         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1577             pred == CmpIPredicate::uge)
1578           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1579                                                      /*width=*/1);
1580         else
1581           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1582                                                      /*width=*/1);
1583         return success();
1584       }
1585     }
1586 
1587     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1588     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1589     // casting the FP value to the integer value and back, checking for
1590     // equality. Don't do this for zero, because -0.0 is not fractional.
1591     bool ignored;
1592     APSInt rhsInt(intWidth, isUnsigned);
1593     if (APFloat::opInvalidOp ==
1594         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1595       // Undefined behavior invoked - the destination type can't represent
1596       // the input constant.
1597       return failure();
1598     }
1599 
1600     if (!rhs.isZero()) {
1601       APFloat apf(floatTy.getFloatSemantics(),
1602                   APInt::getZero(floatTy.getWidth()));
1603       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1604 
1605       bool equal = apf == rhs;
1606       if (!equal) {
1607         // If we had a comparison against a fractional value, we have to adjust
1608         // the compare predicate and sometimes the value.  rhsInt is rounded
1609         // towards zero at this point.
1610         switch (pred) {
1611         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1612           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1613                                                      /*width=*/1);
1614           return success();
1615         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1616           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1617                                                      /*width=*/1);
1618           return success();
1619         case CmpIPredicate::ule:
1620           // (float)int <= 4.4   --> int <= 4
1621           // (float)int <= -4.4  --> false
1622           if (rhs.isNegative()) {
1623             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1624                                                        /*width=*/1);
1625             return success();
1626           }
1627           break;
1628         case CmpIPredicate::sle:
1629           // (float)int <= 4.4   --> int <= 4
1630           // (float)int <= -4.4  --> int < -4
1631           if (rhs.isNegative())
1632             pred = CmpIPredicate::slt;
1633           break;
1634         case CmpIPredicate::ult:
1635           // (float)int < -4.4   --> false
1636           // (float)int < 4.4    --> int <= 4
1637           if (rhs.isNegative()) {
1638             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1639                                                        /*width=*/1);
1640             return success();
1641           }
1642           pred = CmpIPredicate::ule;
1643           break;
1644         case CmpIPredicate::slt:
1645           // (float)int < -4.4   --> int < -4
1646           // (float)int < 4.4    --> int <= 4
1647           if (!rhs.isNegative())
1648             pred = CmpIPredicate::sle;
1649           break;
1650         case CmpIPredicate::ugt:
1651           // (float)int > 4.4    --> int > 4
1652           // (float)int > -4.4   --> true
1653           if (rhs.isNegative()) {
1654             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1655                                                        /*width=*/1);
1656             return success();
1657           }
1658           break;
1659         case CmpIPredicate::sgt:
1660           // (float)int > 4.4    --> int > 4
1661           // (float)int > -4.4   --> int >= -4
1662           if (rhs.isNegative())
1663             pred = CmpIPredicate::sge;
1664           break;
1665         case CmpIPredicate::uge:
1666           // (float)int >= -4.4   --> true
1667           // (float)int >= 4.4    --> int > 4
1668           if (rhs.isNegative()) {
1669             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1670                                                        /*width=*/1);
1671             return success();
1672           }
1673           pred = CmpIPredicate::ugt;
1674           break;
1675         case CmpIPredicate::sge:
1676           // (float)int >= -4.4   --> int >= -4
1677           // (float)int >= 4.4    --> int > 4
1678           if (!rhs.isNegative())
1679             pred = CmpIPredicate::sgt;
1680           break;
1681         }
1682       }
1683     }
1684 
1685     // Lower this FP comparison into an appropriate integer version of the
1686     // comparison.
1687     rewriter.replaceOpWithNewOp<CmpIOp>(
1688         op, pred, intVal,
1689         rewriter.create<ConstantOp>(
1690             op.getLoc(), intVal.getType(),
1691             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1692     return success();
1693   }
1694 };
1695 
1696 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1697                                                 MLIRContext *context) {
1698   patterns.insert<CmpFIntToFPConst>(context);
1699 }
1700 
1701 //===----------------------------------------------------------------------===//
1702 // SelectOp
1703 //===----------------------------------------------------------------------===//
1704 
1705 // Transforms a select of a boolean to arithmetic operations
1706 //
1707 //  arith.select %arg, %x, %y : i1
1708 //
1709 //  becomes
1710 //
1711 //  and(%arg, %x) or and(!%arg, %y)
1712 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1713   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1714 
1715   LogicalResult matchAndRewrite(arith::SelectOp op,
1716                                 PatternRewriter &rewriter) const override {
1717     if (!op.getType().isInteger(1))
1718       return failure();
1719 
1720     Value falseConstant =
1721         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1722     Value notCondition = rewriter.create<arith::XOrIOp>(
1723         op.getLoc(), op.getCondition(), falseConstant);
1724 
1725     Value trueVal = rewriter.create<arith::AndIOp>(
1726         op.getLoc(), op.getCondition(), op.getTrueValue());
1727     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1728                                                     op.getFalseValue());
1729     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1730     return success();
1731   }
1732 };
1733 
1734 //  select %arg, %c1, %c0 => extui %arg
1735 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1736   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1737 
1738   LogicalResult matchAndRewrite(arith::SelectOp op,
1739                                 PatternRewriter &rewriter) const override {
1740     // Cannot extui i1 to i1, or i1 to f32
1741     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1742       return failure();
1743 
1744     // select %x, c1, %c0 => extui %arg
1745     if (matchPattern(op.getTrueValue(), m_One()) &&
1746         matchPattern(op.getFalseValue(), m_Zero())) {
1747       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1748                                                   op.getCondition());
1749       return success();
1750     }
1751 
1752     // select %x, c0, %c1 => extui (xor %arg, true)
1753     if (matchPattern(op.getTrueValue(), m_Zero()) &&
1754         matchPattern(op.getFalseValue(), m_One())) {
1755       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1756           op, op.getType(),
1757           rewriter.create<arith::XOrIOp>(
1758               op.getLoc(), op.getCondition(),
1759               rewriter.create<arith::ConstantIntOp>(
1760                   op.getLoc(), 1, op.getCondition().getType())));
1761       return success();
1762     }
1763 
1764     return failure();
1765   }
1766 };
1767 
1768 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1769                                                   MLIRContext *context) {
1770   results.add<SelectI1Simplify, SelectToExtUI>(context);
1771 }
1772 
1773 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1774   Value trueVal = getTrueValue();
1775   Value falseVal = getFalseValue();
1776   if (trueVal == falseVal)
1777     return trueVal;
1778 
1779   Value condition = getCondition();
1780 
1781   // select true, %0, %1 => %0
1782   if (matchPattern(condition, m_One()))
1783     return trueVal;
1784 
1785   // select false, %0, %1 => %1
1786   if (matchPattern(condition, m_Zero()))
1787     return falseVal;
1788 
1789   // select %x, true, false => %x
1790   if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
1791       matchPattern(getFalseValue(), m_Zero()))
1792     return condition;
1793 
1794   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1795     auto pred = cmp.getPredicate();
1796     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1797       auto cmpLhs = cmp.getLhs();
1798       auto cmpRhs = cmp.getRhs();
1799 
1800       // %0 = arith.cmpi eq, %arg0, %arg1
1801       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1802 
1803       // %0 = arith.cmpi ne, %arg0, %arg1
1804       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1805 
1806       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1807           (cmpRhs == trueVal && cmpLhs == falseVal))
1808         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1809     }
1810   }
1811   return nullptr;
1812 }
1813 
1814 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1815   Type conditionType, resultType;
1816   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1817   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1818       parser.parseOptionalAttrDict(result.attributes) ||
1819       parser.parseColonType(resultType))
1820     return failure();
1821 
1822   // Check for the explicit condition type if this is a masked tensor or vector.
1823   if (succeeded(parser.parseOptionalComma())) {
1824     conditionType = resultType;
1825     if (parser.parseType(resultType))
1826       return failure();
1827   } else {
1828     conditionType = parser.getBuilder().getI1Type();
1829   }
1830 
1831   result.addTypes(resultType);
1832   return parser.resolveOperands(operands,
1833                                 {conditionType, resultType, resultType},
1834                                 parser.getNameLoc(), result.operands);
1835 }
1836 
1837 void arith::SelectOp::print(OpAsmPrinter &p) {
1838   p << " " << getOperands();
1839   p.printOptionalAttrDict((*this)->getAttrs());
1840   p << " : ";
1841   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1842     p << condType << ", ";
1843   p << getType();
1844 }
1845 
1846 LogicalResult arith::SelectOp::verify() {
1847   Type conditionType = getCondition().getType();
1848   if (conditionType.isSignlessInteger(1))
1849     return success();
1850 
1851   // If the result type is a vector or tensor, the type can be a mask with the
1852   // same elements.
1853   Type resultType = getType();
1854   if (!resultType.isa<TensorType, VectorType>())
1855     return emitOpError() << "expected condition to be a signless i1, but got "
1856                          << conditionType;
1857   Type shapedConditionType = getI1SameShape(resultType);
1858   if (conditionType != shapedConditionType) {
1859     return emitOpError() << "expected condition type to have the same shape "
1860                             "as the result type, expected "
1861                          << shapedConditionType << ", but got "
1862                          << conditionType;
1863   }
1864   return success();
1865 }
1866 //===----------------------------------------------------------------------===//
1867 // ShLIOp
1868 //===----------------------------------------------------------------------===//
1869 
1870 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1871   // Don't fold if shifting more than the bit width.
1872   bool bounded = false;
1873   auto result = constFoldBinaryOp<IntegerAttr>(
1874       operands, [&](const APInt &a, const APInt &b) {
1875         bounded = b.ule(b.getBitWidth());
1876         return a.shl(b);
1877       });
1878   return bounded ? result : Attribute();
1879 }
1880 
1881 //===----------------------------------------------------------------------===//
1882 // ShRUIOp
1883 //===----------------------------------------------------------------------===//
1884 
1885 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1886   // Don't fold if shifting more than the bit width.
1887   bool bounded = false;
1888   auto result = constFoldBinaryOp<IntegerAttr>(
1889       operands, [&](const APInt &a, const APInt &b) {
1890         bounded = b.ule(b.getBitWidth());
1891         return a.lshr(b);
1892       });
1893   return bounded ? result : Attribute();
1894 }
1895 
1896 //===----------------------------------------------------------------------===//
1897 // ShRSIOp
1898 //===----------------------------------------------------------------------===//
1899 
1900 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1901   // Don't fold if shifting more than the bit width.
1902   bool bounded = false;
1903   auto result = constFoldBinaryOp<IntegerAttr>(
1904       operands, [&](const APInt &a, const APInt &b) {
1905         bounded = b.ule(b.getBitWidth());
1906         return a.ashr(b);
1907       });
1908   return bounded ? result : Attribute();
1909 }
1910 
1911 //===----------------------------------------------------------------------===//
1912 // Atomic Enum
1913 //===----------------------------------------------------------------------===//
1914 
1915 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1916 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1917                                             OpBuilder &builder, Location loc) {
1918   switch (kind) {
1919   case AtomicRMWKind::maxf:
1920     return builder.getFloatAttr(
1921         resultType,
1922         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1923                         /*Negative=*/true));
1924   case AtomicRMWKind::addf:
1925   case AtomicRMWKind::addi:
1926   case AtomicRMWKind::maxu:
1927   case AtomicRMWKind::ori:
1928     return builder.getZeroAttr(resultType);
1929   case AtomicRMWKind::andi:
1930     return builder.getIntegerAttr(
1931         resultType,
1932         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1933   case AtomicRMWKind::maxs:
1934     return builder.getIntegerAttr(
1935         resultType,
1936         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1937   case AtomicRMWKind::minf:
1938     return builder.getFloatAttr(
1939         resultType,
1940         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1941                         /*Negative=*/false));
1942   case AtomicRMWKind::mins:
1943     return builder.getIntegerAttr(
1944         resultType,
1945         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1946   case AtomicRMWKind::minu:
1947     return builder.getIntegerAttr(
1948         resultType,
1949         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1950   case AtomicRMWKind::muli:
1951     return builder.getIntegerAttr(resultType, 1);
1952   case AtomicRMWKind::mulf:
1953     return builder.getFloatAttr(resultType, 1);
1954   // TODO: Add remaining reduction operations.
1955   default:
1956     (void)emitOptionalError(loc, "Reduction operation type not supported");
1957     break;
1958   }
1959   return nullptr;
1960 }
1961 
1962 /// Returns the identity value associated with an AtomicRMWKind op.
1963 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1964                                     OpBuilder &builder, Location loc) {
1965   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1966   return builder.create<arith::ConstantOp>(loc, attr);
1967 }
1968 
1969 /// Return the value obtained by applying the reduction operation kind
1970 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1971 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1972                                   Location loc, Value lhs, Value rhs) {
1973   switch (op) {
1974   case AtomicRMWKind::addf:
1975     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1976   case AtomicRMWKind::addi:
1977     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1978   case AtomicRMWKind::mulf:
1979     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1980   case AtomicRMWKind::muli:
1981     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1982   case AtomicRMWKind::maxf:
1983     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1984   case AtomicRMWKind::minf:
1985     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1986   case AtomicRMWKind::maxs:
1987     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1988   case AtomicRMWKind::mins:
1989     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1990   case AtomicRMWKind::maxu:
1991     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1992   case AtomicRMWKind::minu:
1993     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1994   case AtomicRMWKind::ori:
1995     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1996   case AtomicRMWKind::andi:
1997     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1998   // TODO: Add remaining reduction operations.
1999   default:
2000     (void)emitOptionalError(loc, "Reduction operation type not supported");
2001     break;
2002   }
2003   return nullptr;
2004 }
2005 
2006 //===----------------------------------------------------------------------===//
2007 // TableGen'd op method definitions
2008 //===----------------------------------------------------------------------===//
2009 
2010 #define GET_OP_CLASSES
2011 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2012 
2013 //===----------------------------------------------------------------------===//
2014 // TableGen'd enum attribute definitions
2015 //===----------------------------------------------------------------------===//
2016 
2017 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2018