1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30                                    Attribute lhs, Attribute rhs) {
31   return builder.getIntegerAttr(res.getType(),
32                                 lhs.cast<IntegerAttr>().getInt() +
33                                     rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37                                    Attribute lhs, Attribute rhs) {
38   return builder.getIntegerAttr(res.getType(),
39                                 lhs.cast<IntegerAttr>().getInt() -
40                                     rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45   switch (pred) {
46   case arith::CmpIPredicate::eq:
47     return arith::CmpIPredicate::ne;
48   case arith::CmpIPredicate::ne:
49     return arith::CmpIPredicate::eq;
50   case arith::CmpIPredicate::slt:
51     return arith::CmpIPredicate::sge;
52   case arith::CmpIPredicate::sle:
53     return arith::CmpIPredicate::sgt;
54   case arith::CmpIPredicate::sgt:
55     return arith::CmpIPredicate::sle;
56   case arith::CmpIPredicate::sge:
57     return arith::CmpIPredicate::slt;
58   case arith::CmpIPredicate::ult:
59     return arith::CmpIPredicate::uge;
60   case arith::CmpIPredicate::ule:
61     return arith::CmpIPredicate::ugt;
62   case arith::CmpIPredicate::ugt:
63     return arith::CmpIPredicate::ule;
64   case arith::CmpIPredicate::uge:
65     return arith::CmpIPredicate::ult;
66   }
67   llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71   return arith::CmpIPredicateAttr::get(pred.getContext(),
72                                        invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88     function_ref<void(Value, StringRef)> setNameFn) {
89   auto type = getType();
90   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91     auto intType = type.dyn_cast<IntegerType>();
92 
93     // Sugar i1 constants with 'true' and 'false'.
94     if (intType && intType.getWidth() == 1)
95       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97     // Otherwise, build a compex name with the value and type.
98     SmallString<32> specialNameBuffer;
99     llvm::raw_svector_ostream specialName(specialNameBuffer);
100     specialName << 'c' << intCst.getInt();
101     if (intType)
102       specialName << '_' << type;
103     setNameFn(getResult(), specialName.str());
104   } else {
105     setNameFn(getResult(), "cst");
106   }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
111 LogicalResult arith::ConstantOp::verify() {
112   auto type = getType();
113   // The value's type must match the return type.
114   if (getValue().getType() != type) {
115     return emitOpError() << "value type " << getValue().getType()
116                          << " must match return type: " << type;
117   }
118   // Integer values must be signless.
119   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120     return emitOpError("integer return type must be signless");
121   // Any float or elements attribute are acceptable.
122   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123     return emitOpError(
124         "value must be an integer, float, or elements attribute");
125   }
126   return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130   // The value's type must be the same as the provided type.
131   if (value.getType() != type)
132     return false;
133   // Integer values must be signless.
134   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135     return false;
136   // Integer, float, and element attributes are buildable.
137   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139 
140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141   return getValue();
142 }
143 
144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145                                  int64_t value, unsigned width) {
146   auto type = builder.getIntegerType(width);
147   arith::ConstantOp::build(builder, result, type,
148                            builder.getIntegerAttr(type, value));
149 }
150 
151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152                                  int64_t value, Type type) {
153   assert(type.isSignlessInteger() &&
154          "ConstantIntOp can only have signless integer type values");
155   arith::ConstantOp::build(builder, result, type,
156                            builder.getIntegerAttr(type, value));
157 }
158 
159 bool arith::ConstantIntOp::classof(Operation *op) {
160   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161     return constOp.getType().isSignlessInteger();
162   return false;
163 }
164 
165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166                                    const APFloat &value, FloatType type) {
167   arith::ConstantOp::build(builder, result, type,
168                            builder.getFloatAttr(type, value));
169 }
170 
171 bool arith::ConstantFloatOp::classof(Operation *op) {
172   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173     return constOp.getType().isa<FloatType>();
174   return false;
175 }
176 
177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178                                    int64_t value) {
179   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180                            builder.getIndexAttr(value));
181 }
182 
183 bool arith::ConstantIndexOp::classof(Operation *op) {
184   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185     return constOp.getType().isIndex();
186   return false;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194   // addi(x, 0) -> x
195   if (matchPattern(getRhs(), m_Zero()))
196     return getLhs();
197 
198   // addi(subi(a, b), b) -> a
199   if (auto sub = getLhs().getDefiningOp<SubIOp>())
200     if (getRhs() == sub.getRhs())
201       return sub.getLhs();
202 
203   // addi(b, subi(a, b)) -> a
204   if (auto sub = getRhs().getDefiningOp<SubIOp>())
205     if (getLhs() == sub.getRhs())
206       return sub.getLhs();
207 
208   return constFoldBinaryOp<IntegerAttr>(
209       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211 
212 void arith::AddIOp::getCanonicalizationPatterns(
213     RewritePatternSet &patterns, MLIRContext *context) {
214   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215       context);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221 
222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223   // subi(x,x) -> 0
224   if (getOperand(0) == getOperand(1))
225     return Builder(getContext()).getZeroAttr(getType());
226   // subi(x,0) -> x
227   if (matchPattern(getRhs(), m_Zero()))
228     return getLhs();
229 
230   return constFoldBinaryOp<IntegerAttr>(
231       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233 
234 void arith::SubIOp::getCanonicalizationPatterns(
235     RewritePatternSet &patterns, MLIRContext *context) {
236   patterns
237       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247   // muli(x, 0) -> 0
248   if (matchPattern(getRhs(), m_Zero()))
249     return getRhs();
250   // muli(x, 1) -> x
251   if (matchPattern(getRhs(), m_One()))
252     return getOperand(0);
253   // TODO: Handle the overflow case.
254 
255   // default folder
256   return constFoldBinaryOp<IntegerAttr>(
257       operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265   // Don't fold if it would require a division by zero.
266   bool div0 = false;
267   auto result =
268       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
269         if (div0 || !b) {
270           div0 = true;
271           return a;
272         }
273         return a.udiv(b);
274       });
275 
276   // Fold out division by one. Assumes all tensors of all ones are splats.
277   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
278     if (rhs.getValue() == 1)
279       return getLhs();
280   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
281     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
282       return getLhs();
283   }
284 
285   return div0 ? Attribute() : result;
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // DivSIOp
290 //===----------------------------------------------------------------------===//
291 
292 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
293   // Don't fold if it would overflow or if it requires a division by zero.
294   bool overflowOrDiv0 = false;
295   auto result =
296       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
297         if (overflowOrDiv0 || !b) {
298           overflowOrDiv0 = true;
299           return a;
300         }
301         return a.sdiv_ov(b, overflowOrDiv0);
302       });
303 
304   // Fold out division by one. Assumes all tensors of all ones are splats.
305   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
306     if (rhs.getValue() == 1)
307       return getLhs();
308   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
309     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
310       return getLhs();
311   }
312 
313   return overflowOrDiv0 ? Attribute() : result;
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // Ceil and floor division folding helpers
318 //===----------------------------------------------------------------------===//
319 
320 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
321                                     bool &overflow) {
322   // Returns (a-1)/b + 1
323   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
324   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
325   return val.sadd_ov(one, overflow);
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // CeilDivUIOp
330 //===----------------------------------------------------------------------===//
331 
332 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
333   bool overflowOrDiv0 = false;
334   auto result =
335       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
336         if (overflowOrDiv0 || !b) {
337           overflowOrDiv0 = true;
338           return a;
339         }
340         APInt quotient = a.udiv(b);
341         if (!a.urem(b))
342           return quotient;
343         APInt one(a.getBitWidth(), 1, true);
344         return quotient.uadd_ov(one, overflowOrDiv0);
345       });
346   // Fold out ceil division by one. Assumes all tensors of all ones are
347   // splats.
348   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
349     if (rhs.getValue() == 1)
350       return getLhs();
351   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
352     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
353       return getLhs();
354   }
355 
356   return overflowOrDiv0 ? Attribute() : result;
357 }
358 
359 //===----------------------------------------------------------------------===//
360 // CeilDivSIOp
361 //===----------------------------------------------------------------------===//
362 
363 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
364   // Don't fold if it would overflow or if it requires a division by zero.
365   bool overflowOrDiv0 = false;
366   auto result =
367       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
368         if (overflowOrDiv0 || !b) {
369           overflowOrDiv0 = true;
370           return a;
371         }
372         if (!a)
373           return a;
374         // After this point we know that neither a or b are zero.
375         unsigned bits = a.getBitWidth();
376         APInt zero = APInt::getZero(bits);
377         bool aGtZero = a.sgt(zero);
378         bool bGtZero = b.sgt(zero);
379         if (aGtZero && bGtZero) {
380           // Both positive, return ceil(a, b).
381           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
382         }
383         if (!aGtZero && !bGtZero) {
384           // Both negative, return ceil(-a, -b).
385           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
386           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
387           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
388         }
389         if (!aGtZero && bGtZero) {
390           // A is negative, b is positive, return - ( -a / b).
391           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
392           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
393           return zero.ssub_ov(div, overflowOrDiv0);
394         }
395         // A is positive, b is negative, return - (a / -b).
396         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
397         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
398         return zero.ssub_ov(div, overflowOrDiv0);
399       });
400 
401   // Fold out ceil division by one. Assumes all tensors of all ones are
402   // splats.
403   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
404     if (rhs.getValue() == 1)
405       return getLhs();
406   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
407     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
408       return getLhs();
409   }
410 
411   return overflowOrDiv0 ? Attribute() : result;
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // FloorDivSIOp
416 //===----------------------------------------------------------------------===//
417 
418 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
419   // Don't fold if it would overflow or if it requires a division by zero.
420   bool overflowOrDiv0 = false;
421   auto result =
422       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
423         if (overflowOrDiv0 || !b) {
424           overflowOrDiv0 = true;
425           return a;
426         }
427         if (!a)
428           return a;
429         // After this point we know that neither a or b are zero.
430         unsigned bits = a.getBitWidth();
431         APInt zero = APInt::getZero(bits);
432         bool aGtZero = a.sgt(zero);
433         bool bGtZero = b.sgt(zero);
434         if (aGtZero && bGtZero) {
435           // Both positive, return a / b.
436           return a.sdiv_ov(b, overflowOrDiv0);
437         }
438         if (!aGtZero && !bGtZero) {
439           // Both negative, return -a / -b.
440           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
441           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
442           return posA.sdiv_ov(posB, overflowOrDiv0);
443         }
444         if (!aGtZero && bGtZero) {
445           // A is negative, b is positive, return - ceil(-a, b).
446           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
447           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
448           return zero.ssub_ov(ceil, overflowOrDiv0);
449         }
450         // A is positive, b is negative, return - ceil(a, -b).
451         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
452         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
453         return zero.ssub_ov(ceil, overflowOrDiv0);
454       });
455 
456   // Fold out floor division by one. Assumes all tensors of all ones are
457   // splats.
458   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
459     if (rhs.getValue() == 1)
460       return getLhs();
461   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
462     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
463       return getLhs();
464   }
465 
466   return overflowOrDiv0 ? Attribute() : result;
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // RemUIOp
471 //===----------------------------------------------------------------------===//
472 
473 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
474   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
475   if (!rhs)
476     return {};
477   auto rhsValue = rhs.getValue();
478 
479   // x % 1 = 0
480   if (rhsValue.isOneValue())
481     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
482 
483   // Don't fold if it requires division by zero.
484   if (rhsValue.isNullValue())
485     return {};
486 
487   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
488   if (!lhs)
489     return {};
490   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // RemSIOp
495 //===----------------------------------------------------------------------===//
496 
497 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
498   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
499   if (!rhs)
500     return {};
501   auto rhsValue = rhs.getValue();
502 
503   // x % 1 = 0
504   if (rhsValue.isOneValue())
505     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
506 
507   // Don't fold if it requires division by zero.
508   if (rhsValue.isNullValue())
509     return {};
510 
511   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
512   if (!lhs)
513     return {};
514   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // AndIOp
519 //===----------------------------------------------------------------------===//
520 
521 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
522   /// and(x, 0) -> 0
523   if (matchPattern(getRhs(), m_Zero()))
524     return getRhs();
525   /// and(x, allOnes) -> x
526   APInt intValue;
527   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
528     return getLhs();
529 
530   return constFoldBinaryOp<IntegerAttr>(
531       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // OrIOp
536 //===----------------------------------------------------------------------===//
537 
538 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
539   /// or(x, 0) -> x
540   if (matchPattern(getRhs(), m_Zero()))
541     return getLhs();
542   /// or(x, <all ones>) -> <all ones>
543   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
544     if (rhsAttr.getValue().isAllOnes())
545       return rhsAttr;
546 
547   return constFoldBinaryOp<IntegerAttr>(
548       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // XOrIOp
553 //===----------------------------------------------------------------------===//
554 
555 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
556   /// xor(x, 0) -> x
557   if (matchPattern(getRhs(), m_Zero()))
558     return getLhs();
559   /// xor(x, x) -> 0
560   if (getLhs() == getRhs())
561     return Builder(getContext()).getZeroAttr(getType());
562   /// xor(xor(x, a), a) -> x
563   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
564     if (prev.getRhs() == getRhs())
565       return prev.getLhs();
566 
567   return constFoldBinaryOp<IntegerAttr>(
568       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
569 }
570 
571 void arith::XOrIOp::getCanonicalizationPatterns(
572     RewritePatternSet &patterns, MLIRContext *context) {
573   patterns.add<XOrINotCmpI>(context);
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // NegFOp
578 //===----------------------------------------------------------------------===//
579 
580 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
581   return constFoldUnaryOp<FloatAttr>(operands,
582                                      [](const APFloat &a) { return -a; });
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // AddFOp
587 //===----------------------------------------------------------------------===//
588 
589 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
590   // addf(x, -0) -> x
591   if (matchPattern(getRhs(), m_NegZeroFloat()))
592     return getLhs();
593 
594   return constFoldBinaryOp<FloatAttr>(
595       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
596 }
597 
598 //===----------------------------------------------------------------------===//
599 // SubFOp
600 //===----------------------------------------------------------------------===//
601 
602 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
603   // subf(x, +0) -> x
604   if (matchPattern(getRhs(), m_PosZeroFloat()))
605     return getLhs();
606 
607   return constFoldBinaryOp<FloatAttr>(
608       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // MaxFOp
613 //===----------------------------------------------------------------------===//
614 
615 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
616   assert(operands.size() == 2 && "maxf takes two operands");
617 
618   // maxf(x,x) -> x
619   if (getLhs() == getRhs())
620     return getRhs();
621 
622   // maxf(x, -inf) -> x
623   if (matchPattern(getRhs(), m_NegInfFloat()))
624     return getLhs();
625 
626   return constFoldBinaryOp<FloatAttr>(
627       operands,
628       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // MaxSIOp
633 //===----------------------------------------------------------------------===//
634 
635 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
636   assert(operands.size() == 2 && "binary operation takes two operands");
637 
638   // maxsi(x,x) -> x
639   if (getLhs() == getRhs())
640     return getRhs();
641 
642   APInt intValue;
643   // maxsi(x,MAX_INT) -> MAX_INT
644   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
645       intValue.isMaxSignedValue())
646     return getRhs();
647 
648   // maxsi(x, MIN_INT) -> x
649   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
650       intValue.isMinSignedValue())
651     return getLhs();
652 
653   return constFoldBinaryOp<IntegerAttr>(operands,
654                                         [](const APInt &a, const APInt &b) {
655                                           return llvm::APIntOps::smax(a, b);
656                                         });
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // MaxUIOp
661 //===----------------------------------------------------------------------===//
662 
663 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
664   assert(operands.size() == 2 && "binary operation takes two operands");
665 
666   // maxui(x,x) -> x
667   if (getLhs() == getRhs())
668     return getRhs();
669 
670   APInt intValue;
671   // maxui(x,MAX_INT) -> MAX_INT
672   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
673     return getRhs();
674 
675   // maxui(x, MIN_INT) -> x
676   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
677     return getLhs();
678 
679   return constFoldBinaryOp<IntegerAttr>(operands,
680                                         [](const APInt &a, const APInt &b) {
681                                           return llvm::APIntOps::umax(a, b);
682                                         });
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // MinFOp
687 //===----------------------------------------------------------------------===//
688 
689 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
690   assert(operands.size() == 2 && "minf takes two operands");
691 
692   // minf(x,x) -> x
693   if (getLhs() == getRhs())
694     return getRhs();
695 
696   // minf(x, +inf) -> x
697   if (matchPattern(getRhs(), m_PosInfFloat()))
698     return getLhs();
699 
700   return constFoldBinaryOp<FloatAttr>(
701       operands,
702       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // MinSIOp
707 //===----------------------------------------------------------------------===//
708 
709 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
710   assert(operands.size() == 2 && "binary operation takes two operands");
711 
712   // minsi(x,x) -> x
713   if (getLhs() == getRhs())
714     return getRhs();
715 
716   APInt intValue;
717   // minsi(x,MIN_INT) -> MIN_INT
718   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
719       intValue.isMinSignedValue())
720     return getRhs();
721 
722   // minsi(x, MAX_INT) -> x
723   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
724       intValue.isMaxSignedValue())
725     return getLhs();
726 
727   return constFoldBinaryOp<IntegerAttr>(operands,
728                                         [](const APInt &a, const APInt &b) {
729                                           return llvm::APIntOps::smin(a, b);
730                                         });
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // MinUIOp
735 //===----------------------------------------------------------------------===//
736 
737 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
738   assert(operands.size() == 2 && "binary operation takes two operands");
739 
740   // minui(x,x) -> x
741   if (getLhs() == getRhs())
742     return getRhs();
743 
744   APInt intValue;
745   // minui(x,MIN_INT) -> MIN_INT
746   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
747     return getRhs();
748 
749   // minui(x, MAX_INT) -> x
750   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
751     return getLhs();
752 
753   return constFoldBinaryOp<IntegerAttr>(operands,
754                                         [](const APInt &a, const APInt &b) {
755                                           return llvm::APIntOps::umin(a, b);
756                                         });
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // MulFOp
761 //===----------------------------------------------------------------------===//
762 
763 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
764   // mulf(x, 1) -> x
765   if (matchPattern(getRhs(), m_OneFloat()))
766     return getLhs();
767 
768   return constFoldBinaryOp<FloatAttr>(
769       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // DivFOp
774 //===----------------------------------------------------------------------===//
775 
776 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
777   // divf(x, 1) -> x
778   if (matchPattern(getRhs(), m_OneFloat()))
779     return getLhs();
780 
781   return constFoldBinaryOp<FloatAttr>(
782       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // Utility functions for verifying cast ops
787 //===----------------------------------------------------------------------===//
788 
789 template <typename... Types>
790 using type_list = std::tuple<Types...> *;
791 
792 /// Returns a non-null type only if the provided type is one of the allowed
793 /// types or one of the allowed shaped types of the allowed types. Returns the
794 /// element type if a valid shaped type is provided.
795 template <typename... ShapedTypes, typename... ElementTypes>
796 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
797                               type_list<ElementTypes...>) {
798   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
799     return {};
800 
801   auto underlyingType = getElementTypeOrSelf(type);
802   if (!underlyingType.isa<ElementTypes...>())
803     return {};
804 
805   return underlyingType;
806 }
807 
808 /// Get allowed underlying types for vectors and tensors.
809 template <typename... ElementTypes>
810 static Type getTypeIfLike(Type type) {
811   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
812                            type_list<ElementTypes...>());
813 }
814 
815 /// Get allowed underlying types for vectors, tensors, and memrefs.
816 template <typename... ElementTypes>
817 static Type getTypeIfLikeOrMemRef(Type type) {
818   return getUnderlyingType(type,
819                            type_list<VectorType, TensorType, MemRefType>(),
820                            type_list<ElementTypes...>());
821 }
822 
823 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
824   return inputs.size() == 1 && outputs.size() == 1 &&
825          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // Verifiers for integer and floating point extension/truncation ops
830 //===----------------------------------------------------------------------===//
831 
832 // Extend ops can only extend to a wider type.
833 template <typename ValType, typename Op>
834 static LogicalResult verifyExtOp(Op op) {
835   Type srcType = getElementTypeOrSelf(op.getIn().getType());
836   Type dstType = getElementTypeOrSelf(op.getType());
837 
838   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
839     return op.emitError("result type ")
840            << dstType << " must be wider than operand type " << srcType;
841 
842   return success();
843 }
844 
845 // Truncate ops can only truncate to a shorter type.
846 template <typename ValType, typename Op>
847 static LogicalResult verifyTruncateOp(Op op) {
848   Type srcType = getElementTypeOrSelf(op.getIn().getType());
849   Type dstType = getElementTypeOrSelf(op.getType());
850 
851   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
852     return op.emitError("result type ")
853            << dstType << " must be shorter than operand type " << srcType;
854 
855   return success();
856 }
857 
858 /// Validate a cast that changes the width of a type.
859 template <template <typename> class WidthComparator, typename... ElementTypes>
860 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
861   if (!areValidCastInputsAndOutputs(inputs, outputs))
862     return false;
863 
864   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
865   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
866   if (!srcType || !dstType)
867     return false;
868 
869   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
870                                      srcType.getIntOrFloatBitWidth());
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // ExtUIOp
875 //===----------------------------------------------------------------------===//
876 
877 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
878   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
879     return IntegerAttr::get(
880         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
881 
882   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
883     getInMutable().assign(lhs.getIn());
884     return getResult();
885   }
886 
887   return {};
888 }
889 
890 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
891   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
892 }
893 
894 LogicalResult arith::ExtUIOp::verify() {
895   return verifyExtOp<IntegerType>(*this);
896 }
897 
898 //===----------------------------------------------------------------------===//
899 // ExtSIOp
900 //===----------------------------------------------------------------------===//
901 
902 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
903   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
904     return IntegerAttr::get(
905         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
906 
907   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
908     getInMutable().assign(lhs.getIn());
909     return getResult();
910   }
911 
912   return {};
913 }
914 
915 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
916   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
917 }
918 
919 void arith::ExtSIOp::getCanonicalizationPatterns(
920     RewritePatternSet &patterns, MLIRContext *context) {
921   patterns.add<ExtSIOfExtUI>(context);
922 }
923 
924 LogicalResult arith::ExtSIOp::verify() {
925   return verifyExtOp<IntegerType>(*this);
926 }
927 
928 //===----------------------------------------------------------------------===//
929 // ExtFOp
930 //===----------------------------------------------------------------------===//
931 
932 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
933   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
934 }
935 
936 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
937 
938 //===----------------------------------------------------------------------===//
939 // TruncIOp
940 //===----------------------------------------------------------------------===//
941 
942 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
943   assert(operands.size() == 1 && "unary operation takes one operand");
944 
945   // trunci(zexti(a)) -> a
946   // trunci(sexti(a)) -> a
947   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
948       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
949     return getOperand().getDefiningOp()->getOperand(0);
950 
951   // trunci(trunci(a)) -> trunci(a))
952   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
953     setOperand(getOperand().getDefiningOp()->getOperand(0));
954     return getResult();
955   }
956 
957   if (!operands[0])
958     return {};
959 
960   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
961     return IntegerAttr::get(
962         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
963   }
964 
965   return {};
966 }
967 
968 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
969   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
970 }
971 
972 LogicalResult arith::TruncIOp::verify() {
973   return verifyTruncateOp<IntegerType>(*this);
974 }
975 
976 //===----------------------------------------------------------------------===//
977 // TruncFOp
978 //===----------------------------------------------------------------------===//
979 
980 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
981 /// can be represented without precision loss or rounding.
982 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
983   assert(operands.size() == 1 && "unary operation takes one operand");
984 
985   auto constOperand = operands.front();
986   if (!constOperand || !constOperand.isa<FloatAttr>())
987     return {};
988 
989   // Convert to target type via 'double'.
990   double sourceValue =
991       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
992   auto targetAttr = FloatAttr::get(getType(), sourceValue);
993 
994   // Propagate if constant's value does not change after truncation.
995   if (sourceValue == targetAttr.getValue().convertToDouble())
996     return targetAttr;
997 
998   return {};
999 }
1000 
1001 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1002   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1003 }
1004 
1005 LogicalResult arith::TruncFOp::verify() {
1006   return verifyTruncateOp<FloatType>(*this);
1007 }
1008 
1009 //===----------------------------------------------------------------------===//
1010 // AndIOp
1011 //===----------------------------------------------------------------------===//
1012 
1013 void arith::AndIOp::getCanonicalizationPatterns(
1014     RewritePatternSet &patterns, MLIRContext *context) {
1015   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1016 }
1017 
1018 //===----------------------------------------------------------------------===//
1019 // OrIOp
1020 //===----------------------------------------------------------------------===//
1021 
1022 void arith::OrIOp::getCanonicalizationPatterns(
1023     RewritePatternSet &patterns, MLIRContext *context) {
1024   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1025 }
1026 
1027 //===----------------------------------------------------------------------===//
1028 // Verifiers for casts between integers and floats.
1029 //===----------------------------------------------------------------------===//
1030 
1031 template <typename From, typename To>
1032 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1033   if (!areValidCastInputsAndOutputs(inputs, outputs))
1034     return false;
1035 
1036   auto srcType = getTypeIfLike<From>(inputs.front());
1037   auto dstType = getTypeIfLike<To>(outputs.back());
1038 
1039   return srcType && dstType;
1040 }
1041 
1042 //===----------------------------------------------------------------------===//
1043 // UIToFPOp
1044 //===----------------------------------------------------------------------===//
1045 
1046 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1047   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1048 }
1049 
1050 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1051   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1052     const APInt &api = lhs.getValue();
1053     FloatType floatTy = getType().cast<FloatType>();
1054     APFloat apf(floatTy.getFloatSemantics(),
1055                 APInt::getZero(floatTy.getWidth()));
1056     apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
1057     return FloatAttr::get(floatTy, apf);
1058   }
1059   return {};
1060 }
1061 
1062 //===----------------------------------------------------------------------===//
1063 // SIToFPOp
1064 //===----------------------------------------------------------------------===//
1065 
1066 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1067   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1068 }
1069 
1070 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1071   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1072     const APInt &api = lhs.getValue();
1073     FloatType floatTy = getType().cast<FloatType>();
1074     APFloat apf(floatTy.getFloatSemantics(),
1075                 APInt::getZero(floatTy.getWidth()));
1076     apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
1077     return FloatAttr::get(floatTy, apf);
1078   }
1079   return {};
1080 }
1081 //===----------------------------------------------------------------------===//
1082 // FPToUIOp
1083 //===----------------------------------------------------------------------===//
1084 
1085 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1086   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1087 }
1088 
1089 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1090   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1091     const APFloat &apf = lhs.getValue();
1092     IntegerType intTy = getType().cast<IntegerType>();
1093     bool ignored;
1094     APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1095     if (APFloat::opInvalidOp ==
1096         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1097       // Undefined behavior invoked - the destination type can't represent
1098       // the input constant.
1099       return {};
1100     }
1101     return IntegerAttr::get(getType(), api);
1102   }
1103 
1104   return {};
1105 }
1106 
1107 //===----------------------------------------------------------------------===//
1108 // FPToSIOp
1109 //===----------------------------------------------------------------------===//
1110 
1111 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1112   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1113 }
1114 
1115 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1116   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1117     const APFloat &apf = lhs.getValue();
1118     IntegerType intTy = getType().cast<IntegerType>();
1119     bool ignored;
1120     APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1121     if (APFloat::opInvalidOp ==
1122         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1123       // Undefined behavior invoked - the destination type can't represent
1124       // the input constant.
1125       return {};
1126     }
1127     return IntegerAttr::get(getType(), api);
1128   }
1129 
1130   return {};
1131 }
1132 
1133 //===----------------------------------------------------------------------===//
1134 // IndexCastOp
1135 //===----------------------------------------------------------------------===//
1136 
1137 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1138                                            TypeRange outputs) {
1139   if (!areValidCastInputsAndOutputs(inputs, outputs))
1140     return false;
1141 
1142   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1143   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1144   if (!srcType || !dstType)
1145     return false;
1146 
1147   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1148          (srcType.isSignlessInteger() && dstType.isIndex());
1149 }
1150 
1151 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1152   // index_cast(constant) -> constant
1153   // A little hack because we go through int. Otherwise, the size of the
1154   // constant might need to change.
1155   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1156     return IntegerAttr::get(getType(), value.getInt());
1157 
1158   return {};
1159 }
1160 
1161 void arith::IndexCastOp::getCanonicalizationPatterns(
1162     RewritePatternSet &patterns, MLIRContext *context) {
1163   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1164 }
1165 
1166 //===----------------------------------------------------------------------===//
1167 // BitcastOp
1168 //===----------------------------------------------------------------------===//
1169 
1170 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1171   if (!areValidCastInputsAndOutputs(inputs, outputs))
1172     return false;
1173 
1174   auto srcType =
1175       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1176   auto dstType =
1177       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1178   if (!srcType || !dstType)
1179     return false;
1180 
1181   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1182 }
1183 
1184 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1185   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1186 
1187   auto resType = getType();
1188   auto operand = operands[0];
1189   if (!operand)
1190     return {};
1191 
1192   /// Bitcast dense elements.
1193   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1194     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1195   /// Other shaped types unhandled.
1196   if (resType.isa<ShapedType>())
1197     return {};
1198 
1199   /// Bitcast integer or float to integer or float.
1200   APInt bits = operand.isa<FloatAttr>()
1201                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1202                    : operand.cast<IntegerAttr>().getValue();
1203 
1204   if (auto resFloatType = resType.dyn_cast<FloatType>())
1205     return FloatAttr::get(resType,
1206                           APFloat(resFloatType.getFloatSemantics(), bits));
1207   return IntegerAttr::get(resType, bits);
1208 }
1209 
1210 void arith::BitcastOp::getCanonicalizationPatterns(
1211     RewritePatternSet &patterns, MLIRContext *context) {
1212   patterns.add<BitcastOfBitcast>(context);
1213 }
1214 
1215 //===----------------------------------------------------------------------===//
1216 // Helpers for compare ops
1217 //===----------------------------------------------------------------------===//
1218 
1219 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1220 static Type getI1SameShape(Type type) {
1221   auto i1Type = IntegerType::get(type.getContext(), 1);
1222   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1223     return RankedTensorType::get(tensorType.getShape(), i1Type);
1224   if (type.isa<UnrankedTensorType>())
1225     return UnrankedTensorType::get(i1Type);
1226   if (auto vectorType = type.dyn_cast<VectorType>())
1227     return VectorType::get(vectorType.getShape(), i1Type,
1228                            vectorType.getNumScalableDims());
1229   return i1Type;
1230 }
1231 
1232 //===----------------------------------------------------------------------===//
1233 // CmpIOp
1234 //===----------------------------------------------------------------------===//
1235 
1236 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1237 /// comparison predicates.
1238 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1239                                     const APInt &lhs, const APInt &rhs) {
1240   switch (predicate) {
1241   case arith::CmpIPredicate::eq:
1242     return lhs.eq(rhs);
1243   case arith::CmpIPredicate::ne:
1244     return lhs.ne(rhs);
1245   case arith::CmpIPredicate::slt:
1246     return lhs.slt(rhs);
1247   case arith::CmpIPredicate::sle:
1248     return lhs.sle(rhs);
1249   case arith::CmpIPredicate::sgt:
1250     return lhs.sgt(rhs);
1251   case arith::CmpIPredicate::sge:
1252     return lhs.sge(rhs);
1253   case arith::CmpIPredicate::ult:
1254     return lhs.ult(rhs);
1255   case arith::CmpIPredicate::ule:
1256     return lhs.ule(rhs);
1257   case arith::CmpIPredicate::ugt:
1258     return lhs.ugt(rhs);
1259   case arith::CmpIPredicate::uge:
1260     return lhs.uge(rhs);
1261   }
1262   llvm_unreachable("unknown cmpi predicate kind");
1263 }
1264 
1265 /// Returns true if the predicate is true for two equal operands.
1266 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1267   switch (predicate) {
1268   case arith::CmpIPredicate::eq:
1269   case arith::CmpIPredicate::sle:
1270   case arith::CmpIPredicate::sge:
1271   case arith::CmpIPredicate::ule:
1272   case arith::CmpIPredicate::uge:
1273     return true;
1274   case arith::CmpIPredicate::ne:
1275   case arith::CmpIPredicate::slt:
1276   case arith::CmpIPredicate::sgt:
1277   case arith::CmpIPredicate::ult:
1278   case arith::CmpIPredicate::ugt:
1279     return false;
1280   }
1281   llvm_unreachable("unknown cmpi predicate kind");
1282 }
1283 
1284 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1285   auto boolAttr = BoolAttr::get(ctx, value);
1286   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1287   if (!shapedType)
1288     return boolAttr;
1289   return DenseElementsAttr::get(shapedType, boolAttr);
1290 }
1291 
1292 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1293   assert(operands.size() == 2 && "cmpi takes two operands");
1294 
1295   // cmpi(pred, x, x)
1296   if (getLhs() == getRhs()) {
1297     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1298     return getBoolAttribute(getType(), getContext(), val);
1299   }
1300 
1301   if (matchPattern(getRhs(), m_Zero())) {
1302     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1303       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1304         // extsi(%x : i1 -> iN) != 0  ->  %x
1305         if (getPredicate() == arith::CmpIPredicate::ne) {
1306           return extOp.getOperand();
1307         }
1308       }
1309     }
1310     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1311       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1312         // extui(%x : i1 -> iN) != 0  ->  %x
1313         if (getPredicate() == arith::CmpIPredicate::ne) {
1314           return extOp.getOperand();
1315         }
1316       }
1317     }
1318   }
1319 
1320   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1321   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1322   if (!lhs || !rhs)
1323     return {};
1324 
1325   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1326   return BoolAttr::get(getContext(), val);
1327 }
1328 
1329 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1330                                                 MLIRContext *context) {
1331   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1332 }
1333 
1334 //===----------------------------------------------------------------------===//
1335 // CmpFOp
1336 //===----------------------------------------------------------------------===//
1337 
1338 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1339 /// comparison predicates.
1340 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1341                                     const APFloat &lhs, const APFloat &rhs) {
1342   auto cmpResult = lhs.compare(rhs);
1343   switch (predicate) {
1344   case arith::CmpFPredicate::AlwaysFalse:
1345     return false;
1346   case arith::CmpFPredicate::OEQ:
1347     return cmpResult == APFloat::cmpEqual;
1348   case arith::CmpFPredicate::OGT:
1349     return cmpResult == APFloat::cmpGreaterThan;
1350   case arith::CmpFPredicate::OGE:
1351     return cmpResult == APFloat::cmpGreaterThan ||
1352            cmpResult == APFloat::cmpEqual;
1353   case arith::CmpFPredicate::OLT:
1354     return cmpResult == APFloat::cmpLessThan;
1355   case arith::CmpFPredicate::OLE:
1356     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1357   case arith::CmpFPredicate::ONE:
1358     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1359   case arith::CmpFPredicate::ORD:
1360     return cmpResult != APFloat::cmpUnordered;
1361   case arith::CmpFPredicate::UEQ:
1362     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1363   case arith::CmpFPredicate::UGT:
1364     return cmpResult == APFloat::cmpUnordered ||
1365            cmpResult == APFloat::cmpGreaterThan;
1366   case arith::CmpFPredicate::UGE:
1367     return cmpResult == APFloat::cmpUnordered ||
1368            cmpResult == APFloat::cmpGreaterThan ||
1369            cmpResult == APFloat::cmpEqual;
1370   case arith::CmpFPredicate::ULT:
1371     return cmpResult == APFloat::cmpUnordered ||
1372            cmpResult == APFloat::cmpLessThan;
1373   case arith::CmpFPredicate::ULE:
1374     return cmpResult == APFloat::cmpUnordered ||
1375            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1376   case arith::CmpFPredicate::UNE:
1377     return cmpResult != APFloat::cmpEqual;
1378   case arith::CmpFPredicate::UNO:
1379     return cmpResult == APFloat::cmpUnordered;
1380   case arith::CmpFPredicate::AlwaysTrue:
1381     return true;
1382   }
1383   llvm_unreachable("unknown cmpf predicate kind");
1384 }
1385 
1386 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1387   assert(operands.size() == 2 && "cmpf takes two operands");
1388 
1389   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1390   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1391 
1392   // If one operand is NaN, making them both NaN does not change the result.
1393   if (lhs && lhs.getValue().isNaN())
1394     rhs = lhs;
1395   if (rhs && rhs.getValue().isNaN())
1396     lhs = rhs;
1397 
1398   if (!lhs || !rhs)
1399     return {};
1400 
1401   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1402   return BoolAttr::get(getContext(), val);
1403 }
1404 
1405 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1406 public:
1407   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1408 
1409   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1410                                                  bool isUnsigned) {
1411     using namespace arith;
1412     switch (pred) {
1413     case CmpFPredicate::UEQ:
1414     case CmpFPredicate::OEQ:
1415       return CmpIPredicate::eq;
1416     case CmpFPredicate::UGT:
1417     case CmpFPredicate::OGT:
1418       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1419     case CmpFPredicate::UGE:
1420     case CmpFPredicate::OGE:
1421       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1422     case CmpFPredicate::ULT:
1423     case CmpFPredicate::OLT:
1424       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1425     case CmpFPredicate::ULE:
1426     case CmpFPredicate::OLE:
1427       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1428     case CmpFPredicate::UNE:
1429     case CmpFPredicate::ONE:
1430       return CmpIPredicate::ne;
1431     default:
1432       llvm_unreachable("Unexpected predicate!");
1433     }
1434   }
1435 
1436   LogicalResult matchAndRewrite(CmpFOp op,
1437                                 PatternRewriter &rewriter) const override {
1438     FloatAttr flt;
1439     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1440       return failure();
1441 
1442     const APFloat &rhs = flt.getValue();
1443 
1444     // Don't attempt to fold a nan.
1445     if (rhs.isNaN())
1446       return failure();
1447 
1448     // Get the width of the mantissa.  We don't want to hack on conversions that
1449     // might lose information from the integer, e.g. "i64 -> float"
1450     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1451     int mantissaWidth = floatTy.getFPMantissaWidth();
1452     if (mantissaWidth <= 0)
1453       return failure();
1454 
1455     bool isUnsigned;
1456     Value intVal;
1457 
1458     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1459       isUnsigned = false;
1460       intVal = si.getIn();
1461     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1462       isUnsigned = true;
1463       intVal = ui.getIn();
1464     } else {
1465       return failure();
1466     }
1467 
1468     // Check to see that the input is converted from an integer type that is
1469     // small enough that preserves all bits.
1470     auto intTy = intVal.getType().cast<IntegerType>();
1471     auto intWidth = intTy.getWidth();
1472 
1473     // Number of bits representing values, as opposed to the sign
1474     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1475 
1476     // Following test does NOT adjust intWidth downwards for signed inputs,
1477     // because the most negative value still requires all the mantissa bits
1478     // to distinguish it from one less than that value.
1479     if ((int)intWidth > mantissaWidth) {
1480       // Conversion would lose accuracy. Check if loss can impact comparison.
1481       int exponent = ilogb(rhs);
1482       if (exponent == APFloat::IEK_Inf) {
1483         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1484         if (maxExponent < (int)valueBits) {
1485           // Conversion could create infinity.
1486           return failure();
1487         }
1488       } else {
1489         // Note that if rhs is zero or NaN, then Exp is negative
1490         // and first condition is trivially false.
1491         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1492           // Conversion could affect comparison.
1493           return failure();
1494         }
1495       }
1496     }
1497 
1498     // Convert to equivalent cmpi predicate
1499     CmpIPredicate pred;
1500     switch (op.getPredicate()) {
1501     case CmpFPredicate::ORD:
1502       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1503       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1504                                                  /*width=*/1);
1505       return success();
1506     case CmpFPredicate::UNO:
1507       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1508       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1509                                                  /*width=*/1);
1510       return success();
1511     default:
1512       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1513       break;
1514     }
1515 
1516     if (!isUnsigned) {
1517       // If the rhs value is > SignedMax, fold the comparison.  This handles
1518       // +INF and large values.
1519       APFloat signedMax(rhs.getSemantics());
1520       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1521                                  APFloat::rmNearestTiesToEven);
1522       if (signedMax < rhs) { // smax < 13123.0
1523         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1524             pred == CmpIPredicate::sle)
1525           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1526                                                      /*width=*/1);
1527         else
1528           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1529                                                      /*width=*/1);
1530         return success();
1531       }
1532     } else {
1533       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1534       // +INF and large values.
1535       APFloat unsignedMax(rhs.getSemantics());
1536       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1537                                    APFloat::rmNearestTiesToEven);
1538       if (unsignedMax < rhs) { // umax < 13123.0
1539         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1540             pred == CmpIPredicate::ule)
1541           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1542                                                      /*width=*/1);
1543         else
1544           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1545                                                      /*width=*/1);
1546         return success();
1547       }
1548     }
1549 
1550     if (!isUnsigned) {
1551       // See if the rhs value is < SignedMin.
1552       APFloat signedMin(rhs.getSemantics());
1553       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1554                                  APFloat::rmNearestTiesToEven);
1555       if (signedMin > rhs) { // smin > 12312.0
1556         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1557             pred == CmpIPredicate::sge)
1558           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1559                                                      /*width=*/1);
1560         else
1561           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1562                                                      /*width=*/1);
1563         return success();
1564       }
1565     } else {
1566       // See if the rhs value is < UnsignedMin.
1567       APFloat unsignedMin(rhs.getSemantics());
1568       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1569                                    APFloat::rmNearestTiesToEven);
1570       if (unsignedMin > rhs) { // umin > 12312.0
1571         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1572             pred == CmpIPredicate::uge)
1573           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1574                                                      /*width=*/1);
1575         else
1576           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1577                                                      /*width=*/1);
1578         return success();
1579       }
1580     }
1581 
1582     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1583     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1584     // casting the FP value to the integer value and back, checking for
1585     // equality. Don't do this for zero, because -0.0 is not fractional.
1586     bool ignored;
1587     APSInt rhsInt(intWidth, isUnsigned);
1588     if (APFloat::opInvalidOp ==
1589         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1590       // Undefined behavior invoked - the destination type can't represent
1591       // the input constant.
1592       return failure();
1593     }
1594 
1595     if (!rhs.isZero()) {
1596       APFloat apf(floatTy.getFloatSemantics(),
1597                   APInt::getZero(floatTy.getWidth()));
1598       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1599 
1600       bool equal = apf == rhs;
1601       if (!equal) {
1602         // If we had a comparison against a fractional value, we have to adjust
1603         // the compare predicate and sometimes the value.  rhsInt is rounded
1604         // towards zero at this point.
1605         switch (pred) {
1606         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1607           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1608                                                      /*width=*/1);
1609           return success();
1610         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1611           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1612                                                      /*width=*/1);
1613           return success();
1614         case CmpIPredicate::ule:
1615           // (float)int <= 4.4   --> int <= 4
1616           // (float)int <= -4.4  --> false
1617           if (rhs.isNegative()) {
1618             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1619                                                        /*width=*/1);
1620             return success();
1621           }
1622           break;
1623         case CmpIPredicate::sle:
1624           // (float)int <= 4.4   --> int <= 4
1625           // (float)int <= -4.4  --> int < -4
1626           if (rhs.isNegative())
1627             pred = CmpIPredicate::slt;
1628           break;
1629         case CmpIPredicate::ult:
1630           // (float)int < -4.4   --> false
1631           // (float)int < 4.4    --> int <= 4
1632           if (rhs.isNegative()) {
1633             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1634                                                        /*width=*/1);
1635             return success();
1636           }
1637           pred = CmpIPredicate::ule;
1638           break;
1639         case CmpIPredicate::slt:
1640           // (float)int < -4.4   --> int < -4
1641           // (float)int < 4.4    --> int <= 4
1642           if (!rhs.isNegative())
1643             pred = CmpIPredicate::sle;
1644           break;
1645         case CmpIPredicate::ugt:
1646           // (float)int > 4.4    --> int > 4
1647           // (float)int > -4.4   --> true
1648           if (rhs.isNegative()) {
1649             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1650                                                        /*width=*/1);
1651             return success();
1652           }
1653           break;
1654         case CmpIPredicate::sgt:
1655           // (float)int > 4.4    --> int > 4
1656           // (float)int > -4.4   --> int >= -4
1657           if (rhs.isNegative())
1658             pred = CmpIPredicate::sge;
1659           break;
1660         case CmpIPredicate::uge:
1661           // (float)int >= -4.4   --> true
1662           // (float)int >= 4.4    --> int > 4
1663           if (rhs.isNegative()) {
1664             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1665                                                        /*width=*/1);
1666             return success();
1667           }
1668           pred = CmpIPredicate::ugt;
1669           break;
1670         case CmpIPredicate::sge:
1671           // (float)int >= -4.4   --> int >= -4
1672           // (float)int >= 4.4    --> int > 4
1673           if (!rhs.isNegative())
1674             pred = CmpIPredicate::sgt;
1675           break;
1676         }
1677       }
1678     }
1679 
1680     // Lower this FP comparison into an appropriate integer version of the
1681     // comparison.
1682     rewriter.replaceOpWithNewOp<CmpIOp>(
1683         op, pred, intVal,
1684         rewriter.create<ConstantOp>(
1685             op.getLoc(), intVal.getType(),
1686             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1687     return success();
1688   }
1689 };
1690 
1691 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1692                                                 MLIRContext *context) {
1693   patterns.insert<CmpFIntToFPConst>(context);
1694 }
1695 
1696 //===----------------------------------------------------------------------===//
1697 // SelectOp
1698 //===----------------------------------------------------------------------===//
1699 
1700 // Transforms a select of a boolean to arithmetic operations
1701 //
1702 //  arith.select %arg, %x, %y : i1
1703 //
1704 //  becomes
1705 //
1706 //  and(%arg, %x) or and(!%arg, %y)
1707 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1708   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1709 
1710   LogicalResult matchAndRewrite(arith::SelectOp op,
1711                                 PatternRewriter &rewriter) const override {
1712     if (!op.getType().isInteger(1))
1713       return failure();
1714 
1715     Value falseConstant =
1716         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1717     Value notCondition = rewriter.create<arith::XOrIOp>(
1718         op.getLoc(), op.getCondition(), falseConstant);
1719 
1720     Value trueVal = rewriter.create<arith::AndIOp>(
1721         op.getLoc(), op.getCondition(), op.getTrueValue());
1722     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1723                                                     op.getFalseValue());
1724     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1725     return success();
1726   }
1727 };
1728 
1729 //  select %arg, %c1, %c0 => extui %arg
1730 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1731   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1732 
1733   LogicalResult matchAndRewrite(arith::SelectOp op,
1734                                 PatternRewriter &rewriter) const override {
1735     // Cannot extui i1 to i1, or i1 to f32
1736     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1737       return failure();
1738 
1739     // select %x, c1, %c0 => extui %arg
1740     if (matchPattern(op.getTrueValue(), m_One()))
1741       if (matchPattern(op.getFalseValue(), m_Zero())) {
1742         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1743                                                     op.getCondition());
1744         return success();
1745       }
1746 
1747     // select %x, c0, %c1 => extui (xor %arg, true)
1748     if (matchPattern(op.getTrueValue(), m_Zero()))
1749       if (matchPattern(op.getFalseValue(), m_One())) {
1750         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1751             op, op.getType(),
1752             rewriter.create<arith::XOrIOp>(
1753                 op.getLoc(), op.getCondition(),
1754                 rewriter.create<arith::ConstantIntOp>(
1755                     op.getLoc(), 1, op.getCondition().getType())));
1756         return success();
1757       }
1758 
1759     return failure();
1760   }
1761 };
1762 
1763 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1764                                                   MLIRContext *context) {
1765   results.add<SelectI1Simplify, SelectToExtUI>(context);
1766 }
1767 
1768 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1769   Value trueVal = getTrueValue();
1770   Value falseVal = getFalseValue();
1771   if (trueVal == falseVal)
1772     return trueVal;
1773 
1774   Value condition = getCondition();
1775 
1776   // select true, %0, %1 => %0
1777   if (matchPattern(condition, m_One()))
1778     return trueVal;
1779 
1780   // select false, %0, %1 => %1
1781   if (matchPattern(condition, m_Zero()))
1782     return falseVal;
1783 
1784   // select %x, true, false => %x
1785   if (getType().isInteger(1))
1786     if (matchPattern(getTrueValue(), m_One()))
1787       if (matchPattern(getFalseValue(), m_Zero()))
1788         return condition;
1789 
1790   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1791     auto pred = cmp.getPredicate();
1792     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1793       auto cmpLhs = cmp.getLhs();
1794       auto cmpRhs = cmp.getRhs();
1795 
1796       // %0 = arith.cmpi eq, %arg0, %arg1
1797       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1798 
1799       // %0 = arith.cmpi ne, %arg0, %arg1
1800       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1801 
1802       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1803           (cmpRhs == trueVal && cmpLhs == falseVal))
1804         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1805     }
1806   }
1807   return nullptr;
1808 }
1809 
1810 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1811   Type conditionType, resultType;
1812   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1813   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1814       parser.parseOptionalAttrDict(result.attributes) ||
1815       parser.parseColonType(resultType))
1816     return failure();
1817 
1818   // Check for the explicit condition type if this is a masked tensor or vector.
1819   if (succeeded(parser.parseOptionalComma())) {
1820     conditionType = resultType;
1821     if (parser.parseType(resultType))
1822       return failure();
1823   } else {
1824     conditionType = parser.getBuilder().getI1Type();
1825   }
1826 
1827   result.addTypes(resultType);
1828   return parser.resolveOperands(operands,
1829                                 {conditionType, resultType, resultType},
1830                                 parser.getNameLoc(), result.operands);
1831 }
1832 
1833 void arith::SelectOp::print(OpAsmPrinter &p) {
1834   p << " " << getOperands();
1835   p.printOptionalAttrDict((*this)->getAttrs());
1836   p << " : ";
1837   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1838     p << condType << ", ";
1839   p << getType();
1840 }
1841 
1842 LogicalResult arith::SelectOp::verify() {
1843   Type conditionType = getCondition().getType();
1844   if (conditionType.isSignlessInteger(1))
1845     return success();
1846 
1847   // If the result type is a vector or tensor, the type can be a mask with the
1848   // same elements.
1849   Type resultType = getType();
1850   if (!resultType.isa<TensorType, VectorType>())
1851     return emitOpError() << "expected condition to be a signless i1, but got "
1852                          << conditionType;
1853   Type shapedConditionType = getI1SameShape(resultType);
1854   if (conditionType != shapedConditionType) {
1855     return emitOpError() << "expected condition type to have the same shape "
1856                             "as the result type, expected "
1857                          << shapedConditionType << ", but got "
1858                          << conditionType;
1859   }
1860   return success();
1861 }
1862 //===----------------------------------------------------------------------===//
1863 // ShLIOp
1864 //===----------------------------------------------------------------------===//
1865 
1866 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1867   // Don't fold if shifting more than the bit width.
1868   bool bounded = false;
1869   auto result = constFoldBinaryOp<IntegerAttr>(
1870       operands, [&](const APInt &a, const APInt &b) {
1871         bounded = b.ule(b.getBitWidth());
1872         return a.shl(b);
1873       });
1874   return bounded ? result : Attribute();
1875 }
1876 
1877 //===----------------------------------------------------------------------===//
1878 // ShRUIOp
1879 //===----------------------------------------------------------------------===//
1880 
1881 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1882   // Don't fold if shifting more than the bit width.
1883   bool bounded = false;
1884   auto result = constFoldBinaryOp<IntegerAttr>(
1885       operands, [&](const APInt &a, const APInt &b) {
1886         bounded = b.ule(b.getBitWidth());
1887         return a.lshr(b);
1888       });
1889   return bounded ? result : Attribute();
1890 }
1891 
1892 //===----------------------------------------------------------------------===//
1893 // ShRSIOp
1894 //===----------------------------------------------------------------------===//
1895 
1896 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1897   // Don't fold if shifting more than the bit width.
1898   bool bounded = false;
1899   auto result = constFoldBinaryOp<IntegerAttr>(
1900       operands, [&](const APInt &a, const APInt &b) {
1901         bounded = b.ule(b.getBitWidth());
1902         return a.ashr(b);
1903       });
1904   return bounded ? result : Attribute();
1905 }
1906 
1907 //===----------------------------------------------------------------------===//
1908 // Atomic Enum
1909 //===----------------------------------------------------------------------===//
1910 
1911 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1912 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1913                                             OpBuilder &builder, Location loc) {
1914   switch (kind) {
1915   case AtomicRMWKind::maxf:
1916     return builder.getFloatAttr(
1917         resultType,
1918         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1919                         /*Negative=*/true));
1920   case AtomicRMWKind::addf:
1921   case AtomicRMWKind::addi:
1922   case AtomicRMWKind::maxu:
1923   case AtomicRMWKind::ori:
1924     return builder.getZeroAttr(resultType);
1925   case AtomicRMWKind::andi:
1926     return builder.getIntegerAttr(
1927         resultType,
1928         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1929   case AtomicRMWKind::maxs:
1930     return builder.getIntegerAttr(
1931         resultType,
1932         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1933   case AtomicRMWKind::minf:
1934     return builder.getFloatAttr(
1935         resultType,
1936         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1937                         /*Negative=*/false));
1938   case AtomicRMWKind::mins:
1939     return builder.getIntegerAttr(
1940         resultType,
1941         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1942   case AtomicRMWKind::minu:
1943     return builder.getIntegerAttr(
1944         resultType,
1945         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1946   case AtomicRMWKind::muli:
1947     return builder.getIntegerAttr(resultType, 1);
1948   case AtomicRMWKind::mulf:
1949     return builder.getFloatAttr(resultType, 1);
1950   // TODO: Add remaining reduction operations.
1951   default:
1952     (void)emitOptionalError(loc, "Reduction operation type not supported");
1953     break;
1954   }
1955   return nullptr;
1956 }
1957 
1958 /// Returns the identity value associated with an AtomicRMWKind op.
1959 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1960                                     OpBuilder &builder, Location loc) {
1961   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1962   return builder.create<arith::ConstantOp>(loc, attr);
1963 }
1964 
1965 /// Return the value obtained by applying the reduction operation kind
1966 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1967 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1968                                   Location loc, Value lhs, Value rhs) {
1969   switch (op) {
1970   case AtomicRMWKind::addf:
1971     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1972   case AtomicRMWKind::addi:
1973     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1974   case AtomicRMWKind::mulf:
1975     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1976   case AtomicRMWKind::muli:
1977     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1978   case AtomicRMWKind::maxf:
1979     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1980   case AtomicRMWKind::minf:
1981     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1982   case AtomicRMWKind::maxs:
1983     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1984   case AtomicRMWKind::mins:
1985     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1986   case AtomicRMWKind::maxu:
1987     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1988   case AtomicRMWKind::minu:
1989     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1990   case AtomicRMWKind::ori:
1991     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1992   case AtomicRMWKind::andi:
1993     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1994   // TODO: Add remaining reduction operations.
1995   default:
1996     (void)emitOptionalError(loc, "Reduction operation type not supported");
1997     break;
1998   }
1999   return nullptr;
2000 }
2001 
2002 //===----------------------------------------------------------------------===//
2003 // TableGen'd op method definitions
2004 //===----------------------------------------------------------------------===//
2005 
2006 #define GET_OP_CLASSES
2007 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2008 
2009 //===----------------------------------------------------------------------===//
2010 // TableGen'd enum attribute definitions
2011 //===----------------------------------------------------------------------===//
2012 
2013 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2014