1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #include "llvm/ADT/APSInt.h"
20 
21 using namespace mlir;
22 using namespace mlir::arith;
23 
24 //===----------------------------------------------------------------------===//
25 // Pattern helpers
26 //===----------------------------------------------------------------------===//
27 
28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
29                                    Attribute lhs, Attribute rhs) {
30   return builder.getIntegerAttr(res.getType(),
31                                 lhs.cast<IntegerAttr>().getInt() +
32                                     rhs.cast<IntegerAttr>().getInt());
33 }
34 
35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
36                                    Attribute lhs, Attribute rhs) {
37   return builder.getIntegerAttr(res.getType(),
38                                 lhs.cast<IntegerAttr>().getInt() -
39                                     rhs.cast<IntegerAttr>().getInt());
40 }
41 
42 /// Invert an integer comparison predicate.
43 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
44   switch (pred) {
45   case arith::CmpIPredicate::eq:
46     return arith::CmpIPredicate::ne;
47   case arith::CmpIPredicate::ne:
48     return arith::CmpIPredicate::eq;
49   case arith::CmpIPredicate::slt:
50     return arith::CmpIPredicate::sge;
51   case arith::CmpIPredicate::sle:
52     return arith::CmpIPredicate::sgt;
53   case arith::CmpIPredicate::sgt:
54     return arith::CmpIPredicate::sle;
55   case arith::CmpIPredicate::sge:
56     return arith::CmpIPredicate::slt;
57   case arith::CmpIPredicate::ult:
58     return arith::CmpIPredicate::uge;
59   case arith::CmpIPredicate::ule:
60     return arith::CmpIPredicate::ugt;
61   case arith::CmpIPredicate::ugt:
62     return arith::CmpIPredicate::ule;
63   case arith::CmpIPredicate::uge:
64     return arith::CmpIPredicate::ult;
65   }
66   llvm_unreachable("unknown cmpi predicate kind");
67 }
68 
69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
70   return arith::CmpIPredicateAttr::get(pred.getContext(),
71                                        invertPredicate(pred.getValue()));
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // TableGen'd canonicalization patterns
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 #include "ArithmeticCanonicalization.inc"
80 } // namespace
81 
82 //===----------------------------------------------------------------------===//
83 // ConstantOp
84 //===----------------------------------------------------------------------===//
85 
86 void arith::ConstantOp::getAsmResultNames(
87     function_ref<void(Value, StringRef)> setNameFn) {
88   auto type = getType();
89   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
90     auto intType = type.dyn_cast<IntegerType>();
91 
92     // Sugar i1 constants with 'true' and 'false'.
93     if (intType && intType.getWidth() == 1)
94       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
95 
96     // Otherwise, build a compex name with the value and type.
97     SmallString<32> specialNameBuffer;
98     llvm::raw_svector_ostream specialName(specialNameBuffer);
99     specialName << 'c' << intCst.getInt();
100     if (intType)
101       specialName << '_' << type;
102     setNameFn(getResult(), specialName.str());
103   } else {
104     setNameFn(getResult(), "cst");
105   }
106 }
107 
108 /// TODO: disallow arith.constant to return anything other than signless integer
109 /// or float like.
110 LogicalResult arith::ConstantOp::verify() {
111   auto type = getType();
112   // The value's type must match the return type.
113   if (getValue().getType() != type) {
114     return emitOpError() << "value type " << getValue().getType()
115                          << " must match return type: " << type;
116   }
117   // Integer values must be signless.
118   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
119     return emitOpError("integer return type must be signless");
120   // Any float or elements attribute are acceptable.
121   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
122     return emitOpError(
123         "value must be an integer, float, or elements attribute");
124   }
125   return success();
126 }
127 
128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
129   // The value's type must be the same as the provided type.
130   if (value.getType() != type)
131     return false;
132   // Integer values must be signless.
133   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
134     return false;
135   // Integer, float, and element attributes are buildable.
136   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
137 }
138 
139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
140   return getValue();
141 }
142 
143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
144                                  int64_t value, unsigned width) {
145   auto type = builder.getIntegerType(width);
146   arith::ConstantOp::build(builder, result, type,
147                            builder.getIntegerAttr(type, value));
148 }
149 
150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
151                                  int64_t value, Type type) {
152   assert(type.isSignlessInteger() &&
153          "ConstantIntOp can only have signless integer type values");
154   arith::ConstantOp::build(builder, result, type,
155                            builder.getIntegerAttr(type, value));
156 }
157 
158 bool arith::ConstantIntOp::classof(Operation *op) {
159   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
160     return constOp.getType().isSignlessInteger();
161   return false;
162 }
163 
164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
165                                    const APFloat &value, FloatType type) {
166   arith::ConstantOp::build(builder, result, type,
167                            builder.getFloatAttr(type, value));
168 }
169 
170 bool arith::ConstantFloatOp::classof(Operation *op) {
171   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
172     return constOp.getType().isa<FloatType>();
173   return false;
174 }
175 
176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
177                                    int64_t value) {
178   arith::ConstantOp::build(builder, result, builder.getIndexType(),
179                            builder.getIndexAttr(value));
180 }
181 
182 bool arith::ConstantIndexOp::classof(Operation *op) {
183   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
184     return constOp.getType().isIndex();
185   return false;
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // AddIOp
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
193   // addi(x, 0) -> x
194   if (matchPattern(getRhs(), m_Zero()))
195     return getLhs();
196 
197   // addi(subi(a, b), b) -> a
198   if (auto sub = getLhs().getDefiningOp<SubIOp>())
199     if (getRhs() == sub.getRhs())
200       return sub.getLhs();
201 
202   // addi(b, subi(a, b)) -> a
203   if (auto sub = getRhs().getDefiningOp<SubIOp>())
204     if (getLhs() == sub.getRhs())
205       return sub.getLhs();
206 
207   return constFoldBinaryOp<IntegerAttr>(
208       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
209 }
210 
211 void arith::AddIOp::getCanonicalizationPatterns(
212     RewritePatternSet &patterns, MLIRContext *context) {
213   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
214       context);
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // SubIOp
219 //===----------------------------------------------------------------------===//
220 
221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
222   // subi(x,x) -> 0
223   if (getOperand(0) == getOperand(1))
224     return Builder(getContext()).getZeroAttr(getType());
225   // subi(x,0) -> x
226   if (matchPattern(getRhs(), m_Zero()))
227     return getLhs();
228 
229   return constFoldBinaryOp<IntegerAttr>(
230       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
231 }
232 
233 void arith::SubIOp::getCanonicalizationPatterns(
234     RewritePatternSet &patterns, MLIRContext *context) {
235   patterns
236       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
237            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
238           context);
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // MulIOp
243 //===----------------------------------------------------------------------===//
244 
245 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
246   // muli(x, 0) -> 0
247   if (matchPattern(getRhs(), m_Zero()))
248     return getRhs();
249   // muli(x, 1) -> x
250   if (matchPattern(getRhs(), m_One()))
251     return getOperand(0);
252   // TODO: Handle the overflow case.
253 
254   // default folder
255   return constFoldBinaryOp<IntegerAttr>(
256       operands, [](const APInt &a, const APInt &b) { return a * b; });
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // DivUIOp
261 //===----------------------------------------------------------------------===//
262 
263 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
264   // Don't fold if it would require a division by zero.
265   bool div0 = false;
266   auto result =
267       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
268         if (div0 || !b) {
269           div0 = true;
270           return a;
271         }
272         return a.udiv(b);
273       });
274 
275   // Fold out division by one. Assumes all tensors of all ones are splats.
276   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
277     if (rhs.getValue() == 1)
278       return getLhs();
279   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
280     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
281       return getLhs();
282   }
283 
284   return div0 ? Attribute() : result;
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // DivSIOp
289 //===----------------------------------------------------------------------===//
290 
291 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
292   // Don't fold if it would overflow or if it requires a division by zero.
293   bool overflowOrDiv0 = false;
294   auto result =
295       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
296         if (overflowOrDiv0 || !b) {
297           overflowOrDiv0 = true;
298           return a;
299         }
300         return a.sdiv_ov(b, overflowOrDiv0);
301       });
302 
303   // Fold out division by one. Assumes all tensors of all ones are splats.
304   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
305     if (rhs.getValue() == 1)
306       return getLhs();
307   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
308     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
309       return getLhs();
310   }
311 
312   return overflowOrDiv0 ? Attribute() : result;
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Ceil and floor division folding helpers
317 //===----------------------------------------------------------------------===//
318 
319 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
320                                     bool &overflow) {
321   // Returns (a-1)/b + 1
322   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
323   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
324   return val.sadd_ov(one, overflow);
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // CeilDivUIOp
329 //===----------------------------------------------------------------------===//
330 
331 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
332   bool overflowOrDiv0 = false;
333   auto result =
334       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
335         if (overflowOrDiv0 || !b) {
336           overflowOrDiv0 = true;
337           return a;
338         }
339         APInt quotient = a.udiv(b);
340         if (!a.urem(b))
341           return quotient;
342         APInt one(a.getBitWidth(), 1, true);
343         return quotient.uadd_ov(one, overflowOrDiv0);
344       });
345   // Fold out ceil division by one. Assumes all tensors of all ones are
346   // splats.
347   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
348     if (rhs.getValue() == 1)
349       return getLhs();
350   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
351     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
352       return getLhs();
353   }
354 
355   return overflowOrDiv0 ? Attribute() : result;
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // CeilDivSIOp
360 //===----------------------------------------------------------------------===//
361 
362 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
363   // Don't fold if it would overflow or if it requires a division by zero.
364   bool overflowOrDiv0 = false;
365   auto result =
366       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
367         if (overflowOrDiv0 || !b) {
368           overflowOrDiv0 = true;
369           return a;
370         }
371         if (!a)
372           return a;
373         // After this point we know that neither a or b are zero.
374         unsigned bits = a.getBitWidth();
375         APInt zero = APInt::getZero(bits);
376         bool aGtZero = a.sgt(zero);
377         bool bGtZero = b.sgt(zero);
378         if (aGtZero && bGtZero) {
379           // Both positive, return ceil(a, b).
380           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
381         }
382         if (!aGtZero && !bGtZero) {
383           // Both negative, return ceil(-a, -b).
384           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
385           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
386           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
387         }
388         if (!aGtZero && bGtZero) {
389           // A is negative, b is positive, return - ( -a / b).
390           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
391           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
392           return zero.ssub_ov(div, overflowOrDiv0);
393         }
394         // A is positive, b is negative, return - (a / -b).
395         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
396         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
397         return zero.ssub_ov(div, overflowOrDiv0);
398       });
399 
400   // Fold out ceil division by one. Assumes all tensors of all ones are
401   // splats.
402   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
403     if (rhs.getValue() == 1)
404       return getLhs();
405   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
406     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
407       return getLhs();
408   }
409 
410   return overflowOrDiv0 ? Attribute() : result;
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // FloorDivSIOp
415 //===----------------------------------------------------------------------===//
416 
417 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
418   // Don't fold if it would overflow or if it requires a division by zero.
419   bool overflowOrDiv0 = false;
420   auto result =
421       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
422         if (overflowOrDiv0 || !b) {
423           overflowOrDiv0 = true;
424           return a;
425         }
426         if (!a)
427           return a;
428         // After this point we know that neither a or b are zero.
429         unsigned bits = a.getBitWidth();
430         APInt zero = APInt::getZero(bits);
431         bool aGtZero = a.sgt(zero);
432         bool bGtZero = b.sgt(zero);
433         if (aGtZero && bGtZero) {
434           // Both positive, return a / b.
435           return a.sdiv_ov(b, overflowOrDiv0);
436         }
437         if (!aGtZero && !bGtZero) {
438           // Both negative, return -a / -b.
439           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
440           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
441           return posA.sdiv_ov(posB, overflowOrDiv0);
442         }
443         if (!aGtZero && bGtZero) {
444           // A is negative, b is positive, return - ceil(-a, b).
445           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
446           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
447           return zero.ssub_ov(ceil, overflowOrDiv0);
448         }
449         // A is positive, b is negative, return - ceil(a, -b).
450         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
451         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
452         return zero.ssub_ov(ceil, overflowOrDiv0);
453       });
454 
455   // Fold out floor division by one. Assumes all tensors of all ones are
456   // splats.
457   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
458     if (rhs.getValue() == 1)
459       return getLhs();
460   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
461     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
462       return getLhs();
463   }
464 
465   return overflowOrDiv0 ? Attribute() : result;
466 }
467 
468 //===----------------------------------------------------------------------===//
469 // RemUIOp
470 //===----------------------------------------------------------------------===//
471 
472 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
473   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
474   if (!rhs)
475     return {};
476   auto rhsValue = rhs.getValue();
477 
478   // x % 1 = 0
479   if (rhsValue.isOneValue())
480     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
481 
482   // Don't fold if it requires division by zero.
483   if (rhsValue.isNullValue())
484     return {};
485 
486   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
487   if (!lhs)
488     return {};
489   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // RemSIOp
494 //===----------------------------------------------------------------------===//
495 
496 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
497   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
498   if (!rhs)
499     return {};
500   auto rhsValue = rhs.getValue();
501 
502   // x % 1 = 0
503   if (rhsValue.isOneValue())
504     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
505 
506   // Don't fold if it requires division by zero.
507   if (rhsValue.isNullValue())
508     return {};
509 
510   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
511   if (!lhs)
512     return {};
513   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // AndIOp
518 //===----------------------------------------------------------------------===//
519 
520 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
521   /// and(x, 0) -> 0
522   if (matchPattern(getRhs(), m_Zero()))
523     return getRhs();
524   /// and(x, allOnes) -> x
525   APInt intValue;
526   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
527     return getLhs();
528 
529   return constFoldBinaryOp<IntegerAttr>(
530       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
531 }
532 
533 //===----------------------------------------------------------------------===//
534 // OrIOp
535 //===----------------------------------------------------------------------===//
536 
537 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
538   /// or(x, 0) -> x
539   if (matchPattern(getRhs(), m_Zero()))
540     return getLhs();
541   /// or(x, <all ones>) -> <all ones>
542   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
543     if (rhsAttr.getValue().isAllOnes())
544       return rhsAttr;
545 
546   return constFoldBinaryOp<IntegerAttr>(
547       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // XOrIOp
552 //===----------------------------------------------------------------------===//
553 
554 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
555   /// xor(x, 0) -> x
556   if (matchPattern(getRhs(), m_Zero()))
557     return getLhs();
558   /// xor(x, x) -> 0
559   if (getLhs() == getRhs())
560     return Builder(getContext()).getZeroAttr(getType());
561   /// xor(xor(x, a), a) -> x
562   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
563     if (prev.getRhs() == getRhs())
564       return prev.getLhs();
565 
566   return constFoldBinaryOp<IntegerAttr>(
567       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
568 }
569 
570 void arith::XOrIOp::getCanonicalizationPatterns(
571     RewritePatternSet &patterns, MLIRContext *context) {
572   patterns.add<XOrINotCmpI>(context);
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // AddFOp
577 //===----------------------------------------------------------------------===//
578 
579 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
580   // addf(x, -0) -> x
581   if (matchPattern(getRhs(), m_NegZeroFloat()))
582     return getLhs();
583 
584   return constFoldBinaryOp<FloatAttr>(
585       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // SubFOp
590 //===----------------------------------------------------------------------===//
591 
592 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
593   // subf(x, +0) -> x
594   if (matchPattern(getRhs(), m_PosZeroFloat()))
595     return getLhs();
596 
597   return constFoldBinaryOp<FloatAttr>(
598       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // MaxFOp
603 //===----------------------------------------------------------------------===//
604 
605 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
606   assert(operands.size() == 2 && "maxf takes two operands");
607 
608   // maxf(x,x) -> x
609   if (getLhs() == getRhs())
610     return getRhs();
611 
612   // maxf(x, -inf) -> x
613   if (matchPattern(getRhs(), m_NegInfFloat()))
614     return getLhs();
615 
616   return constFoldBinaryOp<FloatAttr>(
617       operands,
618       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // MaxSIOp
623 //===----------------------------------------------------------------------===//
624 
625 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
626   assert(operands.size() == 2 && "binary operation takes two operands");
627 
628   // maxsi(x,x) -> x
629   if (getLhs() == getRhs())
630     return getRhs();
631 
632   APInt intValue;
633   // maxsi(x,MAX_INT) -> MAX_INT
634   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
635       intValue.isMaxSignedValue())
636     return getRhs();
637 
638   // maxsi(x, MIN_INT) -> x
639   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
640       intValue.isMinSignedValue())
641     return getLhs();
642 
643   return constFoldBinaryOp<IntegerAttr>(operands,
644                                         [](const APInt &a, const APInt &b) {
645                                           return llvm::APIntOps::smax(a, b);
646                                         });
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // MaxUIOp
651 //===----------------------------------------------------------------------===//
652 
653 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
654   assert(operands.size() == 2 && "binary operation takes two operands");
655 
656   // maxui(x,x) -> x
657   if (getLhs() == getRhs())
658     return getRhs();
659 
660   APInt intValue;
661   // maxui(x,MAX_INT) -> MAX_INT
662   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
663     return getRhs();
664 
665   // maxui(x, MIN_INT) -> x
666   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
667     return getLhs();
668 
669   return constFoldBinaryOp<IntegerAttr>(operands,
670                                         [](const APInt &a, const APInt &b) {
671                                           return llvm::APIntOps::umax(a, b);
672                                         });
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // MinFOp
677 //===----------------------------------------------------------------------===//
678 
679 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
680   assert(operands.size() == 2 && "minf takes two operands");
681 
682   // minf(x,x) -> x
683   if (getLhs() == getRhs())
684     return getRhs();
685 
686   // minf(x, +inf) -> x
687   if (matchPattern(getRhs(), m_PosInfFloat()))
688     return getLhs();
689 
690   return constFoldBinaryOp<FloatAttr>(
691       operands,
692       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
693 }
694 
695 //===----------------------------------------------------------------------===//
696 // MinSIOp
697 //===----------------------------------------------------------------------===//
698 
699 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
700   assert(operands.size() == 2 && "binary operation takes two operands");
701 
702   // minsi(x,x) -> x
703   if (getLhs() == getRhs())
704     return getRhs();
705 
706   APInt intValue;
707   // minsi(x,MIN_INT) -> MIN_INT
708   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
709       intValue.isMinSignedValue())
710     return getRhs();
711 
712   // minsi(x, MAX_INT) -> x
713   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
714       intValue.isMaxSignedValue())
715     return getLhs();
716 
717   return constFoldBinaryOp<IntegerAttr>(operands,
718                                         [](const APInt &a, const APInt &b) {
719                                           return llvm::APIntOps::smin(a, b);
720                                         });
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // MinUIOp
725 //===----------------------------------------------------------------------===//
726 
727 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
728   assert(operands.size() == 2 && "binary operation takes two operands");
729 
730   // minui(x,x) -> x
731   if (getLhs() == getRhs())
732     return getRhs();
733 
734   APInt intValue;
735   // minui(x,MIN_INT) -> MIN_INT
736   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
737     return getRhs();
738 
739   // minui(x, MAX_INT) -> x
740   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
741     return getLhs();
742 
743   return constFoldBinaryOp<IntegerAttr>(operands,
744                                         [](const APInt &a, const APInt &b) {
745                                           return llvm::APIntOps::umin(a, b);
746                                         });
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // MulFOp
751 //===----------------------------------------------------------------------===//
752 
753 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
754   APFloat floatValue(0.0f), inverseValue(0.0f);
755   // mulf(x, 1) -> x
756   if (matchPattern(getRhs(), m_OneFloat()))
757     return getLhs();
758 
759   // mulf(1, x) -> x
760   if (matchPattern(getLhs(), m_OneFloat()))
761     return getRhs();
762 
763   return constFoldBinaryOp<FloatAttr>(
764       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
765 }
766 
767 //===----------------------------------------------------------------------===//
768 // DivFOp
769 //===----------------------------------------------------------------------===//
770 
771 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
772   APFloat floatValue(0.0f), inverseValue(0.0f);
773   // divf(x, 1) -> x
774   if (matchPattern(getRhs(), m_OneFloat()))
775     return getLhs();
776 
777   return constFoldBinaryOp<FloatAttr>(
778       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
779 }
780 
781 //===----------------------------------------------------------------------===//
782 // Utility functions for verifying cast ops
783 //===----------------------------------------------------------------------===//
784 
785 template <typename... Types>
786 using type_list = std::tuple<Types...> *;
787 
788 /// Returns a non-null type only if the provided type is one of the allowed
789 /// types or one of the allowed shaped types of the allowed types. Returns the
790 /// element type if a valid shaped type is provided.
791 template <typename... ShapedTypes, typename... ElementTypes>
792 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
793                               type_list<ElementTypes...>) {
794   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
795     return {};
796 
797   auto underlyingType = getElementTypeOrSelf(type);
798   if (!underlyingType.isa<ElementTypes...>())
799     return {};
800 
801   return underlyingType;
802 }
803 
804 /// Get allowed underlying types for vectors and tensors.
805 template <typename... ElementTypes>
806 static Type getTypeIfLike(Type type) {
807   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
808                            type_list<ElementTypes...>());
809 }
810 
811 /// Get allowed underlying types for vectors, tensors, and memrefs.
812 template <typename... ElementTypes>
813 static Type getTypeIfLikeOrMemRef(Type type) {
814   return getUnderlyingType(type,
815                            type_list<VectorType, TensorType, MemRefType>(),
816                            type_list<ElementTypes...>());
817 }
818 
819 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
820   return inputs.size() == 1 && outputs.size() == 1 &&
821          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // Verifiers for integer and floating point extension/truncation ops
826 //===----------------------------------------------------------------------===//
827 
828 // Extend ops can only extend to a wider type.
829 template <typename ValType, typename Op>
830 static LogicalResult verifyExtOp(Op op) {
831   Type srcType = getElementTypeOrSelf(op.getIn().getType());
832   Type dstType = getElementTypeOrSelf(op.getType());
833 
834   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
835     return op.emitError("result type ")
836            << dstType << " must be wider than operand type " << srcType;
837 
838   return success();
839 }
840 
841 // Truncate ops can only truncate to a shorter type.
842 template <typename ValType, typename Op>
843 static LogicalResult verifyTruncateOp(Op op) {
844   Type srcType = getElementTypeOrSelf(op.getIn().getType());
845   Type dstType = getElementTypeOrSelf(op.getType());
846 
847   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
848     return op.emitError("result type ")
849            << dstType << " must be shorter than operand type " << srcType;
850 
851   return success();
852 }
853 
854 /// Validate a cast that changes the width of a type.
855 template <template <typename> class WidthComparator, typename... ElementTypes>
856 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
857   if (!areValidCastInputsAndOutputs(inputs, outputs))
858     return false;
859 
860   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
861   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
862   if (!srcType || !dstType)
863     return false;
864 
865   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
866                                      srcType.getIntOrFloatBitWidth());
867 }
868 
869 //===----------------------------------------------------------------------===//
870 // ExtUIOp
871 //===----------------------------------------------------------------------===//
872 
873 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
874   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
875     return IntegerAttr::get(
876         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
877 
878   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
879     getInMutable().assign(lhs.getIn());
880     return getResult();
881   }
882 
883   return {};
884 }
885 
886 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
887   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
888 }
889 
890 LogicalResult arith::ExtUIOp::verify() {
891   return verifyExtOp<IntegerType>(*this);
892 }
893 
894 //===----------------------------------------------------------------------===//
895 // ExtSIOp
896 //===----------------------------------------------------------------------===//
897 
898 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
899   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
900     return IntegerAttr::get(
901         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
902 
903   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
904     getInMutable().assign(lhs.getIn());
905     return getResult();
906   }
907 
908   return {};
909 }
910 
911 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
912   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
913 }
914 
915 void arith::ExtSIOp::getCanonicalizationPatterns(
916     RewritePatternSet &patterns, MLIRContext *context) {
917   patterns.add<ExtSIOfExtUI>(context);
918 }
919 
920 LogicalResult arith::ExtSIOp::verify() {
921   return verifyExtOp<IntegerType>(*this);
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // ExtFOp
926 //===----------------------------------------------------------------------===//
927 
928 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
929   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
930 }
931 
932 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
933 
934 //===----------------------------------------------------------------------===//
935 // TruncIOp
936 //===----------------------------------------------------------------------===//
937 
938 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
939   assert(operands.size() == 1 && "unary operation takes one operand");
940 
941   // trunci(zexti(a)) -> a
942   // trunci(sexti(a)) -> a
943   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
944       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
945     return getOperand().getDefiningOp()->getOperand(0);
946 
947   // trunci(trunci(a)) -> trunci(a))
948   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
949     setOperand(getOperand().getDefiningOp()->getOperand(0));
950     return getResult();
951   }
952 
953   if (!operands[0])
954     return {};
955 
956   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
957     return IntegerAttr::get(
958         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
959   }
960 
961   return {};
962 }
963 
964 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
965   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
966 }
967 
968 LogicalResult arith::TruncIOp::verify() {
969   return verifyTruncateOp<IntegerType>(*this);
970 }
971 
972 //===----------------------------------------------------------------------===//
973 // TruncFOp
974 //===----------------------------------------------------------------------===//
975 
976 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
977 /// can be represented without precision loss or rounding.
978 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
979   assert(operands.size() == 1 && "unary operation takes one operand");
980 
981   auto constOperand = operands.front();
982   if (!constOperand || !constOperand.isa<FloatAttr>())
983     return {};
984 
985   // Convert to target type via 'double'.
986   double sourceValue =
987       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
988   auto targetAttr = FloatAttr::get(getType(), sourceValue);
989 
990   // Propagate if constant's value does not change after truncation.
991   if (sourceValue == targetAttr.getValue().convertToDouble())
992     return targetAttr;
993 
994   return {};
995 }
996 
997 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
998   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
999 }
1000 
1001 LogicalResult arith::TruncFOp::verify() {
1002   return verifyTruncateOp<FloatType>(*this);
1003 }
1004 
1005 //===----------------------------------------------------------------------===//
1006 // AndIOp
1007 //===----------------------------------------------------------------------===//
1008 
1009 void arith::AndIOp::getCanonicalizationPatterns(
1010     RewritePatternSet &patterns, MLIRContext *context) {
1011   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1012 }
1013 
1014 //===----------------------------------------------------------------------===//
1015 // OrIOp
1016 //===----------------------------------------------------------------------===//
1017 
1018 void arith::OrIOp::getCanonicalizationPatterns(
1019     RewritePatternSet &patterns, MLIRContext *context) {
1020   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1021 }
1022 
1023 //===----------------------------------------------------------------------===//
1024 // Verifiers for casts between integers and floats.
1025 //===----------------------------------------------------------------------===//
1026 
1027 template <typename From, typename To>
1028 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1029   if (!areValidCastInputsAndOutputs(inputs, outputs))
1030     return false;
1031 
1032   auto srcType = getTypeIfLike<From>(inputs.front());
1033   auto dstType = getTypeIfLike<To>(outputs.back());
1034 
1035   return srcType && dstType;
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // UIToFPOp
1040 //===----------------------------------------------------------------------===//
1041 
1042 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1043   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1044 }
1045 
1046 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1047   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1048     const APInt &api = lhs.getValue();
1049     FloatType floatTy = getType().cast<FloatType>();
1050     APFloat apf(floatTy.getFloatSemantics(),
1051                 APInt::getZero(floatTy.getWidth()));
1052     apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
1053     return FloatAttr::get(floatTy, apf);
1054   }
1055   return {};
1056 }
1057 
1058 //===----------------------------------------------------------------------===//
1059 // SIToFPOp
1060 //===----------------------------------------------------------------------===//
1061 
1062 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1063   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1064 }
1065 
1066 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1067   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1068     const APInt &api = lhs.getValue();
1069     FloatType floatTy = getType().cast<FloatType>();
1070     APFloat apf(floatTy.getFloatSemantics(),
1071                 APInt::getZero(floatTy.getWidth()));
1072     apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
1073     return FloatAttr::get(floatTy, apf);
1074   }
1075   return {};
1076 }
1077 //===----------------------------------------------------------------------===//
1078 // FPToUIOp
1079 //===----------------------------------------------------------------------===//
1080 
1081 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1082   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1083 }
1084 
1085 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1086   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1087     const APFloat &apf = lhs.getValue();
1088     IntegerType intTy = getType().cast<IntegerType>();
1089     bool ignored;
1090     APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1091     if (APFloat::opInvalidOp ==
1092         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1093       // Undefined behavior invoked - the destination type can't represent
1094       // the input constant.
1095       return {};
1096     }
1097     return IntegerAttr::get(getType(), api);
1098   }
1099 
1100   return {};
1101 }
1102 
1103 //===----------------------------------------------------------------------===//
1104 // FPToSIOp
1105 //===----------------------------------------------------------------------===//
1106 
1107 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1108   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1109 }
1110 
1111 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1112   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1113     const APFloat &apf = lhs.getValue();
1114     IntegerType intTy = getType().cast<IntegerType>();
1115     bool ignored;
1116     APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1117     if (APFloat::opInvalidOp ==
1118         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1119       // Undefined behavior invoked - the destination type can't represent
1120       // the input constant.
1121       return {};
1122     }
1123     return IntegerAttr::get(getType(), api);
1124   }
1125 
1126   return {};
1127 }
1128 
1129 //===----------------------------------------------------------------------===//
1130 // IndexCastOp
1131 //===----------------------------------------------------------------------===//
1132 
1133 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1134                                            TypeRange outputs) {
1135   if (!areValidCastInputsAndOutputs(inputs, outputs))
1136     return false;
1137 
1138   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1139   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1140   if (!srcType || !dstType)
1141     return false;
1142 
1143   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1144          (srcType.isSignlessInteger() && dstType.isIndex());
1145 }
1146 
1147 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1148   // index_cast(constant) -> constant
1149   // A little hack because we go through int. Otherwise, the size of the
1150   // constant might need to change.
1151   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1152     return IntegerAttr::get(getType(), value.getInt());
1153 
1154   return {};
1155 }
1156 
1157 void arith::IndexCastOp::getCanonicalizationPatterns(
1158     RewritePatternSet &patterns, MLIRContext *context) {
1159   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1160 }
1161 
1162 //===----------------------------------------------------------------------===//
1163 // BitcastOp
1164 //===----------------------------------------------------------------------===//
1165 
1166 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1167   if (!areValidCastInputsAndOutputs(inputs, outputs))
1168     return false;
1169 
1170   auto srcType =
1171       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1172   auto dstType =
1173       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1174   if (!srcType || !dstType)
1175     return false;
1176 
1177   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1178 }
1179 
1180 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1181   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1182 
1183   auto resType = getType();
1184   auto operand = operands[0];
1185   if (!operand)
1186     return {};
1187 
1188   /// Bitcast dense elements.
1189   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1190     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1191   /// Other shaped types unhandled.
1192   if (resType.isa<ShapedType>())
1193     return {};
1194 
1195   /// Bitcast integer or float to integer or float.
1196   APInt bits = operand.isa<FloatAttr>()
1197                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1198                    : operand.cast<IntegerAttr>().getValue();
1199 
1200   if (auto resFloatType = resType.dyn_cast<FloatType>())
1201     return FloatAttr::get(resType,
1202                           APFloat(resFloatType.getFloatSemantics(), bits));
1203   return IntegerAttr::get(resType, bits);
1204 }
1205 
1206 void arith::BitcastOp::getCanonicalizationPatterns(
1207     RewritePatternSet &patterns, MLIRContext *context) {
1208   patterns.add<BitcastOfBitcast>(context);
1209 }
1210 
1211 //===----------------------------------------------------------------------===//
1212 // Helpers for compare ops
1213 //===----------------------------------------------------------------------===//
1214 
1215 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1216 static Type getI1SameShape(Type type) {
1217   auto i1Type = IntegerType::get(type.getContext(), 1);
1218   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1219     return RankedTensorType::get(tensorType.getShape(), i1Type);
1220   if (type.isa<UnrankedTensorType>())
1221     return UnrankedTensorType::get(i1Type);
1222   if (auto vectorType = type.dyn_cast<VectorType>())
1223     return VectorType::get(vectorType.getShape(), i1Type,
1224                            vectorType.getNumScalableDims());
1225   return i1Type;
1226 }
1227 
1228 //===----------------------------------------------------------------------===//
1229 // CmpIOp
1230 //===----------------------------------------------------------------------===//
1231 
1232 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1233 /// comparison predicates.
1234 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1235                                     const APInt &lhs, const APInt &rhs) {
1236   switch (predicate) {
1237   case arith::CmpIPredicate::eq:
1238     return lhs.eq(rhs);
1239   case arith::CmpIPredicate::ne:
1240     return lhs.ne(rhs);
1241   case arith::CmpIPredicate::slt:
1242     return lhs.slt(rhs);
1243   case arith::CmpIPredicate::sle:
1244     return lhs.sle(rhs);
1245   case arith::CmpIPredicate::sgt:
1246     return lhs.sgt(rhs);
1247   case arith::CmpIPredicate::sge:
1248     return lhs.sge(rhs);
1249   case arith::CmpIPredicate::ult:
1250     return lhs.ult(rhs);
1251   case arith::CmpIPredicate::ule:
1252     return lhs.ule(rhs);
1253   case arith::CmpIPredicate::ugt:
1254     return lhs.ugt(rhs);
1255   case arith::CmpIPredicate::uge:
1256     return lhs.uge(rhs);
1257   }
1258   llvm_unreachable("unknown cmpi predicate kind");
1259 }
1260 
1261 /// Returns true if the predicate is true for two equal operands.
1262 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1263   switch (predicate) {
1264   case arith::CmpIPredicate::eq:
1265   case arith::CmpIPredicate::sle:
1266   case arith::CmpIPredicate::sge:
1267   case arith::CmpIPredicate::ule:
1268   case arith::CmpIPredicate::uge:
1269     return true;
1270   case arith::CmpIPredicate::ne:
1271   case arith::CmpIPredicate::slt:
1272   case arith::CmpIPredicate::sgt:
1273   case arith::CmpIPredicate::ult:
1274   case arith::CmpIPredicate::ugt:
1275     return false;
1276   }
1277   llvm_unreachable("unknown cmpi predicate kind");
1278 }
1279 
1280 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1281   auto boolAttr = BoolAttr::get(ctx, value);
1282   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1283   if (!shapedType)
1284     return boolAttr;
1285   return DenseElementsAttr::get(shapedType, boolAttr);
1286 }
1287 
1288 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1289   assert(operands.size() == 2 && "cmpi takes two operands");
1290 
1291   // cmpi(pred, x, x)
1292   if (getLhs() == getRhs()) {
1293     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1294     return getBoolAttribute(getType(), getContext(), val);
1295   }
1296 
1297   if (matchPattern(getRhs(), m_Zero())) {
1298     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1299       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1300         // extsi(%x : i1 -> iN) != 0  ->  %x
1301         if (getPredicate() == arith::CmpIPredicate::ne) {
1302           return extOp.getOperand();
1303         }
1304       }
1305     }
1306     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1307       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1308         // extui(%x : i1 -> iN) != 0  ->  %x
1309         if (getPredicate() == arith::CmpIPredicate::ne) {
1310           return extOp.getOperand();
1311         }
1312       }
1313     }
1314   }
1315 
1316   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1317   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1318   if (!lhs || !rhs)
1319     return {};
1320 
1321   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1322   return BoolAttr::get(getContext(), val);
1323 }
1324 
1325 //===----------------------------------------------------------------------===//
1326 // CmpFOp
1327 //===----------------------------------------------------------------------===//
1328 
1329 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1330 /// comparison predicates.
1331 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1332                                     const APFloat &lhs, const APFloat &rhs) {
1333   auto cmpResult = lhs.compare(rhs);
1334   switch (predicate) {
1335   case arith::CmpFPredicate::AlwaysFalse:
1336     return false;
1337   case arith::CmpFPredicate::OEQ:
1338     return cmpResult == APFloat::cmpEqual;
1339   case arith::CmpFPredicate::OGT:
1340     return cmpResult == APFloat::cmpGreaterThan;
1341   case arith::CmpFPredicate::OGE:
1342     return cmpResult == APFloat::cmpGreaterThan ||
1343            cmpResult == APFloat::cmpEqual;
1344   case arith::CmpFPredicate::OLT:
1345     return cmpResult == APFloat::cmpLessThan;
1346   case arith::CmpFPredicate::OLE:
1347     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1348   case arith::CmpFPredicate::ONE:
1349     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1350   case arith::CmpFPredicate::ORD:
1351     return cmpResult != APFloat::cmpUnordered;
1352   case arith::CmpFPredicate::UEQ:
1353     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1354   case arith::CmpFPredicate::UGT:
1355     return cmpResult == APFloat::cmpUnordered ||
1356            cmpResult == APFloat::cmpGreaterThan;
1357   case arith::CmpFPredicate::UGE:
1358     return cmpResult == APFloat::cmpUnordered ||
1359            cmpResult == APFloat::cmpGreaterThan ||
1360            cmpResult == APFloat::cmpEqual;
1361   case arith::CmpFPredicate::ULT:
1362     return cmpResult == APFloat::cmpUnordered ||
1363            cmpResult == APFloat::cmpLessThan;
1364   case arith::CmpFPredicate::ULE:
1365     return cmpResult == APFloat::cmpUnordered ||
1366            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1367   case arith::CmpFPredicate::UNE:
1368     return cmpResult != APFloat::cmpEqual;
1369   case arith::CmpFPredicate::UNO:
1370     return cmpResult == APFloat::cmpUnordered;
1371   case arith::CmpFPredicate::AlwaysTrue:
1372     return true;
1373   }
1374   llvm_unreachable("unknown cmpf predicate kind");
1375 }
1376 
1377 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1378   assert(operands.size() == 2 && "cmpf takes two operands");
1379 
1380   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1381   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1382 
1383   // If one operand is NaN, making them both NaN does not change the result.
1384   if (lhs && lhs.getValue().isNaN())
1385     rhs = lhs;
1386   if (rhs && rhs.getValue().isNaN())
1387     lhs = rhs;
1388 
1389   if (!lhs || !rhs)
1390     return {};
1391 
1392   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1393   return BoolAttr::get(getContext(), val);
1394 }
1395 
1396 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1397 public:
1398   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1399 
1400   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1401                                                  bool isUnsigned) {
1402     using namespace arith;
1403     switch (pred) {
1404     case CmpFPredicate::UEQ:
1405     case CmpFPredicate::OEQ:
1406       return CmpIPredicate::eq;
1407     case CmpFPredicate::UGT:
1408     case CmpFPredicate::OGT:
1409       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1410     case CmpFPredicate::UGE:
1411     case CmpFPredicate::OGE:
1412       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1413     case CmpFPredicate::ULT:
1414     case CmpFPredicate::OLT:
1415       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1416     case CmpFPredicate::ULE:
1417     case CmpFPredicate::OLE:
1418       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1419     case CmpFPredicate::UNE:
1420     case CmpFPredicate::ONE:
1421       return CmpIPredicate::ne;
1422     default:
1423       llvm_unreachable("Unexpected predicate!");
1424     }
1425   }
1426 
1427   LogicalResult matchAndRewrite(CmpFOp op,
1428                                 PatternRewriter &rewriter) const override {
1429     FloatAttr flt;
1430     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1431       return failure();
1432 
1433     const APFloat &rhs = flt.getValue();
1434 
1435     // Don't attempt to fold a nan.
1436     if (rhs.isNaN())
1437       return failure();
1438 
1439     // Get the width of the mantissa.  We don't want to hack on conversions that
1440     // might lose information from the integer, e.g. "i64 -> float"
1441     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1442     int mantissaWidth = floatTy.getFPMantissaWidth();
1443     if (mantissaWidth <= 0)
1444       return failure();
1445 
1446     bool isUnsigned;
1447     Value intVal;
1448 
1449     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1450       isUnsigned = false;
1451       intVal = si.getIn();
1452     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1453       isUnsigned = true;
1454       intVal = ui.getIn();
1455     } else {
1456       return failure();
1457     }
1458 
1459     // Check to see that the input is converted from an integer type that is
1460     // small enough that preserves all bits.
1461     auto intTy = intVal.getType().cast<IntegerType>();
1462     auto intWidth = intTy.getWidth();
1463 
1464     // Number of bits representing values, as opposed to the sign
1465     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1466 
1467     // Following test does NOT adjust intWidth downwards for signed inputs,
1468     // because the most negative value still requires all the mantissa bits
1469     // to distinguish it from one less than that value.
1470     if ((int)intWidth > mantissaWidth) {
1471       // Conversion would lose accuracy. Check if loss can impact comparison.
1472       int exponent = ilogb(rhs);
1473       if (exponent == APFloat::IEK_Inf) {
1474         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1475         if (maxExponent < (int)valueBits) {
1476           // Conversion could create infinity.
1477           return failure();
1478         }
1479       } else {
1480         // Note that if rhs is zero or NaN, then Exp is negative
1481         // and first condition is trivially false.
1482         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1483           // Conversion could affect comparison.
1484           return failure();
1485         }
1486       }
1487     }
1488 
1489     // Convert to equivalent cmpi predicate
1490     CmpIPredicate pred;
1491     switch (op.getPredicate()) {
1492     case CmpFPredicate::ORD:
1493       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1494       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1495                                                  /*width=*/1);
1496       return success();
1497     case CmpFPredicate::UNO:
1498       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1499       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1500                                                  /*width=*/1);
1501       return success();
1502     default:
1503       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1504       break;
1505     }
1506 
1507     if (!isUnsigned) {
1508       // If the rhs value is > SignedMax, fold the comparison.  This handles
1509       // +INF and large values.
1510       APFloat signedMax(rhs.getSemantics());
1511       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1512                                  APFloat::rmNearestTiesToEven);
1513       if (signedMax < rhs) { // smax < 13123.0
1514         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1515             pred == CmpIPredicate::sle)
1516           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1517                                                      /*width=*/1);
1518         else
1519           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1520                                                      /*width=*/1);
1521         return success();
1522       }
1523     } else {
1524       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1525       // +INF and large values.
1526       APFloat unsignedMax(rhs.getSemantics());
1527       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1528                                    APFloat::rmNearestTiesToEven);
1529       if (unsignedMax < rhs) { // umax < 13123.0
1530         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1531             pred == CmpIPredicate::ule)
1532           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1533                                                      /*width=*/1);
1534         else
1535           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1536                                                      /*width=*/1);
1537         return success();
1538       }
1539     }
1540 
1541     if (!isUnsigned) {
1542       // See if the rhs value is < SignedMin.
1543       APFloat signedMin(rhs.getSemantics());
1544       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1545                                  APFloat::rmNearestTiesToEven);
1546       if (signedMin > rhs) { // smin > 12312.0
1547         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1548             pred == CmpIPredicate::sge)
1549           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1550                                                      /*width=*/1);
1551         else
1552           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1553                                                      /*width=*/1);
1554         return success();
1555       }
1556     } else {
1557       // See if the rhs value is < UnsignedMin.
1558       APFloat unsignedMin(rhs.getSemantics());
1559       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1560                                    APFloat::rmNearestTiesToEven);
1561       if (unsignedMin > rhs) { // umin > 12312.0
1562         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1563             pred == CmpIPredicate::uge)
1564           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1565                                                      /*width=*/1);
1566         else
1567           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1568                                                      /*width=*/1);
1569         return success();
1570       }
1571     }
1572 
1573     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1574     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1575     // casting the FP value to the integer value and back, checking for
1576     // equality. Don't do this for zero, because -0.0 is not fractional.
1577     bool ignored;
1578     APSInt rhsInt(intWidth, isUnsigned);
1579     if (APFloat::opInvalidOp ==
1580         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1581       // Undefined behavior invoked - the destination type can't represent
1582       // the input constant.
1583       return failure();
1584     }
1585 
1586     if (!rhs.isZero()) {
1587       APFloat apf(floatTy.getFloatSemantics(),
1588                   APInt::getZero(floatTy.getWidth()));
1589       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1590 
1591       bool equal = apf == rhs;
1592       if (!equal) {
1593         // If we had a comparison against a fractional value, we have to adjust
1594         // the compare predicate and sometimes the value.  rhsInt is rounded
1595         // towards zero at this point.
1596         switch (pred) {
1597         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1598           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1599                                                      /*width=*/1);
1600           return success();
1601         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1602           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1603                                                      /*width=*/1);
1604           return success();
1605         case CmpIPredicate::ule:
1606           // (float)int <= 4.4   --> int <= 4
1607           // (float)int <= -4.4  --> false
1608           if (rhs.isNegative()) {
1609             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1610                                                        /*width=*/1);
1611             return success();
1612           }
1613           break;
1614         case CmpIPredicate::sle:
1615           // (float)int <= 4.4   --> int <= 4
1616           // (float)int <= -4.4  --> int < -4
1617           if (rhs.isNegative())
1618             pred = CmpIPredicate::slt;
1619           break;
1620         case CmpIPredicate::ult:
1621           // (float)int < -4.4   --> false
1622           // (float)int < 4.4    --> int <= 4
1623           if (rhs.isNegative()) {
1624             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1625                                                        /*width=*/1);
1626             return success();
1627           }
1628           pred = CmpIPredicate::ule;
1629           break;
1630         case CmpIPredicate::slt:
1631           // (float)int < -4.4   --> int < -4
1632           // (float)int < 4.4    --> int <= 4
1633           if (!rhs.isNegative())
1634             pred = CmpIPredicate::sle;
1635           break;
1636         case CmpIPredicate::ugt:
1637           // (float)int > 4.4    --> int > 4
1638           // (float)int > -4.4   --> true
1639           if (rhs.isNegative()) {
1640             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1641                                                        /*width=*/1);
1642             return success();
1643           }
1644           break;
1645         case CmpIPredicate::sgt:
1646           // (float)int > 4.4    --> int > 4
1647           // (float)int > -4.4   --> int >= -4
1648           if (rhs.isNegative())
1649             pred = CmpIPredicate::sge;
1650           break;
1651         case CmpIPredicate::uge:
1652           // (float)int >= -4.4   --> true
1653           // (float)int >= 4.4    --> int > 4
1654           if (rhs.isNegative()) {
1655             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1656                                                        /*width=*/1);
1657             return success();
1658           }
1659           pred = CmpIPredicate::ugt;
1660           break;
1661         case CmpIPredicate::sge:
1662           // (float)int >= -4.4   --> int >= -4
1663           // (float)int >= 4.4    --> int > 4
1664           if (!rhs.isNegative())
1665             pred = CmpIPredicate::sgt;
1666           break;
1667         }
1668       }
1669     }
1670 
1671     // Lower this FP comparison into an appropriate integer version of the
1672     // comparison.
1673     rewriter.replaceOpWithNewOp<CmpIOp>(
1674         op, pred, intVal,
1675         rewriter.create<ConstantOp>(
1676             op.getLoc(), intVal.getType(),
1677             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1678     return success();
1679   }
1680 };
1681 
1682 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1683                                                 MLIRContext *context) {
1684   patterns.insert<CmpFIntToFPConst>(context);
1685 }
1686 
1687 //===----------------------------------------------------------------------===//
1688 // SelectOp
1689 //===----------------------------------------------------------------------===//
1690 
1691 // Transforms a select of a boolean to arithmetic operations
1692 //
1693 //  arith.select %arg, %x, %y : i1
1694 //
1695 //  becomes
1696 //
1697 //  and(%arg, %x) or and(!%arg, %y)
1698 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1699   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1700 
1701   LogicalResult matchAndRewrite(arith::SelectOp op,
1702                                 PatternRewriter &rewriter) const override {
1703     if (!op.getType().isInteger(1))
1704       return failure();
1705 
1706     Value falseConstant =
1707         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1708     Value notCondition = rewriter.create<arith::XOrIOp>(
1709         op.getLoc(), op.getCondition(), falseConstant);
1710 
1711     Value trueVal = rewriter.create<arith::AndIOp>(
1712         op.getLoc(), op.getCondition(), op.getTrueValue());
1713     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1714                                                     op.getFalseValue());
1715     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1716     return success();
1717   }
1718 };
1719 
1720 //  select %arg, %c1, %c0 => extui %arg
1721 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1722   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1723 
1724   LogicalResult matchAndRewrite(arith::SelectOp op,
1725                                 PatternRewriter &rewriter) const override {
1726     // Cannot extui i1 to i1, or i1 to f32
1727     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1728       return failure();
1729 
1730     // select %x, c1, %c0 => extui %arg
1731     if (matchPattern(op.getTrueValue(), m_One()))
1732       if (matchPattern(op.getFalseValue(), m_Zero())) {
1733         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1734                                                     op.getCondition());
1735         return success();
1736       }
1737 
1738     // select %x, c0, %c1 => extui (xor %arg, true)
1739     if (matchPattern(op.getTrueValue(), m_Zero()))
1740       if (matchPattern(op.getFalseValue(), m_One())) {
1741         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1742             op, op.getType(),
1743             rewriter.create<arith::XOrIOp>(
1744                 op.getLoc(), op.getCondition(),
1745                 rewriter.create<arith::ConstantIntOp>(
1746                     op.getLoc(), 1, op.getCondition().getType())));
1747         return success();
1748       }
1749 
1750     return failure();
1751   }
1752 };
1753 
1754 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1755                                                   MLIRContext *context) {
1756   results.add<SelectI1Simplify, SelectToExtUI>(context);
1757 }
1758 
1759 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1760   Value trueVal = getTrueValue();
1761   Value falseVal = getFalseValue();
1762   if (trueVal == falseVal)
1763     return trueVal;
1764 
1765   Value condition = getCondition();
1766 
1767   // select true, %0, %1 => %0
1768   if (matchPattern(condition, m_One()))
1769     return trueVal;
1770 
1771   // select false, %0, %1 => %1
1772   if (matchPattern(condition, m_Zero()))
1773     return falseVal;
1774 
1775   // select %x, true, false => %x
1776   if (getType().isInteger(1))
1777     if (matchPattern(getTrueValue(), m_One()))
1778       if (matchPattern(getFalseValue(), m_Zero()))
1779         return condition;
1780 
1781   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1782     auto pred = cmp.getPredicate();
1783     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1784       auto cmpLhs = cmp.getLhs();
1785       auto cmpRhs = cmp.getRhs();
1786 
1787       // %0 = arith.cmpi eq, %arg0, %arg1
1788       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1789 
1790       // %0 = arith.cmpi ne, %arg0, %arg1
1791       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1792 
1793       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1794           (cmpRhs == trueVal && cmpLhs == falseVal))
1795         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1796     }
1797   }
1798   return nullptr;
1799 }
1800 
1801 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1802   Type conditionType, resultType;
1803   SmallVector<OpAsmParser::OperandType, 3> operands;
1804   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1805       parser.parseOptionalAttrDict(result.attributes) ||
1806       parser.parseColonType(resultType))
1807     return failure();
1808 
1809   // Check for the explicit condition type if this is a masked tensor or vector.
1810   if (succeeded(parser.parseOptionalComma())) {
1811     conditionType = resultType;
1812     if (parser.parseType(resultType))
1813       return failure();
1814   } else {
1815     conditionType = parser.getBuilder().getI1Type();
1816   }
1817 
1818   result.addTypes(resultType);
1819   return parser.resolveOperands(operands,
1820                                 {conditionType, resultType, resultType},
1821                                 parser.getNameLoc(), result.operands);
1822 }
1823 
1824 void arith::SelectOp::print(OpAsmPrinter &p) {
1825   p << " " << getOperands();
1826   p.printOptionalAttrDict((*this)->getAttrs());
1827   p << " : ";
1828   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1829     p << condType << ", ";
1830   p << getType();
1831 }
1832 
1833 LogicalResult arith::SelectOp::verify() {
1834   Type conditionType = getCondition().getType();
1835   if (conditionType.isSignlessInteger(1))
1836     return success();
1837 
1838   // If the result type is a vector or tensor, the type can be a mask with the
1839   // same elements.
1840   Type resultType = getType();
1841   if (!resultType.isa<TensorType, VectorType>())
1842     return emitOpError() << "expected condition to be a signless i1, but got "
1843                          << conditionType;
1844   Type shapedConditionType = getI1SameShape(resultType);
1845   if (conditionType != shapedConditionType) {
1846     return emitOpError() << "expected condition type to have the same shape "
1847                             "as the result type, expected "
1848                          << shapedConditionType << ", but got "
1849                          << conditionType;
1850   }
1851   return success();
1852 }
1853 
1854 //===----------------------------------------------------------------------===//
1855 // Atomic Enum
1856 //===----------------------------------------------------------------------===//
1857 
1858 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1859 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1860                                             OpBuilder &builder, Location loc) {
1861   switch (kind) {
1862   case AtomicRMWKind::maxf:
1863     return builder.getFloatAttr(
1864         resultType,
1865         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1866                         /*Negative=*/true));
1867   case AtomicRMWKind::addf:
1868   case AtomicRMWKind::addi:
1869   case AtomicRMWKind::maxu:
1870   case AtomicRMWKind::ori:
1871     return builder.getZeroAttr(resultType);
1872   case AtomicRMWKind::andi:
1873     return builder.getIntegerAttr(
1874         resultType,
1875         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1876   case AtomicRMWKind::maxs:
1877     return builder.getIntegerAttr(
1878         resultType,
1879         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1880   case AtomicRMWKind::minf:
1881     return builder.getFloatAttr(
1882         resultType,
1883         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1884                         /*Negative=*/false));
1885   case AtomicRMWKind::mins:
1886     return builder.getIntegerAttr(
1887         resultType,
1888         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1889   case AtomicRMWKind::minu:
1890     return builder.getIntegerAttr(
1891         resultType,
1892         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1893   case AtomicRMWKind::muli:
1894     return builder.getIntegerAttr(resultType, 1);
1895   case AtomicRMWKind::mulf:
1896     return builder.getFloatAttr(resultType, 1);
1897   // TODO: Add remaining reduction operations.
1898   default:
1899     (void)emitOptionalError(loc, "Reduction operation type not supported");
1900     break;
1901   }
1902   return nullptr;
1903 }
1904 
1905 /// Returns the identity value associated with an AtomicRMWKind op.
1906 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1907                                     OpBuilder &builder, Location loc) {
1908   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1909   return builder.create<arith::ConstantOp>(loc, attr);
1910 }
1911 
1912 /// Return the value obtained by applying the reduction operation kind
1913 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1914 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1915                                   Location loc, Value lhs, Value rhs) {
1916   switch (op) {
1917   case AtomicRMWKind::addf:
1918     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1919   case AtomicRMWKind::addi:
1920     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1921   case AtomicRMWKind::mulf:
1922     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1923   case AtomicRMWKind::muli:
1924     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1925   case AtomicRMWKind::maxf:
1926     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1927   case AtomicRMWKind::minf:
1928     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1929   case AtomicRMWKind::maxs:
1930     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1931   case AtomicRMWKind::mins:
1932     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1933   case AtomicRMWKind::maxu:
1934     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1935   case AtomicRMWKind::minu:
1936     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1937   case AtomicRMWKind::ori:
1938     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1939   case AtomicRMWKind::andi:
1940     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1941   // TODO: Add remaining reduction operations.
1942   default:
1943     (void)emitOptionalError(loc, "Reduction operation type not supported");
1944     break;
1945   }
1946   return nullptr;
1947 }
1948 
1949 //===----------------------------------------------------------------------===//
1950 // TableGen'd op method definitions
1951 //===----------------------------------------------------------------------===//
1952 
1953 #define GET_OP_CLASSES
1954 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1955 
1956 //===----------------------------------------------------------------------===//
1957 // TableGen'd enum attribute definitions
1958 //===----------------------------------------------------------------------===//
1959 
1960 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1961