1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30                                    Attribute lhs, Attribute rhs) {
31   return builder.getIntegerAttr(res.getType(),
32                                 lhs.cast<IntegerAttr>().getInt() +
33                                     rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37                                    Attribute lhs, Attribute rhs) {
38   return builder.getIntegerAttr(res.getType(),
39                                 lhs.cast<IntegerAttr>().getInt() -
40                                     rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45   switch (pred) {
46   case arith::CmpIPredicate::eq:
47     return arith::CmpIPredicate::ne;
48   case arith::CmpIPredicate::ne:
49     return arith::CmpIPredicate::eq;
50   case arith::CmpIPredicate::slt:
51     return arith::CmpIPredicate::sge;
52   case arith::CmpIPredicate::sle:
53     return arith::CmpIPredicate::sgt;
54   case arith::CmpIPredicate::sgt:
55     return arith::CmpIPredicate::sle;
56   case arith::CmpIPredicate::sge:
57     return arith::CmpIPredicate::slt;
58   case arith::CmpIPredicate::ult:
59     return arith::CmpIPredicate::uge;
60   case arith::CmpIPredicate::ule:
61     return arith::CmpIPredicate::ugt;
62   case arith::CmpIPredicate::ugt:
63     return arith::CmpIPredicate::ule;
64   case arith::CmpIPredicate::uge:
65     return arith::CmpIPredicate::ult;
66   }
67   llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71   return arith::CmpIPredicateAttr::get(pred.getContext(),
72                                        invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88     function_ref<void(Value, StringRef)> setNameFn) {
89   auto type = getType();
90   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91     auto intType = type.dyn_cast<IntegerType>();
92 
93     // Sugar i1 constants with 'true' and 'false'.
94     if (intType && intType.getWidth() == 1)
95       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97     // Otherwise, build a compex name with the value and type.
98     SmallString<32> specialNameBuffer;
99     llvm::raw_svector_ostream specialName(specialNameBuffer);
100     specialName << 'c' << intCst.getInt();
101     if (intType)
102       specialName << '_' << type;
103     setNameFn(getResult(), specialName.str());
104   } else {
105     setNameFn(getResult(), "cst");
106   }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
111 LogicalResult arith::ConstantOp::verify() {
112   auto type = getType();
113   // The value's type must match the return type.
114   if (getValue().getType() != type) {
115     return emitOpError() << "value type " << getValue().getType()
116                          << " must match return type: " << type;
117   }
118   // Integer values must be signless.
119   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120     return emitOpError("integer return type must be signless");
121   // Any float or elements attribute are acceptable.
122   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123     return emitOpError(
124         "value must be an integer, float, or elements attribute");
125   }
126   return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130   // The value's type must be the same as the provided type.
131   if (value.getType() != type)
132     return false;
133   // Integer values must be signless.
134   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135     return false;
136   // Integer, float, and element attributes are buildable.
137   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139 
140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141   return getValue();
142 }
143 
144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145                                  int64_t value, unsigned width) {
146   auto type = builder.getIntegerType(width);
147   arith::ConstantOp::build(builder, result, type,
148                            builder.getIntegerAttr(type, value));
149 }
150 
151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152                                  int64_t value, Type type) {
153   assert(type.isSignlessInteger() &&
154          "ConstantIntOp can only have signless integer type values");
155   arith::ConstantOp::build(builder, result, type,
156                            builder.getIntegerAttr(type, value));
157 }
158 
159 bool arith::ConstantIntOp::classof(Operation *op) {
160   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161     return constOp.getType().isSignlessInteger();
162   return false;
163 }
164 
165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166                                    const APFloat &value, FloatType type) {
167   arith::ConstantOp::build(builder, result, type,
168                            builder.getFloatAttr(type, value));
169 }
170 
171 bool arith::ConstantFloatOp::classof(Operation *op) {
172   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173     return constOp.getType().isa<FloatType>();
174   return false;
175 }
176 
177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178                                    int64_t value) {
179   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180                            builder.getIndexAttr(value));
181 }
182 
183 bool arith::ConstantIndexOp::classof(Operation *op) {
184   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185     return constOp.getType().isIndex();
186   return false;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194   // addi(x, 0) -> x
195   if (matchPattern(getRhs(), m_Zero()))
196     return getLhs();
197 
198   // addi(subi(a, b), b) -> a
199   if (auto sub = getLhs().getDefiningOp<SubIOp>())
200     if (getRhs() == sub.getRhs())
201       return sub.getLhs();
202 
203   // addi(b, subi(a, b)) -> a
204   if (auto sub = getRhs().getDefiningOp<SubIOp>())
205     if (getLhs() == sub.getRhs())
206       return sub.getLhs();
207 
208   return constFoldBinaryOp<IntegerAttr>(
209       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211 
212 void arith::AddIOp::getCanonicalizationPatterns(
213     RewritePatternSet &patterns, MLIRContext *context) {
214   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215       context);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221 
222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223   // subi(x,x) -> 0
224   if (getOperand(0) == getOperand(1))
225     return Builder(getContext()).getZeroAttr(getType());
226   // subi(x,0) -> x
227   if (matchPattern(getRhs(), m_Zero()))
228     return getLhs();
229 
230   return constFoldBinaryOp<IntegerAttr>(
231       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233 
234 void arith::SubIOp::getCanonicalizationPatterns(
235     RewritePatternSet &patterns, MLIRContext *context) {
236   patterns
237       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247   // muli(x, 0) -> 0
248   if (matchPattern(getRhs(), m_Zero()))
249     return getRhs();
250   // muli(x, 1) -> x
251   if (matchPattern(getRhs(), m_One()))
252     return getOperand(0);
253   // TODO: Handle the overflow case.
254 
255   // default folder
256   return constFoldBinaryOp<IntegerAttr>(
257       operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265   // Don't fold if it would require a division by zero.
266   bool div0 = false;
267   auto result =
268       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
269         if (div0 || !b) {
270           div0 = true;
271           return a;
272         }
273         return a.udiv(b);
274       });
275 
276   // Fold out division by one. Assumes all tensors of all ones are splats.
277   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
278     if (rhs.getValue() == 1)
279       return getLhs();
280   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
281     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
282       return getLhs();
283   }
284 
285   return div0 ? Attribute() : result;
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // DivSIOp
290 //===----------------------------------------------------------------------===//
291 
292 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
293   // Don't fold if it would overflow or if it requires a division by zero.
294   bool overflowOrDiv0 = false;
295   auto result =
296       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
297         if (overflowOrDiv0 || !b) {
298           overflowOrDiv0 = true;
299           return a;
300         }
301         return a.sdiv_ov(b, overflowOrDiv0);
302       });
303 
304   // Fold out division by one. Assumes all tensors of all ones are splats.
305   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
306     if (rhs.getValue() == 1)
307       return getLhs();
308   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
309     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
310       return getLhs();
311   }
312 
313   return overflowOrDiv0 ? Attribute() : result;
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // Ceil and floor division folding helpers
318 //===----------------------------------------------------------------------===//
319 
320 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
321                                     bool &overflow) {
322   // Returns (a-1)/b + 1
323   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
324   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
325   return val.sadd_ov(one, overflow);
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // CeilDivUIOp
330 //===----------------------------------------------------------------------===//
331 
332 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
333   bool overflowOrDiv0 = false;
334   auto result =
335       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
336         if (overflowOrDiv0 || !b) {
337           overflowOrDiv0 = true;
338           return a;
339         }
340         APInt quotient = a.udiv(b);
341         if (!a.urem(b))
342           return quotient;
343         APInt one(a.getBitWidth(), 1, true);
344         return quotient.uadd_ov(one, overflowOrDiv0);
345       });
346   // Fold out ceil division by one. Assumes all tensors of all ones are
347   // splats.
348   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
349     if (rhs.getValue() == 1)
350       return getLhs();
351   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
352     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
353       return getLhs();
354   }
355 
356   return overflowOrDiv0 ? Attribute() : result;
357 }
358 
359 //===----------------------------------------------------------------------===//
360 // CeilDivSIOp
361 //===----------------------------------------------------------------------===//
362 
363 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
364   // Don't fold if it would overflow or if it requires a division by zero.
365   bool overflowOrDiv0 = false;
366   auto result =
367       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
368         if (overflowOrDiv0 || !b) {
369           overflowOrDiv0 = true;
370           return a;
371         }
372         if (!a)
373           return a;
374         // After this point we know that neither a or b are zero.
375         unsigned bits = a.getBitWidth();
376         APInt zero = APInt::getZero(bits);
377         bool aGtZero = a.sgt(zero);
378         bool bGtZero = b.sgt(zero);
379         if (aGtZero && bGtZero) {
380           // Both positive, return ceil(a, b).
381           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
382         }
383         if (!aGtZero && !bGtZero) {
384           // Both negative, return ceil(-a, -b).
385           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
386           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
387           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
388         }
389         if (!aGtZero && bGtZero) {
390           // A is negative, b is positive, return - ( -a / b).
391           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
392           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
393           return zero.ssub_ov(div, overflowOrDiv0);
394         }
395         // A is positive, b is negative, return - (a / -b).
396         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
397         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
398         return zero.ssub_ov(div, overflowOrDiv0);
399       });
400 
401   // Fold out ceil division by one. Assumes all tensors of all ones are
402   // splats.
403   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
404     if (rhs.getValue() == 1)
405       return getLhs();
406   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
407     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
408       return getLhs();
409   }
410 
411   return overflowOrDiv0 ? Attribute() : result;
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // FloorDivSIOp
416 //===----------------------------------------------------------------------===//
417 
418 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
419   // Don't fold if it would overflow or if it requires a division by zero.
420   bool overflowOrDiv0 = false;
421   auto result =
422       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
423         if (overflowOrDiv0 || !b) {
424           overflowOrDiv0 = true;
425           return a;
426         }
427         if (!a)
428           return a;
429         // After this point we know that neither a or b are zero.
430         unsigned bits = a.getBitWidth();
431         APInt zero = APInt::getZero(bits);
432         bool aGtZero = a.sgt(zero);
433         bool bGtZero = b.sgt(zero);
434         if (aGtZero && bGtZero) {
435           // Both positive, return a / b.
436           return a.sdiv_ov(b, overflowOrDiv0);
437         }
438         if (!aGtZero && !bGtZero) {
439           // Both negative, return -a / -b.
440           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
441           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
442           return posA.sdiv_ov(posB, overflowOrDiv0);
443         }
444         if (!aGtZero && bGtZero) {
445           // A is negative, b is positive, return - ceil(-a, b).
446           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
447           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
448           return zero.ssub_ov(ceil, overflowOrDiv0);
449         }
450         // A is positive, b is negative, return - ceil(a, -b).
451         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
452         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
453         return zero.ssub_ov(ceil, overflowOrDiv0);
454       });
455 
456   // Fold out floor division by one. Assumes all tensors of all ones are
457   // splats.
458   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
459     if (rhs.getValue() == 1)
460       return getLhs();
461   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
462     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
463       return getLhs();
464   }
465 
466   return overflowOrDiv0 ? Attribute() : result;
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // RemUIOp
471 //===----------------------------------------------------------------------===//
472 
473 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
474   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
475   if (!rhs)
476     return {};
477   auto rhsValue = rhs.getValue();
478 
479   // x % 1 = 0
480   if (rhsValue.isOneValue())
481     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
482 
483   // Don't fold if it requires division by zero.
484   if (rhsValue.isNullValue())
485     return {};
486 
487   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
488   if (!lhs)
489     return {};
490   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // RemSIOp
495 //===----------------------------------------------------------------------===//
496 
497 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
498   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
499   if (!rhs)
500     return {};
501   auto rhsValue = rhs.getValue();
502 
503   // x % 1 = 0
504   if (rhsValue.isOneValue())
505     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
506 
507   // Don't fold if it requires division by zero.
508   if (rhsValue.isNullValue())
509     return {};
510 
511   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
512   if (!lhs)
513     return {};
514   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // AndIOp
519 //===----------------------------------------------------------------------===//
520 
521 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
522   /// and(x, 0) -> 0
523   if (matchPattern(getRhs(), m_Zero()))
524     return getRhs();
525   /// and(x, allOnes) -> x
526   APInt intValue;
527   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
528     return getLhs();
529 
530   return constFoldBinaryOp<IntegerAttr>(
531       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // OrIOp
536 //===----------------------------------------------------------------------===//
537 
538 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
539   /// or(x, 0) -> x
540   if (matchPattern(getRhs(), m_Zero()))
541     return getLhs();
542   /// or(x, <all ones>) -> <all ones>
543   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
544     if (rhsAttr.getValue().isAllOnes())
545       return rhsAttr;
546 
547   return constFoldBinaryOp<IntegerAttr>(
548       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // XOrIOp
553 //===----------------------------------------------------------------------===//
554 
555 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
556   /// xor(x, 0) -> x
557   if (matchPattern(getRhs(), m_Zero()))
558     return getLhs();
559   /// xor(x, x) -> 0
560   if (getLhs() == getRhs())
561     return Builder(getContext()).getZeroAttr(getType());
562   /// xor(xor(x, a), a) -> x
563   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
564     if (prev.getRhs() == getRhs())
565       return prev.getLhs();
566 
567   return constFoldBinaryOp<IntegerAttr>(
568       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
569 }
570 
571 void arith::XOrIOp::getCanonicalizationPatterns(
572     RewritePatternSet &patterns, MLIRContext *context) {
573   patterns.add<XOrINotCmpI>(context);
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // AddFOp
578 //===----------------------------------------------------------------------===//
579 
580 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
581   // addf(x, -0) -> x
582   if (matchPattern(getRhs(), m_NegZeroFloat()))
583     return getLhs();
584 
585   return constFoldBinaryOp<FloatAttr>(
586       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // SubFOp
591 //===----------------------------------------------------------------------===//
592 
593 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
594   // subf(x, +0) -> x
595   if (matchPattern(getRhs(), m_PosZeroFloat()))
596     return getLhs();
597 
598   return constFoldBinaryOp<FloatAttr>(
599       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // MaxFOp
604 //===----------------------------------------------------------------------===//
605 
606 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
607   assert(operands.size() == 2 && "maxf takes two operands");
608 
609   // maxf(x,x) -> x
610   if (getLhs() == getRhs())
611     return getRhs();
612 
613   // maxf(x, -inf) -> x
614   if (matchPattern(getRhs(), m_NegInfFloat()))
615     return getLhs();
616 
617   return constFoldBinaryOp<FloatAttr>(
618       operands,
619       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // MaxSIOp
624 //===----------------------------------------------------------------------===//
625 
626 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
627   assert(operands.size() == 2 && "binary operation takes two operands");
628 
629   // maxsi(x,x) -> x
630   if (getLhs() == getRhs())
631     return getRhs();
632 
633   APInt intValue;
634   // maxsi(x,MAX_INT) -> MAX_INT
635   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
636       intValue.isMaxSignedValue())
637     return getRhs();
638 
639   // maxsi(x, MIN_INT) -> x
640   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
641       intValue.isMinSignedValue())
642     return getLhs();
643 
644   return constFoldBinaryOp<IntegerAttr>(operands,
645                                         [](const APInt &a, const APInt &b) {
646                                           return llvm::APIntOps::smax(a, b);
647                                         });
648 }
649 
650 //===----------------------------------------------------------------------===//
651 // MaxUIOp
652 //===----------------------------------------------------------------------===//
653 
654 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
655   assert(operands.size() == 2 && "binary operation takes two operands");
656 
657   // maxui(x,x) -> x
658   if (getLhs() == getRhs())
659     return getRhs();
660 
661   APInt intValue;
662   // maxui(x,MAX_INT) -> MAX_INT
663   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
664     return getRhs();
665 
666   // maxui(x, MIN_INT) -> x
667   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
668     return getLhs();
669 
670   return constFoldBinaryOp<IntegerAttr>(operands,
671                                         [](const APInt &a, const APInt &b) {
672                                           return llvm::APIntOps::umax(a, b);
673                                         });
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // MinFOp
678 //===----------------------------------------------------------------------===//
679 
680 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
681   assert(operands.size() == 2 && "minf takes two operands");
682 
683   // minf(x,x) -> x
684   if (getLhs() == getRhs())
685     return getRhs();
686 
687   // minf(x, +inf) -> x
688   if (matchPattern(getRhs(), m_PosInfFloat()))
689     return getLhs();
690 
691   return constFoldBinaryOp<FloatAttr>(
692       operands,
693       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
694 }
695 
696 //===----------------------------------------------------------------------===//
697 // MinSIOp
698 //===----------------------------------------------------------------------===//
699 
700 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
701   assert(operands.size() == 2 && "binary operation takes two operands");
702 
703   // minsi(x,x) -> x
704   if (getLhs() == getRhs())
705     return getRhs();
706 
707   APInt intValue;
708   // minsi(x,MIN_INT) -> MIN_INT
709   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
710       intValue.isMinSignedValue())
711     return getRhs();
712 
713   // minsi(x, MAX_INT) -> x
714   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
715       intValue.isMaxSignedValue())
716     return getLhs();
717 
718   return constFoldBinaryOp<IntegerAttr>(operands,
719                                         [](const APInt &a, const APInt &b) {
720                                           return llvm::APIntOps::smin(a, b);
721                                         });
722 }
723 
724 //===----------------------------------------------------------------------===//
725 // MinUIOp
726 //===----------------------------------------------------------------------===//
727 
728 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
729   assert(operands.size() == 2 && "binary operation takes two operands");
730 
731   // minui(x,x) -> x
732   if (getLhs() == getRhs())
733     return getRhs();
734 
735   APInt intValue;
736   // minui(x,MIN_INT) -> MIN_INT
737   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
738     return getRhs();
739 
740   // minui(x, MAX_INT) -> x
741   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
742     return getLhs();
743 
744   return constFoldBinaryOp<IntegerAttr>(operands,
745                                         [](const APInt &a, const APInt &b) {
746                                           return llvm::APIntOps::umin(a, b);
747                                         });
748 }
749 
750 //===----------------------------------------------------------------------===//
751 // MulFOp
752 //===----------------------------------------------------------------------===//
753 
754 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
755   APFloat floatValue(0.0f), inverseValue(0.0f);
756   // mulf(x, 1) -> x
757   if (matchPattern(getRhs(), m_OneFloat()))
758     return getLhs();
759 
760   // mulf(1, x) -> x
761   if (matchPattern(getLhs(), m_OneFloat()))
762     return getRhs();
763 
764   return constFoldBinaryOp<FloatAttr>(
765       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // DivFOp
770 //===----------------------------------------------------------------------===//
771 
772 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
773   APFloat floatValue(0.0f), inverseValue(0.0f);
774   // divf(x, 1) -> x
775   if (matchPattern(getRhs(), m_OneFloat()))
776     return getLhs();
777 
778   return constFoldBinaryOp<FloatAttr>(
779       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // Utility functions for verifying cast ops
784 //===----------------------------------------------------------------------===//
785 
786 template <typename... Types>
787 using type_list = std::tuple<Types...> *;
788 
789 /// Returns a non-null type only if the provided type is one of the allowed
790 /// types or one of the allowed shaped types of the allowed types. Returns the
791 /// element type if a valid shaped type is provided.
792 template <typename... ShapedTypes, typename... ElementTypes>
793 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
794                               type_list<ElementTypes...>) {
795   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
796     return {};
797 
798   auto underlyingType = getElementTypeOrSelf(type);
799   if (!underlyingType.isa<ElementTypes...>())
800     return {};
801 
802   return underlyingType;
803 }
804 
805 /// Get allowed underlying types for vectors and tensors.
806 template <typename... ElementTypes>
807 static Type getTypeIfLike(Type type) {
808   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
809                            type_list<ElementTypes...>());
810 }
811 
812 /// Get allowed underlying types for vectors, tensors, and memrefs.
813 template <typename... ElementTypes>
814 static Type getTypeIfLikeOrMemRef(Type type) {
815   return getUnderlyingType(type,
816                            type_list<VectorType, TensorType, MemRefType>(),
817                            type_list<ElementTypes...>());
818 }
819 
820 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
821   return inputs.size() == 1 && outputs.size() == 1 &&
822          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
823 }
824 
825 //===----------------------------------------------------------------------===//
826 // Verifiers for integer and floating point extension/truncation ops
827 //===----------------------------------------------------------------------===//
828 
829 // Extend ops can only extend to a wider type.
830 template <typename ValType, typename Op>
831 static LogicalResult verifyExtOp(Op op) {
832   Type srcType = getElementTypeOrSelf(op.getIn().getType());
833   Type dstType = getElementTypeOrSelf(op.getType());
834 
835   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
836     return op.emitError("result type ")
837            << dstType << " must be wider than operand type " << srcType;
838 
839   return success();
840 }
841 
842 // Truncate ops can only truncate to a shorter type.
843 template <typename ValType, typename Op>
844 static LogicalResult verifyTruncateOp(Op op) {
845   Type srcType = getElementTypeOrSelf(op.getIn().getType());
846   Type dstType = getElementTypeOrSelf(op.getType());
847 
848   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
849     return op.emitError("result type ")
850            << dstType << " must be shorter than operand type " << srcType;
851 
852   return success();
853 }
854 
855 /// Validate a cast that changes the width of a type.
856 template <template <typename> class WidthComparator, typename... ElementTypes>
857 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
858   if (!areValidCastInputsAndOutputs(inputs, outputs))
859     return false;
860 
861   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
862   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
863   if (!srcType || !dstType)
864     return false;
865 
866   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
867                                      srcType.getIntOrFloatBitWidth());
868 }
869 
870 //===----------------------------------------------------------------------===//
871 // ExtUIOp
872 //===----------------------------------------------------------------------===//
873 
874 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
875   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
876     return IntegerAttr::get(
877         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
878 
879   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
880     getInMutable().assign(lhs.getIn());
881     return getResult();
882   }
883 
884   return {};
885 }
886 
887 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
888   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
889 }
890 
891 LogicalResult arith::ExtUIOp::verify() {
892   return verifyExtOp<IntegerType>(*this);
893 }
894 
895 //===----------------------------------------------------------------------===//
896 // ExtSIOp
897 //===----------------------------------------------------------------------===//
898 
899 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
900   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
901     return IntegerAttr::get(
902         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
903 
904   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
905     getInMutable().assign(lhs.getIn());
906     return getResult();
907   }
908 
909   return {};
910 }
911 
912 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
913   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
914 }
915 
916 void arith::ExtSIOp::getCanonicalizationPatterns(
917     RewritePatternSet &patterns, MLIRContext *context) {
918   patterns.add<ExtSIOfExtUI>(context);
919 }
920 
921 LogicalResult arith::ExtSIOp::verify() {
922   return verifyExtOp<IntegerType>(*this);
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // ExtFOp
927 //===----------------------------------------------------------------------===//
928 
929 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
930   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
931 }
932 
933 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
934 
935 //===----------------------------------------------------------------------===//
936 // TruncIOp
937 //===----------------------------------------------------------------------===//
938 
939 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
940   assert(operands.size() == 1 && "unary operation takes one operand");
941 
942   // trunci(zexti(a)) -> a
943   // trunci(sexti(a)) -> a
944   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
945       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
946     return getOperand().getDefiningOp()->getOperand(0);
947 
948   // trunci(trunci(a)) -> trunci(a))
949   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
950     setOperand(getOperand().getDefiningOp()->getOperand(0));
951     return getResult();
952   }
953 
954   if (!operands[0])
955     return {};
956 
957   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
958     return IntegerAttr::get(
959         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
960   }
961 
962   return {};
963 }
964 
965 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
966   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
967 }
968 
969 LogicalResult arith::TruncIOp::verify() {
970   return verifyTruncateOp<IntegerType>(*this);
971 }
972 
973 //===----------------------------------------------------------------------===//
974 // TruncFOp
975 //===----------------------------------------------------------------------===//
976 
977 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
978 /// can be represented without precision loss or rounding.
979 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
980   assert(operands.size() == 1 && "unary operation takes one operand");
981 
982   auto constOperand = operands.front();
983   if (!constOperand || !constOperand.isa<FloatAttr>())
984     return {};
985 
986   // Convert to target type via 'double'.
987   double sourceValue =
988       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
989   auto targetAttr = FloatAttr::get(getType(), sourceValue);
990 
991   // Propagate if constant's value does not change after truncation.
992   if (sourceValue == targetAttr.getValue().convertToDouble())
993     return targetAttr;
994 
995   return {};
996 }
997 
998 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
999   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1000 }
1001 
1002 LogicalResult arith::TruncFOp::verify() {
1003   return verifyTruncateOp<FloatType>(*this);
1004 }
1005 
1006 //===----------------------------------------------------------------------===//
1007 // AndIOp
1008 //===----------------------------------------------------------------------===//
1009 
1010 void arith::AndIOp::getCanonicalizationPatterns(
1011     RewritePatternSet &patterns, MLIRContext *context) {
1012   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1013 }
1014 
1015 //===----------------------------------------------------------------------===//
1016 // OrIOp
1017 //===----------------------------------------------------------------------===//
1018 
1019 void arith::OrIOp::getCanonicalizationPatterns(
1020     RewritePatternSet &patterns, MLIRContext *context) {
1021   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // Verifiers for casts between integers and floats.
1026 //===----------------------------------------------------------------------===//
1027 
1028 template <typename From, typename To>
1029 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1030   if (!areValidCastInputsAndOutputs(inputs, outputs))
1031     return false;
1032 
1033   auto srcType = getTypeIfLike<From>(inputs.front());
1034   auto dstType = getTypeIfLike<To>(outputs.back());
1035 
1036   return srcType && dstType;
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // UIToFPOp
1041 //===----------------------------------------------------------------------===//
1042 
1043 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1044   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1045 }
1046 
1047 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1048   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1049     const APInt &api = lhs.getValue();
1050     FloatType floatTy = getType().cast<FloatType>();
1051     APFloat apf(floatTy.getFloatSemantics(),
1052                 APInt::getZero(floatTy.getWidth()));
1053     apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
1054     return FloatAttr::get(floatTy, apf);
1055   }
1056   return {};
1057 }
1058 
1059 //===----------------------------------------------------------------------===//
1060 // SIToFPOp
1061 //===----------------------------------------------------------------------===//
1062 
1063 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1064   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1065 }
1066 
1067 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1068   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1069     const APInt &api = lhs.getValue();
1070     FloatType floatTy = getType().cast<FloatType>();
1071     APFloat apf(floatTy.getFloatSemantics(),
1072                 APInt::getZero(floatTy.getWidth()));
1073     apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
1074     return FloatAttr::get(floatTy, apf);
1075   }
1076   return {};
1077 }
1078 //===----------------------------------------------------------------------===//
1079 // FPToUIOp
1080 //===----------------------------------------------------------------------===//
1081 
1082 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1083   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1084 }
1085 
1086 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1087   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1088     const APFloat &apf = lhs.getValue();
1089     IntegerType intTy = getType().cast<IntegerType>();
1090     bool ignored;
1091     APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1092     if (APFloat::opInvalidOp ==
1093         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1094       // Undefined behavior invoked - the destination type can't represent
1095       // the input constant.
1096       return {};
1097     }
1098     return IntegerAttr::get(getType(), api);
1099   }
1100 
1101   return {};
1102 }
1103 
1104 //===----------------------------------------------------------------------===//
1105 // FPToSIOp
1106 //===----------------------------------------------------------------------===//
1107 
1108 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1109   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1110 }
1111 
1112 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1113   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1114     const APFloat &apf = lhs.getValue();
1115     IntegerType intTy = getType().cast<IntegerType>();
1116     bool ignored;
1117     APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1118     if (APFloat::opInvalidOp ==
1119         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1120       // Undefined behavior invoked - the destination type can't represent
1121       // the input constant.
1122       return {};
1123     }
1124     return IntegerAttr::get(getType(), api);
1125   }
1126 
1127   return {};
1128 }
1129 
1130 //===----------------------------------------------------------------------===//
1131 // IndexCastOp
1132 //===----------------------------------------------------------------------===//
1133 
1134 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1135                                            TypeRange outputs) {
1136   if (!areValidCastInputsAndOutputs(inputs, outputs))
1137     return false;
1138 
1139   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1140   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1141   if (!srcType || !dstType)
1142     return false;
1143 
1144   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1145          (srcType.isSignlessInteger() && dstType.isIndex());
1146 }
1147 
1148 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1149   // index_cast(constant) -> constant
1150   // A little hack because we go through int. Otherwise, the size of the
1151   // constant might need to change.
1152   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1153     return IntegerAttr::get(getType(), value.getInt());
1154 
1155   return {};
1156 }
1157 
1158 void arith::IndexCastOp::getCanonicalizationPatterns(
1159     RewritePatternSet &patterns, MLIRContext *context) {
1160   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1161 }
1162 
1163 //===----------------------------------------------------------------------===//
1164 // BitcastOp
1165 //===----------------------------------------------------------------------===//
1166 
1167 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1168   if (!areValidCastInputsAndOutputs(inputs, outputs))
1169     return false;
1170 
1171   auto srcType =
1172       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1173   auto dstType =
1174       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1175   if (!srcType || !dstType)
1176     return false;
1177 
1178   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1179 }
1180 
1181 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1182   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1183 
1184   auto resType = getType();
1185   auto operand = operands[0];
1186   if (!operand)
1187     return {};
1188 
1189   /// Bitcast dense elements.
1190   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1191     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1192   /// Other shaped types unhandled.
1193   if (resType.isa<ShapedType>())
1194     return {};
1195 
1196   /// Bitcast integer or float to integer or float.
1197   APInt bits = operand.isa<FloatAttr>()
1198                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1199                    : operand.cast<IntegerAttr>().getValue();
1200 
1201   if (auto resFloatType = resType.dyn_cast<FloatType>())
1202     return FloatAttr::get(resType,
1203                           APFloat(resFloatType.getFloatSemantics(), bits));
1204   return IntegerAttr::get(resType, bits);
1205 }
1206 
1207 void arith::BitcastOp::getCanonicalizationPatterns(
1208     RewritePatternSet &patterns, MLIRContext *context) {
1209   patterns.add<BitcastOfBitcast>(context);
1210 }
1211 
1212 //===----------------------------------------------------------------------===//
1213 // Helpers for compare ops
1214 //===----------------------------------------------------------------------===//
1215 
1216 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1217 static Type getI1SameShape(Type type) {
1218   auto i1Type = IntegerType::get(type.getContext(), 1);
1219   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1220     return RankedTensorType::get(tensorType.getShape(), i1Type);
1221   if (type.isa<UnrankedTensorType>())
1222     return UnrankedTensorType::get(i1Type);
1223   if (auto vectorType = type.dyn_cast<VectorType>())
1224     return VectorType::get(vectorType.getShape(), i1Type,
1225                            vectorType.getNumScalableDims());
1226   return i1Type;
1227 }
1228 
1229 //===----------------------------------------------------------------------===//
1230 // CmpIOp
1231 //===----------------------------------------------------------------------===//
1232 
1233 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1234 /// comparison predicates.
1235 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1236                                     const APInt &lhs, const APInt &rhs) {
1237   switch (predicate) {
1238   case arith::CmpIPredicate::eq:
1239     return lhs.eq(rhs);
1240   case arith::CmpIPredicate::ne:
1241     return lhs.ne(rhs);
1242   case arith::CmpIPredicate::slt:
1243     return lhs.slt(rhs);
1244   case arith::CmpIPredicate::sle:
1245     return lhs.sle(rhs);
1246   case arith::CmpIPredicate::sgt:
1247     return lhs.sgt(rhs);
1248   case arith::CmpIPredicate::sge:
1249     return lhs.sge(rhs);
1250   case arith::CmpIPredicate::ult:
1251     return lhs.ult(rhs);
1252   case arith::CmpIPredicate::ule:
1253     return lhs.ule(rhs);
1254   case arith::CmpIPredicate::ugt:
1255     return lhs.ugt(rhs);
1256   case arith::CmpIPredicate::uge:
1257     return lhs.uge(rhs);
1258   }
1259   llvm_unreachable("unknown cmpi predicate kind");
1260 }
1261 
1262 /// Returns true if the predicate is true for two equal operands.
1263 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1264   switch (predicate) {
1265   case arith::CmpIPredicate::eq:
1266   case arith::CmpIPredicate::sle:
1267   case arith::CmpIPredicate::sge:
1268   case arith::CmpIPredicate::ule:
1269   case arith::CmpIPredicate::uge:
1270     return true;
1271   case arith::CmpIPredicate::ne:
1272   case arith::CmpIPredicate::slt:
1273   case arith::CmpIPredicate::sgt:
1274   case arith::CmpIPredicate::ult:
1275   case arith::CmpIPredicate::ugt:
1276     return false;
1277   }
1278   llvm_unreachable("unknown cmpi predicate kind");
1279 }
1280 
1281 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1282   auto boolAttr = BoolAttr::get(ctx, value);
1283   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1284   if (!shapedType)
1285     return boolAttr;
1286   return DenseElementsAttr::get(shapedType, boolAttr);
1287 }
1288 
1289 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1290   assert(operands.size() == 2 && "cmpi takes two operands");
1291 
1292   // cmpi(pred, x, x)
1293   if (getLhs() == getRhs()) {
1294     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1295     return getBoolAttribute(getType(), getContext(), val);
1296   }
1297 
1298   if (matchPattern(getRhs(), m_Zero())) {
1299     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1300       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1301         // extsi(%x : i1 -> iN) != 0  ->  %x
1302         if (getPredicate() == arith::CmpIPredicate::ne) {
1303           return extOp.getOperand();
1304         }
1305       }
1306     }
1307     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1308       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1309         // extui(%x : i1 -> iN) != 0  ->  %x
1310         if (getPredicate() == arith::CmpIPredicate::ne) {
1311           return extOp.getOperand();
1312         }
1313       }
1314     }
1315   }
1316 
1317   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1318   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1319   if (!lhs || !rhs)
1320     return {};
1321 
1322   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1323   return BoolAttr::get(getContext(), val);
1324 }
1325 
1326 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1327                                                 MLIRContext *context) {
1328   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1329 }
1330 
1331 //===----------------------------------------------------------------------===//
1332 // CmpFOp
1333 //===----------------------------------------------------------------------===//
1334 
1335 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1336 /// comparison predicates.
1337 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1338                                     const APFloat &lhs, const APFloat &rhs) {
1339   auto cmpResult = lhs.compare(rhs);
1340   switch (predicate) {
1341   case arith::CmpFPredicate::AlwaysFalse:
1342     return false;
1343   case arith::CmpFPredicate::OEQ:
1344     return cmpResult == APFloat::cmpEqual;
1345   case arith::CmpFPredicate::OGT:
1346     return cmpResult == APFloat::cmpGreaterThan;
1347   case arith::CmpFPredicate::OGE:
1348     return cmpResult == APFloat::cmpGreaterThan ||
1349            cmpResult == APFloat::cmpEqual;
1350   case arith::CmpFPredicate::OLT:
1351     return cmpResult == APFloat::cmpLessThan;
1352   case arith::CmpFPredicate::OLE:
1353     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1354   case arith::CmpFPredicate::ONE:
1355     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1356   case arith::CmpFPredicate::ORD:
1357     return cmpResult != APFloat::cmpUnordered;
1358   case arith::CmpFPredicate::UEQ:
1359     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1360   case arith::CmpFPredicate::UGT:
1361     return cmpResult == APFloat::cmpUnordered ||
1362            cmpResult == APFloat::cmpGreaterThan;
1363   case arith::CmpFPredicate::UGE:
1364     return cmpResult == APFloat::cmpUnordered ||
1365            cmpResult == APFloat::cmpGreaterThan ||
1366            cmpResult == APFloat::cmpEqual;
1367   case arith::CmpFPredicate::ULT:
1368     return cmpResult == APFloat::cmpUnordered ||
1369            cmpResult == APFloat::cmpLessThan;
1370   case arith::CmpFPredicate::ULE:
1371     return cmpResult == APFloat::cmpUnordered ||
1372            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1373   case arith::CmpFPredicate::UNE:
1374     return cmpResult != APFloat::cmpEqual;
1375   case arith::CmpFPredicate::UNO:
1376     return cmpResult == APFloat::cmpUnordered;
1377   case arith::CmpFPredicate::AlwaysTrue:
1378     return true;
1379   }
1380   llvm_unreachable("unknown cmpf predicate kind");
1381 }
1382 
1383 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1384   assert(operands.size() == 2 && "cmpf takes two operands");
1385 
1386   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1387   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1388 
1389   // If one operand is NaN, making them both NaN does not change the result.
1390   if (lhs && lhs.getValue().isNaN())
1391     rhs = lhs;
1392   if (rhs && rhs.getValue().isNaN())
1393     lhs = rhs;
1394 
1395   if (!lhs || !rhs)
1396     return {};
1397 
1398   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1399   return BoolAttr::get(getContext(), val);
1400 }
1401 
1402 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1403 public:
1404   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1405 
1406   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1407                                                  bool isUnsigned) {
1408     using namespace arith;
1409     switch (pred) {
1410     case CmpFPredicate::UEQ:
1411     case CmpFPredicate::OEQ:
1412       return CmpIPredicate::eq;
1413     case CmpFPredicate::UGT:
1414     case CmpFPredicate::OGT:
1415       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1416     case CmpFPredicate::UGE:
1417     case CmpFPredicate::OGE:
1418       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1419     case CmpFPredicate::ULT:
1420     case CmpFPredicate::OLT:
1421       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1422     case CmpFPredicate::ULE:
1423     case CmpFPredicate::OLE:
1424       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1425     case CmpFPredicate::UNE:
1426     case CmpFPredicate::ONE:
1427       return CmpIPredicate::ne;
1428     default:
1429       llvm_unreachable("Unexpected predicate!");
1430     }
1431   }
1432 
1433   LogicalResult matchAndRewrite(CmpFOp op,
1434                                 PatternRewriter &rewriter) const override {
1435     FloatAttr flt;
1436     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1437       return failure();
1438 
1439     const APFloat &rhs = flt.getValue();
1440 
1441     // Don't attempt to fold a nan.
1442     if (rhs.isNaN())
1443       return failure();
1444 
1445     // Get the width of the mantissa.  We don't want to hack on conversions that
1446     // might lose information from the integer, e.g. "i64 -> float"
1447     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1448     int mantissaWidth = floatTy.getFPMantissaWidth();
1449     if (mantissaWidth <= 0)
1450       return failure();
1451 
1452     bool isUnsigned;
1453     Value intVal;
1454 
1455     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1456       isUnsigned = false;
1457       intVal = si.getIn();
1458     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1459       isUnsigned = true;
1460       intVal = ui.getIn();
1461     } else {
1462       return failure();
1463     }
1464 
1465     // Check to see that the input is converted from an integer type that is
1466     // small enough that preserves all bits.
1467     auto intTy = intVal.getType().cast<IntegerType>();
1468     auto intWidth = intTy.getWidth();
1469 
1470     // Number of bits representing values, as opposed to the sign
1471     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1472 
1473     // Following test does NOT adjust intWidth downwards for signed inputs,
1474     // because the most negative value still requires all the mantissa bits
1475     // to distinguish it from one less than that value.
1476     if ((int)intWidth > mantissaWidth) {
1477       // Conversion would lose accuracy. Check if loss can impact comparison.
1478       int exponent = ilogb(rhs);
1479       if (exponent == APFloat::IEK_Inf) {
1480         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1481         if (maxExponent < (int)valueBits) {
1482           // Conversion could create infinity.
1483           return failure();
1484         }
1485       } else {
1486         // Note that if rhs is zero or NaN, then Exp is negative
1487         // and first condition is trivially false.
1488         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1489           // Conversion could affect comparison.
1490           return failure();
1491         }
1492       }
1493     }
1494 
1495     // Convert to equivalent cmpi predicate
1496     CmpIPredicate pred;
1497     switch (op.getPredicate()) {
1498     case CmpFPredicate::ORD:
1499       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1500       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1501                                                  /*width=*/1);
1502       return success();
1503     case CmpFPredicate::UNO:
1504       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1505       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1506                                                  /*width=*/1);
1507       return success();
1508     default:
1509       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1510       break;
1511     }
1512 
1513     if (!isUnsigned) {
1514       // If the rhs value is > SignedMax, fold the comparison.  This handles
1515       // +INF and large values.
1516       APFloat signedMax(rhs.getSemantics());
1517       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1518                                  APFloat::rmNearestTiesToEven);
1519       if (signedMax < rhs) { // smax < 13123.0
1520         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1521             pred == CmpIPredicate::sle)
1522           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1523                                                      /*width=*/1);
1524         else
1525           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1526                                                      /*width=*/1);
1527         return success();
1528       }
1529     } else {
1530       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1531       // +INF and large values.
1532       APFloat unsignedMax(rhs.getSemantics());
1533       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1534                                    APFloat::rmNearestTiesToEven);
1535       if (unsignedMax < rhs) { // umax < 13123.0
1536         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1537             pred == CmpIPredicate::ule)
1538           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1539                                                      /*width=*/1);
1540         else
1541           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1542                                                      /*width=*/1);
1543         return success();
1544       }
1545     }
1546 
1547     if (!isUnsigned) {
1548       // See if the rhs value is < SignedMin.
1549       APFloat signedMin(rhs.getSemantics());
1550       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1551                                  APFloat::rmNearestTiesToEven);
1552       if (signedMin > rhs) { // smin > 12312.0
1553         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1554             pred == CmpIPredicate::sge)
1555           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1556                                                      /*width=*/1);
1557         else
1558           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1559                                                      /*width=*/1);
1560         return success();
1561       }
1562     } else {
1563       // See if the rhs value is < UnsignedMin.
1564       APFloat unsignedMin(rhs.getSemantics());
1565       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1566                                    APFloat::rmNearestTiesToEven);
1567       if (unsignedMin > rhs) { // umin > 12312.0
1568         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1569             pred == CmpIPredicate::uge)
1570           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1571                                                      /*width=*/1);
1572         else
1573           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1574                                                      /*width=*/1);
1575         return success();
1576       }
1577     }
1578 
1579     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1580     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1581     // casting the FP value to the integer value and back, checking for
1582     // equality. Don't do this for zero, because -0.0 is not fractional.
1583     bool ignored;
1584     APSInt rhsInt(intWidth, isUnsigned);
1585     if (APFloat::opInvalidOp ==
1586         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1587       // Undefined behavior invoked - the destination type can't represent
1588       // the input constant.
1589       return failure();
1590     }
1591 
1592     if (!rhs.isZero()) {
1593       APFloat apf(floatTy.getFloatSemantics(),
1594                   APInt::getZero(floatTy.getWidth()));
1595       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1596 
1597       bool equal = apf == rhs;
1598       if (!equal) {
1599         // If we had a comparison against a fractional value, we have to adjust
1600         // the compare predicate and sometimes the value.  rhsInt is rounded
1601         // towards zero at this point.
1602         switch (pred) {
1603         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1604           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1605                                                      /*width=*/1);
1606           return success();
1607         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1608           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1609                                                      /*width=*/1);
1610           return success();
1611         case CmpIPredicate::ule:
1612           // (float)int <= 4.4   --> int <= 4
1613           // (float)int <= -4.4  --> false
1614           if (rhs.isNegative()) {
1615             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1616                                                        /*width=*/1);
1617             return success();
1618           }
1619           break;
1620         case CmpIPredicate::sle:
1621           // (float)int <= 4.4   --> int <= 4
1622           // (float)int <= -4.4  --> int < -4
1623           if (rhs.isNegative())
1624             pred = CmpIPredicate::slt;
1625           break;
1626         case CmpIPredicate::ult:
1627           // (float)int < -4.4   --> false
1628           // (float)int < 4.4    --> int <= 4
1629           if (rhs.isNegative()) {
1630             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1631                                                        /*width=*/1);
1632             return success();
1633           }
1634           pred = CmpIPredicate::ule;
1635           break;
1636         case CmpIPredicate::slt:
1637           // (float)int < -4.4   --> int < -4
1638           // (float)int < 4.4    --> int <= 4
1639           if (!rhs.isNegative())
1640             pred = CmpIPredicate::sle;
1641           break;
1642         case CmpIPredicate::ugt:
1643           // (float)int > 4.4    --> int > 4
1644           // (float)int > -4.4   --> true
1645           if (rhs.isNegative()) {
1646             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1647                                                        /*width=*/1);
1648             return success();
1649           }
1650           break;
1651         case CmpIPredicate::sgt:
1652           // (float)int > 4.4    --> int > 4
1653           // (float)int > -4.4   --> int >= -4
1654           if (rhs.isNegative())
1655             pred = CmpIPredicate::sge;
1656           break;
1657         case CmpIPredicate::uge:
1658           // (float)int >= -4.4   --> true
1659           // (float)int >= 4.4    --> int > 4
1660           if (rhs.isNegative()) {
1661             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1662                                                        /*width=*/1);
1663             return success();
1664           }
1665           pred = CmpIPredicate::ugt;
1666           break;
1667         case CmpIPredicate::sge:
1668           // (float)int >= -4.4   --> int >= -4
1669           // (float)int >= 4.4    --> int > 4
1670           if (!rhs.isNegative())
1671             pred = CmpIPredicate::sgt;
1672           break;
1673         }
1674       }
1675     }
1676 
1677     // Lower this FP comparison into an appropriate integer version of the
1678     // comparison.
1679     rewriter.replaceOpWithNewOp<CmpIOp>(
1680         op, pred, intVal,
1681         rewriter.create<ConstantOp>(
1682             op.getLoc(), intVal.getType(),
1683             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1684     return success();
1685   }
1686 };
1687 
1688 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1689                                                 MLIRContext *context) {
1690   patterns.insert<CmpFIntToFPConst>(context);
1691 }
1692 
1693 //===----------------------------------------------------------------------===//
1694 // SelectOp
1695 //===----------------------------------------------------------------------===//
1696 
1697 // Transforms a select of a boolean to arithmetic operations
1698 //
1699 //  arith.select %arg, %x, %y : i1
1700 //
1701 //  becomes
1702 //
1703 //  and(%arg, %x) or and(!%arg, %y)
1704 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1705   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1706 
1707   LogicalResult matchAndRewrite(arith::SelectOp op,
1708                                 PatternRewriter &rewriter) const override {
1709     if (!op.getType().isInteger(1))
1710       return failure();
1711 
1712     Value falseConstant =
1713         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1714     Value notCondition = rewriter.create<arith::XOrIOp>(
1715         op.getLoc(), op.getCondition(), falseConstant);
1716 
1717     Value trueVal = rewriter.create<arith::AndIOp>(
1718         op.getLoc(), op.getCondition(), op.getTrueValue());
1719     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1720                                                     op.getFalseValue());
1721     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1722     return success();
1723   }
1724 };
1725 
1726 //  select %arg, %c1, %c0 => extui %arg
1727 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1728   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1729 
1730   LogicalResult matchAndRewrite(arith::SelectOp op,
1731                                 PatternRewriter &rewriter) const override {
1732     // Cannot extui i1 to i1, or i1 to f32
1733     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1734       return failure();
1735 
1736     // select %x, c1, %c0 => extui %arg
1737     if (matchPattern(op.getTrueValue(), m_One()))
1738       if (matchPattern(op.getFalseValue(), m_Zero())) {
1739         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1740                                                     op.getCondition());
1741         return success();
1742       }
1743 
1744     // select %x, c0, %c1 => extui (xor %arg, true)
1745     if (matchPattern(op.getTrueValue(), m_Zero()))
1746       if (matchPattern(op.getFalseValue(), m_One())) {
1747         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1748             op, op.getType(),
1749             rewriter.create<arith::XOrIOp>(
1750                 op.getLoc(), op.getCondition(),
1751                 rewriter.create<arith::ConstantIntOp>(
1752                     op.getLoc(), 1, op.getCondition().getType())));
1753         return success();
1754       }
1755 
1756     return failure();
1757   }
1758 };
1759 
1760 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1761                                                   MLIRContext *context) {
1762   results.add<SelectI1Simplify, SelectToExtUI>(context);
1763 }
1764 
1765 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1766   Value trueVal = getTrueValue();
1767   Value falseVal = getFalseValue();
1768   if (trueVal == falseVal)
1769     return trueVal;
1770 
1771   Value condition = getCondition();
1772 
1773   // select true, %0, %1 => %0
1774   if (matchPattern(condition, m_One()))
1775     return trueVal;
1776 
1777   // select false, %0, %1 => %1
1778   if (matchPattern(condition, m_Zero()))
1779     return falseVal;
1780 
1781   // select %x, true, false => %x
1782   if (getType().isInteger(1))
1783     if (matchPattern(getTrueValue(), m_One()))
1784       if (matchPattern(getFalseValue(), m_Zero()))
1785         return condition;
1786 
1787   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1788     auto pred = cmp.getPredicate();
1789     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1790       auto cmpLhs = cmp.getLhs();
1791       auto cmpRhs = cmp.getRhs();
1792 
1793       // %0 = arith.cmpi eq, %arg0, %arg1
1794       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1795 
1796       // %0 = arith.cmpi ne, %arg0, %arg1
1797       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1798 
1799       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1800           (cmpRhs == trueVal && cmpLhs == falseVal))
1801         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1802     }
1803   }
1804   return nullptr;
1805 }
1806 
1807 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1808   Type conditionType, resultType;
1809   SmallVector<OpAsmParser::OperandType, 3> operands;
1810   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1811       parser.parseOptionalAttrDict(result.attributes) ||
1812       parser.parseColonType(resultType))
1813     return failure();
1814 
1815   // Check for the explicit condition type if this is a masked tensor or vector.
1816   if (succeeded(parser.parseOptionalComma())) {
1817     conditionType = resultType;
1818     if (parser.parseType(resultType))
1819       return failure();
1820   } else {
1821     conditionType = parser.getBuilder().getI1Type();
1822   }
1823 
1824   result.addTypes(resultType);
1825   return parser.resolveOperands(operands,
1826                                 {conditionType, resultType, resultType},
1827                                 parser.getNameLoc(), result.operands);
1828 }
1829 
1830 void arith::SelectOp::print(OpAsmPrinter &p) {
1831   p << " " << getOperands();
1832   p.printOptionalAttrDict((*this)->getAttrs());
1833   p << " : ";
1834   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1835     p << condType << ", ";
1836   p << getType();
1837 }
1838 
1839 LogicalResult arith::SelectOp::verify() {
1840   Type conditionType = getCondition().getType();
1841   if (conditionType.isSignlessInteger(1))
1842     return success();
1843 
1844   // If the result type is a vector or tensor, the type can be a mask with the
1845   // same elements.
1846   Type resultType = getType();
1847   if (!resultType.isa<TensorType, VectorType>())
1848     return emitOpError() << "expected condition to be a signless i1, but got "
1849                          << conditionType;
1850   Type shapedConditionType = getI1SameShape(resultType);
1851   if (conditionType != shapedConditionType) {
1852     return emitOpError() << "expected condition type to have the same shape "
1853                             "as the result type, expected "
1854                          << shapedConditionType << ", but got "
1855                          << conditionType;
1856   }
1857   return success();
1858 }
1859 //===----------------------------------------------------------------------===//
1860 // ShLIOp
1861 //===----------------------------------------------------------------------===//
1862 
1863 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1864   // Don't fold if shifting more than the bit width.
1865   bool bounded = false;
1866   auto result = constFoldBinaryOp<IntegerAttr>(
1867       operands, [&](const APInt &a, const APInt &b) {
1868         bounded = b.ule(b.getBitWidth());
1869         return std::move(a).shl(b);
1870       });
1871   return bounded ? result : Attribute();
1872 }
1873 
1874 //===----------------------------------------------------------------------===//
1875 // Atomic Enum
1876 //===----------------------------------------------------------------------===//
1877 
1878 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1879 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1880                                             OpBuilder &builder, Location loc) {
1881   switch (kind) {
1882   case AtomicRMWKind::maxf:
1883     return builder.getFloatAttr(
1884         resultType,
1885         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1886                         /*Negative=*/true));
1887   case AtomicRMWKind::addf:
1888   case AtomicRMWKind::addi:
1889   case AtomicRMWKind::maxu:
1890   case AtomicRMWKind::ori:
1891     return builder.getZeroAttr(resultType);
1892   case AtomicRMWKind::andi:
1893     return builder.getIntegerAttr(
1894         resultType,
1895         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1896   case AtomicRMWKind::maxs:
1897     return builder.getIntegerAttr(
1898         resultType,
1899         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1900   case AtomicRMWKind::minf:
1901     return builder.getFloatAttr(
1902         resultType,
1903         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1904                         /*Negative=*/false));
1905   case AtomicRMWKind::mins:
1906     return builder.getIntegerAttr(
1907         resultType,
1908         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1909   case AtomicRMWKind::minu:
1910     return builder.getIntegerAttr(
1911         resultType,
1912         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1913   case AtomicRMWKind::muli:
1914     return builder.getIntegerAttr(resultType, 1);
1915   case AtomicRMWKind::mulf:
1916     return builder.getFloatAttr(resultType, 1);
1917   // TODO: Add remaining reduction operations.
1918   default:
1919     (void)emitOptionalError(loc, "Reduction operation type not supported");
1920     break;
1921   }
1922   return nullptr;
1923 }
1924 
1925 /// Returns the identity value associated with an AtomicRMWKind op.
1926 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1927                                     OpBuilder &builder, Location loc) {
1928   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1929   return builder.create<arith::ConstantOp>(loc, attr);
1930 }
1931 
1932 /// Return the value obtained by applying the reduction operation kind
1933 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1934 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1935                                   Location loc, Value lhs, Value rhs) {
1936   switch (op) {
1937   case AtomicRMWKind::addf:
1938     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1939   case AtomicRMWKind::addi:
1940     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1941   case AtomicRMWKind::mulf:
1942     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1943   case AtomicRMWKind::muli:
1944     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1945   case AtomicRMWKind::maxf:
1946     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1947   case AtomicRMWKind::minf:
1948     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1949   case AtomicRMWKind::maxs:
1950     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1951   case AtomicRMWKind::mins:
1952     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1953   case AtomicRMWKind::maxu:
1954     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1955   case AtomicRMWKind::minu:
1956     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1957   case AtomicRMWKind::ori:
1958     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1959   case AtomicRMWKind::andi:
1960     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1961   // TODO: Add remaining reduction operations.
1962   default:
1963     (void)emitOptionalError(loc, "Reduction operation type not supported");
1964     break;
1965   }
1966   return nullptr;
1967 }
1968 
1969 //===----------------------------------------------------------------------===//
1970 // TableGen'd op method definitions
1971 //===----------------------------------------------------------------------===//
1972 
1973 #define GET_OP_CLASSES
1974 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1975 
1976 //===----------------------------------------------------------------------===//
1977 // TableGen'd enum attribute definitions
1978 //===----------------------------------------------------------------------===//
1979 
1980 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1981