1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/CommonFolders.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 #include "llvm/ADT/APSInt.h"
18 
19 using namespace mlir;
20 using namespace mlir::arith;
21 
22 //===----------------------------------------------------------------------===//
23 // Pattern helpers
24 //===----------------------------------------------------------------------===//
25 
26 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
27                                    Attribute lhs, Attribute rhs) {
28   return builder.getIntegerAttr(res.getType(),
29                                 lhs.cast<IntegerAttr>().getInt() +
30                                     rhs.cast<IntegerAttr>().getInt());
31 }
32 
33 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
34                                    Attribute lhs, Attribute rhs) {
35   return builder.getIntegerAttr(res.getType(),
36                                 lhs.cast<IntegerAttr>().getInt() -
37                                     rhs.cast<IntegerAttr>().getInt());
38 }
39 
40 /// Invert an integer comparison predicate.
41 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
42   switch (pred) {
43   case arith::CmpIPredicate::eq:
44     return arith::CmpIPredicate::ne;
45   case arith::CmpIPredicate::ne:
46     return arith::CmpIPredicate::eq;
47   case arith::CmpIPredicate::slt:
48     return arith::CmpIPredicate::sge;
49   case arith::CmpIPredicate::sle:
50     return arith::CmpIPredicate::sgt;
51   case arith::CmpIPredicate::sgt:
52     return arith::CmpIPredicate::sle;
53   case arith::CmpIPredicate::sge:
54     return arith::CmpIPredicate::slt;
55   case arith::CmpIPredicate::ult:
56     return arith::CmpIPredicate::uge;
57   case arith::CmpIPredicate::ule:
58     return arith::CmpIPredicate::ugt;
59   case arith::CmpIPredicate::ugt:
60     return arith::CmpIPredicate::ule;
61   case arith::CmpIPredicate::uge:
62     return arith::CmpIPredicate::ult;
63   }
64   llvm_unreachable("unknown cmpi predicate kind");
65 }
66 
67 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
68   return arith::CmpIPredicateAttr::get(pred.getContext(),
69                                        invertPredicate(pred.getValue()));
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // TableGen'd canonicalization patterns
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 #include "ArithmeticCanonicalization.inc"
78 } // namespace
79 
80 //===----------------------------------------------------------------------===//
81 // ConstantOp
82 //===----------------------------------------------------------------------===//
83 
84 void arith::ConstantOp::getAsmResultNames(
85     function_ref<void(Value, StringRef)> setNameFn) {
86   auto type = getType();
87   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
88     auto intType = type.dyn_cast<IntegerType>();
89 
90     // Sugar i1 constants with 'true' and 'false'.
91     if (intType && intType.getWidth() == 1)
92       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
93 
94     // Otherwise, build a compex name with the value and type.
95     SmallString<32> specialNameBuffer;
96     llvm::raw_svector_ostream specialName(specialNameBuffer);
97     specialName << 'c' << intCst.getInt();
98     if (intType)
99       specialName << '_' << type;
100     setNameFn(getResult(), specialName.str());
101   } else {
102     setNameFn(getResult(), "cst");
103   }
104 }
105 
106 /// TODO: disallow arith.constant to return anything other than signless integer
107 /// or float like.
108 static LogicalResult verify(arith::ConstantOp op) {
109   auto type = op.getType();
110   // The value's type must match the return type.
111   if (op.getValue().getType() != type) {
112     return op.emitOpError() << "value type " << op.getValue().getType()
113                             << " must match return type: " << type;
114   }
115   // Integer values must be signless.
116   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
117     return op.emitOpError("integer return type must be signless");
118   // Any float or elements attribute are acceptable.
119   if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
120     return op.emitOpError(
121         "value must be an integer, float, or elements attribute");
122   }
123   return success();
124 }
125 
126 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
127   // The value's type must be the same as the provided type.
128   if (value.getType() != type)
129     return false;
130   // Integer values must be signless.
131   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
132     return false;
133   // Integer, float, and element attributes are buildable.
134   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
135 }
136 
137 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
138   return getValue();
139 }
140 
141 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
142                                  int64_t value, unsigned width) {
143   auto type = builder.getIntegerType(width);
144   arith::ConstantOp::build(builder, result, type,
145                            builder.getIntegerAttr(type, value));
146 }
147 
148 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
149                                  int64_t value, Type type) {
150   assert(type.isSignlessInteger() &&
151          "ConstantIntOp can only have signless integer type values");
152   arith::ConstantOp::build(builder, result, type,
153                            builder.getIntegerAttr(type, value));
154 }
155 
156 bool arith::ConstantIntOp::classof(Operation *op) {
157   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
158     return constOp.getType().isSignlessInteger();
159   return false;
160 }
161 
162 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
163                                    const APFloat &value, FloatType type) {
164   arith::ConstantOp::build(builder, result, type,
165                            builder.getFloatAttr(type, value));
166 }
167 
168 bool arith::ConstantFloatOp::classof(Operation *op) {
169   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
170     return constOp.getType().isa<FloatType>();
171   return false;
172 }
173 
174 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
175                                    int64_t value) {
176   arith::ConstantOp::build(builder, result, builder.getIndexType(),
177                            builder.getIndexAttr(value));
178 }
179 
180 bool arith::ConstantIndexOp::classof(Operation *op) {
181   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
182     return constOp.getType().isIndex();
183   return false;
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // AddIOp
188 //===----------------------------------------------------------------------===//
189 
190 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
191   // addi(x, 0) -> x
192   if (matchPattern(getRhs(), m_Zero()))
193     return getLhs();
194 
195   return constFoldBinaryOp<IntegerAttr>(operands,
196                                         [](APInt a, APInt b) { return a + b; });
197 }
198 
199 void arith::AddIOp::getCanonicalizationPatterns(
200     OwningRewritePatternList &patterns, MLIRContext *context) {
201   patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
202       context);
203 }
204 
205 //===----------------------------------------------------------------------===//
206 // SubIOp
207 //===----------------------------------------------------------------------===//
208 
209 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
210   // subi(x,x) -> 0
211   if (getOperand(0) == getOperand(1))
212     return Builder(getContext()).getZeroAttr(getType());
213   // subi(x,0) -> x
214   if (matchPattern(getRhs(), m_Zero()))
215     return getLhs();
216 
217   return constFoldBinaryOp<IntegerAttr>(operands,
218                                         [](APInt a, APInt b) { return a - b; });
219 }
220 
221 void arith::SubIOp::getCanonicalizationPatterns(
222     OwningRewritePatternList &patterns, MLIRContext *context) {
223   patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
224                   SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
225                   SubILHSSubConstantLHS>(context);
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // MulIOp
230 //===----------------------------------------------------------------------===//
231 
232 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
233   // muli(x, 0) -> 0
234   if (matchPattern(getRhs(), m_Zero()))
235     return getRhs();
236   // muli(x, 1) -> x
237   if (matchPattern(getRhs(), m_One()))
238     return getOperand(0);
239   // TODO: Handle the overflow case.
240 
241   // default folder
242   return constFoldBinaryOp<IntegerAttr>(operands,
243                                         [](APInt a, APInt b) { return a * b; });
244 }
245 
246 //===----------------------------------------------------------------------===//
247 // DivUIOp
248 //===----------------------------------------------------------------------===//
249 
250 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
251   // Don't fold if it would require a division by zero.
252   bool div0 = false;
253   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
254     if (div0 || !b) {
255       div0 = true;
256       return a;
257     }
258     return a.udiv(b);
259   });
260 
261   // Fold out division by one. Assumes all tensors of all ones are splats.
262   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
263     if (rhs.getValue() == 1)
264       return getLhs();
265   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
266     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
267       return getLhs();
268   }
269 
270   return div0 ? Attribute() : result;
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // DivSIOp
275 //===----------------------------------------------------------------------===//
276 
277 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
278   // Don't fold if it would overflow or if it requires a division by zero.
279   bool overflowOrDiv0 = false;
280   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
281     if (overflowOrDiv0 || !b) {
282       overflowOrDiv0 = true;
283       return a;
284     }
285     return a.sdiv_ov(b, overflowOrDiv0);
286   });
287 
288   // Fold out division by one. Assumes all tensors of all ones are splats.
289   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
290     if (rhs.getValue() == 1)
291       return getLhs();
292   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
293     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
294       return getLhs();
295   }
296 
297   return overflowOrDiv0 ? Attribute() : result;
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // Ceil and floor division folding helpers
302 //===----------------------------------------------------------------------===//
303 
304 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
305   // Returns (a-1)/b + 1
306   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
307   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
308   return val.sadd_ov(one, overflow);
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // CeilDivUIOp
313 //===----------------------------------------------------------------------===//
314 
315 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
316   bool overflowOrDiv0 = false;
317   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
318     if (overflowOrDiv0 || !b) {
319       overflowOrDiv0 = true;
320       return a;
321     }
322     APInt quotient = a.udiv(b);
323     if (!a.urem(b))
324       return quotient;
325     APInt one(a.getBitWidth(), 1, true);
326     return quotient.uadd_ov(one, overflowOrDiv0);
327   });
328   // Fold out ceil division by one. Assumes all tensors of all ones are
329   // splats.
330   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
331     if (rhs.getValue() == 1)
332       return getLhs();
333   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
334     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
335       return getLhs();
336   }
337 
338   return overflowOrDiv0 ? Attribute() : result;
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // CeilDivSIOp
343 //===----------------------------------------------------------------------===//
344 
345 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
346   // Don't fold if it would overflow or if it requires a division by zero.
347   bool overflowOrDiv0 = false;
348   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
349     if (overflowOrDiv0 || !b) {
350       overflowOrDiv0 = true;
351       return a;
352     }
353     unsigned bits = a.getBitWidth();
354     APInt zero = APInt::getZero(bits);
355     if (a.sgt(zero) && b.sgt(zero)) {
356       // Both positive, return ceil(a, b).
357       return signedCeilNonnegInputs(a, b, overflowOrDiv0);
358     }
359     if (a.slt(zero) && b.slt(zero)) {
360       // Both negative, return ceil(-a, -b).
361       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
362       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
363       return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
364     }
365     if (a.slt(zero) && b.sgt(zero)) {
366       // A is negative, b is positive, return - ( -a / b).
367       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
368       APInt div = posA.sdiv_ov(b, overflowOrDiv0);
369       return zero.ssub_ov(div, overflowOrDiv0);
370     }
371     // A is positive (or zero), b is negative, return - (a / -b).
372     APInt posB = zero.ssub_ov(b, overflowOrDiv0);
373     APInt div = a.sdiv_ov(posB, overflowOrDiv0);
374     return zero.ssub_ov(div, overflowOrDiv0);
375   });
376 
377   // Fold out ceil division by one. Assumes all tensors of all ones are
378   // splats.
379   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
380     if (rhs.getValue() == 1)
381       return getLhs();
382   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
383     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
384       return getLhs();
385   }
386 
387   return overflowOrDiv0 ? Attribute() : result;
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // FloorDivSIOp
392 //===----------------------------------------------------------------------===//
393 
394 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
395   // Don't fold if it would overflow or if it requires a division by zero.
396   bool overflowOrDiv0 = false;
397   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
398     if (overflowOrDiv0 || !b) {
399       overflowOrDiv0 = true;
400       return a;
401     }
402     unsigned bits = a.getBitWidth();
403     APInt zero = APInt::getZero(bits);
404     if (a.sge(zero) && b.sgt(zero)) {
405       // Both positive (or a is zero), return a / b.
406       return a.sdiv_ov(b, overflowOrDiv0);
407     }
408     if (a.sle(zero) && b.slt(zero)) {
409       // Both negative (or a is zero), return -a / -b.
410       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
411       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
412       return posA.sdiv_ov(posB, overflowOrDiv0);
413     }
414     if (a.slt(zero) && b.sgt(zero)) {
415       // A is negative, b is positive, return - ceil(-a, b).
416       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
417       APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
418       return zero.ssub_ov(ceil, overflowOrDiv0);
419     }
420     // A is positive, b is negative, return - ceil(a, -b).
421     APInt posB = zero.ssub_ov(b, overflowOrDiv0);
422     APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
423     return zero.ssub_ov(ceil, overflowOrDiv0);
424   });
425 
426   // Fold out floor division by one. Assumes all tensors of all ones are
427   // splats.
428   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
429     if (rhs.getValue() == 1)
430       return getLhs();
431   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
432     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
433       return getLhs();
434   }
435 
436   return overflowOrDiv0 ? Attribute() : result;
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // RemUIOp
441 //===----------------------------------------------------------------------===//
442 
443 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
444   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
445   if (!rhs)
446     return {};
447   auto rhsValue = rhs.getValue();
448 
449   // x % 1 = 0
450   if (rhsValue.isOneValue())
451     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
452 
453   // Don't fold if it requires division by zero.
454   if (rhsValue.isNullValue())
455     return {};
456 
457   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
458   if (!lhs)
459     return {};
460   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // RemSIOp
465 //===----------------------------------------------------------------------===//
466 
467 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
468   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
469   if (!rhs)
470     return {};
471   auto rhsValue = rhs.getValue();
472 
473   // x % 1 = 0
474   if (rhsValue.isOneValue())
475     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
476 
477   // Don't fold if it requires division by zero.
478   if (rhsValue.isNullValue())
479     return {};
480 
481   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
482   if (!lhs)
483     return {};
484   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // AndIOp
489 //===----------------------------------------------------------------------===//
490 
491 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
492   /// and(x, 0) -> 0
493   if (matchPattern(getRhs(), m_Zero()))
494     return getRhs();
495   /// and(x, allOnes) -> x
496   APInt intValue;
497   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
498     return getLhs();
499 
500   return constFoldBinaryOp<IntegerAttr>(operands,
501                                         [](APInt a, APInt b) { return a & b; });
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // OrIOp
506 //===----------------------------------------------------------------------===//
507 
508 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
509   /// or(x, 0) -> x
510   if (matchPattern(getRhs(), m_Zero()))
511     return getLhs();
512   /// or(x, <all ones>) -> <all ones>
513   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
514     if (rhsAttr.getValue().isAllOnes())
515       return rhsAttr;
516 
517   return constFoldBinaryOp<IntegerAttr>(operands,
518                                         [](APInt a, APInt b) { return a | b; });
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // XOrIOp
523 //===----------------------------------------------------------------------===//
524 
525 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
526   /// xor(x, 0) -> x
527   if (matchPattern(getRhs(), m_Zero()))
528     return getLhs();
529   /// xor(x, x) -> 0
530   if (getLhs() == getRhs())
531     return Builder(getContext()).getZeroAttr(getType());
532 
533   return constFoldBinaryOp<IntegerAttr>(operands,
534                                         [](APInt a, APInt b) { return a ^ b; });
535 }
536 
537 void arith::XOrIOp::getCanonicalizationPatterns(
538     OwningRewritePatternList &patterns, MLIRContext *context) {
539   patterns.insert<XOrINotCmpI>(context);
540 }
541 
542 //===----------------------------------------------------------------------===//
543 // AddFOp
544 //===----------------------------------------------------------------------===//
545 
546 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
547   return constFoldBinaryOp<FloatAttr>(
548       operands, [](APFloat a, APFloat b) { return a + b; });
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // SubFOp
553 //===----------------------------------------------------------------------===//
554 
555 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
556   return constFoldBinaryOp<FloatAttr>(
557       operands, [](APFloat a, APFloat b) { return a - b; });
558 }
559 
560 //===----------------------------------------------------------------------===//
561 // MaxSIOp
562 //===----------------------------------------------------------------------===//
563 
564 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
565   assert(operands.size() == 2 && "binary operation takes two operands");
566 
567   // maxsi(x,x) -> x
568   if (getLhs() == getRhs())
569     return getRhs();
570 
571   APInt intValue;
572   // maxsi(x,MAX_INT) -> MAX_INT
573   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
574       intValue.isMaxSignedValue())
575     return getRhs();
576 
577   // maxsi(x, MIN_INT) -> x
578   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
579       intValue.isMinSignedValue())
580     return getLhs();
581 
582   return constFoldBinaryOp<IntegerAttr>(
583       operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); });
584 }
585 
586 //===----------------------------------------------------------------------===//
587 // MaxUIOp
588 //===----------------------------------------------------------------------===//
589 
590 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
591   assert(operands.size() == 2 && "binary operation takes two operands");
592 
593   // maxui(x,x) -> x
594   if (getLhs() == getRhs())
595     return getRhs();
596 
597   APInt intValue;
598   // maxui(x,MAX_INT) -> MAX_INT
599   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
600     return getRhs();
601 
602   // maxui(x, MIN_INT) -> x
603   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
604     return getLhs();
605 
606   return constFoldBinaryOp<IntegerAttr>(
607       operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); });
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // MinSIOp
612 //===----------------------------------------------------------------------===//
613 
614 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
615   assert(operands.size() == 2 && "binary operation takes two operands");
616 
617   // minsi(x,x) -> x
618   if (getLhs() == getRhs())
619     return getRhs();
620 
621   APInt intValue;
622   // minsi(x,MIN_INT) -> MIN_INT
623   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
624       intValue.isMinSignedValue())
625     return getRhs();
626 
627   // minsi(x, MAX_INT) -> x
628   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
629       intValue.isMaxSignedValue())
630     return getLhs();
631 
632   return constFoldBinaryOp<IntegerAttr>(
633       operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); });
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // MinUIOp
638 //===----------------------------------------------------------------------===//
639 
640 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
641   assert(operands.size() == 2 && "binary operation takes two operands");
642 
643   // minui(x,x) -> x
644   if (getLhs() == getRhs())
645     return getRhs();
646 
647   APInt intValue;
648   // minui(x,MIN_INT) -> MIN_INT
649   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
650     return getRhs();
651 
652   // minui(x, MAX_INT) -> x
653   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
654     return getLhs();
655 
656   return constFoldBinaryOp<IntegerAttr>(
657       operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); });
658 }
659 
660 //===----------------------------------------------------------------------===//
661 // MulFOp
662 //===----------------------------------------------------------------------===//
663 
664 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
665   return constFoldBinaryOp<FloatAttr>(
666       operands, [](APFloat a, APFloat b) { return a * b; });
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // DivFOp
671 //===----------------------------------------------------------------------===//
672 
673 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
674   return constFoldBinaryOp<FloatAttr>(
675       operands, [](APFloat a, APFloat b) { return a / b; });
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // Utility functions for verifying cast ops
680 //===----------------------------------------------------------------------===//
681 
682 template <typename... Types>
683 using type_list = std::tuple<Types...> *;
684 
685 /// Returns a non-null type only if the provided type is one of the allowed
686 /// types or one of the allowed shaped types of the allowed types. Returns the
687 /// element type if a valid shaped type is provided.
688 template <typename... ShapedTypes, typename... ElementTypes>
689 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
690                               type_list<ElementTypes...>) {
691   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
692     return {};
693 
694   auto underlyingType = getElementTypeOrSelf(type);
695   if (!underlyingType.isa<ElementTypes...>())
696     return {};
697 
698   return underlyingType;
699 }
700 
701 /// Get allowed underlying types for vectors and tensors.
702 template <typename... ElementTypes>
703 static Type getTypeIfLike(Type type) {
704   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
705                            type_list<ElementTypes...>());
706 }
707 
708 /// Get allowed underlying types for vectors, tensors, and memrefs.
709 template <typename... ElementTypes>
710 static Type getTypeIfLikeOrMemRef(Type type) {
711   return getUnderlyingType(type,
712                            type_list<VectorType, TensorType, MemRefType>(),
713                            type_list<ElementTypes...>());
714 }
715 
716 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
717   return inputs.size() == 1 && outputs.size() == 1 &&
718          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
719 }
720 
721 //===----------------------------------------------------------------------===//
722 // Verifiers for integer and floating point extension/truncation ops
723 //===----------------------------------------------------------------------===//
724 
725 // Extend ops can only extend to a wider type.
726 template <typename ValType, typename Op>
727 static LogicalResult verifyExtOp(Op op) {
728   Type srcType = getElementTypeOrSelf(op.getIn().getType());
729   Type dstType = getElementTypeOrSelf(op.getType());
730 
731   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
732     return op.emitError("result type ")
733            << dstType << " must be wider than operand type " << srcType;
734 
735   return success();
736 }
737 
738 // Truncate ops can only truncate to a shorter type.
739 template <typename ValType, typename Op>
740 static LogicalResult verifyTruncateOp(Op op) {
741   Type srcType = getElementTypeOrSelf(op.getIn().getType());
742   Type dstType = getElementTypeOrSelf(op.getType());
743 
744   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
745     return op.emitError("result type ")
746            << dstType << " must be shorter than operand type " << srcType;
747 
748   return success();
749 }
750 
751 /// Validate a cast that changes the width of a type.
752 template <template <typename> class WidthComparator, typename... ElementTypes>
753 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
754   if (!areValidCastInputsAndOutputs(inputs, outputs))
755     return false;
756 
757   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
758   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
759   if (!srcType || !dstType)
760     return false;
761 
762   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
763                                      srcType.getIntOrFloatBitWidth());
764 }
765 
766 //===----------------------------------------------------------------------===//
767 // ExtUIOp
768 //===----------------------------------------------------------------------===//
769 
770 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
771   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
772     return IntegerAttr::get(
773         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
774 
775   return {};
776 }
777 
778 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
779   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // ExtSIOp
784 //===----------------------------------------------------------------------===//
785 
786 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
787   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
788     return IntegerAttr::get(
789         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
790 
791   return {};
792 }
793 
794 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
795   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
796 }
797 
798 //===----------------------------------------------------------------------===//
799 // ExtFOp
800 //===----------------------------------------------------------------------===//
801 
802 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
803   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // TruncIOp
808 //===----------------------------------------------------------------------===//
809 
810 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
811   // trunci(zexti(a)) -> a
812   // trunci(sexti(a)) -> a
813   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
814       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
815     return getOperand().getDefiningOp()->getOperand(0);
816 
817   assert(operands.size() == 1 && "unary operation takes one operand");
818 
819   if (!operands[0])
820     return {};
821 
822   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
823     return IntegerAttr::get(
824         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
825   }
826 
827   return {};
828 }
829 
830 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
831   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
832 }
833 
834 //===----------------------------------------------------------------------===//
835 // TruncFOp
836 //===----------------------------------------------------------------------===//
837 
838 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
839 /// can be represented without precision loss or rounding.
840 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
841   assert(operands.size() == 1 && "unary operation takes one operand");
842 
843   auto constOperand = operands.front();
844   if (!constOperand || !constOperand.isa<FloatAttr>())
845     return {};
846 
847   // Convert to target type via 'double'.
848   double sourceValue =
849       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
850   auto targetAttr = FloatAttr::get(getType(), sourceValue);
851 
852   // Propagate if constant's value does not change after truncation.
853   if (sourceValue == targetAttr.getValue().convertToDouble())
854     return targetAttr;
855 
856   return {};
857 }
858 
859 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
860   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
861 }
862 
863 //===----------------------------------------------------------------------===//
864 // Verifiers for casts between integers and floats.
865 //===----------------------------------------------------------------------===//
866 
867 template <typename From, typename To>
868 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
869   if (!areValidCastInputsAndOutputs(inputs, outputs))
870     return false;
871 
872   auto srcType = getTypeIfLike<From>(inputs.front());
873   auto dstType = getTypeIfLike<To>(outputs.back());
874 
875   return srcType && dstType;
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // UIToFPOp
880 //===----------------------------------------------------------------------===//
881 
882 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
883   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
884 }
885 
886 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
887   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
888     const APInt &api = lhs.getValue();
889     FloatType floatTy = getType().cast<FloatType>();
890     APFloat apf(floatTy.getFloatSemantics(),
891                 APInt::getZero(floatTy.getWidth()));
892     apf.convertFromAPInt(api, /*signed=*/false, APFloat::rmNearestTiesToEven);
893     return FloatAttr::get(floatTy, apf);
894   }
895   return {};
896 }
897 
898 //===----------------------------------------------------------------------===//
899 // SIToFPOp
900 //===----------------------------------------------------------------------===//
901 
902 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
903   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
904 }
905 
906 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
907   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
908     const APInt &api = lhs.getValue();
909     FloatType floatTy = getType().cast<FloatType>();
910     APFloat apf(floatTy.getFloatSemantics(),
911                 APInt::getZero(floatTy.getWidth()));
912     apf.convertFromAPInt(api, /*signed=*/true, APFloat::rmNearestTiesToEven);
913     return FloatAttr::get(floatTy, apf);
914   }
915   return {};
916 }
917 //===----------------------------------------------------------------------===//
918 // FPToUIOp
919 //===----------------------------------------------------------------------===//
920 
921 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
922   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
923 }
924 
925 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
926   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
927     const APFloat &apf = lhs.getValue();
928     IntegerType intTy = getType().cast<IntegerType>();
929     bool ignored;
930     APSInt api(intTy.getWidth(), /*unsigned=*/true);
931     if (APFloat::opInvalidOp ==
932         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
933       // Undefined behavior invoked - the destination type can't represent
934       // the input constant.
935       return {};
936     }
937     return IntegerAttr::get(getType(), api);
938   }
939 
940   return {};
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // FPToSIOp
945 //===----------------------------------------------------------------------===//
946 
947 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
948   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
949 }
950 
951 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
952   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
953     const APFloat &apf = lhs.getValue();
954     IntegerType intTy = getType().cast<IntegerType>();
955     bool ignored;
956     APSInt api(intTy.getWidth(), /*unsigned=*/false);
957     if (APFloat::opInvalidOp ==
958         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
959       // Undefined behavior invoked - the destination type can't represent
960       // the input constant.
961       return {};
962     }
963     return IntegerAttr::get(getType(), api);
964   }
965 
966   return {};
967 }
968 
969 //===----------------------------------------------------------------------===//
970 // IndexCastOp
971 //===----------------------------------------------------------------------===//
972 
973 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
974                                            TypeRange outputs) {
975   if (!areValidCastInputsAndOutputs(inputs, outputs))
976     return false;
977 
978   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
979   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
980   if (!srcType || !dstType)
981     return false;
982 
983   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
984          (srcType.isSignlessInteger() && dstType.isIndex());
985 }
986 
987 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
988   // index_cast(constant) -> constant
989   // A little hack because we go through int. Otherwise, the size of the
990   // constant might need to change.
991   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
992     return IntegerAttr::get(getType(), value.getInt());
993 
994   return {};
995 }
996 
997 void arith::IndexCastOp::getCanonicalizationPatterns(
998     OwningRewritePatternList &patterns, MLIRContext *context) {
999   patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1000 }
1001 
1002 //===----------------------------------------------------------------------===//
1003 // BitcastOp
1004 //===----------------------------------------------------------------------===//
1005 
1006 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1007   if (!areValidCastInputsAndOutputs(inputs, outputs))
1008     return false;
1009 
1010   auto srcType =
1011       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1012   auto dstType =
1013       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1014   if (!srcType || !dstType)
1015     return false;
1016 
1017   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1018 }
1019 
1020 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1021   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1022 
1023   auto resType = getType();
1024   auto operand = operands[0];
1025   if (!operand)
1026     return {};
1027 
1028   /// Bitcast dense elements.
1029   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1030     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1031   /// Other shaped types unhandled.
1032   if (resType.isa<ShapedType>())
1033     return {};
1034 
1035   /// Bitcast integer or float to integer or float.
1036   APInt bits = operand.isa<FloatAttr>()
1037                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1038                    : operand.cast<IntegerAttr>().getValue();
1039 
1040   if (auto resFloatType = resType.dyn_cast<FloatType>())
1041     return FloatAttr::get(resType,
1042                           APFloat(resFloatType.getFloatSemantics(), bits));
1043   return IntegerAttr::get(resType, bits);
1044 }
1045 
1046 void arith::BitcastOp::getCanonicalizationPatterns(
1047     OwningRewritePatternList &patterns, MLIRContext *context) {
1048   patterns.insert<BitcastOfBitcast>(context);
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // Helpers for compare ops
1053 //===----------------------------------------------------------------------===//
1054 
1055 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1056 static Type getI1SameShape(Type type) {
1057   auto i1Type = IntegerType::get(type.getContext(), 1);
1058   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1059     return RankedTensorType::get(tensorType.getShape(), i1Type);
1060   if (type.isa<UnrankedTensorType>())
1061     return UnrankedTensorType::get(i1Type);
1062   if (auto vectorType = type.dyn_cast<VectorType>())
1063     return VectorType::get(vectorType.getShape(), i1Type,
1064                            vectorType.getNumScalableDims());
1065   return i1Type;
1066 }
1067 
1068 //===----------------------------------------------------------------------===//
1069 // CmpIOp
1070 //===----------------------------------------------------------------------===//
1071 
1072 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1073 /// comparison predicates.
1074 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1075                                     const APInt &lhs, const APInt &rhs) {
1076   switch (predicate) {
1077   case arith::CmpIPredicate::eq:
1078     return lhs.eq(rhs);
1079   case arith::CmpIPredicate::ne:
1080     return lhs.ne(rhs);
1081   case arith::CmpIPredicate::slt:
1082     return lhs.slt(rhs);
1083   case arith::CmpIPredicate::sle:
1084     return lhs.sle(rhs);
1085   case arith::CmpIPredicate::sgt:
1086     return lhs.sgt(rhs);
1087   case arith::CmpIPredicate::sge:
1088     return lhs.sge(rhs);
1089   case arith::CmpIPredicate::ult:
1090     return lhs.ult(rhs);
1091   case arith::CmpIPredicate::ule:
1092     return lhs.ule(rhs);
1093   case arith::CmpIPredicate::ugt:
1094     return lhs.ugt(rhs);
1095   case arith::CmpIPredicate::uge:
1096     return lhs.uge(rhs);
1097   }
1098   llvm_unreachable("unknown cmpi predicate kind");
1099 }
1100 
1101 /// Returns true if the predicate is true for two equal operands.
1102 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1103   switch (predicate) {
1104   case arith::CmpIPredicate::eq:
1105   case arith::CmpIPredicate::sle:
1106   case arith::CmpIPredicate::sge:
1107   case arith::CmpIPredicate::ule:
1108   case arith::CmpIPredicate::uge:
1109     return true;
1110   case arith::CmpIPredicate::ne:
1111   case arith::CmpIPredicate::slt:
1112   case arith::CmpIPredicate::sgt:
1113   case arith::CmpIPredicate::ult:
1114   case arith::CmpIPredicate::ugt:
1115     return false;
1116   }
1117   llvm_unreachable("unknown cmpi predicate kind");
1118 }
1119 
1120 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1121   auto boolAttr = BoolAttr::get(ctx, value);
1122   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1123   if (!shapedType)
1124     return boolAttr;
1125   return DenseElementsAttr::get(shapedType, boolAttr);
1126 }
1127 
1128 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1129   assert(operands.size() == 2 && "cmpi takes two operands");
1130 
1131   // cmpi(pred, x, x)
1132   if (getLhs() == getRhs()) {
1133     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1134     return getBoolAttribute(getType(), getContext(), val);
1135   }
1136 
1137   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1138   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1139   if (!lhs || !rhs)
1140     return {};
1141 
1142   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1143   return BoolAttr::get(getContext(), val);
1144 }
1145 
1146 //===----------------------------------------------------------------------===//
1147 // CmpFOp
1148 //===----------------------------------------------------------------------===//
1149 
1150 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1151 /// comparison predicates.
1152 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1153                                     const APFloat &lhs, const APFloat &rhs) {
1154   auto cmpResult = lhs.compare(rhs);
1155   switch (predicate) {
1156   case arith::CmpFPredicate::AlwaysFalse:
1157     return false;
1158   case arith::CmpFPredicate::OEQ:
1159     return cmpResult == APFloat::cmpEqual;
1160   case arith::CmpFPredicate::OGT:
1161     return cmpResult == APFloat::cmpGreaterThan;
1162   case arith::CmpFPredicate::OGE:
1163     return cmpResult == APFloat::cmpGreaterThan ||
1164            cmpResult == APFloat::cmpEqual;
1165   case arith::CmpFPredicate::OLT:
1166     return cmpResult == APFloat::cmpLessThan;
1167   case arith::CmpFPredicate::OLE:
1168     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1169   case arith::CmpFPredicate::ONE:
1170     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1171   case arith::CmpFPredicate::ORD:
1172     return cmpResult != APFloat::cmpUnordered;
1173   case arith::CmpFPredicate::UEQ:
1174     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1175   case arith::CmpFPredicate::UGT:
1176     return cmpResult == APFloat::cmpUnordered ||
1177            cmpResult == APFloat::cmpGreaterThan;
1178   case arith::CmpFPredicate::UGE:
1179     return cmpResult == APFloat::cmpUnordered ||
1180            cmpResult == APFloat::cmpGreaterThan ||
1181            cmpResult == APFloat::cmpEqual;
1182   case arith::CmpFPredicate::ULT:
1183     return cmpResult == APFloat::cmpUnordered ||
1184            cmpResult == APFloat::cmpLessThan;
1185   case arith::CmpFPredicate::ULE:
1186     return cmpResult == APFloat::cmpUnordered ||
1187            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1188   case arith::CmpFPredicate::UNE:
1189     return cmpResult != APFloat::cmpEqual;
1190   case arith::CmpFPredicate::UNO:
1191     return cmpResult == APFloat::cmpUnordered;
1192   case arith::CmpFPredicate::AlwaysTrue:
1193     return true;
1194   }
1195   llvm_unreachable("unknown cmpf predicate kind");
1196 }
1197 
1198 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1199   assert(operands.size() == 2 && "cmpf takes two operands");
1200 
1201   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1202   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1203 
1204   if (!lhs || !rhs)
1205     return {};
1206 
1207   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1208   return BoolAttr::get(getContext(), val);
1209 }
1210 
1211 //===----------------------------------------------------------------------===//
1212 // Atomic Enum
1213 //===----------------------------------------------------------------------===//
1214 
1215 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1216 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1217                                             OpBuilder &builder, Location loc) {
1218   switch (kind) {
1219   case AtomicRMWKind::maxf:
1220     return builder.getFloatAttr(
1221         resultType,
1222         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1223                         /*Negative=*/true));
1224   case AtomicRMWKind::addf:
1225   case AtomicRMWKind::addi:
1226   case AtomicRMWKind::maxu:
1227   case AtomicRMWKind::ori:
1228     return builder.getZeroAttr(resultType);
1229   case AtomicRMWKind::andi:
1230     return builder.getIntegerAttr(
1231         resultType,
1232         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1233   case AtomicRMWKind::maxs:
1234     return builder.getIntegerAttr(
1235         resultType,
1236         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1237   case AtomicRMWKind::minf:
1238     return builder.getFloatAttr(
1239         resultType,
1240         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1241                         /*Negative=*/false));
1242   case AtomicRMWKind::mins:
1243     return builder.getIntegerAttr(
1244         resultType,
1245         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1246   case AtomicRMWKind::minu:
1247     return builder.getIntegerAttr(
1248         resultType,
1249         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1250   case AtomicRMWKind::muli:
1251     return builder.getIntegerAttr(resultType, 1);
1252   case AtomicRMWKind::mulf:
1253     return builder.getFloatAttr(resultType, 1);
1254   // TODO: Add remaining reduction operations.
1255   default:
1256     (void)emitOptionalError(loc, "Reduction operation type not supported");
1257     break;
1258   }
1259   return nullptr;
1260 }
1261 
1262 /// Returns the identity value associated with an AtomicRMWKind op.
1263 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1264                                     OpBuilder &builder, Location loc) {
1265   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1266   return builder.create<arith::ConstantOp>(loc, attr);
1267 }
1268 
1269 /// Return the value obtained by applying the reduction operation kind
1270 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1271 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1272                                   Location loc, Value lhs, Value rhs) {
1273   switch (op) {
1274   case AtomicRMWKind::addf:
1275     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1276   case AtomicRMWKind::addi:
1277     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1278   case AtomicRMWKind::mulf:
1279     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1280   case AtomicRMWKind::muli:
1281     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1282   case AtomicRMWKind::maxf:
1283     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1284   case AtomicRMWKind::minf:
1285     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1286   case AtomicRMWKind::maxs:
1287     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1288   case AtomicRMWKind::mins:
1289     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1290   case AtomicRMWKind::maxu:
1291     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1292   case AtomicRMWKind::minu:
1293     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1294   case AtomicRMWKind::ori:
1295     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1296   case AtomicRMWKind::andi:
1297     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1298   // TODO: Add remaining reduction operations.
1299   default:
1300     (void)emitOptionalError(loc, "Reduction operation type not supported");
1301     break;
1302   }
1303   return nullptr;
1304 }
1305 
1306 //===----------------------------------------------------------------------===//
1307 // TableGen'd op method definitions
1308 //===----------------------------------------------------------------------===//
1309 
1310 #define GET_OP_CLASSES
1311 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1312 
1313 //===----------------------------------------------------------------------===//
1314 // TableGen'd enum attribute definitions
1315 //===----------------------------------------------------------------------===//
1316 
1317 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1318