1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #include "llvm/ADT/APSInt.h"
20 
21 using namespace mlir;
22 using namespace mlir::arith;
23 
24 //===----------------------------------------------------------------------===//
25 // Pattern helpers
26 //===----------------------------------------------------------------------===//
27 
28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
29                                    Attribute lhs, Attribute rhs) {
30   return builder.getIntegerAttr(res.getType(),
31                                 lhs.cast<IntegerAttr>().getInt() +
32                                     rhs.cast<IntegerAttr>().getInt());
33 }
34 
35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
36                                    Attribute lhs, Attribute rhs) {
37   return builder.getIntegerAttr(res.getType(),
38                                 lhs.cast<IntegerAttr>().getInt() -
39                                     rhs.cast<IntegerAttr>().getInt());
40 }
41 
42 /// Invert an integer comparison predicate.
43 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
44   switch (pred) {
45   case arith::CmpIPredicate::eq:
46     return arith::CmpIPredicate::ne;
47   case arith::CmpIPredicate::ne:
48     return arith::CmpIPredicate::eq;
49   case arith::CmpIPredicate::slt:
50     return arith::CmpIPredicate::sge;
51   case arith::CmpIPredicate::sle:
52     return arith::CmpIPredicate::sgt;
53   case arith::CmpIPredicate::sgt:
54     return arith::CmpIPredicate::sle;
55   case arith::CmpIPredicate::sge:
56     return arith::CmpIPredicate::slt;
57   case arith::CmpIPredicate::ult:
58     return arith::CmpIPredicate::uge;
59   case arith::CmpIPredicate::ule:
60     return arith::CmpIPredicate::ugt;
61   case arith::CmpIPredicate::ugt:
62     return arith::CmpIPredicate::ule;
63   case arith::CmpIPredicate::uge:
64     return arith::CmpIPredicate::ult;
65   }
66   llvm_unreachable("unknown cmpi predicate kind");
67 }
68 
69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
70   return arith::CmpIPredicateAttr::get(pred.getContext(),
71                                        invertPredicate(pred.getValue()));
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // TableGen'd canonicalization patterns
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 #include "ArithmeticCanonicalization.inc"
80 } // namespace
81 
82 //===----------------------------------------------------------------------===//
83 // ConstantOp
84 //===----------------------------------------------------------------------===//
85 
86 void arith::ConstantOp::getAsmResultNames(
87     function_ref<void(Value, StringRef)> setNameFn) {
88   auto type = getType();
89   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
90     auto intType = type.dyn_cast<IntegerType>();
91 
92     // Sugar i1 constants with 'true' and 'false'.
93     if (intType && intType.getWidth() == 1)
94       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
95 
96     // Otherwise, build a compex name with the value and type.
97     SmallString<32> specialNameBuffer;
98     llvm::raw_svector_ostream specialName(specialNameBuffer);
99     specialName << 'c' << intCst.getInt();
100     if (intType)
101       specialName << '_' << type;
102     setNameFn(getResult(), specialName.str());
103   } else {
104     setNameFn(getResult(), "cst");
105   }
106 }
107 
108 /// TODO: disallow arith.constant to return anything other than signless integer
109 /// or float like.
110 static LogicalResult verify(arith::ConstantOp op) {
111   auto type = op.getType();
112   // The value's type must match the return type.
113   if (op.getValue().getType() != type) {
114     return op.emitOpError() << "value type " << op.getValue().getType()
115                             << " must match return type: " << type;
116   }
117   // Integer values must be signless.
118   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
119     return op.emitOpError("integer return type must be signless");
120   // Any float or elements attribute are acceptable.
121   if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
122     return op.emitOpError(
123         "value must be an integer, float, or elements attribute");
124   }
125   return success();
126 }
127 
128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
129   // The value's type must be the same as the provided type.
130   if (value.getType() != type)
131     return false;
132   // Integer values must be signless.
133   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
134     return false;
135   // Integer, float, and element attributes are buildable.
136   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
137 }
138 
139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
140   return getValue();
141 }
142 
143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
144                                  int64_t value, unsigned width) {
145   auto type = builder.getIntegerType(width);
146   arith::ConstantOp::build(builder, result, type,
147                            builder.getIntegerAttr(type, value));
148 }
149 
150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
151                                  int64_t value, Type type) {
152   assert(type.isSignlessInteger() &&
153          "ConstantIntOp can only have signless integer type values");
154   arith::ConstantOp::build(builder, result, type,
155                            builder.getIntegerAttr(type, value));
156 }
157 
158 bool arith::ConstantIntOp::classof(Operation *op) {
159   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
160     return constOp.getType().isSignlessInteger();
161   return false;
162 }
163 
164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
165                                    const APFloat &value, FloatType type) {
166   arith::ConstantOp::build(builder, result, type,
167                            builder.getFloatAttr(type, value));
168 }
169 
170 bool arith::ConstantFloatOp::classof(Operation *op) {
171   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
172     return constOp.getType().isa<FloatType>();
173   return false;
174 }
175 
176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
177                                    int64_t value) {
178   arith::ConstantOp::build(builder, result, builder.getIndexType(),
179                            builder.getIndexAttr(value));
180 }
181 
182 bool arith::ConstantIndexOp::classof(Operation *op) {
183   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
184     return constOp.getType().isIndex();
185   return false;
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // AddIOp
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
193   // addi(x, 0) -> x
194   if (matchPattern(getRhs(), m_Zero()))
195     return getLhs();
196 
197   // add(sub(a, b), b) -> a
198   if (auto sub = getLhs().getDefiningOp<SubIOp>())
199     if (getRhs() == sub.getRhs())
200       return sub.getLhs();
201 
202   // add(b, sub(a, b)) -> a
203   if (auto sub = getRhs().getDefiningOp<SubIOp>())
204     if (getLhs() == sub.getRhs())
205       return sub.getLhs();
206 
207   return constFoldBinaryOp<IntegerAttr>(
208       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
209 }
210 
211 void arith::AddIOp::getCanonicalizationPatterns(
212     OwningRewritePatternList &patterns, MLIRContext *context) {
213   patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
214       context);
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // SubIOp
219 //===----------------------------------------------------------------------===//
220 
221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
222   // subi(x,x) -> 0
223   if (getOperand(0) == getOperand(1))
224     return Builder(getContext()).getZeroAttr(getType());
225   // subi(x,0) -> x
226   if (matchPattern(getRhs(), m_Zero()))
227     return getLhs();
228 
229   return constFoldBinaryOp<IntegerAttr>(
230       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
231 }
232 
233 void arith::SubIOp::getCanonicalizationPatterns(
234     OwningRewritePatternList &patterns, MLIRContext *context) {
235   patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
236                   SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
237                   SubILHSSubConstantLHS>(context);
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // MulIOp
242 //===----------------------------------------------------------------------===//
243 
244 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
245   // muli(x, 0) -> 0
246   if (matchPattern(getRhs(), m_Zero()))
247     return getRhs();
248   // muli(x, 1) -> x
249   if (matchPattern(getRhs(), m_One()))
250     return getOperand(0);
251   // TODO: Handle the overflow case.
252 
253   // default folder
254   return constFoldBinaryOp<IntegerAttr>(
255       operands, [](const APInt &a, const APInt &b) { return a * b; });
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // DivUIOp
260 //===----------------------------------------------------------------------===//
261 
262 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
263   // Don't fold if it would require a division by zero.
264   bool div0 = false;
265   auto result =
266       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
267         if (div0 || !b) {
268           div0 = true;
269           return a;
270         }
271         return a.udiv(b);
272       });
273 
274   // Fold out division by one. Assumes all tensors of all ones are splats.
275   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
276     if (rhs.getValue() == 1)
277       return getLhs();
278   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
279     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
280       return getLhs();
281   }
282 
283   return div0 ? Attribute() : result;
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // DivSIOp
288 //===----------------------------------------------------------------------===//
289 
290 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
291   // Don't fold if it would overflow or if it requires a division by zero.
292   bool overflowOrDiv0 = false;
293   auto result =
294       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
295         if (overflowOrDiv0 || !b) {
296           overflowOrDiv0 = true;
297           return a;
298         }
299         return a.sdiv_ov(b, overflowOrDiv0);
300       });
301 
302   // Fold out division by one. Assumes all tensors of all ones are splats.
303   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
304     if (rhs.getValue() == 1)
305       return getLhs();
306   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
307     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
308       return getLhs();
309   }
310 
311   return overflowOrDiv0 ? Attribute() : result;
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Ceil and floor division folding helpers
316 //===----------------------------------------------------------------------===//
317 
318 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
319                                     bool &overflow) {
320   // Returns (a-1)/b + 1
321   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
322   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
323   return val.sadd_ov(one, overflow);
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // CeilDivUIOp
328 //===----------------------------------------------------------------------===//
329 
330 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
331   bool overflowOrDiv0 = false;
332   auto result =
333       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
334         if (overflowOrDiv0 || !b) {
335           overflowOrDiv0 = true;
336           return a;
337         }
338         APInt quotient = a.udiv(b);
339         if (!a.urem(b))
340           return quotient;
341         APInt one(a.getBitWidth(), 1, true);
342         return quotient.uadd_ov(one, overflowOrDiv0);
343       });
344   // Fold out ceil division by one. Assumes all tensors of all ones are
345   // splats.
346   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
347     if (rhs.getValue() == 1)
348       return getLhs();
349   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
350     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
351       return getLhs();
352   }
353 
354   return overflowOrDiv0 ? Attribute() : result;
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // CeilDivSIOp
359 //===----------------------------------------------------------------------===//
360 
361 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
362   // Don't fold if it would overflow or if it requires a division by zero.
363   bool overflowOrDiv0 = false;
364   auto result =
365       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
366         if (overflowOrDiv0 || !b) {
367           overflowOrDiv0 = true;
368           return a;
369         }
370         if (!a)
371           return a;
372         // After this point we know that neither a or b are zero.
373         unsigned bits = a.getBitWidth();
374         APInt zero = APInt::getZero(bits);
375         bool aGtZero = a.sgt(zero);
376         bool bGtZero = b.sgt(zero);
377         if (aGtZero && bGtZero) {
378           // Both positive, return ceil(a, b).
379           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
380         }
381         if (!aGtZero && !bGtZero) {
382           // Both negative, return ceil(-a, -b).
383           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
384           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
385           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
386         }
387         if (!aGtZero && bGtZero) {
388           // A is negative, b is positive, return - ( -a / b).
389           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
390           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
391           return zero.ssub_ov(div, overflowOrDiv0);
392         }
393         // A is positive, b is negative, return - (a / -b).
394         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
395         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
396         return zero.ssub_ov(div, overflowOrDiv0);
397       });
398 
399   // Fold out ceil division by one. Assumes all tensors of all ones are
400   // splats.
401   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
402     if (rhs.getValue() == 1)
403       return getLhs();
404   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
405     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
406       return getLhs();
407   }
408 
409   return overflowOrDiv0 ? Attribute() : result;
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // FloorDivSIOp
414 //===----------------------------------------------------------------------===//
415 
416 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
417   // Don't fold if it would overflow or if it requires a division by zero.
418   bool overflowOrDiv0 = false;
419   auto result =
420       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
421         if (overflowOrDiv0 || !b) {
422           overflowOrDiv0 = true;
423           return a;
424         }
425         if (!a)
426           return a;
427         // After this point we know that neither a or b are zero.
428         unsigned bits = a.getBitWidth();
429         APInt zero = APInt::getZero(bits);
430         bool aGtZero = a.sgt(zero);
431         bool bGtZero = b.sgt(zero);
432         if (aGtZero && bGtZero) {
433           // Both positive, return a / b.
434           return a.sdiv_ov(b, overflowOrDiv0);
435         }
436         if (!aGtZero && !bGtZero) {
437           // Both negative, return -a / -b.
438           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
439           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
440           return posA.sdiv_ov(posB, overflowOrDiv0);
441         }
442         if (!aGtZero && bGtZero) {
443           // A is negative, b is positive, return - ceil(-a, b).
444           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
445           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
446           return zero.ssub_ov(ceil, overflowOrDiv0);
447         }
448         // A is positive, b is negative, return - ceil(a, -b).
449         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
450         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
451         return zero.ssub_ov(ceil, overflowOrDiv0);
452       });
453 
454   // Fold out floor division by one. Assumes all tensors of all ones are
455   // splats.
456   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
457     if (rhs.getValue() == 1)
458       return getLhs();
459   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
460     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
461       return getLhs();
462   }
463 
464   return overflowOrDiv0 ? Attribute() : result;
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // RemUIOp
469 //===----------------------------------------------------------------------===//
470 
471 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
472   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
473   if (!rhs)
474     return {};
475   auto rhsValue = rhs.getValue();
476 
477   // x % 1 = 0
478   if (rhsValue.isOneValue())
479     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
480 
481   // Don't fold if it requires division by zero.
482   if (rhsValue.isNullValue())
483     return {};
484 
485   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
486   if (!lhs)
487     return {};
488   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // RemSIOp
493 //===----------------------------------------------------------------------===//
494 
495 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
496   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
497   if (!rhs)
498     return {};
499   auto rhsValue = rhs.getValue();
500 
501   // x % 1 = 0
502   if (rhsValue.isOneValue())
503     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
504 
505   // Don't fold if it requires division by zero.
506   if (rhsValue.isNullValue())
507     return {};
508 
509   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
510   if (!lhs)
511     return {};
512   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // AndIOp
517 //===----------------------------------------------------------------------===//
518 
519 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
520   /// and(x, 0) -> 0
521   if (matchPattern(getRhs(), m_Zero()))
522     return getRhs();
523   /// and(x, allOnes) -> x
524   APInt intValue;
525   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
526     return getLhs();
527 
528   return constFoldBinaryOp<IntegerAttr>(
529       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // OrIOp
534 //===----------------------------------------------------------------------===//
535 
536 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
537   /// or(x, 0) -> x
538   if (matchPattern(getRhs(), m_Zero()))
539     return getLhs();
540   /// or(x, <all ones>) -> <all ones>
541   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
542     if (rhsAttr.getValue().isAllOnes())
543       return rhsAttr;
544 
545   return constFoldBinaryOp<IntegerAttr>(
546       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // XOrIOp
551 //===----------------------------------------------------------------------===//
552 
553 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
554   /// xor(x, 0) -> x
555   if (matchPattern(getRhs(), m_Zero()))
556     return getLhs();
557   /// xor(x, x) -> 0
558   if (getLhs() == getRhs())
559     return Builder(getContext()).getZeroAttr(getType());
560 
561   return constFoldBinaryOp<IntegerAttr>(
562       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
563 }
564 
565 void arith::XOrIOp::getCanonicalizationPatterns(
566     OwningRewritePatternList &patterns, MLIRContext *context) {
567   patterns.insert<XOrINotCmpI>(context);
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // AddFOp
572 //===----------------------------------------------------------------------===//
573 
574 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
575   return constFoldBinaryOp<FloatAttr>(
576       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // SubFOp
581 //===----------------------------------------------------------------------===//
582 
583 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
584   return constFoldBinaryOp<FloatAttr>(
585       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // MaxSIOp
590 //===----------------------------------------------------------------------===//
591 
592 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
593   assert(operands.size() == 2 && "binary operation takes two operands");
594 
595   // maxsi(x,x) -> x
596   if (getLhs() == getRhs())
597     return getRhs();
598 
599   APInt intValue;
600   // maxsi(x,MAX_INT) -> MAX_INT
601   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
602       intValue.isMaxSignedValue())
603     return getRhs();
604 
605   // maxsi(x, MIN_INT) -> x
606   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
607       intValue.isMinSignedValue())
608     return getLhs();
609 
610   return constFoldBinaryOp<IntegerAttr>(operands,
611                                         [](const APInt &a, const APInt &b) {
612                                           return llvm::APIntOps::smax(a, b);
613                                         });
614 }
615 
616 //===----------------------------------------------------------------------===//
617 // MaxUIOp
618 //===----------------------------------------------------------------------===//
619 
620 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
621   assert(operands.size() == 2 && "binary operation takes two operands");
622 
623   // maxui(x,x) -> x
624   if (getLhs() == getRhs())
625     return getRhs();
626 
627   APInt intValue;
628   // maxui(x,MAX_INT) -> MAX_INT
629   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
630     return getRhs();
631 
632   // maxui(x, MIN_INT) -> x
633   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
634     return getLhs();
635 
636   return constFoldBinaryOp<IntegerAttr>(operands,
637                                         [](const APInt &a, const APInt &b) {
638                                           return llvm::APIntOps::umax(a, b);
639                                         });
640 }
641 
642 //===----------------------------------------------------------------------===//
643 // MinSIOp
644 //===----------------------------------------------------------------------===//
645 
646 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
647   assert(operands.size() == 2 && "binary operation takes two operands");
648 
649   // minsi(x,x) -> x
650   if (getLhs() == getRhs())
651     return getRhs();
652 
653   APInt intValue;
654   // minsi(x,MIN_INT) -> MIN_INT
655   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
656       intValue.isMinSignedValue())
657     return getRhs();
658 
659   // minsi(x, MAX_INT) -> x
660   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
661       intValue.isMaxSignedValue())
662     return getLhs();
663 
664   return constFoldBinaryOp<IntegerAttr>(operands,
665                                         [](const APInt &a, const APInt &b) {
666                                           return llvm::APIntOps::smin(a, b);
667                                         });
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // MinUIOp
672 //===----------------------------------------------------------------------===//
673 
674 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
675   assert(operands.size() == 2 && "binary operation takes two operands");
676 
677   // minui(x,x) -> x
678   if (getLhs() == getRhs())
679     return getRhs();
680 
681   APInt intValue;
682   // minui(x,MIN_INT) -> MIN_INT
683   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
684     return getRhs();
685 
686   // minui(x, MAX_INT) -> x
687   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
688     return getLhs();
689 
690   return constFoldBinaryOp<IntegerAttr>(operands,
691                                         [](const APInt &a, const APInt &b) {
692                                           return llvm::APIntOps::umin(a, b);
693                                         });
694 }
695 
696 //===----------------------------------------------------------------------===//
697 // MulFOp
698 //===----------------------------------------------------------------------===//
699 
700 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
701   return constFoldBinaryOp<FloatAttr>(
702       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // DivFOp
707 //===----------------------------------------------------------------------===//
708 
709 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
710   return constFoldBinaryOp<FloatAttr>(
711       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
712 }
713 
714 //===----------------------------------------------------------------------===//
715 // Utility functions for verifying cast ops
716 //===----------------------------------------------------------------------===//
717 
718 template <typename... Types>
719 using type_list = std::tuple<Types...> *;
720 
721 /// Returns a non-null type only if the provided type is one of the allowed
722 /// types or one of the allowed shaped types of the allowed types. Returns the
723 /// element type if a valid shaped type is provided.
724 template <typename... ShapedTypes, typename... ElementTypes>
725 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
726                               type_list<ElementTypes...>) {
727   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
728     return {};
729 
730   auto underlyingType = getElementTypeOrSelf(type);
731   if (!underlyingType.isa<ElementTypes...>())
732     return {};
733 
734   return underlyingType;
735 }
736 
737 /// Get allowed underlying types for vectors and tensors.
738 template <typename... ElementTypes>
739 static Type getTypeIfLike(Type type) {
740   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
741                            type_list<ElementTypes...>());
742 }
743 
744 /// Get allowed underlying types for vectors, tensors, and memrefs.
745 template <typename... ElementTypes>
746 static Type getTypeIfLikeOrMemRef(Type type) {
747   return getUnderlyingType(type,
748                            type_list<VectorType, TensorType, MemRefType>(),
749                            type_list<ElementTypes...>());
750 }
751 
752 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
753   return inputs.size() == 1 && outputs.size() == 1 &&
754          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // Verifiers for integer and floating point extension/truncation ops
759 //===----------------------------------------------------------------------===//
760 
761 // Extend ops can only extend to a wider type.
762 template <typename ValType, typename Op>
763 static LogicalResult verifyExtOp(Op op) {
764   Type srcType = getElementTypeOrSelf(op.getIn().getType());
765   Type dstType = getElementTypeOrSelf(op.getType());
766 
767   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
768     return op.emitError("result type ")
769            << dstType << " must be wider than operand type " << srcType;
770 
771   return success();
772 }
773 
774 // Truncate ops can only truncate to a shorter type.
775 template <typename ValType, typename Op>
776 static LogicalResult verifyTruncateOp(Op op) {
777   Type srcType = getElementTypeOrSelf(op.getIn().getType());
778   Type dstType = getElementTypeOrSelf(op.getType());
779 
780   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
781     return op.emitError("result type ")
782            << dstType << " must be shorter than operand type " << srcType;
783 
784   return success();
785 }
786 
787 /// Validate a cast that changes the width of a type.
788 template <template <typename> class WidthComparator, typename... ElementTypes>
789 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
790   if (!areValidCastInputsAndOutputs(inputs, outputs))
791     return false;
792 
793   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
794   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
795   if (!srcType || !dstType)
796     return false;
797 
798   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
799                                      srcType.getIntOrFloatBitWidth());
800 }
801 
802 //===----------------------------------------------------------------------===//
803 // ExtUIOp
804 //===----------------------------------------------------------------------===//
805 
806 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
807   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
808     return IntegerAttr::get(
809         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
810 
811   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
812     getInMutable().assign(lhs.getIn());
813     return getResult();
814   }
815 
816   return {};
817 }
818 
819 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
820   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
821 }
822 
823 //===----------------------------------------------------------------------===//
824 // ExtSIOp
825 //===----------------------------------------------------------------------===//
826 
827 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
828   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
829     return IntegerAttr::get(
830         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
831 
832   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
833     getInMutable().assign(lhs.getIn());
834     return getResult();
835   }
836 
837   return {};
838 }
839 
840 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
841   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
842 }
843 
844 void arith::ExtSIOp::getCanonicalizationPatterns(
845     OwningRewritePatternList &patterns, MLIRContext *context) {
846   patterns.insert<ExtSIOfExtUI>(context);
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // ExtFOp
851 //===----------------------------------------------------------------------===//
852 
853 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
854   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
855 }
856 
857 //===----------------------------------------------------------------------===//
858 // TruncIOp
859 //===----------------------------------------------------------------------===//
860 
861 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
862   // trunci(zexti(a)) -> a
863   // trunci(sexti(a)) -> a
864   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
865       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
866     return getOperand().getDefiningOp()->getOperand(0);
867 
868   assert(operands.size() == 1 && "unary operation takes one operand");
869 
870   if (!operands[0])
871     return {};
872 
873   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
874     return IntegerAttr::get(
875         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
876   }
877 
878   return {};
879 }
880 
881 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
882   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // TruncFOp
887 //===----------------------------------------------------------------------===//
888 
889 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
890 /// can be represented without precision loss or rounding.
891 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
892   assert(operands.size() == 1 && "unary operation takes one operand");
893 
894   auto constOperand = operands.front();
895   if (!constOperand || !constOperand.isa<FloatAttr>())
896     return {};
897 
898   // Convert to target type via 'double'.
899   double sourceValue =
900       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
901   auto targetAttr = FloatAttr::get(getType(), sourceValue);
902 
903   // Propagate if constant's value does not change after truncation.
904   if (sourceValue == targetAttr.getValue().convertToDouble())
905     return targetAttr;
906 
907   return {};
908 }
909 
910 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
911   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
912 }
913 
914 //===----------------------------------------------------------------------===//
915 // AndIOp
916 //===----------------------------------------------------------------------===//
917 
918 void arith::AndIOp::getCanonicalizationPatterns(
919     OwningRewritePatternList &patterns, MLIRContext *context) {
920   patterns.insert<AndOfExtUI, AndOfExtSI>(context);
921 }
922 
923 //===----------------------------------------------------------------------===//
924 // OrIOp
925 //===----------------------------------------------------------------------===//
926 
927 void arith::OrIOp::getCanonicalizationPatterns(
928     OwningRewritePatternList &patterns, MLIRContext *context) {
929   patterns.insert<OrOfExtUI, OrOfExtSI>(context);
930 }
931 
932 //===----------------------------------------------------------------------===//
933 // Verifiers for casts between integers and floats.
934 //===----------------------------------------------------------------------===//
935 
936 template <typename From, typename To>
937 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
938   if (!areValidCastInputsAndOutputs(inputs, outputs))
939     return false;
940 
941   auto srcType = getTypeIfLike<From>(inputs.front());
942   auto dstType = getTypeIfLike<To>(outputs.back());
943 
944   return srcType && dstType;
945 }
946 
947 //===----------------------------------------------------------------------===//
948 // UIToFPOp
949 //===----------------------------------------------------------------------===//
950 
951 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
952   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
953 }
954 
955 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
956   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
957     const APInt &api = lhs.getValue();
958     FloatType floatTy = getType().cast<FloatType>();
959     APFloat apf(floatTy.getFloatSemantics(),
960                 APInt::getZero(floatTy.getWidth()));
961     apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
962     return FloatAttr::get(floatTy, apf);
963   }
964   return {};
965 }
966 
967 //===----------------------------------------------------------------------===//
968 // SIToFPOp
969 //===----------------------------------------------------------------------===//
970 
971 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
972   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
973 }
974 
975 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
976   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
977     const APInt &api = lhs.getValue();
978     FloatType floatTy = getType().cast<FloatType>();
979     APFloat apf(floatTy.getFloatSemantics(),
980                 APInt::getZero(floatTy.getWidth()));
981     apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
982     return FloatAttr::get(floatTy, apf);
983   }
984   return {};
985 }
986 //===----------------------------------------------------------------------===//
987 // FPToUIOp
988 //===----------------------------------------------------------------------===//
989 
990 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
991   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
992 }
993 
994 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
995   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
996     const APFloat &apf = lhs.getValue();
997     IntegerType intTy = getType().cast<IntegerType>();
998     bool ignored;
999     APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1000     if (APFloat::opInvalidOp ==
1001         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1002       // Undefined behavior invoked - the destination type can't represent
1003       // the input constant.
1004       return {};
1005     }
1006     return IntegerAttr::get(getType(), api);
1007   }
1008 
1009   return {};
1010 }
1011 
1012 //===----------------------------------------------------------------------===//
1013 // FPToSIOp
1014 //===----------------------------------------------------------------------===//
1015 
1016 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1017   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1018 }
1019 
1020 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1021   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1022     const APFloat &apf = lhs.getValue();
1023     IntegerType intTy = getType().cast<IntegerType>();
1024     bool ignored;
1025     APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1026     if (APFloat::opInvalidOp ==
1027         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1028       // Undefined behavior invoked - the destination type can't represent
1029       // the input constant.
1030       return {};
1031     }
1032     return IntegerAttr::get(getType(), api);
1033   }
1034 
1035   return {};
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // IndexCastOp
1040 //===----------------------------------------------------------------------===//
1041 
1042 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1043                                            TypeRange outputs) {
1044   if (!areValidCastInputsAndOutputs(inputs, outputs))
1045     return false;
1046 
1047   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1048   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1049   if (!srcType || !dstType)
1050     return false;
1051 
1052   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1053          (srcType.isSignlessInteger() && dstType.isIndex());
1054 }
1055 
1056 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1057   // index_cast(constant) -> constant
1058   // A little hack because we go through int. Otherwise, the size of the
1059   // constant might need to change.
1060   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1061     return IntegerAttr::get(getType(), value.getInt());
1062 
1063   return {};
1064 }
1065 
1066 void arith::IndexCastOp::getCanonicalizationPatterns(
1067     OwningRewritePatternList &patterns, MLIRContext *context) {
1068   patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // BitcastOp
1073 //===----------------------------------------------------------------------===//
1074 
1075 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1076   if (!areValidCastInputsAndOutputs(inputs, outputs))
1077     return false;
1078 
1079   auto srcType =
1080       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1081   auto dstType =
1082       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1083   if (!srcType || !dstType)
1084     return false;
1085 
1086   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1087 }
1088 
1089 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1090   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1091 
1092   auto resType = getType();
1093   auto operand = operands[0];
1094   if (!operand)
1095     return {};
1096 
1097   /// Bitcast dense elements.
1098   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1099     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1100   /// Other shaped types unhandled.
1101   if (resType.isa<ShapedType>())
1102     return {};
1103 
1104   /// Bitcast integer or float to integer or float.
1105   APInt bits = operand.isa<FloatAttr>()
1106                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1107                    : operand.cast<IntegerAttr>().getValue();
1108 
1109   if (auto resFloatType = resType.dyn_cast<FloatType>())
1110     return FloatAttr::get(resType,
1111                           APFloat(resFloatType.getFloatSemantics(), bits));
1112   return IntegerAttr::get(resType, bits);
1113 }
1114 
1115 void arith::BitcastOp::getCanonicalizationPatterns(
1116     OwningRewritePatternList &patterns, MLIRContext *context) {
1117   patterns.insert<BitcastOfBitcast>(context);
1118 }
1119 
1120 //===----------------------------------------------------------------------===//
1121 // Helpers for compare ops
1122 //===----------------------------------------------------------------------===//
1123 
1124 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1125 static Type getI1SameShape(Type type) {
1126   auto i1Type = IntegerType::get(type.getContext(), 1);
1127   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1128     return RankedTensorType::get(tensorType.getShape(), i1Type);
1129   if (type.isa<UnrankedTensorType>())
1130     return UnrankedTensorType::get(i1Type);
1131   if (auto vectorType = type.dyn_cast<VectorType>())
1132     return VectorType::get(vectorType.getShape(), i1Type,
1133                            vectorType.getNumScalableDims());
1134   return i1Type;
1135 }
1136 
1137 //===----------------------------------------------------------------------===//
1138 // CmpIOp
1139 //===----------------------------------------------------------------------===//
1140 
1141 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1142 /// comparison predicates.
1143 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1144                                     const APInt &lhs, const APInt &rhs) {
1145   switch (predicate) {
1146   case arith::CmpIPredicate::eq:
1147     return lhs.eq(rhs);
1148   case arith::CmpIPredicate::ne:
1149     return lhs.ne(rhs);
1150   case arith::CmpIPredicate::slt:
1151     return lhs.slt(rhs);
1152   case arith::CmpIPredicate::sle:
1153     return lhs.sle(rhs);
1154   case arith::CmpIPredicate::sgt:
1155     return lhs.sgt(rhs);
1156   case arith::CmpIPredicate::sge:
1157     return lhs.sge(rhs);
1158   case arith::CmpIPredicate::ult:
1159     return lhs.ult(rhs);
1160   case arith::CmpIPredicate::ule:
1161     return lhs.ule(rhs);
1162   case arith::CmpIPredicate::ugt:
1163     return lhs.ugt(rhs);
1164   case arith::CmpIPredicate::uge:
1165     return lhs.uge(rhs);
1166   }
1167   llvm_unreachable("unknown cmpi predicate kind");
1168 }
1169 
1170 /// Returns true if the predicate is true for two equal operands.
1171 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1172   switch (predicate) {
1173   case arith::CmpIPredicate::eq:
1174   case arith::CmpIPredicate::sle:
1175   case arith::CmpIPredicate::sge:
1176   case arith::CmpIPredicate::ule:
1177   case arith::CmpIPredicate::uge:
1178     return true;
1179   case arith::CmpIPredicate::ne:
1180   case arith::CmpIPredicate::slt:
1181   case arith::CmpIPredicate::sgt:
1182   case arith::CmpIPredicate::ult:
1183   case arith::CmpIPredicate::ugt:
1184     return false;
1185   }
1186   llvm_unreachable("unknown cmpi predicate kind");
1187 }
1188 
1189 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1190   auto boolAttr = BoolAttr::get(ctx, value);
1191   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1192   if (!shapedType)
1193     return boolAttr;
1194   return DenseElementsAttr::get(shapedType, boolAttr);
1195 }
1196 
1197 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1198   assert(operands.size() == 2 && "cmpi takes two operands");
1199 
1200   // cmpi(pred, x, x)
1201   if (getLhs() == getRhs()) {
1202     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1203     return getBoolAttribute(getType(), getContext(), val);
1204   }
1205 
1206   if (matchPattern(getRhs(), m_Zero())) {
1207     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1208       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1209         // extsi(%x : i1 -> iN) != 0  ->  %x
1210         if (getPredicate() == arith::CmpIPredicate::ne) {
1211           return extOp.getOperand();
1212         }
1213       }
1214     }
1215     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1216       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1217         // extui(%x : i1 -> iN) != 0  ->  %x
1218         if (getPredicate() == arith::CmpIPredicate::ne) {
1219           return extOp.getOperand();
1220         }
1221       }
1222     }
1223   }
1224 
1225   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1226   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1227   if (!lhs || !rhs)
1228     return {};
1229 
1230   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1231   return BoolAttr::get(getContext(), val);
1232 }
1233 
1234 //===----------------------------------------------------------------------===//
1235 // CmpFOp
1236 //===----------------------------------------------------------------------===//
1237 
1238 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1239 /// comparison predicates.
1240 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1241                                     const APFloat &lhs, const APFloat &rhs) {
1242   auto cmpResult = lhs.compare(rhs);
1243   switch (predicate) {
1244   case arith::CmpFPredicate::AlwaysFalse:
1245     return false;
1246   case arith::CmpFPredicate::OEQ:
1247     return cmpResult == APFloat::cmpEqual;
1248   case arith::CmpFPredicate::OGT:
1249     return cmpResult == APFloat::cmpGreaterThan;
1250   case arith::CmpFPredicate::OGE:
1251     return cmpResult == APFloat::cmpGreaterThan ||
1252            cmpResult == APFloat::cmpEqual;
1253   case arith::CmpFPredicate::OLT:
1254     return cmpResult == APFloat::cmpLessThan;
1255   case arith::CmpFPredicate::OLE:
1256     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1257   case arith::CmpFPredicate::ONE:
1258     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1259   case arith::CmpFPredicate::ORD:
1260     return cmpResult != APFloat::cmpUnordered;
1261   case arith::CmpFPredicate::UEQ:
1262     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1263   case arith::CmpFPredicate::UGT:
1264     return cmpResult == APFloat::cmpUnordered ||
1265            cmpResult == APFloat::cmpGreaterThan;
1266   case arith::CmpFPredicate::UGE:
1267     return cmpResult == APFloat::cmpUnordered ||
1268            cmpResult == APFloat::cmpGreaterThan ||
1269            cmpResult == APFloat::cmpEqual;
1270   case arith::CmpFPredicate::ULT:
1271     return cmpResult == APFloat::cmpUnordered ||
1272            cmpResult == APFloat::cmpLessThan;
1273   case arith::CmpFPredicate::ULE:
1274     return cmpResult == APFloat::cmpUnordered ||
1275            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1276   case arith::CmpFPredicate::UNE:
1277     return cmpResult != APFloat::cmpEqual;
1278   case arith::CmpFPredicate::UNO:
1279     return cmpResult == APFloat::cmpUnordered;
1280   case arith::CmpFPredicate::AlwaysTrue:
1281     return true;
1282   }
1283   llvm_unreachable("unknown cmpf predicate kind");
1284 }
1285 
1286 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1287   assert(operands.size() == 2 && "cmpf takes two operands");
1288 
1289   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1290   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1291 
1292   if (!lhs || !rhs)
1293     return {};
1294 
1295   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1296   return BoolAttr::get(getContext(), val);
1297 }
1298 
1299 //===----------------------------------------------------------------------===//
1300 // Atomic Enum
1301 //===----------------------------------------------------------------------===//
1302 
1303 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1304 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1305                                             OpBuilder &builder, Location loc) {
1306   switch (kind) {
1307   case AtomicRMWKind::maxf:
1308     return builder.getFloatAttr(
1309         resultType,
1310         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1311                         /*Negative=*/true));
1312   case AtomicRMWKind::addf:
1313   case AtomicRMWKind::addi:
1314   case AtomicRMWKind::maxu:
1315   case AtomicRMWKind::ori:
1316     return builder.getZeroAttr(resultType);
1317   case AtomicRMWKind::andi:
1318     return builder.getIntegerAttr(
1319         resultType,
1320         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1321   case AtomicRMWKind::maxs:
1322     return builder.getIntegerAttr(
1323         resultType,
1324         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1325   case AtomicRMWKind::minf:
1326     return builder.getFloatAttr(
1327         resultType,
1328         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1329                         /*Negative=*/false));
1330   case AtomicRMWKind::mins:
1331     return builder.getIntegerAttr(
1332         resultType,
1333         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1334   case AtomicRMWKind::minu:
1335     return builder.getIntegerAttr(
1336         resultType,
1337         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1338   case AtomicRMWKind::muli:
1339     return builder.getIntegerAttr(resultType, 1);
1340   case AtomicRMWKind::mulf:
1341     return builder.getFloatAttr(resultType, 1);
1342   // TODO: Add remaining reduction operations.
1343   default:
1344     (void)emitOptionalError(loc, "Reduction operation type not supported");
1345     break;
1346   }
1347   return nullptr;
1348 }
1349 
1350 /// Returns the identity value associated with an AtomicRMWKind op.
1351 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1352                                     OpBuilder &builder, Location loc) {
1353   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1354   return builder.create<arith::ConstantOp>(loc, attr);
1355 }
1356 
1357 /// Return the value obtained by applying the reduction operation kind
1358 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1359 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1360                                   Location loc, Value lhs, Value rhs) {
1361   switch (op) {
1362   case AtomicRMWKind::addf:
1363     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1364   case AtomicRMWKind::addi:
1365     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1366   case AtomicRMWKind::mulf:
1367     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1368   case AtomicRMWKind::muli:
1369     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1370   case AtomicRMWKind::maxf:
1371     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1372   case AtomicRMWKind::minf:
1373     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1374   case AtomicRMWKind::maxs:
1375     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1376   case AtomicRMWKind::mins:
1377     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1378   case AtomicRMWKind::maxu:
1379     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1380   case AtomicRMWKind::minu:
1381     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1382   case AtomicRMWKind::ori:
1383     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1384   case AtomicRMWKind::andi:
1385     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1386   // TODO: Add remaining reduction operations.
1387   default:
1388     (void)emitOptionalError(loc, "Reduction operation type not supported");
1389     break;
1390   }
1391   return nullptr;
1392 }
1393 
1394 //===----------------------------------------------------------------------===//
1395 // TableGen'd op method definitions
1396 //===----------------------------------------------------------------------===//
1397 
1398 #define GET_OP_CLASSES
1399 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1400 
1401 //===----------------------------------------------------------------------===//
1402 // TableGen'd enum attribute definitions
1403 //===----------------------------------------------------------------------===//
1404 
1405 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1406