1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/CommonFolders.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 using namespace mlir;
18 using namespace mlir::arith;
19 
20 //===----------------------------------------------------------------------===//
21 // Pattern helpers
22 //===----------------------------------------------------------------------===//
23 
24 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
25                                    Attribute lhs, Attribute rhs) {
26   return builder.getIntegerAttr(res.getType(),
27                                 lhs.cast<IntegerAttr>().getInt() +
28                                     rhs.cast<IntegerAttr>().getInt());
29 }
30 
31 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
32                                    Attribute lhs, Attribute rhs) {
33   return builder.getIntegerAttr(res.getType(),
34                                 lhs.cast<IntegerAttr>().getInt() -
35                                     rhs.cast<IntegerAttr>().getInt());
36 }
37 
38 /// Invert an integer comparison predicate.
39 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
40   switch (pred) {
41   case arith::CmpIPredicate::eq:
42     return arith::CmpIPredicate::ne;
43   case arith::CmpIPredicate::ne:
44     return arith::CmpIPredicate::eq;
45   case arith::CmpIPredicate::slt:
46     return arith::CmpIPredicate::sge;
47   case arith::CmpIPredicate::sle:
48     return arith::CmpIPredicate::sgt;
49   case arith::CmpIPredicate::sgt:
50     return arith::CmpIPredicate::sle;
51   case arith::CmpIPredicate::sge:
52     return arith::CmpIPredicate::slt;
53   case arith::CmpIPredicate::ult:
54     return arith::CmpIPredicate::uge;
55   case arith::CmpIPredicate::ule:
56     return arith::CmpIPredicate::ugt;
57   case arith::CmpIPredicate::ugt:
58     return arith::CmpIPredicate::ule;
59   case arith::CmpIPredicate::uge:
60     return arith::CmpIPredicate::ult;
61   }
62   llvm_unreachable("unknown cmpi predicate kind");
63 }
64 
65 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
66   return arith::CmpIPredicateAttr::get(pred.getContext(),
67                                        invertPredicate(pred.getValue()));
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // TableGen'd canonicalization patterns
72 //===----------------------------------------------------------------------===//
73 
74 namespace {
75 #include "ArithmeticCanonicalization.inc"
76 } // end anonymous namespace
77 
78 //===----------------------------------------------------------------------===//
79 // AddIOp
80 //===----------------------------------------------------------------------===//
81 
82 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
83   // addi(x, 0) -> x
84   if (matchPattern(rhs(), m_Zero()))
85     return lhs();
86 
87   return constFoldBinaryOp<IntegerAttr>(operands,
88                                         [](APInt a, APInt b) { return a + b; });
89 }
90 
91 void arith::AddIOp::getCanonicalizationPatterns(
92     OwningRewritePatternList &patterns, MLIRContext *context) {
93   patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
94       context);
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // SubIOp
99 //===----------------------------------------------------------------------===//
100 
101 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
102   // subi(x,x) -> 0
103   if (getOperand(0) == getOperand(1))
104     return Builder(getContext()).getZeroAttr(getType());
105   // subi(x,0) -> x
106   if (matchPattern(rhs(), m_Zero()))
107     return lhs();
108 
109   return constFoldBinaryOp<IntegerAttr>(operands,
110                                         [](APInt a, APInt b) { return a - b; });
111 }
112 
113 void arith::SubIOp::getCanonicalizationPatterns(
114     OwningRewritePatternList &patterns, MLIRContext *context) {
115   patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
116                   SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
117                   SubILHSSubConstantLHS>(context);
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // MulIOp
122 //===----------------------------------------------------------------------===//
123 
124 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
125   // muli(x, 0) -> 0
126   if (matchPattern(rhs(), m_Zero()))
127     return rhs();
128   // muli(x, 1) -> x
129   if (matchPattern(rhs(), m_One()))
130     return getOperand(0);
131   // TODO: Handle the overflow case.
132 
133   // default folder
134   return constFoldBinaryOp<IntegerAttr>(operands,
135                                         [](APInt a, APInt b) { return a * b; });
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // DivUIOp
140 //===----------------------------------------------------------------------===//
141 
142 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
143   // Don't fold if it would require a division by zero.
144   bool div0 = false;
145   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
146     if (div0 || !b) {
147       div0 = true;
148       return a;
149     }
150     return a.udiv(b);
151   });
152 
153   // Fold out division by one. Assumes all tensors of all ones are splats.
154   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
155     if (rhs.getValue() == 1)
156       return lhs();
157   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
158     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
159       return lhs();
160   }
161 
162   return div0 ? Attribute() : result;
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // DivSIOp
167 //===----------------------------------------------------------------------===//
168 
169 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
170   // Don't fold if it would overflow or if it requires a division by zero.
171   bool overflowOrDiv0 = false;
172   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
173     if (overflowOrDiv0 || !b) {
174       overflowOrDiv0 = true;
175       return a;
176     }
177     return a.sdiv_ov(b, overflowOrDiv0);
178   });
179 
180   // Fold out division by one. Assumes all tensors of all ones are splats.
181   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
182     if (rhs.getValue() == 1)
183       return lhs();
184   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
185     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
186       return lhs();
187   }
188 
189   return overflowOrDiv0 ? Attribute() : result;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // Ceil and floor division folding helpers
194 //===----------------------------------------------------------------------===//
195 
196 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
197   // Returns (a-1)/b + 1
198   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
199   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
200   return val.sadd_ov(one, overflow);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // CeilDivSIOp
205 //===----------------------------------------------------------------------===//
206 
207 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
208   // Don't fold if it would overflow or if it requires a division by zero.
209   bool overflowOrDiv0 = false;
210   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
211     if (overflowOrDiv0 || !b) {
212       overflowOrDiv0 = true;
213       return a;
214     }
215     unsigned bits = a.getBitWidth();
216     APInt zero = APInt::getZero(bits);
217     if (a.sgt(zero) && b.sgt(zero)) {
218       // Both positive, return ceil(a, b).
219       return signedCeilNonnegInputs(a, b, overflowOrDiv0);
220     }
221     if (a.slt(zero) && b.slt(zero)) {
222       // Both negative, return ceil(-a, -b).
223       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
224       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
225       return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
226     }
227     if (a.slt(zero) && b.sgt(zero)) {
228       // A is negative, b is positive, return - ( -a / b).
229       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
230       APInt div = posA.sdiv_ov(b, overflowOrDiv0);
231       return zero.ssub_ov(div, overflowOrDiv0);
232     }
233     // A is positive (or zero), b is negative, return - (a / -b).
234     APInt posB = zero.ssub_ov(b, overflowOrDiv0);
235     APInt div = a.sdiv_ov(posB, overflowOrDiv0);
236     return zero.ssub_ov(div, overflowOrDiv0);
237   });
238 
239   // Fold out floor division by one. Assumes all tensors of all ones are
240   // splats.
241   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
242     if (rhs.getValue() == 1)
243       return lhs();
244   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
245     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
246       return lhs();
247   }
248 
249   return overflowOrDiv0 ? Attribute() : result;
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // FloorDivSIOp
254 //===----------------------------------------------------------------------===//
255 
256 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
257   // Don't fold if it would overflow or if it requires a division by zero.
258   bool overflowOrDiv0 = false;
259   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
260     if (overflowOrDiv0 || !b) {
261       overflowOrDiv0 = true;
262       return a;
263     }
264     unsigned bits = a.getBitWidth();
265     APInt zero = APInt::getZero(bits);
266     if (a.sge(zero) && b.sgt(zero)) {
267       // Both positive (or a is zero), return a / b.
268       return a.sdiv_ov(b, overflowOrDiv0);
269     }
270     if (a.sle(zero) && b.slt(zero)) {
271       // Both negative (or a is zero), return -a / -b.
272       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
273       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
274       return posA.sdiv_ov(posB, overflowOrDiv0);
275     }
276     if (a.slt(zero) && b.sgt(zero)) {
277       // A is negative, b is positive, return - ceil(-a, b).
278       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
279       APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
280       return zero.ssub_ov(ceil, overflowOrDiv0);
281     }
282     // A is positive, b is negative, return - ceil(a, -b).
283     APInt posB = zero.ssub_ov(b, overflowOrDiv0);
284     APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
285     return zero.ssub_ov(ceil, overflowOrDiv0);
286   });
287 
288   // Fold out floor division by one. Assumes all tensors of all ones are
289   // splats.
290   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
291     if (rhs.getValue() == 1)
292       return lhs();
293   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
294     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
295       return lhs();
296   }
297 
298   return overflowOrDiv0 ? Attribute() : result;
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // RemUIOp
303 //===----------------------------------------------------------------------===//
304 
305 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
306   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
307   if (!rhs)
308     return {};
309   auto rhsValue = rhs.getValue();
310 
311   // x % 1 = 0
312   if (rhsValue.isOneValue())
313     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
314 
315   // Don't fold if it requires division by zero.
316   if (rhsValue.isNullValue())
317     return {};
318 
319   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
320   if (!lhs)
321     return {};
322   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // RemSIOp
327 //===----------------------------------------------------------------------===//
328 
329 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
330   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
331   if (!rhs)
332     return {};
333   auto rhsValue = rhs.getValue();
334 
335   // x % 1 = 0
336   if (rhsValue.isOneValue())
337     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
338 
339   // Don't fold if it requires division by zero.
340   if (rhsValue.isNullValue())
341     return {};
342 
343   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
344   if (!lhs)
345     return {};
346   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // AndIOp
351 //===----------------------------------------------------------------------===//
352 
353 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
354   /// and(x, 0) -> 0
355   if (matchPattern(rhs(), m_Zero()))
356     return rhs();
357   /// and(x, allOnes) -> x
358   APInt intValue;
359   if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
360     return lhs();
361   /// and(x, x) -> x
362   if (lhs() == rhs())
363     return rhs();
364 
365   return constFoldBinaryOp<IntegerAttr>(operands,
366                                         [](APInt a, APInt b) { return a & b; });
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // OrIOp
371 //===----------------------------------------------------------------------===//
372 
373 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
374   /// or(x, 0) -> x
375   if (matchPattern(rhs(), m_Zero()))
376     return lhs();
377   /// or(x, x) -> x
378   if (lhs() == rhs())
379     return rhs();
380 
381   return constFoldBinaryOp<IntegerAttr>(operands,
382                                         [](APInt a, APInt b) { return a | b; });
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // XOrIOp
387 //===----------------------------------------------------------------------===//
388 
389 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
390   /// xor(x, 0) -> x
391   if (matchPattern(rhs(), m_Zero()))
392     return lhs();
393   /// xor(x, x) -> 0
394   if (lhs() == rhs())
395     return Builder(getContext()).getZeroAttr(getType());
396 
397   return constFoldBinaryOp<IntegerAttr>(operands,
398                                         [](APInt a, APInt b) { return a ^ b; });
399 }
400 
401 void arith::XOrIOp::getCanonicalizationPatterns(
402     OwningRewritePatternList &patterns, MLIRContext *context) {
403   patterns.insert<XOrINotCmpI>(context);
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // AddFOp
408 //===----------------------------------------------------------------------===//
409 
410 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
411   return constFoldBinaryOp<FloatAttr>(
412       operands, [](APFloat a, APFloat b) { return a + b; });
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // SubFOp
417 //===----------------------------------------------------------------------===//
418 
419 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
420   return constFoldBinaryOp<FloatAttr>(
421       operands, [](APFloat a, APFloat b) { return a - b; });
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // MulFOp
426 //===----------------------------------------------------------------------===//
427 
428 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
429   return constFoldBinaryOp<FloatAttr>(
430       operands, [](APFloat a, APFloat b) { return a * b; });
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // DivFOp
435 //===----------------------------------------------------------------------===//
436 
437 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
438   return constFoldBinaryOp<FloatAttr>(
439       operands, [](APFloat a, APFloat b) { return a / b; });
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // Verifiers for integer and floating point extension/truncation ops
444 //===----------------------------------------------------------------------===//
445 
446 // Extend ops can only extend to a wider type.
447 template <typename ValType, typename Op>
448 static LogicalResult verifyExtOp(Op op) {
449   Type srcType = getElementTypeOrSelf(op.in().getType());
450   Type dstType = getElementTypeOrSelf(op.getType());
451 
452   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
453     return op.emitError("result type ")
454            << dstType << " must be wider than operand type " << srcType;
455 
456   return success();
457 }
458 
459 // Truncate ops can only truncate to a shorter type.
460 template <typename ValType, typename Op>
461 static LogicalResult verifyTruncateOp(Op op) {
462   Type srcType = getElementTypeOrSelf(op.in().getType());
463   Type dstType = getElementTypeOrSelf(op.getType());
464 
465   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
466     return op.emitError("result type ")
467            << dstType << " must be shorter than operand type " << srcType;
468 
469   return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // ExtUIOp
474 //===----------------------------------------------------------------------===//
475 
476 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
477   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
478     return IntegerAttr::get(
479         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
480 
481   return {};
482 }
483 
484 //===----------------------------------------------------------------------===//
485 // ExtSIOp
486 //===----------------------------------------------------------------------===//
487 
488 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
489   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
490     return IntegerAttr::get(
491         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
492 
493   return {};
494 }
495 
496 //===----------------------------------------------------------------------===//
497 // IndexCastOp
498 //===----------------------------------------------------------------------===//
499 
500 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
501                                            TypeRange outputs) {
502   assert(inputs.size() == 1 && outputs.size() == 1 &&
503          "index_cast op expects one result and one result");
504 
505   // Shape equivalence is guaranteed by op traits.
506   auto srcType = getElementTypeOrSelf(inputs.front());
507   auto dstType = getElementTypeOrSelf(outputs.front());
508 
509   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
510          (srcType.isSignlessInteger() && dstType.isIndex());
511 }
512 
513 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
514   // index_cast(constant) -> constant
515   // A little hack because we go through int. Otherwise, the size of the
516   // constant might need to change.
517   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
518     return IntegerAttr::get(getType(), value.getInt());
519 
520   return {};
521 }
522 
523 void arith::IndexCastOp::getCanonicalizationPatterns(
524     OwningRewritePatternList &patterns, MLIRContext *context) {
525   patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // BitcastOp
530 //===----------------------------------------------------------------------===//
531 
532 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
533   assert(inputs.size() == 1 && outputs.size() == 1 &&
534          "bitcast op expects one operand and one result");
535 
536   // Shape equivalence is guaranteed by op traits.
537   auto srcType = getElementTypeOrSelf(inputs.front());
538   auto dstType = getElementTypeOrSelf(outputs.front());
539 
540   // Types are guarnateed to be integers or floats by constraints.
541   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
542 }
543 
544 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
545   assert(operands.size() == 1 && "bitcast op expects 1 operand");
546 
547   auto resType = getType();
548   auto operand = operands[0];
549   if (!operand)
550     return {};
551 
552   /// Bitcast dense elements.
553   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
554     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
555   /// Other shaped types unhandled.
556   if (resType.isa<ShapedType>())
557     return {};
558 
559   /// Bitcast integer or float to integer or float.
560   APInt bits = operand.isa<FloatAttr>()
561                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
562                    : operand.cast<IntegerAttr>().getValue();
563 
564   if (auto resFloatType = resType.dyn_cast<FloatType>())
565     return FloatAttr::get(resType,
566                           APFloat(resFloatType.getFloatSemantics(), bits));
567   return IntegerAttr::get(resType, bits);
568 }
569 
570 void arith::BitcastOp::getCanonicalizationPatterns(
571     OwningRewritePatternList &patterns, MLIRContext *context) {
572   patterns.insert<BitcastOfBitcast>(context);
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // Helpers for compare ops
577 //===----------------------------------------------------------------------===//
578 
579 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
580 static Type getI1SameShape(Type type) {
581   auto i1Type = IntegerType::get(type.getContext(), 1);
582   if (auto tensorType = type.dyn_cast<RankedTensorType>())
583     return RankedTensorType::get(tensorType.getShape(), i1Type);
584   if (type.isa<UnrankedTensorType>())
585     return UnrankedTensorType::get(i1Type);
586   if (auto vectorType = type.dyn_cast<VectorType>())
587     return VectorType::get(vectorType.getShape(), i1Type);
588   return i1Type;
589 }
590 
591 //===----------------------------------------------------------------------===//
592 // CmpIOp
593 //===----------------------------------------------------------------------===//
594 
595 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
596 /// comparison predicates.
597 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
598                                     const APInt &lhs, const APInt &rhs) {
599   switch (predicate) {
600   case arith::CmpIPredicate::eq:
601     return lhs.eq(rhs);
602   case arith::CmpIPredicate::ne:
603     return lhs.ne(rhs);
604   case arith::CmpIPredicate::slt:
605     return lhs.slt(rhs);
606   case arith::CmpIPredicate::sle:
607     return lhs.sle(rhs);
608   case arith::CmpIPredicate::sgt:
609     return lhs.sgt(rhs);
610   case arith::CmpIPredicate::sge:
611     return lhs.sge(rhs);
612   case arith::CmpIPredicate::ult:
613     return lhs.ult(rhs);
614   case arith::CmpIPredicate::ule:
615     return lhs.ule(rhs);
616   case arith::CmpIPredicate::ugt:
617     return lhs.ugt(rhs);
618   case arith::CmpIPredicate::uge:
619     return lhs.uge(rhs);
620   }
621   llvm_unreachable("unknown cmpi predicate kind");
622 }
623 
624 /// Returns true if the predicate is true for two equal operands.
625 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
626   switch (predicate) {
627   case arith::CmpIPredicate::eq:
628   case arith::CmpIPredicate::sle:
629   case arith::CmpIPredicate::sge:
630   case arith::CmpIPredicate::ule:
631   case arith::CmpIPredicate::uge:
632     return true;
633   case arith::CmpIPredicate::ne:
634   case arith::CmpIPredicate::slt:
635   case arith::CmpIPredicate::sgt:
636   case arith::CmpIPredicate::ult:
637   case arith::CmpIPredicate::ugt:
638     return false;
639   }
640   llvm_unreachable("unknown cmpi predicate kind");
641 }
642 
643 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
644   assert(operands.size() == 2 && "cmpi takes two operands");
645 
646   // cmpi(pred, x, x)
647   if (lhs() == rhs()) {
648     auto val = applyCmpPredicateToEqualOperands(getPredicate());
649     return BoolAttr::get(getContext(), val);
650   }
651 
652   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
653   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
654   if (!lhs || !rhs)
655     return {};
656 
657   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
658   return BoolAttr::get(getContext(), val);
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // CmpFOp
663 //===----------------------------------------------------------------------===//
664 
665 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
666 /// comparison predicates.
667 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
668                                     const APFloat &lhs, const APFloat &rhs) {
669   auto cmpResult = lhs.compare(rhs);
670   switch (predicate) {
671   case arith::CmpFPredicate::AlwaysFalse:
672     return false;
673   case arith::CmpFPredicate::OEQ:
674     return cmpResult == APFloat::cmpEqual;
675   case arith::CmpFPredicate::OGT:
676     return cmpResult == APFloat::cmpGreaterThan;
677   case arith::CmpFPredicate::OGE:
678     return cmpResult == APFloat::cmpGreaterThan ||
679            cmpResult == APFloat::cmpEqual;
680   case arith::CmpFPredicate::OLT:
681     return cmpResult == APFloat::cmpLessThan;
682   case arith::CmpFPredicate::OLE:
683     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
684   case arith::CmpFPredicate::ONE:
685     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
686   case arith::CmpFPredicate::ORD:
687     return cmpResult != APFloat::cmpUnordered;
688   case arith::CmpFPredicate::UEQ:
689     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
690   case arith::CmpFPredicate::UGT:
691     return cmpResult == APFloat::cmpUnordered ||
692            cmpResult == APFloat::cmpGreaterThan;
693   case arith::CmpFPredicate::UGE:
694     return cmpResult == APFloat::cmpUnordered ||
695            cmpResult == APFloat::cmpGreaterThan ||
696            cmpResult == APFloat::cmpEqual;
697   case arith::CmpFPredicate::ULT:
698     return cmpResult == APFloat::cmpUnordered ||
699            cmpResult == APFloat::cmpLessThan;
700   case arith::CmpFPredicate::ULE:
701     return cmpResult == APFloat::cmpUnordered ||
702            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
703   case arith::CmpFPredicate::UNE:
704     return cmpResult != APFloat::cmpEqual;
705   case arith::CmpFPredicate::UNO:
706     return cmpResult == APFloat::cmpUnordered;
707   case arith::CmpFPredicate::AlwaysTrue:
708     return true;
709   }
710   llvm_unreachable("unknown cmpf predicate kind");
711 }
712 
713 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
714   assert(operands.size() == 2 && "cmpf takes two operands");
715 
716   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
717   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
718 
719   if (!lhs || !rhs)
720     return {};
721 
722   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
723   return BoolAttr::get(getContext(), val);
724 }
725 
726 //===----------------------------------------------------------------------===//
727 // TableGen'd op method definitions
728 //===----------------------------------------------------------------------===//
729 
730 #define GET_OP_CLASSES
731 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
732 
733 //===----------------------------------------------------------------------===//
734 // TableGen'd enum attribute definitions
735 //===----------------------------------------------------------------------===//
736 
737 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
738