1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/CommonFolders.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 using namespace mlir;
18 using namespace mlir::arith;
19 
20 //===----------------------------------------------------------------------===//
21 // Pattern helpers
22 //===----------------------------------------------------------------------===//
23 
24 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
25                                    Attribute lhs, Attribute rhs) {
26   return builder.getIntegerAttr(res.getType(),
27                                 lhs.cast<IntegerAttr>().getInt() +
28                                     rhs.cast<IntegerAttr>().getInt());
29 }
30 
31 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
32                                    Attribute lhs, Attribute rhs) {
33   return builder.getIntegerAttr(res.getType(),
34                                 lhs.cast<IntegerAttr>().getInt() -
35                                     rhs.cast<IntegerAttr>().getInt());
36 }
37 
38 /// Invert an integer comparison predicate.
39 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
40   switch (pred) {
41   case arith::CmpIPredicate::eq:
42     return arith::CmpIPredicate::ne;
43   case arith::CmpIPredicate::ne:
44     return arith::CmpIPredicate::eq;
45   case arith::CmpIPredicate::slt:
46     return arith::CmpIPredicate::sge;
47   case arith::CmpIPredicate::sle:
48     return arith::CmpIPredicate::sgt;
49   case arith::CmpIPredicate::sgt:
50     return arith::CmpIPredicate::sle;
51   case arith::CmpIPredicate::sge:
52     return arith::CmpIPredicate::slt;
53   case arith::CmpIPredicate::ult:
54     return arith::CmpIPredicate::uge;
55   case arith::CmpIPredicate::ule:
56     return arith::CmpIPredicate::ugt;
57   case arith::CmpIPredicate::ugt:
58     return arith::CmpIPredicate::ule;
59   case arith::CmpIPredicate::uge:
60     return arith::CmpIPredicate::ult;
61   }
62   llvm_unreachable("unknown cmpi predicate kind");
63 }
64 
65 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
66   return arith::CmpIPredicateAttr::get(pred.getContext(),
67                                        invertPredicate(pred.getValue()));
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // TableGen'd canonicalization patterns
72 //===----------------------------------------------------------------------===//
73 
74 namespace {
75 #include "ArithmeticCanonicalization.inc"
76 } // end anonymous namespace
77 
78 //===----------------------------------------------------------------------===//
79 // AddIOp
80 //===----------------------------------------------------------------------===//
81 
82 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
83   // addi(x, 0) -> x
84   if (matchPattern(rhs(), m_Zero()))
85     return lhs();
86 
87   return constFoldBinaryOp<IntegerAttr>(operands,
88                                         [](APInt a, APInt b) { return a + b; });
89 }
90 
91 void arith::AddIOp::getCanonicalizationPatterns(
92     OwningRewritePatternList &patterns, MLIRContext *context) {
93   patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
94       context);
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // SubIOp
99 //===----------------------------------------------------------------------===//
100 
101 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
102   // subi(x,x) -> 0
103   if (getOperand(0) == getOperand(1))
104     return Builder(getContext()).getZeroAttr(getType());
105   // subi(x,0) -> x
106   if (matchPattern(rhs(), m_Zero()))
107     return lhs();
108 
109   return constFoldBinaryOp<IntegerAttr>(operands,
110                                         [](APInt a, APInt b) { return a - b; });
111 }
112 
113 void arith::SubIOp::getCanonicalizationPatterns(
114     OwningRewritePatternList &patterns, MLIRContext *context) {
115   patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
116                   SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
117                   SubILHSSubConstantLHS>(context);
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // MulIOp
122 //===----------------------------------------------------------------------===//
123 
124 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
125   // muli(x, 0) -> 0
126   if (matchPattern(rhs(), m_Zero()))
127     return rhs();
128   // muli(x, 1) -> x
129   if (matchPattern(rhs(), m_One()))
130     return getOperand(0);
131   // TODO: Handle the overflow case.
132 
133   // default folder
134   return constFoldBinaryOp<IntegerAttr>(operands,
135                                         [](APInt a, APInt b) { return a * b; });
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // DivUIOp
140 //===----------------------------------------------------------------------===//
141 
142 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
143   // Don't fold if it would require a division by zero.
144   bool div0 = false;
145   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
146     if (div0 || !b) {
147       div0 = true;
148       return a;
149     }
150     return a.udiv(b);
151   });
152 
153   // Fold out division by one. Assumes all tensors of all ones are splats.
154   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
155     if (rhs.getValue() == 1)
156       return lhs();
157   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
158     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
159       return lhs();
160   }
161 
162   return div0 ? Attribute() : result;
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // DivSIOp
167 //===----------------------------------------------------------------------===//
168 
169 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
170   // Don't fold if it would overflow or if it requires a division by zero.
171   bool overflowOrDiv0 = false;
172   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
173     if (overflowOrDiv0 || !b) {
174       overflowOrDiv0 = true;
175       return a;
176     }
177     return a.sdiv_ov(b, overflowOrDiv0);
178   });
179 
180   // Fold out division by one. Assumes all tensors of all ones are splats.
181   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
182     if (rhs.getValue() == 1)
183       return lhs();
184   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
185     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
186       return lhs();
187   }
188 
189   return overflowOrDiv0 ? Attribute() : result;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // Ceil and floor division folding helpers
194 //===----------------------------------------------------------------------===//
195 
196 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
197   // Returns (a-1)/b + 1
198   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
199   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
200   return val.sadd_ov(one, overflow);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // CeilDivSIOp
205 //===----------------------------------------------------------------------===//
206 
207 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
208   // Don't fold if it would overflow or if it requires a division by zero.
209   bool overflowOrDiv0 = false;
210   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
211     if (overflowOrDiv0 || !b) {
212       overflowOrDiv0 = true;
213       return a;
214     }
215     unsigned bits = a.getBitWidth();
216     APInt zero = APInt::getZero(bits);
217     if (a.sgt(zero) && b.sgt(zero)) {
218       // Both positive, return ceil(a, b).
219       return signedCeilNonnegInputs(a, b, overflowOrDiv0);
220     }
221     if (a.slt(zero) && b.slt(zero)) {
222       // Both negative, return ceil(-a, -b).
223       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
224       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
225       return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
226     }
227     if (a.slt(zero) && b.sgt(zero)) {
228       // A is negative, b is positive, return - ( -a / b).
229       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
230       APInt div = posA.sdiv_ov(b, overflowOrDiv0);
231       return zero.ssub_ov(div, overflowOrDiv0);
232     }
233     // A is positive (or zero), b is negative, return - (a / -b).
234     APInt posB = zero.ssub_ov(b, overflowOrDiv0);
235     APInt div = a.sdiv_ov(posB, overflowOrDiv0);
236     return zero.ssub_ov(div, overflowOrDiv0);
237   });
238 
239   // Fold out floor division by one. Assumes all tensors of all ones are
240   // splats.
241   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
242     if (rhs.getValue() == 1)
243       return lhs();
244   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
245     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
246       return lhs();
247   }
248 
249   return overflowOrDiv0 ? Attribute() : result;
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // FloorDivSIOp
254 //===----------------------------------------------------------------------===//
255 
256 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
257   // Don't fold if it would overflow or if it requires a division by zero.
258   bool overflowOrDiv0 = false;
259   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
260     if (overflowOrDiv0 || !b) {
261       overflowOrDiv0 = true;
262       return a;
263     }
264     unsigned bits = a.getBitWidth();
265     APInt zero = APInt::getZero(bits);
266     if (a.sge(zero) && b.sgt(zero)) {
267       // Both positive (or a is zero), return a / b.
268       return a.sdiv_ov(b, overflowOrDiv0);
269     }
270     if (a.sle(zero) && b.slt(zero)) {
271       // Both negative (or a is zero), return -a / -b.
272       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
273       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
274       return posA.sdiv_ov(posB, overflowOrDiv0);
275     }
276     if (a.slt(zero) && b.sgt(zero)) {
277       // A is negative, b is positive, return - ceil(-a, b).
278       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
279       APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
280       return zero.ssub_ov(ceil, overflowOrDiv0);
281     }
282     // A is positive, b is negative, return - ceil(a, -b).
283     APInt posB = zero.ssub_ov(b, overflowOrDiv0);
284     APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
285     return zero.ssub_ov(ceil, overflowOrDiv0);
286   });
287 
288   // Fold out floor division by one. Assumes all tensors of all ones are
289   // splats.
290   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
291     if (rhs.getValue() == 1)
292       return lhs();
293   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
294     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
295       return lhs();
296   }
297 
298   return overflowOrDiv0 ? Attribute() : result;
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // RemUIOp
303 //===----------------------------------------------------------------------===//
304 
305 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
306   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
307   if (!rhs)
308     return {};
309   auto rhsValue = rhs.getValue();
310 
311   // x % 1 = 0
312   if (rhsValue.isOneValue())
313     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
314 
315   // Don't fold if it requires division by zero.
316   if (rhsValue.isNullValue())
317     return {};
318 
319   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
320   if (!lhs)
321     return {};
322   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // RemSIOp
327 //===----------------------------------------------------------------------===//
328 
329 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
330   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
331   if (!rhs)
332     return {};
333   auto rhsValue = rhs.getValue();
334 
335   // x % 1 = 0
336   if (rhsValue.isOneValue())
337     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
338 
339   // Don't fold if it requires division by zero.
340   if (rhsValue.isNullValue())
341     return {};
342 
343   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
344   if (!lhs)
345     return {};
346   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // AndIOp
351 //===----------------------------------------------------------------------===//
352 
353 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
354   /// and(x, 0) -> 0
355   if (matchPattern(rhs(), m_Zero()))
356     return rhs();
357   /// and(x, allOnes) -> x
358   APInt intValue;
359   if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
360     return lhs();
361   /// and(x, x) -> x
362   if (lhs() == rhs())
363     return rhs();
364 
365   return constFoldBinaryOp<IntegerAttr>(operands,
366                                         [](APInt a, APInt b) { return a & b; });
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // OrIOp
371 //===----------------------------------------------------------------------===//
372 
373 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
374   /// or(x, 0) -> x
375   if (matchPattern(rhs(), m_Zero()))
376     return lhs();
377   /// or(x, x) -> x
378   if (lhs() == rhs())
379     return rhs();
380 
381   return constFoldBinaryOp<IntegerAttr>(operands,
382                                         [](APInt a, APInt b) { return a | b; });
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // XOrIOp
387 //===----------------------------------------------------------------------===//
388 
389 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
390   /// xor(x, 0) -> x
391   if (matchPattern(rhs(), m_Zero()))
392     return lhs();
393   /// xor(x, x) -> 0
394   if (lhs() == rhs())
395     return Builder(getContext()).getZeroAttr(getType());
396 
397   return constFoldBinaryOp<IntegerAttr>(operands,
398                                         [](APInt a, APInt b) { return a ^ b; });
399 }
400 
401 void arith::XOrIOp::getCanonicalizationPatterns(
402     OwningRewritePatternList &patterns, MLIRContext *context) {
403   patterns.insert<XOrINotCmpI>(context);
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // AddFOp
408 //===----------------------------------------------------------------------===//
409 
410 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
411   return constFoldBinaryOp<FloatAttr>(
412       operands, [](APFloat a, APFloat b) { return a + b; });
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // SubFOp
417 //===----------------------------------------------------------------------===//
418 
419 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
420   return constFoldBinaryOp<FloatAttr>(
421       operands, [](APFloat a, APFloat b) { return a - b; });
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // MulFOp
426 //===----------------------------------------------------------------------===//
427 
428 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
429   return constFoldBinaryOp<FloatAttr>(
430       operands, [](APFloat a, APFloat b) { return a * b; });
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // DivFOp
435 //===----------------------------------------------------------------------===//
436 
437 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
438   return constFoldBinaryOp<FloatAttr>(
439       operands, [](APFloat a, APFloat b) { return a / b; });
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // Verifiers for integer and floating point extension/truncation ops
444 //===----------------------------------------------------------------------===//
445 
446 // Extend ops can only extend to a wider type.
447 template <typename ValType, typename Op>
448 static LogicalResult verifyExtOp(Op op) {
449   Type srcType = getElementTypeOrSelf(op.in().getType());
450   Type dstType = getElementTypeOrSelf(op.getType());
451 
452   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
453     return op.emitError("result type ")
454            << dstType << " must be wider than operand type " << srcType;
455 
456   return success();
457 }
458 
459 // Truncate ops can only truncate to a shorter type.
460 template <typename ValType, typename Op>
461 static LogicalResult verifyTruncateOp(Op op) {
462   Type srcType = getElementTypeOrSelf(op.in().getType());
463   Type dstType = getElementTypeOrSelf(op.getType());
464 
465   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
466     return op.emitError("result type ")
467            << dstType << " must be shorter than operand type " << srcType;
468 
469   return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // ExtUIOp
474 //===----------------------------------------------------------------------===//
475 
476 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
477   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
478     return IntegerAttr::get(
479         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
480 
481   return {};
482 }
483 
484 //===----------------------------------------------------------------------===//
485 // ExtSIOp
486 //===----------------------------------------------------------------------===//
487 
488 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
489   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
490     return IntegerAttr::get(
491         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
492 
493   return {};
494 }
495 
496 // TODO temporary fixes until second patch is in
497 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
498   return {};
499 }
500 
501 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
502   return true;
503 }
504 
505 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
506   return {};
507 }
508 
509 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
510   return true;
511 }
512 
513 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
514   return true;
515 }
516 
517 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
518   return true;
519 }
520 
521 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
522   return true;
523 }
524 
525 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
526   return {};
527 }
528 
529 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
530   return true;
531 }
532 
533 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
534   return true;
535 }
536 
537 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
538   return true;
539 }
540 
541 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
542   return true;
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // IndexCastOp
547 //===----------------------------------------------------------------------===//
548 
549 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
550                                            TypeRange outputs) {
551   assert(inputs.size() == 1 && outputs.size() == 1 &&
552          "index_cast op expects one result and one result");
553 
554   // Shape equivalence is guaranteed by op traits.
555   auto srcType = getElementTypeOrSelf(inputs.front());
556   auto dstType = getElementTypeOrSelf(outputs.front());
557 
558   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
559          (srcType.isSignlessInteger() && dstType.isIndex());
560 }
561 
562 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
563   // index_cast(constant) -> constant
564   // A little hack because we go through int. Otherwise, the size of the
565   // constant might need to change.
566   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
567     return IntegerAttr::get(getType(), value.getInt());
568 
569   return {};
570 }
571 
572 void arith::IndexCastOp::getCanonicalizationPatterns(
573     OwningRewritePatternList &patterns, MLIRContext *context) {
574   patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // BitcastOp
579 //===----------------------------------------------------------------------===//
580 
581 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
582   assert(inputs.size() == 1 && outputs.size() == 1 &&
583          "bitcast op expects one operand and one result");
584 
585   // Shape equivalence is guaranteed by op traits.
586   auto srcType = getElementTypeOrSelf(inputs.front());
587   auto dstType = getElementTypeOrSelf(outputs.front());
588 
589   // Types are guarnateed to be integers or floats by constraints.
590   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
591 }
592 
593 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
594   assert(operands.size() == 1 && "bitcast op expects 1 operand");
595 
596   auto resType = getType();
597   auto operand = operands[0];
598   if (!operand)
599     return {};
600 
601   /// Bitcast dense elements.
602   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
603     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
604   /// Other shaped types unhandled.
605   if (resType.isa<ShapedType>())
606     return {};
607 
608   /// Bitcast integer or float to integer or float.
609   APInt bits = operand.isa<FloatAttr>()
610                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
611                    : operand.cast<IntegerAttr>().getValue();
612 
613   if (auto resFloatType = resType.dyn_cast<FloatType>())
614     return FloatAttr::get(resType,
615                           APFloat(resFloatType.getFloatSemantics(), bits));
616   return IntegerAttr::get(resType, bits);
617 }
618 
619 void arith::BitcastOp::getCanonicalizationPatterns(
620     OwningRewritePatternList &patterns, MLIRContext *context) {
621   patterns.insert<BitcastOfBitcast>(context);
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // Helpers for compare ops
626 //===----------------------------------------------------------------------===//
627 
628 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
629 static Type getI1SameShape(Type type) {
630   auto i1Type = IntegerType::get(type.getContext(), 1);
631   if (auto tensorType = type.dyn_cast<RankedTensorType>())
632     return RankedTensorType::get(tensorType.getShape(), i1Type);
633   if (type.isa<UnrankedTensorType>())
634     return UnrankedTensorType::get(i1Type);
635   if (auto vectorType = type.dyn_cast<VectorType>())
636     return VectorType::get(vectorType.getShape(), i1Type);
637   return i1Type;
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // CmpIOp
642 //===----------------------------------------------------------------------===//
643 
644 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
645 /// comparison predicates.
646 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
647                                     const APInt &lhs, const APInt &rhs) {
648   switch (predicate) {
649   case arith::CmpIPredicate::eq:
650     return lhs.eq(rhs);
651   case arith::CmpIPredicate::ne:
652     return lhs.ne(rhs);
653   case arith::CmpIPredicate::slt:
654     return lhs.slt(rhs);
655   case arith::CmpIPredicate::sle:
656     return lhs.sle(rhs);
657   case arith::CmpIPredicate::sgt:
658     return lhs.sgt(rhs);
659   case arith::CmpIPredicate::sge:
660     return lhs.sge(rhs);
661   case arith::CmpIPredicate::ult:
662     return lhs.ult(rhs);
663   case arith::CmpIPredicate::ule:
664     return lhs.ule(rhs);
665   case arith::CmpIPredicate::ugt:
666     return lhs.ugt(rhs);
667   case arith::CmpIPredicate::uge:
668     return lhs.uge(rhs);
669   }
670   llvm_unreachable("unknown cmpi predicate kind");
671 }
672 
673 /// Returns true if the predicate is true for two equal operands.
674 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
675   switch (predicate) {
676   case arith::CmpIPredicate::eq:
677   case arith::CmpIPredicate::sle:
678   case arith::CmpIPredicate::sge:
679   case arith::CmpIPredicate::ule:
680   case arith::CmpIPredicate::uge:
681     return true;
682   case arith::CmpIPredicate::ne:
683   case arith::CmpIPredicate::slt:
684   case arith::CmpIPredicate::sgt:
685   case arith::CmpIPredicate::ult:
686   case arith::CmpIPredicate::ugt:
687     return false;
688   }
689   llvm_unreachable("unknown cmpi predicate kind");
690 }
691 
692 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
693   assert(operands.size() == 2 && "cmpi takes two operands");
694 
695   // cmpi(pred, x, x)
696   if (lhs() == rhs()) {
697     auto val = applyCmpPredicateToEqualOperands(getPredicate());
698     return BoolAttr::get(getContext(), val);
699   }
700 
701   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
702   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
703   if (!lhs || !rhs)
704     return {};
705 
706   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
707   return BoolAttr::get(getContext(), val);
708 }
709 
710 //===----------------------------------------------------------------------===//
711 // CmpFOp
712 //===----------------------------------------------------------------------===//
713 
714 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
715 /// comparison predicates.
716 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
717                                     const APFloat &lhs, const APFloat &rhs) {
718   auto cmpResult = lhs.compare(rhs);
719   switch (predicate) {
720   case arith::CmpFPredicate::AlwaysFalse:
721     return false;
722   case arith::CmpFPredicate::OEQ:
723     return cmpResult == APFloat::cmpEqual;
724   case arith::CmpFPredicate::OGT:
725     return cmpResult == APFloat::cmpGreaterThan;
726   case arith::CmpFPredicate::OGE:
727     return cmpResult == APFloat::cmpGreaterThan ||
728            cmpResult == APFloat::cmpEqual;
729   case arith::CmpFPredicate::OLT:
730     return cmpResult == APFloat::cmpLessThan;
731   case arith::CmpFPredicate::OLE:
732     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
733   case arith::CmpFPredicate::ONE:
734     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
735   case arith::CmpFPredicate::ORD:
736     return cmpResult != APFloat::cmpUnordered;
737   case arith::CmpFPredicate::UEQ:
738     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
739   case arith::CmpFPredicate::UGT:
740     return cmpResult == APFloat::cmpUnordered ||
741            cmpResult == APFloat::cmpGreaterThan;
742   case arith::CmpFPredicate::UGE:
743     return cmpResult == APFloat::cmpUnordered ||
744            cmpResult == APFloat::cmpGreaterThan ||
745            cmpResult == APFloat::cmpEqual;
746   case arith::CmpFPredicate::ULT:
747     return cmpResult == APFloat::cmpUnordered ||
748            cmpResult == APFloat::cmpLessThan;
749   case arith::CmpFPredicate::ULE:
750     return cmpResult == APFloat::cmpUnordered ||
751            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
752   case arith::CmpFPredicate::UNE:
753     return cmpResult != APFloat::cmpEqual;
754   case arith::CmpFPredicate::UNO:
755     return cmpResult == APFloat::cmpUnordered;
756   case arith::CmpFPredicate::AlwaysTrue:
757     return true;
758   }
759   llvm_unreachable("unknown cmpf predicate kind");
760 }
761 
762 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
763   assert(operands.size() == 2 && "cmpf takes two operands");
764 
765   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
766   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
767 
768   if (!lhs || !rhs)
769     return {};
770 
771   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
772   return BoolAttr::get(getContext(), val);
773 }
774 
775 //===----------------------------------------------------------------------===//
776 // TableGen'd op method definitions
777 //===----------------------------------------------------------------------===//
778 
779 #define GET_OP_CLASSES
780 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
781 
782 //===----------------------------------------------------------------------===//
783 // TableGen'd enum attribute definitions
784 //===----------------------------------------------------------------------===//
785 
786 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
787