1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30                                    Attribute lhs, Attribute rhs) {
31   return builder.getIntegerAttr(res.getType(),
32                                 lhs.cast<IntegerAttr>().getInt() +
33                                     rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37                                    Attribute lhs, Attribute rhs) {
38   return builder.getIntegerAttr(res.getType(),
39                                 lhs.cast<IntegerAttr>().getInt() -
40                                     rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45   switch (pred) {
46   case arith::CmpIPredicate::eq:
47     return arith::CmpIPredicate::ne;
48   case arith::CmpIPredicate::ne:
49     return arith::CmpIPredicate::eq;
50   case arith::CmpIPredicate::slt:
51     return arith::CmpIPredicate::sge;
52   case arith::CmpIPredicate::sle:
53     return arith::CmpIPredicate::sgt;
54   case arith::CmpIPredicate::sgt:
55     return arith::CmpIPredicate::sle;
56   case arith::CmpIPredicate::sge:
57     return arith::CmpIPredicate::slt;
58   case arith::CmpIPredicate::ult:
59     return arith::CmpIPredicate::uge;
60   case arith::CmpIPredicate::ule:
61     return arith::CmpIPredicate::ugt;
62   case arith::CmpIPredicate::ugt:
63     return arith::CmpIPredicate::ule;
64   case arith::CmpIPredicate::uge:
65     return arith::CmpIPredicate::ult;
66   }
67   llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71   return arith::CmpIPredicateAttr::get(pred.getContext(),
72                                        invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88     function_ref<void(Value, StringRef)> setNameFn) {
89   auto type = getType();
90   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91     auto intType = type.dyn_cast<IntegerType>();
92 
93     // Sugar i1 constants with 'true' and 'false'.
94     if (intType && intType.getWidth() == 1)
95       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97     // Otherwise, build a compex name with the value and type.
98     SmallString<32> specialNameBuffer;
99     llvm::raw_svector_ostream specialName(specialNameBuffer);
100     specialName << 'c' << intCst.getInt();
101     if (intType)
102       specialName << '_' << type;
103     setNameFn(getResult(), specialName.str());
104   } else {
105     setNameFn(getResult(), "cst");
106   }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
111 LogicalResult arith::ConstantOp::verify() {
112   auto type = getType();
113   // The value's type must match the return type.
114   if (getValue().getType() != type) {
115     return emitOpError() << "value type " << getValue().getType()
116                          << " must match return type: " << type;
117   }
118   // Integer values must be signless.
119   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120     return emitOpError("integer return type must be signless");
121   // Any float or elements attribute are acceptable.
122   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123     return emitOpError(
124         "value must be an integer, float, or elements attribute");
125   }
126   return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130   // The value's type must be the same as the provided type.
131   if (value.getType() != type)
132     return false;
133   // Integer values must be signless.
134   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135     return false;
136   // Integer, float, and element attributes are buildable.
137   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139 
140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141   return getValue();
142 }
143 
144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145                                  int64_t value, unsigned width) {
146   auto type = builder.getIntegerType(width);
147   arith::ConstantOp::build(builder, result, type,
148                            builder.getIntegerAttr(type, value));
149 }
150 
151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152                                  int64_t value, Type type) {
153   assert(type.isSignlessInteger() &&
154          "ConstantIntOp can only have signless integer type values");
155   arith::ConstantOp::build(builder, result, type,
156                            builder.getIntegerAttr(type, value));
157 }
158 
159 bool arith::ConstantIntOp::classof(Operation *op) {
160   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161     return constOp.getType().isSignlessInteger();
162   return false;
163 }
164 
165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166                                    const APFloat &value, FloatType type) {
167   arith::ConstantOp::build(builder, result, type,
168                            builder.getFloatAttr(type, value));
169 }
170 
171 bool arith::ConstantFloatOp::classof(Operation *op) {
172   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173     return constOp.getType().isa<FloatType>();
174   return false;
175 }
176 
177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178                                    int64_t value) {
179   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180                            builder.getIndexAttr(value));
181 }
182 
183 bool arith::ConstantIndexOp::classof(Operation *op) {
184   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185     return constOp.getType().isIndex();
186   return false;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194   // addi(x, 0) -> x
195   if (matchPattern(getRhs(), m_Zero()))
196     return getLhs();
197 
198   // addi(subi(a, b), b) -> a
199   if (auto sub = getLhs().getDefiningOp<SubIOp>())
200     if (getRhs() == sub.getRhs())
201       return sub.getLhs();
202 
203   // addi(b, subi(a, b)) -> a
204   if (auto sub = getRhs().getDefiningOp<SubIOp>())
205     if (getLhs() == sub.getRhs())
206       return sub.getLhs();
207 
208   return constFoldBinaryOp<IntegerAttr>(
209       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211 
212 void arith::AddIOp::getCanonicalizationPatterns(
213     RewritePatternSet &patterns, MLIRContext *context) {
214   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215       context);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221 
222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223   // subi(x,x) -> 0
224   if (getOperand(0) == getOperand(1))
225     return Builder(getContext()).getZeroAttr(getType());
226   // subi(x,0) -> x
227   if (matchPattern(getRhs(), m_Zero()))
228     return getLhs();
229 
230   return constFoldBinaryOp<IntegerAttr>(
231       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233 
234 void arith::SubIOp::getCanonicalizationPatterns(
235     RewritePatternSet &patterns, MLIRContext *context) {
236   patterns
237       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247   // muli(x, 0) -> 0
248   if (matchPattern(getRhs(), m_Zero()))
249     return getRhs();
250   // muli(x, 1) -> x
251   if (matchPattern(getRhs(), m_One()))
252     return getOperand(0);
253   // TODO: Handle the overflow case.
254 
255   // default folder
256   return constFoldBinaryOp<IntegerAttr>(
257       operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265   // divui (x, 1) -> x.
266   if (matchPattern(getRhs(), m_One()))
267     return getLhs();
268 
269   // Don't fold if it would require a division by zero.
270   bool div0 = false;
271   auto result =
272       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
273         if (div0 || !b) {
274           div0 = true;
275           return a;
276         }
277         return a.udiv(b);
278       });
279 
280   return div0 ? Attribute() : result;
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // DivSIOp
285 //===----------------------------------------------------------------------===//
286 
287 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
288   // divsi (x, 1) -> x.
289   if (matchPattern(getRhs(), m_One()))
290     return getLhs();
291 
292   // Don't fold if it would overflow or if it requires a division by zero.
293   bool overflowOrDiv0 = false;
294   auto result =
295       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
296         if (overflowOrDiv0 || !b) {
297           overflowOrDiv0 = true;
298           return a;
299         }
300         return a.sdiv_ov(b, overflowOrDiv0);
301       });
302 
303   return overflowOrDiv0 ? Attribute() : result;
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Ceil and floor division folding helpers
308 //===----------------------------------------------------------------------===//
309 
310 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
311                                     bool &overflow) {
312   // Returns (a-1)/b + 1
313   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
314   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
315   return val.sadd_ov(one, overflow);
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // CeilDivUIOp
320 //===----------------------------------------------------------------------===//
321 
322 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
323   // ceildivui (x, 1) -> x.
324   if (matchPattern(getRhs(), m_One()))
325     return getLhs();
326 
327   bool overflowOrDiv0 = false;
328   auto result =
329       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
330         if (overflowOrDiv0 || !b) {
331           overflowOrDiv0 = true;
332           return a;
333         }
334         APInt quotient = a.udiv(b);
335         if (!a.urem(b))
336           return quotient;
337         APInt one(a.getBitWidth(), 1, true);
338         return quotient.uadd_ov(one, overflowOrDiv0);
339       });
340 
341   return overflowOrDiv0 ? Attribute() : result;
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // CeilDivSIOp
346 //===----------------------------------------------------------------------===//
347 
348 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
349   // ceildivsi (x, 1) -> x.
350   if (matchPattern(getRhs(), m_One()))
351     return getLhs();
352 
353   // Don't fold if it would overflow or if it requires a division by zero.
354   bool overflowOrDiv0 = false;
355   auto result =
356       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
357         if (overflowOrDiv0 || !b) {
358           overflowOrDiv0 = true;
359           return a;
360         }
361         if (!a)
362           return a;
363         // After this point we know that neither a or b are zero.
364         unsigned bits = a.getBitWidth();
365         APInt zero = APInt::getZero(bits);
366         bool aGtZero = a.sgt(zero);
367         bool bGtZero = b.sgt(zero);
368         if (aGtZero && bGtZero) {
369           // Both positive, return ceil(a, b).
370           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
371         }
372         if (!aGtZero && !bGtZero) {
373           // Both negative, return ceil(-a, -b).
374           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
375           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
376           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
377         }
378         if (!aGtZero && bGtZero) {
379           // A is negative, b is positive, return - ( -a / b).
380           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
381           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
382           return zero.ssub_ov(div, overflowOrDiv0);
383         }
384         // A is positive, b is negative, return - (a / -b).
385         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
386         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
387         return zero.ssub_ov(div, overflowOrDiv0);
388       });
389 
390   return overflowOrDiv0 ? Attribute() : result;
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // FloorDivSIOp
395 //===----------------------------------------------------------------------===//
396 
397 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
398   // floordivsi (x, 1) -> x.
399   if (matchPattern(getRhs(), m_One()))
400     return getLhs();
401 
402   // Don't fold if it would overflow or if it requires a division by zero.
403   bool overflowOrDiv0 = false;
404   auto result =
405       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
406         if (overflowOrDiv0 || !b) {
407           overflowOrDiv0 = true;
408           return a;
409         }
410         if (!a)
411           return a;
412         // After this point we know that neither a or b are zero.
413         unsigned bits = a.getBitWidth();
414         APInt zero = APInt::getZero(bits);
415         bool aGtZero = a.sgt(zero);
416         bool bGtZero = b.sgt(zero);
417         if (aGtZero && bGtZero) {
418           // Both positive, return a / b.
419           return a.sdiv_ov(b, overflowOrDiv0);
420         }
421         if (!aGtZero && !bGtZero) {
422           // Both negative, return -a / -b.
423           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
424           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
425           return posA.sdiv_ov(posB, overflowOrDiv0);
426         }
427         if (!aGtZero && bGtZero) {
428           // A is negative, b is positive, return - ceil(-a, b).
429           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
430           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
431           return zero.ssub_ov(ceil, overflowOrDiv0);
432         }
433         // A is positive, b is negative, return - ceil(a, -b).
434         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
435         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
436         return zero.ssub_ov(ceil, overflowOrDiv0);
437       });
438 
439   return overflowOrDiv0 ? Attribute() : result;
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // RemUIOp
444 //===----------------------------------------------------------------------===//
445 
446 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
447   // remui (x, 1) -> 0.
448   if (matchPattern(getRhs(), m_One()))
449     return Builder(getContext()).getZeroAttr(getType());
450 
451   // Don't fold if it would require a division by zero.
452   bool div0 = false;
453   auto result =
454       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
455         if (div0 || b.isNullValue()) {
456           div0 = true;
457           return a;
458         }
459         return a.urem(b);
460       });
461 
462   return div0 ? Attribute() : result;
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // RemSIOp
467 //===----------------------------------------------------------------------===//
468 
469 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
470   // remsi (x, 1) -> 0.
471   if (matchPattern(getRhs(), m_One()))
472     return Builder(getContext()).getZeroAttr(getType());
473 
474   // Don't fold if it would require a division by zero.
475   bool div0 = false;
476   auto result =
477       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
478         if (div0 || b.isNullValue()) {
479           div0 = true;
480           return a;
481         }
482         return a.srem(b);
483       });
484 
485   return div0 ? Attribute() : result;
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // AndIOp
490 //===----------------------------------------------------------------------===//
491 
492 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
493   /// and(x, 0) -> 0
494   if (matchPattern(getRhs(), m_Zero()))
495     return getRhs();
496   /// and(x, allOnes) -> x
497   APInt intValue;
498   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
499     return getLhs();
500 
501   return constFoldBinaryOp<IntegerAttr>(
502       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // OrIOp
507 //===----------------------------------------------------------------------===//
508 
509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
510   /// or(x, 0) -> x
511   if (matchPattern(getRhs(), m_Zero()))
512     return getLhs();
513   /// or(x, <all ones>) -> <all ones>
514   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
515     if (rhsAttr.getValue().isAllOnes())
516       return rhsAttr;
517 
518   return constFoldBinaryOp<IntegerAttr>(
519       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // XOrIOp
524 //===----------------------------------------------------------------------===//
525 
526 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
527   /// xor(x, 0) -> x
528   if (matchPattern(getRhs(), m_Zero()))
529     return getLhs();
530   /// xor(x, x) -> 0
531   if (getLhs() == getRhs())
532     return Builder(getContext()).getZeroAttr(getType());
533   /// xor(xor(x, a), a) -> x
534   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
535     if (prev.getRhs() == getRhs())
536       return prev.getLhs();
537 
538   return constFoldBinaryOp<IntegerAttr>(
539       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
540 }
541 
542 void arith::XOrIOp::getCanonicalizationPatterns(
543     RewritePatternSet &patterns, MLIRContext *context) {
544   patterns.add<XOrINotCmpI>(context);
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // NegFOp
549 //===----------------------------------------------------------------------===//
550 
551 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
552   return constFoldUnaryOp<FloatAttr>(operands,
553                                      [](const APFloat &a) { return -a; });
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // AddFOp
558 //===----------------------------------------------------------------------===//
559 
560 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
561   // addf(x, -0) -> x
562   if (matchPattern(getRhs(), m_NegZeroFloat()))
563     return getLhs();
564 
565   return constFoldBinaryOp<FloatAttr>(
566       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // SubFOp
571 //===----------------------------------------------------------------------===//
572 
573 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
574   // subf(x, +0) -> x
575   if (matchPattern(getRhs(), m_PosZeroFloat()))
576     return getLhs();
577 
578   return constFoldBinaryOp<FloatAttr>(
579       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // MaxFOp
584 //===----------------------------------------------------------------------===//
585 
586 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
587   assert(operands.size() == 2 && "maxf takes two operands");
588 
589   // maxf(x,x) -> x
590   if (getLhs() == getRhs())
591     return getRhs();
592 
593   // maxf(x, -inf) -> x
594   if (matchPattern(getRhs(), m_NegInfFloat()))
595     return getLhs();
596 
597   return constFoldBinaryOp<FloatAttr>(
598       operands,
599       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // MaxSIOp
604 //===----------------------------------------------------------------------===//
605 
606 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
607   assert(operands.size() == 2 && "binary operation takes two operands");
608 
609   // maxsi(x,x) -> x
610   if (getLhs() == getRhs())
611     return getRhs();
612 
613   APInt intValue;
614   // maxsi(x,MAX_INT) -> MAX_INT
615   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
616       intValue.isMaxSignedValue())
617     return getRhs();
618 
619   // maxsi(x, MIN_INT) -> x
620   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
621       intValue.isMinSignedValue())
622     return getLhs();
623 
624   return constFoldBinaryOp<IntegerAttr>(operands,
625                                         [](const APInt &a, const APInt &b) {
626                                           return llvm::APIntOps::smax(a, b);
627                                         });
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // MaxUIOp
632 //===----------------------------------------------------------------------===//
633 
634 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
635   assert(operands.size() == 2 && "binary operation takes two operands");
636 
637   // maxui(x,x) -> x
638   if (getLhs() == getRhs())
639     return getRhs();
640 
641   APInt intValue;
642   // maxui(x,MAX_INT) -> MAX_INT
643   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
644     return getRhs();
645 
646   // maxui(x, MIN_INT) -> x
647   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
648     return getLhs();
649 
650   return constFoldBinaryOp<IntegerAttr>(operands,
651                                         [](const APInt &a, const APInt &b) {
652                                           return llvm::APIntOps::umax(a, b);
653                                         });
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // MinFOp
658 //===----------------------------------------------------------------------===//
659 
660 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
661   assert(operands.size() == 2 && "minf takes two operands");
662 
663   // minf(x,x) -> x
664   if (getLhs() == getRhs())
665     return getRhs();
666 
667   // minf(x, +inf) -> x
668   if (matchPattern(getRhs(), m_PosInfFloat()))
669     return getLhs();
670 
671   return constFoldBinaryOp<FloatAttr>(
672       operands,
673       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // MinSIOp
678 //===----------------------------------------------------------------------===//
679 
680 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
681   assert(operands.size() == 2 && "binary operation takes two operands");
682 
683   // minsi(x,x) -> x
684   if (getLhs() == getRhs())
685     return getRhs();
686 
687   APInt intValue;
688   // minsi(x,MIN_INT) -> MIN_INT
689   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
690       intValue.isMinSignedValue())
691     return getRhs();
692 
693   // minsi(x, MAX_INT) -> x
694   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
695       intValue.isMaxSignedValue())
696     return getLhs();
697 
698   return constFoldBinaryOp<IntegerAttr>(operands,
699                                         [](const APInt &a, const APInt &b) {
700                                           return llvm::APIntOps::smin(a, b);
701                                         });
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // MinUIOp
706 //===----------------------------------------------------------------------===//
707 
708 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
709   assert(operands.size() == 2 && "binary operation takes two operands");
710 
711   // minui(x,x) -> x
712   if (getLhs() == getRhs())
713     return getRhs();
714 
715   APInt intValue;
716   // minui(x,MIN_INT) -> MIN_INT
717   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
718     return getRhs();
719 
720   // minui(x, MAX_INT) -> x
721   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
722     return getLhs();
723 
724   return constFoldBinaryOp<IntegerAttr>(operands,
725                                         [](const APInt &a, const APInt &b) {
726                                           return llvm::APIntOps::umin(a, b);
727                                         });
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // MulFOp
732 //===----------------------------------------------------------------------===//
733 
734 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
735   // mulf(x, 1) -> x
736   if (matchPattern(getRhs(), m_OneFloat()))
737     return getLhs();
738 
739   return constFoldBinaryOp<FloatAttr>(
740       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // DivFOp
745 //===----------------------------------------------------------------------===//
746 
747 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
748   // divf(x, 1) -> x
749   if (matchPattern(getRhs(), m_OneFloat()))
750     return getLhs();
751 
752   return constFoldBinaryOp<FloatAttr>(
753       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
754 }
755 
756 //===----------------------------------------------------------------------===//
757 // Utility functions for verifying cast ops
758 //===----------------------------------------------------------------------===//
759 
760 template <typename... Types>
761 using type_list = std::tuple<Types...> *;
762 
763 /// Returns a non-null type only if the provided type is one of the allowed
764 /// types or one of the allowed shaped types of the allowed types. Returns the
765 /// element type if a valid shaped type is provided.
766 template <typename... ShapedTypes, typename... ElementTypes>
767 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
768                               type_list<ElementTypes...>) {
769   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
770     return {};
771 
772   auto underlyingType = getElementTypeOrSelf(type);
773   if (!underlyingType.isa<ElementTypes...>())
774     return {};
775 
776   return underlyingType;
777 }
778 
779 /// Get allowed underlying types for vectors and tensors.
780 template <typename... ElementTypes>
781 static Type getTypeIfLike(Type type) {
782   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
783                            type_list<ElementTypes...>());
784 }
785 
786 /// Get allowed underlying types for vectors, tensors, and memrefs.
787 template <typename... ElementTypes>
788 static Type getTypeIfLikeOrMemRef(Type type) {
789   return getUnderlyingType(type,
790                            type_list<VectorType, TensorType, MemRefType>(),
791                            type_list<ElementTypes...>());
792 }
793 
794 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
795   return inputs.size() == 1 && outputs.size() == 1 &&
796          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // Verifiers for integer and floating point extension/truncation ops
801 //===----------------------------------------------------------------------===//
802 
803 // Extend ops can only extend to a wider type.
804 template <typename ValType, typename Op>
805 static LogicalResult verifyExtOp(Op op) {
806   Type srcType = getElementTypeOrSelf(op.getIn().getType());
807   Type dstType = getElementTypeOrSelf(op.getType());
808 
809   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
810     return op.emitError("result type ")
811            << dstType << " must be wider than operand type " << srcType;
812 
813   return success();
814 }
815 
816 // Truncate ops can only truncate to a shorter type.
817 template <typename ValType, typename Op>
818 static LogicalResult verifyTruncateOp(Op op) {
819   Type srcType = getElementTypeOrSelf(op.getIn().getType());
820   Type dstType = getElementTypeOrSelf(op.getType());
821 
822   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
823     return op.emitError("result type ")
824            << dstType << " must be shorter than operand type " << srcType;
825 
826   return success();
827 }
828 
829 /// Validate a cast that changes the width of a type.
830 template <template <typename> class WidthComparator, typename... ElementTypes>
831 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
832   if (!areValidCastInputsAndOutputs(inputs, outputs))
833     return false;
834 
835   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
836   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
837   if (!srcType || !dstType)
838     return false;
839 
840   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
841                                      srcType.getIntOrFloatBitWidth());
842 }
843 
844 //===----------------------------------------------------------------------===//
845 // ExtUIOp
846 //===----------------------------------------------------------------------===//
847 
848 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
849   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
850     getInMutable().assign(lhs.getIn());
851     return getResult();
852   }
853   Type resType = getType();
854   unsigned bitWidth;
855   if (auto shapedType = resType.dyn_cast<ShapedType>())
856     bitWidth = shapedType.getElementTypeBitWidth();
857   else
858     bitWidth = resType.getIntOrFloatBitWidth();
859   return constFoldCastOp<IntegerAttr, IntegerAttr>(
860       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
861         return a.zext(bitWidth);
862       });
863 }
864 
865 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
866   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
867 }
868 
869 LogicalResult arith::ExtUIOp::verify() {
870   return verifyExtOp<IntegerType>(*this);
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // ExtSIOp
875 //===----------------------------------------------------------------------===//
876 
877 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
878   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
879     getInMutable().assign(lhs.getIn());
880     return getResult();
881   }
882   Type resType = getType();
883   unsigned bitWidth;
884   if (auto shapedType = resType.dyn_cast<ShapedType>())
885     bitWidth = shapedType.getElementTypeBitWidth();
886   else
887     bitWidth = resType.getIntOrFloatBitWidth();
888   return constFoldCastOp<IntegerAttr, IntegerAttr>(
889       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
890         return a.sext(bitWidth);
891       });
892 }
893 
894 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
895   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
896 }
897 
898 void arith::ExtSIOp::getCanonicalizationPatterns(
899     RewritePatternSet &patterns, MLIRContext *context) {
900   patterns.add<ExtSIOfExtUI>(context);
901 }
902 
903 LogicalResult arith::ExtSIOp::verify() {
904   return verifyExtOp<IntegerType>(*this);
905 }
906 
907 //===----------------------------------------------------------------------===//
908 // ExtFOp
909 //===----------------------------------------------------------------------===//
910 
911 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
912   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
913 }
914 
915 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
916 
917 //===----------------------------------------------------------------------===//
918 // TruncIOp
919 //===----------------------------------------------------------------------===//
920 
921 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
922   assert(operands.size() == 1 && "unary operation takes one operand");
923 
924   // trunci(zexti(a)) -> a
925   // trunci(sexti(a)) -> a
926   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
927       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
928     return getOperand().getDefiningOp()->getOperand(0);
929 
930   // trunci(trunci(a)) -> trunci(a))
931   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
932     setOperand(getOperand().getDefiningOp()->getOperand(0));
933     return getResult();
934   }
935 
936   Type resType = getType();
937   unsigned bitWidth;
938   if (auto shapedType = resType.dyn_cast<ShapedType>())
939     bitWidth = shapedType.getElementTypeBitWidth();
940   else
941     bitWidth = resType.getIntOrFloatBitWidth();
942 
943   return constFoldCastOp<IntegerAttr, IntegerAttr>(
944       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
945         return a.trunc(bitWidth);
946       });
947 }
948 
949 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
950   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
951 }
952 
953 LogicalResult arith::TruncIOp::verify() {
954   return verifyTruncateOp<IntegerType>(*this);
955 }
956 
957 //===----------------------------------------------------------------------===//
958 // TruncFOp
959 //===----------------------------------------------------------------------===//
960 
961 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
962 /// can be represented without precision loss or rounding.
963 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
964   assert(operands.size() == 1 && "unary operation takes one operand");
965 
966   auto constOperand = operands.front();
967   if (!constOperand || !constOperand.isa<FloatAttr>())
968     return {};
969 
970   // Convert to target type via 'double'.
971   double sourceValue =
972       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
973   auto targetAttr = FloatAttr::get(getType(), sourceValue);
974 
975   // Propagate if constant's value does not change after truncation.
976   if (sourceValue == targetAttr.getValue().convertToDouble())
977     return targetAttr;
978 
979   return {};
980 }
981 
982 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
983   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
984 }
985 
986 LogicalResult arith::TruncFOp::verify() {
987   return verifyTruncateOp<FloatType>(*this);
988 }
989 
990 //===----------------------------------------------------------------------===//
991 // AndIOp
992 //===----------------------------------------------------------------------===//
993 
994 void arith::AndIOp::getCanonicalizationPatterns(
995     RewritePatternSet &patterns, MLIRContext *context) {
996   patterns.add<AndOfExtUI, AndOfExtSI>(context);
997 }
998 
999 //===----------------------------------------------------------------------===//
1000 // OrIOp
1001 //===----------------------------------------------------------------------===//
1002 
1003 void arith::OrIOp::getCanonicalizationPatterns(
1004     RewritePatternSet &patterns, MLIRContext *context) {
1005   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1006 }
1007 
1008 //===----------------------------------------------------------------------===//
1009 // Verifiers for casts between integers and floats.
1010 //===----------------------------------------------------------------------===//
1011 
1012 template <typename From, typename To>
1013 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1014   if (!areValidCastInputsAndOutputs(inputs, outputs))
1015     return false;
1016 
1017   auto srcType = getTypeIfLike<From>(inputs.front());
1018   auto dstType = getTypeIfLike<To>(outputs.back());
1019 
1020   return srcType && dstType;
1021 }
1022 
1023 //===----------------------------------------------------------------------===//
1024 // UIToFPOp
1025 //===----------------------------------------------------------------------===//
1026 
1027 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1028   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1029 }
1030 
1031 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1032   Type resType = getType();
1033   Type resEleType;
1034   if (auto shapedType = resType.dyn_cast<ShapedType>())
1035     resEleType = shapedType.getElementType();
1036   else
1037     resEleType = resType;
1038   return constFoldCastOp<IntegerAttr, FloatAttr>(
1039       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1040         FloatType floatTy = resEleType.cast<FloatType>();
1041         APFloat apf(floatTy.getFloatSemantics(),
1042                     APInt::getZero(floatTy.getWidth()));
1043         apf.convertFromAPInt(a, /*IsSigned=*/false,
1044                              APFloat::rmNearestTiesToEven);
1045         return apf;
1046       });
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // SIToFPOp
1051 //===----------------------------------------------------------------------===//
1052 
1053 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1054   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1055 }
1056 
1057 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1058   Type resType = getType();
1059   Type resEleType;
1060   if (auto shapedType = resType.dyn_cast<ShapedType>())
1061     resEleType = shapedType.getElementType();
1062   else
1063     resEleType = resType;
1064   return constFoldCastOp<IntegerAttr, FloatAttr>(
1065       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1066         FloatType floatTy = resEleType.cast<FloatType>();
1067         APFloat apf(floatTy.getFloatSemantics(),
1068                     APInt::getZero(floatTy.getWidth()));
1069         apf.convertFromAPInt(a, /*IsSigned=*/true,
1070                              APFloat::rmNearestTiesToEven);
1071         return apf;
1072       });
1073 }
1074 //===----------------------------------------------------------------------===//
1075 // FPToUIOp
1076 //===----------------------------------------------------------------------===//
1077 
1078 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1079   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1080 }
1081 
1082 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1083   Type resType = getType();
1084   Type resEleType;
1085   if (auto shapedType = resType.dyn_cast<ShapedType>())
1086     resEleType = shapedType.getElementType();
1087   else
1088     resEleType = resType;
1089   return constFoldCastOp<FloatAttr, IntegerAttr>(
1090       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1091         IntegerType intTy = resEleType.cast<IntegerType>();
1092         bool ignored;
1093         APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1094         castStatus = APFloat::opInvalidOp !=
1095                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1096         return api;
1097       });
1098 }
1099 
1100 //===----------------------------------------------------------------------===//
1101 // FPToSIOp
1102 //===----------------------------------------------------------------------===//
1103 
1104 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1105   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1106 }
1107 
1108 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1109   Type resType = getType();
1110   Type resEleType;
1111   if (auto shapedType = resType.dyn_cast<ShapedType>())
1112     resEleType = shapedType.getElementType();
1113   else
1114     resEleType = resType;
1115   return constFoldCastOp<FloatAttr, IntegerAttr>(
1116       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1117         IntegerType intTy = resEleType.cast<IntegerType>();
1118         bool ignored;
1119         APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1120         castStatus = APFloat::opInvalidOp !=
1121                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1122         return api;
1123       });
1124 }
1125 
1126 //===----------------------------------------------------------------------===//
1127 // IndexCastOp
1128 //===----------------------------------------------------------------------===//
1129 
1130 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1131                                            TypeRange outputs) {
1132   if (!areValidCastInputsAndOutputs(inputs, outputs))
1133     return false;
1134 
1135   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1136   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1137   if (!srcType || !dstType)
1138     return false;
1139 
1140   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1141          (srcType.isSignlessInteger() && dstType.isIndex());
1142 }
1143 
1144 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1145   // index_cast(constant) -> constant
1146   // A little hack because we go through int. Otherwise, the size of the
1147   // constant might need to change.
1148   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1149     return IntegerAttr::get(getType(), value.getInt());
1150 
1151   return {};
1152 }
1153 
1154 void arith::IndexCastOp::getCanonicalizationPatterns(
1155     RewritePatternSet &patterns, MLIRContext *context) {
1156   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1157 }
1158 
1159 //===----------------------------------------------------------------------===//
1160 // BitcastOp
1161 //===----------------------------------------------------------------------===//
1162 
1163 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1164   if (!areValidCastInputsAndOutputs(inputs, outputs))
1165     return false;
1166 
1167   auto srcType =
1168       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1169   auto dstType =
1170       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1171   if (!srcType || !dstType)
1172     return false;
1173 
1174   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1175 }
1176 
1177 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1178   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1179 
1180   auto resType = getType();
1181   auto operand = operands[0];
1182   if (!operand)
1183     return {};
1184 
1185   /// Bitcast dense elements.
1186   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1187     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1188   /// Other shaped types unhandled.
1189   if (resType.isa<ShapedType>())
1190     return {};
1191 
1192   /// Bitcast integer or float to integer or float.
1193   APInt bits = operand.isa<FloatAttr>()
1194                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1195                    : operand.cast<IntegerAttr>().getValue();
1196 
1197   if (auto resFloatType = resType.dyn_cast<FloatType>())
1198     return FloatAttr::get(resType,
1199                           APFloat(resFloatType.getFloatSemantics(), bits));
1200   return IntegerAttr::get(resType, bits);
1201 }
1202 
1203 void arith::BitcastOp::getCanonicalizationPatterns(
1204     RewritePatternSet &patterns, MLIRContext *context) {
1205   patterns.add<BitcastOfBitcast>(context);
1206 }
1207 
1208 //===----------------------------------------------------------------------===//
1209 // Helpers for compare ops
1210 //===----------------------------------------------------------------------===//
1211 
1212 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1213 static Type getI1SameShape(Type type) {
1214   auto i1Type = IntegerType::get(type.getContext(), 1);
1215   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1216     return RankedTensorType::get(tensorType.getShape(), i1Type);
1217   if (type.isa<UnrankedTensorType>())
1218     return UnrankedTensorType::get(i1Type);
1219   if (auto vectorType = type.dyn_cast<VectorType>())
1220     return VectorType::get(vectorType.getShape(), i1Type,
1221                            vectorType.getNumScalableDims());
1222   return i1Type;
1223 }
1224 
1225 //===----------------------------------------------------------------------===//
1226 // CmpIOp
1227 //===----------------------------------------------------------------------===//
1228 
1229 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1230 /// comparison predicates.
1231 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1232                                     const APInt &lhs, const APInt &rhs) {
1233   switch (predicate) {
1234   case arith::CmpIPredicate::eq:
1235     return lhs.eq(rhs);
1236   case arith::CmpIPredicate::ne:
1237     return lhs.ne(rhs);
1238   case arith::CmpIPredicate::slt:
1239     return lhs.slt(rhs);
1240   case arith::CmpIPredicate::sle:
1241     return lhs.sle(rhs);
1242   case arith::CmpIPredicate::sgt:
1243     return lhs.sgt(rhs);
1244   case arith::CmpIPredicate::sge:
1245     return lhs.sge(rhs);
1246   case arith::CmpIPredicate::ult:
1247     return lhs.ult(rhs);
1248   case arith::CmpIPredicate::ule:
1249     return lhs.ule(rhs);
1250   case arith::CmpIPredicate::ugt:
1251     return lhs.ugt(rhs);
1252   case arith::CmpIPredicate::uge:
1253     return lhs.uge(rhs);
1254   }
1255   llvm_unreachable("unknown cmpi predicate kind");
1256 }
1257 
1258 /// Returns true if the predicate is true for two equal operands.
1259 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1260   switch (predicate) {
1261   case arith::CmpIPredicate::eq:
1262   case arith::CmpIPredicate::sle:
1263   case arith::CmpIPredicate::sge:
1264   case arith::CmpIPredicate::ule:
1265   case arith::CmpIPredicate::uge:
1266     return true;
1267   case arith::CmpIPredicate::ne:
1268   case arith::CmpIPredicate::slt:
1269   case arith::CmpIPredicate::sgt:
1270   case arith::CmpIPredicate::ult:
1271   case arith::CmpIPredicate::ugt:
1272     return false;
1273   }
1274   llvm_unreachable("unknown cmpi predicate kind");
1275 }
1276 
1277 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1278   auto boolAttr = BoolAttr::get(ctx, value);
1279   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1280   if (!shapedType)
1281     return boolAttr;
1282   return DenseElementsAttr::get(shapedType, boolAttr);
1283 }
1284 
1285 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1286   assert(operands.size() == 2 && "cmpi takes two operands");
1287 
1288   // cmpi(pred, x, x)
1289   if (getLhs() == getRhs()) {
1290     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1291     return getBoolAttribute(getType(), getContext(), val);
1292   }
1293 
1294   if (matchPattern(getRhs(), m_Zero())) {
1295     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1296       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1297         // extsi(%x : i1 -> iN) != 0  ->  %x
1298         if (getPredicate() == arith::CmpIPredicate::ne) {
1299           return extOp.getOperand();
1300         }
1301       }
1302     }
1303     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1304       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1305         // extui(%x : i1 -> iN) != 0  ->  %x
1306         if (getPredicate() == arith::CmpIPredicate::ne) {
1307           return extOp.getOperand();
1308         }
1309       }
1310     }
1311   }
1312 
1313   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1314   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1315   if (!lhs || !rhs)
1316     return {};
1317 
1318   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1319   return BoolAttr::get(getContext(), val);
1320 }
1321 
1322 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1323                                                 MLIRContext *context) {
1324   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1325 }
1326 
1327 //===----------------------------------------------------------------------===//
1328 // CmpFOp
1329 //===----------------------------------------------------------------------===//
1330 
1331 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1332 /// comparison predicates.
1333 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1334                                     const APFloat &lhs, const APFloat &rhs) {
1335   auto cmpResult = lhs.compare(rhs);
1336   switch (predicate) {
1337   case arith::CmpFPredicate::AlwaysFalse:
1338     return false;
1339   case arith::CmpFPredicate::OEQ:
1340     return cmpResult == APFloat::cmpEqual;
1341   case arith::CmpFPredicate::OGT:
1342     return cmpResult == APFloat::cmpGreaterThan;
1343   case arith::CmpFPredicate::OGE:
1344     return cmpResult == APFloat::cmpGreaterThan ||
1345            cmpResult == APFloat::cmpEqual;
1346   case arith::CmpFPredicate::OLT:
1347     return cmpResult == APFloat::cmpLessThan;
1348   case arith::CmpFPredicate::OLE:
1349     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1350   case arith::CmpFPredicate::ONE:
1351     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1352   case arith::CmpFPredicate::ORD:
1353     return cmpResult != APFloat::cmpUnordered;
1354   case arith::CmpFPredicate::UEQ:
1355     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1356   case arith::CmpFPredicate::UGT:
1357     return cmpResult == APFloat::cmpUnordered ||
1358            cmpResult == APFloat::cmpGreaterThan;
1359   case arith::CmpFPredicate::UGE:
1360     return cmpResult == APFloat::cmpUnordered ||
1361            cmpResult == APFloat::cmpGreaterThan ||
1362            cmpResult == APFloat::cmpEqual;
1363   case arith::CmpFPredicate::ULT:
1364     return cmpResult == APFloat::cmpUnordered ||
1365            cmpResult == APFloat::cmpLessThan;
1366   case arith::CmpFPredicate::ULE:
1367     return cmpResult == APFloat::cmpUnordered ||
1368            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1369   case arith::CmpFPredicate::UNE:
1370     return cmpResult != APFloat::cmpEqual;
1371   case arith::CmpFPredicate::UNO:
1372     return cmpResult == APFloat::cmpUnordered;
1373   case arith::CmpFPredicate::AlwaysTrue:
1374     return true;
1375   }
1376   llvm_unreachable("unknown cmpf predicate kind");
1377 }
1378 
1379 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1380   assert(operands.size() == 2 && "cmpf takes two operands");
1381 
1382   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1383   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1384 
1385   // If one operand is NaN, making them both NaN does not change the result.
1386   if (lhs && lhs.getValue().isNaN())
1387     rhs = lhs;
1388   if (rhs && rhs.getValue().isNaN())
1389     lhs = rhs;
1390 
1391   if (!lhs || !rhs)
1392     return {};
1393 
1394   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1395   return BoolAttr::get(getContext(), val);
1396 }
1397 
1398 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1399 public:
1400   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1401 
1402   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1403                                                  bool isUnsigned) {
1404     using namespace arith;
1405     switch (pred) {
1406     case CmpFPredicate::UEQ:
1407     case CmpFPredicate::OEQ:
1408       return CmpIPredicate::eq;
1409     case CmpFPredicate::UGT:
1410     case CmpFPredicate::OGT:
1411       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1412     case CmpFPredicate::UGE:
1413     case CmpFPredicate::OGE:
1414       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1415     case CmpFPredicate::ULT:
1416     case CmpFPredicate::OLT:
1417       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1418     case CmpFPredicate::ULE:
1419     case CmpFPredicate::OLE:
1420       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1421     case CmpFPredicate::UNE:
1422     case CmpFPredicate::ONE:
1423       return CmpIPredicate::ne;
1424     default:
1425       llvm_unreachable("Unexpected predicate!");
1426     }
1427   }
1428 
1429   LogicalResult matchAndRewrite(CmpFOp op,
1430                                 PatternRewriter &rewriter) const override {
1431     FloatAttr flt;
1432     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1433       return failure();
1434 
1435     const APFloat &rhs = flt.getValue();
1436 
1437     // Don't attempt to fold a nan.
1438     if (rhs.isNaN())
1439       return failure();
1440 
1441     // Get the width of the mantissa.  We don't want to hack on conversions that
1442     // might lose information from the integer, e.g. "i64 -> float"
1443     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1444     int mantissaWidth = floatTy.getFPMantissaWidth();
1445     if (mantissaWidth <= 0)
1446       return failure();
1447 
1448     bool isUnsigned;
1449     Value intVal;
1450 
1451     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1452       isUnsigned = false;
1453       intVal = si.getIn();
1454     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1455       isUnsigned = true;
1456       intVal = ui.getIn();
1457     } else {
1458       return failure();
1459     }
1460 
1461     // Check to see that the input is converted from an integer type that is
1462     // small enough that preserves all bits.
1463     auto intTy = intVal.getType().cast<IntegerType>();
1464     auto intWidth = intTy.getWidth();
1465 
1466     // Number of bits representing values, as opposed to the sign
1467     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1468 
1469     // Following test does NOT adjust intWidth downwards for signed inputs,
1470     // because the most negative value still requires all the mantissa bits
1471     // to distinguish it from one less than that value.
1472     if ((int)intWidth > mantissaWidth) {
1473       // Conversion would lose accuracy. Check if loss can impact comparison.
1474       int exponent = ilogb(rhs);
1475       if (exponent == APFloat::IEK_Inf) {
1476         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1477         if (maxExponent < (int)valueBits) {
1478           // Conversion could create infinity.
1479           return failure();
1480         }
1481       } else {
1482         // Note that if rhs is zero or NaN, then Exp is negative
1483         // and first condition is trivially false.
1484         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1485           // Conversion could affect comparison.
1486           return failure();
1487         }
1488       }
1489     }
1490 
1491     // Convert to equivalent cmpi predicate
1492     CmpIPredicate pred;
1493     switch (op.getPredicate()) {
1494     case CmpFPredicate::ORD:
1495       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1496       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1497                                                  /*width=*/1);
1498       return success();
1499     case CmpFPredicate::UNO:
1500       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1501       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1502                                                  /*width=*/1);
1503       return success();
1504     default:
1505       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1506       break;
1507     }
1508 
1509     if (!isUnsigned) {
1510       // If the rhs value is > SignedMax, fold the comparison.  This handles
1511       // +INF and large values.
1512       APFloat signedMax(rhs.getSemantics());
1513       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1514                                  APFloat::rmNearestTiesToEven);
1515       if (signedMax < rhs) { // smax < 13123.0
1516         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1517             pred == CmpIPredicate::sle)
1518           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1519                                                      /*width=*/1);
1520         else
1521           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1522                                                      /*width=*/1);
1523         return success();
1524       }
1525     } else {
1526       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1527       // +INF and large values.
1528       APFloat unsignedMax(rhs.getSemantics());
1529       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1530                                    APFloat::rmNearestTiesToEven);
1531       if (unsignedMax < rhs) { // umax < 13123.0
1532         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1533             pred == CmpIPredicate::ule)
1534           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1535                                                      /*width=*/1);
1536         else
1537           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1538                                                      /*width=*/1);
1539         return success();
1540       }
1541     }
1542 
1543     if (!isUnsigned) {
1544       // See if the rhs value is < SignedMin.
1545       APFloat signedMin(rhs.getSemantics());
1546       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1547                                  APFloat::rmNearestTiesToEven);
1548       if (signedMin > rhs) { // smin > 12312.0
1549         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1550             pred == CmpIPredicate::sge)
1551           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1552                                                      /*width=*/1);
1553         else
1554           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1555                                                      /*width=*/1);
1556         return success();
1557       }
1558     } else {
1559       // See if the rhs value is < UnsignedMin.
1560       APFloat unsignedMin(rhs.getSemantics());
1561       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1562                                    APFloat::rmNearestTiesToEven);
1563       if (unsignedMin > rhs) { // umin > 12312.0
1564         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1565             pred == CmpIPredicate::uge)
1566           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1567                                                      /*width=*/1);
1568         else
1569           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1570                                                      /*width=*/1);
1571         return success();
1572       }
1573     }
1574 
1575     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1576     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1577     // casting the FP value to the integer value and back, checking for
1578     // equality. Don't do this for zero, because -0.0 is not fractional.
1579     bool ignored;
1580     APSInt rhsInt(intWidth, isUnsigned);
1581     if (APFloat::opInvalidOp ==
1582         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1583       // Undefined behavior invoked - the destination type can't represent
1584       // the input constant.
1585       return failure();
1586     }
1587 
1588     if (!rhs.isZero()) {
1589       APFloat apf(floatTy.getFloatSemantics(),
1590                   APInt::getZero(floatTy.getWidth()));
1591       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1592 
1593       bool equal = apf == rhs;
1594       if (!equal) {
1595         // If we had a comparison against a fractional value, we have to adjust
1596         // the compare predicate and sometimes the value.  rhsInt is rounded
1597         // towards zero at this point.
1598         switch (pred) {
1599         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1600           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1601                                                      /*width=*/1);
1602           return success();
1603         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1604           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1605                                                      /*width=*/1);
1606           return success();
1607         case CmpIPredicate::ule:
1608           // (float)int <= 4.4   --> int <= 4
1609           // (float)int <= -4.4  --> false
1610           if (rhs.isNegative()) {
1611             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1612                                                        /*width=*/1);
1613             return success();
1614           }
1615           break;
1616         case CmpIPredicate::sle:
1617           // (float)int <= 4.4   --> int <= 4
1618           // (float)int <= -4.4  --> int < -4
1619           if (rhs.isNegative())
1620             pred = CmpIPredicate::slt;
1621           break;
1622         case CmpIPredicate::ult:
1623           // (float)int < -4.4   --> false
1624           // (float)int < 4.4    --> int <= 4
1625           if (rhs.isNegative()) {
1626             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1627                                                        /*width=*/1);
1628             return success();
1629           }
1630           pred = CmpIPredicate::ule;
1631           break;
1632         case CmpIPredicate::slt:
1633           // (float)int < -4.4   --> int < -4
1634           // (float)int < 4.4    --> int <= 4
1635           if (!rhs.isNegative())
1636             pred = CmpIPredicate::sle;
1637           break;
1638         case CmpIPredicate::ugt:
1639           // (float)int > 4.4    --> int > 4
1640           // (float)int > -4.4   --> true
1641           if (rhs.isNegative()) {
1642             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1643                                                        /*width=*/1);
1644             return success();
1645           }
1646           break;
1647         case CmpIPredicate::sgt:
1648           // (float)int > 4.4    --> int > 4
1649           // (float)int > -4.4   --> int >= -4
1650           if (rhs.isNegative())
1651             pred = CmpIPredicate::sge;
1652           break;
1653         case CmpIPredicate::uge:
1654           // (float)int >= -4.4   --> true
1655           // (float)int >= 4.4    --> int > 4
1656           if (rhs.isNegative()) {
1657             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1658                                                        /*width=*/1);
1659             return success();
1660           }
1661           pred = CmpIPredicate::ugt;
1662           break;
1663         case CmpIPredicate::sge:
1664           // (float)int >= -4.4   --> int >= -4
1665           // (float)int >= 4.4    --> int > 4
1666           if (!rhs.isNegative())
1667             pred = CmpIPredicate::sgt;
1668           break;
1669         }
1670       }
1671     }
1672 
1673     // Lower this FP comparison into an appropriate integer version of the
1674     // comparison.
1675     rewriter.replaceOpWithNewOp<CmpIOp>(
1676         op, pred, intVal,
1677         rewriter.create<ConstantOp>(
1678             op.getLoc(), intVal.getType(),
1679             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1680     return success();
1681   }
1682 };
1683 
1684 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1685                                                 MLIRContext *context) {
1686   patterns.insert<CmpFIntToFPConst>(context);
1687 }
1688 
1689 //===----------------------------------------------------------------------===//
1690 // SelectOp
1691 //===----------------------------------------------------------------------===//
1692 
1693 // Transforms a select of a boolean to arithmetic operations
1694 //
1695 //  arith.select %arg, %x, %y : i1
1696 //
1697 //  becomes
1698 //
1699 //  and(%arg, %x) or and(!%arg, %y)
1700 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1701   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1702 
1703   LogicalResult matchAndRewrite(arith::SelectOp op,
1704                                 PatternRewriter &rewriter) const override {
1705     if (!op.getType().isInteger(1))
1706       return failure();
1707 
1708     Value falseConstant =
1709         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1710     Value notCondition = rewriter.create<arith::XOrIOp>(
1711         op.getLoc(), op.getCondition(), falseConstant);
1712 
1713     Value trueVal = rewriter.create<arith::AndIOp>(
1714         op.getLoc(), op.getCondition(), op.getTrueValue());
1715     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1716                                                     op.getFalseValue());
1717     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1718     return success();
1719   }
1720 };
1721 
1722 //  select %arg, %c1, %c0 => extui %arg
1723 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1724   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1725 
1726   LogicalResult matchAndRewrite(arith::SelectOp op,
1727                                 PatternRewriter &rewriter) const override {
1728     // Cannot extui i1 to i1, or i1 to f32
1729     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1730       return failure();
1731 
1732     // select %x, c1, %c0 => extui %arg
1733     if (matchPattern(op.getTrueValue(), m_One()))
1734       if (matchPattern(op.getFalseValue(), m_Zero())) {
1735         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1736                                                     op.getCondition());
1737         return success();
1738       }
1739 
1740     // select %x, c0, %c1 => extui (xor %arg, true)
1741     if (matchPattern(op.getTrueValue(), m_Zero()))
1742       if (matchPattern(op.getFalseValue(), m_One())) {
1743         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1744             op, op.getType(),
1745             rewriter.create<arith::XOrIOp>(
1746                 op.getLoc(), op.getCondition(),
1747                 rewriter.create<arith::ConstantIntOp>(
1748                     op.getLoc(), 1, op.getCondition().getType())));
1749         return success();
1750       }
1751 
1752     return failure();
1753   }
1754 };
1755 
1756 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1757                                                   MLIRContext *context) {
1758   results.add<SelectI1Simplify, SelectToExtUI>(context);
1759 }
1760 
1761 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1762   Value trueVal = getTrueValue();
1763   Value falseVal = getFalseValue();
1764   if (trueVal == falseVal)
1765     return trueVal;
1766 
1767   Value condition = getCondition();
1768 
1769   // select true, %0, %1 => %0
1770   if (matchPattern(condition, m_One()))
1771     return trueVal;
1772 
1773   // select false, %0, %1 => %1
1774   if (matchPattern(condition, m_Zero()))
1775     return falseVal;
1776 
1777   // select %x, true, false => %x
1778   if (getType().isInteger(1))
1779     if (matchPattern(getTrueValue(), m_One()))
1780       if (matchPattern(getFalseValue(), m_Zero()))
1781         return condition;
1782 
1783   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1784     auto pred = cmp.getPredicate();
1785     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1786       auto cmpLhs = cmp.getLhs();
1787       auto cmpRhs = cmp.getRhs();
1788 
1789       // %0 = arith.cmpi eq, %arg0, %arg1
1790       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1791 
1792       // %0 = arith.cmpi ne, %arg0, %arg1
1793       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1794 
1795       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1796           (cmpRhs == trueVal && cmpLhs == falseVal))
1797         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1798     }
1799   }
1800   return nullptr;
1801 }
1802 
1803 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1804   Type conditionType, resultType;
1805   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1806   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1807       parser.parseOptionalAttrDict(result.attributes) ||
1808       parser.parseColonType(resultType))
1809     return failure();
1810 
1811   // Check for the explicit condition type if this is a masked tensor or vector.
1812   if (succeeded(parser.parseOptionalComma())) {
1813     conditionType = resultType;
1814     if (parser.parseType(resultType))
1815       return failure();
1816   } else {
1817     conditionType = parser.getBuilder().getI1Type();
1818   }
1819 
1820   result.addTypes(resultType);
1821   return parser.resolveOperands(operands,
1822                                 {conditionType, resultType, resultType},
1823                                 parser.getNameLoc(), result.operands);
1824 }
1825 
1826 void arith::SelectOp::print(OpAsmPrinter &p) {
1827   p << " " << getOperands();
1828   p.printOptionalAttrDict((*this)->getAttrs());
1829   p << " : ";
1830   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1831     p << condType << ", ";
1832   p << getType();
1833 }
1834 
1835 LogicalResult arith::SelectOp::verify() {
1836   Type conditionType = getCondition().getType();
1837   if (conditionType.isSignlessInteger(1))
1838     return success();
1839 
1840   // If the result type is a vector or tensor, the type can be a mask with the
1841   // same elements.
1842   Type resultType = getType();
1843   if (!resultType.isa<TensorType, VectorType>())
1844     return emitOpError() << "expected condition to be a signless i1, but got "
1845                          << conditionType;
1846   Type shapedConditionType = getI1SameShape(resultType);
1847   if (conditionType != shapedConditionType) {
1848     return emitOpError() << "expected condition type to have the same shape "
1849                             "as the result type, expected "
1850                          << shapedConditionType << ", but got "
1851                          << conditionType;
1852   }
1853   return success();
1854 }
1855 //===----------------------------------------------------------------------===//
1856 // ShLIOp
1857 //===----------------------------------------------------------------------===//
1858 
1859 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1860   // Don't fold if shifting more than the bit width.
1861   bool bounded = false;
1862   auto result = constFoldBinaryOp<IntegerAttr>(
1863       operands, [&](const APInt &a, const APInt &b) {
1864         bounded = b.ule(b.getBitWidth());
1865         return a.shl(b);
1866       });
1867   return bounded ? result : Attribute();
1868 }
1869 
1870 //===----------------------------------------------------------------------===//
1871 // ShRUIOp
1872 //===----------------------------------------------------------------------===//
1873 
1874 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1875   // Don't fold if shifting more than the bit width.
1876   bool bounded = false;
1877   auto result = constFoldBinaryOp<IntegerAttr>(
1878       operands, [&](const APInt &a, const APInt &b) {
1879         bounded = b.ule(b.getBitWidth());
1880         return a.lshr(b);
1881       });
1882   return bounded ? result : Attribute();
1883 }
1884 
1885 //===----------------------------------------------------------------------===//
1886 // ShRSIOp
1887 //===----------------------------------------------------------------------===//
1888 
1889 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1890   // Don't fold if shifting more than the bit width.
1891   bool bounded = false;
1892   auto result = constFoldBinaryOp<IntegerAttr>(
1893       operands, [&](const APInt &a, const APInt &b) {
1894         bounded = b.ule(b.getBitWidth());
1895         return a.ashr(b);
1896       });
1897   return bounded ? result : Attribute();
1898 }
1899 
1900 //===----------------------------------------------------------------------===//
1901 // Atomic Enum
1902 //===----------------------------------------------------------------------===//
1903 
1904 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1905 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1906                                             OpBuilder &builder, Location loc) {
1907   switch (kind) {
1908   case AtomicRMWKind::maxf:
1909     return builder.getFloatAttr(
1910         resultType,
1911         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1912                         /*Negative=*/true));
1913   case AtomicRMWKind::addf:
1914   case AtomicRMWKind::addi:
1915   case AtomicRMWKind::maxu:
1916   case AtomicRMWKind::ori:
1917     return builder.getZeroAttr(resultType);
1918   case AtomicRMWKind::andi:
1919     return builder.getIntegerAttr(
1920         resultType,
1921         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1922   case AtomicRMWKind::maxs:
1923     return builder.getIntegerAttr(
1924         resultType,
1925         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1926   case AtomicRMWKind::minf:
1927     return builder.getFloatAttr(
1928         resultType,
1929         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1930                         /*Negative=*/false));
1931   case AtomicRMWKind::mins:
1932     return builder.getIntegerAttr(
1933         resultType,
1934         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1935   case AtomicRMWKind::minu:
1936     return builder.getIntegerAttr(
1937         resultType,
1938         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1939   case AtomicRMWKind::muli:
1940     return builder.getIntegerAttr(resultType, 1);
1941   case AtomicRMWKind::mulf:
1942     return builder.getFloatAttr(resultType, 1);
1943   // TODO: Add remaining reduction operations.
1944   default:
1945     (void)emitOptionalError(loc, "Reduction operation type not supported");
1946     break;
1947   }
1948   return nullptr;
1949 }
1950 
1951 /// Returns the identity value associated with an AtomicRMWKind op.
1952 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1953                                     OpBuilder &builder, Location loc) {
1954   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1955   return builder.create<arith::ConstantOp>(loc, attr);
1956 }
1957 
1958 /// Return the value obtained by applying the reduction operation kind
1959 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1960 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1961                                   Location loc, Value lhs, Value rhs) {
1962   switch (op) {
1963   case AtomicRMWKind::addf:
1964     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1965   case AtomicRMWKind::addi:
1966     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1967   case AtomicRMWKind::mulf:
1968     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1969   case AtomicRMWKind::muli:
1970     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1971   case AtomicRMWKind::maxf:
1972     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1973   case AtomicRMWKind::minf:
1974     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1975   case AtomicRMWKind::maxs:
1976     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1977   case AtomicRMWKind::mins:
1978     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1979   case AtomicRMWKind::maxu:
1980     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1981   case AtomicRMWKind::minu:
1982     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1983   case AtomicRMWKind::ori:
1984     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1985   case AtomicRMWKind::andi:
1986     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1987   // TODO: Add remaining reduction operations.
1988   default:
1989     (void)emitOptionalError(loc, "Reduction operation type not supported");
1990     break;
1991   }
1992   return nullptr;
1993 }
1994 
1995 //===----------------------------------------------------------------------===//
1996 // TableGen'd op method definitions
1997 //===----------------------------------------------------------------------===//
1998 
1999 #define GET_OP_CLASSES
2000 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2001 
2002 //===----------------------------------------------------------------------===//
2003 // TableGen'd enum attribute definitions
2004 //===----------------------------------------------------------------------===//
2005 
2006 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2007