1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #include "llvm/ADT/APSInt.h"
20 
21 using namespace mlir;
22 using namespace mlir::arith;
23 
24 //===----------------------------------------------------------------------===//
25 // Pattern helpers
26 //===----------------------------------------------------------------------===//
27 
28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
29                                    Attribute lhs, Attribute rhs) {
30   return builder.getIntegerAttr(res.getType(),
31                                 lhs.cast<IntegerAttr>().getInt() +
32                                     rhs.cast<IntegerAttr>().getInt());
33 }
34 
35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
36                                    Attribute lhs, Attribute rhs) {
37   return builder.getIntegerAttr(res.getType(),
38                                 lhs.cast<IntegerAttr>().getInt() -
39                                     rhs.cast<IntegerAttr>().getInt());
40 }
41 
42 /// Invert an integer comparison predicate.
43 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
44   switch (pred) {
45   case arith::CmpIPredicate::eq:
46     return arith::CmpIPredicate::ne;
47   case arith::CmpIPredicate::ne:
48     return arith::CmpIPredicate::eq;
49   case arith::CmpIPredicate::slt:
50     return arith::CmpIPredicate::sge;
51   case arith::CmpIPredicate::sle:
52     return arith::CmpIPredicate::sgt;
53   case arith::CmpIPredicate::sgt:
54     return arith::CmpIPredicate::sle;
55   case arith::CmpIPredicate::sge:
56     return arith::CmpIPredicate::slt;
57   case arith::CmpIPredicate::ult:
58     return arith::CmpIPredicate::uge;
59   case arith::CmpIPredicate::ule:
60     return arith::CmpIPredicate::ugt;
61   case arith::CmpIPredicate::ugt:
62     return arith::CmpIPredicate::ule;
63   case arith::CmpIPredicate::uge:
64     return arith::CmpIPredicate::ult;
65   }
66   llvm_unreachable("unknown cmpi predicate kind");
67 }
68 
69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
70   return arith::CmpIPredicateAttr::get(pred.getContext(),
71                                        invertPredicate(pred.getValue()));
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // TableGen'd canonicalization patterns
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 #include "ArithmeticCanonicalization.inc"
80 } // namespace
81 
82 //===----------------------------------------------------------------------===//
83 // ConstantOp
84 //===----------------------------------------------------------------------===//
85 
86 void arith::ConstantOp::getAsmResultNames(
87     function_ref<void(Value, StringRef)> setNameFn) {
88   auto type = getType();
89   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
90     auto intType = type.dyn_cast<IntegerType>();
91 
92     // Sugar i1 constants with 'true' and 'false'.
93     if (intType && intType.getWidth() == 1)
94       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
95 
96     // Otherwise, build a compex name with the value and type.
97     SmallString<32> specialNameBuffer;
98     llvm::raw_svector_ostream specialName(specialNameBuffer);
99     specialName << 'c' << intCst.getInt();
100     if (intType)
101       specialName << '_' << type;
102     setNameFn(getResult(), specialName.str());
103   } else {
104     setNameFn(getResult(), "cst");
105   }
106 }
107 
108 /// TODO: disallow arith.constant to return anything other than signless integer
109 /// or float like.
110 LogicalResult arith::ConstantOp::verify() {
111   auto type = getType();
112   // The value's type must match the return type.
113   if (getValue().getType() != type) {
114     return emitOpError() << "value type " << getValue().getType()
115                          << " must match return type: " << type;
116   }
117   // Integer values must be signless.
118   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
119     return emitOpError("integer return type must be signless");
120   // Any float or elements attribute are acceptable.
121   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
122     return emitOpError(
123         "value must be an integer, float, or elements attribute");
124   }
125   return success();
126 }
127 
128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
129   // The value's type must be the same as the provided type.
130   if (value.getType() != type)
131     return false;
132   // Integer values must be signless.
133   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
134     return false;
135   // Integer, float, and element attributes are buildable.
136   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
137 }
138 
139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
140   return getValue();
141 }
142 
143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
144                                  int64_t value, unsigned width) {
145   auto type = builder.getIntegerType(width);
146   arith::ConstantOp::build(builder, result, type,
147                            builder.getIntegerAttr(type, value));
148 }
149 
150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
151                                  int64_t value, Type type) {
152   assert(type.isSignlessInteger() &&
153          "ConstantIntOp can only have signless integer type values");
154   arith::ConstantOp::build(builder, result, type,
155                            builder.getIntegerAttr(type, value));
156 }
157 
158 bool arith::ConstantIntOp::classof(Operation *op) {
159   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
160     return constOp.getType().isSignlessInteger();
161   return false;
162 }
163 
164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
165                                    const APFloat &value, FloatType type) {
166   arith::ConstantOp::build(builder, result, type,
167                            builder.getFloatAttr(type, value));
168 }
169 
170 bool arith::ConstantFloatOp::classof(Operation *op) {
171   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
172     return constOp.getType().isa<FloatType>();
173   return false;
174 }
175 
176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
177                                    int64_t value) {
178   arith::ConstantOp::build(builder, result, builder.getIndexType(),
179                            builder.getIndexAttr(value));
180 }
181 
182 bool arith::ConstantIndexOp::classof(Operation *op) {
183   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
184     return constOp.getType().isIndex();
185   return false;
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // AddIOp
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
193   // addi(x, 0) -> x
194   if (matchPattern(getRhs(), m_Zero()))
195     return getLhs();
196 
197   // addi(subi(a, b), b) -> a
198   if (auto sub = getLhs().getDefiningOp<SubIOp>())
199     if (getRhs() == sub.getRhs())
200       return sub.getLhs();
201 
202   // addi(b, subi(a, b)) -> a
203   if (auto sub = getRhs().getDefiningOp<SubIOp>())
204     if (getLhs() == sub.getRhs())
205       return sub.getLhs();
206 
207   return constFoldBinaryOp<IntegerAttr>(
208       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
209 }
210 
211 void arith::AddIOp::getCanonicalizationPatterns(
212     RewritePatternSet &patterns, MLIRContext *context) {
213   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
214       context);
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // SubIOp
219 //===----------------------------------------------------------------------===//
220 
221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
222   // subi(x,x) -> 0
223   if (getOperand(0) == getOperand(1))
224     return Builder(getContext()).getZeroAttr(getType());
225   // subi(x,0) -> x
226   if (matchPattern(getRhs(), m_Zero()))
227     return getLhs();
228 
229   return constFoldBinaryOp<IntegerAttr>(
230       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
231 }
232 
233 void arith::SubIOp::getCanonicalizationPatterns(
234     RewritePatternSet &patterns, MLIRContext *context) {
235   patterns
236       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
237            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
238           context);
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // MulIOp
243 //===----------------------------------------------------------------------===//
244 
245 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
246   // muli(x, 0) -> 0
247   if (matchPattern(getRhs(), m_Zero()))
248     return getRhs();
249   // muli(x, 1) -> x
250   if (matchPattern(getRhs(), m_One()))
251     return getOperand(0);
252   // TODO: Handle the overflow case.
253 
254   // default folder
255   return constFoldBinaryOp<IntegerAttr>(
256       operands, [](const APInt &a, const APInt &b) { return a * b; });
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // DivUIOp
261 //===----------------------------------------------------------------------===//
262 
263 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
264   // Don't fold if it would require a division by zero.
265   bool div0 = false;
266   auto result =
267       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
268         if (div0 || !b) {
269           div0 = true;
270           return a;
271         }
272         return a.udiv(b);
273       });
274 
275   // Fold out division by one. Assumes all tensors of all ones are splats.
276   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
277     if (rhs.getValue() == 1)
278       return getLhs();
279   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
280     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
281       return getLhs();
282   }
283 
284   return div0 ? Attribute() : result;
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // DivSIOp
289 //===----------------------------------------------------------------------===//
290 
291 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
292   // Don't fold if it would overflow or if it requires a division by zero.
293   bool overflowOrDiv0 = false;
294   auto result =
295       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
296         if (overflowOrDiv0 || !b) {
297           overflowOrDiv0 = true;
298           return a;
299         }
300         return a.sdiv_ov(b, overflowOrDiv0);
301       });
302 
303   // Fold out division by one. Assumes all tensors of all ones are splats.
304   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
305     if (rhs.getValue() == 1)
306       return getLhs();
307   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
308     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
309       return getLhs();
310   }
311 
312   return overflowOrDiv0 ? Attribute() : result;
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Ceil and floor division folding helpers
317 //===----------------------------------------------------------------------===//
318 
319 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
320                                     bool &overflow) {
321   // Returns (a-1)/b + 1
322   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
323   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
324   return val.sadd_ov(one, overflow);
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // CeilDivUIOp
329 //===----------------------------------------------------------------------===//
330 
331 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
332   bool overflowOrDiv0 = false;
333   auto result =
334       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
335         if (overflowOrDiv0 || !b) {
336           overflowOrDiv0 = true;
337           return a;
338         }
339         APInt quotient = a.udiv(b);
340         if (!a.urem(b))
341           return quotient;
342         APInt one(a.getBitWidth(), 1, true);
343         return quotient.uadd_ov(one, overflowOrDiv0);
344       });
345   // Fold out ceil division by one. Assumes all tensors of all ones are
346   // splats.
347   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
348     if (rhs.getValue() == 1)
349       return getLhs();
350   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
351     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
352       return getLhs();
353   }
354 
355   return overflowOrDiv0 ? Attribute() : result;
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // CeilDivSIOp
360 //===----------------------------------------------------------------------===//
361 
362 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
363   // Don't fold if it would overflow or if it requires a division by zero.
364   bool overflowOrDiv0 = false;
365   auto result =
366       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
367         if (overflowOrDiv0 || !b) {
368           overflowOrDiv0 = true;
369           return a;
370         }
371         if (!a)
372           return a;
373         // After this point we know that neither a or b are zero.
374         unsigned bits = a.getBitWidth();
375         APInt zero = APInt::getZero(bits);
376         bool aGtZero = a.sgt(zero);
377         bool bGtZero = b.sgt(zero);
378         if (aGtZero && bGtZero) {
379           // Both positive, return ceil(a, b).
380           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
381         }
382         if (!aGtZero && !bGtZero) {
383           // Both negative, return ceil(-a, -b).
384           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
385           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
386           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
387         }
388         if (!aGtZero && bGtZero) {
389           // A is negative, b is positive, return - ( -a / b).
390           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
391           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
392           return zero.ssub_ov(div, overflowOrDiv0);
393         }
394         // A is positive, b is negative, return - (a / -b).
395         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
396         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
397         return zero.ssub_ov(div, overflowOrDiv0);
398       });
399 
400   // Fold out ceil division by one. Assumes all tensors of all ones are
401   // splats.
402   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
403     if (rhs.getValue() == 1)
404       return getLhs();
405   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
406     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
407       return getLhs();
408   }
409 
410   return overflowOrDiv0 ? Attribute() : result;
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // FloorDivSIOp
415 //===----------------------------------------------------------------------===//
416 
417 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
418   // Don't fold if it would overflow or if it requires a division by zero.
419   bool overflowOrDiv0 = false;
420   auto result =
421       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
422         if (overflowOrDiv0 || !b) {
423           overflowOrDiv0 = true;
424           return a;
425         }
426         if (!a)
427           return a;
428         // After this point we know that neither a or b are zero.
429         unsigned bits = a.getBitWidth();
430         APInt zero = APInt::getZero(bits);
431         bool aGtZero = a.sgt(zero);
432         bool bGtZero = b.sgt(zero);
433         if (aGtZero && bGtZero) {
434           // Both positive, return a / b.
435           return a.sdiv_ov(b, overflowOrDiv0);
436         }
437         if (!aGtZero && !bGtZero) {
438           // Both negative, return -a / -b.
439           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
440           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
441           return posA.sdiv_ov(posB, overflowOrDiv0);
442         }
443         if (!aGtZero && bGtZero) {
444           // A is negative, b is positive, return - ceil(-a, b).
445           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
446           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
447           return zero.ssub_ov(ceil, overflowOrDiv0);
448         }
449         // A is positive, b is negative, return - ceil(a, -b).
450         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
451         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
452         return zero.ssub_ov(ceil, overflowOrDiv0);
453       });
454 
455   // Fold out floor division by one. Assumes all tensors of all ones are
456   // splats.
457   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
458     if (rhs.getValue() == 1)
459       return getLhs();
460   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
461     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
462       return getLhs();
463   }
464 
465   return overflowOrDiv0 ? Attribute() : result;
466 }
467 
468 //===----------------------------------------------------------------------===//
469 // RemUIOp
470 //===----------------------------------------------------------------------===//
471 
472 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
473   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
474   if (!rhs)
475     return {};
476   auto rhsValue = rhs.getValue();
477 
478   // x % 1 = 0
479   if (rhsValue.isOneValue())
480     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
481 
482   // Don't fold if it requires division by zero.
483   if (rhsValue.isNullValue())
484     return {};
485 
486   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
487   if (!lhs)
488     return {};
489   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // RemSIOp
494 //===----------------------------------------------------------------------===//
495 
496 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
497   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
498   if (!rhs)
499     return {};
500   auto rhsValue = rhs.getValue();
501 
502   // x % 1 = 0
503   if (rhsValue.isOneValue())
504     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
505 
506   // Don't fold if it requires division by zero.
507   if (rhsValue.isNullValue())
508     return {};
509 
510   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
511   if (!lhs)
512     return {};
513   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // AndIOp
518 //===----------------------------------------------------------------------===//
519 
520 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
521   /// and(x, 0) -> 0
522   if (matchPattern(getRhs(), m_Zero()))
523     return getRhs();
524   /// and(x, allOnes) -> x
525   APInt intValue;
526   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
527     return getLhs();
528 
529   return constFoldBinaryOp<IntegerAttr>(
530       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
531 }
532 
533 //===----------------------------------------------------------------------===//
534 // OrIOp
535 //===----------------------------------------------------------------------===//
536 
537 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
538   /// or(x, 0) -> x
539   if (matchPattern(getRhs(), m_Zero()))
540     return getLhs();
541   /// or(x, <all ones>) -> <all ones>
542   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
543     if (rhsAttr.getValue().isAllOnes())
544       return rhsAttr;
545 
546   return constFoldBinaryOp<IntegerAttr>(
547       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // XOrIOp
552 //===----------------------------------------------------------------------===//
553 
554 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
555   /// xor(x, 0) -> x
556   if (matchPattern(getRhs(), m_Zero()))
557     return getLhs();
558   /// xor(x, x) -> 0
559   if (getLhs() == getRhs())
560     return Builder(getContext()).getZeroAttr(getType());
561   /// xor(xor(x, a), a) -> x
562   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
563     if (prev.getRhs() == getRhs())
564       return prev.getLhs();
565 
566   return constFoldBinaryOp<IntegerAttr>(
567       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
568 }
569 
570 void arith::XOrIOp::getCanonicalizationPatterns(
571     RewritePatternSet &patterns, MLIRContext *context) {
572   patterns.add<XOrINotCmpI>(context);
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // AddFOp
577 //===----------------------------------------------------------------------===//
578 
579 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
580   // addf(x, -0) -> x
581   if (matchPattern(getRhs(), m_NegZeroFloat()))
582     return getLhs();
583 
584   return constFoldBinaryOp<FloatAttr>(
585       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // SubFOp
590 //===----------------------------------------------------------------------===//
591 
592 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
593   // subf(x, +0) -> x
594   if (matchPattern(getRhs(), m_PosZeroFloat()))
595     return getLhs();
596 
597   return constFoldBinaryOp<FloatAttr>(
598       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // MaxFOp
603 //===----------------------------------------------------------------------===//
604 
605 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
606   assert(operands.size() == 2 && "maxf takes two operands");
607 
608   // maxf(x,x) -> x
609   if (getLhs() == getRhs())
610     return getRhs();
611 
612   // maxf(x, -inf) -> x
613   if (matchPattern(getRhs(), m_NegInfFloat()))
614     return getLhs();
615 
616   return constFoldBinaryOp<FloatAttr>(
617       operands,
618       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // MaxSIOp
623 //===----------------------------------------------------------------------===//
624 
625 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
626   assert(operands.size() == 2 && "binary operation takes two operands");
627 
628   // maxsi(x,x) -> x
629   if (getLhs() == getRhs())
630     return getRhs();
631 
632   APInt intValue;
633   // maxsi(x,MAX_INT) -> MAX_INT
634   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
635       intValue.isMaxSignedValue())
636     return getRhs();
637 
638   // maxsi(x, MIN_INT) -> x
639   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
640       intValue.isMinSignedValue())
641     return getLhs();
642 
643   return constFoldBinaryOp<IntegerAttr>(operands,
644                                         [](const APInt &a, const APInt &b) {
645                                           return llvm::APIntOps::smax(a, b);
646                                         });
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // MaxUIOp
651 //===----------------------------------------------------------------------===//
652 
653 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
654   assert(operands.size() == 2 && "binary operation takes two operands");
655 
656   // maxui(x,x) -> x
657   if (getLhs() == getRhs())
658     return getRhs();
659 
660   APInt intValue;
661   // maxui(x,MAX_INT) -> MAX_INT
662   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
663     return getRhs();
664 
665   // maxui(x, MIN_INT) -> x
666   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
667     return getLhs();
668 
669   return constFoldBinaryOp<IntegerAttr>(operands,
670                                         [](const APInt &a, const APInt &b) {
671                                           return llvm::APIntOps::umax(a, b);
672                                         });
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // MinFOp
677 //===----------------------------------------------------------------------===//
678 
679 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
680   assert(operands.size() == 2 && "minf takes two operands");
681 
682   // minf(x,x) -> x
683   if (getLhs() == getRhs())
684     return getRhs();
685 
686   // minf(x, +inf) -> x
687   if (matchPattern(getRhs(), m_PosInfFloat()))
688     return getLhs();
689 
690   return constFoldBinaryOp<FloatAttr>(
691       operands,
692       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
693 }
694 
695 //===----------------------------------------------------------------------===//
696 // MinSIOp
697 //===----------------------------------------------------------------------===//
698 
699 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
700   assert(operands.size() == 2 && "binary operation takes two operands");
701 
702   // minsi(x,x) -> x
703   if (getLhs() == getRhs())
704     return getRhs();
705 
706   APInt intValue;
707   // minsi(x,MIN_INT) -> MIN_INT
708   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
709       intValue.isMinSignedValue())
710     return getRhs();
711 
712   // minsi(x, MAX_INT) -> x
713   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
714       intValue.isMaxSignedValue())
715     return getLhs();
716 
717   return constFoldBinaryOp<IntegerAttr>(operands,
718                                         [](const APInt &a, const APInt &b) {
719                                           return llvm::APIntOps::smin(a, b);
720                                         });
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // MinUIOp
725 //===----------------------------------------------------------------------===//
726 
727 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
728   assert(operands.size() == 2 && "binary operation takes two operands");
729 
730   // minui(x,x) -> x
731   if (getLhs() == getRhs())
732     return getRhs();
733 
734   APInt intValue;
735   // minui(x,MIN_INT) -> MIN_INT
736   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
737     return getRhs();
738 
739   // minui(x, MAX_INT) -> x
740   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
741     return getLhs();
742 
743   return constFoldBinaryOp<IntegerAttr>(operands,
744                                         [](const APInt &a, const APInt &b) {
745                                           return llvm::APIntOps::umin(a, b);
746                                         });
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // MulFOp
751 //===----------------------------------------------------------------------===//
752 
753 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
754   APFloat floatValue(0.0f), inverseValue(0.0f);
755   // mulf(x, 1) -> x
756   if (matchPattern(getRhs(), m_OneFloat()))
757     return getLhs();
758 
759   // mulf(1, x) -> x
760   if (matchPattern(getLhs(), m_OneFloat()))
761     return getRhs();
762 
763   return constFoldBinaryOp<FloatAttr>(
764       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
765 }
766 
767 //===----------------------------------------------------------------------===//
768 // DivFOp
769 //===----------------------------------------------------------------------===//
770 
771 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
772   APFloat floatValue(0.0f), inverseValue(0.0f);
773   // divf(x, 1) -> x
774   if (matchPattern(getRhs(), m_OneFloat()))
775     return getLhs();
776 
777   return constFoldBinaryOp<FloatAttr>(
778       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
779 }
780 
781 //===----------------------------------------------------------------------===//
782 // Utility functions for verifying cast ops
783 //===----------------------------------------------------------------------===//
784 
785 template <typename... Types>
786 using type_list = std::tuple<Types...> *;
787 
788 /// Returns a non-null type only if the provided type is one of the allowed
789 /// types or one of the allowed shaped types of the allowed types. Returns the
790 /// element type if a valid shaped type is provided.
791 template <typename... ShapedTypes, typename... ElementTypes>
792 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
793                               type_list<ElementTypes...>) {
794   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
795     return {};
796 
797   auto underlyingType = getElementTypeOrSelf(type);
798   if (!underlyingType.isa<ElementTypes...>())
799     return {};
800 
801   return underlyingType;
802 }
803 
804 /// Get allowed underlying types for vectors and tensors.
805 template <typename... ElementTypes>
806 static Type getTypeIfLike(Type type) {
807   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
808                            type_list<ElementTypes...>());
809 }
810 
811 /// Get allowed underlying types for vectors, tensors, and memrefs.
812 template <typename... ElementTypes>
813 static Type getTypeIfLikeOrMemRef(Type type) {
814   return getUnderlyingType(type,
815                            type_list<VectorType, TensorType, MemRefType>(),
816                            type_list<ElementTypes...>());
817 }
818 
819 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
820   return inputs.size() == 1 && outputs.size() == 1 &&
821          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // Verifiers for integer and floating point extension/truncation ops
826 //===----------------------------------------------------------------------===//
827 
828 // Extend ops can only extend to a wider type.
829 template <typename ValType, typename Op>
830 static LogicalResult verifyExtOp(Op op) {
831   Type srcType = getElementTypeOrSelf(op.getIn().getType());
832   Type dstType = getElementTypeOrSelf(op.getType());
833 
834   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
835     return op.emitError("result type ")
836            << dstType << " must be wider than operand type " << srcType;
837 
838   return success();
839 }
840 
841 // Truncate ops can only truncate to a shorter type.
842 template <typename ValType, typename Op>
843 static LogicalResult verifyTruncateOp(Op op) {
844   Type srcType = getElementTypeOrSelf(op.getIn().getType());
845   Type dstType = getElementTypeOrSelf(op.getType());
846 
847   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
848     return op.emitError("result type ")
849            << dstType << " must be shorter than operand type " << srcType;
850 
851   return success();
852 }
853 
854 /// Validate a cast that changes the width of a type.
855 template <template <typename> class WidthComparator, typename... ElementTypes>
856 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
857   if (!areValidCastInputsAndOutputs(inputs, outputs))
858     return false;
859 
860   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
861   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
862   if (!srcType || !dstType)
863     return false;
864 
865   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
866                                      srcType.getIntOrFloatBitWidth());
867 }
868 
869 //===----------------------------------------------------------------------===//
870 // ExtUIOp
871 //===----------------------------------------------------------------------===//
872 
873 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
874   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
875     return IntegerAttr::get(
876         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
877 
878   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
879     getInMutable().assign(lhs.getIn());
880     return getResult();
881   }
882 
883   return {};
884 }
885 
886 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
887   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
888 }
889 
890 LogicalResult arith::ExtUIOp::verify() {
891   return verifyExtOp<IntegerType>(*this);
892 }
893 
894 //===----------------------------------------------------------------------===//
895 // ExtSIOp
896 //===----------------------------------------------------------------------===//
897 
898 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
899   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
900     return IntegerAttr::get(
901         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
902 
903   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
904     getInMutable().assign(lhs.getIn());
905     return getResult();
906   }
907 
908   return {};
909 }
910 
911 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
912   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
913 }
914 
915 void arith::ExtSIOp::getCanonicalizationPatterns(
916     RewritePatternSet &patterns, MLIRContext *context) {
917   patterns.add<ExtSIOfExtUI>(context);
918 }
919 
920 LogicalResult arith::ExtSIOp::verify() {
921   return verifyExtOp<IntegerType>(*this);
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // ExtFOp
926 //===----------------------------------------------------------------------===//
927 
928 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
929   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
930 }
931 
932 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
933 
934 //===----------------------------------------------------------------------===//
935 // TruncIOp
936 //===----------------------------------------------------------------------===//
937 
938 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
939   assert(operands.size() == 1 && "unary operation takes one operand");
940 
941   // trunci(zexti(a)) -> a
942   // trunci(sexti(a)) -> a
943   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
944       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
945     return getOperand().getDefiningOp()->getOperand(0);
946 
947   // trunci(trunci(a)) -> trunci(a))
948   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
949     setOperand(getOperand().getDefiningOp()->getOperand(0));
950     return getResult();
951   }
952 
953   if (!operands[0])
954     return {};
955 
956   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
957     return IntegerAttr::get(
958         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
959   }
960 
961   return {};
962 }
963 
964 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
965   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
966 }
967 
968 LogicalResult arith::TruncIOp::verify() {
969   return verifyTruncateOp<IntegerType>(*this);
970 }
971 
972 //===----------------------------------------------------------------------===//
973 // TruncFOp
974 //===----------------------------------------------------------------------===//
975 
976 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
977 /// can be represented without precision loss or rounding.
978 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
979   assert(operands.size() == 1 && "unary operation takes one operand");
980 
981   auto constOperand = operands.front();
982   if (!constOperand || !constOperand.isa<FloatAttr>())
983     return {};
984 
985   // Convert to target type via 'double'.
986   double sourceValue =
987       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
988   auto targetAttr = FloatAttr::get(getType(), sourceValue);
989 
990   // Propagate if constant's value does not change after truncation.
991   if (sourceValue == targetAttr.getValue().convertToDouble())
992     return targetAttr;
993 
994   return {};
995 }
996 
997 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
998   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
999 }
1000 
1001 LogicalResult arith::TruncFOp::verify() {
1002   return verifyTruncateOp<FloatType>(*this);
1003 }
1004 
1005 //===----------------------------------------------------------------------===//
1006 // AndIOp
1007 //===----------------------------------------------------------------------===//
1008 
1009 void arith::AndIOp::getCanonicalizationPatterns(
1010     RewritePatternSet &patterns, MLIRContext *context) {
1011   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1012 }
1013 
1014 //===----------------------------------------------------------------------===//
1015 // OrIOp
1016 //===----------------------------------------------------------------------===//
1017 
1018 void arith::OrIOp::getCanonicalizationPatterns(
1019     RewritePatternSet &patterns, MLIRContext *context) {
1020   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1021 }
1022 
1023 //===----------------------------------------------------------------------===//
1024 // Verifiers for casts between integers and floats.
1025 //===----------------------------------------------------------------------===//
1026 
1027 template <typename From, typename To>
1028 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1029   if (!areValidCastInputsAndOutputs(inputs, outputs))
1030     return false;
1031 
1032   auto srcType = getTypeIfLike<From>(inputs.front());
1033   auto dstType = getTypeIfLike<To>(outputs.back());
1034 
1035   return srcType && dstType;
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // UIToFPOp
1040 //===----------------------------------------------------------------------===//
1041 
1042 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1043   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1044 }
1045 
1046 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1047   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1048     const APInt &api = lhs.getValue();
1049     FloatType floatTy = getType().cast<FloatType>();
1050     APFloat apf(floatTy.getFloatSemantics(),
1051                 APInt::getZero(floatTy.getWidth()));
1052     apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
1053     return FloatAttr::get(floatTy, apf);
1054   }
1055   return {};
1056 }
1057 
1058 //===----------------------------------------------------------------------===//
1059 // SIToFPOp
1060 //===----------------------------------------------------------------------===//
1061 
1062 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1063   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1064 }
1065 
1066 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1067   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1068     const APInt &api = lhs.getValue();
1069     FloatType floatTy = getType().cast<FloatType>();
1070     APFloat apf(floatTy.getFloatSemantics(),
1071                 APInt::getZero(floatTy.getWidth()));
1072     apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
1073     return FloatAttr::get(floatTy, apf);
1074   }
1075   return {};
1076 }
1077 //===----------------------------------------------------------------------===//
1078 // FPToUIOp
1079 //===----------------------------------------------------------------------===//
1080 
1081 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1082   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1083 }
1084 
1085 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1086   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1087     const APFloat &apf = lhs.getValue();
1088     IntegerType intTy = getType().cast<IntegerType>();
1089     bool ignored;
1090     APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1091     if (APFloat::opInvalidOp ==
1092         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1093       // Undefined behavior invoked - the destination type can't represent
1094       // the input constant.
1095       return {};
1096     }
1097     return IntegerAttr::get(getType(), api);
1098   }
1099 
1100   return {};
1101 }
1102 
1103 //===----------------------------------------------------------------------===//
1104 // FPToSIOp
1105 //===----------------------------------------------------------------------===//
1106 
1107 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1108   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1109 }
1110 
1111 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1112   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1113     const APFloat &apf = lhs.getValue();
1114     IntegerType intTy = getType().cast<IntegerType>();
1115     bool ignored;
1116     APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1117     if (APFloat::opInvalidOp ==
1118         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1119       // Undefined behavior invoked - the destination type can't represent
1120       // the input constant.
1121       return {};
1122     }
1123     return IntegerAttr::get(getType(), api);
1124   }
1125 
1126   return {};
1127 }
1128 
1129 //===----------------------------------------------------------------------===//
1130 // IndexCastOp
1131 //===----------------------------------------------------------------------===//
1132 
1133 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1134                                            TypeRange outputs) {
1135   if (!areValidCastInputsAndOutputs(inputs, outputs))
1136     return false;
1137 
1138   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1139   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1140   if (!srcType || !dstType)
1141     return false;
1142 
1143   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1144          (srcType.isSignlessInteger() && dstType.isIndex());
1145 }
1146 
1147 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1148   // index_cast(constant) -> constant
1149   // A little hack because we go through int. Otherwise, the size of the
1150   // constant might need to change.
1151   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1152     return IntegerAttr::get(getType(), value.getInt());
1153 
1154   return {};
1155 }
1156 
1157 void arith::IndexCastOp::getCanonicalizationPatterns(
1158     RewritePatternSet &patterns, MLIRContext *context) {
1159   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1160 }
1161 
1162 //===----------------------------------------------------------------------===//
1163 // BitcastOp
1164 //===----------------------------------------------------------------------===//
1165 
1166 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1167   if (!areValidCastInputsAndOutputs(inputs, outputs))
1168     return false;
1169 
1170   auto srcType =
1171       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1172   auto dstType =
1173       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1174   if (!srcType || !dstType)
1175     return false;
1176 
1177   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1178 }
1179 
1180 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1181   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1182 
1183   auto resType = getType();
1184   auto operand = operands[0];
1185   if (!operand)
1186     return {};
1187 
1188   /// Bitcast dense elements.
1189   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1190     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1191   /// Other shaped types unhandled.
1192   if (resType.isa<ShapedType>())
1193     return {};
1194 
1195   /// Bitcast integer or float to integer or float.
1196   APInt bits = operand.isa<FloatAttr>()
1197                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1198                    : operand.cast<IntegerAttr>().getValue();
1199 
1200   if (auto resFloatType = resType.dyn_cast<FloatType>())
1201     return FloatAttr::get(resType,
1202                           APFloat(resFloatType.getFloatSemantics(), bits));
1203   return IntegerAttr::get(resType, bits);
1204 }
1205 
1206 void arith::BitcastOp::getCanonicalizationPatterns(
1207     RewritePatternSet &patterns, MLIRContext *context) {
1208   patterns.add<BitcastOfBitcast>(context);
1209 }
1210 
1211 //===----------------------------------------------------------------------===//
1212 // Helpers for compare ops
1213 //===----------------------------------------------------------------------===//
1214 
1215 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1216 static Type getI1SameShape(Type type) {
1217   auto i1Type = IntegerType::get(type.getContext(), 1);
1218   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1219     return RankedTensorType::get(tensorType.getShape(), i1Type);
1220   if (type.isa<UnrankedTensorType>())
1221     return UnrankedTensorType::get(i1Type);
1222   if (auto vectorType = type.dyn_cast<VectorType>())
1223     return VectorType::get(vectorType.getShape(), i1Type,
1224                            vectorType.getNumScalableDims());
1225   return i1Type;
1226 }
1227 
1228 //===----------------------------------------------------------------------===//
1229 // CmpIOp
1230 //===----------------------------------------------------------------------===//
1231 
1232 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1233 /// comparison predicates.
1234 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1235                                     const APInt &lhs, const APInt &rhs) {
1236   switch (predicate) {
1237   case arith::CmpIPredicate::eq:
1238     return lhs.eq(rhs);
1239   case arith::CmpIPredicate::ne:
1240     return lhs.ne(rhs);
1241   case arith::CmpIPredicate::slt:
1242     return lhs.slt(rhs);
1243   case arith::CmpIPredicate::sle:
1244     return lhs.sle(rhs);
1245   case arith::CmpIPredicate::sgt:
1246     return lhs.sgt(rhs);
1247   case arith::CmpIPredicate::sge:
1248     return lhs.sge(rhs);
1249   case arith::CmpIPredicate::ult:
1250     return lhs.ult(rhs);
1251   case arith::CmpIPredicate::ule:
1252     return lhs.ule(rhs);
1253   case arith::CmpIPredicate::ugt:
1254     return lhs.ugt(rhs);
1255   case arith::CmpIPredicate::uge:
1256     return lhs.uge(rhs);
1257   }
1258   llvm_unreachable("unknown cmpi predicate kind");
1259 }
1260 
1261 /// Returns true if the predicate is true for two equal operands.
1262 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1263   switch (predicate) {
1264   case arith::CmpIPredicate::eq:
1265   case arith::CmpIPredicate::sle:
1266   case arith::CmpIPredicate::sge:
1267   case arith::CmpIPredicate::ule:
1268   case arith::CmpIPredicate::uge:
1269     return true;
1270   case arith::CmpIPredicate::ne:
1271   case arith::CmpIPredicate::slt:
1272   case arith::CmpIPredicate::sgt:
1273   case arith::CmpIPredicate::ult:
1274   case arith::CmpIPredicate::ugt:
1275     return false;
1276   }
1277   llvm_unreachable("unknown cmpi predicate kind");
1278 }
1279 
1280 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1281   auto boolAttr = BoolAttr::get(ctx, value);
1282   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1283   if (!shapedType)
1284     return boolAttr;
1285   return DenseElementsAttr::get(shapedType, boolAttr);
1286 }
1287 
1288 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1289   assert(operands.size() == 2 && "cmpi takes two operands");
1290 
1291   // cmpi(pred, x, x)
1292   if (getLhs() == getRhs()) {
1293     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1294     return getBoolAttribute(getType(), getContext(), val);
1295   }
1296 
1297   if (matchPattern(getRhs(), m_Zero())) {
1298     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1299       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1300         // extsi(%x : i1 -> iN) != 0  ->  %x
1301         if (getPredicate() == arith::CmpIPredicate::ne) {
1302           return extOp.getOperand();
1303         }
1304       }
1305     }
1306     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1307       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1308         // extui(%x : i1 -> iN) != 0  ->  %x
1309         if (getPredicate() == arith::CmpIPredicate::ne) {
1310           return extOp.getOperand();
1311         }
1312       }
1313     }
1314   }
1315 
1316   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1317   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1318   if (!lhs || !rhs)
1319     return {};
1320 
1321   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1322   return BoolAttr::get(getContext(), val);
1323 }
1324 
1325 //===----------------------------------------------------------------------===//
1326 // CmpFOp
1327 //===----------------------------------------------------------------------===//
1328 
1329 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1330 /// comparison predicates.
1331 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1332                                     const APFloat &lhs, const APFloat &rhs) {
1333   auto cmpResult = lhs.compare(rhs);
1334   switch (predicate) {
1335   case arith::CmpFPredicate::AlwaysFalse:
1336     return false;
1337   case arith::CmpFPredicate::OEQ:
1338     return cmpResult == APFloat::cmpEqual;
1339   case arith::CmpFPredicate::OGT:
1340     return cmpResult == APFloat::cmpGreaterThan;
1341   case arith::CmpFPredicate::OGE:
1342     return cmpResult == APFloat::cmpGreaterThan ||
1343            cmpResult == APFloat::cmpEqual;
1344   case arith::CmpFPredicate::OLT:
1345     return cmpResult == APFloat::cmpLessThan;
1346   case arith::CmpFPredicate::OLE:
1347     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1348   case arith::CmpFPredicate::ONE:
1349     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1350   case arith::CmpFPredicate::ORD:
1351     return cmpResult != APFloat::cmpUnordered;
1352   case arith::CmpFPredicate::UEQ:
1353     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1354   case arith::CmpFPredicate::UGT:
1355     return cmpResult == APFloat::cmpUnordered ||
1356            cmpResult == APFloat::cmpGreaterThan;
1357   case arith::CmpFPredicate::UGE:
1358     return cmpResult == APFloat::cmpUnordered ||
1359            cmpResult == APFloat::cmpGreaterThan ||
1360            cmpResult == APFloat::cmpEqual;
1361   case arith::CmpFPredicate::ULT:
1362     return cmpResult == APFloat::cmpUnordered ||
1363            cmpResult == APFloat::cmpLessThan;
1364   case arith::CmpFPredicate::ULE:
1365     return cmpResult == APFloat::cmpUnordered ||
1366            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1367   case arith::CmpFPredicate::UNE:
1368     return cmpResult != APFloat::cmpEqual;
1369   case arith::CmpFPredicate::UNO:
1370     return cmpResult == APFloat::cmpUnordered;
1371   case arith::CmpFPredicate::AlwaysTrue:
1372     return true;
1373   }
1374   llvm_unreachable("unknown cmpf predicate kind");
1375 }
1376 
1377 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1378   assert(operands.size() == 2 && "cmpf takes two operands");
1379 
1380   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1381   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1382 
1383   // If one operand is NaN, making them both NaN does not change the result.
1384   if (lhs && lhs.getValue().isNaN())
1385     rhs = lhs;
1386   if (rhs && rhs.getValue().isNaN())
1387     lhs = rhs;
1388 
1389   if (!lhs || !rhs)
1390     return {};
1391 
1392   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1393   return BoolAttr::get(getContext(), val);
1394 }
1395 
1396 //===----------------------------------------------------------------------===//
1397 // SelectOp
1398 //===----------------------------------------------------------------------===//
1399 
1400 // Transforms a select of a boolean to arithmetic operations
1401 //
1402 //  arith.select %arg, %x, %y : i1
1403 //
1404 //  becomes
1405 //
1406 //  and(%arg, %x) or and(!%arg, %y)
1407 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1408   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1409 
1410   LogicalResult matchAndRewrite(arith::SelectOp op,
1411                                 PatternRewriter &rewriter) const override {
1412     if (!op.getType().isInteger(1))
1413       return failure();
1414 
1415     Value falseConstant =
1416         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1417     Value notCondition = rewriter.create<arith::XOrIOp>(
1418         op.getLoc(), op.getCondition(), falseConstant);
1419 
1420     Value trueVal = rewriter.create<arith::AndIOp>(
1421         op.getLoc(), op.getCondition(), op.getTrueValue());
1422     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1423                                                     op.getFalseValue());
1424     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1425     return success();
1426   }
1427 };
1428 
1429 //  select %arg, %c1, %c0 => extui %arg
1430 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1431   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1432 
1433   LogicalResult matchAndRewrite(arith::SelectOp op,
1434                                 PatternRewriter &rewriter) const override {
1435     // Cannot extui i1 to i1, or i1 to f32
1436     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1437       return failure();
1438 
1439     // select %x, c1, %c0 => extui %arg
1440     if (matchPattern(op.getTrueValue(), m_One()))
1441       if (matchPattern(op.getFalseValue(), m_Zero())) {
1442         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1443                                                     op.getCondition());
1444         return success();
1445       }
1446 
1447     // select %x, c0, %c1 => extui (xor %arg, true)
1448     if (matchPattern(op.getTrueValue(), m_Zero()))
1449       if (matchPattern(op.getFalseValue(), m_One())) {
1450         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1451             op, op.getType(),
1452             rewriter.create<arith::XOrIOp>(
1453                 op.getLoc(), op.getCondition(),
1454                 rewriter.create<arith::ConstantIntOp>(
1455                     op.getLoc(), 1, op.getCondition().getType())));
1456         return success();
1457       }
1458 
1459     return failure();
1460   }
1461 };
1462 
1463 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1464                                                   MLIRContext *context) {
1465   results.add<SelectI1Simplify, SelectToExtUI>(context);
1466 }
1467 
1468 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1469   Value trueVal = getTrueValue();
1470   Value falseVal = getFalseValue();
1471   if (trueVal == falseVal)
1472     return trueVal;
1473 
1474   Value condition = getCondition();
1475 
1476   // select true, %0, %1 => %0
1477   if (matchPattern(condition, m_One()))
1478     return trueVal;
1479 
1480   // select false, %0, %1 => %1
1481   if (matchPattern(condition, m_Zero()))
1482     return falseVal;
1483 
1484   // select %x, true, false => %x
1485   if (getType().isInteger(1))
1486     if (matchPattern(getTrueValue(), m_One()))
1487       if (matchPattern(getFalseValue(), m_Zero()))
1488         return condition;
1489 
1490   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1491     auto pred = cmp.getPredicate();
1492     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1493       auto cmpLhs = cmp.getLhs();
1494       auto cmpRhs = cmp.getRhs();
1495 
1496       // %0 = arith.cmpi eq, %arg0, %arg1
1497       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1498 
1499       // %0 = arith.cmpi ne, %arg0, %arg1
1500       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1501 
1502       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1503           (cmpRhs == trueVal && cmpLhs == falseVal))
1504         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1505     }
1506   }
1507   return nullptr;
1508 }
1509 
1510 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1511   Type conditionType, resultType;
1512   SmallVector<OpAsmParser::OperandType, 3> operands;
1513   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1514       parser.parseOptionalAttrDict(result.attributes) ||
1515       parser.parseColonType(resultType))
1516     return failure();
1517 
1518   // Check for the explicit condition type if this is a masked tensor or vector.
1519   if (succeeded(parser.parseOptionalComma())) {
1520     conditionType = resultType;
1521     if (parser.parseType(resultType))
1522       return failure();
1523   } else {
1524     conditionType = parser.getBuilder().getI1Type();
1525   }
1526 
1527   result.addTypes(resultType);
1528   return parser.resolveOperands(operands,
1529                                 {conditionType, resultType, resultType},
1530                                 parser.getNameLoc(), result.operands);
1531 }
1532 
1533 void arith::SelectOp::print(OpAsmPrinter &p) {
1534   p << " " << getOperands();
1535   p.printOptionalAttrDict((*this)->getAttrs());
1536   p << " : ";
1537   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1538     p << condType << ", ";
1539   p << getType();
1540 }
1541 
1542 LogicalResult arith::SelectOp::verify() {
1543   Type conditionType = getCondition().getType();
1544   if (conditionType.isSignlessInteger(1))
1545     return success();
1546 
1547   // If the result type is a vector or tensor, the type can be a mask with the
1548   // same elements.
1549   Type resultType = getType();
1550   if (!resultType.isa<TensorType, VectorType>())
1551     return emitOpError() << "expected condition to be a signless i1, but got "
1552                          << conditionType;
1553   Type shapedConditionType = getI1SameShape(resultType);
1554   if (conditionType != shapedConditionType) {
1555     return emitOpError() << "expected condition type to have the same shape "
1556                             "as the result type, expected "
1557                          << shapedConditionType << ", but got "
1558                          << conditionType;
1559   }
1560   return success();
1561 }
1562 
1563 //===----------------------------------------------------------------------===//
1564 // Atomic Enum
1565 //===----------------------------------------------------------------------===//
1566 
1567 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1568 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1569                                             OpBuilder &builder, Location loc) {
1570   switch (kind) {
1571   case AtomicRMWKind::maxf:
1572     return builder.getFloatAttr(
1573         resultType,
1574         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1575                         /*Negative=*/true));
1576   case AtomicRMWKind::addf:
1577   case AtomicRMWKind::addi:
1578   case AtomicRMWKind::maxu:
1579   case AtomicRMWKind::ori:
1580     return builder.getZeroAttr(resultType);
1581   case AtomicRMWKind::andi:
1582     return builder.getIntegerAttr(
1583         resultType,
1584         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1585   case AtomicRMWKind::maxs:
1586     return builder.getIntegerAttr(
1587         resultType,
1588         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1589   case AtomicRMWKind::minf:
1590     return builder.getFloatAttr(
1591         resultType,
1592         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1593                         /*Negative=*/false));
1594   case AtomicRMWKind::mins:
1595     return builder.getIntegerAttr(
1596         resultType,
1597         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1598   case AtomicRMWKind::minu:
1599     return builder.getIntegerAttr(
1600         resultType,
1601         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1602   case AtomicRMWKind::muli:
1603     return builder.getIntegerAttr(resultType, 1);
1604   case AtomicRMWKind::mulf:
1605     return builder.getFloatAttr(resultType, 1);
1606   // TODO: Add remaining reduction operations.
1607   default:
1608     (void)emitOptionalError(loc, "Reduction operation type not supported");
1609     break;
1610   }
1611   return nullptr;
1612 }
1613 
1614 /// Returns the identity value associated with an AtomicRMWKind op.
1615 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1616                                     OpBuilder &builder, Location loc) {
1617   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1618   return builder.create<arith::ConstantOp>(loc, attr);
1619 }
1620 
1621 /// Return the value obtained by applying the reduction operation kind
1622 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1623 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1624                                   Location loc, Value lhs, Value rhs) {
1625   switch (op) {
1626   case AtomicRMWKind::addf:
1627     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1628   case AtomicRMWKind::addi:
1629     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1630   case AtomicRMWKind::mulf:
1631     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1632   case AtomicRMWKind::muli:
1633     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1634   case AtomicRMWKind::maxf:
1635     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1636   case AtomicRMWKind::minf:
1637     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1638   case AtomicRMWKind::maxs:
1639     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1640   case AtomicRMWKind::mins:
1641     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1642   case AtomicRMWKind::maxu:
1643     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1644   case AtomicRMWKind::minu:
1645     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1646   case AtomicRMWKind::ori:
1647     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1648   case AtomicRMWKind::andi:
1649     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1650   // TODO: Add remaining reduction operations.
1651   default:
1652     (void)emitOptionalError(loc, "Reduction operation type not supported");
1653     break;
1654   }
1655   return nullptr;
1656 }
1657 
1658 //===----------------------------------------------------------------------===//
1659 // TableGen'd op method definitions
1660 //===----------------------------------------------------------------------===//
1661 
1662 #define GET_OP_CLASSES
1663 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1664 
1665 //===----------------------------------------------------------------------===//
1666 // TableGen'd enum attribute definitions
1667 //===----------------------------------------------------------------------===//
1668 
1669 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1670