1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #include "llvm/ADT/APSInt.h"
20 
21 using namespace mlir;
22 using namespace mlir::arith;
23 
24 //===----------------------------------------------------------------------===//
25 // Pattern helpers
26 //===----------------------------------------------------------------------===//
27 
28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
29                                    Attribute lhs, Attribute rhs) {
30   return builder.getIntegerAttr(res.getType(),
31                                 lhs.cast<IntegerAttr>().getInt() +
32                                     rhs.cast<IntegerAttr>().getInt());
33 }
34 
35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
36                                    Attribute lhs, Attribute rhs) {
37   return builder.getIntegerAttr(res.getType(),
38                                 lhs.cast<IntegerAttr>().getInt() -
39                                     rhs.cast<IntegerAttr>().getInt());
40 }
41 
42 /// Invert an integer comparison predicate.
43 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
44   switch (pred) {
45   case arith::CmpIPredicate::eq:
46     return arith::CmpIPredicate::ne;
47   case arith::CmpIPredicate::ne:
48     return arith::CmpIPredicate::eq;
49   case arith::CmpIPredicate::slt:
50     return arith::CmpIPredicate::sge;
51   case arith::CmpIPredicate::sle:
52     return arith::CmpIPredicate::sgt;
53   case arith::CmpIPredicate::sgt:
54     return arith::CmpIPredicate::sle;
55   case arith::CmpIPredicate::sge:
56     return arith::CmpIPredicate::slt;
57   case arith::CmpIPredicate::ult:
58     return arith::CmpIPredicate::uge;
59   case arith::CmpIPredicate::ule:
60     return arith::CmpIPredicate::ugt;
61   case arith::CmpIPredicate::ugt:
62     return arith::CmpIPredicate::ule;
63   case arith::CmpIPredicate::uge:
64     return arith::CmpIPredicate::ult;
65   }
66   llvm_unreachable("unknown cmpi predicate kind");
67 }
68 
69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
70   return arith::CmpIPredicateAttr::get(pred.getContext(),
71                                        invertPredicate(pred.getValue()));
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // TableGen'd canonicalization patterns
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 #include "ArithmeticCanonicalization.inc"
80 } // namespace
81 
82 //===----------------------------------------------------------------------===//
83 // ConstantOp
84 //===----------------------------------------------------------------------===//
85 
86 void arith::ConstantOp::getAsmResultNames(
87     function_ref<void(Value, StringRef)> setNameFn) {
88   auto type = getType();
89   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
90     auto intType = type.dyn_cast<IntegerType>();
91 
92     // Sugar i1 constants with 'true' and 'false'.
93     if (intType && intType.getWidth() == 1)
94       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
95 
96     // Otherwise, build a compex name with the value and type.
97     SmallString<32> specialNameBuffer;
98     llvm::raw_svector_ostream specialName(specialNameBuffer);
99     specialName << 'c' << intCst.getInt();
100     if (intType)
101       specialName << '_' << type;
102     setNameFn(getResult(), specialName.str());
103   } else {
104     setNameFn(getResult(), "cst");
105   }
106 }
107 
108 /// TODO: disallow arith.constant to return anything other than signless integer
109 /// or float like.
110 static LogicalResult verify(arith::ConstantOp op) {
111   auto type = op.getType();
112   // The value's type must match the return type.
113   if (op.getValue().getType() != type) {
114     return op.emitOpError() << "value type " << op.getValue().getType()
115                             << " must match return type: " << type;
116   }
117   // Integer values must be signless.
118   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
119     return op.emitOpError("integer return type must be signless");
120   // Any float or elements attribute are acceptable.
121   if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
122     return op.emitOpError(
123         "value must be an integer, float, or elements attribute");
124   }
125   return success();
126 }
127 
128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
129   // The value's type must be the same as the provided type.
130   if (value.getType() != type)
131     return false;
132   // Integer values must be signless.
133   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
134     return false;
135   // Integer, float, and element attributes are buildable.
136   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
137 }
138 
139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
140   return getValue();
141 }
142 
143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
144                                  int64_t value, unsigned width) {
145   auto type = builder.getIntegerType(width);
146   arith::ConstantOp::build(builder, result, type,
147                            builder.getIntegerAttr(type, value));
148 }
149 
150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
151                                  int64_t value, Type type) {
152   assert(type.isSignlessInteger() &&
153          "ConstantIntOp can only have signless integer type values");
154   arith::ConstantOp::build(builder, result, type,
155                            builder.getIntegerAttr(type, value));
156 }
157 
158 bool arith::ConstantIntOp::classof(Operation *op) {
159   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
160     return constOp.getType().isSignlessInteger();
161   return false;
162 }
163 
164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
165                                    const APFloat &value, FloatType type) {
166   arith::ConstantOp::build(builder, result, type,
167                            builder.getFloatAttr(type, value));
168 }
169 
170 bool arith::ConstantFloatOp::classof(Operation *op) {
171   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
172     return constOp.getType().isa<FloatType>();
173   return false;
174 }
175 
176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
177                                    int64_t value) {
178   arith::ConstantOp::build(builder, result, builder.getIndexType(),
179                            builder.getIndexAttr(value));
180 }
181 
182 bool arith::ConstantIndexOp::classof(Operation *op) {
183   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
184     return constOp.getType().isIndex();
185   return false;
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // AddIOp
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
193   // addi(x, 0) -> x
194   if (matchPattern(getRhs(), m_Zero()))
195     return getLhs();
196 
197   // add(sub(a, b), b) -> a
198   if (auto sub = getLhs().getDefiningOp<SubIOp>())
199     if (getRhs() == sub.getRhs())
200       return sub.getLhs();
201 
202   // add(b, sub(a, b)) -> a
203   if (auto sub = getRhs().getDefiningOp<SubIOp>())
204     if (getLhs() == sub.getRhs())
205       return sub.getLhs();
206 
207   return constFoldBinaryOp<IntegerAttr>(
208       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
209 }
210 
211 void arith::AddIOp::getCanonicalizationPatterns(
212     OwningRewritePatternList &patterns, MLIRContext *context) {
213   patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
214       context);
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // SubIOp
219 //===----------------------------------------------------------------------===//
220 
221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
222   // subi(x,x) -> 0
223   if (getOperand(0) == getOperand(1))
224     return Builder(getContext()).getZeroAttr(getType());
225   // subi(x,0) -> x
226   if (matchPattern(getRhs(), m_Zero()))
227     return getLhs();
228 
229   return constFoldBinaryOp<IntegerAttr>(
230       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
231 }
232 
233 void arith::SubIOp::getCanonicalizationPatterns(
234     OwningRewritePatternList &patterns, MLIRContext *context) {
235   patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
236                   SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
237                   SubILHSSubConstantLHS>(context);
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // MulIOp
242 //===----------------------------------------------------------------------===//
243 
244 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
245   // muli(x, 0) -> 0
246   if (matchPattern(getRhs(), m_Zero()))
247     return getRhs();
248   // muli(x, 1) -> x
249   if (matchPattern(getRhs(), m_One()))
250     return getOperand(0);
251   // TODO: Handle the overflow case.
252 
253   // default folder
254   return constFoldBinaryOp<IntegerAttr>(
255       operands, [](const APInt &a, const APInt &b) { return a * b; });
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // DivUIOp
260 //===----------------------------------------------------------------------===//
261 
262 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
263   // Don't fold if it would require a division by zero.
264   bool div0 = false;
265   auto result =
266       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
267         if (div0 || !b) {
268           div0 = true;
269           return a;
270         }
271         return a.udiv(b);
272       });
273 
274   // Fold out division by one. Assumes all tensors of all ones are splats.
275   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
276     if (rhs.getValue() == 1)
277       return getLhs();
278   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
279     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
280       return getLhs();
281   }
282 
283   return div0 ? Attribute() : result;
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // DivSIOp
288 //===----------------------------------------------------------------------===//
289 
290 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
291   // Don't fold if it would overflow or if it requires a division by zero.
292   bool overflowOrDiv0 = false;
293   auto result =
294       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
295         if (overflowOrDiv0 || !b) {
296           overflowOrDiv0 = true;
297           return a;
298         }
299         return a.sdiv_ov(b, overflowOrDiv0);
300       });
301 
302   // Fold out division by one. Assumes all tensors of all ones are splats.
303   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
304     if (rhs.getValue() == 1)
305       return getLhs();
306   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
307     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
308       return getLhs();
309   }
310 
311   return overflowOrDiv0 ? Attribute() : result;
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Ceil and floor division folding helpers
316 //===----------------------------------------------------------------------===//
317 
318 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
319                                     bool &overflow) {
320   // Returns (a-1)/b + 1
321   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
322   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
323   return val.sadd_ov(one, overflow);
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // CeilDivUIOp
328 //===----------------------------------------------------------------------===//
329 
330 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
331   bool overflowOrDiv0 = false;
332   auto result =
333       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
334         if (overflowOrDiv0 || !b) {
335           overflowOrDiv0 = true;
336           return a;
337         }
338         APInt quotient = a.udiv(b);
339         if (!a.urem(b))
340           return quotient;
341         APInt one(a.getBitWidth(), 1, true);
342         return quotient.uadd_ov(one, overflowOrDiv0);
343       });
344   // Fold out ceil division by one. Assumes all tensors of all ones are
345   // splats.
346   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
347     if (rhs.getValue() == 1)
348       return getLhs();
349   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
350     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
351       return getLhs();
352   }
353 
354   return overflowOrDiv0 ? Attribute() : result;
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // CeilDivSIOp
359 //===----------------------------------------------------------------------===//
360 
361 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
362   // Don't fold if it would overflow or if it requires a division by zero.
363   bool overflowOrDiv0 = false;
364   auto result =
365       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
366         if (overflowOrDiv0 || !b) {
367           overflowOrDiv0 = true;
368           return a;
369         }
370         if (!a)
371           return a;
372         // After this point we know that neither a or b are zero.
373         unsigned bits = a.getBitWidth();
374         APInt zero = APInt::getZero(bits);
375         bool aGtZero = a.sgt(zero);
376         bool bGtZero = b.sgt(zero);
377         if (aGtZero && bGtZero) {
378           // Both positive, return ceil(a, b).
379           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
380         }
381         if (!aGtZero && !bGtZero) {
382           // Both negative, return ceil(-a, -b).
383           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
384           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
385           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
386         }
387         if (!aGtZero && bGtZero) {
388           // A is negative, b is positive, return - ( -a / b).
389           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
390           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
391           return zero.ssub_ov(div, overflowOrDiv0);
392         }
393         // A is positive, b is negative, return - (a / -b).
394         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
395         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
396         return zero.ssub_ov(div, overflowOrDiv0);
397       });
398 
399   // Fold out ceil division by one. Assumes all tensors of all ones are
400   // splats.
401   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
402     if (rhs.getValue() == 1)
403       return getLhs();
404   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
405     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
406       return getLhs();
407   }
408 
409   return overflowOrDiv0 ? Attribute() : result;
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // FloorDivSIOp
414 //===----------------------------------------------------------------------===//
415 
416 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
417   // Don't fold if it would overflow or if it requires a division by zero.
418   bool overflowOrDiv0 = false;
419   auto result =
420       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
421         if (overflowOrDiv0 || !b) {
422           overflowOrDiv0 = true;
423           return a;
424         }
425         if (!a)
426           return a;
427         // After this point we know that neither a or b are zero.
428         unsigned bits = a.getBitWidth();
429         APInt zero = APInt::getZero(bits);
430         bool aGtZero = a.sgt(zero);
431         bool bGtZero = b.sgt(zero);
432         if (aGtZero && bGtZero) {
433           // Both positive, return a / b.
434           return a.sdiv_ov(b, overflowOrDiv0);
435         }
436         if (!aGtZero && !bGtZero) {
437           // Both negative, return -a / -b.
438           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
439           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
440           return posA.sdiv_ov(posB, overflowOrDiv0);
441         }
442         if (!aGtZero && bGtZero) {
443           // A is negative, b is positive, return - ceil(-a, b).
444           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
445           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
446           return zero.ssub_ov(ceil, overflowOrDiv0);
447         }
448         // A is positive, b is negative, return - ceil(a, -b).
449         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
450         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
451         return zero.ssub_ov(ceil, overflowOrDiv0);
452       });
453 
454   // Fold out floor division by one. Assumes all tensors of all ones are
455   // splats.
456   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
457     if (rhs.getValue() == 1)
458       return getLhs();
459   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
460     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
461       return getLhs();
462   }
463 
464   return overflowOrDiv0 ? Attribute() : result;
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // RemUIOp
469 //===----------------------------------------------------------------------===//
470 
471 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
472   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
473   if (!rhs)
474     return {};
475   auto rhsValue = rhs.getValue();
476 
477   // x % 1 = 0
478   if (rhsValue.isOneValue())
479     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
480 
481   // Don't fold if it requires division by zero.
482   if (rhsValue.isNullValue())
483     return {};
484 
485   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
486   if (!lhs)
487     return {};
488   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // RemSIOp
493 //===----------------------------------------------------------------------===//
494 
495 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
496   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
497   if (!rhs)
498     return {};
499   auto rhsValue = rhs.getValue();
500 
501   // x % 1 = 0
502   if (rhsValue.isOneValue())
503     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
504 
505   // Don't fold if it requires division by zero.
506   if (rhsValue.isNullValue())
507     return {};
508 
509   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
510   if (!lhs)
511     return {};
512   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // AndIOp
517 //===----------------------------------------------------------------------===//
518 
519 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
520   /// and(x, 0) -> 0
521   if (matchPattern(getRhs(), m_Zero()))
522     return getRhs();
523   /// and(x, allOnes) -> x
524   APInt intValue;
525   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
526     return getLhs();
527 
528   return constFoldBinaryOp<IntegerAttr>(
529       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // OrIOp
534 //===----------------------------------------------------------------------===//
535 
536 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
537   /// or(x, 0) -> x
538   if (matchPattern(getRhs(), m_Zero()))
539     return getLhs();
540   /// or(x, <all ones>) -> <all ones>
541   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
542     if (rhsAttr.getValue().isAllOnes())
543       return rhsAttr;
544 
545   return constFoldBinaryOp<IntegerAttr>(
546       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // XOrIOp
551 //===----------------------------------------------------------------------===//
552 
553 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
554   /// xor(x, 0) -> x
555   if (matchPattern(getRhs(), m_Zero()))
556     return getLhs();
557   /// xor(x, x) -> 0
558   if (getLhs() == getRhs())
559     return Builder(getContext()).getZeroAttr(getType());
560   /// xor(xor(x, a), a) -> x
561   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
562     if (prev.getRhs() == getRhs())
563       return prev.getLhs();
564 
565   return constFoldBinaryOp<IntegerAttr>(
566       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
567 }
568 
569 void arith::XOrIOp::getCanonicalizationPatterns(
570     OwningRewritePatternList &patterns, MLIRContext *context) {
571   patterns.insert<XOrINotCmpI>(context);
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // AddFOp
576 //===----------------------------------------------------------------------===//
577 
578 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
579   return constFoldBinaryOp<FloatAttr>(
580       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
581 }
582 
583 //===----------------------------------------------------------------------===//
584 // SubFOp
585 //===----------------------------------------------------------------------===//
586 
587 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
588   return constFoldBinaryOp<FloatAttr>(
589       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // MaxSIOp
594 //===----------------------------------------------------------------------===//
595 
596 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
597   assert(operands.size() == 2 && "binary operation takes two operands");
598 
599   // maxsi(x,x) -> x
600   if (getLhs() == getRhs())
601     return getRhs();
602 
603   APInt intValue;
604   // maxsi(x,MAX_INT) -> MAX_INT
605   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
606       intValue.isMaxSignedValue())
607     return getRhs();
608 
609   // maxsi(x, MIN_INT) -> x
610   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
611       intValue.isMinSignedValue())
612     return getLhs();
613 
614   return constFoldBinaryOp<IntegerAttr>(operands,
615                                         [](const APInt &a, const APInt &b) {
616                                           return llvm::APIntOps::smax(a, b);
617                                         });
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // MaxUIOp
622 //===----------------------------------------------------------------------===//
623 
624 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
625   assert(operands.size() == 2 && "binary operation takes two operands");
626 
627   // maxui(x,x) -> x
628   if (getLhs() == getRhs())
629     return getRhs();
630 
631   APInt intValue;
632   // maxui(x,MAX_INT) -> MAX_INT
633   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
634     return getRhs();
635 
636   // maxui(x, MIN_INT) -> x
637   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
638     return getLhs();
639 
640   return constFoldBinaryOp<IntegerAttr>(operands,
641                                         [](const APInt &a, const APInt &b) {
642                                           return llvm::APIntOps::umax(a, b);
643                                         });
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // MinSIOp
648 //===----------------------------------------------------------------------===//
649 
650 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
651   assert(operands.size() == 2 && "binary operation takes two operands");
652 
653   // minsi(x,x) -> x
654   if (getLhs() == getRhs())
655     return getRhs();
656 
657   APInt intValue;
658   // minsi(x,MIN_INT) -> MIN_INT
659   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
660       intValue.isMinSignedValue())
661     return getRhs();
662 
663   // minsi(x, MAX_INT) -> x
664   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
665       intValue.isMaxSignedValue())
666     return getLhs();
667 
668   return constFoldBinaryOp<IntegerAttr>(operands,
669                                         [](const APInt &a, const APInt &b) {
670                                           return llvm::APIntOps::smin(a, b);
671                                         });
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // MinUIOp
676 //===----------------------------------------------------------------------===//
677 
678 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
679   assert(operands.size() == 2 && "binary operation takes two operands");
680 
681   // minui(x,x) -> x
682   if (getLhs() == getRhs())
683     return getRhs();
684 
685   APInt intValue;
686   // minui(x,MIN_INT) -> MIN_INT
687   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
688     return getRhs();
689 
690   // minui(x, MAX_INT) -> x
691   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
692     return getLhs();
693 
694   return constFoldBinaryOp<IntegerAttr>(operands,
695                                         [](const APInt &a, const APInt &b) {
696                                           return llvm::APIntOps::umin(a, b);
697                                         });
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // MulFOp
702 //===----------------------------------------------------------------------===//
703 
704 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
705   return constFoldBinaryOp<FloatAttr>(
706       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // DivFOp
711 //===----------------------------------------------------------------------===//
712 
713 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
714   return constFoldBinaryOp<FloatAttr>(
715       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
716 }
717 
718 //===----------------------------------------------------------------------===//
719 // Utility functions for verifying cast ops
720 //===----------------------------------------------------------------------===//
721 
722 template <typename... Types>
723 using type_list = std::tuple<Types...> *;
724 
725 /// Returns a non-null type only if the provided type is one of the allowed
726 /// types or one of the allowed shaped types of the allowed types. Returns the
727 /// element type if a valid shaped type is provided.
728 template <typename... ShapedTypes, typename... ElementTypes>
729 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
730                               type_list<ElementTypes...>) {
731   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
732     return {};
733 
734   auto underlyingType = getElementTypeOrSelf(type);
735   if (!underlyingType.isa<ElementTypes...>())
736     return {};
737 
738   return underlyingType;
739 }
740 
741 /// Get allowed underlying types for vectors and tensors.
742 template <typename... ElementTypes>
743 static Type getTypeIfLike(Type type) {
744   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
745                            type_list<ElementTypes...>());
746 }
747 
748 /// Get allowed underlying types for vectors, tensors, and memrefs.
749 template <typename... ElementTypes>
750 static Type getTypeIfLikeOrMemRef(Type type) {
751   return getUnderlyingType(type,
752                            type_list<VectorType, TensorType, MemRefType>(),
753                            type_list<ElementTypes...>());
754 }
755 
756 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
757   return inputs.size() == 1 && outputs.size() == 1 &&
758          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // Verifiers for integer and floating point extension/truncation ops
763 //===----------------------------------------------------------------------===//
764 
765 // Extend ops can only extend to a wider type.
766 template <typename ValType, typename Op>
767 static LogicalResult verifyExtOp(Op op) {
768   Type srcType = getElementTypeOrSelf(op.getIn().getType());
769   Type dstType = getElementTypeOrSelf(op.getType());
770 
771   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
772     return op.emitError("result type ")
773            << dstType << " must be wider than operand type " << srcType;
774 
775   return success();
776 }
777 
778 // Truncate ops can only truncate to a shorter type.
779 template <typename ValType, typename Op>
780 static LogicalResult verifyTruncateOp(Op op) {
781   Type srcType = getElementTypeOrSelf(op.getIn().getType());
782   Type dstType = getElementTypeOrSelf(op.getType());
783 
784   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
785     return op.emitError("result type ")
786            << dstType << " must be shorter than operand type " << srcType;
787 
788   return success();
789 }
790 
791 /// Validate a cast that changes the width of a type.
792 template <template <typename> class WidthComparator, typename... ElementTypes>
793 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
794   if (!areValidCastInputsAndOutputs(inputs, outputs))
795     return false;
796 
797   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
798   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
799   if (!srcType || !dstType)
800     return false;
801 
802   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
803                                      srcType.getIntOrFloatBitWidth());
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // ExtUIOp
808 //===----------------------------------------------------------------------===//
809 
810 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
811   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
812     return IntegerAttr::get(
813         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
814 
815   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
816     getInMutable().assign(lhs.getIn());
817     return getResult();
818   }
819 
820   return {};
821 }
822 
823 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
824   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
825 }
826 
827 //===----------------------------------------------------------------------===//
828 // ExtSIOp
829 //===----------------------------------------------------------------------===//
830 
831 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
832   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
833     return IntegerAttr::get(
834         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
835 
836   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
837     getInMutable().assign(lhs.getIn());
838     return getResult();
839   }
840 
841   return {};
842 }
843 
844 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
845   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
846 }
847 
848 void arith::ExtSIOp::getCanonicalizationPatterns(
849     OwningRewritePatternList &patterns, MLIRContext *context) {
850   patterns.insert<ExtSIOfExtUI>(context);
851 }
852 
853 //===----------------------------------------------------------------------===//
854 // ExtFOp
855 //===----------------------------------------------------------------------===//
856 
857 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
858   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
859 }
860 
861 //===----------------------------------------------------------------------===//
862 // TruncIOp
863 //===----------------------------------------------------------------------===//
864 
865 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
866   assert(operands.size() == 1 && "unary operation takes one operand");
867 
868   // trunci(zexti(a)) -> a
869   // trunci(sexti(a)) -> a
870   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
871       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
872     return getOperand().getDefiningOp()->getOperand(0);
873 
874   // trunci(trunci(a)) -> trunci(a))
875   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
876     setOperand(getOperand().getDefiningOp()->getOperand(0));
877     return getResult();
878   }
879 
880   if (!operands[0])
881     return {};
882 
883   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
884     return IntegerAttr::get(
885         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
886   }
887 
888   return {};
889 }
890 
891 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
892   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
893 }
894 
895 //===----------------------------------------------------------------------===//
896 // TruncFOp
897 //===----------------------------------------------------------------------===//
898 
899 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
900 /// can be represented without precision loss or rounding.
901 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
902   assert(operands.size() == 1 && "unary operation takes one operand");
903 
904   auto constOperand = operands.front();
905   if (!constOperand || !constOperand.isa<FloatAttr>())
906     return {};
907 
908   // Convert to target type via 'double'.
909   double sourceValue =
910       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
911   auto targetAttr = FloatAttr::get(getType(), sourceValue);
912 
913   // Propagate if constant's value does not change after truncation.
914   if (sourceValue == targetAttr.getValue().convertToDouble())
915     return targetAttr;
916 
917   return {};
918 }
919 
920 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
921   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // AndIOp
926 //===----------------------------------------------------------------------===//
927 
928 void arith::AndIOp::getCanonicalizationPatterns(
929     OwningRewritePatternList &patterns, MLIRContext *context) {
930   patterns.insert<AndOfExtUI, AndOfExtSI>(context);
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // OrIOp
935 //===----------------------------------------------------------------------===//
936 
937 void arith::OrIOp::getCanonicalizationPatterns(
938     OwningRewritePatternList &patterns, MLIRContext *context) {
939   patterns.insert<OrOfExtUI, OrOfExtSI>(context);
940 }
941 
942 //===----------------------------------------------------------------------===//
943 // Verifiers for casts between integers and floats.
944 //===----------------------------------------------------------------------===//
945 
946 template <typename From, typename To>
947 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
948   if (!areValidCastInputsAndOutputs(inputs, outputs))
949     return false;
950 
951   auto srcType = getTypeIfLike<From>(inputs.front());
952   auto dstType = getTypeIfLike<To>(outputs.back());
953 
954   return srcType && dstType;
955 }
956 
957 //===----------------------------------------------------------------------===//
958 // UIToFPOp
959 //===----------------------------------------------------------------------===//
960 
961 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
962   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
963 }
964 
965 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
966   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
967     const APInt &api = lhs.getValue();
968     FloatType floatTy = getType().cast<FloatType>();
969     APFloat apf(floatTy.getFloatSemantics(),
970                 APInt::getZero(floatTy.getWidth()));
971     apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
972     return FloatAttr::get(floatTy, apf);
973   }
974   return {};
975 }
976 
977 //===----------------------------------------------------------------------===//
978 // SIToFPOp
979 //===----------------------------------------------------------------------===//
980 
981 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
982   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
983 }
984 
985 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
986   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
987     const APInt &api = lhs.getValue();
988     FloatType floatTy = getType().cast<FloatType>();
989     APFloat apf(floatTy.getFloatSemantics(),
990                 APInt::getZero(floatTy.getWidth()));
991     apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
992     return FloatAttr::get(floatTy, apf);
993   }
994   return {};
995 }
996 //===----------------------------------------------------------------------===//
997 // FPToUIOp
998 //===----------------------------------------------------------------------===//
999 
1000 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1001   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1002 }
1003 
1004 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1005   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1006     const APFloat &apf = lhs.getValue();
1007     IntegerType intTy = getType().cast<IntegerType>();
1008     bool ignored;
1009     APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1010     if (APFloat::opInvalidOp ==
1011         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1012       // Undefined behavior invoked - the destination type can't represent
1013       // the input constant.
1014       return {};
1015     }
1016     return IntegerAttr::get(getType(), api);
1017   }
1018 
1019   return {};
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // FPToSIOp
1024 //===----------------------------------------------------------------------===//
1025 
1026 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1027   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1028 }
1029 
1030 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1031   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1032     const APFloat &apf = lhs.getValue();
1033     IntegerType intTy = getType().cast<IntegerType>();
1034     bool ignored;
1035     APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1036     if (APFloat::opInvalidOp ==
1037         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1038       // Undefined behavior invoked - the destination type can't represent
1039       // the input constant.
1040       return {};
1041     }
1042     return IntegerAttr::get(getType(), api);
1043   }
1044 
1045   return {};
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // IndexCastOp
1050 //===----------------------------------------------------------------------===//
1051 
1052 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1053                                            TypeRange outputs) {
1054   if (!areValidCastInputsAndOutputs(inputs, outputs))
1055     return false;
1056 
1057   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1058   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1059   if (!srcType || !dstType)
1060     return false;
1061 
1062   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1063          (srcType.isSignlessInteger() && dstType.isIndex());
1064 }
1065 
1066 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1067   // index_cast(constant) -> constant
1068   // A little hack because we go through int. Otherwise, the size of the
1069   // constant might need to change.
1070   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1071     return IntegerAttr::get(getType(), value.getInt());
1072 
1073   return {};
1074 }
1075 
1076 void arith::IndexCastOp::getCanonicalizationPatterns(
1077     OwningRewritePatternList &patterns, MLIRContext *context) {
1078   patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1079 }
1080 
1081 //===----------------------------------------------------------------------===//
1082 // BitcastOp
1083 //===----------------------------------------------------------------------===//
1084 
1085 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1086   if (!areValidCastInputsAndOutputs(inputs, outputs))
1087     return false;
1088 
1089   auto srcType =
1090       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1091   auto dstType =
1092       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1093   if (!srcType || !dstType)
1094     return false;
1095 
1096   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1097 }
1098 
1099 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1100   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1101 
1102   auto resType = getType();
1103   auto operand = operands[0];
1104   if (!operand)
1105     return {};
1106 
1107   /// Bitcast dense elements.
1108   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1109     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1110   /// Other shaped types unhandled.
1111   if (resType.isa<ShapedType>())
1112     return {};
1113 
1114   /// Bitcast integer or float to integer or float.
1115   APInt bits = operand.isa<FloatAttr>()
1116                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1117                    : operand.cast<IntegerAttr>().getValue();
1118 
1119   if (auto resFloatType = resType.dyn_cast<FloatType>())
1120     return FloatAttr::get(resType,
1121                           APFloat(resFloatType.getFloatSemantics(), bits));
1122   return IntegerAttr::get(resType, bits);
1123 }
1124 
1125 void arith::BitcastOp::getCanonicalizationPatterns(
1126     OwningRewritePatternList &patterns, MLIRContext *context) {
1127   patterns.insert<BitcastOfBitcast>(context);
1128 }
1129 
1130 //===----------------------------------------------------------------------===//
1131 // Helpers for compare ops
1132 //===----------------------------------------------------------------------===//
1133 
1134 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1135 static Type getI1SameShape(Type type) {
1136   auto i1Type = IntegerType::get(type.getContext(), 1);
1137   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1138     return RankedTensorType::get(tensorType.getShape(), i1Type);
1139   if (type.isa<UnrankedTensorType>())
1140     return UnrankedTensorType::get(i1Type);
1141   if (auto vectorType = type.dyn_cast<VectorType>())
1142     return VectorType::get(vectorType.getShape(), i1Type,
1143                            vectorType.getNumScalableDims());
1144   return i1Type;
1145 }
1146 
1147 //===----------------------------------------------------------------------===//
1148 // CmpIOp
1149 //===----------------------------------------------------------------------===//
1150 
1151 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1152 /// comparison predicates.
1153 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1154                                     const APInt &lhs, const APInt &rhs) {
1155   switch (predicate) {
1156   case arith::CmpIPredicate::eq:
1157     return lhs.eq(rhs);
1158   case arith::CmpIPredicate::ne:
1159     return lhs.ne(rhs);
1160   case arith::CmpIPredicate::slt:
1161     return lhs.slt(rhs);
1162   case arith::CmpIPredicate::sle:
1163     return lhs.sle(rhs);
1164   case arith::CmpIPredicate::sgt:
1165     return lhs.sgt(rhs);
1166   case arith::CmpIPredicate::sge:
1167     return lhs.sge(rhs);
1168   case arith::CmpIPredicate::ult:
1169     return lhs.ult(rhs);
1170   case arith::CmpIPredicate::ule:
1171     return lhs.ule(rhs);
1172   case arith::CmpIPredicate::ugt:
1173     return lhs.ugt(rhs);
1174   case arith::CmpIPredicate::uge:
1175     return lhs.uge(rhs);
1176   }
1177   llvm_unreachable("unknown cmpi predicate kind");
1178 }
1179 
1180 /// Returns true if the predicate is true for two equal operands.
1181 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1182   switch (predicate) {
1183   case arith::CmpIPredicate::eq:
1184   case arith::CmpIPredicate::sle:
1185   case arith::CmpIPredicate::sge:
1186   case arith::CmpIPredicate::ule:
1187   case arith::CmpIPredicate::uge:
1188     return true;
1189   case arith::CmpIPredicate::ne:
1190   case arith::CmpIPredicate::slt:
1191   case arith::CmpIPredicate::sgt:
1192   case arith::CmpIPredicate::ult:
1193   case arith::CmpIPredicate::ugt:
1194     return false;
1195   }
1196   llvm_unreachable("unknown cmpi predicate kind");
1197 }
1198 
1199 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1200   auto boolAttr = BoolAttr::get(ctx, value);
1201   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1202   if (!shapedType)
1203     return boolAttr;
1204   return DenseElementsAttr::get(shapedType, boolAttr);
1205 }
1206 
1207 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1208   assert(operands.size() == 2 && "cmpi takes two operands");
1209 
1210   // cmpi(pred, x, x)
1211   if (getLhs() == getRhs()) {
1212     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1213     return getBoolAttribute(getType(), getContext(), val);
1214   }
1215 
1216   if (matchPattern(getRhs(), m_Zero())) {
1217     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1218       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1219         // extsi(%x : i1 -> iN) != 0  ->  %x
1220         if (getPredicate() == arith::CmpIPredicate::ne) {
1221           return extOp.getOperand();
1222         }
1223       }
1224     }
1225     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1226       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1227         // extui(%x : i1 -> iN) != 0  ->  %x
1228         if (getPredicate() == arith::CmpIPredicate::ne) {
1229           return extOp.getOperand();
1230         }
1231       }
1232     }
1233   }
1234 
1235   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1236   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1237   if (!lhs || !rhs)
1238     return {};
1239 
1240   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1241   return BoolAttr::get(getContext(), val);
1242 }
1243 
1244 //===----------------------------------------------------------------------===//
1245 // CmpFOp
1246 //===----------------------------------------------------------------------===//
1247 
1248 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1249 /// comparison predicates.
1250 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1251                                     const APFloat &lhs, const APFloat &rhs) {
1252   auto cmpResult = lhs.compare(rhs);
1253   switch (predicate) {
1254   case arith::CmpFPredicate::AlwaysFalse:
1255     return false;
1256   case arith::CmpFPredicate::OEQ:
1257     return cmpResult == APFloat::cmpEqual;
1258   case arith::CmpFPredicate::OGT:
1259     return cmpResult == APFloat::cmpGreaterThan;
1260   case arith::CmpFPredicate::OGE:
1261     return cmpResult == APFloat::cmpGreaterThan ||
1262            cmpResult == APFloat::cmpEqual;
1263   case arith::CmpFPredicate::OLT:
1264     return cmpResult == APFloat::cmpLessThan;
1265   case arith::CmpFPredicate::OLE:
1266     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1267   case arith::CmpFPredicate::ONE:
1268     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1269   case arith::CmpFPredicate::ORD:
1270     return cmpResult != APFloat::cmpUnordered;
1271   case arith::CmpFPredicate::UEQ:
1272     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1273   case arith::CmpFPredicate::UGT:
1274     return cmpResult == APFloat::cmpUnordered ||
1275            cmpResult == APFloat::cmpGreaterThan;
1276   case arith::CmpFPredicate::UGE:
1277     return cmpResult == APFloat::cmpUnordered ||
1278            cmpResult == APFloat::cmpGreaterThan ||
1279            cmpResult == APFloat::cmpEqual;
1280   case arith::CmpFPredicate::ULT:
1281     return cmpResult == APFloat::cmpUnordered ||
1282            cmpResult == APFloat::cmpLessThan;
1283   case arith::CmpFPredicate::ULE:
1284     return cmpResult == APFloat::cmpUnordered ||
1285            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1286   case arith::CmpFPredicate::UNE:
1287     return cmpResult != APFloat::cmpEqual;
1288   case arith::CmpFPredicate::UNO:
1289     return cmpResult == APFloat::cmpUnordered;
1290   case arith::CmpFPredicate::AlwaysTrue:
1291     return true;
1292   }
1293   llvm_unreachable("unknown cmpf predicate kind");
1294 }
1295 
1296 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1297   assert(operands.size() == 2 && "cmpf takes two operands");
1298 
1299   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1300   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1301 
1302   // If one operand is NaN, making them both NaN does not change the result.
1303   if (lhs && lhs.getValue().isNaN())
1304     rhs = lhs;
1305   if (rhs && rhs.getValue().isNaN())
1306     lhs = rhs;
1307 
1308   if (!lhs || !rhs)
1309     return {};
1310 
1311   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1312   return BoolAttr::get(getContext(), val);
1313 }
1314 
1315 //===----------------------------------------------------------------------===//
1316 // Atomic Enum
1317 //===----------------------------------------------------------------------===//
1318 
1319 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1320 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1321                                             OpBuilder &builder, Location loc) {
1322   switch (kind) {
1323   case AtomicRMWKind::maxf:
1324     return builder.getFloatAttr(
1325         resultType,
1326         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1327                         /*Negative=*/true));
1328   case AtomicRMWKind::addf:
1329   case AtomicRMWKind::addi:
1330   case AtomicRMWKind::maxu:
1331   case AtomicRMWKind::ori:
1332     return builder.getZeroAttr(resultType);
1333   case AtomicRMWKind::andi:
1334     return builder.getIntegerAttr(
1335         resultType,
1336         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1337   case AtomicRMWKind::maxs:
1338     return builder.getIntegerAttr(
1339         resultType,
1340         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1341   case AtomicRMWKind::minf:
1342     return builder.getFloatAttr(
1343         resultType,
1344         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1345                         /*Negative=*/false));
1346   case AtomicRMWKind::mins:
1347     return builder.getIntegerAttr(
1348         resultType,
1349         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1350   case AtomicRMWKind::minu:
1351     return builder.getIntegerAttr(
1352         resultType,
1353         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1354   case AtomicRMWKind::muli:
1355     return builder.getIntegerAttr(resultType, 1);
1356   case AtomicRMWKind::mulf:
1357     return builder.getFloatAttr(resultType, 1);
1358   // TODO: Add remaining reduction operations.
1359   default:
1360     (void)emitOptionalError(loc, "Reduction operation type not supported");
1361     break;
1362   }
1363   return nullptr;
1364 }
1365 
1366 /// Returns the identity value associated with an AtomicRMWKind op.
1367 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1368                                     OpBuilder &builder, Location loc) {
1369   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1370   return builder.create<arith::ConstantOp>(loc, attr);
1371 }
1372 
1373 /// Return the value obtained by applying the reduction operation kind
1374 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1375 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1376                                   Location loc, Value lhs, Value rhs) {
1377   switch (op) {
1378   case AtomicRMWKind::addf:
1379     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1380   case AtomicRMWKind::addi:
1381     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1382   case AtomicRMWKind::mulf:
1383     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1384   case AtomicRMWKind::muli:
1385     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1386   case AtomicRMWKind::maxf:
1387     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1388   case AtomicRMWKind::minf:
1389     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1390   case AtomicRMWKind::maxs:
1391     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1392   case AtomicRMWKind::mins:
1393     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1394   case AtomicRMWKind::maxu:
1395     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1396   case AtomicRMWKind::minu:
1397     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1398   case AtomicRMWKind::ori:
1399     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1400   case AtomicRMWKind::andi:
1401     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1402   // TODO: Add remaining reduction operations.
1403   default:
1404     (void)emitOptionalError(loc, "Reduction operation type not supported");
1405     break;
1406   }
1407   return nullptr;
1408 }
1409 
1410 //===----------------------------------------------------------------------===//
1411 // TableGen'd op method definitions
1412 //===----------------------------------------------------------------------===//
1413 
1414 #define GET_OP_CLASSES
1415 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1416 
1417 //===----------------------------------------------------------------------===//
1418 // TableGen'd enum attribute definitions
1419 //===----------------------------------------------------------------------===//
1420 
1421 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1422