1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30                                    Attribute lhs, Attribute rhs) {
31   return builder.getIntegerAttr(res.getType(),
32                                 lhs.cast<IntegerAttr>().getInt() +
33                                     rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37                                    Attribute lhs, Attribute rhs) {
38   return builder.getIntegerAttr(res.getType(),
39                                 lhs.cast<IntegerAttr>().getInt() -
40                                     rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45   switch (pred) {
46   case arith::CmpIPredicate::eq:
47     return arith::CmpIPredicate::ne;
48   case arith::CmpIPredicate::ne:
49     return arith::CmpIPredicate::eq;
50   case arith::CmpIPredicate::slt:
51     return arith::CmpIPredicate::sge;
52   case arith::CmpIPredicate::sle:
53     return arith::CmpIPredicate::sgt;
54   case arith::CmpIPredicate::sgt:
55     return arith::CmpIPredicate::sle;
56   case arith::CmpIPredicate::sge:
57     return arith::CmpIPredicate::slt;
58   case arith::CmpIPredicate::ult:
59     return arith::CmpIPredicate::uge;
60   case arith::CmpIPredicate::ule:
61     return arith::CmpIPredicate::ugt;
62   case arith::CmpIPredicate::ugt:
63     return arith::CmpIPredicate::ule;
64   case arith::CmpIPredicate::uge:
65     return arith::CmpIPredicate::ult;
66   }
67   llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71   return arith::CmpIPredicateAttr::get(pred.getContext(),
72                                        invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88     function_ref<void(Value, StringRef)> setNameFn) {
89   auto type = getType();
90   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91     auto intType = type.dyn_cast<IntegerType>();
92 
93     // Sugar i1 constants with 'true' and 'false'.
94     if (intType && intType.getWidth() == 1)
95       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97     // Otherwise, build a compex name with the value and type.
98     SmallString<32> specialNameBuffer;
99     llvm::raw_svector_ostream specialName(specialNameBuffer);
100     specialName << 'c' << intCst.getInt();
101     if (intType)
102       specialName << '_' << type;
103     setNameFn(getResult(), specialName.str());
104   } else {
105     setNameFn(getResult(), "cst");
106   }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
111 LogicalResult arith::ConstantOp::verify() {
112   auto type = getType();
113   // The value's type must match the return type.
114   if (getValue().getType() != type) {
115     return emitOpError() << "value type " << getValue().getType()
116                          << " must match return type: " << type;
117   }
118   // Integer values must be signless.
119   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120     return emitOpError("integer return type must be signless");
121   // Any float or elements attribute are acceptable.
122   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123     return emitOpError(
124         "value must be an integer, float, or elements attribute");
125   }
126   return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130   // The value's type must be the same as the provided type.
131   if (value.getType() != type)
132     return false;
133   // Integer values must be signless.
134   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135     return false;
136   // Integer, float, and element attributes are buildable.
137   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139 
140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141   return getValue();
142 }
143 
144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145                                  int64_t value, unsigned width) {
146   auto type = builder.getIntegerType(width);
147   arith::ConstantOp::build(builder, result, type,
148                            builder.getIntegerAttr(type, value));
149 }
150 
151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152                                  int64_t value, Type type) {
153   assert(type.isSignlessInteger() &&
154          "ConstantIntOp can only have signless integer type values");
155   arith::ConstantOp::build(builder, result, type,
156                            builder.getIntegerAttr(type, value));
157 }
158 
159 bool arith::ConstantIntOp::classof(Operation *op) {
160   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161     return constOp.getType().isSignlessInteger();
162   return false;
163 }
164 
165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166                                    const APFloat &value, FloatType type) {
167   arith::ConstantOp::build(builder, result, type,
168                            builder.getFloatAttr(type, value));
169 }
170 
171 bool arith::ConstantFloatOp::classof(Operation *op) {
172   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173     return constOp.getType().isa<FloatType>();
174   return false;
175 }
176 
177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178                                    int64_t value) {
179   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180                            builder.getIndexAttr(value));
181 }
182 
183 bool arith::ConstantIndexOp::classof(Operation *op) {
184   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185     return constOp.getType().isIndex();
186   return false;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194   // addi(x, 0) -> x
195   if (matchPattern(getRhs(), m_Zero()))
196     return getLhs();
197 
198   // addi(subi(a, b), b) -> a
199   if (auto sub = getLhs().getDefiningOp<SubIOp>())
200     if (getRhs() == sub.getRhs())
201       return sub.getLhs();
202 
203   // addi(b, subi(a, b)) -> a
204   if (auto sub = getRhs().getDefiningOp<SubIOp>())
205     if (getLhs() == sub.getRhs())
206       return sub.getLhs();
207 
208   return constFoldBinaryOp<IntegerAttr>(
209       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211 
212 void arith::AddIOp::getCanonicalizationPatterns(
213     RewritePatternSet &patterns, MLIRContext *context) {
214   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215       context);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221 
222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223   // subi(x,x) -> 0
224   if (getOperand(0) == getOperand(1))
225     return Builder(getContext()).getZeroAttr(getType());
226   // subi(x,0) -> x
227   if (matchPattern(getRhs(), m_Zero()))
228     return getLhs();
229 
230   return constFoldBinaryOp<IntegerAttr>(
231       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233 
234 void arith::SubIOp::getCanonicalizationPatterns(
235     RewritePatternSet &patterns, MLIRContext *context) {
236   patterns
237       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247   // muli(x, 0) -> 0
248   if (matchPattern(getRhs(), m_Zero()))
249     return getRhs();
250   // muli(x, 1) -> x
251   if (matchPattern(getRhs(), m_One()))
252     return getOperand(0);
253   // TODO: Handle the overflow case.
254 
255   // default folder
256   return constFoldBinaryOp<IntegerAttr>(
257       operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265   // Don't fold if it would require a division by zero.
266   bool div0 = false;
267   auto result =
268       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
269         if (div0 || !b) {
270           div0 = true;
271           return a;
272         }
273         return a.udiv(b);
274       });
275 
276   // Fold out division by one. Assumes all tensors of all ones are splats.
277   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
278     if (rhs.getValue() == 1)
279       return getLhs();
280   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
281     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
282       return getLhs();
283   }
284 
285   return div0 ? Attribute() : result;
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // DivSIOp
290 //===----------------------------------------------------------------------===//
291 
292 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
293   // Don't fold if it would overflow or if it requires a division by zero.
294   bool overflowOrDiv0 = false;
295   auto result =
296       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
297         if (overflowOrDiv0 || !b) {
298           overflowOrDiv0 = true;
299           return a;
300         }
301         return a.sdiv_ov(b, overflowOrDiv0);
302       });
303 
304   // Fold out division by one. Assumes all tensors of all ones are splats.
305   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
306     if (rhs.getValue() == 1)
307       return getLhs();
308   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
309     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
310       return getLhs();
311   }
312 
313   return overflowOrDiv0 ? Attribute() : result;
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // Ceil and floor division folding helpers
318 //===----------------------------------------------------------------------===//
319 
320 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
321                                     bool &overflow) {
322   // Returns (a-1)/b + 1
323   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
324   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
325   return val.sadd_ov(one, overflow);
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // CeilDivUIOp
330 //===----------------------------------------------------------------------===//
331 
332 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
333   bool overflowOrDiv0 = false;
334   auto result =
335       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
336         if (overflowOrDiv0 || !b) {
337           overflowOrDiv0 = true;
338           return a;
339         }
340         APInt quotient = a.udiv(b);
341         if (!a.urem(b))
342           return quotient;
343         APInt one(a.getBitWidth(), 1, true);
344         return quotient.uadd_ov(one, overflowOrDiv0);
345       });
346   // Fold out ceil division by one. Assumes all tensors of all ones are
347   // splats.
348   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
349     if (rhs.getValue() == 1)
350       return getLhs();
351   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
352     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
353       return getLhs();
354   }
355 
356   return overflowOrDiv0 ? Attribute() : result;
357 }
358 
359 //===----------------------------------------------------------------------===//
360 // CeilDivSIOp
361 //===----------------------------------------------------------------------===//
362 
363 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
364   // Don't fold if it would overflow or if it requires a division by zero.
365   bool overflowOrDiv0 = false;
366   auto result =
367       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
368         if (overflowOrDiv0 || !b) {
369           overflowOrDiv0 = true;
370           return a;
371         }
372         if (!a)
373           return a;
374         // After this point we know that neither a or b are zero.
375         unsigned bits = a.getBitWidth();
376         APInt zero = APInt::getZero(bits);
377         bool aGtZero = a.sgt(zero);
378         bool bGtZero = b.sgt(zero);
379         if (aGtZero && bGtZero) {
380           // Both positive, return ceil(a, b).
381           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
382         }
383         if (!aGtZero && !bGtZero) {
384           // Both negative, return ceil(-a, -b).
385           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
386           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
387           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
388         }
389         if (!aGtZero && bGtZero) {
390           // A is negative, b is positive, return - ( -a / b).
391           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
392           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
393           return zero.ssub_ov(div, overflowOrDiv0);
394         }
395         // A is positive, b is negative, return - (a / -b).
396         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
397         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
398         return zero.ssub_ov(div, overflowOrDiv0);
399       });
400 
401   // Fold out ceil division by one. Assumes all tensors of all ones are
402   // splats.
403   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
404     if (rhs.getValue() == 1)
405       return getLhs();
406   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
407     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
408       return getLhs();
409   }
410 
411   return overflowOrDiv0 ? Attribute() : result;
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // FloorDivSIOp
416 //===----------------------------------------------------------------------===//
417 
418 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
419   // Don't fold if it would overflow or if it requires a division by zero.
420   bool overflowOrDiv0 = false;
421   auto result =
422       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
423         if (overflowOrDiv0 || !b) {
424           overflowOrDiv0 = true;
425           return a;
426         }
427         if (!a)
428           return a;
429         // After this point we know that neither a or b are zero.
430         unsigned bits = a.getBitWidth();
431         APInt zero = APInt::getZero(bits);
432         bool aGtZero = a.sgt(zero);
433         bool bGtZero = b.sgt(zero);
434         if (aGtZero && bGtZero) {
435           // Both positive, return a / b.
436           return a.sdiv_ov(b, overflowOrDiv0);
437         }
438         if (!aGtZero && !bGtZero) {
439           // Both negative, return -a / -b.
440           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
441           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
442           return posA.sdiv_ov(posB, overflowOrDiv0);
443         }
444         if (!aGtZero && bGtZero) {
445           // A is negative, b is positive, return - ceil(-a, b).
446           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
447           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
448           return zero.ssub_ov(ceil, overflowOrDiv0);
449         }
450         // A is positive, b is negative, return - ceil(a, -b).
451         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
452         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
453         return zero.ssub_ov(ceil, overflowOrDiv0);
454       });
455 
456   // Fold out floor division by one. Assumes all tensors of all ones are
457   // splats.
458   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
459     if (rhs.getValue() == 1)
460       return getLhs();
461   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
462     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
463       return getLhs();
464   }
465 
466   return overflowOrDiv0 ? Attribute() : result;
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // RemUIOp
471 //===----------------------------------------------------------------------===//
472 
473 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
474   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
475   if (!rhs)
476     return {};
477   auto rhsValue = rhs.getValue();
478 
479   // x % 1 = 0
480   if (rhsValue.isOneValue())
481     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
482 
483   // Don't fold if it requires division by zero.
484   if (rhsValue.isNullValue())
485     return {};
486 
487   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
488   if (!lhs)
489     return {};
490   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // RemSIOp
495 //===----------------------------------------------------------------------===//
496 
497 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
498   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
499   if (!rhs)
500     return {};
501   auto rhsValue = rhs.getValue();
502 
503   // x % 1 = 0
504   if (rhsValue.isOneValue())
505     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
506 
507   // Don't fold if it requires division by zero.
508   if (rhsValue.isNullValue())
509     return {};
510 
511   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
512   if (!lhs)
513     return {};
514   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // AndIOp
519 //===----------------------------------------------------------------------===//
520 
521 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
522   /// and(x, 0) -> 0
523   if (matchPattern(getRhs(), m_Zero()))
524     return getRhs();
525   /// and(x, allOnes) -> x
526   APInt intValue;
527   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
528     return getLhs();
529 
530   return constFoldBinaryOp<IntegerAttr>(
531       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // OrIOp
536 //===----------------------------------------------------------------------===//
537 
538 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
539   /// or(x, 0) -> x
540   if (matchPattern(getRhs(), m_Zero()))
541     return getLhs();
542   /// or(x, <all ones>) -> <all ones>
543   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
544     if (rhsAttr.getValue().isAllOnes())
545       return rhsAttr;
546 
547   return constFoldBinaryOp<IntegerAttr>(
548       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // XOrIOp
553 //===----------------------------------------------------------------------===//
554 
555 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
556   /// xor(x, 0) -> x
557   if (matchPattern(getRhs(), m_Zero()))
558     return getLhs();
559   /// xor(x, x) -> 0
560   if (getLhs() == getRhs())
561     return Builder(getContext()).getZeroAttr(getType());
562   /// xor(xor(x, a), a) -> x
563   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
564     if (prev.getRhs() == getRhs())
565       return prev.getLhs();
566 
567   return constFoldBinaryOp<IntegerAttr>(
568       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
569 }
570 
571 void arith::XOrIOp::getCanonicalizationPatterns(
572     RewritePatternSet &patterns, MLIRContext *context) {
573   patterns.add<XOrINotCmpI>(context);
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // NegFOp
578 //===----------------------------------------------------------------------===//
579 
580 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
581   return constFoldUnaryOp<FloatAttr>(operands,
582                                      [](const APFloat &a) { return -a; });
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // AddFOp
587 //===----------------------------------------------------------------------===//
588 
589 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
590   // addf(x, -0) -> x
591   if (matchPattern(getRhs(), m_NegZeroFloat()))
592     return getLhs();
593 
594   return constFoldBinaryOp<FloatAttr>(
595       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
596 }
597 
598 //===----------------------------------------------------------------------===//
599 // SubFOp
600 //===----------------------------------------------------------------------===//
601 
602 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
603   // subf(x, +0) -> x
604   if (matchPattern(getRhs(), m_PosZeroFloat()))
605     return getLhs();
606 
607   return constFoldBinaryOp<FloatAttr>(
608       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // MaxFOp
613 //===----------------------------------------------------------------------===//
614 
615 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
616   assert(operands.size() == 2 && "maxf takes two operands");
617 
618   // maxf(x,x) -> x
619   if (getLhs() == getRhs())
620     return getRhs();
621 
622   // maxf(x, -inf) -> x
623   if (matchPattern(getRhs(), m_NegInfFloat()))
624     return getLhs();
625 
626   return constFoldBinaryOp<FloatAttr>(
627       operands,
628       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // MaxSIOp
633 //===----------------------------------------------------------------------===//
634 
635 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
636   assert(operands.size() == 2 && "binary operation takes two operands");
637 
638   // maxsi(x,x) -> x
639   if (getLhs() == getRhs())
640     return getRhs();
641 
642   APInt intValue;
643   // maxsi(x,MAX_INT) -> MAX_INT
644   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
645       intValue.isMaxSignedValue())
646     return getRhs();
647 
648   // maxsi(x, MIN_INT) -> x
649   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
650       intValue.isMinSignedValue())
651     return getLhs();
652 
653   return constFoldBinaryOp<IntegerAttr>(operands,
654                                         [](const APInt &a, const APInt &b) {
655                                           return llvm::APIntOps::smax(a, b);
656                                         });
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // MaxUIOp
661 //===----------------------------------------------------------------------===//
662 
663 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
664   assert(operands.size() == 2 && "binary operation takes two operands");
665 
666   // maxui(x,x) -> x
667   if (getLhs() == getRhs())
668     return getRhs();
669 
670   APInt intValue;
671   // maxui(x,MAX_INT) -> MAX_INT
672   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
673     return getRhs();
674 
675   // maxui(x, MIN_INT) -> x
676   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
677     return getLhs();
678 
679   return constFoldBinaryOp<IntegerAttr>(operands,
680                                         [](const APInt &a, const APInt &b) {
681                                           return llvm::APIntOps::umax(a, b);
682                                         });
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // MinFOp
687 //===----------------------------------------------------------------------===//
688 
689 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
690   assert(operands.size() == 2 && "minf takes two operands");
691 
692   // minf(x,x) -> x
693   if (getLhs() == getRhs())
694     return getRhs();
695 
696   // minf(x, +inf) -> x
697   if (matchPattern(getRhs(), m_PosInfFloat()))
698     return getLhs();
699 
700   return constFoldBinaryOp<FloatAttr>(
701       operands,
702       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // MinSIOp
707 //===----------------------------------------------------------------------===//
708 
709 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
710   assert(operands.size() == 2 && "binary operation takes two operands");
711 
712   // minsi(x,x) -> x
713   if (getLhs() == getRhs())
714     return getRhs();
715 
716   APInt intValue;
717   // minsi(x,MIN_INT) -> MIN_INT
718   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
719       intValue.isMinSignedValue())
720     return getRhs();
721 
722   // minsi(x, MAX_INT) -> x
723   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
724       intValue.isMaxSignedValue())
725     return getLhs();
726 
727   return constFoldBinaryOp<IntegerAttr>(operands,
728                                         [](const APInt &a, const APInt &b) {
729                                           return llvm::APIntOps::smin(a, b);
730                                         });
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // MinUIOp
735 //===----------------------------------------------------------------------===//
736 
737 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
738   assert(operands.size() == 2 && "binary operation takes two operands");
739 
740   // minui(x,x) -> x
741   if (getLhs() == getRhs())
742     return getRhs();
743 
744   APInt intValue;
745   // minui(x,MIN_INT) -> MIN_INT
746   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
747     return getRhs();
748 
749   // minui(x, MAX_INT) -> x
750   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
751     return getLhs();
752 
753   return constFoldBinaryOp<IntegerAttr>(operands,
754                                         [](const APInt &a, const APInt &b) {
755                                           return llvm::APIntOps::umin(a, b);
756                                         });
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // MulFOp
761 //===----------------------------------------------------------------------===//
762 
763 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
764   // mulf(x, 1) -> x
765   if (matchPattern(getRhs(), m_OneFloat()))
766     return getLhs();
767 
768   return constFoldBinaryOp<FloatAttr>(
769       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // DivFOp
774 //===----------------------------------------------------------------------===//
775 
776 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
777   // divf(x, 1) -> x
778   if (matchPattern(getRhs(), m_OneFloat()))
779     return getLhs();
780 
781   return constFoldBinaryOp<FloatAttr>(
782       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // Utility functions for verifying cast ops
787 //===----------------------------------------------------------------------===//
788 
789 template <typename... Types>
790 using type_list = std::tuple<Types...> *;
791 
792 /// Returns a non-null type only if the provided type is one of the allowed
793 /// types or one of the allowed shaped types of the allowed types. Returns the
794 /// element type if a valid shaped type is provided.
795 template <typename... ShapedTypes, typename... ElementTypes>
796 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
797                               type_list<ElementTypes...>) {
798   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
799     return {};
800 
801   auto underlyingType = getElementTypeOrSelf(type);
802   if (!underlyingType.isa<ElementTypes...>())
803     return {};
804 
805   return underlyingType;
806 }
807 
808 /// Get allowed underlying types for vectors and tensors.
809 template <typename... ElementTypes>
810 static Type getTypeIfLike(Type type) {
811   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
812                            type_list<ElementTypes...>());
813 }
814 
815 /// Get allowed underlying types for vectors, tensors, and memrefs.
816 template <typename... ElementTypes>
817 static Type getTypeIfLikeOrMemRef(Type type) {
818   return getUnderlyingType(type,
819                            type_list<VectorType, TensorType, MemRefType>(),
820                            type_list<ElementTypes...>());
821 }
822 
823 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
824   return inputs.size() == 1 && outputs.size() == 1 &&
825          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // Verifiers for integer and floating point extension/truncation ops
830 //===----------------------------------------------------------------------===//
831 
832 // Extend ops can only extend to a wider type.
833 template <typename ValType, typename Op>
834 static LogicalResult verifyExtOp(Op op) {
835   Type srcType = getElementTypeOrSelf(op.getIn().getType());
836   Type dstType = getElementTypeOrSelf(op.getType());
837 
838   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
839     return op.emitError("result type ")
840            << dstType << " must be wider than operand type " << srcType;
841 
842   return success();
843 }
844 
845 // Truncate ops can only truncate to a shorter type.
846 template <typename ValType, typename Op>
847 static LogicalResult verifyTruncateOp(Op op) {
848   Type srcType = getElementTypeOrSelf(op.getIn().getType());
849   Type dstType = getElementTypeOrSelf(op.getType());
850 
851   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
852     return op.emitError("result type ")
853            << dstType << " must be shorter than operand type " << srcType;
854 
855   return success();
856 }
857 
858 /// Validate a cast that changes the width of a type.
859 template <template <typename> class WidthComparator, typename... ElementTypes>
860 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
861   if (!areValidCastInputsAndOutputs(inputs, outputs))
862     return false;
863 
864   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
865   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
866   if (!srcType || !dstType)
867     return false;
868 
869   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
870                                      srcType.getIntOrFloatBitWidth());
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // ExtUIOp
875 //===----------------------------------------------------------------------===//
876 
877 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
878   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
879     getInMutable().assign(lhs.getIn());
880     return getResult();
881   }
882   Type resType = getType();
883   unsigned bitWidth;
884   if (auto shapedType = resType.dyn_cast<ShapedType>())
885     bitWidth = shapedType.getElementTypeBitWidth();
886   else
887     bitWidth = resType.getIntOrFloatBitWidth();
888   return constFoldCastOp<IntegerAttr, IntegerAttr>(
889       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
890         return a.zext(bitWidth);
891       });
892 }
893 
894 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
895   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
896 }
897 
898 LogicalResult arith::ExtUIOp::verify() {
899   return verifyExtOp<IntegerType>(*this);
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // ExtSIOp
904 //===----------------------------------------------------------------------===//
905 
906 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
907   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
908     getInMutable().assign(lhs.getIn());
909     return getResult();
910   }
911   Type resType = getType();
912   unsigned bitWidth;
913   if (auto shapedType = resType.dyn_cast<ShapedType>())
914     bitWidth = shapedType.getElementTypeBitWidth();
915   else
916     bitWidth = resType.getIntOrFloatBitWidth();
917   return constFoldCastOp<IntegerAttr, IntegerAttr>(
918       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
919         return a.sext(bitWidth);
920       });
921 }
922 
923 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
924   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
925 }
926 
927 void arith::ExtSIOp::getCanonicalizationPatterns(
928     RewritePatternSet &patterns, MLIRContext *context) {
929   patterns.add<ExtSIOfExtUI>(context);
930 }
931 
932 LogicalResult arith::ExtSIOp::verify() {
933   return verifyExtOp<IntegerType>(*this);
934 }
935 
936 //===----------------------------------------------------------------------===//
937 // ExtFOp
938 //===----------------------------------------------------------------------===//
939 
940 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
941   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
942 }
943 
944 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
945 
946 //===----------------------------------------------------------------------===//
947 // TruncIOp
948 //===----------------------------------------------------------------------===//
949 
950 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
951   assert(operands.size() == 1 && "unary operation takes one operand");
952 
953   // trunci(zexti(a)) -> a
954   // trunci(sexti(a)) -> a
955   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
956       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
957     return getOperand().getDefiningOp()->getOperand(0);
958 
959   // trunci(trunci(a)) -> trunci(a))
960   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
961     setOperand(getOperand().getDefiningOp()->getOperand(0));
962     return getResult();
963   }
964 
965   Type resType = getType();
966   unsigned bitWidth;
967   if (auto shapedType = resType.dyn_cast<ShapedType>())
968     bitWidth = shapedType.getElementTypeBitWidth();
969   else
970     bitWidth = resType.getIntOrFloatBitWidth();
971 
972   return constFoldCastOp<IntegerAttr, IntegerAttr>(
973       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
974         return a.trunc(bitWidth);
975       });
976 }
977 
978 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
979   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
980 }
981 
982 LogicalResult arith::TruncIOp::verify() {
983   return verifyTruncateOp<IntegerType>(*this);
984 }
985 
986 //===----------------------------------------------------------------------===//
987 // TruncFOp
988 //===----------------------------------------------------------------------===//
989 
990 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
991 /// can be represented without precision loss or rounding.
992 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
993   assert(operands.size() == 1 && "unary operation takes one operand");
994 
995   auto constOperand = operands.front();
996   if (!constOperand || !constOperand.isa<FloatAttr>())
997     return {};
998 
999   // Convert to target type via 'double'.
1000   double sourceValue =
1001       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
1002   auto targetAttr = FloatAttr::get(getType(), sourceValue);
1003 
1004   // Propagate if constant's value does not change after truncation.
1005   if (sourceValue == targetAttr.getValue().convertToDouble())
1006     return targetAttr;
1007 
1008   return {};
1009 }
1010 
1011 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1012   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1013 }
1014 
1015 LogicalResult arith::TruncFOp::verify() {
1016   return verifyTruncateOp<FloatType>(*this);
1017 }
1018 
1019 //===----------------------------------------------------------------------===//
1020 // AndIOp
1021 //===----------------------------------------------------------------------===//
1022 
1023 void arith::AndIOp::getCanonicalizationPatterns(
1024     RewritePatternSet &patterns, MLIRContext *context) {
1025   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // OrIOp
1030 //===----------------------------------------------------------------------===//
1031 
1032 void arith::OrIOp::getCanonicalizationPatterns(
1033     RewritePatternSet &patterns, MLIRContext *context) {
1034   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1035 }
1036 
1037 //===----------------------------------------------------------------------===//
1038 // Verifiers for casts between integers and floats.
1039 //===----------------------------------------------------------------------===//
1040 
1041 template <typename From, typename To>
1042 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1043   if (!areValidCastInputsAndOutputs(inputs, outputs))
1044     return false;
1045 
1046   auto srcType = getTypeIfLike<From>(inputs.front());
1047   auto dstType = getTypeIfLike<To>(outputs.back());
1048 
1049   return srcType && dstType;
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // UIToFPOp
1054 //===----------------------------------------------------------------------===//
1055 
1056 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1057   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1058 }
1059 
1060 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1061   Type resType = getType();
1062   Type resEleType;
1063   if (auto shapedType = resType.dyn_cast<ShapedType>())
1064     resEleType = shapedType.getElementType();
1065   else
1066     resEleType = resType;
1067   return constFoldCastOp<IntegerAttr, FloatAttr>(
1068       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1069         FloatType floatTy = resEleType.cast<FloatType>();
1070         APFloat apf(floatTy.getFloatSemantics(),
1071                     APInt::getZero(floatTy.getWidth()));
1072         apf.convertFromAPInt(a, /*IsSigned=*/false,
1073                              APFloat::rmNearestTiesToEven);
1074         return apf;
1075       });
1076 }
1077 
1078 //===----------------------------------------------------------------------===//
1079 // SIToFPOp
1080 //===----------------------------------------------------------------------===//
1081 
1082 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1083   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1084 }
1085 
1086 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1087   Type resType = getType();
1088   Type resEleType;
1089   if (auto shapedType = resType.dyn_cast<ShapedType>())
1090     resEleType = shapedType.getElementType();
1091   else
1092     resEleType = resType;
1093   return constFoldCastOp<IntegerAttr, FloatAttr>(
1094       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1095         FloatType floatTy = resEleType.cast<FloatType>();
1096         APFloat apf(floatTy.getFloatSemantics(),
1097                     APInt::getZero(floatTy.getWidth()));
1098         apf.convertFromAPInt(a, /*IsSigned=*/true,
1099                              APFloat::rmNearestTiesToEven);
1100         return apf;
1101       });
1102 }
1103 //===----------------------------------------------------------------------===//
1104 // FPToUIOp
1105 //===----------------------------------------------------------------------===//
1106 
1107 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1108   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1109 }
1110 
1111 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1112   Type resType = getType();
1113   Type resEleType;
1114   if (auto shapedType = resType.dyn_cast<ShapedType>())
1115     resEleType = shapedType.getElementType();
1116   else
1117     resEleType = resType;
1118   return constFoldCastOp<FloatAttr, IntegerAttr>(
1119       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1120         IntegerType intTy = resEleType.cast<IntegerType>();
1121         bool ignored;
1122         APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1123         castStatus = APFloat::opInvalidOp !=
1124                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1125         return api;
1126       });
1127 }
1128 
1129 //===----------------------------------------------------------------------===//
1130 // FPToSIOp
1131 //===----------------------------------------------------------------------===//
1132 
1133 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1134   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1135 }
1136 
1137 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1138   Type resType = getType();
1139   Type resEleType;
1140   if (auto shapedType = resType.dyn_cast<ShapedType>())
1141     resEleType = shapedType.getElementType();
1142   else
1143     resEleType = resType;
1144   return constFoldCastOp<FloatAttr, IntegerAttr>(
1145       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1146         IntegerType intTy = resEleType.cast<IntegerType>();
1147         bool ignored;
1148         APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1149         castStatus = APFloat::opInvalidOp !=
1150                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1151         return api;
1152       });
1153 }
1154 
1155 //===----------------------------------------------------------------------===//
1156 // IndexCastOp
1157 //===----------------------------------------------------------------------===//
1158 
1159 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1160                                            TypeRange outputs) {
1161   if (!areValidCastInputsAndOutputs(inputs, outputs))
1162     return false;
1163 
1164   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1165   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1166   if (!srcType || !dstType)
1167     return false;
1168 
1169   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1170          (srcType.isSignlessInteger() && dstType.isIndex());
1171 }
1172 
1173 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1174   // index_cast(constant) -> constant
1175   // A little hack because we go through int. Otherwise, the size of the
1176   // constant might need to change.
1177   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1178     return IntegerAttr::get(getType(), value.getInt());
1179 
1180   return {};
1181 }
1182 
1183 void arith::IndexCastOp::getCanonicalizationPatterns(
1184     RewritePatternSet &patterns, MLIRContext *context) {
1185   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // BitcastOp
1190 //===----------------------------------------------------------------------===//
1191 
1192 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1193   if (!areValidCastInputsAndOutputs(inputs, outputs))
1194     return false;
1195 
1196   auto srcType =
1197       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1198   auto dstType =
1199       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1200   if (!srcType || !dstType)
1201     return false;
1202 
1203   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1204 }
1205 
1206 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1207   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1208 
1209   auto resType = getType();
1210   auto operand = operands[0];
1211   if (!operand)
1212     return {};
1213 
1214   /// Bitcast dense elements.
1215   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1216     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1217   /// Other shaped types unhandled.
1218   if (resType.isa<ShapedType>())
1219     return {};
1220 
1221   /// Bitcast integer or float to integer or float.
1222   APInt bits = operand.isa<FloatAttr>()
1223                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1224                    : operand.cast<IntegerAttr>().getValue();
1225 
1226   if (auto resFloatType = resType.dyn_cast<FloatType>())
1227     return FloatAttr::get(resType,
1228                           APFloat(resFloatType.getFloatSemantics(), bits));
1229   return IntegerAttr::get(resType, bits);
1230 }
1231 
1232 void arith::BitcastOp::getCanonicalizationPatterns(
1233     RewritePatternSet &patterns, MLIRContext *context) {
1234   patterns.add<BitcastOfBitcast>(context);
1235 }
1236 
1237 //===----------------------------------------------------------------------===//
1238 // Helpers for compare ops
1239 //===----------------------------------------------------------------------===//
1240 
1241 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1242 static Type getI1SameShape(Type type) {
1243   auto i1Type = IntegerType::get(type.getContext(), 1);
1244   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1245     return RankedTensorType::get(tensorType.getShape(), i1Type);
1246   if (type.isa<UnrankedTensorType>())
1247     return UnrankedTensorType::get(i1Type);
1248   if (auto vectorType = type.dyn_cast<VectorType>())
1249     return VectorType::get(vectorType.getShape(), i1Type,
1250                            vectorType.getNumScalableDims());
1251   return i1Type;
1252 }
1253 
1254 //===----------------------------------------------------------------------===//
1255 // CmpIOp
1256 //===----------------------------------------------------------------------===//
1257 
1258 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1259 /// comparison predicates.
1260 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1261                                     const APInt &lhs, const APInt &rhs) {
1262   switch (predicate) {
1263   case arith::CmpIPredicate::eq:
1264     return lhs.eq(rhs);
1265   case arith::CmpIPredicate::ne:
1266     return lhs.ne(rhs);
1267   case arith::CmpIPredicate::slt:
1268     return lhs.slt(rhs);
1269   case arith::CmpIPredicate::sle:
1270     return lhs.sle(rhs);
1271   case arith::CmpIPredicate::sgt:
1272     return lhs.sgt(rhs);
1273   case arith::CmpIPredicate::sge:
1274     return lhs.sge(rhs);
1275   case arith::CmpIPredicate::ult:
1276     return lhs.ult(rhs);
1277   case arith::CmpIPredicate::ule:
1278     return lhs.ule(rhs);
1279   case arith::CmpIPredicate::ugt:
1280     return lhs.ugt(rhs);
1281   case arith::CmpIPredicate::uge:
1282     return lhs.uge(rhs);
1283   }
1284   llvm_unreachable("unknown cmpi predicate kind");
1285 }
1286 
1287 /// Returns true if the predicate is true for two equal operands.
1288 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1289   switch (predicate) {
1290   case arith::CmpIPredicate::eq:
1291   case arith::CmpIPredicate::sle:
1292   case arith::CmpIPredicate::sge:
1293   case arith::CmpIPredicate::ule:
1294   case arith::CmpIPredicate::uge:
1295     return true;
1296   case arith::CmpIPredicate::ne:
1297   case arith::CmpIPredicate::slt:
1298   case arith::CmpIPredicate::sgt:
1299   case arith::CmpIPredicate::ult:
1300   case arith::CmpIPredicate::ugt:
1301     return false;
1302   }
1303   llvm_unreachable("unknown cmpi predicate kind");
1304 }
1305 
1306 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1307   auto boolAttr = BoolAttr::get(ctx, value);
1308   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1309   if (!shapedType)
1310     return boolAttr;
1311   return DenseElementsAttr::get(shapedType, boolAttr);
1312 }
1313 
1314 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1315   assert(operands.size() == 2 && "cmpi takes two operands");
1316 
1317   // cmpi(pred, x, x)
1318   if (getLhs() == getRhs()) {
1319     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1320     return getBoolAttribute(getType(), getContext(), val);
1321   }
1322 
1323   if (matchPattern(getRhs(), m_Zero())) {
1324     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1325       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1326         // extsi(%x : i1 -> iN) != 0  ->  %x
1327         if (getPredicate() == arith::CmpIPredicate::ne) {
1328           return extOp.getOperand();
1329         }
1330       }
1331     }
1332     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1333       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1334         // extui(%x : i1 -> iN) != 0  ->  %x
1335         if (getPredicate() == arith::CmpIPredicate::ne) {
1336           return extOp.getOperand();
1337         }
1338       }
1339     }
1340   }
1341 
1342   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1343   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1344   if (!lhs || !rhs)
1345     return {};
1346 
1347   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1348   return BoolAttr::get(getContext(), val);
1349 }
1350 
1351 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1352                                                 MLIRContext *context) {
1353   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1354 }
1355 
1356 //===----------------------------------------------------------------------===//
1357 // CmpFOp
1358 //===----------------------------------------------------------------------===//
1359 
1360 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1361 /// comparison predicates.
1362 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1363                                     const APFloat &lhs, const APFloat &rhs) {
1364   auto cmpResult = lhs.compare(rhs);
1365   switch (predicate) {
1366   case arith::CmpFPredicate::AlwaysFalse:
1367     return false;
1368   case arith::CmpFPredicate::OEQ:
1369     return cmpResult == APFloat::cmpEqual;
1370   case arith::CmpFPredicate::OGT:
1371     return cmpResult == APFloat::cmpGreaterThan;
1372   case arith::CmpFPredicate::OGE:
1373     return cmpResult == APFloat::cmpGreaterThan ||
1374            cmpResult == APFloat::cmpEqual;
1375   case arith::CmpFPredicate::OLT:
1376     return cmpResult == APFloat::cmpLessThan;
1377   case arith::CmpFPredicate::OLE:
1378     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1379   case arith::CmpFPredicate::ONE:
1380     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1381   case arith::CmpFPredicate::ORD:
1382     return cmpResult != APFloat::cmpUnordered;
1383   case arith::CmpFPredicate::UEQ:
1384     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1385   case arith::CmpFPredicate::UGT:
1386     return cmpResult == APFloat::cmpUnordered ||
1387            cmpResult == APFloat::cmpGreaterThan;
1388   case arith::CmpFPredicate::UGE:
1389     return cmpResult == APFloat::cmpUnordered ||
1390            cmpResult == APFloat::cmpGreaterThan ||
1391            cmpResult == APFloat::cmpEqual;
1392   case arith::CmpFPredicate::ULT:
1393     return cmpResult == APFloat::cmpUnordered ||
1394            cmpResult == APFloat::cmpLessThan;
1395   case arith::CmpFPredicate::ULE:
1396     return cmpResult == APFloat::cmpUnordered ||
1397            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1398   case arith::CmpFPredicate::UNE:
1399     return cmpResult != APFloat::cmpEqual;
1400   case arith::CmpFPredicate::UNO:
1401     return cmpResult == APFloat::cmpUnordered;
1402   case arith::CmpFPredicate::AlwaysTrue:
1403     return true;
1404   }
1405   llvm_unreachable("unknown cmpf predicate kind");
1406 }
1407 
1408 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1409   assert(operands.size() == 2 && "cmpf takes two operands");
1410 
1411   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1412   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1413 
1414   // If one operand is NaN, making them both NaN does not change the result.
1415   if (lhs && lhs.getValue().isNaN())
1416     rhs = lhs;
1417   if (rhs && rhs.getValue().isNaN())
1418     lhs = rhs;
1419 
1420   if (!lhs || !rhs)
1421     return {};
1422 
1423   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1424   return BoolAttr::get(getContext(), val);
1425 }
1426 
1427 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1428 public:
1429   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1430 
1431   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1432                                                  bool isUnsigned) {
1433     using namespace arith;
1434     switch (pred) {
1435     case CmpFPredicate::UEQ:
1436     case CmpFPredicate::OEQ:
1437       return CmpIPredicate::eq;
1438     case CmpFPredicate::UGT:
1439     case CmpFPredicate::OGT:
1440       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1441     case CmpFPredicate::UGE:
1442     case CmpFPredicate::OGE:
1443       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1444     case CmpFPredicate::ULT:
1445     case CmpFPredicate::OLT:
1446       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1447     case CmpFPredicate::ULE:
1448     case CmpFPredicate::OLE:
1449       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1450     case CmpFPredicate::UNE:
1451     case CmpFPredicate::ONE:
1452       return CmpIPredicate::ne;
1453     default:
1454       llvm_unreachable("Unexpected predicate!");
1455     }
1456   }
1457 
1458   LogicalResult matchAndRewrite(CmpFOp op,
1459                                 PatternRewriter &rewriter) const override {
1460     FloatAttr flt;
1461     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1462       return failure();
1463 
1464     const APFloat &rhs = flt.getValue();
1465 
1466     // Don't attempt to fold a nan.
1467     if (rhs.isNaN())
1468       return failure();
1469 
1470     // Get the width of the mantissa.  We don't want to hack on conversions that
1471     // might lose information from the integer, e.g. "i64 -> float"
1472     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1473     int mantissaWidth = floatTy.getFPMantissaWidth();
1474     if (mantissaWidth <= 0)
1475       return failure();
1476 
1477     bool isUnsigned;
1478     Value intVal;
1479 
1480     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1481       isUnsigned = false;
1482       intVal = si.getIn();
1483     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1484       isUnsigned = true;
1485       intVal = ui.getIn();
1486     } else {
1487       return failure();
1488     }
1489 
1490     // Check to see that the input is converted from an integer type that is
1491     // small enough that preserves all bits.
1492     auto intTy = intVal.getType().cast<IntegerType>();
1493     auto intWidth = intTy.getWidth();
1494 
1495     // Number of bits representing values, as opposed to the sign
1496     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1497 
1498     // Following test does NOT adjust intWidth downwards for signed inputs,
1499     // because the most negative value still requires all the mantissa bits
1500     // to distinguish it from one less than that value.
1501     if ((int)intWidth > mantissaWidth) {
1502       // Conversion would lose accuracy. Check if loss can impact comparison.
1503       int exponent = ilogb(rhs);
1504       if (exponent == APFloat::IEK_Inf) {
1505         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1506         if (maxExponent < (int)valueBits) {
1507           // Conversion could create infinity.
1508           return failure();
1509         }
1510       } else {
1511         // Note that if rhs is zero or NaN, then Exp is negative
1512         // and first condition is trivially false.
1513         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1514           // Conversion could affect comparison.
1515           return failure();
1516         }
1517       }
1518     }
1519 
1520     // Convert to equivalent cmpi predicate
1521     CmpIPredicate pred;
1522     switch (op.getPredicate()) {
1523     case CmpFPredicate::ORD:
1524       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1525       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1526                                                  /*width=*/1);
1527       return success();
1528     case CmpFPredicate::UNO:
1529       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1530       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1531                                                  /*width=*/1);
1532       return success();
1533     default:
1534       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1535       break;
1536     }
1537 
1538     if (!isUnsigned) {
1539       // If the rhs value is > SignedMax, fold the comparison.  This handles
1540       // +INF and large values.
1541       APFloat signedMax(rhs.getSemantics());
1542       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1543                                  APFloat::rmNearestTiesToEven);
1544       if (signedMax < rhs) { // smax < 13123.0
1545         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1546             pred == CmpIPredicate::sle)
1547           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1548                                                      /*width=*/1);
1549         else
1550           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1551                                                      /*width=*/1);
1552         return success();
1553       }
1554     } else {
1555       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1556       // +INF and large values.
1557       APFloat unsignedMax(rhs.getSemantics());
1558       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1559                                    APFloat::rmNearestTiesToEven);
1560       if (unsignedMax < rhs) { // umax < 13123.0
1561         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1562             pred == CmpIPredicate::ule)
1563           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1564                                                      /*width=*/1);
1565         else
1566           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1567                                                      /*width=*/1);
1568         return success();
1569       }
1570     }
1571 
1572     if (!isUnsigned) {
1573       // See if the rhs value is < SignedMin.
1574       APFloat signedMin(rhs.getSemantics());
1575       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1576                                  APFloat::rmNearestTiesToEven);
1577       if (signedMin > rhs) { // smin > 12312.0
1578         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1579             pred == CmpIPredicate::sge)
1580           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1581                                                      /*width=*/1);
1582         else
1583           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1584                                                      /*width=*/1);
1585         return success();
1586       }
1587     } else {
1588       // See if the rhs value is < UnsignedMin.
1589       APFloat unsignedMin(rhs.getSemantics());
1590       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1591                                    APFloat::rmNearestTiesToEven);
1592       if (unsignedMin > rhs) { // umin > 12312.0
1593         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1594             pred == CmpIPredicate::uge)
1595           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1596                                                      /*width=*/1);
1597         else
1598           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1599                                                      /*width=*/1);
1600         return success();
1601       }
1602     }
1603 
1604     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1605     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1606     // casting the FP value to the integer value and back, checking for
1607     // equality. Don't do this for zero, because -0.0 is not fractional.
1608     bool ignored;
1609     APSInt rhsInt(intWidth, isUnsigned);
1610     if (APFloat::opInvalidOp ==
1611         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1612       // Undefined behavior invoked - the destination type can't represent
1613       // the input constant.
1614       return failure();
1615     }
1616 
1617     if (!rhs.isZero()) {
1618       APFloat apf(floatTy.getFloatSemantics(),
1619                   APInt::getZero(floatTy.getWidth()));
1620       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1621 
1622       bool equal = apf == rhs;
1623       if (!equal) {
1624         // If we had a comparison against a fractional value, we have to adjust
1625         // the compare predicate and sometimes the value.  rhsInt is rounded
1626         // towards zero at this point.
1627         switch (pred) {
1628         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1629           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1630                                                      /*width=*/1);
1631           return success();
1632         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1633           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1634                                                      /*width=*/1);
1635           return success();
1636         case CmpIPredicate::ule:
1637           // (float)int <= 4.4   --> int <= 4
1638           // (float)int <= -4.4  --> false
1639           if (rhs.isNegative()) {
1640             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1641                                                        /*width=*/1);
1642             return success();
1643           }
1644           break;
1645         case CmpIPredicate::sle:
1646           // (float)int <= 4.4   --> int <= 4
1647           // (float)int <= -4.4  --> int < -4
1648           if (rhs.isNegative())
1649             pred = CmpIPredicate::slt;
1650           break;
1651         case CmpIPredicate::ult:
1652           // (float)int < -4.4   --> false
1653           // (float)int < 4.4    --> int <= 4
1654           if (rhs.isNegative()) {
1655             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1656                                                        /*width=*/1);
1657             return success();
1658           }
1659           pred = CmpIPredicate::ule;
1660           break;
1661         case CmpIPredicate::slt:
1662           // (float)int < -4.4   --> int < -4
1663           // (float)int < 4.4    --> int <= 4
1664           if (!rhs.isNegative())
1665             pred = CmpIPredicate::sle;
1666           break;
1667         case CmpIPredicate::ugt:
1668           // (float)int > 4.4    --> int > 4
1669           // (float)int > -4.4   --> true
1670           if (rhs.isNegative()) {
1671             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1672                                                        /*width=*/1);
1673             return success();
1674           }
1675           break;
1676         case CmpIPredicate::sgt:
1677           // (float)int > 4.4    --> int > 4
1678           // (float)int > -4.4   --> int >= -4
1679           if (rhs.isNegative())
1680             pred = CmpIPredicate::sge;
1681           break;
1682         case CmpIPredicate::uge:
1683           // (float)int >= -4.4   --> true
1684           // (float)int >= 4.4    --> int > 4
1685           if (rhs.isNegative()) {
1686             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1687                                                        /*width=*/1);
1688             return success();
1689           }
1690           pred = CmpIPredicate::ugt;
1691           break;
1692         case CmpIPredicate::sge:
1693           // (float)int >= -4.4   --> int >= -4
1694           // (float)int >= 4.4    --> int > 4
1695           if (!rhs.isNegative())
1696             pred = CmpIPredicate::sgt;
1697           break;
1698         }
1699       }
1700     }
1701 
1702     // Lower this FP comparison into an appropriate integer version of the
1703     // comparison.
1704     rewriter.replaceOpWithNewOp<CmpIOp>(
1705         op, pred, intVal,
1706         rewriter.create<ConstantOp>(
1707             op.getLoc(), intVal.getType(),
1708             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1709     return success();
1710   }
1711 };
1712 
1713 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1714                                                 MLIRContext *context) {
1715   patterns.insert<CmpFIntToFPConst>(context);
1716 }
1717 
1718 //===----------------------------------------------------------------------===//
1719 // SelectOp
1720 //===----------------------------------------------------------------------===//
1721 
1722 // Transforms a select of a boolean to arithmetic operations
1723 //
1724 //  arith.select %arg, %x, %y : i1
1725 //
1726 //  becomes
1727 //
1728 //  and(%arg, %x) or and(!%arg, %y)
1729 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1730   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1731 
1732   LogicalResult matchAndRewrite(arith::SelectOp op,
1733                                 PatternRewriter &rewriter) const override {
1734     if (!op.getType().isInteger(1))
1735       return failure();
1736 
1737     Value falseConstant =
1738         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1739     Value notCondition = rewriter.create<arith::XOrIOp>(
1740         op.getLoc(), op.getCondition(), falseConstant);
1741 
1742     Value trueVal = rewriter.create<arith::AndIOp>(
1743         op.getLoc(), op.getCondition(), op.getTrueValue());
1744     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1745                                                     op.getFalseValue());
1746     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1747     return success();
1748   }
1749 };
1750 
1751 //  select %arg, %c1, %c0 => extui %arg
1752 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1753   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1754 
1755   LogicalResult matchAndRewrite(arith::SelectOp op,
1756                                 PatternRewriter &rewriter) const override {
1757     // Cannot extui i1 to i1, or i1 to f32
1758     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1759       return failure();
1760 
1761     // select %x, c1, %c0 => extui %arg
1762     if (matchPattern(op.getTrueValue(), m_One()))
1763       if (matchPattern(op.getFalseValue(), m_Zero())) {
1764         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1765                                                     op.getCondition());
1766         return success();
1767       }
1768 
1769     // select %x, c0, %c1 => extui (xor %arg, true)
1770     if (matchPattern(op.getTrueValue(), m_Zero()))
1771       if (matchPattern(op.getFalseValue(), m_One())) {
1772         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1773             op, op.getType(),
1774             rewriter.create<arith::XOrIOp>(
1775                 op.getLoc(), op.getCondition(),
1776                 rewriter.create<arith::ConstantIntOp>(
1777                     op.getLoc(), 1, op.getCondition().getType())));
1778         return success();
1779       }
1780 
1781     return failure();
1782   }
1783 };
1784 
1785 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1786                                                   MLIRContext *context) {
1787   results.add<SelectI1Simplify, SelectToExtUI>(context);
1788 }
1789 
1790 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1791   Value trueVal = getTrueValue();
1792   Value falseVal = getFalseValue();
1793   if (trueVal == falseVal)
1794     return trueVal;
1795 
1796   Value condition = getCondition();
1797 
1798   // select true, %0, %1 => %0
1799   if (matchPattern(condition, m_One()))
1800     return trueVal;
1801 
1802   // select false, %0, %1 => %1
1803   if (matchPattern(condition, m_Zero()))
1804     return falseVal;
1805 
1806   // select %x, true, false => %x
1807   if (getType().isInteger(1))
1808     if (matchPattern(getTrueValue(), m_One()))
1809       if (matchPattern(getFalseValue(), m_Zero()))
1810         return condition;
1811 
1812   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1813     auto pred = cmp.getPredicate();
1814     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1815       auto cmpLhs = cmp.getLhs();
1816       auto cmpRhs = cmp.getRhs();
1817 
1818       // %0 = arith.cmpi eq, %arg0, %arg1
1819       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1820 
1821       // %0 = arith.cmpi ne, %arg0, %arg1
1822       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1823 
1824       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1825           (cmpRhs == trueVal && cmpLhs == falseVal))
1826         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1827     }
1828   }
1829   return nullptr;
1830 }
1831 
1832 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1833   Type conditionType, resultType;
1834   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1835   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1836       parser.parseOptionalAttrDict(result.attributes) ||
1837       parser.parseColonType(resultType))
1838     return failure();
1839 
1840   // Check for the explicit condition type if this is a masked tensor or vector.
1841   if (succeeded(parser.parseOptionalComma())) {
1842     conditionType = resultType;
1843     if (parser.parseType(resultType))
1844       return failure();
1845   } else {
1846     conditionType = parser.getBuilder().getI1Type();
1847   }
1848 
1849   result.addTypes(resultType);
1850   return parser.resolveOperands(operands,
1851                                 {conditionType, resultType, resultType},
1852                                 parser.getNameLoc(), result.operands);
1853 }
1854 
1855 void arith::SelectOp::print(OpAsmPrinter &p) {
1856   p << " " << getOperands();
1857   p.printOptionalAttrDict((*this)->getAttrs());
1858   p << " : ";
1859   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1860     p << condType << ", ";
1861   p << getType();
1862 }
1863 
1864 LogicalResult arith::SelectOp::verify() {
1865   Type conditionType = getCondition().getType();
1866   if (conditionType.isSignlessInteger(1))
1867     return success();
1868 
1869   // If the result type is a vector or tensor, the type can be a mask with the
1870   // same elements.
1871   Type resultType = getType();
1872   if (!resultType.isa<TensorType, VectorType>())
1873     return emitOpError() << "expected condition to be a signless i1, but got "
1874                          << conditionType;
1875   Type shapedConditionType = getI1SameShape(resultType);
1876   if (conditionType != shapedConditionType) {
1877     return emitOpError() << "expected condition type to have the same shape "
1878                             "as the result type, expected "
1879                          << shapedConditionType << ", but got "
1880                          << conditionType;
1881   }
1882   return success();
1883 }
1884 //===----------------------------------------------------------------------===//
1885 // ShLIOp
1886 //===----------------------------------------------------------------------===//
1887 
1888 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1889   // Don't fold if shifting more than the bit width.
1890   bool bounded = false;
1891   auto result = constFoldBinaryOp<IntegerAttr>(
1892       operands, [&](const APInt &a, const APInt &b) {
1893         bounded = b.ule(b.getBitWidth());
1894         return a.shl(b);
1895       });
1896   return bounded ? result : Attribute();
1897 }
1898 
1899 //===----------------------------------------------------------------------===//
1900 // ShRUIOp
1901 //===----------------------------------------------------------------------===//
1902 
1903 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1904   // Don't fold if shifting more than the bit width.
1905   bool bounded = false;
1906   auto result = constFoldBinaryOp<IntegerAttr>(
1907       operands, [&](const APInt &a, const APInt &b) {
1908         bounded = b.ule(b.getBitWidth());
1909         return a.lshr(b);
1910       });
1911   return bounded ? result : Attribute();
1912 }
1913 
1914 //===----------------------------------------------------------------------===//
1915 // ShRSIOp
1916 //===----------------------------------------------------------------------===//
1917 
1918 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1919   // Don't fold if shifting more than the bit width.
1920   bool bounded = false;
1921   auto result = constFoldBinaryOp<IntegerAttr>(
1922       operands, [&](const APInt &a, const APInt &b) {
1923         bounded = b.ule(b.getBitWidth());
1924         return a.ashr(b);
1925       });
1926   return bounded ? result : Attribute();
1927 }
1928 
1929 //===----------------------------------------------------------------------===//
1930 // Atomic Enum
1931 //===----------------------------------------------------------------------===//
1932 
1933 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1934 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1935                                             OpBuilder &builder, Location loc) {
1936   switch (kind) {
1937   case AtomicRMWKind::maxf:
1938     return builder.getFloatAttr(
1939         resultType,
1940         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1941                         /*Negative=*/true));
1942   case AtomicRMWKind::addf:
1943   case AtomicRMWKind::addi:
1944   case AtomicRMWKind::maxu:
1945   case AtomicRMWKind::ori:
1946     return builder.getZeroAttr(resultType);
1947   case AtomicRMWKind::andi:
1948     return builder.getIntegerAttr(
1949         resultType,
1950         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1951   case AtomicRMWKind::maxs:
1952     return builder.getIntegerAttr(
1953         resultType,
1954         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1955   case AtomicRMWKind::minf:
1956     return builder.getFloatAttr(
1957         resultType,
1958         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1959                         /*Negative=*/false));
1960   case AtomicRMWKind::mins:
1961     return builder.getIntegerAttr(
1962         resultType,
1963         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1964   case AtomicRMWKind::minu:
1965     return builder.getIntegerAttr(
1966         resultType,
1967         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1968   case AtomicRMWKind::muli:
1969     return builder.getIntegerAttr(resultType, 1);
1970   case AtomicRMWKind::mulf:
1971     return builder.getFloatAttr(resultType, 1);
1972   // TODO: Add remaining reduction operations.
1973   default:
1974     (void)emitOptionalError(loc, "Reduction operation type not supported");
1975     break;
1976   }
1977   return nullptr;
1978 }
1979 
1980 /// Returns the identity value associated with an AtomicRMWKind op.
1981 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1982                                     OpBuilder &builder, Location loc) {
1983   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1984   return builder.create<arith::ConstantOp>(loc, attr);
1985 }
1986 
1987 /// Return the value obtained by applying the reduction operation kind
1988 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1989 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1990                                   Location loc, Value lhs, Value rhs) {
1991   switch (op) {
1992   case AtomicRMWKind::addf:
1993     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1994   case AtomicRMWKind::addi:
1995     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1996   case AtomicRMWKind::mulf:
1997     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1998   case AtomicRMWKind::muli:
1999     return builder.create<arith::MulIOp>(loc, lhs, rhs);
2000   case AtomicRMWKind::maxf:
2001     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
2002   case AtomicRMWKind::minf:
2003     return builder.create<arith::MinFOp>(loc, lhs, rhs);
2004   case AtomicRMWKind::maxs:
2005     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2006   case AtomicRMWKind::mins:
2007     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2008   case AtomicRMWKind::maxu:
2009     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2010   case AtomicRMWKind::minu:
2011     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2012   case AtomicRMWKind::ori:
2013     return builder.create<arith::OrIOp>(loc, lhs, rhs);
2014   case AtomicRMWKind::andi:
2015     return builder.create<arith::AndIOp>(loc, lhs, rhs);
2016   // TODO: Add remaining reduction operations.
2017   default:
2018     (void)emitOptionalError(loc, "Reduction operation type not supported");
2019     break;
2020   }
2021   return nullptr;
2022 }
2023 
2024 //===----------------------------------------------------------------------===//
2025 // TableGen'd op method definitions
2026 //===----------------------------------------------------------------------===//
2027 
2028 #define GET_OP_CLASSES
2029 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2030 
2031 //===----------------------------------------------------------------------===//
2032 // TableGen'd enum attribute definitions
2033 //===----------------------------------------------------------------------===//
2034 
2035 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2036