1 //===- InferIntRangeInterface.cpp -  Integer range inference interface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/InferIntRangeInterface.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
12 
13 using namespace mlir;
14 
15 bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
16   return umin().getBitWidth() == other.umin().getBitWidth() &&
17          umin() == other.umin() && umax() == other.umax() &&
18          smin() == other.smin() && smax() == other.smax();
19 }
20 
21 const APInt &ConstantIntRanges::umin() const { return uminVal; }
22 
23 const APInt &ConstantIntRanges::umax() const { return umaxVal; }
24 
25 const APInt &ConstantIntRanges::smin() const { return sminVal; }
26 
27 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
28 
29 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
30   if (type.isIndex())
31     return IndexType::kInternalStorageBitWidth;
32   if (auto integerType = type.dyn_cast<IntegerType>())
33     return integerType.getWidth();
34   // Non-integer types have their bounds stored in width 0 `APInt`s.
35   return 0;
36 }
37 
38 ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) {
39   return {min, max, min, max};
40 }
41 
42 ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
43                                                 const APInt &smax) {
44   unsigned int width = smin.getBitWidth();
45   APInt umin, umax;
46   if (smin.isNonNegative() == smax.isNonNegative()) {
47     umin = smin.ult(smax) ? smin : smax;
48     umax = smin.ugt(smax) ? smin : smax;
49   } else {
50     umin = APInt::getMinValue(width);
51     umax = APInt::getMaxValue(width);
52   }
53   return {umin, umax, smin, smax};
54 }
55 
56 ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
57                                                   const APInt &umax) {
58   unsigned int width = umin.getBitWidth();
59   APInt smin, smax;
60   if (umin.isNonNegative() == umax.isNonNegative()) {
61     smin = umin.slt(umax) ? umin : umax;
62     smax = umin.sgt(umax) ? umin : umax;
63   } else {
64     smin = APInt::getSignedMinValue(width);
65     smax = APInt::getSignedMaxValue(width);
66   }
67   return {umin, umax, smin, smax};
68 }
69 
70 ConstantIntRanges
71 ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
72   // "Not an integer" poisons everything and also cannot be fed to comparison
73   // operators.
74   if (umin().getBitWidth() == 0)
75     return *this;
76   if (other.umin().getBitWidth() == 0)
77     return other;
78 
79   const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
80   const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
81   const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
82   const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
83 
84   return {uminUnion, umaxUnion, sminUnion, smaxUnion};
85 }
86 
87 Optional<APInt> ConstantIntRanges::getConstantValue() const {
88   // Note: we need to exclude the trivially-equal width 0 values here.
89   if (umin() == umax() && umin().getBitWidth() != 0)
90     return umin();
91   if (smin() == smax() && smin().getBitWidth() != 0)
92     return smin();
93   return None;
94 }
95 
96 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
97   return os << "unsigned : [" << range.umin() << ", " << range.umax()
98             << "] signed : [" << range.smin() << ", " << range.smax() << "]";
99 }
100