1 //===- InferIntRangeInterface.cpp -  Integer range inference interface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/InferIntRangeInterface.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
12 
13 using namespace mlir;
14 
operator ==(const ConstantIntRanges & other) const15 bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
16   return umin().getBitWidth() == other.umin().getBitWidth() &&
17          umin() == other.umin() && umax() == other.umax() &&
18          smin() == other.smin() && smax() == other.smax();
19 }
20 
umin() const21 const APInt &ConstantIntRanges::umin() const { return uminVal; }
22 
umax() const23 const APInt &ConstantIntRanges::umax() const { return umaxVal; }
24 
smin() const25 const APInt &ConstantIntRanges::smin() const { return sminVal; }
26 
smax() const27 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
28 
getStorageBitwidth(Type type)29 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
30   if (type.isIndex())
31     return IndexType::kInternalStorageBitWidth;
32   if (auto integerType = type.dyn_cast<IntegerType>())
33     return integerType.getWidth();
34   // Non-integer types have their bounds stored in width 0 `APInt`s.
35   return 0;
36 }
37 
maxRange(unsigned bitwidth)38 ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
39   return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
40 }
41 
constant(const APInt & value)42 ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
43   return {value, value, value, value};
44 }
45 
range(const APInt & min,const APInt & max,bool isSigned)46 ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max,
47                                            bool isSigned) {
48   if (isSigned)
49     return fromSigned(min, max);
50   return fromUnsigned(min, max);
51 }
52 
fromSigned(const APInt & smin,const APInt & smax)53 ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
54                                                 const APInt &smax) {
55   unsigned int width = smin.getBitWidth();
56   APInt umin, umax;
57   if (smin.isNonNegative() == smax.isNonNegative()) {
58     umin = smin.ult(smax) ? smin : smax;
59     umax = smin.ugt(smax) ? smin : smax;
60   } else {
61     umin = APInt::getMinValue(width);
62     umax = APInt::getMaxValue(width);
63   }
64   return {umin, umax, smin, smax};
65 }
66 
fromUnsigned(const APInt & umin,const APInt & umax)67 ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
68                                                   const APInt &umax) {
69   unsigned int width = umin.getBitWidth();
70   APInt smin, smax;
71   if (umin.isNonNegative() == umax.isNonNegative()) {
72     smin = umin.slt(umax) ? umin : umax;
73     smax = umin.sgt(umax) ? umin : umax;
74   } else {
75     smin = APInt::getSignedMinValue(width);
76     smax = APInt::getSignedMaxValue(width);
77   }
78   return {umin, umax, smin, smax};
79 }
80 
81 ConstantIntRanges
rangeUnion(const ConstantIntRanges & other) const82 ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
83   // "Not an integer" poisons everything and also cannot be fed to comparison
84   // operators.
85   if (umin().getBitWidth() == 0)
86     return *this;
87   if (other.umin().getBitWidth() == 0)
88     return other;
89 
90   const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
91   const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
92   const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
93   const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
94 
95   return {uminUnion, umaxUnion, sminUnion, smaxUnion};
96 }
97 
98 ConstantIntRanges
intersection(const ConstantIntRanges & other) const99 ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
100   // "Not an integer" poisons everything and also cannot be fed to comparison
101   // operators.
102   if (umin().getBitWidth() == 0)
103     return *this;
104   if (other.umin().getBitWidth() == 0)
105     return other;
106 
107   const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
108   const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
109   const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
110   const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
111 
112   return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
113 }
114 
getConstantValue() const115 Optional<APInt> ConstantIntRanges::getConstantValue() const {
116   // Note: we need to exclude the trivially-equal width 0 values here.
117   if (umin() == umax() && umin().getBitWidth() != 0)
118     return umin();
119   if (smin() == smax() && smin().getBitWidth() != 0)
120     return smin();
121   return None;
122 }
123 
operator <<(raw_ostream & os,const ConstantIntRanges & range)124 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
125   return os << "unsigned : [" << range.umin() << ", " << range.umax()
126             << "] signed : [" << range.smin() << ", " << range.smax() << "]";
127 }
128