1 //===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Interfaces/InferIntRangeInterface.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
12
13 using namespace mlir;
14
operator ==(const ConstantIntRanges & other) const15 bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
16 return umin().getBitWidth() == other.umin().getBitWidth() &&
17 umin() == other.umin() && umax() == other.umax() &&
18 smin() == other.smin() && smax() == other.smax();
19 }
20
umin() const21 const APInt &ConstantIntRanges::umin() const { return uminVal; }
22
umax() const23 const APInt &ConstantIntRanges::umax() const { return umaxVal; }
24
smin() const25 const APInt &ConstantIntRanges::smin() const { return sminVal; }
26
smax() const27 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
28
getStorageBitwidth(Type type)29 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
30 if (type.isIndex())
31 return IndexType::kInternalStorageBitWidth;
32 if (auto integerType = type.dyn_cast<IntegerType>())
33 return integerType.getWidth();
34 // Non-integer types have their bounds stored in width 0 `APInt`s.
35 return 0;
36 }
37
maxRange(unsigned bitwidth)38 ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
39 return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
40 }
41
constant(const APInt & value)42 ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
43 return {value, value, value, value};
44 }
45
range(const APInt & min,const APInt & max,bool isSigned)46 ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max,
47 bool isSigned) {
48 if (isSigned)
49 return fromSigned(min, max);
50 return fromUnsigned(min, max);
51 }
52
fromSigned(const APInt & smin,const APInt & smax)53 ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
54 const APInt &smax) {
55 unsigned int width = smin.getBitWidth();
56 APInt umin, umax;
57 if (smin.isNonNegative() == smax.isNonNegative()) {
58 umin = smin.ult(smax) ? smin : smax;
59 umax = smin.ugt(smax) ? smin : smax;
60 } else {
61 umin = APInt::getMinValue(width);
62 umax = APInt::getMaxValue(width);
63 }
64 return {umin, umax, smin, smax};
65 }
66
fromUnsigned(const APInt & umin,const APInt & umax)67 ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
68 const APInt &umax) {
69 unsigned int width = umin.getBitWidth();
70 APInt smin, smax;
71 if (umin.isNonNegative() == umax.isNonNegative()) {
72 smin = umin.slt(umax) ? umin : umax;
73 smax = umin.sgt(umax) ? umin : umax;
74 } else {
75 smin = APInt::getSignedMinValue(width);
76 smax = APInt::getSignedMaxValue(width);
77 }
78 return {umin, umax, smin, smax};
79 }
80
81 ConstantIntRanges
rangeUnion(const ConstantIntRanges & other) const82 ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
83 // "Not an integer" poisons everything and also cannot be fed to comparison
84 // operators.
85 if (umin().getBitWidth() == 0)
86 return *this;
87 if (other.umin().getBitWidth() == 0)
88 return other;
89
90 const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
91 const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
92 const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
93 const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
94
95 return {uminUnion, umaxUnion, sminUnion, smaxUnion};
96 }
97
98 ConstantIntRanges
intersection(const ConstantIntRanges & other) const99 ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
100 // "Not an integer" poisons everything and also cannot be fed to comparison
101 // operators.
102 if (umin().getBitWidth() == 0)
103 return *this;
104 if (other.umin().getBitWidth() == 0)
105 return other;
106
107 const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
108 const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
109 const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
110 const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
111
112 return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
113 }
114
getConstantValue() const115 Optional<APInt> ConstantIntRanges::getConstantValue() const {
116 // Note: we need to exclude the trivially-equal width 0 values here.
117 if (umin() == umax() && umin().getBitWidth() != 0)
118 return umin();
119 if (smin() == smax() && smin().getBitWidth() != 0)
120 return smin();
121 return None;
122 }
123
operator <<(raw_ostream & os,const ConstantIntRanges & range)124 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
125 return os << "unsigned : [" << range.umin() << ", " << range.umax()
126 << "] signed : [" << range.smin() << ", " << range.smax() << "]";
127 }
128