1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #include "llvm/ADT/APSInt.h"
20 
21 using namespace mlir;
22 using namespace mlir::arith;
23 
24 //===----------------------------------------------------------------------===//
25 // Pattern helpers
26 //===----------------------------------------------------------------------===//
27 
28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
29                                    Attribute lhs, Attribute rhs) {
30   return builder.getIntegerAttr(res.getType(),
31                                 lhs.cast<IntegerAttr>().getInt() +
32                                     rhs.cast<IntegerAttr>().getInt());
33 }
34 
35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
36                                    Attribute lhs, Attribute rhs) {
37   return builder.getIntegerAttr(res.getType(),
38                                 lhs.cast<IntegerAttr>().getInt() -
39                                     rhs.cast<IntegerAttr>().getInt());
40 }
41 
42 /// Invert an integer comparison predicate.
43 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
44   switch (pred) {
45   case arith::CmpIPredicate::eq:
46     return arith::CmpIPredicate::ne;
47   case arith::CmpIPredicate::ne:
48     return arith::CmpIPredicate::eq;
49   case arith::CmpIPredicate::slt:
50     return arith::CmpIPredicate::sge;
51   case arith::CmpIPredicate::sle:
52     return arith::CmpIPredicate::sgt;
53   case arith::CmpIPredicate::sgt:
54     return arith::CmpIPredicate::sle;
55   case arith::CmpIPredicate::sge:
56     return arith::CmpIPredicate::slt;
57   case arith::CmpIPredicate::ult:
58     return arith::CmpIPredicate::uge;
59   case arith::CmpIPredicate::ule:
60     return arith::CmpIPredicate::ugt;
61   case arith::CmpIPredicate::ugt:
62     return arith::CmpIPredicate::ule;
63   case arith::CmpIPredicate::uge:
64     return arith::CmpIPredicate::ult;
65   }
66   llvm_unreachable("unknown cmpi predicate kind");
67 }
68 
69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
70   return arith::CmpIPredicateAttr::get(pred.getContext(),
71                                        invertPredicate(pred.getValue()));
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // TableGen'd canonicalization patterns
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 #include "ArithmeticCanonicalization.inc"
80 } // namespace
81 
82 //===----------------------------------------------------------------------===//
83 // ConstantOp
84 //===----------------------------------------------------------------------===//
85 
86 void arith::ConstantOp::getAsmResultNames(
87     function_ref<void(Value, StringRef)> setNameFn) {
88   auto type = getType();
89   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
90     auto intType = type.dyn_cast<IntegerType>();
91 
92     // Sugar i1 constants with 'true' and 'false'.
93     if (intType && intType.getWidth() == 1)
94       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
95 
96     // Otherwise, build a compex name with the value and type.
97     SmallString<32> specialNameBuffer;
98     llvm::raw_svector_ostream specialName(specialNameBuffer);
99     specialName << 'c' << intCst.getInt();
100     if (intType)
101       specialName << '_' << type;
102     setNameFn(getResult(), specialName.str());
103   } else {
104     setNameFn(getResult(), "cst");
105   }
106 }
107 
108 /// TODO: disallow arith.constant to return anything other than signless integer
109 /// or float like.
110 static LogicalResult verify(arith::ConstantOp op) {
111   auto type = op.getType();
112   // The value's type must match the return type.
113   if (op.getValue().getType() != type) {
114     return op.emitOpError() << "value type " << op.getValue().getType()
115                             << " must match return type: " << type;
116   }
117   // Integer values must be signless.
118   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
119     return op.emitOpError("integer return type must be signless");
120   // Any float or elements attribute are acceptable.
121   if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
122     return op.emitOpError(
123         "value must be an integer, float, or elements attribute");
124   }
125   return success();
126 }
127 
128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
129   // The value's type must be the same as the provided type.
130   if (value.getType() != type)
131     return false;
132   // Integer values must be signless.
133   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
134     return false;
135   // Integer, float, and element attributes are buildable.
136   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
137 }
138 
139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
140   return getValue();
141 }
142 
143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
144                                  int64_t value, unsigned width) {
145   auto type = builder.getIntegerType(width);
146   arith::ConstantOp::build(builder, result, type,
147                            builder.getIntegerAttr(type, value));
148 }
149 
150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
151                                  int64_t value, Type type) {
152   assert(type.isSignlessInteger() &&
153          "ConstantIntOp can only have signless integer type values");
154   arith::ConstantOp::build(builder, result, type,
155                            builder.getIntegerAttr(type, value));
156 }
157 
158 bool arith::ConstantIntOp::classof(Operation *op) {
159   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
160     return constOp.getType().isSignlessInteger();
161   return false;
162 }
163 
164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
165                                    const APFloat &value, FloatType type) {
166   arith::ConstantOp::build(builder, result, type,
167                            builder.getFloatAttr(type, value));
168 }
169 
170 bool arith::ConstantFloatOp::classof(Operation *op) {
171   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
172     return constOp.getType().isa<FloatType>();
173   return false;
174 }
175 
176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
177                                    int64_t value) {
178   arith::ConstantOp::build(builder, result, builder.getIndexType(),
179                            builder.getIndexAttr(value));
180 }
181 
182 bool arith::ConstantIndexOp::classof(Operation *op) {
183   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
184     return constOp.getType().isIndex();
185   return false;
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // AddIOp
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
193   // addi(x, 0) -> x
194   if (matchPattern(getRhs(), m_Zero()))
195     return getLhs();
196 
197   return constFoldBinaryOp<IntegerAttr>(
198       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
199 }
200 
201 void arith::AddIOp::getCanonicalizationPatterns(
202     OwningRewritePatternList &patterns, MLIRContext *context) {
203   patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
204       context);
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // SubIOp
209 //===----------------------------------------------------------------------===//
210 
211 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
212   // subi(x,x) -> 0
213   if (getOperand(0) == getOperand(1))
214     return Builder(getContext()).getZeroAttr(getType());
215   // subi(x,0) -> x
216   if (matchPattern(getRhs(), m_Zero()))
217     return getLhs();
218 
219   return constFoldBinaryOp<IntegerAttr>(
220       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
221 }
222 
223 void arith::SubIOp::getCanonicalizationPatterns(
224     OwningRewritePatternList &patterns, MLIRContext *context) {
225   patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
226                   SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
227                   SubILHSSubConstantLHS>(context);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // MulIOp
232 //===----------------------------------------------------------------------===//
233 
234 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
235   // muli(x, 0) -> 0
236   if (matchPattern(getRhs(), m_Zero()))
237     return getRhs();
238   // muli(x, 1) -> x
239   if (matchPattern(getRhs(), m_One()))
240     return getOperand(0);
241   // TODO: Handle the overflow case.
242 
243   // default folder
244   return constFoldBinaryOp<IntegerAttr>(
245       operands, [](const APInt &a, const APInt &b) { return a * b; });
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // DivUIOp
250 //===----------------------------------------------------------------------===//
251 
252 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
253   // Don't fold if it would require a division by zero.
254   bool div0 = false;
255   auto result =
256       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
257         if (div0 || !b) {
258           div0 = true;
259           return a;
260         }
261         return a.udiv(b);
262       });
263 
264   // Fold out division by one. Assumes all tensors of all ones are splats.
265   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
266     if (rhs.getValue() == 1)
267       return getLhs();
268   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
269     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
270       return getLhs();
271   }
272 
273   return div0 ? Attribute() : result;
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // DivSIOp
278 //===----------------------------------------------------------------------===//
279 
280 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
281   // Don't fold if it would overflow or if it requires a division by zero.
282   bool overflowOrDiv0 = false;
283   auto result =
284       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
285         if (overflowOrDiv0 || !b) {
286           overflowOrDiv0 = true;
287           return a;
288         }
289         return a.sdiv_ov(b, overflowOrDiv0);
290       });
291 
292   // Fold out division by one. Assumes all tensors of all ones are splats.
293   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
294     if (rhs.getValue() == 1)
295       return getLhs();
296   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
297     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
298       return getLhs();
299   }
300 
301   return overflowOrDiv0 ? Attribute() : result;
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // Ceil and floor division folding helpers
306 //===----------------------------------------------------------------------===//
307 
308 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
309                                     bool &overflow) {
310   // Returns (a-1)/b + 1
311   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
312   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
313   return val.sadd_ov(one, overflow);
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // CeilDivUIOp
318 //===----------------------------------------------------------------------===//
319 
320 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
321   bool overflowOrDiv0 = false;
322   auto result =
323       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
324         if (overflowOrDiv0 || !b) {
325           overflowOrDiv0 = true;
326           return a;
327         }
328         APInt quotient = a.udiv(b);
329         if (!a.urem(b))
330           return quotient;
331         APInt one(a.getBitWidth(), 1, true);
332         return quotient.uadd_ov(one, overflowOrDiv0);
333       });
334   // Fold out ceil division by one. Assumes all tensors of all ones are
335   // splats.
336   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
337     if (rhs.getValue() == 1)
338       return getLhs();
339   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
340     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
341       return getLhs();
342   }
343 
344   return overflowOrDiv0 ? Attribute() : result;
345 }
346 
347 //===----------------------------------------------------------------------===//
348 // CeilDivSIOp
349 //===----------------------------------------------------------------------===//
350 
351 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
352   // Don't fold if it would overflow or if it requires a division by zero.
353   bool overflowOrDiv0 = false;
354   auto result =
355       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
356         if (overflowOrDiv0 || !b) {
357           overflowOrDiv0 = true;
358           return a;
359         }
360         if (!a)
361           return a;
362         // After this point we know that neither a or b are zero.
363         unsigned bits = a.getBitWidth();
364         APInt zero = APInt::getZero(bits);
365         bool aGtZero = a.sgt(zero);
366         bool bGtZero = b.sgt(zero);
367         if (aGtZero && bGtZero) {
368           // Both positive, return ceil(a, b).
369           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
370         }
371         if (!aGtZero && !bGtZero) {
372           // Both negative, return ceil(-a, -b).
373           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
374           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
375           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
376         }
377         if (!aGtZero && bGtZero) {
378           // A is negative, b is positive, return - ( -a / b).
379           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
380           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
381           return zero.ssub_ov(div, overflowOrDiv0);
382         }
383         // A is positive, b is negative, return - (a / -b).
384         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
385         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
386         return zero.ssub_ov(div, overflowOrDiv0);
387       });
388 
389   // Fold out ceil division by one. Assumes all tensors of all ones are
390   // splats.
391   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
392     if (rhs.getValue() == 1)
393       return getLhs();
394   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
395     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
396       return getLhs();
397   }
398 
399   return overflowOrDiv0 ? Attribute() : result;
400 }
401 
402 //===----------------------------------------------------------------------===//
403 // FloorDivSIOp
404 //===----------------------------------------------------------------------===//
405 
406 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
407   // Don't fold if it would overflow or if it requires a division by zero.
408   bool overflowOrDiv0 = false;
409   auto result =
410       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
411         if (overflowOrDiv0 || !b) {
412           overflowOrDiv0 = true;
413           return a;
414         }
415         if (!a)
416           return a;
417         // After this point we know that neither a or b are zero.
418         unsigned bits = a.getBitWidth();
419         APInt zero = APInt::getZero(bits);
420         bool aGtZero = a.sgt(zero);
421         bool bGtZero = b.sgt(zero);
422         if (aGtZero && bGtZero) {
423           // Both positive, return a / b.
424           return a.sdiv_ov(b, overflowOrDiv0);
425         }
426         if (!aGtZero && !bGtZero) {
427           // Both negative, return -a / -b.
428           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
429           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
430           return posA.sdiv_ov(posB, overflowOrDiv0);
431         }
432         if (!aGtZero && bGtZero) {
433           // A is negative, b is positive, return - ceil(-a, b).
434           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
435           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
436           return zero.ssub_ov(ceil, overflowOrDiv0);
437         }
438         // A is positive, b is negative, return - ceil(a, -b).
439         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
440         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
441         return zero.ssub_ov(ceil, overflowOrDiv0);
442       });
443 
444   // Fold out floor division by one. Assumes all tensors of all ones are
445   // splats.
446   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
447     if (rhs.getValue() == 1)
448       return getLhs();
449   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
450     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
451       return getLhs();
452   }
453 
454   return overflowOrDiv0 ? Attribute() : result;
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // RemUIOp
459 //===----------------------------------------------------------------------===//
460 
461 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
462   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
463   if (!rhs)
464     return {};
465   auto rhsValue = rhs.getValue();
466 
467   // x % 1 = 0
468   if (rhsValue.isOneValue())
469     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
470 
471   // Don't fold if it requires division by zero.
472   if (rhsValue.isNullValue())
473     return {};
474 
475   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
476   if (!lhs)
477     return {};
478   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // RemSIOp
483 //===----------------------------------------------------------------------===//
484 
485 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
486   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
487   if (!rhs)
488     return {};
489   auto rhsValue = rhs.getValue();
490 
491   // x % 1 = 0
492   if (rhsValue.isOneValue())
493     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
494 
495   // Don't fold if it requires division by zero.
496   if (rhsValue.isNullValue())
497     return {};
498 
499   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
500   if (!lhs)
501     return {};
502   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // AndIOp
507 //===----------------------------------------------------------------------===//
508 
509 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
510   /// and(x, 0) -> 0
511   if (matchPattern(getRhs(), m_Zero()))
512     return getRhs();
513   /// and(x, allOnes) -> x
514   APInt intValue;
515   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
516     return getLhs();
517 
518   return constFoldBinaryOp<IntegerAttr>(
519       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // OrIOp
524 //===----------------------------------------------------------------------===//
525 
526 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
527   /// or(x, 0) -> x
528   if (matchPattern(getRhs(), m_Zero()))
529     return getLhs();
530   /// or(x, <all ones>) -> <all ones>
531   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
532     if (rhsAttr.getValue().isAllOnes())
533       return rhsAttr;
534 
535   return constFoldBinaryOp<IntegerAttr>(
536       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
537 }
538 
539 //===----------------------------------------------------------------------===//
540 // XOrIOp
541 //===----------------------------------------------------------------------===//
542 
543 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
544   /// xor(x, 0) -> x
545   if (matchPattern(getRhs(), m_Zero()))
546     return getLhs();
547   /// xor(x, x) -> 0
548   if (getLhs() == getRhs())
549     return Builder(getContext()).getZeroAttr(getType());
550 
551   return constFoldBinaryOp<IntegerAttr>(
552       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
553 }
554 
555 void arith::XOrIOp::getCanonicalizationPatterns(
556     OwningRewritePatternList &patterns, MLIRContext *context) {
557   patterns.insert<XOrINotCmpI>(context);
558 }
559 
560 //===----------------------------------------------------------------------===//
561 // AddFOp
562 //===----------------------------------------------------------------------===//
563 
564 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
565   return constFoldBinaryOp<FloatAttr>(
566       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // SubFOp
571 //===----------------------------------------------------------------------===//
572 
573 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
574   return constFoldBinaryOp<FloatAttr>(
575       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
576 }
577 
578 //===----------------------------------------------------------------------===//
579 // MaxSIOp
580 //===----------------------------------------------------------------------===//
581 
582 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
583   assert(operands.size() == 2 && "binary operation takes two operands");
584 
585   // maxsi(x,x) -> x
586   if (getLhs() == getRhs())
587     return getRhs();
588 
589   APInt intValue;
590   // maxsi(x,MAX_INT) -> MAX_INT
591   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
592       intValue.isMaxSignedValue())
593     return getRhs();
594 
595   // maxsi(x, MIN_INT) -> x
596   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
597       intValue.isMinSignedValue())
598     return getLhs();
599 
600   return constFoldBinaryOp<IntegerAttr>(operands,
601                                         [](const APInt &a, const APInt &b) {
602                                           return llvm::APIntOps::smax(a, b);
603                                         });
604 }
605 
606 //===----------------------------------------------------------------------===//
607 // MaxUIOp
608 //===----------------------------------------------------------------------===//
609 
610 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
611   assert(operands.size() == 2 && "binary operation takes two operands");
612 
613   // maxui(x,x) -> x
614   if (getLhs() == getRhs())
615     return getRhs();
616 
617   APInt intValue;
618   // maxui(x,MAX_INT) -> MAX_INT
619   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
620     return getRhs();
621 
622   // maxui(x, MIN_INT) -> x
623   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
624     return getLhs();
625 
626   return constFoldBinaryOp<IntegerAttr>(operands,
627                                         [](const APInt &a, const APInt &b) {
628                                           return llvm::APIntOps::umax(a, b);
629                                         });
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // MinSIOp
634 //===----------------------------------------------------------------------===//
635 
636 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
637   assert(operands.size() == 2 && "binary operation takes two operands");
638 
639   // minsi(x,x) -> x
640   if (getLhs() == getRhs())
641     return getRhs();
642 
643   APInt intValue;
644   // minsi(x,MIN_INT) -> MIN_INT
645   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
646       intValue.isMinSignedValue())
647     return getRhs();
648 
649   // minsi(x, MAX_INT) -> x
650   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
651       intValue.isMaxSignedValue())
652     return getLhs();
653 
654   return constFoldBinaryOp<IntegerAttr>(operands,
655                                         [](const APInt &a, const APInt &b) {
656                                           return llvm::APIntOps::smin(a, b);
657                                         });
658 }
659 
660 //===----------------------------------------------------------------------===//
661 // MinUIOp
662 //===----------------------------------------------------------------------===//
663 
664 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
665   assert(operands.size() == 2 && "binary operation takes two operands");
666 
667   // minui(x,x) -> x
668   if (getLhs() == getRhs())
669     return getRhs();
670 
671   APInt intValue;
672   // minui(x,MIN_INT) -> MIN_INT
673   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
674     return getRhs();
675 
676   // minui(x, MAX_INT) -> x
677   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
678     return getLhs();
679 
680   return constFoldBinaryOp<IntegerAttr>(operands,
681                                         [](const APInt &a, const APInt &b) {
682                                           return llvm::APIntOps::umin(a, b);
683                                         });
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // MulFOp
688 //===----------------------------------------------------------------------===//
689 
690 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
691   return constFoldBinaryOp<FloatAttr>(
692       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
693 }
694 
695 //===----------------------------------------------------------------------===//
696 // DivFOp
697 //===----------------------------------------------------------------------===//
698 
699 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
700   return constFoldBinaryOp<FloatAttr>(
701       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // Utility functions for verifying cast ops
706 //===----------------------------------------------------------------------===//
707 
708 template <typename... Types>
709 using type_list = std::tuple<Types...> *;
710 
711 /// Returns a non-null type only if the provided type is one of the allowed
712 /// types or one of the allowed shaped types of the allowed types. Returns the
713 /// element type if a valid shaped type is provided.
714 template <typename... ShapedTypes, typename... ElementTypes>
715 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
716                               type_list<ElementTypes...>) {
717   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
718     return {};
719 
720   auto underlyingType = getElementTypeOrSelf(type);
721   if (!underlyingType.isa<ElementTypes...>())
722     return {};
723 
724   return underlyingType;
725 }
726 
727 /// Get allowed underlying types for vectors and tensors.
728 template <typename... ElementTypes>
729 static Type getTypeIfLike(Type type) {
730   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
731                            type_list<ElementTypes...>());
732 }
733 
734 /// Get allowed underlying types for vectors, tensors, and memrefs.
735 template <typename... ElementTypes>
736 static Type getTypeIfLikeOrMemRef(Type type) {
737   return getUnderlyingType(type,
738                            type_list<VectorType, TensorType, MemRefType>(),
739                            type_list<ElementTypes...>());
740 }
741 
742 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
743   return inputs.size() == 1 && outputs.size() == 1 &&
744          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
745 }
746 
747 //===----------------------------------------------------------------------===//
748 // Verifiers for integer and floating point extension/truncation ops
749 //===----------------------------------------------------------------------===//
750 
751 // Extend ops can only extend to a wider type.
752 template <typename ValType, typename Op>
753 static LogicalResult verifyExtOp(Op op) {
754   Type srcType = getElementTypeOrSelf(op.getIn().getType());
755   Type dstType = getElementTypeOrSelf(op.getType());
756 
757   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
758     return op.emitError("result type ")
759            << dstType << " must be wider than operand type " << srcType;
760 
761   return success();
762 }
763 
764 // Truncate ops can only truncate to a shorter type.
765 template <typename ValType, typename Op>
766 static LogicalResult verifyTruncateOp(Op op) {
767   Type srcType = getElementTypeOrSelf(op.getIn().getType());
768   Type dstType = getElementTypeOrSelf(op.getType());
769 
770   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
771     return op.emitError("result type ")
772            << dstType << " must be shorter than operand type " << srcType;
773 
774   return success();
775 }
776 
777 /// Validate a cast that changes the width of a type.
778 template <template <typename> class WidthComparator, typename... ElementTypes>
779 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
780   if (!areValidCastInputsAndOutputs(inputs, outputs))
781     return false;
782 
783   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
784   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
785   if (!srcType || !dstType)
786     return false;
787 
788   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
789                                      srcType.getIntOrFloatBitWidth());
790 }
791 
792 //===----------------------------------------------------------------------===//
793 // ExtUIOp
794 //===----------------------------------------------------------------------===//
795 
796 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
797   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
798     return IntegerAttr::get(
799         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
800 
801   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
802     getInMutable().assign(lhs.getIn());
803     return getResult();
804   }
805 
806   return {};
807 }
808 
809 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
810   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
811 }
812 
813 //===----------------------------------------------------------------------===//
814 // ExtSIOp
815 //===----------------------------------------------------------------------===//
816 
817 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
818   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
819     return IntegerAttr::get(
820         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
821 
822   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
823     getInMutable().assign(lhs.getIn());
824     return getResult();
825   }
826 
827   return {};
828 }
829 
830 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
831   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
832 }
833 
834 void arith::ExtSIOp::getCanonicalizationPatterns(
835     OwningRewritePatternList &patterns, MLIRContext *context) {
836   patterns.insert<ExtSIOfExtUI>(context);
837 }
838 
839 //===----------------------------------------------------------------------===//
840 // ExtFOp
841 //===----------------------------------------------------------------------===//
842 
843 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
844   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // TruncIOp
849 //===----------------------------------------------------------------------===//
850 
851 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
852   // trunci(zexti(a)) -> a
853   // trunci(sexti(a)) -> a
854   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
855       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
856     return getOperand().getDefiningOp()->getOperand(0);
857 
858   assert(operands.size() == 1 && "unary operation takes one operand");
859 
860   if (!operands[0])
861     return {};
862 
863   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
864     return IntegerAttr::get(
865         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
866   }
867 
868   return {};
869 }
870 
871 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
872   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // TruncFOp
877 //===----------------------------------------------------------------------===//
878 
879 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
880 /// can be represented without precision loss or rounding.
881 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
882   assert(operands.size() == 1 && "unary operation takes one operand");
883 
884   auto constOperand = operands.front();
885   if (!constOperand || !constOperand.isa<FloatAttr>())
886     return {};
887 
888   // Convert to target type via 'double'.
889   double sourceValue =
890       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
891   auto targetAttr = FloatAttr::get(getType(), sourceValue);
892 
893   // Propagate if constant's value does not change after truncation.
894   if (sourceValue == targetAttr.getValue().convertToDouble())
895     return targetAttr;
896 
897   return {};
898 }
899 
900 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
901   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
902 }
903 
904 //===----------------------------------------------------------------------===//
905 // Verifiers for casts between integers and floats.
906 //===----------------------------------------------------------------------===//
907 
908 template <typename From, typename To>
909 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
910   if (!areValidCastInputsAndOutputs(inputs, outputs))
911     return false;
912 
913   auto srcType = getTypeIfLike<From>(inputs.front());
914   auto dstType = getTypeIfLike<To>(outputs.back());
915 
916   return srcType && dstType;
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // UIToFPOp
921 //===----------------------------------------------------------------------===//
922 
923 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
924   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
925 }
926 
927 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
928   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
929     const APInt &api = lhs.getValue();
930     FloatType floatTy = getType().cast<FloatType>();
931     APFloat apf(floatTy.getFloatSemantics(),
932                 APInt::getZero(floatTy.getWidth()));
933     apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
934     return FloatAttr::get(floatTy, apf);
935   }
936   return {};
937 }
938 
939 //===----------------------------------------------------------------------===//
940 // SIToFPOp
941 //===----------------------------------------------------------------------===//
942 
943 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
944   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
945 }
946 
947 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
948   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
949     const APInt &api = lhs.getValue();
950     FloatType floatTy = getType().cast<FloatType>();
951     APFloat apf(floatTy.getFloatSemantics(),
952                 APInt::getZero(floatTy.getWidth()));
953     apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
954     return FloatAttr::get(floatTy, apf);
955   }
956   return {};
957 }
958 //===----------------------------------------------------------------------===//
959 // FPToUIOp
960 //===----------------------------------------------------------------------===//
961 
962 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
963   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
964 }
965 
966 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
967   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
968     const APFloat &apf = lhs.getValue();
969     IntegerType intTy = getType().cast<IntegerType>();
970     bool ignored;
971     APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
972     if (APFloat::opInvalidOp ==
973         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
974       // Undefined behavior invoked - the destination type can't represent
975       // the input constant.
976       return {};
977     }
978     return IntegerAttr::get(getType(), api);
979   }
980 
981   return {};
982 }
983 
984 //===----------------------------------------------------------------------===//
985 // FPToSIOp
986 //===----------------------------------------------------------------------===//
987 
988 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
989   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
990 }
991 
992 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
993   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
994     const APFloat &apf = lhs.getValue();
995     IntegerType intTy = getType().cast<IntegerType>();
996     bool ignored;
997     APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
998     if (APFloat::opInvalidOp ==
999         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1000       // Undefined behavior invoked - the destination type can't represent
1001       // the input constant.
1002       return {};
1003     }
1004     return IntegerAttr::get(getType(), api);
1005   }
1006 
1007   return {};
1008 }
1009 
1010 //===----------------------------------------------------------------------===//
1011 // IndexCastOp
1012 //===----------------------------------------------------------------------===//
1013 
1014 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1015                                            TypeRange outputs) {
1016   if (!areValidCastInputsAndOutputs(inputs, outputs))
1017     return false;
1018 
1019   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1020   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1021   if (!srcType || !dstType)
1022     return false;
1023 
1024   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1025          (srcType.isSignlessInteger() && dstType.isIndex());
1026 }
1027 
1028 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1029   // index_cast(constant) -> constant
1030   // A little hack because we go through int. Otherwise, the size of the
1031   // constant might need to change.
1032   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1033     return IntegerAttr::get(getType(), value.getInt());
1034 
1035   return {};
1036 }
1037 
1038 void arith::IndexCastOp::getCanonicalizationPatterns(
1039     OwningRewritePatternList &patterns, MLIRContext *context) {
1040   patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1041 }
1042 
1043 //===----------------------------------------------------------------------===//
1044 // BitcastOp
1045 //===----------------------------------------------------------------------===//
1046 
1047 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1048   if (!areValidCastInputsAndOutputs(inputs, outputs))
1049     return false;
1050 
1051   auto srcType =
1052       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1053   auto dstType =
1054       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1055   if (!srcType || !dstType)
1056     return false;
1057 
1058   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1059 }
1060 
1061 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1062   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1063 
1064   auto resType = getType();
1065   auto operand = operands[0];
1066   if (!operand)
1067     return {};
1068 
1069   /// Bitcast dense elements.
1070   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1071     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1072   /// Other shaped types unhandled.
1073   if (resType.isa<ShapedType>())
1074     return {};
1075 
1076   /// Bitcast integer or float to integer or float.
1077   APInt bits = operand.isa<FloatAttr>()
1078                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1079                    : operand.cast<IntegerAttr>().getValue();
1080 
1081   if (auto resFloatType = resType.dyn_cast<FloatType>())
1082     return FloatAttr::get(resType,
1083                           APFloat(resFloatType.getFloatSemantics(), bits));
1084   return IntegerAttr::get(resType, bits);
1085 }
1086 
1087 void arith::BitcastOp::getCanonicalizationPatterns(
1088     OwningRewritePatternList &patterns, MLIRContext *context) {
1089   patterns.insert<BitcastOfBitcast>(context);
1090 }
1091 
1092 //===----------------------------------------------------------------------===//
1093 // Helpers for compare ops
1094 //===----------------------------------------------------------------------===//
1095 
1096 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1097 static Type getI1SameShape(Type type) {
1098   auto i1Type = IntegerType::get(type.getContext(), 1);
1099   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1100     return RankedTensorType::get(tensorType.getShape(), i1Type);
1101   if (type.isa<UnrankedTensorType>())
1102     return UnrankedTensorType::get(i1Type);
1103   if (auto vectorType = type.dyn_cast<VectorType>())
1104     return VectorType::get(vectorType.getShape(), i1Type,
1105                            vectorType.getNumScalableDims());
1106   return i1Type;
1107 }
1108 
1109 //===----------------------------------------------------------------------===//
1110 // CmpIOp
1111 //===----------------------------------------------------------------------===//
1112 
1113 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1114 /// comparison predicates.
1115 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1116                                     const APInt &lhs, const APInt &rhs) {
1117   switch (predicate) {
1118   case arith::CmpIPredicate::eq:
1119     return lhs.eq(rhs);
1120   case arith::CmpIPredicate::ne:
1121     return lhs.ne(rhs);
1122   case arith::CmpIPredicate::slt:
1123     return lhs.slt(rhs);
1124   case arith::CmpIPredicate::sle:
1125     return lhs.sle(rhs);
1126   case arith::CmpIPredicate::sgt:
1127     return lhs.sgt(rhs);
1128   case arith::CmpIPredicate::sge:
1129     return lhs.sge(rhs);
1130   case arith::CmpIPredicate::ult:
1131     return lhs.ult(rhs);
1132   case arith::CmpIPredicate::ule:
1133     return lhs.ule(rhs);
1134   case arith::CmpIPredicate::ugt:
1135     return lhs.ugt(rhs);
1136   case arith::CmpIPredicate::uge:
1137     return lhs.uge(rhs);
1138   }
1139   llvm_unreachable("unknown cmpi predicate kind");
1140 }
1141 
1142 /// Returns true if the predicate is true for two equal operands.
1143 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1144   switch (predicate) {
1145   case arith::CmpIPredicate::eq:
1146   case arith::CmpIPredicate::sle:
1147   case arith::CmpIPredicate::sge:
1148   case arith::CmpIPredicate::ule:
1149   case arith::CmpIPredicate::uge:
1150     return true;
1151   case arith::CmpIPredicate::ne:
1152   case arith::CmpIPredicate::slt:
1153   case arith::CmpIPredicate::sgt:
1154   case arith::CmpIPredicate::ult:
1155   case arith::CmpIPredicate::ugt:
1156     return false;
1157   }
1158   llvm_unreachable("unknown cmpi predicate kind");
1159 }
1160 
1161 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1162   auto boolAttr = BoolAttr::get(ctx, value);
1163   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1164   if (!shapedType)
1165     return boolAttr;
1166   return DenseElementsAttr::get(shapedType, boolAttr);
1167 }
1168 
1169 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1170   assert(operands.size() == 2 && "cmpi takes two operands");
1171 
1172   // cmpi(pred, x, x)
1173   if (getLhs() == getRhs()) {
1174     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1175     return getBoolAttribute(getType(), getContext(), val);
1176   }
1177 
1178   if (matchPattern(getRhs(), m_Zero())) {
1179     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1180       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1181         // extsi(%x : i1 -> iN) != 0  ->  %x
1182         if (getPredicate() == arith::CmpIPredicate::ne) {
1183           return extOp.getOperand();
1184         }
1185       }
1186     }
1187     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1188       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1189         // extui(%x : i1 -> iN) != 0  ->  %x
1190         if (getPredicate() == arith::CmpIPredicate::ne) {
1191           return extOp.getOperand();
1192         }
1193       }
1194     }
1195   }
1196 
1197   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1198   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1199   if (!lhs || !rhs)
1200     return {};
1201 
1202   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1203   return BoolAttr::get(getContext(), val);
1204 }
1205 
1206 //===----------------------------------------------------------------------===//
1207 // CmpFOp
1208 //===----------------------------------------------------------------------===//
1209 
1210 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1211 /// comparison predicates.
1212 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1213                                     const APFloat &lhs, const APFloat &rhs) {
1214   auto cmpResult = lhs.compare(rhs);
1215   switch (predicate) {
1216   case arith::CmpFPredicate::AlwaysFalse:
1217     return false;
1218   case arith::CmpFPredicate::OEQ:
1219     return cmpResult == APFloat::cmpEqual;
1220   case arith::CmpFPredicate::OGT:
1221     return cmpResult == APFloat::cmpGreaterThan;
1222   case arith::CmpFPredicate::OGE:
1223     return cmpResult == APFloat::cmpGreaterThan ||
1224            cmpResult == APFloat::cmpEqual;
1225   case arith::CmpFPredicate::OLT:
1226     return cmpResult == APFloat::cmpLessThan;
1227   case arith::CmpFPredicate::OLE:
1228     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1229   case arith::CmpFPredicate::ONE:
1230     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1231   case arith::CmpFPredicate::ORD:
1232     return cmpResult != APFloat::cmpUnordered;
1233   case arith::CmpFPredicate::UEQ:
1234     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1235   case arith::CmpFPredicate::UGT:
1236     return cmpResult == APFloat::cmpUnordered ||
1237            cmpResult == APFloat::cmpGreaterThan;
1238   case arith::CmpFPredicate::UGE:
1239     return cmpResult == APFloat::cmpUnordered ||
1240            cmpResult == APFloat::cmpGreaterThan ||
1241            cmpResult == APFloat::cmpEqual;
1242   case arith::CmpFPredicate::ULT:
1243     return cmpResult == APFloat::cmpUnordered ||
1244            cmpResult == APFloat::cmpLessThan;
1245   case arith::CmpFPredicate::ULE:
1246     return cmpResult == APFloat::cmpUnordered ||
1247            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1248   case arith::CmpFPredicate::UNE:
1249     return cmpResult != APFloat::cmpEqual;
1250   case arith::CmpFPredicate::UNO:
1251     return cmpResult == APFloat::cmpUnordered;
1252   case arith::CmpFPredicate::AlwaysTrue:
1253     return true;
1254   }
1255   llvm_unreachable("unknown cmpf predicate kind");
1256 }
1257 
1258 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1259   assert(operands.size() == 2 && "cmpf takes two operands");
1260 
1261   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1262   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1263 
1264   if (!lhs || !rhs)
1265     return {};
1266 
1267   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1268   return BoolAttr::get(getContext(), val);
1269 }
1270 
1271 //===----------------------------------------------------------------------===//
1272 // Atomic Enum
1273 //===----------------------------------------------------------------------===//
1274 
1275 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1276 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1277                                             OpBuilder &builder, Location loc) {
1278   switch (kind) {
1279   case AtomicRMWKind::maxf:
1280     return builder.getFloatAttr(
1281         resultType,
1282         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1283                         /*Negative=*/true));
1284   case AtomicRMWKind::addf:
1285   case AtomicRMWKind::addi:
1286   case AtomicRMWKind::maxu:
1287   case AtomicRMWKind::ori:
1288     return builder.getZeroAttr(resultType);
1289   case AtomicRMWKind::andi:
1290     return builder.getIntegerAttr(
1291         resultType,
1292         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1293   case AtomicRMWKind::maxs:
1294     return builder.getIntegerAttr(
1295         resultType,
1296         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1297   case AtomicRMWKind::minf:
1298     return builder.getFloatAttr(
1299         resultType,
1300         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1301                         /*Negative=*/false));
1302   case AtomicRMWKind::mins:
1303     return builder.getIntegerAttr(
1304         resultType,
1305         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1306   case AtomicRMWKind::minu:
1307     return builder.getIntegerAttr(
1308         resultType,
1309         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1310   case AtomicRMWKind::muli:
1311     return builder.getIntegerAttr(resultType, 1);
1312   case AtomicRMWKind::mulf:
1313     return builder.getFloatAttr(resultType, 1);
1314   // TODO: Add remaining reduction operations.
1315   default:
1316     (void)emitOptionalError(loc, "Reduction operation type not supported");
1317     break;
1318   }
1319   return nullptr;
1320 }
1321 
1322 /// Returns the identity value associated with an AtomicRMWKind op.
1323 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1324                                     OpBuilder &builder, Location loc) {
1325   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1326   return builder.create<arith::ConstantOp>(loc, attr);
1327 }
1328 
1329 /// Return the value obtained by applying the reduction operation kind
1330 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1331 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1332                                   Location loc, Value lhs, Value rhs) {
1333   switch (op) {
1334   case AtomicRMWKind::addf:
1335     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1336   case AtomicRMWKind::addi:
1337     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1338   case AtomicRMWKind::mulf:
1339     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1340   case AtomicRMWKind::muli:
1341     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1342   case AtomicRMWKind::maxf:
1343     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1344   case AtomicRMWKind::minf:
1345     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1346   case AtomicRMWKind::maxs:
1347     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1348   case AtomicRMWKind::mins:
1349     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1350   case AtomicRMWKind::maxu:
1351     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1352   case AtomicRMWKind::minu:
1353     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1354   case AtomicRMWKind::ori:
1355     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1356   case AtomicRMWKind::andi:
1357     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1358   // TODO: Add remaining reduction operations.
1359   default:
1360     (void)emitOptionalError(loc, "Reduction operation type not supported");
1361     break;
1362   }
1363   return nullptr;
1364 }
1365 
1366 //===----------------------------------------------------------------------===//
1367 // TableGen'd op method definitions
1368 //===----------------------------------------------------------------------===//
1369 
1370 #define GET_OP_CLASSES
1371 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1372 
1373 //===----------------------------------------------------------------------===//
1374 // TableGen'd enum attribute definitions
1375 //===----------------------------------------------------------------------===//
1376 
1377 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1378