1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/CommonFolders.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 using namespace mlir;
18 using namespace mlir::arith;
19 
20 //===----------------------------------------------------------------------===//
21 // Pattern helpers
22 //===----------------------------------------------------------------------===//
23 
24 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
25                                    Attribute lhs, Attribute rhs) {
26   return builder.getIntegerAttr(res.getType(),
27                                 lhs.cast<IntegerAttr>().getInt() +
28                                     rhs.cast<IntegerAttr>().getInt());
29 }
30 
31 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
32                                    Attribute lhs, Attribute rhs) {
33   return builder.getIntegerAttr(res.getType(),
34                                 lhs.cast<IntegerAttr>().getInt() -
35                                     rhs.cast<IntegerAttr>().getInt());
36 }
37 
38 /// Invert an integer comparison predicate.
39 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
40   switch (pred) {
41   case arith::CmpIPredicate::eq:
42     return arith::CmpIPredicate::ne;
43   case arith::CmpIPredicate::ne:
44     return arith::CmpIPredicate::eq;
45   case arith::CmpIPredicate::slt:
46     return arith::CmpIPredicate::sge;
47   case arith::CmpIPredicate::sle:
48     return arith::CmpIPredicate::sgt;
49   case arith::CmpIPredicate::sgt:
50     return arith::CmpIPredicate::sle;
51   case arith::CmpIPredicate::sge:
52     return arith::CmpIPredicate::slt;
53   case arith::CmpIPredicate::ult:
54     return arith::CmpIPredicate::uge;
55   case arith::CmpIPredicate::ule:
56     return arith::CmpIPredicate::ugt;
57   case arith::CmpIPredicate::ugt:
58     return arith::CmpIPredicate::ule;
59   case arith::CmpIPredicate::uge:
60     return arith::CmpIPredicate::ult;
61   }
62   llvm_unreachable("unknown cmpi predicate kind");
63 }
64 
65 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
66   return arith::CmpIPredicateAttr::get(pred.getContext(),
67                                        invertPredicate(pred.getValue()));
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // TableGen'd canonicalization patterns
72 //===----------------------------------------------------------------------===//
73 
74 namespace {
75 #include "ArithmeticCanonicalization.inc"
76 } // end anonymous namespace
77 
78 //===----------------------------------------------------------------------===//
79 // ConstantOp
80 //===----------------------------------------------------------------------===//
81 
82 void arith::ConstantOp::getAsmResultNames(
83     function_ref<void(Value, StringRef)> setNameFn) {
84   auto type = getType();
85   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
86     auto intType = type.dyn_cast<IntegerType>();
87 
88     // Sugar i1 constants with 'true' and 'false'.
89     if (intType && intType.getWidth() == 1)
90       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
91 
92     // Otherwise, build a compex name with the value and type.
93     SmallString<32> specialNameBuffer;
94     llvm::raw_svector_ostream specialName(specialNameBuffer);
95     specialName << 'c' << intCst.getInt();
96     if (intType)
97       specialName << '_' << type;
98     setNameFn(getResult(), specialName.str());
99   } else {
100     setNameFn(getResult(), "cst");
101   }
102 }
103 
104 /// TODO: disallow arith.constant to return anything other than signless integer
105 /// or float like.
106 static LogicalResult verify(arith::ConstantOp op) {
107   auto type = op.getType();
108   // The value's type must match the return type.
109   if (op.getValue().getType() != type) {
110     return op.emitOpError() << "value type " << op.getValue().getType()
111                             << " must match return type: " << type;
112   }
113   // Integer values must be signless.
114   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
115     return op.emitOpError("integer return type must be signless");
116   // Any float or elements attribute are acceptable.
117   if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
118     return op.emitOpError(
119         "value must be an integer, float, or elements attribute");
120   }
121   return success();
122 }
123 
124 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
125   // The value's type must be the same as the provided type.
126   if (value.getType() != type)
127     return false;
128   // Integer values must be signless.
129   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
130     return false;
131   // Integer, float, and element attributes are buildable.
132   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
133 }
134 
135 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
136   return getValue();
137 }
138 
139 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
140                                  int64_t value, unsigned width) {
141   auto type = builder.getIntegerType(width);
142   arith::ConstantOp::build(builder, result, type,
143                            builder.getIntegerAttr(type, value));
144 }
145 
146 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
147                                  int64_t value, Type type) {
148   assert(type.isSignlessInteger() &&
149          "ConstantIntOp can only have signless integer type values");
150   arith::ConstantOp::build(builder, result, type,
151                            builder.getIntegerAttr(type, value));
152 }
153 
154 bool arith::ConstantIntOp::classof(Operation *op) {
155   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
156     return constOp.getType().isSignlessInteger();
157   return false;
158 }
159 
160 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
161                                    const APFloat &value, FloatType type) {
162   arith::ConstantOp::build(builder, result, type,
163                            builder.getFloatAttr(type, value));
164 }
165 
166 bool arith::ConstantFloatOp::classof(Operation *op) {
167   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
168     return constOp.getType().isa<FloatType>();
169   return false;
170 }
171 
172 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
173                                    int64_t value) {
174   arith::ConstantOp::build(builder, result, builder.getIndexType(),
175                            builder.getIndexAttr(value));
176 }
177 
178 bool arith::ConstantIndexOp::classof(Operation *op) {
179   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
180     return constOp.getType().isIndex();
181   return false;
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // AddIOp
186 //===----------------------------------------------------------------------===//
187 
188 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
189   // addi(x, 0) -> x
190   if (matchPattern(getRhs(), m_Zero()))
191     return getLhs();
192 
193   return constFoldBinaryOp<IntegerAttr>(operands,
194                                         [](APInt a, APInt b) { return a + b; });
195 }
196 
197 void arith::AddIOp::getCanonicalizationPatterns(
198     OwningRewritePatternList &patterns, MLIRContext *context) {
199   patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
200       context);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // SubIOp
205 //===----------------------------------------------------------------------===//
206 
207 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
208   // subi(x,x) -> 0
209   if (getOperand(0) == getOperand(1))
210     return Builder(getContext()).getZeroAttr(getType());
211   // subi(x,0) -> x
212   if (matchPattern(getRhs(), m_Zero()))
213     return getLhs();
214 
215   return constFoldBinaryOp<IntegerAttr>(operands,
216                                         [](APInt a, APInt b) { return a - b; });
217 }
218 
219 void arith::SubIOp::getCanonicalizationPatterns(
220     OwningRewritePatternList &patterns, MLIRContext *context) {
221   patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
222                   SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
223                   SubILHSSubConstantLHS>(context);
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // MulIOp
228 //===----------------------------------------------------------------------===//
229 
230 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
231   // muli(x, 0) -> 0
232   if (matchPattern(getRhs(), m_Zero()))
233     return getRhs();
234   // muli(x, 1) -> x
235   if (matchPattern(getRhs(), m_One()))
236     return getOperand(0);
237   // TODO: Handle the overflow case.
238 
239   // default folder
240   return constFoldBinaryOp<IntegerAttr>(operands,
241                                         [](APInt a, APInt b) { return a * b; });
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // DivUIOp
246 //===----------------------------------------------------------------------===//
247 
248 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
249   // Don't fold if it would require a division by zero.
250   bool div0 = false;
251   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
252     if (div0 || !b) {
253       div0 = true;
254       return a;
255     }
256     return a.udiv(b);
257   });
258 
259   // Fold out division by one. Assumes all tensors of all ones are splats.
260   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
261     if (rhs.getValue() == 1)
262       return getLhs();
263   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
264     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
265       return getLhs();
266   }
267 
268   return div0 ? Attribute() : result;
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // DivSIOp
273 //===----------------------------------------------------------------------===//
274 
275 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
276   // Don't fold if it would overflow or if it requires a division by zero.
277   bool overflowOrDiv0 = false;
278   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
279     if (overflowOrDiv0 || !b) {
280       overflowOrDiv0 = true;
281       return a;
282     }
283     return a.sdiv_ov(b, overflowOrDiv0);
284   });
285 
286   // Fold out division by one. Assumes all tensors of all ones are splats.
287   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
288     if (rhs.getValue() == 1)
289       return getLhs();
290   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
291     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
292       return getLhs();
293   }
294 
295   return overflowOrDiv0 ? Attribute() : result;
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // Ceil and floor division folding helpers
300 //===----------------------------------------------------------------------===//
301 
302 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
303   // Returns (a-1)/b + 1
304   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
305   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
306   return val.sadd_ov(one, overflow);
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // CeilDivUIOp
311 //===----------------------------------------------------------------------===//
312 
313 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
314   bool overflowOrDiv0 = false;
315   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
316     if (overflowOrDiv0 || !b) {
317       overflowOrDiv0 = true;
318       return a;
319     }
320     APInt quotient = a.udiv(b);
321     if (!a.urem(b))
322       return quotient;
323     APInt one(a.getBitWidth(), 1, true);
324     return quotient.uadd_ov(one, overflowOrDiv0);
325   });
326   // Fold out ceil division by one. Assumes all tensors of all ones are
327   // splats.
328   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
329     if (rhs.getValue() == 1)
330       return getLhs();
331   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
332     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
333       return getLhs();
334   }
335 
336   return overflowOrDiv0 ? Attribute() : result;
337 }
338 
339 //===----------------------------------------------------------------------===//
340 // CeilDivSIOp
341 //===----------------------------------------------------------------------===//
342 
343 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
344   // Don't fold if it would overflow or if it requires a division by zero.
345   bool overflowOrDiv0 = false;
346   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
347     if (overflowOrDiv0 || !b) {
348       overflowOrDiv0 = true;
349       return a;
350     }
351     unsigned bits = a.getBitWidth();
352     APInt zero = APInt::getZero(bits);
353     if (a.sgt(zero) && b.sgt(zero)) {
354       // Both positive, return ceil(a, b).
355       return signedCeilNonnegInputs(a, b, overflowOrDiv0);
356     }
357     if (a.slt(zero) && b.slt(zero)) {
358       // Both negative, return ceil(-a, -b).
359       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
360       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
361       return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
362     }
363     if (a.slt(zero) && b.sgt(zero)) {
364       // A is negative, b is positive, return - ( -a / b).
365       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
366       APInt div = posA.sdiv_ov(b, overflowOrDiv0);
367       return zero.ssub_ov(div, overflowOrDiv0);
368     }
369     // A is positive (or zero), b is negative, return - (a / -b).
370     APInt posB = zero.ssub_ov(b, overflowOrDiv0);
371     APInt div = a.sdiv_ov(posB, overflowOrDiv0);
372     return zero.ssub_ov(div, overflowOrDiv0);
373   });
374 
375   // Fold out ceil division by one. Assumes all tensors of all ones are
376   // splats.
377   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
378     if (rhs.getValue() == 1)
379       return getLhs();
380   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
381     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
382       return getLhs();
383   }
384 
385   return overflowOrDiv0 ? Attribute() : result;
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // FloorDivSIOp
390 //===----------------------------------------------------------------------===//
391 
392 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
393   // Don't fold if it would overflow or if it requires a division by zero.
394   bool overflowOrDiv0 = false;
395   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
396     if (overflowOrDiv0 || !b) {
397       overflowOrDiv0 = true;
398       return a;
399     }
400     unsigned bits = a.getBitWidth();
401     APInt zero = APInt::getZero(bits);
402     if (a.sge(zero) && b.sgt(zero)) {
403       // Both positive (or a is zero), return a / b.
404       return a.sdiv_ov(b, overflowOrDiv0);
405     }
406     if (a.sle(zero) && b.slt(zero)) {
407       // Both negative (or a is zero), return -a / -b.
408       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
409       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
410       return posA.sdiv_ov(posB, overflowOrDiv0);
411     }
412     if (a.slt(zero) && b.sgt(zero)) {
413       // A is negative, b is positive, return - ceil(-a, b).
414       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
415       APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
416       return zero.ssub_ov(ceil, overflowOrDiv0);
417     }
418     // A is positive, b is negative, return - ceil(a, -b).
419     APInt posB = zero.ssub_ov(b, overflowOrDiv0);
420     APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
421     return zero.ssub_ov(ceil, overflowOrDiv0);
422   });
423 
424   // Fold out floor division by one. Assumes all tensors of all ones are
425   // splats.
426   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
427     if (rhs.getValue() == 1)
428       return getLhs();
429   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
430     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
431       return getLhs();
432   }
433 
434   return overflowOrDiv0 ? Attribute() : result;
435 }
436 
437 //===----------------------------------------------------------------------===//
438 // RemUIOp
439 //===----------------------------------------------------------------------===//
440 
441 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
442   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
443   if (!rhs)
444     return {};
445   auto rhsValue = rhs.getValue();
446 
447   // x % 1 = 0
448   if (rhsValue.isOneValue())
449     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
450 
451   // Don't fold if it requires division by zero.
452   if (rhsValue.isNullValue())
453     return {};
454 
455   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
456   if (!lhs)
457     return {};
458   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // RemSIOp
463 //===----------------------------------------------------------------------===//
464 
465 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
466   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
467   if (!rhs)
468     return {};
469   auto rhsValue = rhs.getValue();
470 
471   // x % 1 = 0
472   if (rhsValue.isOneValue())
473     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
474 
475   // Don't fold if it requires division by zero.
476   if (rhsValue.isNullValue())
477     return {};
478 
479   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
480   if (!lhs)
481     return {};
482   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // AndIOp
487 //===----------------------------------------------------------------------===//
488 
489 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
490   /// and(x, 0) -> 0
491   if (matchPattern(getRhs(), m_Zero()))
492     return getRhs();
493   /// and(x, allOnes) -> x
494   APInt intValue;
495   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
496     return getLhs();
497   /// and(x, x) -> x
498   if (getLhs() == getRhs())
499     return getRhs();
500 
501   return constFoldBinaryOp<IntegerAttr>(operands,
502                                         [](APInt a, APInt b) { return a & b; });
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // OrIOp
507 //===----------------------------------------------------------------------===//
508 
509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
510   /// or(x, 0) -> x
511   if (matchPattern(getRhs(), m_Zero()))
512     return getLhs();
513   /// or(x, x) -> x
514   if (getLhs() == getRhs())
515     return getRhs();
516   /// or(x, <all ones>) -> <all ones>
517   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
518     if (rhsAttr.getValue().isAllOnes())
519       return rhsAttr;
520 
521   return constFoldBinaryOp<IntegerAttr>(operands,
522                                         [](APInt a, APInt b) { return a | b; });
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // XOrIOp
527 //===----------------------------------------------------------------------===//
528 
529 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
530   /// xor(x, 0) -> x
531   if (matchPattern(getRhs(), m_Zero()))
532     return getLhs();
533   /// xor(x, x) -> 0
534   if (getLhs() == getRhs())
535     return Builder(getContext()).getZeroAttr(getType());
536 
537   return constFoldBinaryOp<IntegerAttr>(operands,
538                                         [](APInt a, APInt b) { return a ^ b; });
539 }
540 
541 void arith::XOrIOp::getCanonicalizationPatterns(
542     OwningRewritePatternList &patterns, MLIRContext *context) {
543   patterns.insert<XOrINotCmpI>(context);
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // AddFOp
548 //===----------------------------------------------------------------------===//
549 
550 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
551   return constFoldBinaryOp<FloatAttr>(
552       operands, [](APFloat a, APFloat b) { return a + b; });
553 }
554 
555 //===----------------------------------------------------------------------===//
556 // SubFOp
557 //===----------------------------------------------------------------------===//
558 
559 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
560   return constFoldBinaryOp<FloatAttr>(
561       operands, [](APFloat a, APFloat b) { return a - b; });
562 }
563 
564 //===----------------------------------------------------------------------===//
565 // MaxSIOp
566 //===----------------------------------------------------------------------===//
567 
568 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
569   assert(operands.size() == 2 && "binary operation takes two operands");
570 
571   // maxsi(x,x) -> x
572   if (getLhs() == getRhs())
573     return getRhs();
574 
575   APInt intValue;
576   // maxsi(x,MAX_INT) -> MAX_INT
577   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
578       intValue.isMaxSignedValue())
579     return getRhs();
580 
581   // maxsi(x, MIN_INT) -> x
582   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
583       intValue.isMinSignedValue())
584     return getLhs();
585 
586   return constFoldBinaryOp<IntegerAttr>(
587       operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); });
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // MaxUIOp
592 //===----------------------------------------------------------------------===//
593 
594 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
595   assert(operands.size() == 2 && "binary operation takes two operands");
596 
597   // maxui(x,x) -> x
598   if (getLhs() == getRhs())
599     return getRhs();
600 
601   APInt intValue;
602   // maxui(x,MAX_INT) -> MAX_INT
603   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
604     return getRhs();
605 
606   // maxui(x, MIN_INT) -> x
607   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
608     return getLhs();
609 
610   return constFoldBinaryOp<IntegerAttr>(
611       operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); });
612 }
613 
614 //===----------------------------------------------------------------------===//
615 // MinSIOp
616 //===----------------------------------------------------------------------===//
617 
618 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
619   assert(operands.size() == 2 && "binary operation takes two operands");
620 
621   // minsi(x,x) -> x
622   if (getLhs() == getRhs())
623     return getRhs();
624 
625   APInt intValue;
626   // minsi(x,MIN_INT) -> MIN_INT
627   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
628       intValue.isMinSignedValue())
629     return getRhs();
630 
631   // minsi(x, MAX_INT) -> x
632   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
633       intValue.isMaxSignedValue())
634     return getLhs();
635 
636   return constFoldBinaryOp<IntegerAttr>(
637       operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); });
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // MinUIOp
642 //===----------------------------------------------------------------------===//
643 
644 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
645   assert(operands.size() == 2 && "binary operation takes two operands");
646 
647   // minui(x,x) -> x
648   if (getLhs() == getRhs())
649     return getRhs();
650 
651   APInt intValue;
652   // minui(x,MIN_INT) -> MIN_INT
653   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
654     return getRhs();
655 
656   // minui(x, MAX_INT) -> x
657   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
658     return getLhs();
659 
660   return constFoldBinaryOp<IntegerAttr>(
661       operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); });
662 }
663 
664 //===----------------------------------------------------------------------===//
665 // MulFOp
666 //===----------------------------------------------------------------------===//
667 
668 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
669   return constFoldBinaryOp<FloatAttr>(
670       operands, [](APFloat a, APFloat b) { return a * b; });
671 }
672 
673 //===----------------------------------------------------------------------===//
674 // DivFOp
675 //===----------------------------------------------------------------------===//
676 
677 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
678   return constFoldBinaryOp<FloatAttr>(
679       operands, [](APFloat a, APFloat b) { return a / b; });
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // Utility functions for verifying cast ops
684 //===----------------------------------------------------------------------===//
685 
686 template <typename... Types>
687 using type_list = std::tuple<Types...> *;
688 
689 /// Returns a non-null type only if the provided type is one of the allowed
690 /// types or one of the allowed shaped types of the allowed types. Returns the
691 /// element type if a valid shaped type is provided.
692 template <typename... ShapedTypes, typename... ElementTypes>
693 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
694                               type_list<ElementTypes...>) {
695   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
696     return {};
697 
698   auto underlyingType = getElementTypeOrSelf(type);
699   if (!underlyingType.isa<ElementTypes...>())
700     return {};
701 
702   return underlyingType;
703 }
704 
705 /// Get allowed underlying types for vectors and tensors.
706 template <typename... ElementTypes>
707 static Type getTypeIfLike(Type type) {
708   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
709                            type_list<ElementTypes...>());
710 }
711 
712 /// Get allowed underlying types for vectors, tensors, and memrefs.
713 template <typename... ElementTypes>
714 static Type getTypeIfLikeOrMemRef(Type type) {
715   return getUnderlyingType(type,
716                            type_list<VectorType, TensorType, MemRefType>(),
717                            type_list<ElementTypes...>());
718 }
719 
720 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
721   return inputs.size() == 1 && outputs.size() == 1 &&
722          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
723 }
724 
725 //===----------------------------------------------------------------------===//
726 // Verifiers for integer and floating point extension/truncation ops
727 //===----------------------------------------------------------------------===//
728 
729 // Extend ops can only extend to a wider type.
730 template <typename ValType, typename Op>
731 static LogicalResult verifyExtOp(Op op) {
732   Type srcType = getElementTypeOrSelf(op.getIn().getType());
733   Type dstType = getElementTypeOrSelf(op.getType());
734 
735   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
736     return op.emitError("result type ")
737            << dstType << " must be wider than operand type " << srcType;
738 
739   return success();
740 }
741 
742 // Truncate ops can only truncate to a shorter type.
743 template <typename ValType, typename Op>
744 static LogicalResult verifyTruncateOp(Op op) {
745   Type srcType = getElementTypeOrSelf(op.getIn().getType());
746   Type dstType = getElementTypeOrSelf(op.getType());
747 
748   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
749     return op.emitError("result type ")
750            << dstType << " must be shorter than operand type " << srcType;
751 
752   return success();
753 }
754 
755 /// Validate a cast that changes the width of a type.
756 template <template <typename> class WidthComparator, typename... ElementTypes>
757 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
758   if (!areValidCastInputsAndOutputs(inputs, outputs))
759     return false;
760 
761   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
762   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
763   if (!srcType || !dstType)
764     return false;
765 
766   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
767                                      srcType.getIntOrFloatBitWidth());
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // ExtUIOp
772 //===----------------------------------------------------------------------===//
773 
774 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
775   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
776     return IntegerAttr::get(
777         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
778 
779   return {};
780 }
781 
782 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
783   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // ExtSIOp
788 //===----------------------------------------------------------------------===//
789 
790 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
791   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
792     return IntegerAttr::get(
793         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
794 
795   return {};
796 }
797 
798 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
799   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
800 }
801 
802 //===----------------------------------------------------------------------===//
803 // ExtFOp
804 //===----------------------------------------------------------------------===//
805 
806 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
807   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
808 }
809 
810 //===----------------------------------------------------------------------===//
811 // TruncIOp
812 //===----------------------------------------------------------------------===//
813 
814 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
815   // trunci(zexti(a)) -> a
816   // trunci(sexti(a)) -> a
817   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
818       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
819     return getOperand().getDefiningOp()->getOperand(0);
820 
821   assert(operands.size() == 1 && "unary operation takes one operand");
822 
823   if (!operands[0])
824     return {};
825 
826   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
827     return IntegerAttr::get(
828         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
829   }
830 
831   return {};
832 }
833 
834 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
835   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
836 }
837 
838 //===----------------------------------------------------------------------===//
839 // TruncFOp
840 //===----------------------------------------------------------------------===//
841 
842 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
843 /// can be represented without precision loss or rounding.
844 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
845   assert(operands.size() == 1 && "unary operation takes one operand");
846 
847   auto constOperand = operands.front();
848   if (!constOperand || !constOperand.isa<FloatAttr>())
849     return {};
850 
851   // Convert to target type via 'double'.
852   double sourceValue =
853       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
854   auto targetAttr = FloatAttr::get(getType(), sourceValue);
855 
856   // Propagate if constant's value does not change after truncation.
857   if (sourceValue == targetAttr.getValue().convertToDouble())
858     return targetAttr;
859 
860   return {};
861 }
862 
863 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
864   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // Verifiers for casts between integers and floats.
869 //===----------------------------------------------------------------------===//
870 
871 template <typename From, typename To>
872 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
873   if (!areValidCastInputsAndOutputs(inputs, outputs))
874     return false;
875 
876   auto srcType = getTypeIfLike<From>(inputs.front());
877   auto dstType = getTypeIfLike<To>(outputs.back());
878 
879   return srcType && dstType;
880 }
881 
882 //===----------------------------------------------------------------------===//
883 // UIToFPOp
884 //===----------------------------------------------------------------------===//
885 
886 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
887   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
888 }
889 
890 //===----------------------------------------------------------------------===//
891 // SIToFPOp
892 //===----------------------------------------------------------------------===//
893 
894 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
895   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
896 }
897 
898 //===----------------------------------------------------------------------===//
899 // FPToUIOp
900 //===----------------------------------------------------------------------===//
901 
902 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
903   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
904 }
905 
906 //===----------------------------------------------------------------------===//
907 // FPToSIOp
908 //===----------------------------------------------------------------------===//
909 
910 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
911   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
912 }
913 
914 //===----------------------------------------------------------------------===//
915 // IndexCastOp
916 //===----------------------------------------------------------------------===//
917 
918 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
919                                            TypeRange outputs) {
920   if (!areValidCastInputsAndOutputs(inputs, outputs))
921     return false;
922 
923   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
924   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
925   if (!srcType || !dstType)
926     return false;
927 
928   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
929          (srcType.isSignlessInteger() && dstType.isIndex());
930 }
931 
932 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
933   // index_cast(constant) -> constant
934   // A little hack because we go through int. Otherwise, the size of the
935   // constant might need to change.
936   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
937     return IntegerAttr::get(getType(), value.getInt());
938 
939   return {};
940 }
941 
942 void arith::IndexCastOp::getCanonicalizationPatterns(
943     OwningRewritePatternList &patterns, MLIRContext *context) {
944   patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
945 }
946 
947 //===----------------------------------------------------------------------===//
948 // BitcastOp
949 //===----------------------------------------------------------------------===//
950 
951 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
952   if (!areValidCastInputsAndOutputs(inputs, outputs))
953     return false;
954 
955   auto srcType =
956       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
957   auto dstType =
958       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
959   if (!srcType || !dstType)
960     return false;
961 
962   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
963 }
964 
965 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
966   assert(operands.size() == 1 && "bitcast op expects 1 operand");
967 
968   auto resType = getType();
969   auto operand = operands[0];
970   if (!operand)
971     return {};
972 
973   /// Bitcast dense elements.
974   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
975     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
976   /// Other shaped types unhandled.
977   if (resType.isa<ShapedType>())
978     return {};
979 
980   /// Bitcast integer or float to integer or float.
981   APInt bits = operand.isa<FloatAttr>()
982                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
983                    : operand.cast<IntegerAttr>().getValue();
984 
985   if (auto resFloatType = resType.dyn_cast<FloatType>())
986     return FloatAttr::get(resType,
987                           APFloat(resFloatType.getFloatSemantics(), bits));
988   return IntegerAttr::get(resType, bits);
989 }
990 
991 void arith::BitcastOp::getCanonicalizationPatterns(
992     OwningRewritePatternList &patterns, MLIRContext *context) {
993   patterns.insert<BitcastOfBitcast>(context);
994 }
995 
996 //===----------------------------------------------------------------------===//
997 // Helpers for compare ops
998 //===----------------------------------------------------------------------===//
999 
1000 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1001 static Type getI1SameShape(Type type) {
1002   auto i1Type = IntegerType::get(type.getContext(), 1);
1003   if (auto tensorType = type.dyn_cast<RankedTensorType>())
1004     return RankedTensorType::get(tensorType.getShape(), i1Type);
1005   if (type.isa<UnrankedTensorType>())
1006     return UnrankedTensorType::get(i1Type);
1007   if (auto vectorType = type.dyn_cast<VectorType>())
1008     return VectorType::get(vectorType.getShape(), i1Type);
1009   return i1Type;
1010 }
1011 
1012 //===----------------------------------------------------------------------===//
1013 // CmpIOp
1014 //===----------------------------------------------------------------------===//
1015 
1016 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1017 /// comparison predicates.
1018 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1019                                     const APInt &lhs, const APInt &rhs) {
1020   switch (predicate) {
1021   case arith::CmpIPredicate::eq:
1022     return lhs.eq(rhs);
1023   case arith::CmpIPredicate::ne:
1024     return lhs.ne(rhs);
1025   case arith::CmpIPredicate::slt:
1026     return lhs.slt(rhs);
1027   case arith::CmpIPredicate::sle:
1028     return lhs.sle(rhs);
1029   case arith::CmpIPredicate::sgt:
1030     return lhs.sgt(rhs);
1031   case arith::CmpIPredicate::sge:
1032     return lhs.sge(rhs);
1033   case arith::CmpIPredicate::ult:
1034     return lhs.ult(rhs);
1035   case arith::CmpIPredicate::ule:
1036     return lhs.ule(rhs);
1037   case arith::CmpIPredicate::ugt:
1038     return lhs.ugt(rhs);
1039   case arith::CmpIPredicate::uge:
1040     return lhs.uge(rhs);
1041   }
1042   llvm_unreachable("unknown cmpi predicate kind");
1043 }
1044 
1045 /// Returns true if the predicate is true for two equal operands.
1046 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1047   switch (predicate) {
1048   case arith::CmpIPredicate::eq:
1049   case arith::CmpIPredicate::sle:
1050   case arith::CmpIPredicate::sge:
1051   case arith::CmpIPredicate::ule:
1052   case arith::CmpIPredicate::uge:
1053     return true;
1054   case arith::CmpIPredicate::ne:
1055   case arith::CmpIPredicate::slt:
1056   case arith::CmpIPredicate::sgt:
1057   case arith::CmpIPredicate::ult:
1058   case arith::CmpIPredicate::ugt:
1059     return false;
1060   }
1061   llvm_unreachable("unknown cmpi predicate kind");
1062 }
1063 
1064 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1065   assert(operands.size() == 2 && "cmpi takes two operands");
1066 
1067   // cmpi(pred, x, x)
1068   if (getLhs() == getRhs()) {
1069     auto val = applyCmpPredicateToEqualOperands(getPredicate());
1070     return BoolAttr::get(getContext(), val);
1071   }
1072 
1073   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1074   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1075   if (!lhs || !rhs)
1076     return {};
1077 
1078   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1079   return BoolAttr::get(getContext(), val);
1080 }
1081 
1082 //===----------------------------------------------------------------------===//
1083 // CmpFOp
1084 //===----------------------------------------------------------------------===//
1085 
1086 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1087 /// comparison predicates.
1088 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1089                                     const APFloat &lhs, const APFloat &rhs) {
1090   auto cmpResult = lhs.compare(rhs);
1091   switch (predicate) {
1092   case arith::CmpFPredicate::AlwaysFalse:
1093     return false;
1094   case arith::CmpFPredicate::OEQ:
1095     return cmpResult == APFloat::cmpEqual;
1096   case arith::CmpFPredicate::OGT:
1097     return cmpResult == APFloat::cmpGreaterThan;
1098   case arith::CmpFPredicate::OGE:
1099     return cmpResult == APFloat::cmpGreaterThan ||
1100            cmpResult == APFloat::cmpEqual;
1101   case arith::CmpFPredicate::OLT:
1102     return cmpResult == APFloat::cmpLessThan;
1103   case arith::CmpFPredicate::OLE:
1104     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1105   case arith::CmpFPredicate::ONE:
1106     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1107   case arith::CmpFPredicate::ORD:
1108     return cmpResult != APFloat::cmpUnordered;
1109   case arith::CmpFPredicate::UEQ:
1110     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1111   case arith::CmpFPredicate::UGT:
1112     return cmpResult == APFloat::cmpUnordered ||
1113            cmpResult == APFloat::cmpGreaterThan;
1114   case arith::CmpFPredicate::UGE:
1115     return cmpResult == APFloat::cmpUnordered ||
1116            cmpResult == APFloat::cmpGreaterThan ||
1117            cmpResult == APFloat::cmpEqual;
1118   case arith::CmpFPredicate::ULT:
1119     return cmpResult == APFloat::cmpUnordered ||
1120            cmpResult == APFloat::cmpLessThan;
1121   case arith::CmpFPredicate::ULE:
1122     return cmpResult == APFloat::cmpUnordered ||
1123            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1124   case arith::CmpFPredicate::UNE:
1125     return cmpResult != APFloat::cmpEqual;
1126   case arith::CmpFPredicate::UNO:
1127     return cmpResult == APFloat::cmpUnordered;
1128   case arith::CmpFPredicate::AlwaysTrue:
1129     return true;
1130   }
1131   llvm_unreachable("unknown cmpf predicate kind");
1132 }
1133 
1134 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1135   assert(operands.size() == 2 && "cmpf takes two operands");
1136 
1137   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1138   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1139 
1140   if (!lhs || !rhs)
1141     return {};
1142 
1143   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1144   return BoolAttr::get(getContext(), val);
1145 }
1146 
1147 //===----------------------------------------------------------------------===//
1148 // TableGen'd op method definitions
1149 //===----------------------------------------------------------------------===//
1150 
1151 #define GET_OP_CLASSES
1152 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1153 
1154 //===----------------------------------------------------------------------===//
1155 // TableGen'd enum attribute definitions
1156 //===----------------------------------------------------------------------===//
1157 
1158 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1159