1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30                                    Attribute lhs, Attribute rhs) {
31   return builder.getIntegerAttr(res.getType(),
32                                 lhs.cast<IntegerAttr>().getInt() +
33                                     rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37                                    Attribute lhs, Attribute rhs) {
38   return builder.getIntegerAttr(res.getType(),
39                                 lhs.cast<IntegerAttr>().getInt() -
40                                     rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45   switch (pred) {
46   case arith::CmpIPredicate::eq:
47     return arith::CmpIPredicate::ne;
48   case arith::CmpIPredicate::ne:
49     return arith::CmpIPredicate::eq;
50   case arith::CmpIPredicate::slt:
51     return arith::CmpIPredicate::sge;
52   case arith::CmpIPredicate::sle:
53     return arith::CmpIPredicate::sgt;
54   case arith::CmpIPredicate::sgt:
55     return arith::CmpIPredicate::sle;
56   case arith::CmpIPredicate::sge:
57     return arith::CmpIPredicate::slt;
58   case arith::CmpIPredicate::ult:
59     return arith::CmpIPredicate::uge;
60   case arith::CmpIPredicate::ule:
61     return arith::CmpIPredicate::ugt;
62   case arith::CmpIPredicate::ugt:
63     return arith::CmpIPredicate::ule;
64   case arith::CmpIPredicate::uge:
65     return arith::CmpIPredicate::ult;
66   }
67   llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71   return arith::CmpIPredicateAttr::get(pred.getContext(),
72                                        invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88     function_ref<void(Value, StringRef)> setNameFn) {
89   auto type = getType();
90   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91     auto intType = type.dyn_cast<IntegerType>();
92 
93     // Sugar i1 constants with 'true' and 'false'.
94     if (intType && intType.getWidth() == 1)
95       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97     // Otherwise, build a compex name with the value and type.
98     SmallString<32> specialNameBuffer;
99     llvm::raw_svector_ostream specialName(specialNameBuffer);
100     specialName << 'c' << intCst.getInt();
101     if (intType)
102       specialName << '_' << type;
103     setNameFn(getResult(), specialName.str());
104   } else {
105     setNameFn(getResult(), "cst");
106   }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
111 LogicalResult arith::ConstantOp::verify() {
112   auto type = getType();
113   // The value's type must match the return type.
114   if (getValue().getType() != type) {
115     return emitOpError() << "value type " << getValue().getType()
116                          << " must match return type: " << type;
117   }
118   // Integer values must be signless.
119   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120     return emitOpError("integer return type must be signless");
121   // Any float or elements attribute are acceptable.
122   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123     return emitOpError(
124         "value must be an integer, float, or elements attribute");
125   }
126   return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130   // The value's type must be the same as the provided type.
131   if (value.getType() != type)
132     return false;
133   // Integer values must be signless.
134   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135     return false;
136   // Integer, float, and element attributes are buildable.
137   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139 
140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141   return getValue();
142 }
143 
144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145                                  int64_t value, unsigned width) {
146   auto type = builder.getIntegerType(width);
147   arith::ConstantOp::build(builder, result, type,
148                            builder.getIntegerAttr(type, value));
149 }
150 
151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152                                  int64_t value, Type type) {
153   assert(type.isSignlessInteger() &&
154          "ConstantIntOp can only have signless integer type values");
155   arith::ConstantOp::build(builder, result, type,
156                            builder.getIntegerAttr(type, value));
157 }
158 
159 bool arith::ConstantIntOp::classof(Operation *op) {
160   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161     return constOp.getType().isSignlessInteger();
162   return false;
163 }
164 
165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166                                    const APFloat &value, FloatType type) {
167   arith::ConstantOp::build(builder, result, type,
168                            builder.getFloatAttr(type, value));
169 }
170 
171 bool arith::ConstantFloatOp::classof(Operation *op) {
172   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173     return constOp.getType().isa<FloatType>();
174   return false;
175 }
176 
177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178                                    int64_t value) {
179   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180                            builder.getIndexAttr(value));
181 }
182 
183 bool arith::ConstantIndexOp::classof(Operation *op) {
184   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185     return constOp.getType().isIndex();
186   return false;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194   // addi(x, 0) -> x
195   if (matchPattern(getRhs(), m_Zero()))
196     return getLhs();
197 
198   // addi(subi(a, b), b) -> a
199   if (auto sub = getLhs().getDefiningOp<SubIOp>())
200     if (getRhs() == sub.getRhs())
201       return sub.getLhs();
202 
203   // addi(b, subi(a, b)) -> a
204   if (auto sub = getRhs().getDefiningOp<SubIOp>())
205     if (getLhs() == sub.getRhs())
206       return sub.getLhs();
207 
208   return constFoldBinaryOp<IntegerAttr>(
209       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211 
212 void arith::AddIOp::getCanonicalizationPatterns(
213     RewritePatternSet &patterns, MLIRContext *context) {
214   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215       context);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221 
222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223   // subi(x,x) -> 0
224   if (getOperand(0) == getOperand(1))
225     return Builder(getContext()).getZeroAttr(getType());
226   // subi(x,0) -> x
227   if (matchPattern(getRhs(), m_Zero()))
228     return getLhs();
229 
230   return constFoldBinaryOp<IntegerAttr>(
231       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233 
234 void arith::SubIOp::getCanonicalizationPatterns(
235     RewritePatternSet &patterns, MLIRContext *context) {
236   patterns
237       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247   // muli(x, 0) -> 0
248   if (matchPattern(getRhs(), m_Zero()))
249     return getRhs();
250   // muli(x, 1) -> x
251   if (matchPattern(getRhs(), m_One()))
252     return getOperand(0);
253   // TODO: Handle the overflow case.
254 
255   // default folder
256   return constFoldBinaryOp<IntegerAttr>(
257       operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265   // divui (x, 1) -> x.
266   if (matchPattern(getRhs(), m_One()))
267     return getLhs();
268 
269   // Don't fold if it would require a division by zero.
270   bool div0 = false;
271   auto result =
272       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
273         if (div0 || !b) {
274           div0 = true;
275           return a;
276         }
277         return a.udiv(b);
278       });
279 
280   return div0 ? Attribute() : result;
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // DivSIOp
285 //===----------------------------------------------------------------------===//
286 
287 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
288   // divsi (x, 1) -> x.
289   if (matchPattern(getRhs(), m_One()))
290     return getLhs();
291 
292   // Don't fold if it would overflow or if it requires a division by zero.
293   bool overflowOrDiv0 = false;
294   auto result =
295       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
296         if (overflowOrDiv0 || !b) {
297           overflowOrDiv0 = true;
298           return a;
299         }
300         return a.sdiv_ov(b, overflowOrDiv0);
301       });
302 
303   return overflowOrDiv0 ? Attribute() : result;
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Ceil and floor division folding helpers
308 //===----------------------------------------------------------------------===//
309 
310 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
311                                     bool &overflow) {
312   // Returns (a-1)/b + 1
313   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
314   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
315   return val.sadd_ov(one, overflow);
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // CeilDivUIOp
320 //===----------------------------------------------------------------------===//
321 
322 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
323   // ceildivui (x, 1) -> x.
324   if (matchPattern(getRhs(), m_One()))
325     return getLhs();
326 
327   bool overflowOrDiv0 = false;
328   auto result =
329       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
330         if (overflowOrDiv0 || !b) {
331           overflowOrDiv0 = true;
332           return a;
333         }
334         APInt quotient = a.udiv(b);
335         if (!a.urem(b))
336           return quotient;
337         APInt one(a.getBitWidth(), 1, true);
338         return quotient.uadd_ov(one, overflowOrDiv0);
339       });
340 
341   return overflowOrDiv0 ? Attribute() : result;
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // CeilDivSIOp
346 //===----------------------------------------------------------------------===//
347 
348 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
349   // ceildivsi (x, 1) -> x.
350   if (matchPattern(getRhs(), m_One()))
351     return getLhs();
352 
353   // Don't fold if it would overflow or if it requires a division by zero.
354   bool overflowOrDiv0 = false;
355   auto result =
356       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
357         if (overflowOrDiv0 || !b) {
358           overflowOrDiv0 = true;
359           return a;
360         }
361         if (!a)
362           return a;
363         // After this point we know that neither a or b are zero.
364         unsigned bits = a.getBitWidth();
365         APInt zero = APInt::getZero(bits);
366         bool aGtZero = a.sgt(zero);
367         bool bGtZero = b.sgt(zero);
368         if (aGtZero && bGtZero) {
369           // Both positive, return ceil(a, b).
370           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
371         }
372         if (!aGtZero && !bGtZero) {
373           // Both negative, return ceil(-a, -b).
374           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
375           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
376           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
377         }
378         if (!aGtZero && bGtZero) {
379           // A is negative, b is positive, return - ( -a / b).
380           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
381           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
382           return zero.ssub_ov(div, overflowOrDiv0);
383         }
384         // A is positive, b is negative, return - (a / -b).
385         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
386         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
387         return zero.ssub_ov(div, overflowOrDiv0);
388       });
389 
390   return overflowOrDiv0 ? Attribute() : result;
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // FloorDivSIOp
395 //===----------------------------------------------------------------------===//
396 
397 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
398   // floordivsi (x, 1) -> x.
399   if (matchPattern(getRhs(), m_One()))
400     return getLhs();
401 
402   // Don't fold if it would overflow or if it requires a division by zero.
403   bool overflowOrDiv0 = false;
404   auto result =
405       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
406         if (overflowOrDiv0 || !b) {
407           overflowOrDiv0 = true;
408           return a;
409         }
410         if (!a)
411           return a;
412         // After this point we know that neither a or b are zero.
413         unsigned bits = a.getBitWidth();
414         APInt zero = APInt::getZero(bits);
415         bool aGtZero = a.sgt(zero);
416         bool bGtZero = b.sgt(zero);
417         if (aGtZero && bGtZero) {
418           // Both positive, return a / b.
419           return a.sdiv_ov(b, overflowOrDiv0);
420         }
421         if (!aGtZero && !bGtZero) {
422           // Both negative, return -a / -b.
423           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
424           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
425           return posA.sdiv_ov(posB, overflowOrDiv0);
426         }
427         if (!aGtZero && bGtZero) {
428           // A is negative, b is positive, return - ceil(-a, b).
429           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
430           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
431           return zero.ssub_ov(ceil, overflowOrDiv0);
432         }
433         // A is positive, b is negative, return - ceil(a, -b).
434         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
435         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
436         return zero.ssub_ov(ceil, overflowOrDiv0);
437       });
438 
439   return overflowOrDiv0 ? Attribute() : result;
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // RemUIOp
444 //===----------------------------------------------------------------------===//
445 
446 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
447   // remui (x, 1) -> 0.
448   if (matchPattern(getRhs(), m_One()))
449     return Builder(getContext()).getZeroAttr(getType());
450 
451   // Don't fold if it would require a division by zero.
452   bool div0 = false;
453   auto result =
454       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
455         if (div0 || b.isNullValue()) {
456           div0 = true;
457           return a;
458         }
459         return a.urem(b);
460       });
461 
462   return div0 ? Attribute() : result;
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // RemSIOp
467 //===----------------------------------------------------------------------===//
468 
469 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
470   // remsi (x, 1) -> 0.
471   if (matchPattern(getRhs(), m_One()))
472     return Builder(getContext()).getZeroAttr(getType());
473 
474   // Don't fold if it would require a division by zero.
475   bool div0 = false;
476   auto result =
477       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
478         if (div0 || b.isNullValue()) {
479           div0 = true;
480           return a;
481         }
482         return a.srem(b);
483       });
484 
485   return div0 ? Attribute() : result;
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // AndIOp
490 //===----------------------------------------------------------------------===//
491 
492 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
493   /// and(x, 0) -> 0
494   if (matchPattern(getRhs(), m_Zero()))
495     return getRhs();
496   /// and(x, allOnes) -> x
497   APInt intValue;
498   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
499     return getLhs();
500 
501   return constFoldBinaryOp<IntegerAttr>(
502       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // OrIOp
507 //===----------------------------------------------------------------------===//
508 
509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
510   /// or(x, 0) -> x
511   if (matchPattern(getRhs(), m_Zero()))
512     return getLhs();
513   /// or(x, <all ones>) -> <all ones>
514   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
515     if (rhsAttr.getValue().isAllOnes())
516       return rhsAttr;
517 
518   return constFoldBinaryOp<IntegerAttr>(
519       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // XOrIOp
524 //===----------------------------------------------------------------------===//
525 
526 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
527   /// xor(x, 0) -> x
528   if (matchPattern(getRhs(), m_Zero()))
529     return getLhs();
530   /// xor(x, x) -> 0
531   if (getLhs() == getRhs())
532     return Builder(getContext()).getZeroAttr(getType());
533   /// xor(xor(x, a), a) -> x
534   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
535     if (prev.getRhs() == getRhs())
536       return prev.getLhs();
537 
538   return constFoldBinaryOp<IntegerAttr>(
539       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
540 }
541 
542 void arith::XOrIOp::getCanonicalizationPatterns(
543     RewritePatternSet &patterns, MLIRContext *context) {
544   patterns.add<XOrINotCmpI>(context);
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // NegFOp
549 //===----------------------------------------------------------------------===//
550 
551 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
552   /// negf(negf(x)) -> x
553   if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
554     return op.getOperand();
555   return constFoldUnaryOp<FloatAttr>(operands,
556                                      [](const APFloat &a) { return -a; });
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // AddFOp
561 //===----------------------------------------------------------------------===//
562 
563 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
564   // addf(x, -0) -> x
565   if (matchPattern(getRhs(), m_NegZeroFloat()))
566     return getLhs();
567 
568   return constFoldBinaryOp<FloatAttr>(
569       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // SubFOp
574 //===----------------------------------------------------------------------===//
575 
576 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
577   // subf(x, +0) -> x
578   if (matchPattern(getRhs(), m_PosZeroFloat()))
579     return getLhs();
580 
581   return constFoldBinaryOp<FloatAttr>(
582       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // MaxFOp
587 //===----------------------------------------------------------------------===//
588 
589 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
590   assert(operands.size() == 2 && "maxf takes two operands");
591 
592   // maxf(x,x) -> x
593   if (getLhs() == getRhs())
594     return getRhs();
595 
596   // maxf(x, -inf) -> x
597   if (matchPattern(getRhs(), m_NegInfFloat()))
598     return getLhs();
599 
600   return constFoldBinaryOp<FloatAttr>(
601       operands,
602       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // MaxSIOp
607 //===----------------------------------------------------------------------===//
608 
609 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
610   assert(operands.size() == 2 && "binary operation takes two operands");
611 
612   // maxsi(x,x) -> x
613   if (getLhs() == getRhs())
614     return getRhs();
615 
616   APInt intValue;
617   // maxsi(x,MAX_INT) -> MAX_INT
618   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
619       intValue.isMaxSignedValue())
620     return getRhs();
621 
622   // maxsi(x, MIN_INT) -> x
623   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
624       intValue.isMinSignedValue())
625     return getLhs();
626 
627   return constFoldBinaryOp<IntegerAttr>(operands,
628                                         [](const APInt &a, const APInt &b) {
629                                           return llvm::APIntOps::smax(a, b);
630                                         });
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // MaxUIOp
635 //===----------------------------------------------------------------------===//
636 
637 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
638   assert(operands.size() == 2 && "binary operation takes two operands");
639 
640   // maxui(x,x) -> x
641   if (getLhs() == getRhs())
642     return getRhs();
643 
644   APInt intValue;
645   // maxui(x,MAX_INT) -> MAX_INT
646   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
647     return getRhs();
648 
649   // maxui(x, MIN_INT) -> x
650   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
651     return getLhs();
652 
653   return constFoldBinaryOp<IntegerAttr>(operands,
654                                         [](const APInt &a, const APInt &b) {
655                                           return llvm::APIntOps::umax(a, b);
656                                         });
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // MinFOp
661 //===----------------------------------------------------------------------===//
662 
663 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
664   assert(operands.size() == 2 && "minf takes two operands");
665 
666   // minf(x,x) -> x
667   if (getLhs() == getRhs())
668     return getRhs();
669 
670   // minf(x, +inf) -> x
671   if (matchPattern(getRhs(), m_PosInfFloat()))
672     return getLhs();
673 
674   return constFoldBinaryOp<FloatAttr>(
675       operands,
676       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
677 }
678 
679 //===----------------------------------------------------------------------===//
680 // MinSIOp
681 //===----------------------------------------------------------------------===//
682 
683 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
684   assert(operands.size() == 2 && "binary operation takes two operands");
685 
686   // minsi(x,x) -> x
687   if (getLhs() == getRhs())
688     return getRhs();
689 
690   APInt intValue;
691   // minsi(x,MIN_INT) -> MIN_INT
692   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
693       intValue.isMinSignedValue())
694     return getRhs();
695 
696   // minsi(x, MAX_INT) -> x
697   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
698       intValue.isMaxSignedValue())
699     return getLhs();
700 
701   return constFoldBinaryOp<IntegerAttr>(operands,
702                                         [](const APInt &a, const APInt &b) {
703                                           return llvm::APIntOps::smin(a, b);
704                                         });
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // MinUIOp
709 //===----------------------------------------------------------------------===//
710 
711 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
712   assert(operands.size() == 2 && "binary operation takes two operands");
713 
714   // minui(x,x) -> x
715   if (getLhs() == getRhs())
716     return getRhs();
717 
718   APInt intValue;
719   // minui(x,MIN_INT) -> MIN_INT
720   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
721     return getRhs();
722 
723   // minui(x, MAX_INT) -> x
724   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
725     return getLhs();
726 
727   return constFoldBinaryOp<IntegerAttr>(operands,
728                                         [](const APInt &a, const APInt &b) {
729                                           return llvm::APIntOps::umin(a, b);
730                                         });
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // MulFOp
735 //===----------------------------------------------------------------------===//
736 
737 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
738   // mulf(x, 1) -> x
739   if (matchPattern(getRhs(), m_OneFloat()))
740     return getLhs();
741 
742   return constFoldBinaryOp<FloatAttr>(
743       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
744 }
745 
746 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
747                                                 MLIRContext *context) {
748   patterns.add<MulFOfNegF>(context);
749 }
750 
751 //===----------------------------------------------------------------------===//
752 // DivFOp
753 //===----------------------------------------------------------------------===//
754 
755 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
756   // divf(x, 1) -> x
757   if (matchPattern(getRhs(), m_OneFloat()))
758     return getLhs();
759 
760   return constFoldBinaryOp<FloatAttr>(
761       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
762 }
763 
764 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
765                                                 MLIRContext *context) {
766   patterns.add<DivFOfNegF>(context);
767 }
768 
769 //===----------------------------------------------------------------------===//
770 // RemFOp
771 //===----------------------------------------------------------------------===//
772 
773 OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) {
774   return constFoldBinaryOp<FloatAttr>(operands,
775                                       [](const APFloat &a, const APFloat &b) {
776                                         APFloat Result(a);
777                                         (void)Result.remainder(b);
778                                         return Result;
779                                       });
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // Utility functions for verifying cast ops
784 //===----------------------------------------------------------------------===//
785 
786 template <typename... Types>
787 using type_list = std::tuple<Types...> *;
788 
789 /// Returns a non-null type only if the provided type is one of the allowed
790 /// types or one of the allowed shaped types of the allowed types. Returns the
791 /// element type if a valid shaped type is provided.
792 template <typename... ShapedTypes, typename... ElementTypes>
793 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
794                               type_list<ElementTypes...>) {
795   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
796     return {};
797 
798   auto underlyingType = getElementTypeOrSelf(type);
799   if (!underlyingType.isa<ElementTypes...>())
800     return {};
801 
802   return underlyingType;
803 }
804 
805 /// Get allowed underlying types for vectors and tensors.
806 template <typename... ElementTypes>
807 static Type getTypeIfLike(Type type) {
808   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
809                            type_list<ElementTypes...>());
810 }
811 
812 /// Get allowed underlying types for vectors, tensors, and memrefs.
813 template <typename... ElementTypes>
814 static Type getTypeIfLikeOrMemRef(Type type) {
815   return getUnderlyingType(type,
816                            type_list<VectorType, TensorType, MemRefType>(),
817                            type_list<ElementTypes...>());
818 }
819 
820 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
821   return inputs.size() == 1 && outputs.size() == 1 &&
822          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
823 }
824 
825 //===----------------------------------------------------------------------===//
826 // Verifiers for integer and floating point extension/truncation ops
827 //===----------------------------------------------------------------------===//
828 
829 // Extend ops can only extend to a wider type.
830 template <typename ValType, typename Op>
831 static LogicalResult verifyExtOp(Op op) {
832   Type srcType = getElementTypeOrSelf(op.getIn().getType());
833   Type dstType = getElementTypeOrSelf(op.getType());
834 
835   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
836     return op.emitError("result type ")
837            << dstType << " must be wider than operand type " << srcType;
838 
839   return success();
840 }
841 
842 // Truncate ops can only truncate to a shorter type.
843 template <typename ValType, typename Op>
844 static LogicalResult verifyTruncateOp(Op op) {
845   Type srcType = getElementTypeOrSelf(op.getIn().getType());
846   Type dstType = getElementTypeOrSelf(op.getType());
847 
848   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
849     return op.emitError("result type ")
850            << dstType << " must be shorter than operand type " << srcType;
851 
852   return success();
853 }
854 
855 /// Validate a cast that changes the width of a type.
856 template <template <typename> class WidthComparator, typename... ElementTypes>
857 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
858   if (!areValidCastInputsAndOutputs(inputs, outputs))
859     return false;
860 
861   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
862   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
863   if (!srcType || !dstType)
864     return false;
865 
866   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
867                                      srcType.getIntOrFloatBitWidth());
868 }
869 
870 //===----------------------------------------------------------------------===//
871 // ExtUIOp
872 //===----------------------------------------------------------------------===//
873 
874 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
875   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
876     getInMutable().assign(lhs.getIn());
877     return getResult();
878   }
879   Type resType = getType();
880   unsigned bitWidth;
881   if (auto shapedType = resType.dyn_cast<ShapedType>())
882     bitWidth = shapedType.getElementTypeBitWidth();
883   else
884     bitWidth = resType.getIntOrFloatBitWidth();
885   return constFoldCastOp<IntegerAttr, IntegerAttr>(
886       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
887         return a.zext(bitWidth);
888       });
889 }
890 
891 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
892   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
893 }
894 
895 LogicalResult arith::ExtUIOp::verify() {
896   return verifyExtOp<IntegerType>(*this);
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // ExtSIOp
901 //===----------------------------------------------------------------------===//
902 
903 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
904   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
905     getInMutable().assign(lhs.getIn());
906     return getResult();
907   }
908   Type resType = getType();
909   unsigned bitWidth;
910   if (auto shapedType = resType.dyn_cast<ShapedType>())
911     bitWidth = shapedType.getElementTypeBitWidth();
912   else
913     bitWidth = resType.getIntOrFloatBitWidth();
914   return constFoldCastOp<IntegerAttr, IntegerAttr>(
915       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
916         return a.sext(bitWidth);
917       });
918 }
919 
920 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
921   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
922 }
923 
924 void arith::ExtSIOp::getCanonicalizationPatterns(
925     RewritePatternSet &patterns, MLIRContext *context) {
926   patterns.add<ExtSIOfExtUI>(context);
927 }
928 
929 LogicalResult arith::ExtSIOp::verify() {
930   return verifyExtOp<IntegerType>(*this);
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // ExtFOp
935 //===----------------------------------------------------------------------===//
936 
937 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
938   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
939 }
940 
941 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
942 
943 //===----------------------------------------------------------------------===//
944 // TruncIOp
945 //===----------------------------------------------------------------------===//
946 
947 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
948   assert(operands.size() == 1 && "unary operation takes one operand");
949 
950   // trunci(zexti(a)) -> a
951   // trunci(sexti(a)) -> a
952   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
953       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
954     return getOperand().getDefiningOp()->getOperand(0);
955 
956   // trunci(trunci(a)) -> trunci(a))
957   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
958     setOperand(getOperand().getDefiningOp()->getOperand(0));
959     return getResult();
960   }
961 
962   Type resType = getType();
963   unsigned bitWidth;
964   if (auto shapedType = resType.dyn_cast<ShapedType>())
965     bitWidth = shapedType.getElementTypeBitWidth();
966   else
967     bitWidth = resType.getIntOrFloatBitWidth();
968 
969   return constFoldCastOp<IntegerAttr, IntegerAttr>(
970       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
971         return a.trunc(bitWidth);
972       });
973 }
974 
975 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
976   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
977 }
978 
979 LogicalResult arith::TruncIOp::verify() {
980   return verifyTruncateOp<IntegerType>(*this);
981 }
982 
983 //===----------------------------------------------------------------------===//
984 // TruncFOp
985 //===----------------------------------------------------------------------===//
986 
987 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
988 /// can be represented without precision loss or rounding.
989 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
990   assert(operands.size() == 1 && "unary operation takes one operand");
991 
992   auto constOperand = operands.front();
993   if (!constOperand || !constOperand.isa<FloatAttr>())
994     return {};
995 
996   // Convert to target type via 'double'.
997   double sourceValue =
998       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
999   auto targetAttr = FloatAttr::get(getType(), sourceValue);
1000 
1001   // Propagate if constant's value does not change after truncation.
1002   if (sourceValue == targetAttr.getValue().convertToDouble())
1003     return targetAttr;
1004 
1005   return {};
1006 }
1007 
1008 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1009   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1010 }
1011 
1012 LogicalResult arith::TruncFOp::verify() {
1013   return verifyTruncateOp<FloatType>(*this);
1014 }
1015 
1016 //===----------------------------------------------------------------------===//
1017 // AndIOp
1018 //===----------------------------------------------------------------------===//
1019 
1020 void arith::AndIOp::getCanonicalizationPatterns(
1021     RewritePatternSet &patterns, MLIRContext *context) {
1022   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1023 }
1024 
1025 //===----------------------------------------------------------------------===//
1026 // OrIOp
1027 //===----------------------------------------------------------------------===//
1028 
1029 void arith::OrIOp::getCanonicalizationPatterns(
1030     RewritePatternSet &patterns, MLIRContext *context) {
1031   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1032 }
1033 
1034 //===----------------------------------------------------------------------===//
1035 // Verifiers for casts between integers and floats.
1036 //===----------------------------------------------------------------------===//
1037 
1038 template <typename From, typename To>
1039 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1040   if (!areValidCastInputsAndOutputs(inputs, outputs))
1041     return false;
1042 
1043   auto srcType = getTypeIfLike<From>(inputs.front());
1044   auto dstType = getTypeIfLike<To>(outputs.back());
1045 
1046   return srcType && dstType;
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // UIToFPOp
1051 //===----------------------------------------------------------------------===//
1052 
1053 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1054   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1055 }
1056 
1057 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1058   Type resType = getType();
1059   Type resEleType;
1060   if (auto shapedType = resType.dyn_cast<ShapedType>())
1061     resEleType = shapedType.getElementType();
1062   else
1063     resEleType = resType;
1064   return constFoldCastOp<IntegerAttr, FloatAttr>(
1065       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1066         FloatType floatTy = resEleType.cast<FloatType>();
1067         APFloat apf(floatTy.getFloatSemantics(),
1068                     APInt::getZero(floatTy.getWidth()));
1069         apf.convertFromAPInt(a, /*IsSigned=*/false,
1070                              APFloat::rmNearestTiesToEven);
1071         return apf;
1072       });
1073 }
1074 
1075 //===----------------------------------------------------------------------===//
1076 // SIToFPOp
1077 //===----------------------------------------------------------------------===//
1078 
1079 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1080   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1081 }
1082 
1083 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1084   Type resType = getType();
1085   Type resEleType;
1086   if (auto shapedType = resType.dyn_cast<ShapedType>())
1087     resEleType = shapedType.getElementType();
1088   else
1089     resEleType = resType;
1090   return constFoldCastOp<IntegerAttr, FloatAttr>(
1091       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1092         FloatType floatTy = resEleType.cast<FloatType>();
1093         APFloat apf(floatTy.getFloatSemantics(),
1094                     APInt::getZero(floatTy.getWidth()));
1095         apf.convertFromAPInt(a, /*IsSigned=*/true,
1096                              APFloat::rmNearestTiesToEven);
1097         return apf;
1098       });
1099 }
1100 //===----------------------------------------------------------------------===//
1101 // FPToUIOp
1102 //===----------------------------------------------------------------------===//
1103 
1104 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1105   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1106 }
1107 
1108 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1109   Type resType = getType();
1110   Type resEleType;
1111   if (auto shapedType = resType.dyn_cast<ShapedType>())
1112     resEleType = shapedType.getElementType();
1113   else
1114     resEleType = resType;
1115   return constFoldCastOp<FloatAttr, IntegerAttr>(
1116       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1117         IntegerType intTy = resEleType.cast<IntegerType>();
1118         bool ignored;
1119         APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1120         castStatus = APFloat::opInvalidOp !=
1121                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1122         return api;
1123       });
1124 }
1125 
1126 //===----------------------------------------------------------------------===//
1127 // FPToSIOp
1128 //===----------------------------------------------------------------------===//
1129 
1130 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1131   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1132 }
1133 
1134 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1135   Type resType = getType();
1136   Type resEleType;
1137   if (auto shapedType = resType.dyn_cast<ShapedType>())
1138     resEleType = shapedType.getElementType();
1139   else
1140     resEleType = resType;
1141   return constFoldCastOp<FloatAttr, IntegerAttr>(
1142       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1143         IntegerType intTy = resEleType.cast<IntegerType>();
1144         bool ignored;
1145         APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1146         castStatus = APFloat::opInvalidOp !=
1147                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1148         return api;
1149       });
1150 }
1151 
1152 //===----------------------------------------------------------------------===//
1153 // IndexCastOp
1154 //===----------------------------------------------------------------------===//
1155 
1156 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1157                                            TypeRange outputs) {
1158   if (!areValidCastInputsAndOutputs(inputs, outputs))
1159     return false;
1160 
1161   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1162   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1163   if (!srcType || !dstType)
1164     return false;
1165 
1166   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1167          (srcType.isSignlessInteger() && dstType.isIndex());
1168 }
1169 
1170 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1171   // index_cast(constant) -> constant
1172   // A little hack because we go through int. Otherwise, the size of the
1173   // constant might need to change.
1174   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1175     return IntegerAttr::get(getType(), value.getInt());
1176 
1177   return {};
1178 }
1179 
1180 void arith::IndexCastOp::getCanonicalizationPatterns(
1181     RewritePatternSet &patterns, MLIRContext *context) {
1182   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1183 }
1184 
1185 //===----------------------------------------------------------------------===//
1186 // BitcastOp
1187 //===----------------------------------------------------------------------===//
1188 
1189 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1190   if (!areValidCastInputsAndOutputs(inputs, outputs))
1191     return false;
1192 
1193   auto srcType =
1194       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1195   auto dstType =
1196       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1197   if (!srcType || !dstType)
1198     return false;
1199 
1200   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1201 }
1202 
1203 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1204   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1205 
1206   auto resType = getType();
1207   auto operand = operands[0];
1208   if (!operand)
1209     return {};
1210 
1211   /// Bitcast dense elements.
1212   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1213     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1214   /// Other shaped types unhandled.
1215   if (resType.isa<ShapedType>())
1216     return {};
1217 
1218   /// Bitcast integer or float to integer or float.
1219   APInt bits = operand.isa<FloatAttr>()
1220                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1221                    : operand.cast<IntegerAttr>().getValue();
1222 
1223   if (auto resFloatType = resType.dyn_cast<FloatType>())
1224     return FloatAttr::get(resType,
1225                           APFloat(resFloatType.getFloatSemantics(), bits));
1226   return IntegerAttr::get(resType, bits);
1227 }
1228 
1229 void arith::BitcastOp::getCanonicalizationPatterns(
1230     RewritePatternSet &patterns, MLIRContext *context) {
1231   patterns.add<BitcastOfBitcast>(context);
1232 }
1233 
1234 //===----------------------------------------------------------------------===//
1235 // Helpers for compare ops
1236 //===----------------------------------------------------------------------===//
1237 
1238 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1239 static Type getI1SameShape(Type type) {
1240   auto i1Type = IntegerType::get(type.getContext(), 1);
1241   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1242     return RankedTensorType::get(tensorType.getShape(), i1Type);
1243   if (type.isa<UnrankedTensorType>())
1244     return UnrankedTensorType::get(i1Type);
1245   if (auto vectorType = type.dyn_cast<VectorType>())
1246     return VectorType::get(vectorType.getShape(), i1Type,
1247                            vectorType.getNumScalableDims());
1248   return i1Type;
1249 }
1250 
1251 //===----------------------------------------------------------------------===//
1252 // CmpIOp
1253 //===----------------------------------------------------------------------===//
1254 
1255 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1256 /// comparison predicates.
1257 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1258                                     const APInt &lhs, const APInt &rhs) {
1259   switch (predicate) {
1260   case arith::CmpIPredicate::eq:
1261     return lhs.eq(rhs);
1262   case arith::CmpIPredicate::ne:
1263     return lhs.ne(rhs);
1264   case arith::CmpIPredicate::slt:
1265     return lhs.slt(rhs);
1266   case arith::CmpIPredicate::sle:
1267     return lhs.sle(rhs);
1268   case arith::CmpIPredicate::sgt:
1269     return lhs.sgt(rhs);
1270   case arith::CmpIPredicate::sge:
1271     return lhs.sge(rhs);
1272   case arith::CmpIPredicate::ult:
1273     return lhs.ult(rhs);
1274   case arith::CmpIPredicate::ule:
1275     return lhs.ule(rhs);
1276   case arith::CmpIPredicate::ugt:
1277     return lhs.ugt(rhs);
1278   case arith::CmpIPredicate::uge:
1279     return lhs.uge(rhs);
1280   }
1281   llvm_unreachable("unknown cmpi predicate kind");
1282 }
1283 
1284 /// Returns true if the predicate is true for two equal operands.
1285 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1286   switch (predicate) {
1287   case arith::CmpIPredicate::eq:
1288   case arith::CmpIPredicate::sle:
1289   case arith::CmpIPredicate::sge:
1290   case arith::CmpIPredicate::ule:
1291   case arith::CmpIPredicate::uge:
1292     return true;
1293   case arith::CmpIPredicate::ne:
1294   case arith::CmpIPredicate::slt:
1295   case arith::CmpIPredicate::sgt:
1296   case arith::CmpIPredicate::ult:
1297   case arith::CmpIPredicate::ugt:
1298     return false;
1299   }
1300   llvm_unreachable("unknown cmpi predicate kind");
1301 }
1302 
1303 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1304   auto boolAttr = BoolAttr::get(ctx, value);
1305   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1306   if (!shapedType)
1307     return boolAttr;
1308   return DenseElementsAttr::get(shapedType, boolAttr);
1309 }
1310 
1311 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1312   assert(operands.size() == 2 && "cmpi takes two operands");
1313 
1314   // cmpi(pred, x, x)
1315   if (getLhs() == getRhs()) {
1316     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1317     return getBoolAttribute(getType(), getContext(), val);
1318   }
1319 
1320   if (matchPattern(getRhs(), m_Zero())) {
1321     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1322       // extsi(%x : i1 -> iN) != 0  ->  %x
1323       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1324           getPredicate() == arith::CmpIPredicate::ne)
1325         return extOp.getOperand();
1326     }
1327     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1328       // extui(%x : i1 -> iN) != 0  ->  %x
1329       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1330           getPredicate() == arith::CmpIPredicate::ne)
1331         return extOp.getOperand();
1332     }
1333   }
1334 
1335   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1336   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1337   if (!lhs || !rhs)
1338     return {};
1339 
1340   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1341   return BoolAttr::get(getContext(), val);
1342 }
1343 
1344 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1345                                                 MLIRContext *context) {
1346   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1347 }
1348 
1349 //===----------------------------------------------------------------------===//
1350 // CmpFOp
1351 //===----------------------------------------------------------------------===//
1352 
1353 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1354 /// comparison predicates.
1355 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1356                                     const APFloat &lhs, const APFloat &rhs) {
1357   auto cmpResult = lhs.compare(rhs);
1358   switch (predicate) {
1359   case arith::CmpFPredicate::AlwaysFalse:
1360     return false;
1361   case arith::CmpFPredicate::OEQ:
1362     return cmpResult == APFloat::cmpEqual;
1363   case arith::CmpFPredicate::OGT:
1364     return cmpResult == APFloat::cmpGreaterThan;
1365   case arith::CmpFPredicate::OGE:
1366     return cmpResult == APFloat::cmpGreaterThan ||
1367            cmpResult == APFloat::cmpEqual;
1368   case arith::CmpFPredicate::OLT:
1369     return cmpResult == APFloat::cmpLessThan;
1370   case arith::CmpFPredicate::OLE:
1371     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1372   case arith::CmpFPredicate::ONE:
1373     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1374   case arith::CmpFPredicate::ORD:
1375     return cmpResult != APFloat::cmpUnordered;
1376   case arith::CmpFPredicate::UEQ:
1377     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1378   case arith::CmpFPredicate::UGT:
1379     return cmpResult == APFloat::cmpUnordered ||
1380            cmpResult == APFloat::cmpGreaterThan;
1381   case arith::CmpFPredicate::UGE:
1382     return cmpResult == APFloat::cmpUnordered ||
1383            cmpResult == APFloat::cmpGreaterThan ||
1384            cmpResult == APFloat::cmpEqual;
1385   case arith::CmpFPredicate::ULT:
1386     return cmpResult == APFloat::cmpUnordered ||
1387            cmpResult == APFloat::cmpLessThan;
1388   case arith::CmpFPredicate::ULE:
1389     return cmpResult == APFloat::cmpUnordered ||
1390            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1391   case arith::CmpFPredicate::UNE:
1392     return cmpResult != APFloat::cmpEqual;
1393   case arith::CmpFPredicate::UNO:
1394     return cmpResult == APFloat::cmpUnordered;
1395   case arith::CmpFPredicate::AlwaysTrue:
1396     return true;
1397   }
1398   llvm_unreachable("unknown cmpf predicate kind");
1399 }
1400 
1401 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1402   assert(operands.size() == 2 && "cmpf takes two operands");
1403 
1404   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1405   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1406 
1407   // If one operand is NaN, making them both NaN does not change the result.
1408   if (lhs && lhs.getValue().isNaN())
1409     rhs = lhs;
1410   if (rhs && rhs.getValue().isNaN())
1411     lhs = rhs;
1412 
1413   if (!lhs || !rhs)
1414     return {};
1415 
1416   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1417   return BoolAttr::get(getContext(), val);
1418 }
1419 
1420 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1421 public:
1422   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1423 
1424   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1425                                                  bool isUnsigned) {
1426     using namespace arith;
1427     switch (pred) {
1428     case CmpFPredicate::UEQ:
1429     case CmpFPredicate::OEQ:
1430       return CmpIPredicate::eq;
1431     case CmpFPredicate::UGT:
1432     case CmpFPredicate::OGT:
1433       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1434     case CmpFPredicate::UGE:
1435     case CmpFPredicate::OGE:
1436       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1437     case CmpFPredicate::ULT:
1438     case CmpFPredicate::OLT:
1439       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1440     case CmpFPredicate::ULE:
1441     case CmpFPredicate::OLE:
1442       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1443     case CmpFPredicate::UNE:
1444     case CmpFPredicate::ONE:
1445       return CmpIPredicate::ne;
1446     default:
1447       llvm_unreachable("Unexpected predicate!");
1448     }
1449   }
1450 
1451   LogicalResult matchAndRewrite(CmpFOp op,
1452                                 PatternRewriter &rewriter) const override {
1453     FloatAttr flt;
1454     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1455       return failure();
1456 
1457     const APFloat &rhs = flt.getValue();
1458 
1459     // Don't attempt to fold a nan.
1460     if (rhs.isNaN())
1461       return failure();
1462 
1463     // Get the width of the mantissa.  We don't want to hack on conversions that
1464     // might lose information from the integer, e.g. "i64 -> float"
1465     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1466     int mantissaWidth = floatTy.getFPMantissaWidth();
1467     if (mantissaWidth <= 0)
1468       return failure();
1469 
1470     bool isUnsigned;
1471     Value intVal;
1472 
1473     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1474       isUnsigned = false;
1475       intVal = si.getIn();
1476     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1477       isUnsigned = true;
1478       intVal = ui.getIn();
1479     } else {
1480       return failure();
1481     }
1482 
1483     // Check to see that the input is converted from an integer type that is
1484     // small enough that preserves all bits.
1485     auto intTy = intVal.getType().cast<IntegerType>();
1486     auto intWidth = intTy.getWidth();
1487 
1488     // Number of bits representing values, as opposed to the sign
1489     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1490 
1491     // Following test does NOT adjust intWidth downwards for signed inputs,
1492     // because the most negative value still requires all the mantissa bits
1493     // to distinguish it from one less than that value.
1494     if ((int)intWidth > mantissaWidth) {
1495       // Conversion would lose accuracy. Check if loss can impact comparison.
1496       int exponent = ilogb(rhs);
1497       if (exponent == APFloat::IEK_Inf) {
1498         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1499         if (maxExponent < (int)valueBits) {
1500           // Conversion could create infinity.
1501           return failure();
1502         }
1503       } else {
1504         // Note that if rhs is zero or NaN, then Exp is negative
1505         // and first condition is trivially false.
1506         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1507           // Conversion could affect comparison.
1508           return failure();
1509         }
1510       }
1511     }
1512 
1513     // Convert to equivalent cmpi predicate
1514     CmpIPredicate pred;
1515     switch (op.getPredicate()) {
1516     case CmpFPredicate::ORD:
1517       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1518       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1519                                                  /*width=*/1);
1520       return success();
1521     case CmpFPredicate::UNO:
1522       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1523       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1524                                                  /*width=*/1);
1525       return success();
1526     default:
1527       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1528       break;
1529     }
1530 
1531     if (!isUnsigned) {
1532       // If the rhs value is > SignedMax, fold the comparison.  This handles
1533       // +INF and large values.
1534       APFloat signedMax(rhs.getSemantics());
1535       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1536                                  APFloat::rmNearestTiesToEven);
1537       if (signedMax < rhs) { // smax < 13123.0
1538         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1539             pred == CmpIPredicate::sle)
1540           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1541                                                      /*width=*/1);
1542         else
1543           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1544                                                      /*width=*/1);
1545         return success();
1546       }
1547     } else {
1548       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1549       // +INF and large values.
1550       APFloat unsignedMax(rhs.getSemantics());
1551       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1552                                    APFloat::rmNearestTiesToEven);
1553       if (unsignedMax < rhs) { // umax < 13123.0
1554         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1555             pred == CmpIPredicate::ule)
1556           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1557                                                      /*width=*/1);
1558         else
1559           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1560                                                      /*width=*/1);
1561         return success();
1562       }
1563     }
1564 
1565     if (!isUnsigned) {
1566       // See if the rhs value is < SignedMin.
1567       APFloat signedMin(rhs.getSemantics());
1568       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1569                                  APFloat::rmNearestTiesToEven);
1570       if (signedMin > rhs) { // smin > 12312.0
1571         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1572             pred == CmpIPredicate::sge)
1573           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1574                                                      /*width=*/1);
1575         else
1576           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1577                                                      /*width=*/1);
1578         return success();
1579       }
1580     } else {
1581       // See if the rhs value is < UnsignedMin.
1582       APFloat unsignedMin(rhs.getSemantics());
1583       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1584                                    APFloat::rmNearestTiesToEven);
1585       if (unsignedMin > rhs) { // umin > 12312.0
1586         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1587             pred == CmpIPredicate::uge)
1588           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1589                                                      /*width=*/1);
1590         else
1591           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1592                                                      /*width=*/1);
1593         return success();
1594       }
1595     }
1596 
1597     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1598     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1599     // casting the FP value to the integer value and back, checking for
1600     // equality. Don't do this for zero, because -0.0 is not fractional.
1601     bool ignored;
1602     APSInt rhsInt(intWidth, isUnsigned);
1603     if (APFloat::opInvalidOp ==
1604         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1605       // Undefined behavior invoked - the destination type can't represent
1606       // the input constant.
1607       return failure();
1608     }
1609 
1610     if (!rhs.isZero()) {
1611       APFloat apf(floatTy.getFloatSemantics(),
1612                   APInt::getZero(floatTy.getWidth()));
1613       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1614 
1615       bool equal = apf == rhs;
1616       if (!equal) {
1617         // If we had a comparison against a fractional value, we have to adjust
1618         // the compare predicate and sometimes the value.  rhsInt is rounded
1619         // towards zero at this point.
1620         switch (pred) {
1621         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1622           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1623                                                      /*width=*/1);
1624           return success();
1625         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1626           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1627                                                      /*width=*/1);
1628           return success();
1629         case CmpIPredicate::ule:
1630           // (float)int <= 4.4   --> int <= 4
1631           // (float)int <= -4.4  --> false
1632           if (rhs.isNegative()) {
1633             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1634                                                        /*width=*/1);
1635             return success();
1636           }
1637           break;
1638         case CmpIPredicate::sle:
1639           // (float)int <= 4.4   --> int <= 4
1640           // (float)int <= -4.4  --> int < -4
1641           if (rhs.isNegative())
1642             pred = CmpIPredicate::slt;
1643           break;
1644         case CmpIPredicate::ult:
1645           // (float)int < -4.4   --> false
1646           // (float)int < 4.4    --> int <= 4
1647           if (rhs.isNegative()) {
1648             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1649                                                        /*width=*/1);
1650             return success();
1651           }
1652           pred = CmpIPredicate::ule;
1653           break;
1654         case CmpIPredicate::slt:
1655           // (float)int < -4.4   --> int < -4
1656           // (float)int < 4.4    --> int <= 4
1657           if (!rhs.isNegative())
1658             pred = CmpIPredicate::sle;
1659           break;
1660         case CmpIPredicate::ugt:
1661           // (float)int > 4.4    --> int > 4
1662           // (float)int > -4.4   --> true
1663           if (rhs.isNegative()) {
1664             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1665                                                        /*width=*/1);
1666             return success();
1667           }
1668           break;
1669         case CmpIPredicate::sgt:
1670           // (float)int > 4.4    --> int > 4
1671           // (float)int > -4.4   --> int >= -4
1672           if (rhs.isNegative())
1673             pred = CmpIPredicate::sge;
1674           break;
1675         case CmpIPredicate::uge:
1676           // (float)int >= -4.4   --> true
1677           // (float)int >= 4.4    --> int > 4
1678           if (rhs.isNegative()) {
1679             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1680                                                        /*width=*/1);
1681             return success();
1682           }
1683           pred = CmpIPredicate::ugt;
1684           break;
1685         case CmpIPredicate::sge:
1686           // (float)int >= -4.4   --> int >= -4
1687           // (float)int >= 4.4    --> int > 4
1688           if (!rhs.isNegative())
1689             pred = CmpIPredicate::sgt;
1690           break;
1691         }
1692       }
1693     }
1694 
1695     // Lower this FP comparison into an appropriate integer version of the
1696     // comparison.
1697     rewriter.replaceOpWithNewOp<CmpIOp>(
1698         op, pred, intVal,
1699         rewriter.create<ConstantOp>(
1700             op.getLoc(), intVal.getType(),
1701             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1702     return success();
1703   }
1704 };
1705 
1706 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1707                                                 MLIRContext *context) {
1708   patterns.insert<CmpFIntToFPConst>(context);
1709 }
1710 
1711 //===----------------------------------------------------------------------===//
1712 // SelectOp
1713 //===----------------------------------------------------------------------===//
1714 
1715 // Transforms a select of a boolean to arithmetic operations
1716 //
1717 //  arith.select %arg, %x, %y : i1
1718 //
1719 //  becomes
1720 //
1721 //  and(%arg, %x) or and(!%arg, %y)
1722 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1723   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1724 
1725   LogicalResult matchAndRewrite(arith::SelectOp op,
1726                                 PatternRewriter &rewriter) const override {
1727     if (!op.getType().isInteger(1))
1728       return failure();
1729 
1730     Value falseConstant =
1731         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1732     Value notCondition = rewriter.create<arith::XOrIOp>(
1733         op.getLoc(), op.getCondition(), falseConstant);
1734 
1735     Value trueVal = rewriter.create<arith::AndIOp>(
1736         op.getLoc(), op.getCondition(), op.getTrueValue());
1737     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1738                                                     op.getFalseValue());
1739     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1740     return success();
1741   }
1742 };
1743 
1744 //  select %arg, %c1, %c0 => extui %arg
1745 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1746   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1747 
1748   LogicalResult matchAndRewrite(arith::SelectOp op,
1749                                 PatternRewriter &rewriter) const override {
1750     // Cannot extui i1 to i1, or i1 to f32
1751     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1752       return failure();
1753 
1754     // select %x, c1, %c0 => extui %arg
1755     if (matchPattern(op.getTrueValue(), m_One()) &&
1756         matchPattern(op.getFalseValue(), m_Zero())) {
1757       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1758                                                   op.getCondition());
1759       return success();
1760     }
1761 
1762     // select %x, c0, %c1 => extui (xor %arg, true)
1763     if (matchPattern(op.getTrueValue(), m_Zero()) &&
1764         matchPattern(op.getFalseValue(), m_One())) {
1765       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1766           op, op.getType(),
1767           rewriter.create<arith::XOrIOp>(
1768               op.getLoc(), op.getCondition(),
1769               rewriter.create<arith::ConstantIntOp>(
1770                   op.getLoc(), 1, op.getCondition().getType())));
1771       return success();
1772     }
1773 
1774     return failure();
1775   }
1776 };
1777 
1778 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1779                                                   MLIRContext *context) {
1780   results.add<SelectI1Simplify, SelectToExtUI>(context);
1781 }
1782 
1783 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1784   Value trueVal = getTrueValue();
1785   Value falseVal = getFalseValue();
1786   if (trueVal == falseVal)
1787     return trueVal;
1788 
1789   Value condition = getCondition();
1790 
1791   // select true, %0, %1 => %0
1792   if (matchPattern(condition, m_One()))
1793     return trueVal;
1794 
1795   // select false, %0, %1 => %1
1796   if (matchPattern(condition, m_Zero()))
1797     return falseVal;
1798 
1799   // select %x, true, false => %x
1800   if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
1801       matchPattern(getFalseValue(), m_Zero()))
1802     return condition;
1803 
1804   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1805     auto pred = cmp.getPredicate();
1806     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1807       auto cmpLhs = cmp.getLhs();
1808       auto cmpRhs = cmp.getRhs();
1809 
1810       // %0 = arith.cmpi eq, %arg0, %arg1
1811       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1812 
1813       // %0 = arith.cmpi ne, %arg0, %arg1
1814       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1815 
1816       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1817           (cmpRhs == trueVal && cmpLhs == falseVal))
1818         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1819     }
1820   }
1821   return nullptr;
1822 }
1823 
1824 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1825   Type conditionType, resultType;
1826   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1827   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1828       parser.parseOptionalAttrDict(result.attributes) ||
1829       parser.parseColonType(resultType))
1830     return failure();
1831 
1832   // Check for the explicit condition type if this is a masked tensor or vector.
1833   if (succeeded(parser.parseOptionalComma())) {
1834     conditionType = resultType;
1835     if (parser.parseType(resultType))
1836       return failure();
1837   } else {
1838     conditionType = parser.getBuilder().getI1Type();
1839   }
1840 
1841   result.addTypes(resultType);
1842   return parser.resolveOperands(operands,
1843                                 {conditionType, resultType, resultType},
1844                                 parser.getNameLoc(), result.operands);
1845 }
1846 
1847 void arith::SelectOp::print(OpAsmPrinter &p) {
1848   p << " " << getOperands();
1849   p.printOptionalAttrDict((*this)->getAttrs());
1850   p << " : ";
1851   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1852     p << condType << ", ";
1853   p << getType();
1854 }
1855 
1856 LogicalResult arith::SelectOp::verify() {
1857   Type conditionType = getCondition().getType();
1858   if (conditionType.isSignlessInteger(1))
1859     return success();
1860 
1861   // If the result type is a vector or tensor, the type can be a mask with the
1862   // same elements.
1863   Type resultType = getType();
1864   if (!resultType.isa<TensorType, VectorType>())
1865     return emitOpError() << "expected condition to be a signless i1, but got "
1866                          << conditionType;
1867   Type shapedConditionType = getI1SameShape(resultType);
1868   if (conditionType != shapedConditionType) {
1869     return emitOpError() << "expected condition type to have the same shape "
1870                             "as the result type, expected "
1871                          << shapedConditionType << ", but got "
1872                          << conditionType;
1873   }
1874   return success();
1875 }
1876 //===----------------------------------------------------------------------===//
1877 // ShLIOp
1878 //===----------------------------------------------------------------------===//
1879 
1880 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1881   // Don't fold if shifting more than the bit width.
1882   bool bounded = false;
1883   auto result = constFoldBinaryOp<IntegerAttr>(
1884       operands, [&](const APInt &a, const APInt &b) {
1885         bounded = b.ule(b.getBitWidth());
1886         return a.shl(b);
1887       });
1888   return bounded ? result : Attribute();
1889 }
1890 
1891 //===----------------------------------------------------------------------===//
1892 // ShRUIOp
1893 //===----------------------------------------------------------------------===//
1894 
1895 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1896   // Don't fold if shifting more than the bit width.
1897   bool bounded = false;
1898   auto result = constFoldBinaryOp<IntegerAttr>(
1899       operands, [&](const APInt &a, const APInt &b) {
1900         bounded = b.ule(b.getBitWidth());
1901         return a.lshr(b);
1902       });
1903   return bounded ? result : Attribute();
1904 }
1905 
1906 //===----------------------------------------------------------------------===//
1907 // ShRSIOp
1908 //===----------------------------------------------------------------------===//
1909 
1910 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1911   // Don't fold if shifting more than the bit width.
1912   bool bounded = false;
1913   auto result = constFoldBinaryOp<IntegerAttr>(
1914       operands, [&](const APInt &a, const APInt &b) {
1915         bounded = b.ule(b.getBitWidth());
1916         return a.ashr(b);
1917       });
1918   return bounded ? result : Attribute();
1919 }
1920 
1921 //===----------------------------------------------------------------------===//
1922 // Atomic Enum
1923 //===----------------------------------------------------------------------===//
1924 
1925 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1926 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1927                                             OpBuilder &builder, Location loc) {
1928   switch (kind) {
1929   case AtomicRMWKind::maxf:
1930     return builder.getFloatAttr(
1931         resultType,
1932         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1933                         /*Negative=*/true));
1934   case AtomicRMWKind::addf:
1935   case AtomicRMWKind::addi:
1936   case AtomicRMWKind::maxu:
1937   case AtomicRMWKind::ori:
1938     return builder.getZeroAttr(resultType);
1939   case AtomicRMWKind::andi:
1940     return builder.getIntegerAttr(
1941         resultType,
1942         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1943   case AtomicRMWKind::maxs:
1944     return builder.getIntegerAttr(
1945         resultType,
1946         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1947   case AtomicRMWKind::minf:
1948     return builder.getFloatAttr(
1949         resultType,
1950         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1951                         /*Negative=*/false));
1952   case AtomicRMWKind::mins:
1953     return builder.getIntegerAttr(
1954         resultType,
1955         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1956   case AtomicRMWKind::minu:
1957     return builder.getIntegerAttr(
1958         resultType,
1959         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1960   case AtomicRMWKind::muli:
1961     return builder.getIntegerAttr(resultType, 1);
1962   case AtomicRMWKind::mulf:
1963     return builder.getFloatAttr(resultType, 1);
1964   // TODO: Add remaining reduction operations.
1965   default:
1966     (void)emitOptionalError(loc, "Reduction operation type not supported");
1967     break;
1968   }
1969   return nullptr;
1970 }
1971 
1972 /// Returns the identity value associated with an AtomicRMWKind op.
1973 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1974                                     OpBuilder &builder, Location loc) {
1975   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1976   return builder.create<arith::ConstantOp>(loc, attr);
1977 }
1978 
1979 /// Return the value obtained by applying the reduction operation kind
1980 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1981 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1982                                   Location loc, Value lhs, Value rhs) {
1983   switch (op) {
1984   case AtomicRMWKind::addf:
1985     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1986   case AtomicRMWKind::addi:
1987     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1988   case AtomicRMWKind::mulf:
1989     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1990   case AtomicRMWKind::muli:
1991     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1992   case AtomicRMWKind::maxf:
1993     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1994   case AtomicRMWKind::minf:
1995     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1996   case AtomicRMWKind::maxs:
1997     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1998   case AtomicRMWKind::mins:
1999     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2000   case AtomicRMWKind::maxu:
2001     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2002   case AtomicRMWKind::minu:
2003     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2004   case AtomicRMWKind::ori:
2005     return builder.create<arith::OrIOp>(loc, lhs, rhs);
2006   case AtomicRMWKind::andi:
2007     return builder.create<arith::AndIOp>(loc, lhs, rhs);
2008   // TODO: Add remaining reduction operations.
2009   default:
2010     (void)emitOptionalError(loc, "Reduction operation type not supported");
2011     break;
2012   }
2013   return nullptr;
2014 }
2015 
2016 //===----------------------------------------------------------------------===//
2017 // TableGen'd op method definitions
2018 //===----------------------------------------------------------------------===//
2019 
2020 #define GET_OP_CLASSES
2021 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2022 
2023 //===----------------------------------------------------------------------===//
2024 // TableGen'd enum attribute definitions
2025 //===----------------------------------------------------------------------===//
2026 
2027 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2028