1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #include "llvm/ADT/APSInt.h"
20 
21 using namespace mlir;
22 using namespace mlir::arith;
23 
24 //===----------------------------------------------------------------------===//
25 // Pattern helpers
26 //===----------------------------------------------------------------------===//
27 
28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
29                                    Attribute lhs, Attribute rhs) {
30   return builder.getIntegerAttr(res.getType(),
31                                 lhs.cast<IntegerAttr>().getInt() +
32                                     rhs.cast<IntegerAttr>().getInt());
33 }
34 
35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
36                                    Attribute lhs, Attribute rhs) {
37   return builder.getIntegerAttr(res.getType(),
38                                 lhs.cast<IntegerAttr>().getInt() -
39                                     rhs.cast<IntegerAttr>().getInt());
40 }
41 
42 /// Invert an integer comparison predicate.
43 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
44   switch (pred) {
45   case arith::CmpIPredicate::eq:
46     return arith::CmpIPredicate::ne;
47   case arith::CmpIPredicate::ne:
48     return arith::CmpIPredicate::eq;
49   case arith::CmpIPredicate::slt:
50     return arith::CmpIPredicate::sge;
51   case arith::CmpIPredicate::sle:
52     return arith::CmpIPredicate::sgt;
53   case arith::CmpIPredicate::sgt:
54     return arith::CmpIPredicate::sle;
55   case arith::CmpIPredicate::sge:
56     return arith::CmpIPredicate::slt;
57   case arith::CmpIPredicate::ult:
58     return arith::CmpIPredicate::uge;
59   case arith::CmpIPredicate::ule:
60     return arith::CmpIPredicate::ugt;
61   case arith::CmpIPredicate::ugt:
62     return arith::CmpIPredicate::ule;
63   case arith::CmpIPredicate::uge:
64     return arith::CmpIPredicate::ult;
65   }
66   llvm_unreachable("unknown cmpi predicate kind");
67 }
68 
69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
70   return arith::CmpIPredicateAttr::get(pred.getContext(),
71                                        invertPredicate(pred.getValue()));
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // TableGen'd canonicalization patterns
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 #include "ArithmeticCanonicalization.inc"
80 } // namespace
81 
82 //===----------------------------------------------------------------------===//
83 // ConstantOp
84 //===----------------------------------------------------------------------===//
85 
86 void arith::ConstantOp::getAsmResultNames(
87     function_ref<void(Value, StringRef)> setNameFn) {
88   auto type = getType();
89   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
90     auto intType = type.dyn_cast<IntegerType>();
91 
92     // Sugar i1 constants with 'true' and 'false'.
93     if (intType && intType.getWidth() == 1)
94       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
95 
96     // Otherwise, build a compex name with the value and type.
97     SmallString<32> specialNameBuffer;
98     llvm::raw_svector_ostream specialName(specialNameBuffer);
99     specialName << 'c' << intCst.getInt();
100     if (intType)
101       specialName << '_' << type;
102     setNameFn(getResult(), specialName.str());
103   } else {
104     setNameFn(getResult(), "cst");
105   }
106 }
107 
108 /// TODO: disallow arith.constant to return anything other than signless integer
109 /// or float like.
110 static LogicalResult verify(arith::ConstantOp op) {
111   auto type = op.getType();
112   // The value's type must match the return type.
113   if (op.getValue().getType() != type) {
114     return op.emitOpError() << "value type " << op.getValue().getType()
115                             << " must match return type: " << type;
116   }
117   // Integer values must be signless.
118   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
119     return op.emitOpError("integer return type must be signless");
120   // Any float or elements attribute are acceptable.
121   if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
122     return op.emitOpError(
123         "value must be an integer, float, or elements attribute");
124   }
125   return success();
126 }
127 
128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
129   // The value's type must be the same as the provided type.
130   if (value.getType() != type)
131     return false;
132   // Integer values must be signless.
133   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
134     return false;
135   // Integer, float, and element attributes are buildable.
136   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
137 }
138 
139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
140   return getValue();
141 }
142 
143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
144                                  int64_t value, unsigned width) {
145   auto type = builder.getIntegerType(width);
146   arith::ConstantOp::build(builder, result, type,
147                            builder.getIntegerAttr(type, value));
148 }
149 
150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
151                                  int64_t value, Type type) {
152   assert(type.isSignlessInteger() &&
153          "ConstantIntOp can only have signless integer type values");
154   arith::ConstantOp::build(builder, result, type,
155                            builder.getIntegerAttr(type, value));
156 }
157 
158 bool arith::ConstantIntOp::classof(Operation *op) {
159   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
160     return constOp.getType().isSignlessInteger();
161   return false;
162 }
163 
164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
165                                    const APFloat &value, FloatType type) {
166   arith::ConstantOp::build(builder, result, type,
167                            builder.getFloatAttr(type, value));
168 }
169 
170 bool arith::ConstantFloatOp::classof(Operation *op) {
171   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
172     return constOp.getType().isa<FloatType>();
173   return false;
174 }
175 
176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
177                                    int64_t value) {
178   arith::ConstantOp::build(builder, result, builder.getIndexType(),
179                            builder.getIndexAttr(value));
180 }
181 
182 bool arith::ConstantIndexOp::classof(Operation *op) {
183   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
184     return constOp.getType().isIndex();
185   return false;
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // AddIOp
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
193   // addi(x, 0) -> x
194   if (matchPattern(getRhs(), m_Zero()))
195     return getLhs();
196 
197   // addi(subi(a, b), b) -> a
198   if (auto sub = getLhs().getDefiningOp<SubIOp>())
199     if (getRhs() == sub.getRhs())
200       return sub.getLhs();
201 
202   // addi(b, subi(a, b)) -> a
203   if (auto sub = getRhs().getDefiningOp<SubIOp>())
204     if (getLhs() == sub.getRhs())
205       return sub.getLhs();
206 
207   return constFoldBinaryOp<IntegerAttr>(
208       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
209 }
210 
211 void arith::AddIOp::getCanonicalizationPatterns(
212     RewritePatternSet &patterns, MLIRContext *context) {
213   patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
214       context);
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // SubIOp
219 //===----------------------------------------------------------------------===//
220 
221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
222   // subi(x,x) -> 0
223   if (getOperand(0) == getOperand(1))
224     return Builder(getContext()).getZeroAttr(getType());
225   // subi(x,0) -> x
226   if (matchPattern(getRhs(), m_Zero()))
227     return getLhs();
228 
229   return constFoldBinaryOp<IntegerAttr>(
230       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
231 }
232 
233 void arith::SubIOp::getCanonicalizationPatterns(
234     RewritePatternSet &patterns, MLIRContext *context) {
235   patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
236                   SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
237                   SubILHSSubConstantLHS>(context);
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // MulIOp
242 //===----------------------------------------------------------------------===//
243 
244 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
245   // muli(x, 0) -> 0
246   if (matchPattern(getRhs(), m_Zero()))
247     return getRhs();
248   // muli(x, 1) -> x
249   if (matchPattern(getRhs(), m_One()))
250     return getOperand(0);
251   // TODO: Handle the overflow case.
252 
253   // default folder
254   return constFoldBinaryOp<IntegerAttr>(
255       operands, [](const APInt &a, const APInt &b) { return a * b; });
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // DivUIOp
260 //===----------------------------------------------------------------------===//
261 
262 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
263   // Don't fold if it would require a division by zero.
264   bool div0 = false;
265   auto result =
266       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
267         if (div0 || !b) {
268           div0 = true;
269           return a;
270         }
271         return a.udiv(b);
272       });
273 
274   // Fold out division by one. Assumes all tensors of all ones are splats.
275   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
276     if (rhs.getValue() == 1)
277       return getLhs();
278   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
279     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
280       return getLhs();
281   }
282 
283   return div0 ? Attribute() : result;
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // DivSIOp
288 //===----------------------------------------------------------------------===//
289 
290 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
291   // Don't fold if it would overflow or if it requires a division by zero.
292   bool overflowOrDiv0 = false;
293   auto result =
294       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
295         if (overflowOrDiv0 || !b) {
296           overflowOrDiv0 = true;
297           return a;
298         }
299         return a.sdiv_ov(b, overflowOrDiv0);
300       });
301 
302   // Fold out division by one. Assumes all tensors of all ones are splats.
303   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
304     if (rhs.getValue() == 1)
305       return getLhs();
306   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
307     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
308       return getLhs();
309   }
310 
311   return overflowOrDiv0 ? Attribute() : result;
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Ceil and floor division folding helpers
316 //===----------------------------------------------------------------------===//
317 
318 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
319                                     bool &overflow) {
320   // Returns (a-1)/b + 1
321   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
322   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
323   return val.sadd_ov(one, overflow);
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // CeilDivUIOp
328 //===----------------------------------------------------------------------===//
329 
330 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
331   bool overflowOrDiv0 = false;
332   auto result =
333       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
334         if (overflowOrDiv0 || !b) {
335           overflowOrDiv0 = true;
336           return a;
337         }
338         APInt quotient = a.udiv(b);
339         if (!a.urem(b))
340           return quotient;
341         APInt one(a.getBitWidth(), 1, true);
342         return quotient.uadd_ov(one, overflowOrDiv0);
343       });
344   // Fold out ceil division by one. Assumes all tensors of all ones are
345   // splats.
346   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
347     if (rhs.getValue() == 1)
348       return getLhs();
349   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
350     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
351       return getLhs();
352   }
353 
354   return overflowOrDiv0 ? Attribute() : result;
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // CeilDivSIOp
359 //===----------------------------------------------------------------------===//
360 
361 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
362   // Don't fold if it would overflow or if it requires a division by zero.
363   bool overflowOrDiv0 = false;
364   auto result =
365       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
366         if (overflowOrDiv0 || !b) {
367           overflowOrDiv0 = true;
368           return a;
369         }
370         if (!a)
371           return a;
372         // After this point we know that neither a or b are zero.
373         unsigned bits = a.getBitWidth();
374         APInt zero = APInt::getZero(bits);
375         bool aGtZero = a.sgt(zero);
376         bool bGtZero = b.sgt(zero);
377         if (aGtZero && bGtZero) {
378           // Both positive, return ceil(a, b).
379           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
380         }
381         if (!aGtZero && !bGtZero) {
382           // Both negative, return ceil(-a, -b).
383           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
384           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
385           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
386         }
387         if (!aGtZero && bGtZero) {
388           // A is negative, b is positive, return - ( -a / b).
389           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
390           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
391           return zero.ssub_ov(div, overflowOrDiv0);
392         }
393         // A is positive, b is negative, return - (a / -b).
394         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
395         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
396         return zero.ssub_ov(div, overflowOrDiv0);
397       });
398 
399   // Fold out ceil division by one. Assumes all tensors of all ones are
400   // splats.
401   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
402     if (rhs.getValue() == 1)
403       return getLhs();
404   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
405     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
406       return getLhs();
407   }
408 
409   return overflowOrDiv0 ? Attribute() : result;
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // FloorDivSIOp
414 //===----------------------------------------------------------------------===//
415 
416 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
417   // Don't fold if it would overflow or if it requires a division by zero.
418   bool overflowOrDiv0 = false;
419   auto result =
420       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
421         if (overflowOrDiv0 || !b) {
422           overflowOrDiv0 = true;
423           return a;
424         }
425         if (!a)
426           return a;
427         // After this point we know that neither a or b are zero.
428         unsigned bits = a.getBitWidth();
429         APInt zero = APInt::getZero(bits);
430         bool aGtZero = a.sgt(zero);
431         bool bGtZero = b.sgt(zero);
432         if (aGtZero && bGtZero) {
433           // Both positive, return a / b.
434           return a.sdiv_ov(b, overflowOrDiv0);
435         }
436         if (!aGtZero && !bGtZero) {
437           // Both negative, return -a / -b.
438           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
439           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
440           return posA.sdiv_ov(posB, overflowOrDiv0);
441         }
442         if (!aGtZero && bGtZero) {
443           // A is negative, b is positive, return - ceil(-a, b).
444           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
445           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
446           return zero.ssub_ov(ceil, overflowOrDiv0);
447         }
448         // A is positive, b is negative, return - ceil(a, -b).
449         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
450         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
451         return zero.ssub_ov(ceil, overflowOrDiv0);
452       });
453 
454   // Fold out floor division by one. Assumes all tensors of all ones are
455   // splats.
456   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
457     if (rhs.getValue() == 1)
458       return getLhs();
459   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
460     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
461       return getLhs();
462   }
463 
464   return overflowOrDiv0 ? Attribute() : result;
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // RemUIOp
469 //===----------------------------------------------------------------------===//
470 
471 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
472   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
473   if (!rhs)
474     return {};
475   auto rhsValue = rhs.getValue();
476 
477   // x % 1 = 0
478   if (rhsValue.isOneValue())
479     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
480 
481   // Don't fold if it requires division by zero.
482   if (rhsValue.isNullValue())
483     return {};
484 
485   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
486   if (!lhs)
487     return {};
488   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // RemSIOp
493 //===----------------------------------------------------------------------===//
494 
495 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
496   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
497   if (!rhs)
498     return {};
499   auto rhsValue = rhs.getValue();
500 
501   // x % 1 = 0
502   if (rhsValue.isOneValue())
503     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
504 
505   // Don't fold if it requires division by zero.
506   if (rhsValue.isNullValue())
507     return {};
508 
509   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
510   if (!lhs)
511     return {};
512   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // AndIOp
517 //===----------------------------------------------------------------------===//
518 
519 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
520   /// and(x, 0) -> 0
521   if (matchPattern(getRhs(), m_Zero()))
522     return getRhs();
523   /// and(x, allOnes) -> x
524   APInt intValue;
525   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
526     return getLhs();
527 
528   return constFoldBinaryOp<IntegerAttr>(
529       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // OrIOp
534 //===----------------------------------------------------------------------===//
535 
536 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
537   /// or(x, 0) -> x
538   if (matchPattern(getRhs(), m_Zero()))
539     return getLhs();
540   /// or(x, <all ones>) -> <all ones>
541   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
542     if (rhsAttr.getValue().isAllOnes())
543       return rhsAttr;
544 
545   return constFoldBinaryOp<IntegerAttr>(
546       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // XOrIOp
551 //===----------------------------------------------------------------------===//
552 
553 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
554   /// xor(x, 0) -> x
555   if (matchPattern(getRhs(), m_Zero()))
556     return getLhs();
557   /// xor(x, x) -> 0
558   if (getLhs() == getRhs())
559     return Builder(getContext()).getZeroAttr(getType());
560   /// xor(xor(x, a), a) -> x
561   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
562     if (prev.getRhs() == getRhs())
563       return prev.getLhs();
564 
565   return constFoldBinaryOp<IntegerAttr>(
566       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
567 }
568 
569 void arith::XOrIOp::getCanonicalizationPatterns(
570     RewritePatternSet &patterns, MLIRContext *context) {
571   patterns.insert<XOrINotCmpI>(context);
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // AddFOp
576 //===----------------------------------------------------------------------===//
577 
578 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
579   // addf(x, -0) -> x
580   if (matchPattern(getRhs(), m_NegZeroFloat()))
581     return getLhs();
582 
583   return constFoldBinaryOp<FloatAttr>(
584       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
585 }
586 
587 //===----------------------------------------------------------------------===//
588 // SubFOp
589 //===----------------------------------------------------------------------===//
590 
591 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
592   // subf(x, +0) -> x
593   if (matchPattern(getRhs(), m_PosZeroFloat()))
594     return getLhs();
595 
596   return constFoldBinaryOp<FloatAttr>(
597       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // MaxFOp
602 //===----------------------------------------------------------------------===//
603 
604 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
605   assert(operands.size() == 2 && "maxf takes two operands");
606 
607   // maxf(x,x) -> x
608   if (getLhs() == getRhs())
609     return getRhs();
610 
611   // maxf(x, -inf) -> x
612   if (matchPattern(getRhs(), m_NegInfFloat()))
613     return getLhs();
614 
615   return constFoldBinaryOp<FloatAttr>(
616       operands,
617       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // MaxSIOp
622 //===----------------------------------------------------------------------===//
623 
624 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
625   assert(operands.size() == 2 && "binary operation takes two operands");
626 
627   // maxsi(x,x) -> x
628   if (getLhs() == getRhs())
629     return getRhs();
630 
631   APInt intValue;
632   // maxsi(x,MAX_INT) -> MAX_INT
633   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
634       intValue.isMaxSignedValue())
635     return getRhs();
636 
637   // maxsi(x, MIN_INT) -> x
638   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
639       intValue.isMinSignedValue())
640     return getLhs();
641 
642   return constFoldBinaryOp<IntegerAttr>(operands,
643                                         [](const APInt &a, const APInt &b) {
644                                           return llvm::APIntOps::smax(a, b);
645                                         });
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // MaxUIOp
650 //===----------------------------------------------------------------------===//
651 
652 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
653   assert(operands.size() == 2 && "binary operation takes two operands");
654 
655   // maxui(x,x) -> x
656   if (getLhs() == getRhs())
657     return getRhs();
658 
659   APInt intValue;
660   // maxui(x,MAX_INT) -> MAX_INT
661   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
662     return getRhs();
663 
664   // maxui(x, MIN_INT) -> x
665   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
666     return getLhs();
667 
668   return constFoldBinaryOp<IntegerAttr>(operands,
669                                         [](const APInt &a, const APInt &b) {
670                                           return llvm::APIntOps::umax(a, b);
671                                         });
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // MinFOp
676 //===----------------------------------------------------------------------===//
677 
678 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
679   assert(operands.size() == 2 && "minf takes two operands");
680 
681   // minf(x,x) -> x
682   if (getLhs() == getRhs())
683     return getRhs();
684 
685   // minf(x, +inf) -> x
686   if (matchPattern(getRhs(), m_PosInfFloat()))
687     return getLhs();
688 
689   return constFoldBinaryOp<FloatAttr>(
690       operands,
691       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
692 }
693 
694 //===----------------------------------------------------------------------===//
695 // MinSIOp
696 //===----------------------------------------------------------------------===//
697 
698 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
699   assert(operands.size() == 2 && "binary operation takes two operands");
700 
701   // minsi(x,x) -> x
702   if (getLhs() == getRhs())
703     return getRhs();
704 
705   APInt intValue;
706   // minsi(x,MIN_INT) -> MIN_INT
707   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
708       intValue.isMinSignedValue())
709     return getRhs();
710 
711   // minsi(x, MAX_INT) -> x
712   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
713       intValue.isMaxSignedValue())
714     return getLhs();
715 
716   return constFoldBinaryOp<IntegerAttr>(operands,
717                                         [](const APInt &a, const APInt &b) {
718                                           return llvm::APIntOps::smin(a, b);
719                                         });
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // MinUIOp
724 //===----------------------------------------------------------------------===//
725 
726 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
727   assert(operands.size() == 2 && "binary operation takes two operands");
728 
729   // minui(x,x) -> x
730   if (getLhs() == getRhs())
731     return getRhs();
732 
733   APInt intValue;
734   // minui(x,MIN_INT) -> MIN_INT
735   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
736     return getRhs();
737 
738   // minui(x, MAX_INT) -> x
739   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
740     return getLhs();
741 
742   return constFoldBinaryOp<IntegerAttr>(operands,
743                                         [](const APInt &a, const APInt &b) {
744                                           return llvm::APIntOps::umin(a, b);
745                                         });
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // MulFOp
750 //===----------------------------------------------------------------------===//
751 
752 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
753   APFloat floatValue(0.0f), inverseValue(0.0f);
754   // mulf(x, 1) -> x
755   if (matchPattern(getRhs(), m_OneFloat()))
756     return getLhs();
757 
758   // mulf(1, x) -> x
759   if (matchPattern(getLhs(), m_OneFloat()))
760     return getRhs();
761 
762   return constFoldBinaryOp<FloatAttr>(
763       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
764 }
765 
766 //===----------------------------------------------------------------------===//
767 // DivFOp
768 //===----------------------------------------------------------------------===//
769 
770 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
771   APFloat floatValue(0.0f), inverseValue(0.0f);
772   // divf(x, 1) -> x
773   if (matchPattern(getRhs(), m_OneFloat()))
774     return getLhs();
775 
776   return constFoldBinaryOp<FloatAttr>(
777       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
778 }
779 
780 //===----------------------------------------------------------------------===//
781 // Utility functions for verifying cast ops
782 //===----------------------------------------------------------------------===//
783 
784 template <typename... Types>
785 using type_list = std::tuple<Types...> *;
786 
787 /// Returns a non-null type only if the provided type is one of the allowed
788 /// types or one of the allowed shaped types of the allowed types. Returns the
789 /// element type if a valid shaped type is provided.
790 template <typename... ShapedTypes, typename... ElementTypes>
791 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
792                               type_list<ElementTypes...>) {
793   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
794     return {};
795 
796   auto underlyingType = getElementTypeOrSelf(type);
797   if (!underlyingType.isa<ElementTypes...>())
798     return {};
799 
800   return underlyingType;
801 }
802 
803 /// Get allowed underlying types for vectors and tensors.
804 template <typename... ElementTypes>
805 static Type getTypeIfLike(Type type) {
806   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
807                            type_list<ElementTypes...>());
808 }
809 
810 /// Get allowed underlying types for vectors, tensors, and memrefs.
811 template <typename... ElementTypes>
812 static Type getTypeIfLikeOrMemRef(Type type) {
813   return getUnderlyingType(type,
814                            type_list<VectorType, TensorType, MemRefType>(),
815                            type_list<ElementTypes...>());
816 }
817 
818 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
819   return inputs.size() == 1 && outputs.size() == 1 &&
820          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
821 }
822 
823 //===----------------------------------------------------------------------===//
824 // Verifiers for integer and floating point extension/truncation ops
825 //===----------------------------------------------------------------------===//
826 
827 // Extend ops can only extend to a wider type.
828 template <typename ValType, typename Op>
829 static LogicalResult verifyExtOp(Op op) {
830   Type srcType = getElementTypeOrSelf(op.getIn().getType());
831   Type dstType = getElementTypeOrSelf(op.getType());
832 
833   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
834     return op.emitError("result type ")
835            << dstType << " must be wider than operand type " << srcType;
836 
837   return success();
838 }
839 
840 // Truncate ops can only truncate to a shorter type.
841 template <typename ValType, typename Op>
842 static LogicalResult verifyTruncateOp(Op op) {
843   Type srcType = getElementTypeOrSelf(op.getIn().getType());
844   Type dstType = getElementTypeOrSelf(op.getType());
845 
846   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
847     return op.emitError("result type ")
848            << dstType << " must be shorter than operand type " << srcType;
849 
850   return success();
851 }
852 
853 /// Validate a cast that changes the width of a type.
854 template <template <typename> class WidthComparator, typename... ElementTypes>
855 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
856   if (!areValidCastInputsAndOutputs(inputs, outputs))
857     return false;
858 
859   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
860   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
861   if (!srcType || !dstType)
862     return false;
863 
864   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
865                                      srcType.getIntOrFloatBitWidth());
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // ExtUIOp
870 //===----------------------------------------------------------------------===//
871 
872 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
873   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
874     return IntegerAttr::get(
875         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
876 
877   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
878     getInMutable().assign(lhs.getIn());
879     return getResult();
880   }
881 
882   return {};
883 }
884 
885 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
886   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
887 }
888 
889 //===----------------------------------------------------------------------===//
890 // ExtSIOp
891 //===----------------------------------------------------------------------===//
892 
893 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
894   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
895     return IntegerAttr::get(
896         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
897 
898   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
899     getInMutable().assign(lhs.getIn());
900     return getResult();
901   }
902 
903   return {};
904 }
905 
906 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
907   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
908 }
909 
910 void arith::ExtSIOp::getCanonicalizationPatterns(
911     RewritePatternSet &patterns, MLIRContext *context) {
912   patterns.insert<ExtSIOfExtUI>(context);
913 }
914 
915 //===----------------------------------------------------------------------===//
916 // ExtFOp
917 //===----------------------------------------------------------------------===//
918 
919 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
920   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
921 }
922 
923 //===----------------------------------------------------------------------===//
924 // TruncIOp
925 //===----------------------------------------------------------------------===//
926 
927 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
928   assert(operands.size() == 1 && "unary operation takes one operand");
929 
930   // trunci(zexti(a)) -> a
931   // trunci(sexti(a)) -> a
932   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
933       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
934     return getOperand().getDefiningOp()->getOperand(0);
935 
936   // trunci(trunci(a)) -> trunci(a))
937   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
938     setOperand(getOperand().getDefiningOp()->getOperand(0));
939     return getResult();
940   }
941 
942   if (!operands[0])
943     return {};
944 
945   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
946     return IntegerAttr::get(
947         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
948   }
949 
950   return {};
951 }
952 
953 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
954   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
955 }
956 
957 //===----------------------------------------------------------------------===//
958 // TruncFOp
959 //===----------------------------------------------------------------------===//
960 
961 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
962 /// can be represented without precision loss or rounding.
963 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
964   assert(operands.size() == 1 && "unary operation takes one operand");
965 
966   auto constOperand = operands.front();
967   if (!constOperand || !constOperand.isa<FloatAttr>())
968     return {};
969 
970   // Convert to target type via 'double'.
971   double sourceValue =
972       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
973   auto targetAttr = FloatAttr::get(getType(), sourceValue);
974 
975   // Propagate if constant's value does not change after truncation.
976   if (sourceValue == targetAttr.getValue().convertToDouble())
977     return targetAttr;
978 
979   return {};
980 }
981 
982 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
983   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
984 }
985 
986 //===----------------------------------------------------------------------===//
987 // AndIOp
988 //===----------------------------------------------------------------------===//
989 
990 void arith::AndIOp::getCanonicalizationPatterns(
991     RewritePatternSet &patterns, MLIRContext *context) {
992   patterns.insert<AndOfExtUI, AndOfExtSI>(context);
993 }
994 
995 //===----------------------------------------------------------------------===//
996 // OrIOp
997 //===----------------------------------------------------------------------===//
998 
999 void arith::OrIOp::getCanonicalizationPatterns(
1000     RewritePatternSet &patterns, MLIRContext *context) {
1001   patterns.insert<OrOfExtUI, OrOfExtSI>(context);
1002 }
1003 
1004 //===----------------------------------------------------------------------===//
1005 // Verifiers for casts between integers and floats.
1006 //===----------------------------------------------------------------------===//
1007 
1008 template <typename From, typename To>
1009 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1010   if (!areValidCastInputsAndOutputs(inputs, outputs))
1011     return false;
1012 
1013   auto srcType = getTypeIfLike<From>(inputs.front());
1014   auto dstType = getTypeIfLike<To>(outputs.back());
1015 
1016   return srcType && dstType;
1017 }
1018 
1019 //===----------------------------------------------------------------------===//
1020 // UIToFPOp
1021 //===----------------------------------------------------------------------===//
1022 
1023 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1024   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1025 }
1026 
1027 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1028   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1029     const APInt &api = lhs.getValue();
1030     FloatType floatTy = getType().cast<FloatType>();
1031     APFloat apf(floatTy.getFloatSemantics(),
1032                 APInt::getZero(floatTy.getWidth()));
1033     apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
1034     return FloatAttr::get(floatTy, apf);
1035   }
1036   return {};
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // SIToFPOp
1041 //===----------------------------------------------------------------------===//
1042 
1043 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1044   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1045 }
1046 
1047 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1048   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1049     const APInt &api = lhs.getValue();
1050     FloatType floatTy = getType().cast<FloatType>();
1051     APFloat apf(floatTy.getFloatSemantics(),
1052                 APInt::getZero(floatTy.getWidth()));
1053     apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
1054     return FloatAttr::get(floatTy, apf);
1055   }
1056   return {};
1057 }
1058 //===----------------------------------------------------------------------===//
1059 // FPToUIOp
1060 //===----------------------------------------------------------------------===//
1061 
1062 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1063   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1064 }
1065 
1066 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1067   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1068     const APFloat &apf = lhs.getValue();
1069     IntegerType intTy = getType().cast<IntegerType>();
1070     bool ignored;
1071     APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1072     if (APFloat::opInvalidOp ==
1073         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1074       // Undefined behavior invoked - the destination type can't represent
1075       // the input constant.
1076       return {};
1077     }
1078     return IntegerAttr::get(getType(), api);
1079   }
1080 
1081   return {};
1082 }
1083 
1084 //===----------------------------------------------------------------------===//
1085 // FPToSIOp
1086 //===----------------------------------------------------------------------===//
1087 
1088 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1089   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1090 }
1091 
1092 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1093   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1094     const APFloat &apf = lhs.getValue();
1095     IntegerType intTy = getType().cast<IntegerType>();
1096     bool ignored;
1097     APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1098     if (APFloat::opInvalidOp ==
1099         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1100       // Undefined behavior invoked - the destination type can't represent
1101       // the input constant.
1102       return {};
1103     }
1104     return IntegerAttr::get(getType(), api);
1105   }
1106 
1107   return {};
1108 }
1109 
1110 //===----------------------------------------------------------------------===//
1111 // IndexCastOp
1112 //===----------------------------------------------------------------------===//
1113 
1114 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1115                                            TypeRange outputs) {
1116   if (!areValidCastInputsAndOutputs(inputs, outputs))
1117     return false;
1118 
1119   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1120   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1121   if (!srcType || !dstType)
1122     return false;
1123 
1124   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1125          (srcType.isSignlessInteger() && dstType.isIndex());
1126 }
1127 
1128 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1129   // index_cast(constant) -> constant
1130   // A little hack because we go through int. Otherwise, the size of the
1131   // constant might need to change.
1132   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1133     return IntegerAttr::get(getType(), value.getInt());
1134 
1135   return {};
1136 }
1137 
1138 void arith::IndexCastOp::getCanonicalizationPatterns(
1139     RewritePatternSet &patterns, MLIRContext *context) {
1140   patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1141 }
1142 
1143 //===----------------------------------------------------------------------===//
1144 // BitcastOp
1145 //===----------------------------------------------------------------------===//
1146 
1147 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1148   if (!areValidCastInputsAndOutputs(inputs, outputs))
1149     return false;
1150 
1151   auto srcType =
1152       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1153   auto dstType =
1154       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1155   if (!srcType || !dstType)
1156     return false;
1157 
1158   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1159 }
1160 
1161 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1162   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1163 
1164   auto resType = getType();
1165   auto operand = operands[0];
1166   if (!operand)
1167     return {};
1168 
1169   /// Bitcast dense elements.
1170   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1171     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1172   /// Other shaped types unhandled.
1173   if (resType.isa<ShapedType>())
1174     return {};
1175 
1176   /// Bitcast integer or float to integer or float.
1177   APInt bits = operand.isa<FloatAttr>()
1178                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1179                    : operand.cast<IntegerAttr>().getValue();
1180 
1181   if (auto resFloatType = resType.dyn_cast<FloatType>())
1182     return FloatAttr::get(resType,
1183                           APFloat(resFloatType.getFloatSemantics(), bits));
1184   return IntegerAttr::get(resType, bits);
1185 }
1186 
1187 void arith::BitcastOp::getCanonicalizationPatterns(
1188     RewritePatternSet &patterns, MLIRContext *context) {
1189   patterns.insert<BitcastOfBitcast>(context);
1190 }
1191 
1192 //===----------------------------------------------------------------------===//
1193 // Helpers for compare ops
1194 //===----------------------------------------------------------------------===//
1195 
1196 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1197 static Type getI1SameShape(Type type) {
1198   auto i1Type = IntegerType::get(type.getContext(), 1);
1199   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1200     return RankedTensorType::get(tensorType.getShape(), i1Type);
1201   if (type.isa<UnrankedTensorType>())
1202     return UnrankedTensorType::get(i1Type);
1203   if (auto vectorType = type.dyn_cast<VectorType>())
1204     return VectorType::get(vectorType.getShape(), i1Type,
1205                            vectorType.getNumScalableDims());
1206   return i1Type;
1207 }
1208 
1209 //===----------------------------------------------------------------------===//
1210 // CmpIOp
1211 //===----------------------------------------------------------------------===//
1212 
1213 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1214 /// comparison predicates.
1215 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1216                                     const APInt &lhs, const APInt &rhs) {
1217   switch (predicate) {
1218   case arith::CmpIPredicate::eq:
1219     return lhs.eq(rhs);
1220   case arith::CmpIPredicate::ne:
1221     return lhs.ne(rhs);
1222   case arith::CmpIPredicate::slt:
1223     return lhs.slt(rhs);
1224   case arith::CmpIPredicate::sle:
1225     return lhs.sle(rhs);
1226   case arith::CmpIPredicate::sgt:
1227     return lhs.sgt(rhs);
1228   case arith::CmpIPredicate::sge:
1229     return lhs.sge(rhs);
1230   case arith::CmpIPredicate::ult:
1231     return lhs.ult(rhs);
1232   case arith::CmpIPredicate::ule:
1233     return lhs.ule(rhs);
1234   case arith::CmpIPredicate::ugt:
1235     return lhs.ugt(rhs);
1236   case arith::CmpIPredicate::uge:
1237     return lhs.uge(rhs);
1238   }
1239   llvm_unreachable("unknown cmpi predicate kind");
1240 }
1241 
1242 /// Returns true if the predicate is true for two equal operands.
1243 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1244   switch (predicate) {
1245   case arith::CmpIPredicate::eq:
1246   case arith::CmpIPredicate::sle:
1247   case arith::CmpIPredicate::sge:
1248   case arith::CmpIPredicate::ule:
1249   case arith::CmpIPredicate::uge:
1250     return true;
1251   case arith::CmpIPredicate::ne:
1252   case arith::CmpIPredicate::slt:
1253   case arith::CmpIPredicate::sgt:
1254   case arith::CmpIPredicate::ult:
1255   case arith::CmpIPredicate::ugt:
1256     return false;
1257   }
1258   llvm_unreachable("unknown cmpi predicate kind");
1259 }
1260 
1261 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1262   auto boolAttr = BoolAttr::get(ctx, value);
1263   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1264   if (!shapedType)
1265     return boolAttr;
1266   return DenseElementsAttr::get(shapedType, boolAttr);
1267 }
1268 
1269 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1270   assert(operands.size() == 2 && "cmpi takes two operands");
1271 
1272   // cmpi(pred, x, x)
1273   if (getLhs() == getRhs()) {
1274     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1275     return getBoolAttribute(getType(), getContext(), val);
1276   }
1277 
1278   if (matchPattern(getRhs(), m_Zero())) {
1279     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1280       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1281         // extsi(%x : i1 -> iN) != 0  ->  %x
1282         if (getPredicate() == arith::CmpIPredicate::ne) {
1283           return extOp.getOperand();
1284         }
1285       }
1286     }
1287     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1288       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1289         // extui(%x : i1 -> iN) != 0  ->  %x
1290         if (getPredicate() == arith::CmpIPredicate::ne) {
1291           return extOp.getOperand();
1292         }
1293       }
1294     }
1295   }
1296 
1297   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1298   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1299   if (!lhs || !rhs)
1300     return {};
1301 
1302   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1303   return BoolAttr::get(getContext(), val);
1304 }
1305 
1306 //===----------------------------------------------------------------------===//
1307 // CmpFOp
1308 //===----------------------------------------------------------------------===//
1309 
1310 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1311 /// comparison predicates.
1312 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1313                                     const APFloat &lhs, const APFloat &rhs) {
1314   auto cmpResult = lhs.compare(rhs);
1315   switch (predicate) {
1316   case arith::CmpFPredicate::AlwaysFalse:
1317     return false;
1318   case arith::CmpFPredicate::OEQ:
1319     return cmpResult == APFloat::cmpEqual;
1320   case arith::CmpFPredicate::OGT:
1321     return cmpResult == APFloat::cmpGreaterThan;
1322   case arith::CmpFPredicate::OGE:
1323     return cmpResult == APFloat::cmpGreaterThan ||
1324            cmpResult == APFloat::cmpEqual;
1325   case arith::CmpFPredicate::OLT:
1326     return cmpResult == APFloat::cmpLessThan;
1327   case arith::CmpFPredicate::OLE:
1328     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1329   case arith::CmpFPredicate::ONE:
1330     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1331   case arith::CmpFPredicate::ORD:
1332     return cmpResult != APFloat::cmpUnordered;
1333   case arith::CmpFPredicate::UEQ:
1334     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1335   case arith::CmpFPredicate::UGT:
1336     return cmpResult == APFloat::cmpUnordered ||
1337            cmpResult == APFloat::cmpGreaterThan;
1338   case arith::CmpFPredicate::UGE:
1339     return cmpResult == APFloat::cmpUnordered ||
1340            cmpResult == APFloat::cmpGreaterThan ||
1341            cmpResult == APFloat::cmpEqual;
1342   case arith::CmpFPredicate::ULT:
1343     return cmpResult == APFloat::cmpUnordered ||
1344            cmpResult == APFloat::cmpLessThan;
1345   case arith::CmpFPredicate::ULE:
1346     return cmpResult == APFloat::cmpUnordered ||
1347            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1348   case arith::CmpFPredicate::UNE:
1349     return cmpResult != APFloat::cmpEqual;
1350   case arith::CmpFPredicate::UNO:
1351     return cmpResult == APFloat::cmpUnordered;
1352   case arith::CmpFPredicate::AlwaysTrue:
1353     return true;
1354   }
1355   llvm_unreachable("unknown cmpf predicate kind");
1356 }
1357 
1358 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1359   assert(operands.size() == 2 && "cmpf takes two operands");
1360 
1361   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1362   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1363 
1364   // If one operand is NaN, making them both NaN does not change the result.
1365   if (lhs && lhs.getValue().isNaN())
1366     rhs = lhs;
1367   if (rhs && rhs.getValue().isNaN())
1368     lhs = rhs;
1369 
1370   if (!lhs || !rhs)
1371     return {};
1372 
1373   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1374   return BoolAttr::get(getContext(), val);
1375 }
1376 
1377 //===----------------------------------------------------------------------===//
1378 // Atomic Enum
1379 //===----------------------------------------------------------------------===//
1380 
1381 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1382 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1383                                             OpBuilder &builder, Location loc) {
1384   switch (kind) {
1385   case AtomicRMWKind::maxf:
1386     return builder.getFloatAttr(
1387         resultType,
1388         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1389                         /*Negative=*/true));
1390   case AtomicRMWKind::addf:
1391   case AtomicRMWKind::addi:
1392   case AtomicRMWKind::maxu:
1393   case AtomicRMWKind::ori:
1394     return builder.getZeroAttr(resultType);
1395   case AtomicRMWKind::andi:
1396     return builder.getIntegerAttr(
1397         resultType,
1398         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1399   case AtomicRMWKind::maxs:
1400     return builder.getIntegerAttr(
1401         resultType,
1402         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1403   case AtomicRMWKind::minf:
1404     return builder.getFloatAttr(
1405         resultType,
1406         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1407                         /*Negative=*/false));
1408   case AtomicRMWKind::mins:
1409     return builder.getIntegerAttr(
1410         resultType,
1411         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1412   case AtomicRMWKind::minu:
1413     return builder.getIntegerAttr(
1414         resultType,
1415         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1416   case AtomicRMWKind::muli:
1417     return builder.getIntegerAttr(resultType, 1);
1418   case AtomicRMWKind::mulf:
1419     return builder.getFloatAttr(resultType, 1);
1420   // TODO: Add remaining reduction operations.
1421   default:
1422     (void)emitOptionalError(loc, "Reduction operation type not supported");
1423     break;
1424   }
1425   return nullptr;
1426 }
1427 
1428 /// Returns the identity value associated with an AtomicRMWKind op.
1429 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1430                                     OpBuilder &builder, Location loc) {
1431   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1432   return builder.create<arith::ConstantOp>(loc, attr);
1433 }
1434 
1435 /// Return the value obtained by applying the reduction operation kind
1436 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1437 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1438                                   Location loc, Value lhs, Value rhs) {
1439   switch (op) {
1440   case AtomicRMWKind::addf:
1441     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1442   case AtomicRMWKind::addi:
1443     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1444   case AtomicRMWKind::mulf:
1445     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1446   case AtomicRMWKind::muli:
1447     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1448   case AtomicRMWKind::maxf:
1449     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1450   case AtomicRMWKind::minf:
1451     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1452   case AtomicRMWKind::maxs:
1453     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1454   case AtomicRMWKind::mins:
1455     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1456   case AtomicRMWKind::maxu:
1457     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1458   case AtomicRMWKind::minu:
1459     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1460   case AtomicRMWKind::ori:
1461     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1462   case AtomicRMWKind::andi:
1463     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1464   // TODO: Add remaining reduction operations.
1465   default:
1466     (void)emitOptionalError(loc, "Reduction operation type not supported");
1467     break;
1468   }
1469   return nullptr;
1470 }
1471 
1472 //===----------------------------------------------------------------------===//
1473 // TableGen'd op method definitions
1474 //===----------------------------------------------------------------------===//
1475 
1476 #define GET_OP_CLASSES
1477 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1478 
1479 //===----------------------------------------------------------------------===//
1480 // TableGen'd enum attribute definitions
1481 //===----------------------------------------------------------------------===//
1482 
1483 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1484