1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Interfaces/InferIntRangeInterface.h"
18 #include "mlir/Interfaces/LoopLikeInterface.h"
19 #include "llvm/Support/Debug.h"
20
21 #define DEBUG_TYPE "int-range-analysis"
22
23 using namespace mlir;
24 using namespace mlir::dataflow;
25
getPessimisticValueState(Value value)26 IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
27 unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
28 APInt umin = APInt::getMinValue(width);
29 APInt umax = APInt::getMaxValue(width);
30 APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
31 APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
32 return {{umin, umax, smin, smax}};
33 }
34
onUpdate(DataFlowSolver * solver) const35 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
36 Lattice::onUpdate(solver);
37
38 // If the integer range can be narrowed to a constant, update the constant
39 // value of the SSA value.
40 Optional<APInt> constant = getValue().getValue().getConstantValue();
41 auto value = point.get<Value>();
42 auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
43 if (!constant)
44 return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
45
46 Dialect *dialect;
47 if (auto *parent = value.getDefiningOp())
48 dialect = parent->getDialect();
49 else
50 dialect = value.getParentBlock()->getParentOp()->getDialect();
51 solver->propagateIfChanged(
52 cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
53 dialect)));
54 }
55
visitOperation(Operation * op,ArrayRef<const IntegerValueRangeLattice * > operands,ArrayRef<IntegerValueRangeLattice * > results)56 void IntegerRangeAnalysis::visitOperation(
57 Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
58 ArrayRef<IntegerValueRangeLattice *> results) {
59 // Ignore non-integer outputs - return early if the op has no scalar
60 // integer results
61 bool hasIntegerResult = false;
62 for (auto it : llvm::zip(results, op->getResults())) {
63 if (std::get<1>(it).getType().isIntOrIndex()) {
64 hasIntegerResult = true;
65 } else {
66 propagateIfChanged(std::get<0>(it),
67 std::get<0>(it)->markPessimisticFixpoint());
68 }
69 }
70 if (!hasIntegerResult)
71 return;
72
73 auto inferrable = dyn_cast<InferIntRangeInterface>(op);
74 if (!inferrable)
75 return markAllPessimisticFixpoint(results);
76
77 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
78 SmallVector<ConstantIntRanges> argRanges(
79 llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
80 return val->getValue().getValue();
81 }));
82
83 auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
84 auto result = v.dyn_cast<OpResult>();
85 if (!result)
86 return;
87 assert(llvm::is_contained(op->getResults(), result));
88
89 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
90 IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
91 Optional<IntegerValueRange> oldRange;
92 if (!lattice->isUninitialized())
93 oldRange = lattice->getValue();
94
95 ChangeResult changed = lattice->join(attrs);
96
97 // Catch loop results with loop variant bounds and conservatively make
98 // them [-inf, inf] so we don't circle around infinitely often (because
99 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
100 // and often can't).
101 bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
102 return op->hasTrait<OpTrait::IsTerminator>();
103 });
104 if (isYieldedResult && oldRange.has_value() &&
105 !(lattice->getValue() == *oldRange)) {
106 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
107 changed |= lattice->markPessimisticFixpoint();
108 }
109 propagateIfChanged(lattice, changed);
110 };
111
112 inferrable.inferResultRanges(argRanges, joinCallback);
113 }
114
visitNonControlFlowArguments(Operation * op,const RegionSuccessor & successor,ArrayRef<IntegerValueRangeLattice * > argLattices,unsigned firstIndex)115 void IntegerRangeAnalysis::visitNonControlFlowArguments(
116 Operation *op, const RegionSuccessor &successor,
117 ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
118 if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
119 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
120 SmallVector<ConstantIntRanges> argRanges(
121 llvm::map_range(op->getOperands(), [&](Value value) {
122 return getLatticeElementFor(op, value)->getValue().getValue();
123 }));
124
125 auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
126 auto arg = v.dyn_cast<BlockArgument>();
127 if (!arg)
128 return;
129 if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
130 return;
131
132 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
133 IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
134 Optional<IntegerValueRange> oldRange;
135 if (!lattice->isUninitialized())
136 oldRange = lattice->getValue();
137
138 ChangeResult changed = lattice->join(attrs);
139
140 // Catch loop results with loop variant bounds and conservatively make
141 // them [-inf, inf] so we don't circle around infinitely often (because
142 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
143 // and often can't).
144 bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
145 return op->hasTrait<OpTrait::IsTerminator>();
146 });
147 if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
148 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
149 changed |= lattice->markPessimisticFixpoint();
150 }
151 propagateIfChanged(lattice, changed);
152 };
153
154 inferrable.inferResultRanges(argRanges, joinCallback);
155 return;
156 }
157
158 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
159 /// on a LoopLikeInterface return the lower/upper bound for that result if
160 /// possible.
161 auto getLoopBoundFromFold = [&](Optional<OpFoldResult> loopBound,
162 Type boundType, bool getUpper) {
163 unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
164 if (loopBound.has_value()) {
165 if (loopBound->is<Attribute>()) {
166 if (auto bound =
167 loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
168 return bound.getValue();
169 } else if (auto value = loopBound->dyn_cast<Value>()) {
170 const IntegerValueRangeLattice *lattice =
171 getLatticeElementFor(op, value);
172 if (lattice != nullptr)
173 return getUpper ? lattice->getValue().getValue().smax()
174 : lattice->getValue().getValue().smin();
175 }
176 }
177 // Given the results of getConstant{Lower,Upper}Bound()
178 // or getConstantStep() on a LoopLikeInterface return the lower/upper
179 // bound
180 return getUpper ? APInt::getSignedMaxValue(width)
181 : APInt::getSignedMinValue(width);
182 };
183
184 // Infer bounds for loop arguments that have static bounds
185 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
186 Optional<Value> iv = loop.getSingleInductionVar();
187 if (!iv) {
188 return SparseDataFlowAnalysis ::visitNonControlFlowArguments(
189 op, successor, argLattices, firstIndex);
190 }
191 Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
192 Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
193 Optional<OpFoldResult> step = loop.getSingleStep();
194 APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
195 /*getUpper=*/false);
196 APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
197 /*getUpper=*/true);
198 // Assume positivity for uniscoverable steps by way of getUpper = true.
199 APInt stepVal =
200 getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
201
202 if (stepVal.isNegative()) {
203 std::swap(min, max);
204 } else {
205 // Correct the upper bound by subtracting 1 so that it becomes a <=
206 // bound, because loops do not generally include their upper bound.
207 max -= 1;
208 }
209
210 IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
211 auto ivRange = ConstantIntRanges::fromSigned(min, max);
212 propagateIfChanged(ivEntry, ivEntry->join(ivRange));
213 return;
214 }
215
216 return SparseDataFlowAnalysis::visitNonControlFlowArguments(
217 op, successor, argLattices, firstIndex);
218 }
219