1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #include "llvm/ADT/APSInt.h"
20 
21 using namespace mlir;
22 using namespace mlir::arith;
23 
24 //===----------------------------------------------------------------------===//
25 // Pattern helpers
26 //===----------------------------------------------------------------------===//
27 
28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
29                                    Attribute lhs, Attribute rhs) {
30   return builder.getIntegerAttr(res.getType(),
31                                 lhs.cast<IntegerAttr>().getInt() +
32                                     rhs.cast<IntegerAttr>().getInt());
33 }
34 
35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
36                                    Attribute lhs, Attribute rhs) {
37   return builder.getIntegerAttr(res.getType(),
38                                 lhs.cast<IntegerAttr>().getInt() -
39                                     rhs.cast<IntegerAttr>().getInt());
40 }
41 
42 /// Invert an integer comparison predicate.
43 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
44   switch (pred) {
45   case arith::CmpIPredicate::eq:
46     return arith::CmpIPredicate::ne;
47   case arith::CmpIPredicate::ne:
48     return arith::CmpIPredicate::eq;
49   case arith::CmpIPredicate::slt:
50     return arith::CmpIPredicate::sge;
51   case arith::CmpIPredicate::sle:
52     return arith::CmpIPredicate::sgt;
53   case arith::CmpIPredicate::sgt:
54     return arith::CmpIPredicate::sle;
55   case arith::CmpIPredicate::sge:
56     return arith::CmpIPredicate::slt;
57   case arith::CmpIPredicate::ult:
58     return arith::CmpIPredicate::uge;
59   case arith::CmpIPredicate::ule:
60     return arith::CmpIPredicate::ugt;
61   case arith::CmpIPredicate::ugt:
62     return arith::CmpIPredicate::ule;
63   case arith::CmpIPredicate::uge:
64     return arith::CmpIPredicate::ult;
65   }
66   llvm_unreachable("unknown cmpi predicate kind");
67 }
68 
69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
70   return arith::CmpIPredicateAttr::get(pred.getContext(),
71                                        invertPredicate(pred.getValue()));
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // TableGen'd canonicalization patterns
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 #include "ArithmeticCanonicalization.inc"
80 } // namespace
81 
82 //===----------------------------------------------------------------------===//
83 // ConstantOp
84 //===----------------------------------------------------------------------===//
85 
86 void arith::ConstantOp::getAsmResultNames(
87     function_ref<void(Value, StringRef)> setNameFn) {
88   auto type = getType();
89   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
90     auto intType = type.dyn_cast<IntegerType>();
91 
92     // Sugar i1 constants with 'true' and 'false'.
93     if (intType && intType.getWidth() == 1)
94       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
95 
96     // Otherwise, build a compex name with the value and type.
97     SmallString<32> specialNameBuffer;
98     llvm::raw_svector_ostream specialName(specialNameBuffer);
99     specialName << 'c' << intCst.getInt();
100     if (intType)
101       specialName << '_' << type;
102     setNameFn(getResult(), specialName.str());
103   } else {
104     setNameFn(getResult(), "cst");
105   }
106 }
107 
108 /// TODO: disallow arith.constant to return anything other than signless integer
109 /// or float like.
110 LogicalResult arith::ConstantOp::verify() {
111   auto type = getType();
112   // The value's type must match the return type.
113   if (getValue().getType() != type) {
114     return emitOpError() << "value type " << getValue().getType()
115                          << " must match return type: " << type;
116   }
117   // Integer values must be signless.
118   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
119     return emitOpError("integer return type must be signless");
120   // Any float or elements attribute are acceptable.
121   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
122     return emitOpError(
123         "value must be an integer, float, or elements attribute");
124   }
125   return success();
126 }
127 
128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
129   // The value's type must be the same as the provided type.
130   if (value.getType() != type)
131     return false;
132   // Integer values must be signless.
133   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
134     return false;
135   // Integer, float, and element attributes are buildable.
136   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
137 }
138 
139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
140   return getValue();
141 }
142 
143 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
144                                  int64_t value, unsigned width) {
145   auto type = builder.getIntegerType(width);
146   arith::ConstantOp::build(builder, result, type,
147                            builder.getIntegerAttr(type, value));
148 }
149 
150 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
151                                  int64_t value, Type type) {
152   assert(type.isSignlessInteger() &&
153          "ConstantIntOp can only have signless integer type values");
154   arith::ConstantOp::build(builder, result, type,
155                            builder.getIntegerAttr(type, value));
156 }
157 
158 bool arith::ConstantIntOp::classof(Operation *op) {
159   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
160     return constOp.getType().isSignlessInteger();
161   return false;
162 }
163 
164 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
165                                    const APFloat &value, FloatType type) {
166   arith::ConstantOp::build(builder, result, type,
167                            builder.getFloatAttr(type, value));
168 }
169 
170 bool arith::ConstantFloatOp::classof(Operation *op) {
171   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
172     return constOp.getType().isa<FloatType>();
173   return false;
174 }
175 
176 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
177                                    int64_t value) {
178   arith::ConstantOp::build(builder, result, builder.getIndexType(),
179                            builder.getIndexAttr(value));
180 }
181 
182 bool arith::ConstantIndexOp::classof(Operation *op) {
183   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
184     return constOp.getType().isIndex();
185   return false;
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // AddIOp
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
193   // addi(x, 0) -> x
194   if (matchPattern(getRhs(), m_Zero()))
195     return getLhs();
196 
197   // addi(subi(a, b), b) -> a
198   if (auto sub = getLhs().getDefiningOp<SubIOp>())
199     if (getRhs() == sub.getRhs())
200       return sub.getLhs();
201 
202   // addi(b, subi(a, b)) -> a
203   if (auto sub = getRhs().getDefiningOp<SubIOp>())
204     if (getLhs() == sub.getRhs())
205       return sub.getLhs();
206 
207   return constFoldBinaryOp<IntegerAttr>(
208       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
209 }
210 
211 void arith::AddIOp::getCanonicalizationPatterns(
212     RewritePatternSet &patterns, MLIRContext *context) {
213   patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
214       context);
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // SubIOp
219 //===----------------------------------------------------------------------===//
220 
221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
222   // subi(x,x) -> 0
223   if (getOperand(0) == getOperand(1))
224     return Builder(getContext()).getZeroAttr(getType());
225   // subi(x,0) -> x
226   if (matchPattern(getRhs(), m_Zero()))
227     return getLhs();
228 
229   return constFoldBinaryOp<IntegerAttr>(
230       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
231 }
232 
233 void arith::SubIOp::getCanonicalizationPatterns(
234     RewritePatternSet &patterns, MLIRContext *context) {
235   patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
236                   SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
237                   SubILHSSubConstantLHS>(context);
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // MulIOp
242 //===----------------------------------------------------------------------===//
243 
244 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
245   // muli(x, 0) -> 0
246   if (matchPattern(getRhs(), m_Zero()))
247     return getRhs();
248   // muli(x, 1) -> x
249   if (matchPattern(getRhs(), m_One()))
250     return getOperand(0);
251   // TODO: Handle the overflow case.
252 
253   // default folder
254   return constFoldBinaryOp<IntegerAttr>(
255       operands, [](const APInt &a, const APInt &b) { return a * b; });
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // DivUIOp
260 //===----------------------------------------------------------------------===//
261 
262 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
263   // Don't fold if it would require a division by zero.
264   bool div0 = false;
265   auto result =
266       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
267         if (div0 || !b) {
268           div0 = true;
269           return a;
270         }
271         return a.udiv(b);
272       });
273 
274   // Fold out division by one. Assumes all tensors of all ones are splats.
275   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
276     if (rhs.getValue() == 1)
277       return getLhs();
278   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
279     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
280       return getLhs();
281   }
282 
283   return div0 ? Attribute() : result;
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // DivSIOp
288 //===----------------------------------------------------------------------===//
289 
290 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
291   // Don't fold if it would overflow or if it requires a division by zero.
292   bool overflowOrDiv0 = false;
293   auto result =
294       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
295         if (overflowOrDiv0 || !b) {
296           overflowOrDiv0 = true;
297           return a;
298         }
299         return a.sdiv_ov(b, overflowOrDiv0);
300       });
301 
302   // Fold out division by one. Assumes all tensors of all ones are splats.
303   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
304     if (rhs.getValue() == 1)
305       return getLhs();
306   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
307     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
308       return getLhs();
309   }
310 
311   return overflowOrDiv0 ? Attribute() : result;
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Ceil and floor division folding helpers
316 //===----------------------------------------------------------------------===//
317 
318 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
319                                     bool &overflow) {
320   // Returns (a-1)/b + 1
321   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
322   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
323   return val.sadd_ov(one, overflow);
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // CeilDivUIOp
328 //===----------------------------------------------------------------------===//
329 
330 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
331   bool overflowOrDiv0 = false;
332   auto result =
333       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
334         if (overflowOrDiv0 || !b) {
335           overflowOrDiv0 = true;
336           return a;
337         }
338         APInt quotient = a.udiv(b);
339         if (!a.urem(b))
340           return quotient;
341         APInt one(a.getBitWidth(), 1, true);
342         return quotient.uadd_ov(one, overflowOrDiv0);
343       });
344   // Fold out ceil division by one. Assumes all tensors of all ones are
345   // splats.
346   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
347     if (rhs.getValue() == 1)
348       return getLhs();
349   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
350     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
351       return getLhs();
352   }
353 
354   return overflowOrDiv0 ? Attribute() : result;
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // CeilDivSIOp
359 //===----------------------------------------------------------------------===//
360 
361 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
362   // Don't fold if it would overflow or if it requires a division by zero.
363   bool overflowOrDiv0 = false;
364   auto result =
365       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
366         if (overflowOrDiv0 || !b) {
367           overflowOrDiv0 = true;
368           return a;
369         }
370         if (!a)
371           return a;
372         // After this point we know that neither a or b are zero.
373         unsigned bits = a.getBitWidth();
374         APInt zero = APInt::getZero(bits);
375         bool aGtZero = a.sgt(zero);
376         bool bGtZero = b.sgt(zero);
377         if (aGtZero && bGtZero) {
378           // Both positive, return ceil(a, b).
379           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
380         }
381         if (!aGtZero && !bGtZero) {
382           // Both negative, return ceil(-a, -b).
383           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
384           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
385           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
386         }
387         if (!aGtZero && bGtZero) {
388           // A is negative, b is positive, return - ( -a / b).
389           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
390           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
391           return zero.ssub_ov(div, overflowOrDiv0);
392         }
393         // A is positive, b is negative, return - (a / -b).
394         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
395         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
396         return zero.ssub_ov(div, overflowOrDiv0);
397       });
398 
399   // Fold out ceil division by one. Assumes all tensors of all ones are
400   // splats.
401   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
402     if (rhs.getValue() == 1)
403       return getLhs();
404   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
405     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
406       return getLhs();
407   }
408 
409   return overflowOrDiv0 ? Attribute() : result;
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // FloorDivSIOp
414 //===----------------------------------------------------------------------===//
415 
416 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
417   // Don't fold if it would overflow or if it requires a division by zero.
418   bool overflowOrDiv0 = false;
419   auto result =
420       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
421         if (overflowOrDiv0 || !b) {
422           overflowOrDiv0 = true;
423           return a;
424         }
425         if (!a)
426           return a;
427         // After this point we know that neither a or b are zero.
428         unsigned bits = a.getBitWidth();
429         APInt zero = APInt::getZero(bits);
430         bool aGtZero = a.sgt(zero);
431         bool bGtZero = b.sgt(zero);
432         if (aGtZero && bGtZero) {
433           // Both positive, return a / b.
434           return a.sdiv_ov(b, overflowOrDiv0);
435         }
436         if (!aGtZero && !bGtZero) {
437           // Both negative, return -a / -b.
438           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
439           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
440           return posA.sdiv_ov(posB, overflowOrDiv0);
441         }
442         if (!aGtZero && bGtZero) {
443           // A is negative, b is positive, return - ceil(-a, b).
444           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
445           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
446           return zero.ssub_ov(ceil, overflowOrDiv0);
447         }
448         // A is positive, b is negative, return - ceil(a, -b).
449         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
450         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
451         return zero.ssub_ov(ceil, overflowOrDiv0);
452       });
453 
454   // Fold out floor division by one. Assumes all tensors of all ones are
455   // splats.
456   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
457     if (rhs.getValue() == 1)
458       return getLhs();
459   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
460     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
461       return getLhs();
462   }
463 
464   return overflowOrDiv0 ? Attribute() : result;
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // RemUIOp
469 //===----------------------------------------------------------------------===//
470 
471 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
472   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
473   if (!rhs)
474     return {};
475   auto rhsValue = rhs.getValue();
476 
477   // x % 1 = 0
478   if (rhsValue.isOneValue())
479     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
480 
481   // Don't fold if it requires division by zero.
482   if (rhsValue.isNullValue())
483     return {};
484 
485   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
486   if (!lhs)
487     return {};
488   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // RemSIOp
493 //===----------------------------------------------------------------------===//
494 
495 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
496   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
497   if (!rhs)
498     return {};
499   auto rhsValue = rhs.getValue();
500 
501   // x % 1 = 0
502   if (rhsValue.isOneValue())
503     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
504 
505   // Don't fold if it requires division by zero.
506   if (rhsValue.isNullValue())
507     return {};
508 
509   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
510   if (!lhs)
511     return {};
512   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // AndIOp
517 //===----------------------------------------------------------------------===//
518 
519 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
520   /// and(x, 0) -> 0
521   if (matchPattern(getRhs(), m_Zero()))
522     return getRhs();
523   /// and(x, allOnes) -> x
524   APInt intValue;
525   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
526     return getLhs();
527 
528   return constFoldBinaryOp<IntegerAttr>(
529       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // OrIOp
534 //===----------------------------------------------------------------------===//
535 
536 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
537   /// or(x, 0) -> x
538   if (matchPattern(getRhs(), m_Zero()))
539     return getLhs();
540   /// or(x, <all ones>) -> <all ones>
541   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
542     if (rhsAttr.getValue().isAllOnes())
543       return rhsAttr;
544 
545   return constFoldBinaryOp<IntegerAttr>(
546       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // XOrIOp
551 //===----------------------------------------------------------------------===//
552 
553 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
554   /// xor(x, 0) -> x
555   if (matchPattern(getRhs(), m_Zero()))
556     return getLhs();
557   /// xor(x, x) -> 0
558   if (getLhs() == getRhs())
559     return Builder(getContext()).getZeroAttr(getType());
560   /// xor(xor(x, a), a) -> x
561   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
562     if (prev.getRhs() == getRhs())
563       return prev.getLhs();
564 
565   return constFoldBinaryOp<IntegerAttr>(
566       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
567 }
568 
569 void arith::XOrIOp::getCanonicalizationPatterns(
570     RewritePatternSet &patterns, MLIRContext *context) {
571   patterns.insert<XOrINotCmpI>(context);
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // AddFOp
576 //===----------------------------------------------------------------------===//
577 
578 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
579   // addf(x, -0) -> x
580   if (matchPattern(getRhs(), m_NegZeroFloat()))
581     return getLhs();
582 
583   return constFoldBinaryOp<FloatAttr>(
584       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
585 }
586 
587 //===----------------------------------------------------------------------===//
588 // SubFOp
589 //===----------------------------------------------------------------------===//
590 
591 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
592   // subf(x, +0) -> x
593   if (matchPattern(getRhs(), m_PosZeroFloat()))
594     return getLhs();
595 
596   return constFoldBinaryOp<FloatAttr>(
597       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // MaxFOp
602 //===----------------------------------------------------------------------===//
603 
604 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
605   assert(operands.size() == 2 && "maxf takes two operands");
606 
607   // maxf(x,x) -> x
608   if (getLhs() == getRhs())
609     return getRhs();
610 
611   // maxf(x, -inf) -> x
612   if (matchPattern(getRhs(), m_NegInfFloat()))
613     return getLhs();
614 
615   return constFoldBinaryOp<FloatAttr>(
616       operands,
617       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // MaxSIOp
622 //===----------------------------------------------------------------------===//
623 
624 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
625   assert(operands.size() == 2 && "binary operation takes two operands");
626 
627   // maxsi(x,x) -> x
628   if (getLhs() == getRhs())
629     return getRhs();
630 
631   APInt intValue;
632   // maxsi(x,MAX_INT) -> MAX_INT
633   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
634       intValue.isMaxSignedValue())
635     return getRhs();
636 
637   // maxsi(x, MIN_INT) -> x
638   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
639       intValue.isMinSignedValue())
640     return getLhs();
641 
642   return constFoldBinaryOp<IntegerAttr>(operands,
643                                         [](const APInt &a, const APInt &b) {
644                                           return llvm::APIntOps::smax(a, b);
645                                         });
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // MaxUIOp
650 //===----------------------------------------------------------------------===//
651 
652 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
653   assert(operands.size() == 2 && "binary operation takes two operands");
654 
655   // maxui(x,x) -> x
656   if (getLhs() == getRhs())
657     return getRhs();
658 
659   APInt intValue;
660   // maxui(x,MAX_INT) -> MAX_INT
661   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
662     return getRhs();
663 
664   // maxui(x, MIN_INT) -> x
665   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
666     return getLhs();
667 
668   return constFoldBinaryOp<IntegerAttr>(operands,
669                                         [](const APInt &a, const APInt &b) {
670                                           return llvm::APIntOps::umax(a, b);
671                                         });
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // MinFOp
676 //===----------------------------------------------------------------------===//
677 
678 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
679   assert(operands.size() == 2 && "minf takes two operands");
680 
681   // minf(x,x) -> x
682   if (getLhs() == getRhs())
683     return getRhs();
684 
685   // minf(x, +inf) -> x
686   if (matchPattern(getRhs(), m_PosInfFloat()))
687     return getLhs();
688 
689   return constFoldBinaryOp<FloatAttr>(
690       operands,
691       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
692 }
693 
694 //===----------------------------------------------------------------------===//
695 // MinSIOp
696 //===----------------------------------------------------------------------===//
697 
698 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
699   assert(operands.size() == 2 && "binary operation takes two operands");
700 
701   // minsi(x,x) -> x
702   if (getLhs() == getRhs())
703     return getRhs();
704 
705   APInt intValue;
706   // minsi(x,MIN_INT) -> MIN_INT
707   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
708       intValue.isMinSignedValue())
709     return getRhs();
710 
711   // minsi(x, MAX_INT) -> x
712   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
713       intValue.isMaxSignedValue())
714     return getLhs();
715 
716   return constFoldBinaryOp<IntegerAttr>(operands,
717                                         [](const APInt &a, const APInt &b) {
718                                           return llvm::APIntOps::smin(a, b);
719                                         });
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // MinUIOp
724 //===----------------------------------------------------------------------===//
725 
726 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
727   assert(operands.size() == 2 && "binary operation takes two operands");
728 
729   // minui(x,x) -> x
730   if (getLhs() == getRhs())
731     return getRhs();
732 
733   APInt intValue;
734   // minui(x,MIN_INT) -> MIN_INT
735   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
736     return getRhs();
737 
738   // minui(x, MAX_INT) -> x
739   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
740     return getLhs();
741 
742   return constFoldBinaryOp<IntegerAttr>(operands,
743                                         [](const APInt &a, const APInt &b) {
744                                           return llvm::APIntOps::umin(a, b);
745                                         });
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // MulFOp
750 //===----------------------------------------------------------------------===//
751 
752 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
753   APFloat floatValue(0.0f), inverseValue(0.0f);
754   // mulf(x, 1) -> x
755   if (matchPattern(getRhs(), m_OneFloat()))
756     return getLhs();
757 
758   // mulf(1, x) -> x
759   if (matchPattern(getLhs(), m_OneFloat()))
760     return getRhs();
761 
762   return constFoldBinaryOp<FloatAttr>(
763       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
764 }
765 
766 //===----------------------------------------------------------------------===//
767 // DivFOp
768 //===----------------------------------------------------------------------===//
769 
770 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
771   APFloat floatValue(0.0f), inverseValue(0.0f);
772   // divf(x, 1) -> x
773   if (matchPattern(getRhs(), m_OneFloat()))
774     return getLhs();
775 
776   return constFoldBinaryOp<FloatAttr>(
777       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
778 }
779 
780 //===----------------------------------------------------------------------===//
781 // Utility functions for verifying cast ops
782 //===----------------------------------------------------------------------===//
783 
784 template <typename... Types>
785 using type_list = std::tuple<Types...> *;
786 
787 /// Returns a non-null type only if the provided type is one of the allowed
788 /// types or one of the allowed shaped types of the allowed types. Returns the
789 /// element type if a valid shaped type is provided.
790 template <typename... ShapedTypes, typename... ElementTypes>
791 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
792                               type_list<ElementTypes...>) {
793   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
794     return {};
795 
796   auto underlyingType = getElementTypeOrSelf(type);
797   if (!underlyingType.isa<ElementTypes...>())
798     return {};
799 
800   return underlyingType;
801 }
802 
803 /// Get allowed underlying types for vectors and tensors.
804 template <typename... ElementTypes>
805 static Type getTypeIfLike(Type type) {
806   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
807                            type_list<ElementTypes...>());
808 }
809 
810 /// Get allowed underlying types for vectors, tensors, and memrefs.
811 template <typename... ElementTypes>
812 static Type getTypeIfLikeOrMemRef(Type type) {
813   return getUnderlyingType(type,
814                            type_list<VectorType, TensorType, MemRefType>(),
815                            type_list<ElementTypes...>());
816 }
817 
818 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
819   return inputs.size() == 1 && outputs.size() == 1 &&
820          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
821 }
822 
823 //===----------------------------------------------------------------------===//
824 // Verifiers for integer and floating point extension/truncation ops
825 //===----------------------------------------------------------------------===//
826 
827 // Extend ops can only extend to a wider type.
828 template <typename ValType, typename Op>
829 static LogicalResult verifyExtOp(Op op) {
830   Type srcType = getElementTypeOrSelf(op.getIn().getType());
831   Type dstType = getElementTypeOrSelf(op.getType());
832 
833   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
834     return op.emitError("result type ")
835            << dstType << " must be wider than operand type " << srcType;
836 
837   return success();
838 }
839 
840 // Truncate ops can only truncate to a shorter type.
841 template <typename ValType, typename Op>
842 static LogicalResult verifyTruncateOp(Op op) {
843   Type srcType = getElementTypeOrSelf(op.getIn().getType());
844   Type dstType = getElementTypeOrSelf(op.getType());
845 
846   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
847     return op.emitError("result type ")
848            << dstType << " must be shorter than operand type " << srcType;
849 
850   return success();
851 }
852 
853 /// Validate a cast that changes the width of a type.
854 template <template <typename> class WidthComparator, typename... ElementTypes>
855 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
856   if (!areValidCastInputsAndOutputs(inputs, outputs))
857     return false;
858 
859   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
860   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
861   if (!srcType || !dstType)
862     return false;
863 
864   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
865                                      srcType.getIntOrFloatBitWidth());
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // ExtUIOp
870 //===----------------------------------------------------------------------===//
871 
872 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
873   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
874     return IntegerAttr::get(
875         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
876 
877   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
878     getInMutable().assign(lhs.getIn());
879     return getResult();
880   }
881 
882   return {};
883 }
884 
885 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
886   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
887 }
888 
889 LogicalResult arith::ExtUIOp::verify() {
890   return verifyExtOp<IntegerType>(*this);
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // ExtSIOp
895 //===----------------------------------------------------------------------===//
896 
897 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
898   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
899     return IntegerAttr::get(
900         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
901 
902   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
903     getInMutable().assign(lhs.getIn());
904     return getResult();
905   }
906 
907   return {};
908 }
909 
910 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
911   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
912 }
913 
914 void arith::ExtSIOp::getCanonicalizationPatterns(
915     RewritePatternSet &patterns, MLIRContext *context) {
916   patterns.insert<ExtSIOfExtUI>(context);
917 }
918 
919 LogicalResult arith::ExtSIOp::verify() {
920   return verifyExtOp<IntegerType>(*this);
921 }
922 
923 //===----------------------------------------------------------------------===//
924 // ExtFOp
925 //===----------------------------------------------------------------------===//
926 
927 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
928   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
929 }
930 
931 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
932 
933 //===----------------------------------------------------------------------===//
934 // TruncIOp
935 //===----------------------------------------------------------------------===//
936 
937 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
938   assert(operands.size() == 1 && "unary operation takes one operand");
939 
940   // trunci(zexti(a)) -> a
941   // trunci(sexti(a)) -> a
942   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
943       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
944     return getOperand().getDefiningOp()->getOperand(0);
945 
946   // trunci(trunci(a)) -> trunci(a))
947   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
948     setOperand(getOperand().getDefiningOp()->getOperand(0));
949     return getResult();
950   }
951 
952   if (!operands[0])
953     return {};
954 
955   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
956     return IntegerAttr::get(
957         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
958   }
959 
960   return {};
961 }
962 
963 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
964   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
965 }
966 
967 LogicalResult arith::TruncIOp::verify() {
968   return verifyTruncateOp<IntegerType>(*this);
969 }
970 
971 //===----------------------------------------------------------------------===//
972 // TruncFOp
973 //===----------------------------------------------------------------------===//
974 
975 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
976 /// can be represented without precision loss or rounding.
977 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
978   assert(operands.size() == 1 && "unary operation takes one operand");
979 
980   auto constOperand = operands.front();
981   if (!constOperand || !constOperand.isa<FloatAttr>())
982     return {};
983 
984   // Convert to target type via 'double'.
985   double sourceValue =
986       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
987   auto targetAttr = FloatAttr::get(getType(), sourceValue);
988 
989   // Propagate if constant's value does not change after truncation.
990   if (sourceValue == targetAttr.getValue().convertToDouble())
991     return targetAttr;
992 
993   return {};
994 }
995 
996 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
997   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
998 }
999 
1000 LogicalResult arith::TruncFOp::verify() {
1001   return verifyTruncateOp<FloatType>(*this);
1002 }
1003 
1004 //===----------------------------------------------------------------------===//
1005 // AndIOp
1006 //===----------------------------------------------------------------------===//
1007 
1008 void arith::AndIOp::getCanonicalizationPatterns(
1009     RewritePatternSet &patterns, MLIRContext *context) {
1010   patterns.insert<AndOfExtUI, AndOfExtSI>(context);
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // OrIOp
1015 //===----------------------------------------------------------------------===//
1016 
1017 void arith::OrIOp::getCanonicalizationPatterns(
1018     RewritePatternSet &patterns, MLIRContext *context) {
1019   patterns.insert<OrOfExtUI, OrOfExtSI>(context);
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // Verifiers for casts between integers and floats.
1024 //===----------------------------------------------------------------------===//
1025 
1026 template <typename From, typename To>
1027 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1028   if (!areValidCastInputsAndOutputs(inputs, outputs))
1029     return false;
1030 
1031   auto srcType = getTypeIfLike<From>(inputs.front());
1032   auto dstType = getTypeIfLike<To>(outputs.back());
1033 
1034   return srcType && dstType;
1035 }
1036 
1037 //===----------------------------------------------------------------------===//
1038 // UIToFPOp
1039 //===----------------------------------------------------------------------===//
1040 
1041 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1042   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1043 }
1044 
1045 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1046   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1047     const APInt &api = lhs.getValue();
1048     FloatType floatTy = getType().cast<FloatType>();
1049     APFloat apf(floatTy.getFloatSemantics(),
1050                 APInt::getZero(floatTy.getWidth()));
1051     apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
1052     return FloatAttr::get(floatTy, apf);
1053   }
1054   return {};
1055 }
1056 
1057 //===----------------------------------------------------------------------===//
1058 // SIToFPOp
1059 //===----------------------------------------------------------------------===//
1060 
1061 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1062   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1063 }
1064 
1065 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1066   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1067     const APInt &api = lhs.getValue();
1068     FloatType floatTy = getType().cast<FloatType>();
1069     APFloat apf(floatTy.getFloatSemantics(),
1070                 APInt::getZero(floatTy.getWidth()));
1071     apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
1072     return FloatAttr::get(floatTy, apf);
1073   }
1074   return {};
1075 }
1076 //===----------------------------------------------------------------------===//
1077 // FPToUIOp
1078 //===----------------------------------------------------------------------===//
1079 
1080 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1081   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1082 }
1083 
1084 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1085   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1086     const APFloat &apf = lhs.getValue();
1087     IntegerType intTy = getType().cast<IntegerType>();
1088     bool ignored;
1089     APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1090     if (APFloat::opInvalidOp ==
1091         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1092       // Undefined behavior invoked - the destination type can't represent
1093       // the input constant.
1094       return {};
1095     }
1096     return IntegerAttr::get(getType(), api);
1097   }
1098 
1099   return {};
1100 }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // FPToSIOp
1104 //===----------------------------------------------------------------------===//
1105 
1106 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1107   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1108 }
1109 
1110 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1111   if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1112     const APFloat &apf = lhs.getValue();
1113     IntegerType intTy = getType().cast<IntegerType>();
1114     bool ignored;
1115     APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1116     if (APFloat::opInvalidOp ==
1117         apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1118       // Undefined behavior invoked - the destination type can't represent
1119       // the input constant.
1120       return {};
1121     }
1122     return IntegerAttr::get(getType(), api);
1123   }
1124 
1125   return {};
1126 }
1127 
1128 //===----------------------------------------------------------------------===//
1129 // IndexCastOp
1130 //===----------------------------------------------------------------------===//
1131 
1132 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1133                                            TypeRange outputs) {
1134   if (!areValidCastInputsAndOutputs(inputs, outputs))
1135     return false;
1136 
1137   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1138   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1139   if (!srcType || !dstType)
1140     return false;
1141 
1142   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1143          (srcType.isSignlessInteger() && dstType.isIndex());
1144 }
1145 
1146 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1147   // index_cast(constant) -> constant
1148   // A little hack because we go through int. Otherwise, the size of the
1149   // constant might need to change.
1150   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1151     return IntegerAttr::get(getType(), value.getInt());
1152 
1153   return {};
1154 }
1155 
1156 void arith::IndexCastOp::getCanonicalizationPatterns(
1157     RewritePatternSet &patterns, MLIRContext *context) {
1158   patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // BitcastOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1166   if (!areValidCastInputsAndOutputs(inputs, outputs))
1167     return false;
1168 
1169   auto srcType =
1170       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1171   auto dstType =
1172       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1173   if (!srcType || !dstType)
1174     return false;
1175 
1176   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1177 }
1178 
1179 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1180   assert(operands.size() == 1 && "bitcast op expects 1 operand");
1181 
1182   auto resType = getType();
1183   auto operand = operands[0];
1184   if (!operand)
1185     return {};
1186 
1187   /// Bitcast dense elements.
1188   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1189     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1190   /// Other shaped types unhandled.
1191   if (resType.isa<ShapedType>())
1192     return {};
1193 
1194   /// Bitcast integer or float to integer or float.
1195   APInt bits = operand.isa<FloatAttr>()
1196                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1197                    : operand.cast<IntegerAttr>().getValue();
1198 
1199   if (auto resFloatType = resType.dyn_cast<FloatType>())
1200     return FloatAttr::get(resType,
1201                           APFloat(resFloatType.getFloatSemantics(), bits));
1202   return IntegerAttr::get(resType, bits);
1203 }
1204 
1205 void arith::BitcastOp::getCanonicalizationPatterns(
1206     RewritePatternSet &patterns, MLIRContext *context) {
1207   patterns.insert<BitcastOfBitcast>(context);
1208 }
1209 
1210 //===----------------------------------------------------------------------===//
1211 // Helpers for compare ops
1212 //===----------------------------------------------------------------------===//
1213 
1214 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1215 static Type getI1SameShape(Type type) {
1216   auto i1Type = IntegerType::get(type.getContext(), 1);
1217   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1218     return RankedTensorType::get(tensorType.getShape(), i1Type);
1219   if (type.isa<UnrankedTensorType>())
1220     return UnrankedTensorType::get(i1Type);
1221   if (auto vectorType = type.dyn_cast<VectorType>())
1222     return VectorType::get(vectorType.getShape(), i1Type,
1223                            vectorType.getNumScalableDims());
1224   return i1Type;
1225 }
1226 
1227 //===----------------------------------------------------------------------===//
1228 // CmpIOp
1229 //===----------------------------------------------------------------------===//
1230 
1231 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1232 /// comparison predicates.
1233 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1234                                     const APInt &lhs, const APInt &rhs) {
1235   switch (predicate) {
1236   case arith::CmpIPredicate::eq:
1237     return lhs.eq(rhs);
1238   case arith::CmpIPredicate::ne:
1239     return lhs.ne(rhs);
1240   case arith::CmpIPredicate::slt:
1241     return lhs.slt(rhs);
1242   case arith::CmpIPredicate::sle:
1243     return lhs.sle(rhs);
1244   case arith::CmpIPredicate::sgt:
1245     return lhs.sgt(rhs);
1246   case arith::CmpIPredicate::sge:
1247     return lhs.sge(rhs);
1248   case arith::CmpIPredicate::ult:
1249     return lhs.ult(rhs);
1250   case arith::CmpIPredicate::ule:
1251     return lhs.ule(rhs);
1252   case arith::CmpIPredicate::ugt:
1253     return lhs.ugt(rhs);
1254   case arith::CmpIPredicate::uge:
1255     return lhs.uge(rhs);
1256   }
1257   llvm_unreachable("unknown cmpi predicate kind");
1258 }
1259 
1260 /// Returns true if the predicate is true for two equal operands.
1261 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1262   switch (predicate) {
1263   case arith::CmpIPredicate::eq:
1264   case arith::CmpIPredicate::sle:
1265   case arith::CmpIPredicate::sge:
1266   case arith::CmpIPredicate::ule:
1267   case arith::CmpIPredicate::uge:
1268     return true;
1269   case arith::CmpIPredicate::ne:
1270   case arith::CmpIPredicate::slt:
1271   case arith::CmpIPredicate::sgt:
1272   case arith::CmpIPredicate::ult:
1273   case arith::CmpIPredicate::ugt:
1274     return false;
1275   }
1276   llvm_unreachable("unknown cmpi predicate kind");
1277 }
1278 
1279 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1280   auto boolAttr = BoolAttr::get(ctx, value);
1281   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1282   if (!shapedType)
1283     return boolAttr;
1284   return DenseElementsAttr::get(shapedType, boolAttr);
1285 }
1286 
1287 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1288   assert(operands.size() == 2 && "cmpi takes two operands");
1289 
1290   // cmpi(pred, x, x)
1291   if (getLhs() == getRhs()) {
1292     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1293     return getBoolAttribute(getType(), getContext(), val);
1294   }
1295 
1296   if (matchPattern(getRhs(), m_Zero())) {
1297     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1298       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1299         // extsi(%x : i1 -> iN) != 0  ->  %x
1300         if (getPredicate() == arith::CmpIPredicate::ne) {
1301           return extOp.getOperand();
1302         }
1303       }
1304     }
1305     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1306       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1307         // extui(%x : i1 -> iN) != 0  ->  %x
1308         if (getPredicate() == arith::CmpIPredicate::ne) {
1309           return extOp.getOperand();
1310         }
1311       }
1312     }
1313   }
1314 
1315   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1316   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1317   if (!lhs || !rhs)
1318     return {};
1319 
1320   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1321   return BoolAttr::get(getContext(), val);
1322 }
1323 
1324 //===----------------------------------------------------------------------===//
1325 // CmpFOp
1326 //===----------------------------------------------------------------------===//
1327 
1328 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1329 /// comparison predicates.
1330 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1331                                     const APFloat &lhs, const APFloat &rhs) {
1332   auto cmpResult = lhs.compare(rhs);
1333   switch (predicate) {
1334   case arith::CmpFPredicate::AlwaysFalse:
1335     return false;
1336   case arith::CmpFPredicate::OEQ:
1337     return cmpResult == APFloat::cmpEqual;
1338   case arith::CmpFPredicate::OGT:
1339     return cmpResult == APFloat::cmpGreaterThan;
1340   case arith::CmpFPredicate::OGE:
1341     return cmpResult == APFloat::cmpGreaterThan ||
1342            cmpResult == APFloat::cmpEqual;
1343   case arith::CmpFPredicate::OLT:
1344     return cmpResult == APFloat::cmpLessThan;
1345   case arith::CmpFPredicate::OLE:
1346     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1347   case arith::CmpFPredicate::ONE:
1348     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1349   case arith::CmpFPredicate::ORD:
1350     return cmpResult != APFloat::cmpUnordered;
1351   case arith::CmpFPredicate::UEQ:
1352     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1353   case arith::CmpFPredicate::UGT:
1354     return cmpResult == APFloat::cmpUnordered ||
1355            cmpResult == APFloat::cmpGreaterThan;
1356   case arith::CmpFPredicate::UGE:
1357     return cmpResult == APFloat::cmpUnordered ||
1358            cmpResult == APFloat::cmpGreaterThan ||
1359            cmpResult == APFloat::cmpEqual;
1360   case arith::CmpFPredicate::ULT:
1361     return cmpResult == APFloat::cmpUnordered ||
1362            cmpResult == APFloat::cmpLessThan;
1363   case arith::CmpFPredicate::ULE:
1364     return cmpResult == APFloat::cmpUnordered ||
1365            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1366   case arith::CmpFPredicate::UNE:
1367     return cmpResult != APFloat::cmpEqual;
1368   case arith::CmpFPredicate::UNO:
1369     return cmpResult == APFloat::cmpUnordered;
1370   case arith::CmpFPredicate::AlwaysTrue:
1371     return true;
1372   }
1373   llvm_unreachable("unknown cmpf predicate kind");
1374 }
1375 
1376 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1377   assert(operands.size() == 2 && "cmpf takes two operands");
1378 
1379   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1380   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1381 
1382   // If one operand is NaN, making them both NaN does not change the result.
1383   if (lhs && lhs.getValue().isNaN())
1384     rhs = lhs;
1385   if (rhs && rhs.getValue().isNaN())
1386     lhs = rhs;
1387 
1388   if (!lhs || !rhs)
1389     return {};
1390 
1391   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1392   return BoolAttr::get(getContext(), val);
1393 }
1394 
1395 //===----------------------------------------------------------------------===//
1396 // SelectOp
1397 //===----------------------------------------------------------------------===//
1398 
1399 // Transforms a select of a boolean to arithmetic operations
1400 //
1401 //  arith.select %arg, %x, %y : i1
1402 //
1403 //  becomes
1404 //
1405 //  and(%arg, %x) or and(!%arg, %y)
1406 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1407   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1408 
1409   LogicalResult matchAndRewrite(arith::SelectOp op,
1410                                 PatternRewriter &rewriter) const override {
1411     if (!op.getType().isInteger(1))
1412       return failure();
1413 
1414     Value falseConstant =
1415         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1416     Value notCondition = rewriter.create<arith::XOrIOp>(
1417         op.getLoc(), op.getCondition(), falseConstant);
1418 
1419     Value trueVal = rewriter.create<arith::AndIOp>(
1420         op.getLoc(), op.getCondition(), op.getTrueValue());
1421     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1422                                                     op.getFalseValue());
1423     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1424     return success();
1425   }
1426 };
1427 
1428 //  select %arg, %c1, %c0 => extui %arg
1429 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1430   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1431 
1432   LogicalResult matchAndRewrite(arith::SelectOp op,
1433                                 PatternRewriter &rewriter) const override {
1434     // Cannot extui i1 to i1, or i1 to f32
1435     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1436       return failure();
1437 
1438     // select %x, c1, %c0 => extui %arg
1439     if (matchPattern(op.getTrueValue(), m_One()))
1440       if (matchPattern(op.getFalseValue(), m_Zero())) {
1441         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1442                                                     op.getCondition());
1443         return success();
1444       }
1445 
1446     // select %x, c0, %c1 => extui (xor %arg, true)
1447     if (matchPattern(op.getTrueValue(), m_Zero()))
1448       if (matchPattern(op.getFalseValue(), m_One())) {
1449         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1450             op, op.getType(),
1451             rewriter.create<arith::XOrIOp>(
1452                 op.getLoc(), op.getCondition(),
1453                 rewriter.create<arith::ConstantIntOp>(
1454                     op.getLoc(), 1, op.getCondition().getType())));
1455         return success();
1456       }
1457 
1458     return failure();
1459   }
1460 };
1461 
1462 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1463                                                   MLIRContext *context) {
1464   results.insert<SelectI1Simplify, SelectToExtUI>(context);
1465 }
1466 
1467 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1468   Value trueVal = getTrueValue();
1469   Value falseVal = getFalseValue();
1470   if (trueVal == falseVal)
1471     return trueVal;
1472 
1473   Value condition = getCondition();
1474 
1475   // select true, %0, %1 => %0
1476   if (matchPattern(condition, m_One()))
1477     return trueVal;
1478 
1479   // select false, %0, %1 => %1
1480   if (matchPattern(condition, m_Zero()))
1481     return falseVal;
1482 
1483   // select %x, true, false => %x
1484   if (getType().isInteger(1))
1485     if (matchPattern(getTrueValue(), m_One()))
1486       if (matchPattern(getFalseValue(), m_Zero()))
1487         return condition;
1488 
1489   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1490     auto pred = cmp.getPredicate();
1491     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1492       auto cmpLhs = cmp.getLhs();
1493       auto cmpRhs = cmp.getRhs();
1494 
1495       // %0 = arith.cmpi eq, %arg0, %arg1
1496       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1497 
1498       // %0 = arith.cmpi ne, %arg0, %arg1
1499       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1500 
1501       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1502           (cmpRhs == trueVal && cmpLhs == falseVal))
1503         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1504     }
1505   }
1506   return nullptr;
1507 }
1508 
1509 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1510   Type conditionType, resultType;
1511   SmallVector<OpAsmParser::OperandType, 3> operands;
1512   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1513       parser.parseOptionalAttrDict(result.attributes) ||
1514       parser.parseColonType(resultType))
1515     return failure();
1516 
1517   // Check for the explicit condition type if this is a masked tensor or vector.
1518   if (succeeded(parser.parseOptionalComma())) {
1519     conditionType = resultType;
1520     if (parser.parseType(resultType))
1521       return failure();
1522   } else {
1523     conditionType = parser.getBuilder().getI1Type();
1524   }
1525 
1526   result.addTypes(resultType);
1527   return parser.resolveOperands(operands,
1528                                 {conditionType, resultType, resultType},
1529                                 parser.getNameLoc(), result.operands);
1530 }
1531 
1532 void arith::SelectOp::print(OpAsmPrinter &p) {
1533   p << " " << getOperands();
1534   p.printOptionalAttrDict((*this)->getAttrs());
1535   p << " : ";
1536   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1537     p << condType << ", ";
1538   p << getType();
1539 }
1540 
1541 LogicalResult arith::SelectOp::verify() {
1542   Type conditionType = getCondition().getType();
1543   if (conditionType.isSignlessInteger(1))
1544     return success();
1545 
1546   // If the result type is a vector or tensor, the type can be a mask with the
1547   // same elements.
1548   Type resultType = getType();
1549   if (!resultType.isa<TensorType, VectorType>())
1550     return emitOpError() << "expected condition to be a signless i1, but got "
1551                          << conditionType;
1552   Type shapedConditionType = getI1SameShape(resultType);
1553   if (conditionType != shapedConditionType) {
1554     return emitOpError() << "expected condition type to have the same shape "
1555                             "as the result type, expected "
1556                          << shapedConditionType << ", but got "
1557                          << conditionType;
1558   }
1559   return success();
1560 }
1561 
1562 //===----------------------------------------------------------------------===//
1563 // Atomic Enum
1564 //===----------------------------------------------------------------------===//
1565 
1566 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1567 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1568                                             OpBuilder &builder, Location loc) {
1569   switch (kind) {
1570   case AtomicRMWKind::maxf:
1571     return builder.getFloatAttr(
1572         resultType,
1573         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1574                         /*Negative=*/true));
1575   case AtomicRMWKind::addf:
1576   case AtomicRMWKind::addi:
1577   case AtomicRMWKind::maxu:
1578   case AtomicRMWKind::ori:
1579     return builder.getZeroAttr(resultType);
1580   case AtomicRMWKind::andi:
1581     return builder.getIntegerAttr(
1582         resultType,
1583         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1584   case AtomicRMWKind::maxs:
1585     return builder.getIntegerAttr(
1586         resultType,
1587         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1588   case AtomicRMWKind::minf:
1589     return builder.getFloatAttr(
1590         resultType,
1591         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1592                         /*Negative=*/false));
1593   case AtomicRMWKind::mins:
1594     return builder.getIntegerAttr(
1595         resultType,
1596         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1597   case AtomicRMWKind::minu:
1598     return builder.getIntegerAttr(
1599         resultType,
1600         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1601   case AtomicRMWKind::muli:
1602     return builder.getIntegerAttr(resultType, 1);
1603   case AtomicRMWKind::mulf:
1604     return builder.getFloatAttr(resultType, 1);
1605   // TODO: Add remaining reduction operations.
1606   default:
1607     (void)emitOptionalError(loc, "Reduction operation type not supported");
1608     break;
1609   }
1610   return nullptr;
1611 }
1612 
1613 /// Returns the identity value associated with an AtomicRMWKind op.
1614 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1615                                     OpBuilder &builder, Location loc) {
1616   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1617   return builder.create<arith::ConstantOp>(loc, attr);
1618 }
1619 
1620 /// Return the value obtained by applying the reduction operation kind
1621 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1622 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1623                                   Location loc, Value lhs, Value rhs) {
1624   switch (op) {
1625   case AtomicRMWKind::addf:
1626     return builder.create<arith::AddFOp>(loc, lhs, rhs);
1627   case AtomicRMWKind::addi:
1628     return builder.create<arith::AddIOp>(loc, lhs, rhs);
1629   case AtomicRMWKind::mulf:
1630     return builder.create<arith::MulFOp>(loc, lhs, rhs);
1631   case AtomicRMWKind::muli:
1632     return builder.create<arith::MulIOp>(loc, lhs, rhs);
1633   case AtomicRMWKind::maxf:
1634     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1635   case AtomicRMWKind::minf:
1636     return builder.create<arith::MinFOp>(loc, lhs, rhs);
1637   case AtomicRMWKind::maxs:
1638     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1639   case AtomicRMWKind::mins:
1640     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1641   case AtomicRMWKind::maxu:
1642     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1643   case AtomicRMWKind::minu:
1644     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1645   case AtomicRMWKind::ori:
1646     return builder.create<arith::OrIOp>(loc, lhs, rhs);
1647   case AtomicRMWKind::andi:
1648     return builder.create<arith::AndIOp>(loc, lhs, rhs);
1649   // TODO: Add remaining reduction operations.
1650   default:
1651     (void)emitOptionalError(loc, "Reduction operation type not supported");
1652     break;
1653   }
1654   return nullptr;
1655 }
1656 
1657 //===----------------------------------------------------------------------===//
1658 // TableGen'd op method definitions
1659 //===----------------------------------------------------------------------===//
1660 
1661 #define GET_OP_CLASSES
1662 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1663 
1664 //===----------------------------------------------------------------------===//
1665 // TableGen'd enum attribute definitions
1666 //===----------------------------------------------------------------------===//
1667 
1668 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1669