1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/Dialect/Utils/StaticValueUtils.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/SmallString.h"
20 
21 #include "llvm/ADT/APSInt.h"
22 
23 using namespace mlir;
24 using namespace mlir::arith;
25 
26 //===----------------------------------------------------------------------===//
27 // Pattern helpers
28 //===----------------------------------------------------------------------===//
29 
30 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
31                                    Attribute lhs, Attribute rhs) {
32   return builder.getIntegerAttr(res.getType(),
33                                 lhs.cast<IntegerAttr>().getInt() +
34                                     rhs.cast<IntegerAttr>().getInt());
35 }
36 
37 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
38                                    Attribute lhs, Attribute rhs) {
39   return builder.getIntegerAttr(res.getType(),
40                                 lhs.cast<IntegerAttr>().getInt() -
41                                     rhs.cast<IntegerAttr>().getInt());
42 }
43 
44 /// Invert an integer comparison predicate.
45 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
46   switch (pred) {
47   case arith::CmpIPredicate::eq:
48     return arith::CmpIPredicate::ne;
49   case arith::CmpIPredicate::ne:
50     return arith::CmpIPredicate::eq;
51   case arith::CmpIPredicate::slt:
52     return arith::CmpIPredicate::sge;
53   case arith::CmpIPredicate::sle:
54     return arith::CmpIPredicate::sgt;
55   case arith::CmpIPredicate::sgt:
56     return arith::CmpIPredicate::sle;
57   case arith::CmpIPredicate::sge:
58     return arith::CmpIPredicate::slt;
59   case arith::CmpIPredicate::ult:
60     return arith::CmpIPredicate::uge;
61   case arith::CmpIPredicate::ule:
62     return arith::CmpIPredicate::ugt;
63   case arith::CmpIPredicate::ugt:
64     return arith::CmpIPredicate::ule;
65   case arith::CmpIPredicate::uge:
66     return arith::CmpIPredicate::ult;
67   }
68   llvm_unreachable("unknown cmpi predicate kind");
69 }
70 
71 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
72   return arith::CmpIPredicateAttr::get(pred.getContext(),
73                                        invertPredicate(pred.getValue()));
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // TableGen'd canonicalization patterns
78 //===----------------------------------------------------------------------===//
79 
80 namespace {
81 #include "ArithmeticCanonicalization.inc"
82 } // namespace
83 
84 //===----------------------------------------------------------------------===//
85 // ConstantOp
86 //===----------------------------------------------------------------------===//
87 
88 void arith::ConstantOp::getAsmResultNames(
89     function_ref<void(Value, StringRef)> setNameFn) {
90   auto type = getType();
91   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
92     auto intType = type.dyn_cast<IntegerType>();
93 
94     // Sugar i1 constants with 'true' and 'false'.
95     if (intType && intType.getWidth() == 1)
96       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
97 
98     // Otherwise, build a compex name with the value and type.
99     SmallString<32> specialNameBuffer;
100     llvm::raw_svector_ostream specialName(specialNameBuffer);
101     specialName << 'c' << intCst.getInt();
102     if (intType)
103       specialName << '_' << type;
104     setNameFn(getResult(), specialName.str());
105   } else {
106     setNameFn(getResult(), "cst");
107   }
108 }
109 
110 /// TODO: disallow arith.constant to return anything other than signless integer
111 /// or float like.
112 LogicalResult arith::ConstantOp::verify() {
113   auto type = getType();
114   // The value's type must match the return type.
115   if (getValue().getType() != type) {
116     return emitOpError() << "value type " << getValue().getType()
117                          << " must match return type: " << type;
118   }
119   // Integer values must be signless.
120   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
121     return emitOpError("integer return type must be signless");
122   // Any float or elements attribute are acceptable.
123   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
124     return emitOpError(
125         "value must be an integer, float, or elements attribute");
126   }
127   return success();
128 }
129 
130 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
131   // The value's type must be the same as the provided type.
132   if (value.getType() != type)
133     return false;
134   // Integer values must be signless.
135   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
136     return false;
137   // Integer, float, and element attributes are buildable.
138   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
139 }
140 
141 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
142   return getValue();
143 }
144 
145 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
146                                  int64_t value, unsigned width) {
147   auto type = builder.getIntegerType(width);
148   arith::ConstantOp::build(builder, result, type,
149                            builder.getIntegerAttr(type, value));
150 }
151 
152 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
153                                  int64_t value, Type type) {
154   assert(type.isSignlessInteger() &&
155          "ConstantIntOp can only have signless integer type values");
156   arith::ConstantOp::build(builder, result, type,
157                            builder.getIntegerAttr(type, value));
158 }
159 
160 bool arith::ConstantIntOp::classof(Operation *op) {
161   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
162     return constOp.getType().isSignlessInteger();
163   return false;
164 }
165 
166 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
167                                    const APFloat &value, FloatType type) {
168   arith::ConstantOp::build(builder, result, type,
169                            builder.getFloatAttr(type, value));
170 }
171 
172 bool arith::ConstantFloatOp::classof(Operation *op) {
173   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
174     return constOp.getType().isa<FloatType>();
175   return false;
176 }
177 
178 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
179                                    int64_t value) {
180   arith::ConstantOp::build(builder, result, builder.getIndexType(),
181                            builder.getIndexAttr(value));
182 }
183 
184 bool arith::ConstantIndexOp::classof(Operation *op) {
185   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
186     return constOp.getType().isIndex();
187   return false;
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // AddIOp
192 //===----------------------------------------------------------------------===//
193 
194 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
195   // addi(x, 0) -> x
196   if (matchPattern(getRhs(), m_Zero()))
197     return getLhs();
198 
199   // addi(subi(a, b), b) -> a
200   if (auto sub = getLhs().getDefiningOp<SubIOp>())
201     if (getRhs() == sub.getRhs())
202       return sub.getLhs();
203 
204   // addi(b, subi(a, b)) -> a
205   if (auto sub = getRhs().getDefiningOp<SubIOp>())
206     if (getLhs() == sub.getRhs())
207       return sub.getLhs();
208 
209   return constFoldBinaryOp<IntegerAttr>(
210       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
211 }
212 
213 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
214                                                 MLIRContext *context) {
215   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
216       context);
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // SubIOp
221 //===----------------------------------------------------------------------===//
222 
223 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
224   // subi(x,x) -> 0
225   if (getOperand(0) == getOperand(1))
226     return Builder(getContext()).getZeroAttr(getType());
227   // subi(x,0) -> x
228   if (matchPattern(getRhs(), m_Zero()))
229     return getLhs();
230 
231   return constFoldBinaryOp<IntegerAttr>(
232       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
233 }
234 
235 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
236                                                 MLIRContext *context) {
237   patterns
238       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
239            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
240           context);
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // MulIOp
245 //===----------------------------------------------------------------------===//
246 
247 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
248   // muli(x, 0) -> 0
249   if (matchPattern(getRhs(), m_Zero()))
250     return getRhs();
251   // muli(x, 1) -> x
252   if (matchPattern(getRhs(), m_One()))
253     return getOperand(0);
254   // TODO: Handle the overflow case.
255 
256   // default folder
257   return constFoldBinaryOp<IntegerAttr>(
258       operands, [](const APInt &a, const APInt &b) { return a * b; });
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // DivUIOp
263 //===----------------------------------------------------------------------===//
264 
265 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
266   // divui (x, 1) -> x.
267   if (matchPattern(getRhs(), m_One()))
268     return getLhs();
269 
270   // Don't fold if it would require a division by zero.
271   bool div0 = false;
272   auto result =
273       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
274         if (div0 || !b) {
275           div0 = true;
276           return a;
277         }
278         return a.udiv(b);
279       });
280 
281   return div0 ? Attribute() : result;
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // DivSIOp
286 //===----------------------------------------------------------------------===//
287 
288 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
289   // divsi (x, 1) -> x.
290   if (matchPattern(getRhs(), m_One()))
291     return getLhs();
292 
293   // Don't fold if it would overflow or if it requires a division by zero.
294   bool overflowOrDiv0 = false;
295   auto result =
296       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
297         if (overflowOrDiv0 || !b) {
298           overflowOrDiv0 = true;
299           return a;
300         }
301         return a.sdiv_ov(b, overflowOrDiv0);
302       });
303 
304   return overflowOrDiv0 ? Attribute() : result;
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // Ceil and floor division folding helpers
309 //===----------------------------------------------------------------------===//
310 
311 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
312                                     bool &overflow) {
313   // Returns (a-1)/b + 1
314   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
315   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
316   return val.sadd_ov(one, overflow);
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // CeilDivUIOp
321 //===----------------------------------------------------------------------===//
322 
323 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
324   // ceildivui (x, 1) -> x.
325   if (matchPattern(getRhs(), m_One()))
326     return getLhs();
327 
328   bool overflowOrDiv0 = false;
329   auto result =
330       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
331         if (overflowOrDiv0 || !b) {
332           overflowOrDiv0 = true;
333           return a;
334         }
335         APInt quotient = a.udiv(b);
336         if (!a.urem(b))
337           return quotient;
338         APInt one(a.getBitWidth(), 1, true);
339         return quotient.uadd_ov(one, overflowOrDiv0);
340       });
341 
342   return overflowOrDiv0 ? Attribute() : result;
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // CeilDivSIOp
347 //===----------------------------------------------------------------------===//
348 
349 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
350   // ceildivsi (x, 1) -> x.
351   if (matchPattern(getRhs(), m_One()))
352     return getLhs();
353 
354   // Don't fold if it would overflow or if it requires a division by zero.
355   bool overflowOrDiv0 = false;
356   auto result =
357       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
358         if (overflowOrDiv0 || !b) {
359           overflowOrDiv0 = true;
360           return a;
361         }
362         if (!a)
363           return a;
364         // After this point we know that neither a or b are zero.
365         unsigned bits = a.getBitWidth();
366         APInt zero = APInt::getZero(bits);
367         bool aGtZero = a.sgt(zero);
368         bool bGtZero = b.sgt(zero);
369         if (aGtZero && bGtZero) {
370           // Both positive, return ceil(a, b).
371           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
372         }
373         if (!aGtZero && !bGtZero) {
374           // Both negative, return ceil(-a, -b).
375           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
376           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
377           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
378         }
379         if (!aGtZero && bGtZero) {
380           // A is negative, b is positive, return - ( -a / b).
381           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
382           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
383           return zero.ssub_ov(div, overflowOrDiv0);
384         }
385         // A is positive, b is negative, return - (a / -b).
386         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
387         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
388         return zero.ssub_ov(div, overflowOrDiv0);
389       });
390 
391   return overflowOrDiv0 ? Attribute() : result;
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // FloorDivSIOp
396 //===----------------------------------------------------------------------===//
397 
398 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
399   // floordivsi (x, 1) -> x.
400   if (matchPattern(getRhs(), m_One()))
401     return getLhs();
402 
403   // Don't fold if it would overflow or if it requires a division by zero.
404   bool overflowOrDiv0 = false;
405   auto result =
406       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
407         if (overflowOrDiv0 || !b) {
408           overflowOrDiv0 = true;
409           return a;
410         }
411         if (!a)
412           return a;
413         // After this point we know that neither a or b are zero.
414         unsigned bits = a.getBitWidth();
415         APInt zero = APInt::getZero(bits);
416         bool aGtZero = a.sgt(zero);
417         bool bGtZero = b.sgt(zero);
418         if (aGtZero && bGtZero) {
419           // Both positive, return a / b.
420           return a.sdiv_ov(b, overflowOrDiv0);
421         }
422         if (!aGtZero && !bGtZero) {
423           // Both negative, return -a / -b.
424           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
425           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
426           return posA.sdiv_ov(posB, overflowOrDiv0);
427         }
428         if (!aGtZero && bGtZero) {
429           // A is negative, b is positive, return - ceil(-a, b).
430           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
431           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
432           return zero.ssub_ov(ceil, overflowOrDiv0);
433         }
434         // A is positive, b is negative, return - ceil(a, -b).
435         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
436         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
437         return zero.ssub_ov(ceil, overflowOrDiv0);
438       });
439 
440   return overflowOrDiv0 ? Attribute() : result;
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // RemUIOp
445 //===----------------------------------------------------------------------===//
446 
447 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
448   // remui (x, 1) -> 0.
449   if (matchPattern(getRhs(), m_One()))
450     return Builder(getContext()).getZeroAttr(getType());
451 
452   // Don't fold if it would require a division by zero.
453   bool div0 = false;
454   auto result =
455       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
456         if (div0 || b.isNullValue()) {
457           div0 = true;
458           return a;
459         }
460         return a.urem(b);
461       });
462 
463   return div0 ? Attribute() : result;
464 }
465 
466 //===----------------------------------------------------------------------===//
467 // RemSIOp
468 //===----------------------------------------------------------------------===//
469 
470 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
471   // remsi (x, 1) -> 0.
472   if (matchPattern(getRhs(), m_One()))
473     return Builder(getContext()).getZeroAttr(getType());
474 
475   // Don't fold if it would require a division by zero.
476   bool div0 = false;
477   auto result =
478       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
479         if (div0 || b.isNullValue()) {
480           div0 = true;
481           return a;
482         }
483         return a.srem(b);
484       });
485 
486   return div0 ? Attribute() : result;
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // AndIOp
491 //===----------------------------------------------------------------------===//
492 
493 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
494   /// and(x, 0) -> 0
495   if (matchPattern(getRhs(), m_Zero()))
496     return getRhs();
497   /// and(x, allOnes) -> x
498   APInt intValue;
499   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
500     return getLhs();
501 
502   return constFoldBinaryOp<IntegerAttr>(
503       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // OrIOp
508 //===----------------------------------------------------------------------===//
509 
510 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
511   /// or(x, 0) -> x
512   if (matchPattern(getRhs(), m_Zero()))
513     return getLhs();
514   /// or(x, <all ones>) -> <all ones>
515   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
516     if (rhsAttr.getValue().isAllOnes())
517       return rhsAttr;
518 
519   return constFoldBinaryOp<IntegerAttr>(
520       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
521 }
522 
523 //===----------------------------------------------------------------------===//
524 // XOrIOp
525 //===----------------------------------------------------------------------===//
526 
527 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
528   /// xor(x, 0) -> x
529   if (matchPattern(getRhs(), m_Zero()))
530     return getLhs();
531   /// xor(x, x) -> 0
532   if (getLhs() == getRhs())
533     return Builder(getContext()).getZeroAttr(getType());
534   /// xor(xor(x, a), a) -> x
535   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
536     if (prev.getRhs() == getRhs())
537       return prev.getLhs();
538 
539   return constFoldBinaryOp<IntegerAttr>(
540       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
541 }
542 
543 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
544                                                 MLIRContext *context) {
545   patterns.add<XOrINotCmpI>(context);
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // NegFOp
550 //===----------------------------------------------------------------------===//
551 
552 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
553   /// negf(negf(x)) -> x
554   if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
555     return op.getOperand();
556   return constFoldUnaryOp<FloatAttr>(operands,
557                                      [](const APFloat &a) { return -a; });
558 }
559 
560 //===----------------------------------------------------------------------===//
561 // AddFOp
562 //===----------------------------------------------------------------------===//
563 
564 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
565   // addf(x, -0) -> x
566   if (matchPattern(getRhs(), m_NegZeroFloat()))
567     return getLhs();
568 
569   return constFoldBinaryOp<FloatAttr>(
570       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
571 }
572 
573 //===----------------------------------------------------------------------===//
574 // SubFOp
575 //===----------------------------------------------------------------------===//
576 
577 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
578   // subf(x, +0) -> x
579   if (matchPattern(getRhs(), m_PosZeroFloat()))
580     return getLhs();
581 
582   return constFoldBinaryOp<FloatAttr>(
583       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
584 }
585 
586 //===----------------------------------------------------------------------===//
587 // MaxFOp
588 //===----------------------------------------------------------------------===//
589 
590 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
591   assert(operands.size() == 2 && "maxf takes two operands");
592 
593   // maxf(x,x) -> x
594   if (getLhs() == getRhs())
595     return getRhs();
596 
597   // maxf(x, -inf) -> x
598   if (matchPattern(getRhs(), m_NegInfFloat()))
599     return getLhs();
600 
601   return constFoldBinaryOp<FloatAttr>(
602       operands,
603       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
604 }
605 
606 //===----------------------------------------------------------------------===//
607 // MaxSIOp
608 //===----------------------------------------------------------------------===//
609 
610 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
611   assert(operands.size() == 2 && "binary operation takes two operands");
612 
613   // maxsi(x,x) -> x
614   if (getLhs() == getRhs())
615     return getRhs();
616 
617   APInt intValue;
618   // maxsi(x,MAX_INT) -> MAX_INT
619   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
620       intValue.isMaxSignedValue())
621     return getRhs();
622 
623   // maxsi(x, MIN_INT) -> x
624   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
625       intValue.isMinSignedValue())
626     return getLhs();
627 
628   return constFoldBinaryOp<IntegerAttr>(operands,
629                                         [](const APInt &a, const APInt &b) {
630                                           return llvm::APIntOps::smax(a, b);
631                                         });
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // MaxUIOp
636 //===----------------------------------------------------------------------===//
637 
638 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
639   assert(operands.size() == 2 && "binary operation takes two operands");
640 
641   // maxui(x,x) -> x
642   if (getLhs() == getRhs())
643     return getRhs();
644 
645   APInt intValue;
646   // maxui(x,MAX_INT) -> MAX_INT
647   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
648     return getRhs();
649 
650   // maxui(x, MIN_INT) -> x
651   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
652     return getLhs();
653 
654   return constFoldBinaryOp<IntegerAttr>(operands,
655                                         [](const APInt &a, const APInt &b) {
656                                           return llvm::APIntOps::umax(a, b);
657                                         });
658 }
659 
660 //===----------------------------------------------------------------------===//
661 // MinFOp
662 //===----------------------------------------------------------------------===//
663 
664 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
665   assert(operands.size() == 2 && "minf takes two operands");
666 
667   // minf(x,x) -> x
668   if (getLhs() == getRhs())
669     return getRhs();
670 
671   // minf(x, +inf) -> x
672   if (matchPattern(getRhs(), m_PosInfFloat()))
673     return getLhs();
674 
675   return constFoldBinaryOp<FloatAttr>(
676       operands,
677       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
678 }
679 
680 //===----------------------------------------------------------------------===//
681 // MinSIOp
682 //===----------------------------------------------------------------------===//
683 
684 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
685   assert(operands.size() == 2 && "binary operation takes two operands");
686 
687   // minsi(x,x) -> x
688   if (getLhs() == getRhs())
689     return getRhs();
690 
691   APInt intValue;
692   // minsi(x,MIN_INT) -> MIN_INT
693   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
694       intValue.isMinSignedValue())
695     return getRhs();
696 
697   // minsi(x, MAX_INT) -> x
698   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
699       intValue.isMaxSignedValue())
700     return getLhs();
701 
702   return constFoldBinaryOp<IntegerAttr>(operands,
703                                         [](const APInt &a, const APInt &b) {
704                                           return llvm::APIntOps::smin(a, b);
705                                         });
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // MinUIOp
710 //===----------------------------------------------------------------------===//
711 
712 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
713   assert(operands.size() == 2 && "binary operation takes two operands");
714 
715   // minui(x,x) -> x
716   if (getLhs() == getRhs())
717     return getRhs();
718 
719   APInt intValue;
720   // minui(x,MIN_INT) -> MIN_INT
721   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
722     return getRhs();
723 
724   // minui(x, MAX_INT) -> x
725   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
726     return getLhs();
727 
728   return constFoldBinaryOp<IntegerAttr>(operands,
729                                         [](const APInt &a, const APInt &b) {
730                                           return llvm::APIntOps::umin(a, b);
731                                         });
732 }
733 
734 //===----------------------------------------------------------------------===//
735 // MulFOp
736 //===----------------------------------------------------------------------===//
737 
738 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
739   // mulf(x, 1) -> x
740   if (matchPattern(getRhs(), m_OneFloat()))
741     return getLhs();
742 
743   return constFoldBinaryOp<FloatAttr>(
744       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
745 }
746 
747 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
748                                                 MLIRContext *context) {
749   patterns.add<MulFOfNegF>(context);
750 }
751 
752 //===----------------------------------------------------------------------===//
753 // DivFOp
754 //===----------------------------------------------------------------------===//
755 
756 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
757   // divf(x, 1) -> x
758   if (matchPattern(getRhs(), m_OneFloat()))
759     return getLhs();
760 
761   return constFoldBinaryOp<FloatAttr>(
762       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
763 }
764 
765 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
766                                                 MLIRContext *context) {
767   patterns.add<DivFOfNegF>(context);
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // RemFOp
772 //===----------------------------------------------------------------------===//
773 
774 OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) {
775   return constFoldBinaryOp<FloatAttr>(operands,
776                                       [](const APFloat &a, const APFloat &b) {
777                                         APFloat result(a);
778                                         (void)result.remainder(b);
779                                         return result;
780                                       });
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // Utility functions for verifying cast ops
785 //===----------------------------------------------------------------------===//
786 
787 template <typename... Types>
788 using type_list = std::tuple<Types...> *;
789 
790 /// Returns a non-null type only if the provided type is one of the allowed
791 /// types or one of the allowed shaped types of the allowed types. Returns the
792 /// element type if a valid shaped type is provided.
793 template <typename... ShapedTypes, typename... ElementTypes>
794 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
795                               type_list<ElementTypes...>) {
796   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
797     return {};
798 
799   auto underlyingType = getElementTypeOrSelf(type);
800   if (!underlyingType.isa<ElementTypes...>())
801     return {};
802 
803   return underlyingType;
804 }
805 
806 /// Get allowed underlying types for vectors and tensors.
807 template <typename... ElementTypes>
808 static Type getTypeIfLike(Type type) {
809   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
810                            type_list<ElementTypes...>());
811 }
812 
813 /// Get allowed underlying types for vectors, tensors, and memrefs.
814 template <typename... ElementTypes>
815 static Type getTypeIfLikeOrMemRef(Type type) {
816   return getUnderlyingType(type,
817                            type_list<VectorType, TensorType, MemRefType>(),
818                            type_list<ElementTypes...>());
819 }
820 
821 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
822   return inputs.size() == 1 && outputs.size() == 1 &&
823          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
824 }
825 
826 //===----------------------------------------------------------------------===//
827 // Verifiers for integer and floating point extension/truncation ops
828 //===----------------------------------------------------------------------===//
829 
830 // Extend ops can only extend to a wider type.
831 template <typename ValType, typename Op>
832 static LogicalResult verifyExtOp(Op op) {
833   Type srcType = getElementTypeOrSelf(op.getIn().getType());
834   Type dstType = getElementTypeOrSelf(op.getType());
835 
836   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
837     return op.emitError("result type ")
838            << dstType << " must be wider than operand type " << srcType;
839 
840   return success();
841 }
842 
843 // Truncate ops can only truncate to a shorter type.
844 template <typename ValType, typename Op>
845 static LogicalResult verifyTruncateOp(Op op) {
846   Type srcType = getElementTypeOrSelf(op.getIn().getType());
847   Type dstType = getElementTypeOrSelf(op.getType());
848 
849   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
850     return op.emitError("result type ")
851            << dstType << " must be shorter than operand type " << srcType;
852 
853   return success();
854 }
855 
856 /// Validate a cast that changes the width of a type.
857 template <template <typename> class WidthComparator, typename... ElementTypes>
858 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
859   if (!areValidCastInputsAndOutputs(inputs, outputs))
860     return false;
861 
862   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
863   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
864   if (!srcType || !dstType)
865     return false;
866 
867   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
868                                      srcType.getIntOrFloatBitWidth());
869 }
870 
871 //===----------------------------------------------------------------------===//
872 // ExtUIOp
873 //===----------------------------------------------------------------------===//
874 
875 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
876   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
877     getInMutable().assign(lhs.getIn());
878     return getResult();
879   }
880   Type resType = getType();
881   unsigned bitWidth;
882   if (auto shapedType = resType.dyn_cast<ShapedType>())
883     bitWidth = shapedType.getElementTypeBitWidth();
884   else
885     bitWidth = resType.getIntOrFloatBitWidth();
886   return constFoldCastOp<IntegerAttr, IntegerAttr>(
887       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
888         return a.zext(bitWidth);
889       });
890 }
891 
892 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
893   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
894 }
895 
896 LogicalResult arith::ExtUIOp::verify() {
897   return verifyExtOp<IntegerType>(*this);
898 }
899 
900 //===----------------------------------------------------------------------===//
901 // ExtSIOp
902 //===----------------------------------------------------------------------===//
903 
904 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
905   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
906     getInMutable().assign(lhs.getIn());
907     return getResult();
908   }
909   Type resType = getType();
910   unsigned bitWidth;
911   if (auto shapedType = resType.dyn_cast<ShapedType>())
912     bitWidth = shapedType.getElementTypeBitWidth();
913   else
914     bitWidth = resType.getIntOrFloatBitWidth();
915   return constFoldCastOp<IntegerAttr, IntegerAttr>(
916       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
917         return a.sext(bitWidth);
918       });
919 }
920 
921 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
922   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
923 }
924 
925 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
926                                                  MLIRContext *context) {
927   patterns.add<ExtSIOfExtUI>(context);
928 }
929 
930 LogicalResult arith::ExtSIOp::verify() {
931   return verifyExtOp<IntegerType>(*this);
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // ExtFOp
936 //===----------------------------------------------------------------------===//
937 
938 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
939   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
940 }
941 
942 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
943 
944 //===----------------------------------------------------------------------===//
945 // TruncIOp
946 //===----------------------------------------------------------------------===//
947 
948 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
949   assert(operands.size() == 1 && "unary operation takes one operand");
950 
951   // trunci(zexti(a)) -> a
952   // trunci(sexti(a)) -> a
953   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
954       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
955     return getOperand().getDefiningOp()->getOperand(0);
956 
957   // trunci(trunci(a)) -> trunci(a))
958   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
959     setOperand(getOperand().getDefiningOp()->getOperand(0));
960     return getResult();
961   }
962 
963   Type resType = getType();
964   unsigned bitWidth;
965   if (auto shapedType = resType.dyn_cast<ShapedType>())
966     bitWidth = shapedType.getElementTypeBitWidth();
967   else
968     bitWidth = resType.getIntOrFloatBitWidth();
969 
970   return constFoldCastOp<IntegerAttr, IntegerAttr>(
971       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
972         return a.trunc(bitWidth);
973       });
974 }
975 
976 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
977   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
978 }
979 
980 LogicalResult arith::TruncIOp::verify() {
981   return verifyTruncateOp<IntegerType>(*this);
982 }
983 
984 //===----------------------------------------------------------------------===//
985 // TruncFOp
986 //===----------------------------------------------------------------------===//
987 
988 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
989 /// can be represented without precision loss or rounding.
990 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
991   assert(operands.size() == 1 && "unary operation takes one operand");
992 
993   auto constOperand = operands.front();
994   if (!constOperand || !constOperand.isa<FloatAttr>())
995     return {};
996 
997   // Convert to target type via 'double'.
998   double sourceValue =
999       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
1000   auto targetAttr = FloatAttr::get(getType(), sourceValue);
1001 
1002   // Propagate if constant's value does not change after truncation.
1003   if (sourceValue == targetAttr.getValue().convertToDouble())
1004     return targetAttr;
1005 
1006   return {};
1007 }
1008 
1009 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1010   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1011 }
1012 
1013 LogicalResult arith::TruncFOp::verify() {
1014   return verifyTruncateOp<FloatType>(*this);
1015 }
1016 
1017 //===----------------------------------------------------------------------===//
1018 // AndIOp
1019 //===----------------------------------------------------------------------===//
1020 
1021 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1022                                                 MLIRContext *context) {
1023   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // OrIOp
1028 //===----------------------------------------------------------------------===//
1029 
1030 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1031                                                MLIRContext *context) {
1032   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1033 }
1034 
1035 //===----------------------------------------------------------------------===//
1036 // Verifiers for casts between integers and floats.
1037 //===----------------------------------------------------------------------===//
1038 
1039 template <typename From, typename To>
1040 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1041   if (!areValidCastInputsAndOutputs(inputs, outputs))
1042     return false;
1043 
1044   auto srcType = getTypeIfLike<From>(inputs.front());
1045   auto dstType = getTypeIfLike<To>(outputs.back());
1046 
1047   return srcType && dstType;
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // UIToFPOp
1052 //===----------------------------------------------------------------------===//
1053 
1054 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1055   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1056 }
1057 
1058 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1059   Type resType = getType();
1060   Type resEleType;
1061   if (auto shapedType = resType.dyn_cast<ShapedType>())
1062     resEleType = shapedType.getElementType();
1063   else
1064     resEleType = resType;
1065   return constFoldCastOp<IntegerAttr, FloatAttr>(
1066       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1067         FloatType floatTy = resEleType.cast<FloatType>();
1068         APFloat apf(floatTy.getFloatSemantics(),
1069                     APInt::getZero(floatTy.getWidth()));
1070         apf.convertFromAPInt(a, /*IsSigned=*/false,
1071                              APFloat::rmNearestTiesToEven);
1072         return apf;
1073       });
1074 }
1075 
1076 //===----------------------------------------------------------------------===//
1077 // SIToFPOp
1078 //===----------------------------------------------------------------------===//
1079 
1080 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1081   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1082 }
1083 
1084 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1085   Type resType = getType();
1086   Type resEleType;
1087   if (auto shapedType = resType.dyn_cast<ShapedType>())
1088     resEleType = shapedType.getElementType();
1089   else
1090     resEleType = resType;
1091   return constFoldCastOp<IntegerAttr, FloatAttr>(
1092       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1093         FloatType floatTy = resEleType.cast<FloatType>();
1094         APFloat apf(floatTy.getFloatSemantics(),
1095                     APInt::getZero(floatTy.getWidth()));
1096         apf.convertFromAPInt(a, /*IsSigned=*/true,
1097                              APFloat::rmNearestTiesToEven);
1098         return apf;
1099       });
1100 }
1101 //===----------------------------------------------------------------------===//
1102 // FPToUIOp
1103 //===----------------------------------------------------------------------===//
1104 
1105 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1106   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1107 }
1108 
1109 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1110   Type resType = getType();
1111   Type resEleType;
1112   if (auto shapedType = resType.dyn_cast<ShapedType>())
1113     resEleType = shapedType.getElementType();
1114   else
1115     resEleType = resType;
1116   return constFoldCastOp<FloatAttr, IntegerAttr>(
1117       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1118         IntegerType intTy = resEleType.cast<IntegerType>();
1119         bool ignored;
1120         APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1121         castStatus = APFloat::opInvalidOp !=
1122                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1123         return api;
1124       });
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // FPToSIOp
1129 //===----------------------------------------------------------------------===//
1130 
1131 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1132   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1133 }
1134 
1135 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1136   Type resType = getType();
1137   Type resEleType;
1138   if (auto shapedType = resType.dyn_cast<ShapedType>())
1139     resEleType = shapedType.getElementType();
1140   else
1141     resEleType = resType;
1142   return constFoldCastOp<FloatAttr, IntegerAttr>(
1143       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1144         IntegerType intTy = resEleType.cast<IntegerType>();
1145         bool ignored;
1146         APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1147         castStatus = APFloat::opInvalidOp !=
1148                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1149         return api;
1150       });
1151 }
1152 
1153 //===----------------------------------------------------------------------===//
1154 // IndexCastOp
1155 //===----------------------------------------------------------------------===//
1156 
1157 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1158                                            TypeRange outputs) {
1159   if (!areValidCastInputsAndOutputs(inputs, outputs))
1160     return false;
1161 
1162   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1163   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1164   if (!srcType || !dstType)
1165     return false;
1166 
1167   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1168          (srcType.isSignlessInteger() && dstType.isIndex());
1169 }
1170 
1171 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1172   // index_cast(constant) -> constant
1173   // A little hack because we go through int. Otherwise, the size of the
1174   // constant might need to change.
1175   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1176     return IntegerAttr::get(getType(), value.getInt());
1177 
1178   return {};
1179 }
1180 
1181 void arith::IndexCastOp::getCanonicalizationPatterns(
1182     RewritePatternSet &patterns, MLIRContext *context) {
1183   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1184 }
1185 
1186 //===----------------------------------------------------------------------===//
1187 // BitcastOp
1188 //===----------------------------------------------------------------------===//
1189 
1190 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1191   if (!areValidCastInputsAndOutputs(inputs, outputs))
1192     return false;
1193 
1194   auto srcType =
1195       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1196   auto dstType =
1197       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1198   if (!srcType || !dstType)
1199     return false;
1200 
1201   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1202 }
1203 
1204 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1205   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1206 
1207   auto resType = getType();
1208   auto operand = operands[0];
1209   if (!operand)
1210     return {};
1211 
1212   /// Bitcast dense elements.
1213   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1214     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1215   /// Other shaped types unhandled.
1216   if (resType.isa<ShapedType>())
1217     return {};
1218 
1219   /// Bitcast integer or float to integer or float.
1220   APInt bits = operand.isa<FloatAttr>()
1221                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1222                    : operand.cast<IntegerAttr>().getValue();
1223 
1224   if (auto resFloatType = resType.dyn_cast<FloatType>())
1225     return FloatAttr::get(resType,
1226                           APFloat(resFloatType.getFloatSemantics(), bits));
1227   return IntegerAttr::get(resType, bits);
1228 }
1229 
1230 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1231                                                    MLIRContext *context) {
1232   patterns.add<BitcastOfBitcast>(context);
1233 }
1234 
1235 //===----------------------------------------------------------------------===//
1236 // Helpers for compare ops
1237 //===----------------------------------------------------------------------===//
1238 
1239 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1240 static Type getI1SameShape(Type type) {
1241   auto i1Type = IntegerType::get(type.getContext(), 1);
1242   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1243     return RankedTensorType::get(tensorType.getShape(), i1Type);
1244   if (type.isa<UnrankedTensorType>())
1245     return UnrankedTensorType::get(i1Type);
1246   if (auto vectorType = type.dyn_cast<VectorType>())
1247     return VectorType::get(vectorType.getShape(), i1Type,
1248                            vectorType.getNumScalableDims());
1249   return i1Type;
1250 }
1251 
1252 //===----------------------------------------------------------------------===//
1253 // CmpIOp
1254 //===----------------------------------------------------------------------===//
1255 
1256 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1257 /// comparison predicates.
1258 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1259                                     const APInt &lhs, const APInt &rhs) {
1260   switch (predicate) {
1261   case arith::CmpIPredicate::eq:
1262     return lhs.eq(rhs);
1263   case arith::CmpIPredicate::ne:
1264     return lhs.ne(rhs);
1265   case arith::CmpIPredicate::slt:
1266     return lhs.slt(rhs);
1267   case arith::CmpIPredicate::sle:
1268     return lhs.sle(rhs);
1269   case arith::CmpIPredicate::sgt:
1270     return lhs.sgt(rhs);
1271   case arith::CmpIPredicate::sge:
1272     return lhs.sge(rhs);
1273   case arith::CmpIPredicate::ult:
1274     return lhs.ult(rhs);
1275   case arith::CmpIPredicate::ule:
1276     return lhs.ule(rhs);
1277   case arith::CmpIPredicate::ugt:
1278     return lhs.ugt(rhs);
1279   case arith::CmpIPredicate::uge:
1280     return lhs.uge(rhs);
1281   }
1282   llvm_unreachable("unknown cmpi predicate kind");
1283 }
1284 
1285 /// Returns true if the predicate is true for two equal operands.
1286 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1287   switch (predicate) {
1288   case arith::CmpIPredicate::eq:
1289   case arith::CmpIPredicate::sle:
1290   case arith::CmpIPredicate::sge:
1291   case arith::CmpIPredicate::ule:
1292   case arith::CmpIPredicate::uge:
1293     return true;
1294   case arith::CmpIPredicate::ne:
1295   case arith::CmpIPredicate::slt:
1296   case arith::CmpIPredicate::sgt:
1297   case arith::CmpIPredicate::ult:
1298   case arith::CmpIPredicate::ugt:
1299     return false;
1300   }
1301   llvm_unreachable("unknown cmpi predicate kind");
1302 }
1303 
1304 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1305   auto boolAttr = BoolAttr::get(ctx, value);
1306   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1307   if (!shapedType)
1308     return boolAttr;
1309   return DenseElementsAttr::get(shapedType, boolAttr);
1310 }
1311 
1312 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1313   assert(operands.size() == 2 && "cmpi takes two operands");
1314 
1315   // cmpi(pred, x, x)
1316   if (getLhs() == getRhs()) {
1317     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1318     return getBoolAttribute(getType(), getContext(), val);
1319   }
1320 
1321   if (matchPattern(getRhs(), m_Zero())) {
1322     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1323       // extsi(%x : i1 -> iN) != 0  ->  %x
1324       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1325           getPredicate() == arith::CmpIPredicate::ne)
1326         return extOp.getOperand();
1327     }
1328     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1329       // extui(%x : i1 -> iN) != 0  ->  %x
1330       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1331           getPredicate() == arith::CmpIPredicate::ne)
1332         return extOp.getOperand();
1333     }
1334   }
1335 
1336   // Move constant to the right side.
1337   if (operands[0] && !operands[1]) {
1338     // Do not use invertPredicate, as it will change eq to ne and vice versa.
1339     using Pred = CmpIPredicate;
1340     const std::pair<Pred, Pred> invPreds[] = {
1341         {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1342         {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1343         {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1344         {Pred::ne, Pred::ne},
1345     };
1346     Pred origPred = getPredicate();
1347     for (auto pred : invPreds) {
1348       if (origPred == pred.first) {
1349         setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second));
1350         Value lhs = getLhs();
1351         Value rhs = getRhs();
1352         getLhsMutable().assign(rhs);
1353         getRhsMutable().assign(lhs);
1354         return getResult();
1355       }
1356     }
1357     llvm_unreachable("unknown cmpi predicate kind");
1358   }
1359 
1360   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1361   if (!lhs)
1362     return {};
1363 
1364   // We are moving constants to the right side; So if lhs is constant rhs is
1365   // guaranteed to be a constant.
1366   auto rhs = operands.back().cast<IntegerAttr>();
1367 
1368   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1369   return BoolAttr::get(getContext(), val);
1370 }
1371 
1372 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1373                                                 MLIRContext *context) {
1374   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1375 }
1376 
1377 //===----------------------------------------------------------------------===//
1378 // CmpFOp
1379 //===----------------------------------------------------------------------===//
1380 
1381 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1382 /// comparison predicates.
1383 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1384                                     const APFloat &lhs, const APFloat &rhs) {
1385   auto cmpResult = lhs.compare(rhs);
1386   switch (predicate) {
1387   case arith::CmpFPredicate::AlwaysFalse:
1388     return false;
1389   case arith::CmpFPredicate::OEQ:
1390     return cmpResult == APFloat::cmpEqual;
1391   case arith::CmpFPredicate::OGT:
1392     return cmpResult == APFloat::cmpGreaterThan;
1393   case arith::CmpFPredicate::OGE:
1394     return cmpResult == APFloat::cmpGreaterThan ||
1395            cmpResult == APFloat::cmpEqual;
1396   case arith::CmpFPredicate::OLT:
1397     return cmpResult == APFloat::cmpLessThan;
1398   case arith::CmpFPredicate::OLE:
1399     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1400   case arith::CmpFPredicate::ONE:
1401     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1402   case arith::CmpFPredicate::ORD:
1403     return cmpResult != APFloat::cmpUnordered;
1404   case arith::CmpFPredicate::UEQ:
1405     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1406   case arith::CmpFPredicate::UGT:
1407     return cmpResult == APFloat::cmpUnordered ||
1408            cmpResult == APFloat::cmpGreaterThan;
1409   case arith::CmpFPredicate::UGE:
1410     return cmpResult == APFloat::cmpUnordered ||
1411            cmpResult == APFloat::cmpGreaterThan ||
1412            cmpResult == APFloat::cmpEqual;
1413   case arith::CmpFPredicate::ULT:
1414     return cmpResult == APFloat::cmpUnordered ||
1415            cmpResult == APFloat::cmpLessThan;
1416   case arith::CmpFPredicate::ULE:
1417     return cmpResult == APFloat::cmpUnordered ||
1418            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1419   case arith::CmpFPredicate::UNE:
1420     return cmpResult != APFloat::cmpEqual;
1421   case arith::CmpFPredicate::UNO:
1422     return cmpResult == APFloat::cmpUnordered;
1423   case arith::CmpFPredicate::AlwaysTrue:
1424     return true;
1425   }
1426   llvm_unreachable("unknown cmpf predicate kind");
1427 }
1428 
1429 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1430   assert(operands.size() == 2 && "cmpf takes two operands");
1431 
1432   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1433   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1434 
1435   // If one operand is NaN, making them both NaN does not change the result.
1436   if (lhs && lhs.getValue().isNaN())
1437     rhs = lhs;
1438   if (rhs && rhs.getValue().isNaN())
1439     lhs = rhs;
1440 
1441   if (!lhs || !rhs)
1442     return {};
1443 
1444   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1445   return BoolAttr::get(getContext(), val);
1446 }
1447 
1448 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1449 public:
1450   using OpRewritePattern<CmpFOp>::OpRewritePattern;
1451 
1452   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1453                                                  bool isUnsigned) {
1454     using namespace arith;
1455     switch (pred) {
1456     case CmpFPredicate::UEQ:
1457     case CmpFPredicate::OEQ:
1458       return CmpIPredicate::eq;
1459     case CmpFPredicate::UGT:
1460     case CmpFPredicate::OGT:
1461       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1462     case CmpFPredicate::UGE:
1463     case CmpFPredicate::OGE:
1464       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1465     case CmpFPredicate::ULT:
1466     case CmpFPredicate::OLT:
1467       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1468     case CmpFPredicate::ULE:
1469     case CmpFPredicate::OLE:
1470       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1471     case CmpFPredicate::UNE:
1472     case CmpFPredicate::ONE:
1473       return CmpIPredicate::ne;
1474     default:
1475       llvm_unreachable("Unexpected predicate!");
1476     }
1477   }
1478 
1479   LogicalResult matchAndRewrite(CmpFOp op,
1480                                 PatternRewriter &rewriter) const override {
1481     FloatAttr flt;
1482     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1483       return failure();
1484 
1485     const APFloat &rhs = flt.getValue();
1486 
1487     // Don't attempt to fold a nan.
1488     if (rhs.isNaN())
1489       return failure();
1490 
1491     // Get the width of the mantissa.  We don't want to hack on conversions that
1492     // might lose information from the integer, e.g. "i64 -> float"
1493     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1494     int mantissaWidth = floatTy.getFPMantissaWidth();
1495     if (mantissaWidth <= 0)
1496       return failure();
1497 
1498     bool isUnsigned;
1499     Value intVal;
1500 
1501     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1502       isUnsigned = false;
1503       intVal = si.getIn();
1504     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1505       isUnsigned = true;
1506       intVal = ui.getIn();
1507     } else {
1508       return failure();
1509     }
1510 
1511     // Check to see that the input is converted from an integer type that is
1512     // small enough that preserves all bits.
1513     auto intTy = intVal.getType().cast<IntegerType>();
1514     auto intWidth = intTy.getWidth();
1515 
1516     // Number of bits representing values, as opposed to the sign
1517     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1518 
1519     // Following test does NOT adjust intWidth downwards for signed inputs,
1520     // because the most negative value still requires all the mantissa bits
1521     // to distinguish it from one less than that value.
1522     if ((int)intWidth > mantissaWidth) {
1523       // Conversion would lose accuracy. Check if loss can impact comparison.
1524       int exponent = ilogb(rhs);
1525       if (exponent == APFloat::IEK_Inf) {
1526         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1527         if (maxExponent < (int)valueBits) {
1528           // Conversion could create infinity.
1529           return failure();
1530         }
1531       } else {
1532         // Note that if rhs is zero or NaN, then Exp is negative
1533         // and first condition is trivially false.
1534         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1535           // Conversion could affect comparison.
1536           return failure();
1537         }
1538       }
1539     }
1540 
1541     // Convert to equivalent cmpi predicate
1542     CmpIPredicate pred;
1543     switch (op.getPredicate()) {
1544     case CmpFPredicate::ORD:
1545       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1546       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1547                                                  /*width=*/1);
1548       return success();
1549     case CmpFPredicate::UNO:
1550       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1551       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1552                                                  /*width=*/1);
1553       return success();
1554     default:
1555       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1556       break;
1557     }
1558 
1559     if (!isUnsigned) {
1560       // If the rhs value is > SignedMax, fold the comparison.  This handles
1561       // +INF and large values.
1562       APFloat signedMax(rhs.getSemantics());
1563       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1564                                  APFloat::rmNearestTiesToEven);
1565       if (signedMax < rhs) { // smax < 13123.0
1566         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1567             pred == CmpIPredicate::sle)
1568           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1569                                                      /*width=*/1);
1570         else
1571           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1572                                                      /*width=*/1);
1573         return success();
1574       }
1575     } else {
1576       // If the rhs value is > UnsignedMax, fold the comparison. This handles
1577       // +INF and large values.
1578       APFloat unsignedMax(rhs.getSemantics());
1579       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1580                                    APFloat::rmNearestTiesToEven);
1581       if (unsignedMax < rhs) { // umax < 13123.0
1582         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1583             pred == CmpIPredicate::ule)
1584           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1585                                                      /*width=*/1);
1586         else
1587           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1588                                                      /*width=*/1);
1589         return success();
1590       }
1591     }
1592 
1593     if (!isUnsigned) {
1594       // See if the rhs value is < SignedMin.
1595       APFloat signedMin(rhs.getSemantics());
1596       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1597                                  APFloat::rmNearestTiesToEven);
1598       if (signedMin > rhs) { // smin > 12312.0
1599         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1600             pred == CmpIPredicate::sge)
1601           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1602                                                      /*width=*/1);
1603         else
1604           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1605                                                      /*width=*/1);
1606         return success();
1607       }
1608     } else {
1609       // See if the rhs value is < UnsignedMin.
1610       APFloat unsignedMin(rhs.getSemantics());
1611       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1612                                    APFloat::rmNearestTiesToEven);
1613       if (unsignedMin > rhs) { // umin > 12312.0
1614         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1615             pred == CmpIPredicate::uge)
1616           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1617                                                      /*width=*/1);
1618         else
1619           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1620                                                      /*width=*/1);
1621         return success();
1622       }
1623     }
1624 
1625     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1626     // [0, UMAX], but it may still be fractional.  See if it is fractional by
1627     // casting the FP value to the integer value and back, checking for
1628     // equality. Don't do this for zero, because -0.0 is not fractional.
1629     bool ignored;
1630     APSInt rhsInt(intWidth, isUnsigned);
1631     if (APFloat::opInvalidOp ==
1632         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1633       // Undefined behavior invoked - the destination type can't represent
1634       // the input constant.
1635       return failure();
1636     }
1637 
1638     if (!rhs.isZero()) {
1639       APFloat apf(floatTy.getFloatSemantics(),
1640                   APInt::getZero(floatTy.getWidth()));
1641       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1642 
1643       bool equal = apf == rhs;
1644       if (!equal) {
1645         // If we had a comparison against a fractional value, we have to adjust
1646         // the compare predicate and sometimes the value.  rhsInt is rounded
1647         // towards zero at this point.
1648         switch (pred) {
1649         case CmpIPredicate::ne: // (float)int != 4.4   --> true
1650           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1651                                                      /*width=*/1);
1652           return success();
1653         case CmpIPredicate::eq: // (float)int == 4.4   --> false
1654           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1655                                                      /*width=*/1);
1656           return success();
1657         case CmpIPredicate::ule:
1658           // (float)int <= 4.4   --> int <= 4
1659           // (float)int <= -4.4  --> false
1660           if (rhs.isNegative()) {
1661             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1662                                                        /*width=*/1);
1663             return success();
1664           }
1665           break;
1666         case CmpIPredicate::sle:
1667           // (float)int <= 4.4   --> int <= 4
1668           // (float)int <= -4.4  --> int < -4
1669           if (rhs.isNegative())
1670             pred = CmpIPredicate::slt;
1671           break;
1672         case CmpIPredicate::ult:
1673           // (float)int < -4.4   --> false
1674           // (float)int < 4.4    --> int <= 4
1675           if (rhs.isNegative()) {
1676             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1677                                                        /*width=*/1);
1678             return success();
1679           }
1680           pred = CmpIPredicate::ule;
1681           break;
1682         case CmpIPredicate::slt:
1683           // (float)int < -4.4   --> int < -4
1684           // (float)int < 4.4    --> int <= 4
1685           if (!rhs.isNegative())
1686             pred = CmpIPredicate::sle;
1687           break;
1688         case CmpIPredicate::ugt:
1689           // (float)int > 4.4    --> int > 4
1690           // (float)int > -4.4   --> true
1691           if (rhs.isNegative()) {
1692             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1693                                                        /*width=*/1);
1694             return success();
1695           }
1696           break;
1697         case CmpIPredicate::sgt:
1698           // (float)int > 4.4    --> int > 4
1699           // (float)int > -4.4   --> int >= -4
1700           if (rhs.isNegative())
1701             pred = CmpIPredicate::sge;
1702           break;
1703         case CmpIPredicate::uge:
1704           // (float)int >= -4.4   --> true
1705           // (float)int >= 4.4    --> int > 4
1706           if (rhs.isNegative()) {
1707             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1708                                                        /*width=*/1);
1709             return success();
1710           }
1711           pred = CmpIPredicate::ugt;
1712           break;
1713         case CmpIPredicate::sge:
1714           // (float)int >= -4.4   --> int >= -4
1715           // (float)int >= 4.4    --> int > 4
1716           if (!rhs.isNegative())
1717             pred = CmpIPredicate::sgt;
1718           break;
1719         }
1720       }
1721     }
1722 
1723     // Lower this FP comparison into an appropriate integer version of the
1724     // comparison.
1725     rewriter.replaceOpWithNewOp<CmpIOp>(
1726         op, pred, intVal,
1727         rewriter.create<ConstantOp>(
1728             op.getLoc(), intVal.getType(),
1729             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1730     return success();
1731   }
1732 };
1733 
1734 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1735                                                 MLIRContext *context) {
1736   patterns.insert<CmpFIntToFPConst>(context);
1737 }
1738 
1739 //===----------------------------------------------------------------------===//
1740 // SelectOp
1741 //===----------------------------------------------------------------------===//
1742 
1743 // Transforms a select of a boolean to arithmetic operations
1744 //
1745 //  arith.select %arg, %x, %y : i1
1746 //
1747 //  becomes
1748 //
1749 //  and(%arg, %x) or and(!%arg, %y)
1750 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1751   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1752 
1753   LogicalResult matchAndRewrite(arith::SelectOp op,
1754                                 PatternRewriter &rewriter) const override {
1755     if (!op.getType().isInteger(1))
1756       return failure();
1757 
1758     Value falseConstant =
1759         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1760     Value notCondition = rewriter.create<arith::XOrIOp>(
1761         op.getLoc(), op.getCondition(), falseConstant);
1762 
1763     Value trueVal = rewriter.create<arith::AndIOp>(
1764         op.getLoc(), op.getCondition(), op.getTrueValue());
1765     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1766                                                     op.getFalseValue());
1767     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1768     return success();
1769   }
1770 };
1771 
1772 //  select %arg, %c1, %c0 => extui %arg
1773 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1774   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1775 
1776   LogicalResult matchAndRewrite(arith::SelectOp op,
1777                                 PatternRewriter &rewriter) const override {
1778     // Cannot extui i1 to i1, or i1 to f32
1779     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1780       return failure();
1781 
1782     // select %x, c1, %c0 => extui %arg
1783     if (matchPattern(op.getTrueValue(), m_One()) &&
1784         matchPattern(op.getFalseValue(), m_Zero())) {
1785       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1786                                                   op.getCondition());
1787       return success();
1788     }
1789 
1790     // select %x, c0, %c1 => extui (xor %arg, true)
1791     if (matchPattern(op.getTrueValue(), m_Zero()) &&
1792         matchPattern(op.getFalseValue(), m_One())) {
1793       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1794           op, op.getType(),
1795           rewriter.create<arith::XOrIOp>(
1796               op.getLoc(), op.getCondition(),
1797               rewriter.create<arith::ConstantIntOp>(
1798                   op.getLoc(), 1, op.getCondition().getType())));
1799       return success();
1800     }
1801 
1802     return failure();
1803   }
1804 };
1805 
1806 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1807                                                   MLIRContext *context) {
1808   results.add<SelectI1Simplify, SelectToExtUI>(context);
1809 }
1810 
1811 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1812   Value trueVal = getTrueValue();
1813   Value falseVal = getFalseValue();
1814   if (trueVal == falseVal)
1815     return trueVal;
1816 
1817   Value condition = getCondition();
1818 
1819   // select true, %0, %1 => %0
1820   if (matchPattern(condition, m_One()))
1821     return trueVal;
1822 
1823   // select false, %0, %1 => %1
1824   if (matchPattern(condition, m_Zero()))
1825     return falseVal;
1826 
1827   // select %x, true, false => %x
1828   if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
1829       matchPattern(getFalseValue(), m_Zero()))
1830     return condition;
1831 
1832   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1833     auto pred = cmp.getPredicate();
1834     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1835       auto cmpLhs = cmp.getLhs();
1836       auto cmpRhs = cmp.getRhs();
1837 
1838       // %0 = arith.cmpi eq, %arg0, %arg1
1839       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1840 
1841       // %0 = arith.cmpi ne, %arg0, %arg1
1842       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1843 
1844       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1845           (cmpRhs == trueVal && cmpLhs == falseVal))
1846         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1847     }
1848   }
1849   return nullptr;
1850 }
1851 
1852 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1853   Type conditionType, resultType;
1854   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1855   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1856       parser.parseOptionalAttrDict(result.attributes) ||
1857       parser.parseColonType(resultType))
1858     return failure();
1859 
1860   // Check for the explicit condition type if this is a masked tensor or vector.
1861   if (succeeded(parser.parseOptionalComma())) {
1862     conditionType = resultType;
1863     if (parser.parseType(resultType))
1864       return failure();
1865   } else {
1866     conditionType = parser.getBuilder().getI1Type();
1867   }
1868 
1869   result.addTypes(resultType);
1870   return parser.resolveOperands(operands,
1871                                 {conditionType, resultType, resultType},
1872                                 parser.getNameLoc(), result.operands);
1873 }
1874 
1875 void arith::SelectOp::print(OpAsmPrinter &p) {
1876   p << " " << getOperands();
1877   p.printOptionalAttrDict((*this)->getAttrs());
1878   p << " : ";
1879   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1880     p << condType << ", ";
1881   p << getType();
1882 }
1883 
1884 LogicalResult arith::SelectOp::verify() {
1885   Type conditionType = getCondition().getType();
1886   if (conditionType.isSignlessInteger(1))
1887     return success();
1888 
1889   // If the result type is a vector or tensor, the type can be a mask with the
1890   // same elements.
1891   Type resultType = getType();
1892   if (!resultType.isa<TensorType, VectorType>())
1893     return emitOpError() << "expected condition to be a signless i1, but got "
1894                          << conditionType;
1895   Type shapedConditionType = getI1SameShape(resultType);
1896   if (conditionType != shapedConditionType) {
1897     return emitOpError() << "expected condition type to have the same shape "
1898                             "as the result type, expected "
1899                          << shapedConditionType << ", but got "
1900                          << conditionType;
1901   }
1902   return success();
1903 }
1904 //===----------------------------------------------------------------------===//
1905 // ShLIOp
1906 //===----------------------------------------------------------------------===//
1907 
1908 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1909   // Don't fold if shifting more than the bit width.
1910   bool bounded = false;
1911   auto result = constFoldBinaryOp<IntegerAttr>(
1912       operands, [&](const APInt &a, const APInt &b) {
1913         bounded = b.ule(b.getBitWidth());
1914         return a.shl(b);
1915       });
1916   return bounded ? result : Attribute();
1917 }
1918 
1919 //===----------------------------------------------------------------------===//
1920 // ShRUIOp
1921 //===----------------------------------------------------------------------===//
1922 
1923 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1924   // Don't fold if shifting more than the bit width.
1925   bool bounded = false;
1926   auto result = constFoldBinaryOp<IntegerAttr>(
1927       operands, [&](const APInt &a, const APInt &b) {
1928         bounded = b.ule(b.getBitWidth());
1929         return a.lshr(b);
1930       });
1931   return bounded ? result : Attribute();
1932 }
1933 
1934 //===----------------------------------------------------------------------===//
1935 // ShRSIOp
1936 //===----------------------------------------------------------------------===//
1937 
1938 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1939   // Don't fold if shifting more than the bit width.
1940   bool bounded = false;
1941   auto result = constFoldBinaryOp<IntegerAttr>(
1942       operands, [&](const APInt &a, const APInt &b) {
1943         bounded = b.ule(b.getBitWidth());
1944         return a.ashr(b);
1945       });
1946   return bounded ? result : Attribute();
1947 }
1948 
1949 //===----------------------------------------------------------------------===//
1950 // Atomic Enum
1951 //===----------------------------------------------------------------------===//
1952 
1953 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1954 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1955                                             OpBuilder &builder, Location loc) {
1956   switch (kind) {
1957   case AtomicRMWKind::maxf:
1958     return builder.getFloatAttr(
1959         resultType,
1960         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1961                         /*Negative=*/true));
1962   case AtomicRMWKind::addf:
1963   case AtomicRMWKind::addi:
1964   case AtomicRMWKind::maxu:
1965   case AtomicRMWKind::ori:
1966     return builder.getZeroAttr(resultType);
1967   case AtomicRMWKind::andi:
1968     return builder.getIntegerAttr(
1969         resultType,
1970         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1971   case AtomicRMWKind::maxs:
1972     return builder.getIntegerAttr(
1973         resultType,
1974         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1975   case AtomicRMWKind::minf:
1976     return builder.getFloatAttr(
1977         resultType,
1978         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1979                         /*Negative=*/false));
1980   case AtomicRMWKind::mins:
1981     return builder.getIntegerAttr(
1982         resultType,
1983         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1984   case AtomicRMWKind::minu:
1985     return builder.getIntegerAttr(
1986         resultType,
1987         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1988   case AtomicRMWKind::muli:
1989     return builder.getIntegerAttr(resultType, 1);
1990   case AtomicRMWKind::mulf:
1991     return builder.getFloatAttr(resultType, 1);
1992   // TODO: Add remaining reduction operations.
1993   default:
1994     (void)emitOptionalError(loc, "Reduction operation type not supported");
1995     break;
1996   }
1997   return nullptr;
1998 }
1999 
2000 /// Returns the identity value associated with an AtomicRMWKind op.
2001 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2002                                     OpBuilder &builder, Location loc) {
2003   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
2004   return builder.create<arith::ConstantOp>(loc, attr);
2005 }
2006 
2007 /// Return the value obtained by applying the reduction operation kind
2008 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2009 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2010                                   Location loc, Value lhs, Value rhs) {
2011   switch (op) {
2012   case AtomicRMWKind::addf:
2013     return builder.create<arith::AddFOp>(loc, lhs, rhs);
2014   case AtomicRMWKind::addi:
2015     return builder.create<arith::AddIOp>(loc, lhs, rhs);
2016   case AtomicRMWKind::mulf:
2017     return builder.create<arith::MulFOp>(loc, lhs, rhs);
2018   case AtomicRMWKind::muli:
2019     return builder.create<arith::MulIOp>(loc, lhs, rhs);
2020   case AtomicRMWKind::maxf:
2021     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
2022   case AtomicRMWKind::minf:
2023     return builder.create<arith::MinFOp>(loc, lhs, rhs);
2024   case AtomicRMWKind::maxs:
2025     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2026   case AtomicRMWKind::mins:
2027     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2028   case AtomicRMWKind::maxu:
2029     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2030   case AtomicRMWKind::minu:
2031     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2032   case AtomicRMWKind::ori:
2033     return builder.create<arith::OrIOp>(loc, lhs, rhs);
2034   case AtomicRMWKind::andi:
2035     return builder.create<arith::AndIOp>(loc, lhs, rhs);
2036   // TODO: Add remaining reduction operations.
2037   default:
2038     (void)emitOptionalError(loc, "Reduction operation type not supported");
2039     break;
2040   }
2041   return nullptr;
2042 }
2043 
2044 //===----------------------------------------------------------------------===//
2045 // DelinearizeIndexOp
2046 //===----------------------------------------------------------------------===//
2047 
2048 void arith::DelinearizeIndexOp::build(OpBuilder &builder,
2049                                       OperationState &result,
2050                                       Value linear_index,
2051                                       ArrayRef<OpFoldResult> basis) {
2052   result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
2053   result.addOperands(linear_index);
2054   SmallVector<Value> basisValues =
2055       llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value {
2056         Optional<int64_t> staticDim = getConstantIntValue(ofr);
2057         if (staticDim.has_value())
2058           return builder.create<arith::ConstantIndexOp>(result.location,
2059                                                         *staticDim);
2060         return ofr.dyn_cast<Value>();
2061       }));
2062   result.addOperands(basisValues);
2063 }
2064 
2065 LogicalResult arith::DelinearizeIndexOp::verify() {
2066   if (getBasis().empty())
2067     return emitOpError("basis should not be empty");
2068   if (getNumResults() != getBasis().size())
2069     return emitOpError("should return an index for each basis element");
2070   return success();
2071 }
2072 
2073 //===----------------------------------------------------------------------===//
2074 // TableGen'd op method definitions
2075 //===----------------------------------------------------------------------===//
2076 
2077 #define GET_OP_CLASSES
2078 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2079 
2080 //===----------------------------------------------------------------------===//
2081 // TableGen'd enum attribute definitions
2082 //===----------------------------------------------------------------------===//
2083 
2084 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2085