1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30                                    Attribute lhs, Attribute rhs) {
31   return builder.getIntegerAttr(res.getType(),
32                                 lhs.cast<IntegerAttr>().getInt() +
33                                     rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37                                    Attribute lhs, Attribute rhs) {
38   return builder.getIntegerAttr(res.getType(),
39                                 lhs.cast<IntegerAttr>().getInt() -
40                                     rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45   switch (pred) {
46   case arith::CmpIPredicate::eq:
47     return arith::CmpIPredicate::ne;
48   case arith::CmpIPredicate::ne:
49     return arith::CmpIPredicate::eq;
50   case arith::CmpIPredicate::slt:
51     return arith::CmpIPredicate::sge;
52   case arith::CmpIPredicate::sle:
53     return arith::CmpIPredicate::sgt;
54   case arith::CmpIPredicate::sgt:
55     return arith::CmpIPredicate::sle;
56   case arith::CmpIPredicate::sge:
57     return arith::CmpIPredicate::slt;
58   case arith::CmpIPredicate::ult:
59     return arith::CmpIPredicate::uge;
60   case arith::CmpIPredicate::ule:
61     return arith::CmpIPredicate::ugt;
62   case arith::CmpIPredicate::ugt:
63     return arith::CmpIPredicate::ule;
64   case arith::CmpIPredicate::uge:
65     return arith::CmpIPredicate::ult;
66   }
67   llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71   return arith::CmpIPredicateAttr::get(pred.getContext(),
72                                        invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88     function_ref<void(Value, StringRef)> setNameFn) {
89   auto type = getType();
90   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91     auto intType = type.dyn_cast<IntegerType>();
92 
93     // Sugar i1 constants with 'true' and 'false'.
94     if (intType && intType.getWidth() == 1)
95       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97     // Otherwise, build a complex name with the value and type.
98     SmallString<32> specialNameBuffer;
99     llvm::raw_svector_ostream specialName(specialNameBuffer);
100     specialName << 'c' << intCst.getValue();
101     if (intType)
102       specialName << '_' << type;
103     setNameFn(getResult(), specialName.str());
104   } else {
105     setNameFn(getResult(), "cst");
106   }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
111 LogicalResult arith::ConstantOp::verify() {
112   auto type = getType();
113   // The value's type must match the return type.
114   if (getValue().getType() != type) {
115     return emitOpError() << "value type " << getValue().getType()
116                          << " must match return type: " << type;
117   }
118   // Integer values must be signless.
119   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120     return emitOpError("integer return type must be signless");
121   // Any float or elements attribute are acceptable.
122   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123     return emitOpError(
124         "value must be an integer, float, or elements attribute");
125   }
126   return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130   // The value's type must be the same as the provided type.
131   if (value.getType() != type)
132     return false;
133   // Integer values must be signless.
134   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135     return false;
136   // Integer, float, and element attributes are buildable.
137   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139 
140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141   return getValue();
142 }
143 
144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145                                  int64_t value, unsigned width) {
146   auto type = builder.getIntegerType(width);
147   arith::ConstantOp::build(builder, result, type,
148                            builder.getIntegerAttr(type, value));
149 }
150 
151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152                                  int64_t value, Type type) {
153   assert(type.isSignlessInteger() &&
154          "ConstantIntOp can only have signless integer type values");
155   arith::ConstantOp::build(builder, result, type,
156                            builder.getIntegerAttr(type, value));
157 }
158 
159 bool arith::ConstantIntOp::classof(Operation *op) {
160   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161     return constOp.getType().isSignlessInteger();
162   return false;
163 }
164 
165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166                                    const APFloat &value, FloatType type) {
167   arith::ConstantOp::build(builder, result, type,
168                            builder.getFloatAttr(type, value));
169 }
170 
171 bool arith::ConstantFloatOp::classof(Operation *op) {
172   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173     return constOp.getType().isa<FloatType>();
174   return false;
175 }
176 
177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178                                    int64_t value) {
179   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180                            builder.getIndexAttr(value));
181 }
182 
183 bool arith::ConstantIndexOp::classof(Operation *op) {
184   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185     return constOp.getType().isIndex();
186   return false;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194   // addi(x, 0) -> x
195   if (matchPattern(getRhs(), m_Zero()))
196     return getLhs();
197 
198   // addi(subi(a, b), b) -> a
199   if (auto sub = getLhs().getDefiningOp<SubIOp>())
200     if (getRhs() == sub.getRhs())
201       return sub.getLhs();
202 
203   // addi(b, subi(a, b)) -> a
204   if (auto sub = getRhs().getDefiningOp<SubIOp>())
205     if (getLhs() == sub.getRhs())
206       return sub.getLhs();
207 
208   return constFoldBinaryOp<IntegerAttr>(
209       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211 
212 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
213                                                 MLIRContext *context) {
214   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215       context);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221 
222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223   // subi(x,x) -> 0
224   if (getOperand(0) == getOperand(1))
225     return Builder(getContext()).getZeroAttr(getType());
226   // subi(x,0) -> x
227   if (matchPattern(getRhs(), m_Zero()))
228     return getLhs();
229 
230   return constFoldBinaryOp<IntegerAttr>(
231       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233 
234 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
235                                                 MLIRContext *context) {
236   patterns
237       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247   // muli(x, 0) -> 0
248   if (matchPattern(getRhs(), m_Zero()))
249     return getRhs();
250   // muli(x, 1) -> x
251   if (matchPattern(getRhs(), m_One()))
252     return getOperand(0);
253   // TODO: Handle the overflow case.
254 
255   // default folder
256   return constFoldBinaryOp<IntegerAttr>(
257       operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265   // divui (x, 1) -> x.
266   if (matchPattern(getRhs(), m_One()))
267     return getLhs();
268 
269   // Don't fold if it would require a division by zero.
270   bool div0 = false;
271   auto result =
272       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
273         if (div0 || !b) {
274           div0 = true;
275           return a;
276         }
277         return a.udiv(b);
278       });
279 
280   return div0 ? Attribute() : result;
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // DivSIOp
285 //===----------------------------------------------------------------------===//
286 
287 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
288   // divsi (x, 1) -> x.
289   if (matchPattern(getRhs(), m_One()))
290     return getLhs();
291 
292   // Don't fold if it would overflow or if it requires a division by zero.
293   bool overflowOrDiv0 = false;
294   auto result =
295       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
296         if (overflowOrDiv0 || !b) {
297           overflowOrDiv0 = true;
298           return a;
299         }
300         return a.sdiv_ov(b, overflowOrDiv0);
301       });
302 
303   return overflowOrDiv0 ? Attribute() : result;
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Ceil and floor division folding helpers
308 //===----------------------------------------------------------------------===//
309 
310 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
311                                     bool &overflow) {
312   // Returns (a-1)/b + 1
313   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
314   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
315   return val.sadd_ov(one, overflow);
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // CeilDivUIOp
320 //===----------------------------------------------------------------------===//
321 
322 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
323   // ceildivui (x, 1) -> x.
324   if (matchPattern(getRhs(), m_One()))
325     return getLhs();
326 
327   bool overflowOrDiv0 = false;
328   auto result =
329       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
330         if (overflowOrDiv0 || !b) {
331           overflowOrDiv0 = true;
332           return a;
333         }
334         APInt quotient = a.udiv(b);
335         if (!a.urem(b))
336           return quotient;
337         APInt one(a.getBitWidth(), 1, true);
338         return quotient.uadd_ov(one, overflowOrDiv0);
339       });
340 
341   return overflowOrDiv0 ? Attribute() : result;
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // CeilDivSIOp
346 //===----------------------------------------------------------------------===//
347 
348 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
349   // ceildivsi (x, 1) -> x.
350   if (matchPattern(getRhs(), m_One()))
351     return getLhs();
352 
353   // Don't fold if it would overflow or if it requires a division by zero.
354   bool overflowOrDiv0 = false;
355   auto result =
356       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
357         if (overflowOrDiv0 || !b) {
358           overflowOrDiv0 = true;
359           return a;
360         }
361         if (!a)
362           return a;
363         // After this point we know that neither a or b are zero.
364         unsigned bits = a.getBitWidth();
365         APInt zero = APInt::getZero(bits);
366         bool aGtZero = a.sgt(zero);
367         bool bGtZero = b.sgt(zero);
368         if (aGtZero && bGtZero) {
369           // Both positive, return ceil(a, b).
370           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
371         }
372         if (!aGtZero && !bGtZero) {
373           // Both negative, return ceil(-a, -b).
374           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
375           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
376           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
377         }
378         if (!aGtZero && bGtZero) {
379           // A is negative, b is positive, return - ( -a / b).
380           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
381           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
382           return zero.ssub_ov(div, overflowOrDiv0);
383         }
384         // A is positive, b is negative, return - (a / -b).
385         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
386         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
387         return zero.ssub_ov(div, overflowOrDiv0);
388       });
389 
390   return overflowOrDiv0 ? Attribute() : result;
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // FloorDivSIOp
395 //===----------------------------------------------------------------------===//
396 
397 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
398   // floordivsi (x, 1) -> x.
399   if (matchPattern(getRhs(), m_One()))
400     return getLhs();
401 
402   // Don't fold if it would overflow or if it requires a division by zero.
403   bool overflowOrDiv0 = false;
404   auto result =
405       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
406         if (overflowOrDiv0 || !b) {
407           overflowOrDiv0 = true;
408           return a;
409         }
410         if (!a)
411           return a;
412         // After this point we know that neither a or b are zero.
413         unsigned bits = a.getBitWidth();
414         APInt zero = APInt::getZero(bits);
415         bool aGtZero = a.sgt(zero);
416         bool bGtZero = b.sgt(zero);
417         if (aGtZero && bGtZero) {
418           // Both positive, return a / b.
419           return a.sdiv_ov(b, overflowOrDiv0);
420         }
421         if (!aGtZero && !bGtZero) {
422           // Both negative, return -a / -b.
423           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
424           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
425           return posA.sdiv_ov(posB, overflowOrDiv0);
426         }
427         if (!aGtZero && bGtZero) {
428           // A is negative, b is positive, return - ceil(-a, b).
429           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
430           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
431           return zero.ssub_ov(ceil, overflowOrDiv0);
432         }
433         // A is positive, b is negative, return - ceil(a, -b).
434         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
435         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
436         return zero.ssub_ov(ceil, overflowOrDiv0);
437       });
438 
439   return overflowOrDiv0 ? Attribute() : result;
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // RemUIOp
444 //===----------------------------------------------------------------------===//
445 
446 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
447   // remui (x, 1) -> 0.
448   if (matchPattern(getRhs(), m_One()))
449     return Builder(getContext()).getZeroAttr(getType());
450 
451   // Don't fold if it would require a division by zero.
452   bool div0 = false;
453   auto result =
454       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
455         if (div0 || b.isNullValue()) {
456           div0 = true;
457           return a;
458         }
459         return a.urem(b);
460       });
461 
462   return div0 ? Attribute() : result;
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // RemSIOp
467 //===----------------------------------------------------------------------===//
468 
469 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
470   // remsi (x, 1) -> 0.
471   if (matchPattern(getRhs(), m_One()))
472     return Builder(getContext()).getZeroAttr(getType());
473 
474   // Don't fold if it would require a division by zero.
475   bool div0 = false;
476   auto result =
477       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
478         if (div0 || b.isNullValue()) {
479           div0 = true;
480           return a;
481         }
482         return a.srem(b);
483       });
484 
485   return div0 ? Attribute() : result;
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // AndIOp
490 //===----------------------------------------------------------------------===//
491 
492 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
493   /// and(x, 0) -> 0
494   if (matchPattern(getRhs(), m_Zero()))
495     return getRhs();
496   /// and(x, allOnes) -> x
497   APInt intValue;
498   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
499     return getLhs();
500 
501   return constFoldBinaryOp<IntegerAttr>(
502       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // OrIOp
507 //===----------------------------------------------------------------------===//
508 
509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
510   /// or(x, 0) -> x
511   if (matchPattern(getRhs(), m_Zero()))
512     return getLhs();
513   /// or(x, <all ones>) -> <all ones>
514   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
515     if (rhsAttr.getValue().isAllOnes())
516       return rhsAttr;
517 
518   return constFoldBinaryOp<IntegerAttr>(
519       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // XOrIOp
524 //===----------------------------------------------------------------------===//
525 
526 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
527   /// xor(x, 0) -> x
528   if (matchPattern(getRhs(), m_Zero()))
529     return getLhs();
530   /// xor(x, x) -> 0
531   if (getLhs() == getRhs())
532     return Builder(getContext()).getZeroAttr(getType());
533   /// xor(xor(x, a), a) -> x
534   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
535     if (prev.getRhs() == getRhs())
536       return prev.getLhs();
537 
538   return constFoldBinaryOp<IntegerAttr>(
539       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
540 }
541 
542 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
543                                                 MLIRContext *context) {
544   patterns.add<XOrINotCmpI>(context);
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // NegFOp
549 //===----------------------------------------------------------------------===//
550 
551 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
552   /// negf(negf(x)) -> x
553   if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
554     return op.getOperand();
555   return constFoldUnaryOp<FloatAttr>(operands,
556                                      [](const APFloat &a) { return -a; });
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // AddFOp
561 //===----------------------------------------------------------------------===//
562 
563 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
564   // addf(x, -0) -> x
565   if (matchPattern(getRhs(), m_NegZeroFloat()))
566     return getLhs();
567 
568   return constFoldBinaryOp<FloatAttr>(
569       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // SubFOp
574 //===----------------------------------------------------------------------===//
575 
576 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
577   // subf(x, +0) -> x
578   if (matchPattern(getRhs(), m_PosZeroFloat()))
579     return getLhs();
580 
581   return constFoldBinaryOp<FloatAttr>(
582       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // MaxFOp
587 //===----------------------------------------------------------------------===//
588 
589 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
590   assert(operands.size() == 2 && "maxf takes two operands");
591 
592   // maxf(x,x) -> x
593   if (getLhs() == getRhs())
594     return getRhs();
595 
596   // maxf(x, -inf) -> x
597   if (matchPattern(getRhs(), m_NegInfFloat()))
598     return getLhs();
599 
600   return constFoldBinaryOp<FloatAttr>(
601       operands,
602       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // MaxSIOp
607 //===----------------------------------------------------------------------===//
608 
609 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
610   assert(operands.size() == 2 && "binary operation takes two operands");
611 
612   // maxsi(x,x) -> x
613   if (getLhs() == getRhs())
614     return getRhs();
615 
616   APInt intValue;
617   // maxsi(x,MAX_INT) -> MAX_INT
618   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
619       intValue.isMaxSignedValue())
620     return getRhs();
621 
622   // maxsi(x, MIN_INT) -> x
623   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
624       intValue.isMinSignedValue())
625     return getLhs();
626 
627   return constFoldBinaryOp<IntegerAttr>(operands,
628                                         [](const APInt &a, const APInt &b) {
629                                           return llvm::APIntOps::smax(a, b);
630                                         });
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // MaxUIOp
635 //===----------------------------------------------------------------------===//
636 
637 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
638   assert(operands.size() == 2 && "binary operation takes two operands");
639 
640   // maxui(x,x) -> x
641   if (getLhs() == getRhs())
642     return getRhs();
643 
644   APInt intValue;
645   // maxui(x,MAX_INT) -> MAX_INT
646   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
647     return getRhs();
648 
649   // maxui(x, MIN_INT) -> x
650   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
651     return getLhs();
652 
653   return constFoldBinaryOp<IntegerAttr>(operands,
654                                         [](const APInt &a, const APInt &b) {
655                                           return llvm::APIntOps::umax(a, b);
656                                         });
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // MinFOp
661 //===----------------------------------------------------------------------===//
662 
663 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
664   assert(operands.size() == 2 && "minf takes two operands");
665 
666   // minf(x,x) -> x
667   if (getLhs() == getRhs())
668     return getRhs();
669 
670   // minf(x, +inf) -> x
671   if (matchPattern(getRhs(), m_PosInfFloat()))
672     return getLhs();
673 
674   return constFoldBinaryOp<FloatAttr>(
675       operands,
676       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
677 }
678 
679 //===----------------------------------------------------------------------===//
680 // MinSIOp
681 //===----------------------------------------------------------------------===//
682 
683 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
684   assert(operands.size() == 2 && "binary operation takes two operands");
685 
686   // minsi(x,x) -> x
687   if (getLhs() == getRhs())
688     return getRhs();
689 
690   APInt intValue;
691   // minsi(x,MIN_INT) -> MIN_INT
692   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
693       intValue.isMinSignedValue())
694     return getRhs();
695 
696   // minsi(x, MAX_INT) -> x
697   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
698       intValue.isMaxSignedValue())
699     return getLhs();
700 
701   return constFoldBinaryOp<IntegerAttr>(operands,
702                                         [](const APInt &a, const APInt &b) {
703                                           return llvm::APIntOps::smin(a, b);
704                                         });
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // MinUIOp
709 //===----------------------------------------------------------------------===//
710 
711 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
712   assert(operands.size() == 2 && "binary operation takes two operands");
713 
714   // minui(x,x) -> x
715   if (getLhs() == getRhs())
716     return getRhs();
717 
718   APInt intValue;
719   // minui(x,MIN_INT) -> MIN_INT
720   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
721     return getRhs();
722 
723   // minui(x, MAX_INT) -> x
724   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
725     return getLhs();
726 
727   return constFoldBinaryOp<IntegerAttr>(operands,
728                                         [](const APInt &a, const APInt &b) {
729                                           return llvm::APIntOps::umin(a, b);
730                                         });
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // MulFOp
735 //===----------------------------------------------------------------------===//
736 
737 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
738   // mulf(x, 1) -> x
739   if (matchPattern(getRhs(), m_OneFloat()))
740     return getLhs();
741 
742   return constFoldBinaryOp<FloatAttr>(
743       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
744 }
745 
746 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
747                                                 MLIRContext *context) {
748   patterns.add<MulFOfNegF>(context);
749 }
750 
751 //===----------------------------------------------------------------------===//
752 // DivFOp
753 //===----------------------------------------------------------------------===//
754 
755 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
756   // divf(x, 1) -> x
757   if (matchPattern(getRhs(), m_OneFloat()))
758     return getLhs();
759 
760   return constFoldBinaryOp<FloatAttr>(
761       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
762 }
763 
764 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
765                                                 MLIRContext *context) {
766   patterns.add<DivFOfNegF>(context);
767 }
768 
769 //===----------------------------------------------------------------------===//
770 // RemFOp
771 //===----------------------------------------------------------------------===//
772 
773 OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) {
774   return constFoldBinaryOp<FloatAttr>(operands,
775                                       [](const APFloat &a, const APFloat &b) {
776                                         APFloat result(a);
777                                         (void)result.remainder(b);
778                                         return result;
779                                       });
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // Utility functions for verifying cast ops
784 //===----------------------------------------------------------------------===//
785 
786 template <typename... Types>
787 using type_list = std::tuple<Types...> *;
788 
789 /// Returns a non-null type only if the provided type is one of the allowed
790 /// types or one of the allowed shaped types of the allowed types. Returns the
791 /// element type if a valid shaped type is provided.
792 template <typename... ShapedTypes, typename... ElementTypes>
793 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
794                               type_list<ElementTypes...>) {
795   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
796     return {};
797 
798   auto underlyingType = getElementTypeOrSelf(type);
799   if (!underlyingType.isa<ElementTypes...>())
800     return {};
801 
802   return underlyingType;
803 }
804 
805 /// Get allowed underlying types for vectors and tensors.
806 template <typename... ElementTypes>
807 static Type getTypeIfLike(Type type) {
808   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
809                            type_list<ElementTypes...>());
810 }
811 
812 /// Get allowed underlying types for vectors, tensors, and memrefs.
813 template <typename... ElementTypes>
814 static Type getTypeIfLikeOrMemRef(Type type) {
815   return getUnderlyingType(type,
816                            type_list<VectorType, TensorType, MemRefType>(),
817                            type_list<ElementTypes...>());
818 }
819 
820 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
821   return inputs.size() == 1 && outputs.size() == 1 &&
822          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
823 }
824 
825 //===----------------------------------------------------------------------===//
826 // Verifiers for integer and floating point extension/truncation ops
827 //===----------------------------------------------------------------------===//
828 
829 // Extend ops can only extend to a wider type.
830 template <typename ValType, typename Op>
831 static LogicalResult verifyExtOp(Op op) {
832   Type srcType = getElementTypeOrSelf(op.getIn().getType());
833   Type dstType = getElementTypeOrSelf(op.getType());
834 
835   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
836     return op.emitError("result type ")
837            << dstType << " must be wider than operand type " << srcType;
838 
839   return success();
840 }
841 
842 // Truncate ops can only truncate to a shorter type.
843 template <typename ValType, typename Op>
844 static LogicalResult verifyTruncateOp(Op op) {
845   Type srcType = getElementTypeOrSelf(op.getIn().getType());
846   Type dstType = getElementTypeOrSelf(op.getType());
847 
848   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
849     return op.emitError("result type ")
850            << dstType << " must be shorter than operand type " << srcType;
851 
852   return success();
853 }
854 
855 /// Validate a cast that changes the width of a type.
856 template <template <typename> class WidthComparator, typename... ElementTypes>
857 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
858   if (!areValidCastInputsAndOutputs(inputs, outputs))
859     return false;
860 
861   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
862   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
863   if (!srcType || !dstType)
864     return false;
865 
866   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
867                                      srcType.getIntOrFloatBitWidth());
868 }
869 
870 //===----------------------------------------------------------------------===//
871 // ExtUIOp
872 //===----------------------------------------------------------------------===//
873 
874 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
875   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
876     getInMutable().assign(lhs.getIn());
877     return getResult();
878   }
879   Type resType = getType();
880   unsigned bitWidth;
881   if (auto shapedType = resType.dyn_cast<ShapedType>())
882     bitWidth = shapedType.getElementTypeBitWidth();
883   else
884     bitWidth = resType.getIntOrFloatBitWidth();
885   return constFoldCastOp<IntegerAttr, IntegerAttr>(
886       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
887         return a.zext(bitWidth);
888       });
889 }
890 
891 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
892   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
893 }
894 
895 LogicalResult arith::ExtUIOp::verify() {
896   return verifyExtOp<IntegerType>(*this);
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // ExtSIOp
901 //===----------------------------------------------------------------------===//
902 
903 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
904   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
905     getInMutable().assign(lhs.getIn());
906     return getResult();
907   }
908   Type resType = getType();
909   unsigned bitWidth;
910   if (auto shapedType = resType.dyn_cast<ShapedType>())
911     bitWidth = shapedType.getElementTypeBitWidth();
912   else
913     bitWidth = resType.getIntOrFloatBitWidth();
914   return constFoldCastOp<IntegerAttr, IntegerAttr>(
915       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
916         return a.sext(bitWidth);
917       });
918 }
919 
920 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
921   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
922 }
923 
924 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
925                                                  MLIRContext *context) {
926   patterns.add<ExtSIOfExtUI>(context);
927 }
928 
929 LogicalResult arith::ExtSIOp::verify() {
930   return verifyExtOp<IntegerType>(*this);
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // ExtFOp
935 //===----------------------------------------------------------------------===//
936 
937 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
938   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
939 }
940 
941 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
942 
943 //===----------------------------------------------------------------------===//
944 // TruncIOp
945 //===----------------------------------------------------------------------===//
946 
947 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
948   assert(operands.size() == 1 && "unary operation takes one operand");
949 
950   // trunci(zexti(a)) -> a
951   // trunci(sexti(a)) -> a
952   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
953       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
954     return getOperand().getDefiningOp()->getOperand(0);
955 
956   // trunci(trunci(a)) -> trunci(a))
957   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
958     setOperand(getOperand().getDefiningOp()->getOperand(0));
959     return getResult();
960   }
961 
962   Type resType = getType();
963   unsigned bitWidth;
964   if (auto shapedType = resType.dyn_cast<ShapedType>())
965     bitWidth = shapedType.getElementTypeBitWidth();
966   else
967     bitWidth = resType.getIntOrFloatBitWidth();
968 
969   return constFoldCastOp<IntegerAttr, IntegerAttr>(
970       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
971         return a.trunc(bitWidth);
972       });
973 }
974 
975 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
976   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
977 }
978 
979 LogicalResult arith::TruncIOp::verify() {
980   return verifyTruncateOp<IntegerType>(*this);
981 }
982 
983 //===----------------------------------------------------------------------===//
984 // TruncFOp
985 //===----------------------------------------------------------------------===//
986 
987 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
988 /// can be represented without precision loss or rounding.
989 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
990   assert(operands.size() == 1 && "unary operation takes one operand");
991 
992   auto constOperand = operands.front();
993   if (!constOperand || !constOperand.isa<FloatAttr>())
994     return {};
995 
996   // Convert to target type via 'double'.
997   double sourceValue =
998       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
999   auto targetAttr = FloatAttr::get(getType(), sourceValue);
1000 
1001   // Propagate if constant's value does not change after truncation.
1002   if (sourceValue == targetAttr.getValue().convertToDouble())
1003     return targetAttr;
1004 
1005   return {};
1006 }
1007 
1008 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1009   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1010 }
1011 
1012 LogicalResult arith::TruncFOp::verify() {
1013   return verifyTruncateOp<FloatType>(*this);
1014 }
1015 
1016 //===----------------------------------------------------------------------===//
1017 // AndIOp
1018 //===----------------------------------------------------------------------===//
1019 
1020 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1021                                                 MLIRContext *context) {
1022   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1023 }
1024 
1025 //===----------------------------------------------------------------------===//
1026 // OrIOp
1027 //===----------------------------------------------------------------------===//
1028 
1029 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1030                                                MLIRContext *context) {
1031   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1032 }
1033 
1034 //===----------------------------------------------------------------------===//
1035 // Verifiers for casts between integers and floats.
1036 //===----------------------------------------------------------------------===//
1037 
1038 template <typename From, typename To>
1039 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1040   if (!areValidCastInputsAndOutputs(inputs, outputs))
1041     return false;
1042 
1043   auto srcType = getTypeIfLike<From>(inputs.front());
1044   auto dstType = getTypeIfLike<To>(outputs.back());
1045 
1046   return srcType && dstType;
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // UIToFPOp
1051 //===----------------------------------------------------------------------===//
1052 
1053 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1054   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1055 }
1056 
1057 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1058   Type resType = getType();
1059   Type resEleType;
1060   if (auto shapedType = resType.dyn_cast<ShapedType>())
1061     resEleType = shapedType.getElementType();
1062   else
1063     resEleType = resType;
1064   return constFoldCastOp<IntegerAttr, FloatAttr>(
1065       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1066         FloatType floatTy = resEleType.cast<FloatType>();
1067         APFloat apf(floatTy.getFloatSemantics(),
1068                     APInt::getZero(floatTy.getWidth()));
1069         apf.convertFromAPInt(a, /*IsSigned=*/false,
1070                              APFloat::rmNearestTiesToEven);
1071         return apf;
1072       });
1073 }
1074 
1075 //===----------------------------------------------------------------------===//
1076 // SIToFPOp
1077 //===----------------------------------------------------------------------===//
1078 
1079 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1080   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1081 }
1082 
1083 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1084   Type resType = getType();
1085   Type resEleType;
1086   if (auto shapedType = resType.dyn_cast<ShapedType>())
1087     resEleType = shapedType.getElementType();
1088   else
1089     resEleType = resType;
1090   return constFoldCastOp<IntegerAttr, FloatAttr>(
1091       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1092         FloatType floatTy = resEleType.cast<FloatType>();
1093         APFloat apf(floatTy.getFloatSemantics(),
1094                     APInt::getZero(floatTy.getWidth()));
1095         apf.convertFromAPInt(a, /*IsSigned=*/true,
1096                              APFloat::rmNearestTiesToEven);
1097         return apf;
1098       });
1099 }
1100 //===----------------------------------------------------------------------===//
1101 // FPToUIOp
1102 //===----------------------------------------------------------------------===//
1103 
1104 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1105   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1106 }
1107 
1108 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1109   Type resType = getType();
1110   Type resEleType;
1111   if (auto shapedType = resType.dyn_cast<ShapedType>())
1112     resEleType = shapedType.getElementType();
1113   else
1114     resEleType = resType;
1115   return constFoldCastOp<FloatAttr, IntegerAttr>(
1116       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1117         IntegerType intTy = resEleType.cast<IntegerType>();
1118         bool ignored;
1119         APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1120         castStatus = APFloat::opInvalidOp !=
1121                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1122         return api;
1123       });
1124 }
1125 
1126 //===----------------------------------------------------------------------===//
1127 // FPToSIOp
1128 //===----------------------------------------------------------------------===//
1129 
1130 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1131   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1132 }
1133 
1134 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1135   Type resType = getType();
1136   Type resEleType;
1137   if (auto shapedType = resType.dyn_cast<ShapedType>())
1138     resEleType = shapedType.getElementType();
1139   else
1140     resEleType = resType;
1141   return constFoldCastOp<FloatAttr, IntegerAttr>(
1142       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1143         IntegerType intTy = resEleType.cast<IntegerType>();
1144         bool ignored;
1145         APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1146         castStatus = APFloat::opInvalidOp !=
1147                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1148         return api;
1149       });
1150 }
1151 
1152 //===----------------------------------------------------------------------===//
1153 // IndexCastOp
1154 //===----------------------------------------------------------------------===//
1155 
1156 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1157                                            TypeRange outputs) {
1158   if (!areValidCastInputsAndOutputs(inputs, outputs))
1159     return false;
1160 
1161   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1162   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1163   if (!srcType || !dstType)
1164     return false;
1165 
1166   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1167          (srcType.isSignlessInteger() && dstType.isIndex());
1168 }
1169 
1170 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1171   // index_cast(constant) -> constant
1172   // A little hack because we go through int. Otherwise, the size of the
1173   // constant might need to change.
1174   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1175     return IntegerAttr::get(getType(), value.getInt());
1176 
1177   return {};
1178 }
1179 
1180 void arith::IndexCastOp::getCanonicalizationPatterns(
1181     RewritePatternSet &patterns, MLIRContext *context) {
1182   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1183 }
1184 
1185 //===----------------------------------------------------------------------===//
1186 // BitcastOp
1187 //===----------------------------------------------------------------------===//
1188 
1189 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1190   if (!areValidCastInputsAndOutputs(inputs, outputs))
1191     return false;
1192 
1193   auto srcType =
1194       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1195   auto dstType =
1196       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1197   if (!srcType || !dstType)
1198     return false;
1199 
1200   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1201 }
1202 
1203 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1204   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1205 
1206   auto resType = getType();
1207   auto operand = operands[0];
1208   if (!operand)
1209     return {};
1210 
1211   /// Bitcast dense elements.
1212   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1213     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1214   /// Other shaped types unhandled.
1215   if (resType.isa<ShapedType>())
1216     return {};
1217 
1218   /// Bitcast integer or float to integer or float.
1219   APInt bits = operand.isa<FloatAttr>()
1220                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1221                    : operand.cast<IntegerAttr>().getValue();
1222 
1223   if (auto resFloatType = resType.dyn_cast<FloatType>())
1224     return FloatAttr::get(resType,
1225                           APFloat(resFloatType.getFloatSemantics(), bits));
1226   return IntegerAttr::get(resType, bits);
1227 }
1228 
1229 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1230                                                    MLIRContext *context) {
1231   patterns.add<BitcastOfBitcast>(context);
1232 }
1233 
1234 //===----------------------------------------------------------------------===//
1235 // Helpers for compare ops
1236 //===----------------------------------------------------------------------===//
1237 
1238 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1239 static Type getI1SameShape(Type type) {
1240   auto i1Type = IntegerType::get(type.getContext(), 1);
1241   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1242     return RankedTensorType::get(tensorType.getShape(), i1Type);
1243   if (type.isa<UnrankedTensorType>())
1244     return UnrankedTensorType::get(i1Type);
1245   if (auto vectorType = type.dyn_cast<VectorType>())
1246     return VectorType::get(vectorType.getShape(), i1Type,
1247                            vectorType.getNumScalableDims());
1248   return i1Type;
1249 }
1250 
1251 //===----------------------------------------------------------------------===//
1252 // CmpIOp
1253 //===----------------------------------------------------------------------===//
1254 
1255 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1256 /// comparison predicates.
1257 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1258                                     const APInt &lhs, const APInt &rhs) {
1259   switch (predicate) {
1260   case arith::CmpIPredicate::eq:
1261     return lhs.eq(rhs);
1262   case arith::CmpIPredicate::ne:
1263     return lhs.ne(rhs);
1264   case arith::CmpIPredicate::slt:
1265     return lhs.slt(rhs);
1266   case arith::CmpIPredicate::sle:
1267     return lhs.sle(rhs);
1268   case arith::CmpIPredicate::sgt:
1269     return lhs.sgt(rhs);
1270   case arith::CmpIPredicate::sge:
1271     return lhs.sge(rhs);
1272   case arith::CmpIPredicate::ult:
1273     return lhs.ult(rhs);
1274   case arith::CmpIPredicate::ule:
1275     return lhs.ule(rhs);
1276   case arith::CmpIPredicate::ugt:
1277     return lhs.ugt(rhs);
1278   case arith::CmpIPredicate::uge:
1279     return lhs.uge(rhs);
1280   }
1281   llvm_unreachable("unknown cmpi predicate kind");
1282 }
1283 
1284 /// Returns true if the predicate is true for two equal operands.
1285 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1286   switch (predicate) {
1287   case arith::CmpIPredicate::eq:
1288   case arith::CmpIPredicate::sle:
1289   case arith::CmpIPredicate::sge:
1290   case arith::CmpIPredicate::ule:
1291   case arith::CmpIPredicate::uge:
1292     return true;
1293   case arith::CmpIPredicate::ne:
1294   case arith::CmpIPredicate::slt:
1295   case arith::CmpIPredicate::sgt:
1296   case arith::CmpIPredicate::ult:
1297   case arith::CmpIPredicate::ugt:
1298     return false;
1299   }
1300   llvm_unreachable("unknown cmpi predicate kind");
1301 }
1302 
1303 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1304   auto boolAttr = BoolAttr::get(ctx, value);
1305   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1306   if (!shapedType)
1307     return boolAttr;
1308   return DenseElementsAttr::get(shapedType, boolAttr);
1309 }
1310 
1311 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1312   assert(operands.size() == 2 && "cmpi takes two operands");
1313 
1314   // cmpi(pred, x, x)
1315   if (getLhs() == getRhs()) {
1316     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1317     return getBoolAttribute(getType(), getContext(), val);
1318   }
1319 
1320   if (matchPattern(getRhs(), m_Zero())) {
1321     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1322       // extsi(%x : i1 -> iN) != 0  ->  %x
1323       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1324           getPredicate() == arith::CmpIPredicate::ne)
1325         return extOp.getOperand();
1326     }
1327     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1328       // extui(%x : i1 -> iN) != 0  ->  %x
1329       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1330           getPredicate() == arith::CmpIPredicate::ne)
1331         return extOp.getOperand();
1332     }
1333   }
1334 
1335   // Move constant to the right side.
1336   if (operands[0] && !operands[1]) {
1337     // Do not use invertPredicate, as it will change eq to ne and vice versa.
1338     using Pred = CmpIPredicate;
1339     const std::pair<Pred, Pred> invPreds[] = {
1340         {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1341         {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1342         {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1343         {Pred::ne, Pred::ne},
1344     };
1345     Pred origPred = getPredicate();
1346     for (auto pred : invPreds) {
1347       if (origPred == pred.first) {
1348         setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second));
1349         Value lhs = getLhs();
1350         Value rhs = getRhs();
1351         getLhsMutable().assign(rhs);
1352         getRhsMutable().assign(lhs);
1353         return getResult();
1354       }
1355     }
1356     llvm_unreachable("unknown cmpi predicate kind");
1357   }
1358 
1359   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1360   if (!lhs)
1361     return {};
1362 
1363   // We are moving constants to the right side; So if lhs is constant rhs is
1364   // guaranteed to be a constant.
1365   auto rhs = operands.back().cast<IntegerAttr>();
1366 
1367   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1368   return BoolAttr::get(getContext(), val);
1369 }
1370 
1371 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1372                                                 MLIRContext *context) {
1373   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1374 }
1375 
1376 //===----------------------------------------------------------------------===//
1377 // CmpFOp
1378 //===----------------------------------------------------------------------===//
1379 
1380 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1381 /// comparison predicates.
1382 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1383                                     const APFloat &lhs, const APFloat &rhs) {
1384   auto cmpResult = lhs.compare(rhs);
1385   switch (predicate) {
1386   case arith::CmpFPredicate::AlwaysFalse:
1387     return false;
1388   case arith::CmpFPredicate::OEQ:
1389     return cmpResult == APFloat::cmpEqual;
1390   case arith::CmpFPredicate::OGT:
1391     return cmpResult == APFloat::cmpGreaterThan;
1392   case arith::CmpFPredicate::OGE:
1393     return cmpResult == APFloat::cmpGreaterThan ||
1394            cmpResult == APFloat::cmpEqual;
1395   case arith::CmpFPredicate::OLT:
1396     return cmpResult == APFloat::cmpLessThan;
1397   case arith::CmpFPredicate::OLE:
1398     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1399   case arith::CmpFPredicate::ONE:
1400     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1401   case arith::CmpFPredicate::ORD:
1402     return cmpResult != APFloat::cmpUnordered;
1403   case arith::CmpFPredicate::UEQ:
1404     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1405   case arith::CmpFPredicate::UGT:
1406     return cmpResult == APFloat::cmpUnordered ||
1407            cmpResult == APFloat::cmpGreaterThan;
1408   case arith::CmpFPredicate::UGE:
1409     return cmpResult == APFloat::cmpUnordered ||
1410            cmpResult == APFloat::cmpGreaterThan ||
1411            cmpResult == APFloat::cmpEqual;
1412   case arith::CmpFPredicate::ULT:
1413     return cmpResult == APFloat::cmpUnordered ||
1414            cmpResult == APFloat::cmpLessThan;
1415   case arith::CmpFPredicate::ULE:
1416     return cmpResult == APFloat::cmpUnordered ||
1417            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1418   case arith::CmpFPredicate::UNE:
1419     return cmpResult != APFloat::cmpEqual;
1420   case arith::CmpFPredicate::UNO:
1421     return cmpResult == APFloat::cmpUnordered;
1422   case arith::CmpFPredicate::AlwaysTrue:
1423     return true;
1424   }
1425   llvm_unreachable("unknown cmpf predicate kind");
1426 }
1427 
1428 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1429   assert(operands.size() == 2 && "cmpf takes two operands");
1430 
1431   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1432   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1433 
1434   // If one operand is NaN, making them both NaN does not change the result.
1435   if (lhs && lhs.getValue().isNaN())
1436     rhs = lhs;
1437   if (rhs && rhs.getValue().isNaN())
1438     lhs = rhs;
1439 
1440   if (!lhs || !rhs)
1441     return {};
1442 
1443   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1444   return BoolAttr::get(getContext(), val);
1445 }
1446 
1447 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1448 public:
1449   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1450 
1451   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1452                                                  bool isUnsigned) {
1453     using namespace arith;
1454     switch (pred) {
1455     case CmpFPredicate::UEQ:
1456     case CmpFPredicate::OEQ:
1457       return CmpIPredicate::eq;
1458     case CmpFPredicate::UGT:
1459     case CmpFPredicate::OGT:
1460       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1461     case CmpFPredicate::UGE:
1462     case CmpFPredicate::OGE:
1463       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1464     case CmpFPredicate::ULT:
1465     case CmpFPredicate::OLT:
1466       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1467     case CmpFPredicate::ULE:
1468     case CmpFPredicate::OLE:
1469       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1470     case CmpFPredicate::UNE:
1471     case CmpFPredicate::ONE:
1472       return CmpIPredicate::ne;
1473     default:
1474       llvm_unreachable("Unexpected predicate!");
1475     }
1476   }
1477 
1478   LogicalResult matchAndRewrite(CmpFOp op,
1479                                 PatternRewriter &rewriter) const override {
1480     FloatAttr flt;
1481     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1482       return failure();
1483 
1484     const APFloat &rhs = flt.getValue();
1485 
1486     // Don't attempt to fold a nan.
1487     if (rhs.isNaN())
1488       return failure();
1489 
1490     // Get the width of the mantissa.  We don't want to hack on conversions that
1491     // might lose information from the integer, e.g. "i64 -> float"
1492     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1493     int mantissaWidth = floatTy.getFPMantissaWidth();
1494     if (mantissaWidth <= 0)
1495       return failure();
1496 
1497     bool isUnsigned;
1498     Value intVal;
1499 
1500     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1501       isUnsigned = false;
1502       intVal = si.getIn();
1503     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1504       isUnsigned = true;
1505       intVal = ui.getIn();
1506     } else {
1507       return failure();
1508     }
1509 
1510     // Check to see that the input is converted from an integer type that is
1511     // small enough that preserves all bits.
1512     auto intTy = intVal.getType().cast<IntegerType>();
1513     auto intWidth = intTy.getWidth();
1514 
1515     // Number of bits representing values, as opposed to the sign
1516     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1517 
1518     // Following test does NOT adjust intWidth downwards for signed inputs,
1519     // because the most negative value still requires all the mantissa bits
1520     // to distinguish it from one less than that value.
1521     if ((int)intWidth > mantissaWidth) {
1522       // Conversion would lose accuracy. Check if loss can impact comparison.
1523       int exponent = ilogb(rhs);
1524       if (exponent == APFloat::IEK_Inf) {
1525         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1526         if (maxExponent < (int)valueBits) {
1527           // Conversion could create infinity.
1528           return failure();
1529         }
1530       } else {
1531         // Note that if rhs is zero or NaN, then Exp is negative
1532         // and first condition is trivially false.
1533         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1534           // Conversion could affect comparison.
1535           return failure();
1536         }
1537       }
1538     }
1539 
1540     // Convert to equivalent cmpi predicate
1541     CmpIPredicate pred;
1542     switch (op.getPredicate()) {
1543     case CmpFPredicate::ORD:
1544       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1545       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1546                                                  /*width=*/1);
1547       return success();
1548     case CmpFPredicate::UNO:
1549       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1550       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1551                                                  /*width=*/1);
1552       return success();
1553     default:
1554       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1555       break;
1556     }
1557 
1558     if (!isUnsigned) {
1559       // If the rhs value is > SignedMax, fold the comparison.  This handles
1560       // +INF and large values.
1561       APFloat signedMax(rhs.getSemantics());
1562       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1563                                  APFloat::rmNearestTiesToEven);
1564       if (signedMax < rhs) { // smax < 13123.0
1565         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1566             pred == CmpIPredicate::sle)
1567           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1568                                                      /*width=*/1);
1569         else
1570           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1571                                                      /*width=*/1);
1572         return success();
1573       }
1574     } else {
1575       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1576       // +INF and large values.
1577       APFloat unsignedMax(rhs.getSemantics());
1578       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1579                                    APFloat::rmNearestTiesToEven);
1580       if (unsignedMax < rhs) { // umax < 13123.0
1581         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1582             pred == CmpIPredicate::ule)
1583           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1584                                                      /*width=*/1);
1585         else
1586           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1587                                                      /*width=*/1);
1588         return success();
1589       }
1590     }
1591 
1592     if (!isUnsigned) {
1593       // See if the rhs value is < SignedMin.
1594       APFloat signedMin(rhs.getSemantics());
1595       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1596                                  APFloat::rmNearestTiesToEven);
1597       if (signedMin > rhs) { // smin > 12312.0
1598         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1599             pred == CmpIPredicate::sge)
1600           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1601                                                      /*width=*/1);
1602         else
1603           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1604                                                      /*width=*/1);
1605         return success();
1606       }
1607     } else {
1608       // See if the rhs value is < UnsignedMin.
1609       APFloat unsignedMin(rhs.getSemantics());
1610       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1611                                    APFloat::rmNearestTiesToEven);
1612       if (unsignedMin > rhs) { // umin > 12312.0
1613         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1614             pred == CmpIPredicate::uge)
1615           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1616                                                      /*width=*/1);
1617         else
1618           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1619                                                      /*width=*/1);
1620         return success();
1621       }
1622     }
1623 
1624     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1625     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1626     // casting the FP value to the integer value and back, checking for
1627     // equality. Don't do this for zero, because -0.0 is not fractional.
1628     bool ignored;
1629     APSInt rhsInt(intWidth, isUnsigned);
1630     if (APFloat::opInvalidOp ==
1631         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1632       // Undefined behavior invoked - the destination type can't represent
1633       // the input constant.
1634       return failure();
1635     }
1636 
1637     if (!rhs.isZero()) {
1638       APFloat apf(floatTy.getFloatSemantics(),
1639                   APInt::getZero(floatTy.getWidth()));
1640       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1641 
1642       bool equal = apf == rhs;
1643       if (!equal) {
1644         // If we had a comparison against a fractional value, we have to adjust
1645         // the compare predicate and sometimes the value.  rhsInt is rounded
1646         // towards zero at this point.
1647         switch (pred) {
1648         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1649           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1650                                                      /*width=*/1);
1651           return success();
1652         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1653           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1654                                                      /*width=*/1);
1655           return success();
1656         case CmpIPredicate::ule:
1657           // (float)int <= 4.4   --> int <= 4
1658           // (float)int <= -4.4  --> false
1659           if (rhs.isNegative()) {
1660             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1661                                                        /*width=*/1);
1662             return success();
1663           }
1664           break;
1665         case CmpIPredicate::sle:
1666           // (float)int <= 4.4   --> int <= 4
1667           // (float)int <= -4.4  --> int < -4
1668           if (rhs.isNegative())
1669             pred = CmpIPredicate::slt;
1670           break;
1671         case CmpIPredicate::ult:
1672           // (float)int < -4.4   --> false
1673           // (float)int < 4.4    --> int <= 4
1674           if (rhs.isNegative()) {
1675             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1676                                                        /*width=*/1);
1677             return success();
1678           }
1679           pred = CmpIPredicate::ule;
1680           break;
1681         case CmpIPredicate::slt:
1682           // (float)int < -4.4   --> int < -4
1683           // (float)int < 4.4    --> int <= 4
1684           if (!rhs.isNegative())
1685             pred = CmpIPredicate::sle;
1686           break;
1687         case CmpIPredicate::ugt:
1688           // (float)int > 4.4    --> int > 4
1689           // (float)int > -4.4   --> true
1690           if (rhs.isNegative()) {
1691             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1692                                                        /*width=*/1);
1693             return success();
1694           }
1695           break;
1696         case CmpIPredicate::sgt:
1697           // (float)int > 4.4    --> int > 4
1698           // (float)int > -4.4   --> int >= -4
1699           if (rhs.isNegative())
1700             pred = CmpIPredicate::sge;
1701           break;
1702         case CmpIPredicate::uge:
1703           // (float)int >= -4.4   --> true
1704           // (float)int >= 4.4    --> int > 4
1705           if (rhs.isNegative()) {
1706             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1707                                                        /*width=*/1);
1708             return success();
1709           }
1710           pred = CmpIPredicate::ugt;
1711           break;
1712         case CmpIPredicate::sge:
1713           // (float)int >= -4.4   --> int >= -4
1714           // (float)int >= 4.4    --> int > 4
1715           if (!rhs.isNegative())
1716             pred = CmpIPredicate::sgt;
1717           break;
1718         }
1719       }
1720     }
1721 
1722     // Lower this FP comparison into an appropriate integer version of the
1723     // comparison.
1724     rewriter.replaceOpWithNewOp<CmpIOp>(
1725         op, pred, intVal,
1726         rewriter.create<ConstantOp>(
1727             op.getLoc(), intVal.getType(),
1728             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1729     return success();
1730   }
1731 };
1732 
1733 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1734                                                 MLIRContext *context) {
1735   patterns.insert<CmpFIntToFPConst>(context);
1736 }
1737 
1738 //===----------------------------------------------------------------------===//
1739 // SelectOp
1740 //===----------------------------------------------------------------------===//
1741 
1742 // Transforms a select of a boolean to arithmetic operations
1743 //
1744 //  arith.select %arg, %x, %y : i1
1745 //
1746 //  becomes
1747 //
1748 //  and(%arg, %x) or and(!%arg, %y)
1749 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1750   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1751 
1752   LogicalResult matchAndRewrite(arith::SelectOp op,
1753                                 PatternRewriter &rewriter) const override {
1754     if (!op.getType().isInteger(1))
1755       return failure();
1756 
1757     Value falseConstant =
1758         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1759     Value notCondition = rewriter.create<arith::XOrIOp>(
1760         op.getLoc(), op.getCondition(), falseConstant);
1761 
1762     Value trueVal = rewriter.create<arith::AndIOp>(
1763         op.getLoc(), op.getCondition(), op.getTrueValue());
1764     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1765                                                     op.getFalseValue());
1766     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1767     return success();
1768   }
1769 };
1770 
1771 //  select %arg, %c1, %c0 => extui %arg
1772 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1773   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1774 
1775   LogicalResult matchAndRewrite(arith::SelectOp op,
1776                                 PatternRewriter &rewriter) const override {
1777     // Cannot extui i1 to i1, or i1 to f32
1778     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1779       return failure();
1780 
1781     // select %x, c1, %c0 => extui %arg
1782     if (matchPattern(op.getTrueValue(), m_One()) &&
1783         matchPattern(op.getFalseValue(), m_Zero())) {
1784       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1785                                                   op.getCondition());
1786       return success();
1787     }
1788 
1789     // select %x, c0, %c1 => extui (xor %arg, true)
1790     if (matchPattern(op.getTrueValue(), m_Zero()) &&
1791         matchPattern(op.getFalseValue(), m_One())) {
1792       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1793           op, op.getType(),
1794           rewriter.create<arith::XOrIOp>(
1795               op.getLoc(), op.getCondition(),
1796               rewriter.create<arith::ConstantIntOp>(
1797                   op.getLoc(), 1, op.getCondition().getType())));
1798       return success();
1799     }
1800 
1801     return failure();
1802   }
1803 };
1804 
1805 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1806                                                   MLIRContext *context) {
1807   results.add<SelectI1Simplify, SelectToExtUI>(context);
1808 }
1809 
1810 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1811   Value trueVal = getTrueValue();
1812   Value falseVal = getFalseValue();
1813   if (trueVal == falseVal)
1814     return trueVal;
1815 
1816   Value condition = getCondition();
1817 
1818   // select true, %0, %1 => %0
1819   if (matchPattern(condition, m_One()))
1820     return trueVal;
1821 
1822   // select false, %0, %1 => %1
1823   if (matchPattern(condition, m_Zero()))
1824     return falseVal;
1825 
1826   // select %x, true, false => %x
1827   if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
1828       matchPattern(getFalseValue(), m_Zero()))
1829     return condition;
1830 
1831   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1832     auto pred = cmp.getPredicate();
1833     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1834       auto cmpLhs = cmp.getLhs();
1835       auto cmpRhs = cmp.getRhs();
1836 
1837       // %0 = arith.cmpi eq, %arg0, %arg1
1838       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1839 
1840       // %0 = arith.cmpi ne, %arg0, %arg1
1841       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1842 
1843       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1844           (cmpRhs == trueVal && cmpLhs == falseVal))
1845         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1846     }
1847   }
1848   return nullptr;
1849 }
1850 
1851 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1852   Type conditionType, resultType;
1853   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1854   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1855       parser.parseOptionalAttrDict(result.attributes) ||
1856       parser.parseColonType(resultType))
1857     return failure();
1858 
1859   // Check for the explicit condition type if this is a masked tensor or vector.
1860   if (succeeded(parser.parseOptionalComma())) {
1861     conditionType = resultType;
1862     if (parser.parseType(resultType))
1863       return failure();
1864   } else {
1865     conditionType = parser.getBuilder().getI1Type();
1866   }
1867 
1868   result.addTypes(resultType);
1869   return parser.resolveOperands(operands,
1870                                 {conditionType, resultType, resultType},
1871                                 parser.getNameLoc(), result.operands);
1872 }
1873 
1874 void arith::SelectOp::print(OpAsmPrinter &p) {
1875   p << " " << getOperands();
1876   p.printOptionalAttrDict((*this)->getAttrs());
1877   p << " : ";
1878   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1879     p << condType << ", ";
1880   p << getType();
1881 }
1882 
1883 LogicalResult arith::SelectOp::verify() {
1884   Type conditionType = getCondition().getType();
1885   if (conditionType.isSignlessInteger(1))
1886     return success();
1887 
1888   // If the result type is a vector or tensor, the type can be a mask with the
1889   // same elements.
1890   Type resultType = getType();
1891   if (!resultType.isa<TensorType, VectorType>())
1892     return emitOpError() << "expected condition to be a signless i1, but got "
1893                          << conditionType;
1894   Type shapedConditionType = getI1SameShape(resultType);
1895   if (conditionType != shapedConditionType) {
1896     return emitOpError() << "expected condition type to have the same shape "
1897                             "as the result type, expected "
1898                          << shapedConditionType << ", but got "
1899                          << conditionType;
1900   }
1901   return success();
1902 }
1903 //===----------------------------------------------------------------------===//
1904 // ShLIOp
1905 //===----------------------------------------------------------------------===//
1906 
1907 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1908   // Don't fold if shifting more than the bit width.
1909   bool bounded = false;
1910   auto result = constFoldBinaryOp<IntegerAttr>(
1911       operands, [&](const APInt &a, const APInt &b) {
1912         bounded = b.ule(b.getBitWidth());
1913         return a.shl(b);
1914       });
1915   return bounded ? result : Attribute();
1916 }
1917 
1918 //===----------------------------------------------------------------------===//
1919 // ShRUIOp
1920 //===----------------------------------------------------------------------===//
1921 
1922 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1923   // Don't fold if shifting more than the bit width.
1924   bool bounded = false;
1925   auto result = constFoldBinaryOp<IntegerAttr>(
1926       operands, [&](const APInt &a, const APInt &b) {
1927         bounded = b.ule(b.getBitWidth());
1928         return a.lshr(b);
1929       });
1930   return bounded ? result : Attribute();
1931 }
1932 
1933 //===----------------------------------------------------------------------===//
1934 // ShRSIOp
1935 //===----------------------------------------------------------------------===//
1936 
1937 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1938   // Don't fold if shifting more than the bit width.
1939   bool bounded = false;
1940   auto result = constFoldBinaryOp<IntegerAttr>(
1941       operands, [&](const APInt &a, const APInt &b) {
1942         bounded = b.ule(b.getBitWidth());
1943         return a.ashr(b);
1944       });
1945   return bounded ? result : Attribute();
1946 }
1947 
1948 //===----------------------------------------------------------------------===//
1949 // Atomic Enum
1950 //===----------------------------------------------------------------------===//
1951 
1952 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1953 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1954                                             OpBuilder &builder, Location loc) {
1955   switch (kind) {
1956   case AtomicRMWKind::maxf:
1957     return builder.getFloatAttr(
1958         resultType,
1959         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1960                         /*Negative=*/true));
1961   case AtomicRMWKind::addf:
1962   case AtomicRMWKind::addi:
1963   case AtomicRMWKind::maxu:
1964   case AtomicRMWKind::ori:
1965     return builder.getZeroAttr(resultType);
1966   case AtomicRMWKind::andi:
1967     return builder.getIntegerAttr(
1968         resultType,
1969         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1970   case AtomicRMWKind::maxs:
1971     return builder.getIntegerAttr(
1972         resultType,
1973         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1974   case AtomicRMWKind::minf:
1975     return builder.getFloatAttr(
1976         resultType,
1977         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1978                         /*Negative=*/false));
1979   case AtomicRMWKind::mins:
1980     return builder.getIntegerAttr(
1981         resultType,
1982         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1983   case AtomicRMWKind::minu:
1984     return builder.getIntegerAttr(
1985         resultType,
1986         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1987   case AtomicRMWKind::muli:
1988     return builder.getIntegerAttr(resultType, 1);
1989   case AtomicRMWKind::mulf:
1990     return builder.getFloatAttr(resultType, 1);
1991   // TODO: Add remaining reduction operations.
1992   default:
1993     (void)emitOptionalError(loc, "Reduction operation type not supported");
1994     break;
1995   }
1996   return nullptr;
1997 }
1998 
1999 /// Returns the identity value associated with an AtomicRMWKind op.
2000 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2001                                     OpBuilder &builder, Location loc) {
2002   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
2003   return builder.create<arith::ConstantOp>(loc, attr);
2004 }
2005 
2006 /// Return the value obtained by applying the reduction operation kind
2007 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2008 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2009                                   Location loc, Value lhs, Value rhs) {
2010   switch (op) {
2011   case AtomicRMWKind::addf:
2012     return builder.create<arith::AddFOp>(loc, lhs, rhs);
2013   case AtomicRMWKind::addi:
2014     return builder.create<arith::AddIOp>(loc, lhs, rhs);
2015   case AtomicRMWKind::mulf:
2016     return builder.create<arith::MulFOp>(loc, lhs, rhs);
2017   case AtomicRMWKind::muli:
2018     return builder.create<arith::MulIOp>(loc, lhs, rhs);
2019   case AtomicRMWKind::maxf:
2020     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
2021   case AtomicRMWKind::minf:
2022     return builder.create<arith::MinFOp>(loc, lhs, rhs);
2023   case AtomicRMWKind::maxs:
2024     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2025   case AtomicRMWKind::mins:
2026     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2027   case AtomicRMWKind::maxu:
2028     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2029   case AtomicRMWKind::minu:
2030     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2031   case AtomicRMWKind::ori:
2032     return builder.create<arith::OrIOp>(loc, lhs, rhs);
2033   case AtomicRMWKind::andi:
2034     return builder.create<arith::AndIOp>(loc, lhs, rhs);
2035   // TODO: Add remaining reduction operations.
2036   default:
2037     (void)emitOptionalError(loc, "Reduction operation type not supported");
2038     break;
2039   }
2040   return nullptr;
2041 }
2042 
2043 //===----------------------------------------------------------------------===//
2044 // TableGen'd op method definitions
2045 //===----------------------------------------------------------------------===//
2046 
2047 #define GET_OP_CLASSES
2048 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2049 
2050 //===----------------------------------------------------------------------===//
2051 // TableGen'd enum attribute definitions
2052 //===----------------------------------------------------------------------===//
2053 
2054 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2055