1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30                                    Attribute lhs, Attribute rhs) {
31   return builder.getIntegerAttr(res.getType(),
32                                 lhs.cast<IntegerAttr>().getInt() +
33                                     rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37                                    Attribute lhs, Attribute rhs) {
38   return builder.getIntegerAttr(res.getType(),
39                                 lhs.cast<IntegerAttr>().getInt() -
40                                     rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45   switch (pred) {
46   case arith::CmpIPredicate::eq:
47     return arith::CmpIPredicate::ne;
48   case arith::CmpIPredicate::ne:
49     return arith::CmpIPredicate::eq;
50   case arith::CmpIPredicate::slt:
51     return arith::CmpIPredicate::sge;
52   case arith::CmpIPredicate::sle:
53     return arith::CmpIPredicate::sgt;
54   case arith::CmpIPredicate::sgt:
55     return arith::CmpIPredicate::sle;
56   case arith::CmpIPredicate::sge:
57     return arith::CmpIPredicate::slt;
58   case arith::CmpIPredicate::ult:
59     return arith::CmpIPredicate::uge;
60   case arith::CmpIPredicate::ule:
61     return arith::CmpIPredicate::ugt;
62   case arith::CmpIPredicate::ugt:
63     return arith::CmpIPredicate::ule;
64   case arith::CmpIPredicate::uge:
65     return arith::CmpIPredicate::ult;
66   }
67   llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71   return arith::CmpIPredicateAttr::get(pred.getContext(),
72                                        invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88     function_ref<void(Value, StringRef)> setNameFn) {
89   auto type = getType();
90   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91     auto intType = type.dyn_cast<IntegerType>();
92 
93     // Sugar i1 constants with 'true' and 'false'.
94     if (intType && intType.getWidth() == 1)
95       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97     // Otherwise, build a compex name with the value and type.
98     SmallString<32> specialNameBuffer;
99     llvm::raw_svector_ostream specialName(specialNameBuffer);
100     specialName << 'c' << intCst.getInt();
101     if (intType)
102       specialName << '_' << type;
103     setNameFn(getResult(), specialName.str());
104   } else {
105     setNameFn(getResult(), "cst");
106   }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
111 LogicalResult arith::ConstantOp::verify() {
112   auto type = getType();
113   // The value's type must match the return type.
114   if (getValue().getType() != type) {
115     return emitOpError() << "value type " << getValue().getType()
116                          << " must match return type: " << type;
117   }
118   // Integer values must be signless.
119   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120     return emitOpError("integer return type must be signless");
121   // Any float or elements attribute are acceptable.
122   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123     return emitOpError(
124         "value must be an integer, float, or elements attribute");
125   }
126   return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130   // The value's type must be the same as the provided type.
131   if (value.getType() != type)
132     return false;
133   // Integer values must be signless.
134   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135     return false;
136   // Integer, float, and element attributes are buildable.
137   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139 
140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141   return getValue();
142 }
143 
144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145                                  int64_t value, unsigned width) {
146   auto type = builder.getIntegerType(width);
147   arith::ConstantOp::build(builder, result, type,
148                            builder.getIntegerAttr(type, value));
149 }
150 
151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152                                  int64_t value, Type type) {
153   assert(type.isSignlessInteger() &&
154          "ConstantIntOp can only have signless integer type values");
155   arith::ConstantOp::build(builder, result, type,
156                            builder.getIntegerAttr(type, value));
157 }
158 
159 bool arith::ConstantIntOp::classof(Operation *op) {
160   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161     return constOp.getType().isSignlessInteger();
162   return false;
163 }
164 
165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166                                    const APFloat &value, FloatType type) {
167   arith::ConstantOp::build(builder, result, type,
168                            builder.getFloatAttr(type, value));
169 }
170 
171 bool arith::ConstantFloatOp::classof(Operation *op) {
172   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173     return constOp.getType().isa<FloatType>();
174   return false;
175 }
176 
177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178                                    int64_t value) {
179   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180                            builder.getIndexAttr(value));
181 }
182 
183 bool arith::ConstantIndexOp::classof(Operation *op) {
184   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185     return constOp.getType().isIndex();
186   return false;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194   // addi(x, 0) -> x
195   if (matchPattern(getRhs(), m_Zero()))
196     return getLhs();
197 
198   // addi(subi(a, b), b) -> a
199   if (auto sub = getLhs().getDefiningOp<SubIOp>())
200     if (getRhs() == sub.getRhs())
201       return sub.getLhs();
202 
203   // addi(b, subi(a, b)) -> a
204   if (auto sub = getRhs().getDefiningOp<SubIOp>())
205     if (getLhs() == sub.getRhs())
206       return sub.getLhs();
207 
208   return constFoldBinaryOp<IntegerAttr>(
209       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211 
212 void arith::AddIOp::getCanonicalizationPatterns(
213     RewritePatternSet &patterns, MLIRContext *context) {
214   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215       context);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221 
222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223   // subi(x,x) -> 0
224   if (getOperand(0) == getOperand(1))
225     return Builder(getContext()).getZeroAttr(getType());
226   // subi(x,0) -> x
227   if (matchPattern(getRhs(), m_Zero()))
228     return getLhs();
229 
230   return constFoldBinaryOp<IntegerAttr>(
231       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233 
234 void arith::SubIOp::getCanonicalizationPatterns(
235     RewritePatternSet &patterns, MLIRContext *context) {
236   patterns
237       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247   // muli(x, 0) -> 0
248   if (matchPattern(getRhs(), m_Zero()))
249     return getRhs();
250   // muli(x, 1) -> x
251   if (matchPattern(getRhs(), m_One()))
252     return getOperand(0);
253   // TODO: Handle the overflow case.
254 
255   // default folder
256   return constFoldBinaryOp<IntegerAttr>(
257       operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265   // Don't fold if it would require a division by zero.
266   bool div0 = false;
267   auto result =
268       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
269         if (div0 || !b) {
270           div0 = true;
271           return a;
272         }
273         return a.udiv(b);
274       });
275 
276   // Fold out division by one. Assumes all tensors of all ones are splats.
277   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
278     if (rhs.getValue() == 1)
279       return getLhs();
280   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
281     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
282       return getLhs();
283   }
284 
285   return div0 ? Attribute() : result;
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // DivSIOp
290 //===----------------------------------------------------------------------===//
291 
292 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
293   // Don't fold if it would overflow or if it requires a division by zero.
294   bool overflowOrDiv0 = false;
295   auto result =
296       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
297         if (overflowOrDiv0 || !b) {
298           overflowOrDiv0 = true;
299           return a;
300         }
301         return a.sdiv_ov(b, overflowOrDiv0);
302       });
303 
304   // Fold out division by one. Assumes all tensors of all ones are splats.
305   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
306     if (rhs.getValue() == 1)
307       return getLhs();
308   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
309     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
310       return getLhs();
311   }
312 
313   return overflowOrDiv0 ? Attribute() : result;
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // Ceil and floor division folding helpers
318 //===----------------------------------------------------------------------===//
319 
320 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
321                                     bool &overflow) {
322   // Returns (a-1)/b + 1
323   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
324   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
325   return val.sadd_ov(one, overflow);
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // CeilDivUIOp
330 //===----------------------------------------------------------------------===//
331 
332 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
333   bool overflowOrDiv0 = false;
334   auto result =
335       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
336         if (overflowOrDiv0 || !b) {
337           overflowOrDiv0 = true;
338           return a;
339         }
340         APInt quotient = a.udiv(b);
341         if (!a.urem(b))
342           return quotient;
343         APInt one(a.getBitWidth(), 1, true);
344         return quotient.uadd_ov(one, overflowOrDiv0);
345       });
346   // Fold out ceil division by one. Assumes all tensors of all ones are
347   // splats.
348   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
349     if (rhs.getValue() == 1)
350       return getLhs();
351   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
352     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
353       return getLhs();
354   }
355 
356   return overflowOrDiv0 ? Attribute() : result;
357 }
358 
359 //===----------------------------------------------------------------------===//
360 // CeilDivSIOp
361 //===----------------------------------------------------------------------===//
362 
363 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
364   // Don't fold if it would overflow or if it requires a division by zero.
365   bool overflowOrDiv0 = false;
366   auto result =
367       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
368         if (overflowOrDiv0 || !b) {
369           overflowOrDiv0 = true;
370           return a;
371         }
372         if (!a)
373           return a;
374         // After this point we know that neither a or b are zero.
375         unsigned bits = a.getBitWidth();
376         APInt zero = APInt::getZero(bits);
377         bool aGtZero = a.sgt(zero);
378         bool bGtZero = b.sgt(zero);
379         if (aGtZero && bGtZero) {
380           // Both positive, return ceil(a, b).
381           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
382         }
383         if (!aGtZero && !bGtZero) {
384           // Both negative, return ceil(-a, -b).
385           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
386           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
387           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
388         }
389         if (!aGtZero && bGtZero) {
390           // A is negative, b is positive, return - ( -a / b).
391           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
392           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
393           return zero.ssub_ov(div, overflowOrDiv0);
394         }
395         // A is positive, b is negative, return - (a / -b).
396         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
397         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
398         return zero.ssub_ov(div, overflowOrDiv0);
399       });
400 
401   // Fold out ceil division by one. Assumes all tensors of all ones are
402   // splats.
403   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
404     if (rhs.getValue() == 1)
405       return getLhs();
406   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
407     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
408       return getLhs();
409   }
410 
411   return overflowOrDiv0 ? Attribute() : result;
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // FloorDivSIOp
416 //===----------------------------------------------------------------------===//
417 
418 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
419   // Don't fold if it would overflow or if it requires a division by zero.
420   bool overflowOrDiv0 = false;
421   auto result =
422       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
423         if (overflowOrDiv0 || !b) {
424           overflowOrDiv0 = true;
425           return a;
426         }
427         if (!a)
428           return a;
429         // After this point we know that neither a or b are zero.
430         unsigned bits = a.getBitWidth();
431         APInt zero = APInt::getZero(bits);
432         bool aGtZero = a.sgt(zero);
433         bool bGtZero = b.sgt(zero);
434         if (aGtZero && bGtZero) {
435           // Both positive, return a / b.
436           return a.sdiv_ov(b, overflowOrDiv0);
437         }
438         if (!aGtZero && !bGtZero) {
439           // Both negative, return -a / -b.
440           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
441           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
442           return posA.sdiv_ov(posB, overflowOrDiv0);
443         }
444         if (!aGtZero && bGtZero) {
445           // A is negative, b is positive, return - ceil(-a, b).
446           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
447           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
448           return zero.ssub_ov(ceil, overflowOrDiv0);
449         }
450         // A is positive, b is negative, return - ceil(a, -b).
451         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
452         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
453         return zero.ssub_ov(ceil, overflowOrDiv0);
454       });
455 
456   // Fold out floor division by one. Assumes all tensors of all ones are
457   // splats.
458   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
459     if (rhs.getValue() == 1)
460       return getLhs();
461   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
462     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
463       return getLhs();
464   }
465 
466   return overflowOrDiv0 ? Attribute() : result;
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // RemUIOp
471 //===----------------------------------------------------------------------===//
472 
473 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
474   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
475   if (!rhs)
476     return {};
477   auto rhsValue = rhs.getValue();
478 
479   // x % 1 = 0
480   if (rhsValue.isOneValue())
481     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
482 
483   // Don't fold if it requires division by zero.
484   if (rhsValue.isNullValue())
485     return {};
486 
487   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
488   if (!lhs)
489     return {};
490   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // RemSIOp
495 //===----------------------------------------------------------------------===//
496 
497 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
498   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
499   if (!rhs)
500     return {};
501   auto rhsValue = rhs.getValue();
502 
503   // x % 1 = 0
504   if (rhsValue.isOneValue())
505     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
506 
507   // Don't fold if it requires division by zero.
508   if (rhsValue.isNullValue())
509     return {};
510 
511   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
512   if (!lhs)
513     return {};
514   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // AndIOp
519 //===----------------------------------------------------------------------===//
520 
521 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
522   /// and(x, 0) -> 0
523   if (matchPattern(getRhs(), m_Zero()))
524     return getRhs();
525   /// and(x, allOnes) -> x
526   APInt intValue;
527   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
528     return getLhs();
529 
530   return constFoldBinaryOp<IntegerAttr>(
531       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // OrIOp
536 //===----------------------------------------------------------------------===//
537 
538 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
539   /// or(x, 0) -> x
540   if (matchPattern(getRhs(), m_Zero()))
541     return getLhs();
542   /// or(x, <all ones>) -> <all ones>
543   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
544     if (rhsAttr.getValue().isAllOnes())
545       return rhsAttr;
546 
547   return constFoldBinaryOp<IntegerAttr>(
548       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // XOrIOp
553 //===----------------------------------------------------------------------===//
554 
555 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
556   /// xor(x, 0) -> x
557   if (matchPattern(getRhs(), m_Zero()))
558     return getLhs();
559   /// xor(x, x) -> 0
560   if (getLhs() == getRhs())
561     return Builder(getContext()).getZeroAttr(getType());
562   /// xor(xor(x, a), a) -> x
563   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
564     if (prev.getRhs() == getRhs())
565       return prev.getLhs();
566 
567   return constFoldBinaryOp<IntegerAttr>(
568       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
569 }
570 
571 void arith::XOrIOp::getCanonicalizationPatterns(
572     RewritePatternSet &patterns, MLIRContext *context) {
573   patterns.add<XOrINotCmpI>(context);
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // NegFOp
578 //===----------------------------------------------------------------------===//
579 
580 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
581   return constFoldUnaryOp<FloatAttr>(operands,
582                                      [](const APFloat &a) { return -a; });
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // AddFOp
587 //===----------------------------------------------------------------------===//
588 
589 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
590   // addf(x, -0) -> x
591   if (matchPattern(getRhs(), m_NegZeroFloat()))
592     return getLhs();
593 
594   return constFoldBinaryOp<FloatAttr>(
595       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
596 }
597 
598 //===----------------------------------------------------------------------===//
599 // SubFOp
600 //===----------------------------------------------------------------------===//
601 
602 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
603   // subf(x, +0) -> x
604   if (matchPattern(getRhs(), m_PosZeroFloat()))
605     return getLhs();
606 
607   return constFoldBinaryOp<FloatAttr>(
608       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // MaxFOp
613 //===----------------------------------------------------------------------===//
614 
615 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
616   assert(operands.size() == 2 && "maxf takes two operands");
617 
618   // maxf(x,x) -> x
619   if (getLhs() == getRhs())
620     return getRhs();
621 
622   // maxf(x, -inf) -> x
623   if (matchPattern(getRhs(), m_NegInfFloat()))
624     return getLhs();
625 
626   return constFoldBinaryOp<FloatAttr>(
627       operands,
628       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // MaxSIOp
633 //===----------------------------------------------------------------------===//
634 
635 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
636   assert(operands.size() == 2 && "binary operation takes two operands");
637 
638   // maxsi(x,x) -> x
639   if (getLhs() == getRhs())
640     return getRhs();
641 
642   APInt intValue;
643   // maxsi(x,MAX_INT) -> MAX_INT
644   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
645       intValue.isMaxSignedValue())
646     return getRhs();
647 
648   // maxsi(x, MIN_INT) -> x
649   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
650       intValue.isMinSignedValue())
651     return getLhs();
652 
653   return constFoldBinaryOp<IntegerAttr>(operands,
654                                         [](const APInt &a, const APInt &b) {
655                                           return llvm::APIntOps::smax(a, b);
656                                         });
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // MaxUIOp
661 //===----------------------------------------------------------------------===//
662 
663 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
664   assert(operands.size() == 2 && "binary operation takes two operands");
665 
666   // maxui(x,x) -> x
667   if (getLhs() == getRhs())
668     return getRhs();
669 
670   APInt intValue;
671   // maxui(x,MAX_INT) -> MAX_INT
672   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
673     return getRhs();
674 
675   // maxui(x, MIN_INT) -> x
676   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
677     return getLhs();
678 
679   return constFoldBinaryOp<IntegerAttr>(operands,
680                                         [](const APInt &a, const APInt &b) {
681                                           return llvm::APIntOps::umax(a, b);
682                                         });
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // MinFOp
687 //===----------------------------------------------------------------------===//
688 
689 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
690   assert(operands.size() == 2 && "minf takes two operands");
691 
692   // minf(x,x) -> x
693   if (getLhs() == getRhs())
694     return getRhs();
695 
696   // minf(x, +inf) -> x
697   if (matchPattern(getRhs(), m_PosInfFloat()))
698     return getLhs();
699 
700   return constFoldBinaryOp<FloatAttr>(
701       operands,
702       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // MinSIOp
707 //===----------------------------------------------------------------------===//
708 
709 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
710   assert(operands.size() == 2 && "binary operation takes two operands");
711 
712   // minsi(x,x) -> x
713   if (getLhs() == getRhs())
714     return getRhs();
715 
716   APInt intValue;
717   // minsi(x,MIN_INT) -> MIN_INT
718   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
719       intValue.isMinSignedValue())
720     return getRhs();
721 
722   // minsi(x, MAX_INT) -> x
723   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
724       intValue.isMaxSignedValue())
725     return getLhs();
726 
727   return constFoldBinaryOp<IntegerAttr>(operands,
728                                         [](const APInt &a, const APInt &b) {
729                                           return llvm::APIntOps::smin(a, b);
730                                         });
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // MinUIOp
735 //===----------------------------------------------------------------------===//
736 
737 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
738   assert(operands.size() == 2 && "binary operation takes two operands");
739 
740   // minui(x,x) -> x
741   if (getLhs() == getRhs())
742     return getRhs();
743 
744   APInt intValue;
745   // minui(x,MIN_INT) -> MIN_INT
746   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
747     return getRhs();
748 
749   // minui(x, MAX_INT) -> x
750   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
751     return getLhs();
752 
753   return constFoldBinaryOp<IntegerAttr>(operands,
754                                         [](const APInt &a, const APInt &b) {
755                                           return llvm::APIntOps::umin(a, b);
756                                         });
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // MulFOp
761 //===----------------------------------------------------------------------===//
762 
763 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
764   APFloat floatValue(0.0f), inverseValue(0.0f);
765   // mulf(x, 1) -> x
766   if (matchPattern(getRhs(), m_OneFloat()))
767     return getLhs();
768 
769   // mulf(1, x) -> x
770   if (matchPattern(getLhs(), m_OneFloat()))
771     return getRhs();
772 
773   return constFoldBinaryOp<FloatAttr>(
774       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
775 }
776 
777 //===----------------------------------------------------------------------===//
778 // DivFOp
779 //===----------------------------------------------------------------------===//
780 
781 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
782   APFloat floatValue(0.0f), inverseValue(0.0f);
783   // divf(x, 1) -> x
784   if (matchPattern(getRhs(), m_OneFloat()))
785     return getLhs();
786 
787   return constFoldBinaryOp<FloatAttr>(
788       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
789 }
790 
791 //===----------------------------------------------------------------------===//
792 // Utility functions for verifying cast ops
793 //===----------------------------------------------------------------------===//
794 
795 template <typename... Types>
796 using type_list = std::tuple<Types...> *;
797 
798 /// Returns a non-null type only if the provided type is one of the allowed
799 /// types or one of the allowed shaped types of the allowed types. Returns the
800 /// element type if a valid shaped type is provided.
801 template <typename... ShapedTypes, typename... ElementTypes>
802 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
803                               type_list<ElementTypes...>) {
804   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
805     return {};
806 
807   auto underlyingType = getElementTypeOrSelf(type);
808   if (!underlyingType.isa<ElementTypes...>())
809     return {};
810 
811   return underlyingType;
812 }
813 
814 /// Get allowed underlying types for vectors and tensors.
815 template <typename... ElementTypes>
816 static Type getTypeIfLike(Type type) {
817   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
818                            type_list<ElementTypes...>());
819 }
820 
821 /// Get allowed underlying types for vectors, tensors, and memrefs.
822 template <typename... ElementTypes>
823 static Type getTypeIfLikeOrMemRef(Type type) {
824   return getUnderlyingType(type,
825                            type_list<VectorType, TensorType, MemRefType>(),
826                            type_list<ElementTypes...>());
827 }
828 
829 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
830   return inputs.size() == 1 && outputs.size() == 1 &&
831          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
832 }
833 
834 //===----------------------------------------------------------------------===//
835 // Verifiers for integer and floating point extension/truncation ops
836 //===----------------------------------------------------------------------===//
837 
838 // Extend ops can only extend to a wider type.
839 template <typename ValType, typename Op>
840 static LogicalResult verifyExtOp(Op op) {
841   Type srcType = getElementTypeOrSelf(op.getIn().getType());
842   Type dstType = getElementTypeOrSelf(op.getType());
843 
844   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
845     return op.emitError("result type ")
846            << dstType << " must be wider than operand type " << srcType;
847 
848   return success();
849 }
850 
851 // Truncate ops can only truncate to a shorter type.
852 template <typename ValType, typename Op>
853 static LogicalResult verifyTruncateOp(Op op) {
854   Type srcType = getElementTypeOrSelf(op.getIn().getType());
855   Type dstType = getElementTypeOrSelf(op.getType());
856 
857   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
858     return op.emitError("result type ")
859            << dstType << " must be shorter than operand type " << srcType;
860 
861   return success();
862 }
863 
864 /// Validate a cast that changes the width of a type.
865 template <template <typename> class WidthComparator, typename... ElementTypes>
866 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
867   if (!areValidCastInputsAndOutputs(inputs, outputs))
868     return false;
869 
870   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
871   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
872   if (!srcType || !dstType)
873     return false;
874 
875   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
876                                      srcType.getIntOrFloatBitWidth());
877 }
878 
879 //===----------------------------------------------------------------------===//
880 // ExtUIOp
881 //===----------------------------------------------------------------------===//
882 
883 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
884   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
885     return IntegerAttr::get(
886         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
887 
888   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
889     getInMutable().assign(lhs.getIn());
890     return getResult();
891   }
892 
893   return {};
894 }
895 
896 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
897   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
898 }
899 
900 LogicalResult arith::ExtUIOp::verify() {
901   return verifyExtOp<IntegerType>(*this);
902 }
903 
904 //===----------------------------------------------------------------------===//
905 // ExtSIOp
906 //===----------------------------------------------------------------------===//
907 
908 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
909   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
910     return IntegerAttr::get(
911         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
912 
913   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
914     getInMutable().assign(lhs.getIn());
915     return getResult();
916   }
917 
918   return {};
919 }
920 
921 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
922   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
923 }
924 
925 void arith::ExtSIOp::getCanonicalizationPatterns(
926     RewritePatternSet &patterns, MLIRContext *context) {
927   patterns.add<ExtSIOfExtUI>(context);
928 }
929 
930 LogicalResult arith::ExtSIOp::verify() {
931   return verifyExtOp<IntegerType>(*this);
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // ExtFOp
936 //===----------------------------------------------------------------------===//
937 
938 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
939   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
940 }
941 
942 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
943 
944 //===----------------------------------------------------------------------===//
945 // TruncIOp
946 //===----------------------------------------------------------------------===//
947 
948 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
949   assert(operands.size() == 1 && "unary operation takes one operand");
950 
951   // trunci(zexti(a)) -> a
952   // trunci(sexti(a)) -> a
953   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
954       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
955     return getOperand().getDefiningOp()->getOperand(0);
956 
957   // trunci(trunci(a)) -> trunci(a))
958   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
959     setOperand(getOperand().getDefiningOp()->getOperand(0));
960     return getResult();
961   }
962 
963   if (!operands[0])
964     return {};
965 
966   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
967     return IntegerAttr::get(
968         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
969   }
970 
971   return {};
972 }
973 
974 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
975   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
976 }
977 
978 LogicalResult arith::TruncIOp::verify() {
979   return verifyTruncateOp<IntegerType>(*this);
980 }
981 
982 //===----------------------------------------------------------------------===//
983 // TruncFOp
984 //===----------------------------------------------------------------------===//
985 
986 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
987 /// can be represented without precision loss or rounding.
988 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
989   assert(operands.size() == 1 && "unary operation takes one operand");
990 
991   auto constOperand = operands.front();
992   if (!constOperand || !constOperand.isa<FloatAttr>())
993     return {};
994 
995   // Convert to target type via 'double'.
996   double sourceValue =
997       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
998   auto targetAttr = FloatAttr::get(getType(), sourceValue);
999 
1000   // Propagate if constant's value does not change after truncation.
1001   if (sourceValue == targetAttr.getValue().convertToDouble())
1002     return targetAttr;
1003 
1004   return {};
1005 }
1006 
1007 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1008   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1009 }
1010 
1011 LogicalResult arith::TruncFOp::verify() {
1012   return verifyTruncateOp<FloatType>(*this);
1013 }
1014 
1015 //===----------------------------------------------------------------------===//
1016 // AndIOp
1017 //===----------------------------------------------------------------------===//
1018 
1019 void arith::AndIOp::getCanonicalizationPatterns(
1020     RewritePatternSet &patterns, MLIRContext *context) {
1021   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // OrIOp
1026 //===----------------------------------------------------------------------===//
1027 
1028 void arith::OrIOp::getCanonicalizationPatterns(
1029     RewritePatternSet &patterns, MLIRContext *context) {
1030   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // Verifiers for casts between integers and floats.
1035 //===----------------------------------------------------------------------===//
1036 
1037 template <typename From, typename To>
1038 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1039   if (!areValidCastInputsAndOutputs(inputs, outputs))
1040     return false;
1041 
1042   auto srcType = getTypeIfLike<From>(inputs.front());
1043   auto dstType = getTypeIfLike<To>(outputs.back());
1044 
1045   return srcType && dstType;
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // UIToFPOp
1050 //===----------------------------------------------------------------------===//
1051 
1052 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1053   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1054 }
1055 
1056 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1057   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1058     const APInt &api = lhs.getValue();
1059     FloatType floatTy = getType().cast<FloatType>();
1060     APFloat apf(floatTy.getFloatSemantics(),
1061                 APInt::getZero(floatTy.getWidth()));
1062     apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
1063     return FloatAttr::get(floatTy, apf);
1064   }
1065   return {};
1066 }
1067 
1068 //===----------------------------------------------------------------------===//
1069 // SIToFPOp
1070 //===----------------------------------------------------------------------===//
1071 
1072 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1073   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1074 }
1075 
1076 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1077   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1078     const APInt &api = lhs.getValue();
1079     FloatType floatTy = getType().cast<FloatType>();
1080     APFloat apf(floatTy.getFloatSemantics(),
1081                 APInt::getZero(floatTy.getWidth()));
1082     apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
1083     return FloatAttr::get(floatTy, apf);
1084   }
1085   return {};
1086 }
1087 //===----------------------------------------------------------------------===//
1088 // FPToUIOp
1089 //===----------------------------------------------------------------------===//
1090 
1091 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1092   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1093 }
1094 
1095 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1096   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1097     const APFloat &apf = lhs.getValue();
1098     IntegerType intTy = getType().cast<IntegerType>();
1099     bool ignored;
1100     APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1101     if (APFloat::opInvalidOp ==
1102         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1103       // Undefined behavior invoked - the destination type can't represent
1104       // the input constant.
1105       return {};
1106     }
1107     return IntegerAttr::get(getType(), api);
1108   }
1109 
1110   return {};
1111 }
1112 
1113 //===----------------------------------------------------------------------===//
1114 // FPToSIOp
1115 //===----------------------------------------------------------------------===//
1116 
1117 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1118   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1119 }
1120 
1121 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1122   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1123     const APFloat &apf = lhs.getValue();
1124     IntegerType intTy = getType().cast<IntegerType>();
1125     bool ignored;
1126     APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1127     if (APFloat::opInvalidOp ==
1128         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1129       // Undefined behavior invoked - the destination type can't represent
1130       // the input constant.
1131       return {};
1132     }
1133     return IntegerAttr::get(getType(), api);
1134   }
1135 
1136   return {};
1137 }
1138 
1139 //===----------------------------------------------------------------------===//
1140 // IndexCastOp
1141 //===----------------------------------------------------------------------===//
1142 
1143 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1144                                            TypeRange outputs) {
1145   if (!areValidCastInputsAndOutputs(inputs, outputs))
1146     return false;
1147 
1148   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1149   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1150   if (!srcType || !dstType)
1151     return false;
1152 
1153   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1154          (srcType.isSignlessInteger() && dstType.isIndex());
1155 }
1156 
1157 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1158   // index_cast(constant) -> constant
1159   // A little hack because we go through int. Otherwise, the size of the
1160   // constant might need to change.
1161   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1162     return IntegerAttr::get(getType(), value.getInt());
1163 
1164   return {};
1165 }
1166 
1167 void arith::IndexCastOp::getCanonicalizationPatterns(
1168     RewritePatternSet &patterns, MLIRContext *context) {
1169   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1170 }
1171 
1172 //===----------------------------------------------------------------------===//
1173 // BitcastOp
1174 //===----------------------------------------------------------------------===//
1175 
1176 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1177   if (!areValidCastInputsAndOutputs(inputs, outputs))
1178     return false;
1179 
1180   auto srcType =
1181       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1182   auto dstType =
1183       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1184   if (!srcType || !dstType)
1185     return false;
1186 
1187   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1188 }
1189 
1190 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1191   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1192 
1193   auto resType = getType();
1194   auto operand = operands[0];
1195   if (!operand)
1196     return {};
1197 
1198   /// Bitcast dense elements.
1199   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1200     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1201   /// Other shaped types unhandled.
1202   if (resType.isa<ShapedType>())
1203     return {};
1204 
1205   /// Bitcast integer or float to integer or float.
1206   APInt bits = operand.isa<FloatAttr>()
1207                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1208                    : operand.cast<IntegerAttr>().getValue();
1209 
1210   if (auto resFloatType = resType.dyn_cast<FloatType>())
1211     return FloatAttr::get(resType,
1212                           APFloat(resFloatType.getFloatSemantics(), bits));
1213   return IntegerAttr::get(resType, bits);
1214 }
1215 
1216 void arith::BitcastOp::getCanonicalizationPatterns(
1217     RewritePatternSet &patterns, MLIRContext *context) {
1218   patterns.add<BitcastOfBitcast>(context);
1219 }
1220 
1221 //===----------------------------------------------------------------------===//
1222 // Helpers for compare ops
1223 //===----------------------------------------------------------------------===//
1224 
1225 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1226 static Type getI1SameShape(Type type) {
1227   auto i1Type = IntegerType::get(type.getContext(), 1);
1228   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1229     return RankedTensorType::get(tensorType.getShape(), i1Type);
1230   if (type.isa<UnrankedTensorType>())
1231     return UnrankedTensorType::get(i1Type);
1232   if (auto vectorType = type.dyn_cast<VectorType>())
1233     return VectorType::get(vectorType.getShape(), i1Type,
1234                            vectorType.getNumScalableDims());
1235   return i1Type;
1236 }
1237 
1238 //===----------------------------------------------------------------------===//
1239 // CmpIOp
1240 //===----------------------------------------------------------------------===//
1241 
1242 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1243 /// comparison predicates.
1244 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1245                                     const APInt &lhs, const APInt &rhs) {
1246   switch (predicate) {
1247   case arith::CmpIPredicate::eq:
1248     return lhs.eq(rhs);
1249   case arith::CmpIPredicate::ne:
1250     return lhs.ne(rhs);
1251   case arith::CmpIPredicate::slt:
1252     return lhs.slt(rhs);
1253   case arith::CmpIPredicate::sle:
1254     return lhs.sle(rhs);
1255   case arith::CmpIPredicate::sgt:
1256     return lhs.sgt(rhs);
1257   case arith::CmpIPredicate::sge:
1258     return lhs.sge(rhs);
1259   case arith::CmpIPredicate::ult:
1260     return lhs.ult(rhs);
1261   case arith::CmpIPredicate::ule:
1262     return lhs.ule(rhs);
1263   case arith::CmpIPredicate::ugt:
1264     return lhs.ugt(rhs);
1265   case arith::CmpIPredicate::uge:
1266     return lhs.uge(rhs);
1267   }
1268   llvm_unreachable("unknown cmpi predicate kind");
1269 }
1270 
1271 /// Returns true if the predicate is true for two equal operands.
1272 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1273   switch (predicate) {
1274   case arith::CmpIPredicate::eq:
1275   case arith::CmpIPredicate::sle:
1276   case arith::CmpIPredicate::sge:
1277   case arith::CmpIPredicate::ule:
1278   case arith::CmpIPredicate::uge:
1279     return true;
1280   case arith::CmpIPredicate::ne:
1281   case arith::CmpIPredicate::slt:
1282   case arith::CmpIPredicate::sgt:
1283   case arith::CmpIPredicate::ult:
1284   case arith::CmpIPredicate::ugt:
1285     return false;
1286   }
1287   llvm_unreachable("unknown cmpi predicate kind");
1288 }
1289 
1290 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1291   auto boolAttr = BoolAttr::get(ctx, value);
1292   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1293   if (!shapedType)
1294     return boolAttr;
1295   return DenseElementsAttr::get(shapedType, boolAttr);
1296 }
1297 
1298 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1299   assert(operands.size() == 2 && "cmpi takes two operands");
1300 
1301   // cmpi(pred, x, x)
1302   if (getLhs() == getRhs()) {
1303     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1304     return getBoolAttribute(getType(), getContext(), val);
1305   }
1306 
1307   if (matchPattern(getRhs(), m_Zero())) {
1308     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1309       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1310         // extsi(%x : i1 -> iN) != 0  ->  %x
1311         if (getPredicate() == arith::CmpIPredicate::ne) {
1312           return extOp.getOperand();
1313         }
1314       }
1315     }
1316     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1317       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1318         // extui(%x : i1 -> iN) != 0  ->  %x
1319         if (getPredicate() == arith::CmpIPredicate::ne) {
1320           return extOp.getOperand();
1321         }
1322       }
1323     }
1324   }
1325 
1326   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1327   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1328   if (!lhs || !rhs)
1329     return {};
1330 
1331   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1332   return BoolAttr::get(getContext(), val);
1333 }
1334 
1335 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1336                                                 MLIRContext *context) {
1337   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1338 }
1339 
1340 //===----------------------------------------------------------------------===//
1341 // CmpFOp
1342 //===----------------------------------------------------------------------===//
1343 
1344 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1345 /// comparison predicates.
1346 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1347                                     const APFloat &lhs, const APFloat &rhs) {
1348   auto cmpResult = lhs.compare(rhs);
1349   switch (predicate) {
1350   case arith::CmpFPredicate::AlwaysFalse:
1351     return false;
1352   case arith::CmpFPredicate::OEQ:
1353     return cmpResult == APFloat::cmpEqual;
1354   case arith::CmpFPredicate::OGT:
1355     return cmpResult == APFloat::cmpGreaterThan;
1356   case arith::CmpFPredicate::OGE:
1357     return cmpResult == APFloat::cmpGreaterThan ||
1358            cmpResult == APFloat::cmpEqual;
1359   case arith::CmpFPredicate::OLT:
1360     return cmpResult == APFloat::cmpLessThan;
1361   case arith::CmpFPredicate::OLE:
1362     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1363   case arith::CmpFPredicate::ONE:
1364     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1365   case arith::CmpFPredicate::ORD:
1366     return cmpResult != APFloat::cmpUnordered;
1367   case arith::CmpFPredicate::UEQ:
1368     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1369   case arith::CmpFPredicate::UGT:
1370     return cmpResult == APFloat::cmpUnordered ||
1371            cmpResult == APFloat::cmpGreaterThan;
1372   case arith::CmpFPredicate::UGE:
1373     return cmpResult == APFloat::cmpUnordered ||
1374            cmpResult == APFloat::cmpGreaterThan ||
1375            cmpResult == APFloat::cmpEqual;
1376   case arith::CmpFPredicate::ULT:
1377     return cmpResult == APFloat::cmpUnordered ||
1378            cmpResult == APFloat::cmpLessThan;
1379   case arith::CmpFPredicate::ULE:
1380     return cmpResult == APFloat::cmpUnordered ||
1381            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1382   case arith::CmpFPredicate::UNE:
1383     return cmpResult != APFloat::cmpEqual;
1384   case arith::CmpFPredicate::UNO:
1385     return cmpResult == APFloat::cmpUnordered;
1386   case arith::CmpFPredicate::AlwaysTrue:
1387     return true;
1388   }
1389   llvm_unreachable("unknown cmpf predicate kind");
1390 }
1391 
1392 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1393   assert(operands.size() == 2 && "cmpf takes two operands");
1394 
1395   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1396   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1397 
1398   // If one operand is NaN, making them both NaN does not change the result.
1399   if (lhs && lhs.getValue().isNaN())
1400     rhs = lhs;
1401   if (rhs && rhs.getValue().isNaN())
1402     lhs = rhs;
1403 
1404   if (!lhs || !rhs)
1405     return {};
1406 
1407   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1408   return BoolAttr::get(getContext(), val);
1409 }
1410 
1411 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1412 public:
1413   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1414 
1415   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1416                                                  bool isUnsigned) {
1417     using namespace arith;
1418     switch (pred) {
1419     case CmpFPredicate::UEQ:
1420     case CmpFPredicate::OEQ:
1421       return CmpIPredicate::eq;
1422     case CmpFPredicate::UGT:
1423     case CmpFPredicate::OGT:
1424       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1425     case CmpFPredicate::UGE:
1426     case CmpFPredicate::OGE:
1427       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1428     case CmpFPredicate::ULT:
1429     case CmpFPredicate::OLT:
1430       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1431     case CmpFPredicate::ULE:
1432     case CmpFPredicate::OLE:
1433       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1434     case CmpFPredicate::UNE:
1435     case CmpFPredicate::ONE:
1436       return CmpIPredicate::ne;
1437     default:
1438       llvm_unreachable("Unexpected predicate!");
1439     }
1440   }
1441 
1442   LogicalResult matchAndRewrite(CmpFOp op,
1443                                 PatternRewriter &rewriter) const override {
1444     FloatAttr flt;
1445     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1446       return failure();
1447 
1448     const APFloat &rhs = flt.getValue();
1449 
1450     // Don't attempt to fold a nan.
1451     if (rhs.isNaN())
1452       return failure();
1453 
1454     // Get the width of the mantissa.  We don't want to hack on conversions that
1455     // might lose information from the integer, e.g. "i64 -> float"
1456     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1457     int mantissaWidth = floatTy.getFPMantissaWidth();
1458     if (mantissaWidth <= 0)
1459       return failure();
1460 
1461     bool isUnsigned;
1462     Value intVal;
1463 
1464     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1465       isUnsigned = false;
1466       intVal = si.getIn();
1467     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1468       isUnsigned = true;
1469       intVal = ui.getIn();
1470     } else {
1471       return failure();
1472     }
1473 
1474     // Check to see that the input is converted from an integer type that is
1475     // small enough that preserves all bits.
1476     auto intTy = intVal.getType().cast<IntegerType>();
1477     auto intWidth = intTy.getWidth();
1478 
1479     // Number of bits representing values, as opposed to the sign
1480     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1481 
1482     // Following test does NOT adjust intWidth downwards for signed inputs,
1483     // because the most negative value still requires all the mantissa bits
1484     // to distinguish it from one less than that value.
1485     if ((int)intWidth > mantissaWidth) {
1486       // Conversion would lose accuracy. Check if loss can impact comparison.
1487       int exponent = ilogb(rhs);
1488       if (exponent == APFloat::IEK_Inf) {
1489         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1490         if (maxExponent < (int)valueBits) {
1491           // Conversion could create infinity.
1492           return failure();
1493         }
1494       } else {
1495         // Note that if rhs is zero or NaN, then Exp is negative
1496         // and first condition is trivially false.
1497         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1498           // Conversion could affect comparison.
1499           return failure();
1500         }
1501       }
1502     }
1503 
1504     // Convert to equivalent cmpi predicate
1505     CmpIPredicate pred;
1506     switch (op.getPredicate()) {
1507     case CmpFPredicate::ORD:
1508       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1509       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1510                                                  /*width=*/1);
1511       return success();
1512     case CmpFPredicate::UNO:
1513       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1514       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1515                                                  /*width=*/1);
1516       return success();
1517     default:
1518       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1519       break;
1520     }
1521 
1522     if (!isUnsigned) {
1523       // If the rhs value is > SignedMax, fold the comparison.  This handles
1524       // +INF and large values.
1525       APFloat signedMax(rhs.getSemantics());
1526       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1527                                  APFloat::rmNearestTiesToEven);
1528       if (signedMax < rhs) { // smax < 13123.0
1529         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1530             pred == CmpIPredicate::sle)
1531           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1532                                                      /*width=*/1);
1533         else
1534           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1535                                                      /*width=*/1);
1536         return success();
1537       }
1538     } else {
1539       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1540       // +INF and large values.
1541       APFloat unsignedMax(rhs.getSemantics());
1542       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1543                                    APFloat::rmNearestTiesToEven);
1544       if (unsignedMax < rhs) { // umax < 13123.0
1545         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1546             pred == CmpIPredicate::ule)
1547           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1548                                                      /*width=*/1);
1549         else
1550           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1551                                                      /*width=*/1);
1552         return success();
1553       }
1554     }
1555 
1556     if (!isUnsigned) {
1557       // See if the rhs value is < SignedMin.
1558       APFloat signedMin(rhs.getSemantics());
1559       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1560                                  APFloat::rmNearestTiesToEven);
1561       if (signedMin > rhs) { // smin > 12312.0
1562         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1563             pred == CmpIPredicate::sge)
1564           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1565                                                      /*width=*/1);
1566         else
1567           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1568                                                      /*width=*/1);
1569         return success();
1570       }
1571     } else {
1572       // See if the rhs value is < UnsignedMin.
1573       APFloat unsignedMin(rhs.getSemantics());
1574       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1575                                    APFloat::rmNearestTiesToEven);
1576       if (unsignedMin > rhs) { // umin > 12312.0
1577         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1578             pred == CmpIPredicate::uge)
1579           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1580                                                      /*width=*/1);
1581         else
1582           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1583                                                      /*width=*/1);
1584         return success();
1585       }
1586     }
1587 
1588     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1589     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1590     // casting the FP value to the integer value and back, checking for
1591     // equality. Don't do this for zero, because -0.0 is not fractional.
1592     bool ignored;
1593     APSInt rhsInt(intWidth, isUnsigned);
1594     if (APFloat::opInvalidOp ==
1595         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1596       // Undefined behavior invoked - the destination type can't represent
1597       // the input constant.
1598       return failure();
1599     }
1600 
1601     if (!rhs.isZero()) {
1602       APFloat apf(floatTy.getFloatSemantics(),
1603                   APInt::getZero(floatTy.getWidth()));
1604       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1605 
1606       bool equal = apf == rhs;
1607       if (!equal) {
1608         // If we had a comparison against a fractional value, we have to adjust
1609         // the compare predicate and sometimes the value.  rhsInt is rounded
1610         // towards zero at this point.
1611         switch (pred) {
1612         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1613           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1614                                                      /*width=*/1);
1615           return success();
1616         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1617           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1618                                                      /*width=*/1);
1619           return success();
1620         case CmpIPredicate::ule:
1621           // (float)int <= 4.4   --> int <= 4
1622           // (float)int <= -4.4  --> false
1623           if (rhs.isNegative()) {
1624             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1625                                                        /*width=*/1);
1626             return success();
1627           }
1628           break;
1629         case CmpIPredicate::sle:
1630           // (float)int <= 4.4   --> int <= 4
1631           // (float)int <= -4.4  --> int < -4
1632           if (rhs.isNegative())
1633             pred = CmpIPredicate::slt;
1634           break;
1635         case CmpIPredicate::ult:
1636           // (float)int < -4.4   --> false
1637           // (float)int < 4.4    --> int <= 4
1638           if (rhs.isNegative()) {
1639             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1640                                                        /*width=*/1);
1641             return success();
1642           }
1643           pred = CmpIPredicate::ule;
1644           break;
1645         case CmpIPredicate::slt:
1646           // (float)int < -4.4   --> int < -4
1647           // (float)int < 4.4    --> int <= 4
1648           if (!rhs.isNegative())
1649             pred = CmpIPredicate::sle;
1650           break;
1651         case CmpIPredicate::ugt:
1652           // (float)int > 4.4    --> int > 4
1653           // (float)int > -4.4   --> true
1654           if (rhs.isNegative()) {
1655             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1656                                                        /*width=*/1);
1657             return success();
1658           }
1659           break;
1660         case CmpIPredicate::sgt:
1661           // (float)int > 4.4    --> int > 4
1662           // (float)int > -4.4   --> int >= -4
1663           if (rhs.isNegative())
1664             pred = CmpIPredicate::sge;
1665           break;
1666         case CmpIPredicate::uge:
1667           // (float)int >= -4.4   --> true
1668           // (float)int >= 4.4    --> int > 4
1669           if (rhs.isNegative()) {
1670             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1671                                                        /*width=*/1);
1672             return success();
1673           }
1674           pred = CmpIPredicate::ugt;
1675           break;
1676         case CmpIPredicate::sge:
1677           // (float)int >= -4.4   --> int >= -4
1678           // (float)int >= 4.4    --> int > 4
1679           if (!rhs.isNegative())
1680             pred = CmpIPredicate::sgt;
1681           break;
1682         }
1683       }
1684     }
1685 
1686     // Lower this FP comparison into an appropriate integer version of the
1687     // comparison.
1688     rewriter.replaceOpWithNewOp<CmpIOp>(
1689         op, pred, intVal,
1690         rewriter.create<ConstantOp>(
1691             op.getLoc(), intVal.getType(),
1692             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1693     return success();
1694   }
1695 };
1696 
1697 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1698                                                 MLIRContext *context) {
1699   patterns.insert<CmpFIntToFPConst>(context);
1700 }
1701 
1702 //===----------------------------------------------------------------------===//
1703 // SelectOp
1704 //===----------------------------------------------------------------------===//
1705 
1706 // Transforms a select of a boolean to arithmetic operations
1707 //
1708 //  arith.select %arg, %x, %y : i1
1709 //
1710 //  becomes
1711 //
1712 //  and(%arg, %x) or and(!%arg, %y)
1713 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1714   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1715 
1716   LogicalResult matchAndRewrite(arith::SelectOp op,
1717                                 PatternRewriter &rewriter) const override {
1718     if (!op.getType().isInteger(1))
1719       return failure();
1720 
1721     Value falseConstant =
1722         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1723     Value notCondition = rewriter.create<arith::XOrIOp>(
1724         op.getLoc(), op.getCondition(), falseConstant);
1725 
1726     Value trueVal = rewriter.create<arith::AndIOp>(
1727         op.getLoc(), op.getCondition(), op.getTrueValue());
1728     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1729                                                     op.getFalseValue());
1730     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1731     return success();
1732   }
1733 };
1734 
1735 //  select %arg, %c1, %c0 => extui %arg
1736 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1737   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1738 
1739   LogicalResult matchAndRewrite(arith::SelectOp op,
1740                                 PatternRewriter &rewriter) const override {
1741     // Cannot extui i1 to i1, or i1 to f32
1742     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1743       return failure();
1744 
1745     // select %x, c1, %c0 => extui %arg
1746     if (matchPattern(op.getTrueValue(), m_One()))
1747       if (matchPattern(op.getFalseValue(), m_Zero())) {
1748         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1749                                                     op.getCondition());
1750         return success();
1751       }
1752 
1753     // select %x, c0, %c1 => extui (xor %arg, true)
1754     if (matchPattern(op.getTrueValue(), m_Zero()))
1755       if (matchPattern(op.getFalseValue(), m_One())) {
1756         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1757             op, op.getType(),
1758             rewriter.create<arith::XOrIOp>(
1759                 op.getLoc(), op.getCondition(),
1760                 rewriter.create<arith::ConstantIntOp>(
1761                     op.getLoc(), 1, op.getCondition().getType())));
1762         return success();
1763       }
1764 
1765     return failure();
1766   }
1767 };
1768 
1769 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1770                                                   MLIRContext *context) {
1771   results.add<SelectI1Simplify, SelectToExtUI>(context);
1772 }
1773 
1774 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1775   Value trueVal = getTrueValue();
1776   Value falseVal = getFalseValue();
1777   if (trueVal == falseVal)
1778     return trueVal;
1779 
1780   Value condition = getCondition();
1781 
1782   // select true, %0, %1 => %0
1783   if (matchPattern(condition, m_One()))
1784     return trueVal;
1785 
1786   // select false, %0, %1 => %1
1787   if (matchPattern(condition, m_Zero()))
1788     return falseVal;
1789 
1790   // select %x, true, false => %x
1791   if (getType().isInteger(1))
1792     if (matchPattern(getTrueValue(), m_One()))
1793       if (matchPattern(getFalseValue(), m_Zero()))
1794         return condition;
1795 
1796   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1797     auto pred = cmp.getPredicate();
1798     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1799       auto cmpLhs = cmp.getLhs();
1800       auto cmpRhs = cmp.getRhs();
1801 
1802       // %0 = arith.cmpi eq, %arg0, %arg1
1803       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1804 
1805       // %0 = arith.cmpi ne, %arg0, %arg1
1806       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1807 
1808       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1809           (cmpRhs == trueVal && cmpLhs == falseVal))
1810         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1811     }
1812   }
1813   return nullptr;
1814 }
1815 
1816 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1817   Type conditionType, resultType;
1818   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1819   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1820       parser.parseOptionalAttrDict(result.attributes) ||
1821       parser.parseColonType(resultType))
1822     return failure();
1823 
1824   // Check for the explicit condition type if this is a masked tensor or vector.
1825   if (succeeded(parser.parseOptionalComma())) {
1826     conditionType = resultType;
1827     if (parser.parseType(resultType))
1828       return failure();
1829   } else {
1830     conditionType = parser.getBuilder().getI1Type();
1831   }
1832 
1833   result.addTypes(resultType);
1834   return parser.resolveOperands(operands,
1835                                 {conditionType, resultType, resultType},
1836                                 parser.getNameLoc(), result.operands);
1837 }
1838 
1839 void arith::SelectOp::print(OpAsmPrinter &p) {
1840   p << " " << getOperands();
1841   p.printOptionalAttrDict((*this)->getAttrs());
1842   p << " : ";
1843   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1844     p << condType << ", ";
1845   p << getType();
1846 }
1847 
1848 LogicalResult arith::SelectOp::verify() {
1849   Type conditionType = getCondition().getType();
1850   if (conditionType.isSignlessInteger(1))
1851     return success();
1852 
1853   // If the result type is a vector or tensor, the type can be a mask with the
1854   // same elements.
1855   Type resultType = getType();
1856   if (!resultType.isa<TensorType, VectorType>())
1857     return emitOpError() << "expected condition to be a signless i1, but got "
1858                          << conditionType;
1859   Type shapedConditionType = getI1SameShape(resultType);
1860   if (conditionType != shapedConditionType) {
1861     return emitOpError() << "expected condition type to have the same shape "
1862                             "as the result type, expected "
1863                          << shapedConditionType << ", but got "
1864                          << conditionType;
1865   }
1866   return success();
1867 }
1868 //===----------------------------------------------------------------------===//
1869 // ShLIOp
1870 //===----------------------------------------------------------------------===//
1871 
1872 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1873   // Don't fold if shifting more than the bit width.
1874   bool bounded = false;
1875   auto result = constFoldBinaryOp<IntegerAttr>(
1876       operands, [&](const APInt &a, const APInt &b) {
1877         bounded = b.ule(b.getBitWidth());
1878         return std::move(a).shl(b);
1879       });
1880   return bounded ? result : Attribute();
1881 }
1882 
1883 //===----------------------------------------------------------------------===//
1884 // ShRUIOp
1885 //===----------------------------------------------------------------------===//
1886 
1887 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1888   // Don't fold if shifting more than the bit width.
1889   bool bounded = false;
1890   auto result = constFoldBinaryOp<IntegerAttr>(
1891       operands, [&](const APInt &a, const APInt &b) {
1892         bounded = b.ule(b.getBitWidth());
1893         return std::move(a).lshr(b);
1894       });
1895   return bounded ? result : Attribute();
1896 }
1897 
1898 //===----------------------------------------------------------------------===//
1899 // ShRSIOp
1900 //===----------------------------------------------------------------------===//
1901 
1902 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1903   // Don't fold if shifting more than the bit width.
1904   bool bounded = false;
1905   auto result = constFoldBinaryOp<IntegerAttr>(
1906       operands, [&](const APInt &a, const APInt &b) {
1907         bounded = b.ule(b.getBitWidth());
1908         return std::move(a).ashr(b);
1909       });
1910   return bounded ? result : Attribute();
1911 }
1912 
1913 //===----------------------------------------------------------------------===//
1914 // Atomic Enum
1915 //===----------------------------------------------------------------------===//
1916 
1917 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1918 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1919                                             OpBuilder &builder, Location loc) {
1920   switch (kind) {
1921   case AtomicRMWKind::maxf:
1922     return builder.getFloatAttr(
1923         resultType,
1924         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1925                         /*Negative=*/true));
1926   case AtomicRMWKind::addf:
1927   case AtomicRMWKind::addi:
1928   case AtomicRMWKind::maxu:
1929   case AtomicRMWKind::ori:
1930     return builder.getZeroAttr(resultType);
1931   case AtomicRMWKind::andi:
1932     return builder.getIntegerAttr(
1933         resultType,
1934         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1935   case AtomicRMWKind::maxs:
1936     return builder.getIntegerAttr(
1937         resultType,
1938         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1939   case AtomicRMWKind::minf:
1940     return builder.getFloatAttr(
1941         resultType,
1942         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1943                         /*Negative=*/false));
1944   case AtomicRMWKind::mins:
1945     return builder.getIntegerAttr(
1946         resultType,
1947         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1948   case AtomicRMWKind::minu:
1949     return builder.getIntegerAttr(
1950         resultType,
1951         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1952   case AtomicRMWKind::muli:
1953     return builder.getIntegerAttr(resultType, 1);
1954   case AtomicRMWKind::mulf:
1955     return builder.getFloatAttr(resultType, 1);
1956   // TODO: Add remaining reduction operations.
1957   default:
1958     (void)emitOptionalError(loc, "Reduction operation type not supported");
1959     break;
1960   }
1961   return nullptr;
1962 }
1963 
1964 /// Returns the identity value associated with an AtomicRMWKind op.
1965 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1966                                     OpBuilder &builder, Location loc) {
1967   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1968   return builder.create<arith::ConstantOp>(loc, attr);
1969 }
1970 
1971 /// Return the value obtained by applying the reduction operation kind
1972 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1973 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1974                                   Location loc, Value lhs, Value rhs) {
1975   switch (op) {
1976   case AtomicRMWKind::addf:
1977     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1978   case AtomicRMWKind::addi:
1979     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1980   case AtomicRMWKind::mulf:
1981     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1982   case AtomicRMWKind::muli:
1983     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1984   case AtomicRMWKind::maxf:
1985     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1986   case AtomicRMWKind::minf:
1987     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1988   case AtomicRMWKind::maxs:
1989     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1990   case AtomicRMWKind::mins:
1991     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1992   case AtomicRMWKind::maxu:
1993     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1994   case AtomicRMWKind::minu:
1995     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1996   case AtomicRMWKind::ori:
1997     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1998   case AtomicRMWKind::andi:
1999     return builder.create<arith::AndIOp>(loc, lhs, rhs);
2000   // TODO: Add remaining reduction operations.
2001   default:
2002     (void)emitOptionalError(loc, "Reduction operation type not supported");
2003     break;
2004   }
2005   return nullptr;
2006 }
2007 
2008 //===----------------------------------------------------------------------===//
2009 // TableGen'd op method definitions
2010 //===----------------------------------------------------------------------===//
2011 
2012 #define GET_OP_CLASSES
2013 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2014 
2015 //===----------------------------------------------------------------------===//
2016 // TableGen'd enum attribute definitions
2017 //===----------------------------------------------------------------------===//
2018 
2019 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2020