1ab701975SMogball //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2ab701975SMogball //
3ab701975SMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ab701975SMogball // See https://llvm.org/LICENSE.txt for license information.
5ab701975SMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ab701975SMogball //
7ab701975SMogball //===----------------------------------------------------------------------===//
8ab701975SMogball //
9ab701975SMogball // This file defines the dataflow analysis class for integer range inference
10ab701975SMogball // which is used in transformations over the `arith` dialect such as
11ab701975SMogball // branch elimination or signed->unsigned rewriting
12ab701975SMogball //
13ab701975SMogball //===----------------------------------------------------------------------===//
14ab701975SMogball
15ab701975SMogball #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16ab701975SMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17ab701975SMogball #include "mlir/Interfaces/InferIntRangeInterface.h"
18ab701975SMogball #include "mlir/Interfaces/LoopLikeInterface.h"
19ab701975SMogball #include "llvm/Support/Debug.h"
20ab701975SMogball
21ab701975SMogball #define DEBUG_TYPE "int-range-analysis"
22ab701975SMogball
23ab701975SMogball using namespace mlir;
24ab701975SMogball using namespace mlir::dataflow;
25ab701975SMogball
getPessimisticValueState(Value value)26ab701975SMogball IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
27ab701975SMogball unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
28ab701975SMogball APInt umin = APInt::getMinValue(width);
29ab701975SMogball APInt umax = APInt::getMaxValue(width);
30ab701975SMogball APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
31ab701975SMogball APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
32ab701975SMogball return {{umin, umax, smin, smax}};
33ab701975SMogball }
34ab701975SMogball
onUpdate(DataFlowSolver * solver) const35ab701975SMogball void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
36ab701975SMogball Lattice::onUpdate(solver);
37ab701975SMogball
38ab701975SMogball // If the integer range can be narrowed to a constant, update the constant
39ab701975SMogball // value of the SSA value.
40ab701975SMogball Optional<APInt> constant = getValue().getValue().getConstantValue();
41ab701975SMogball auto value = point.get<Value>();
42ab701975SMogball auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
43ab701975SMogball if (!constant)
44ab701975SMogball return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
45ab701975SMogball
46ab701975SMogball Dialect *dialect;
47ab701975SMogball if (auto *parent = value.getDefiningOp())
48ab701975SMogball dialect = parent->getDialect();
49ab701975SMogball else
50ab701975SMogball dialect = value.getParentBlock()->getParentOp()->getDialect();
51ab701975SMogball solver->propagateIfChanged(
52ab701975SMogball cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
53ab701975SMogball dialect)));
54ab701975SMogball }
55ab701975SMogball
visitOperation(Operation * op,ArrayRef<const IntegerValueRangeLattice * > operands,ArrayRef<IntegerValueRangeLattice * > results)56ab701975SMogball void IntegerRangeAnalysis::visitOperation(
57ab701975SMogball Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
58ab701975SMogball ArrayRef<IntegerValueRangeLattice *> results) {
59ab701975SMogball // Ignore non-integer outputs - return early if the op has no scalar
60ab701975SMogball // integer results
61ab701975SMogball bool hasIntegerResult = false;
62ab701975SMogball for (auto it : llvm::zip(results, op->getResults())) {
63ab701975SMogball if (std::get<1>(it).getType().isIntOrIndex()) {
64ab701975SMogball hasIntegerResult = true;
65ab701975SMogball } else {
66ab701975SMogball propagateIfChanged(std::get<0>(it),
67ab701975SMogball std::get<0>(it)->markPessimisticFixpoint());
68ab701975SMogball }
69ab701975SMogball }
70ab701975SMogball if (!hasIntegerResult)
71ab701975SMogball return;
72ab701975SMogball
73ab701975SMogball auto inferrable = dyn_cast<InferIntRangeInterface>(op);
74ab701975SMogball if (!inferrable)
75ab701975SMogball return markAllPessimisticFixpoint(results);
76ab701975SMogball
77ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
78ab701975SMogball SmallVector<ConstantIntRanges> argRanges(
79ab701975SMogball llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
80ab701975SMogball return val->getValue().getValue();
81ab701975SMogball }));
82ab701975SMogball
83ab701975SMogball auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
84ab701975SMogball auto result = v.dyn_cast<OpResult>();
85ab701975SMogball if (!result)
86ab701975SMogball return;
87*360c1111SKazu Hirata assert(llvm::is_contained(op->getResults(), result));
88ab701975SMogball
89ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
90ab701975SMogball IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
91ab701975SMogball Optional<IntegerValueRange> oldRange;
92ab701975SMogball if (!lattice->isUninitialized())
93ab701975SMogball oldRange = lattice->getValue();
94ab701975SMogball
95ab701975SMogball ChangeResult changed = lattice->join(attrs);
96ab701975SMogball
97ab701975SMogball // Catch loop results with loop variant bounds and conservatively make
98ab701975SMogball // them [-inf, inf] so we don't circle around infinitely often (because
99ab701975SMogball // the dataflow analysis in MLIR doesn't attempt to work out trip counts
100ab701975SMogball // and often can't).
101ab701975SMogball bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
102ab701975SMogball return op->hasTrait<OpTrait::IsTerminator>();
103ab701975SMogball });
104491d2701SKazu Hirata if (isYieldedResult && oldRange.has_value() &&
105ab701975SMogball !(lattice->getValue() == *oldRange)) {
106ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
107ab701975SMogball changed |= lattice->markPessimisticFixpoint();
108ab701975SMogball }
109ab701975SMogball propagateIfChanged(lattice, changed);
110ab701975SMogball };
111ab701975SMogball
112ab701975SMogball inferrable.inferResultRanges(argRanges, joinCallback);
113ab701975SMogball }
114ab701975SMogball
visitNonControlFlowArguments(Operation * op,const RegionSuccessor & successor,ArrayRef<IntegerValueRangeLattice * > argLattices,unsigned firstIndex)115ab701975SMogball void IntegerRangeAnalysis::visitNonControlFlowArguments(
116ab701975SMogball Operation *op, const RegionSuccessor &successor,
117ab701975SMogball ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
118ab701975SMogball if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
119ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
120ab701975SMogball SmallVector<ConstantIntRanges> argRanges(
121ab701975SMogball llvm::map_range(op->getOperands(), [&](Value value) {
122ab701975SMogball return getLatticeElementFor(op, value)->getValue().getValue();
123ab701975SMogball }));
124ab701975SMogball
125ab701975SMogball auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
126ab701975SMogball auto arg = v.dyn_cast<BlockArgument>();
127ab701975SMogball if (!arg)
128ab701975SMogball return;
129*360c1111SKazu Hirata if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
130ab701975SMogball return;
131ab701975SMogball
132ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
133ab701975SMogball IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
134ab701975SMogball Optional<IntegerValueRange> oldRange;
135ab701975SMogball if (!lattice->isUninitialized())
136ab701975SMogball oldRange = lattice->getValue();
137ab701975SMogball
138ab701975SMogball ChangeResult changed = lattice->join(attrs);
139ab701975SMogball
140ab701975SMogball // Catch loop results with loop variant bounds and conservatively make
141ab701975SMogball // them [-inf, inf] so we don't circle around infinitely often (because
142ab701975SMogball // the dataflow analysis in MLIR doesn't attempt to work out trip counts
143ab701975SMogball // and often can't).
144ab701975SMogball bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
145ab701975SMogball return op->hasTrait<OpTrait::IsTerminator>();
146ab701975SMogball });
147ab701975SMogball if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
148ab701975SMogball LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
149ab701975SMogball changed |= lattice->markPessimisticFixpoint();
150ab701975SMogball }
151ab701975SMogball propagateIfChanged(lattice, changed);
152ab701975SMogball };
153ab701975SMogball
154ab701975SMogball inferrable.inferResultRanges(argRanges, joinCallback);
155ab701975SMogball return;
156ab701975SMogball }
157ab701975SMogball
158ab701975SMogball /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
159ab701975SMogball /// on a LoopLikeInterface return the lower/upper bound for that result if
160ab701975SMogball /// possible.
161ab701975SMogball auto getLoopBoundFromFold = [&](Optional<OpFoldResult> loopBound,
162ab701975SMogball Type boundType, bool getUpper) {
163ab701975SMogball unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
164491d2701SKazu Hirata if (loopBound.has_value()) {
165ab701975SMogball if (loopBound->is<Attribute>()) {
166ab701975SMogball if (auto bound =
167ab701975SMogball loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
168ab701975SMogball return bound.getValue();
169ab701975SMogball } else if (auto value = loopBound->dyn_cast<Value>()) {
170ab701975SMogball const IntegerValueRangeLattice *lattice =
171ab701975SMogball getLatticeElementFor(op, value);
172ab701975SMogball if (lattice != nullptr)
173ab701975SMogball return getUpper ? lattice->getValue().getValue().smax()
174ab701975SMogball : lattice->getValue().getValue().smin();
175ab701975SMogball }
176ab701975SMogball }
177ab701975SMogball // Given the results of getConstant{Lower,Upper}Bound()
178ab701975SMogball // or getConstantStep() on a LoopLikeInterface return the lower/upper
179ab701975SMogball // bound
180ab701975SMogball return getUpper ? APInt::getSignedMaxValue(width)
181ab701975SMogball : APInt::getSignedMinValue(width);
182ab701975SMogball };
183ab701975SMogball
184ab701975SMogball // Infer bounds for loop arguments that have static bounds
185ab701975SMogball if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
186ab701975SMogball Optional<Value> iv = loop.getSingleInductionVar();
187ab701975SMogball if (!iv) {
188ab701975SMogball return SparseDataFlowAnalysis ::visitNonControlFlowArguments(
189ab701975SMogball op, successor, argLattices, firstIndex);
190ab701975SMogball }
191ab701975SMogball Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
192ab701975SMogball Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
193ab701975SMogball Optional<OpFoldResult> step = loop.getSingleStep();
194ab701975SMogball APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
195ab701975SMogball /*getUpper=*/false);
196ab701975SMogball APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
197ab701975SMogball /*getUpper=*/true);
198ab701975SMogball // Assume positivity for uniscoverable steps by way of getUpper = true.
199ab701975SMogball APInt stepVal =
200ab701975SMogball getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
201ab701975SMogball
202ab701975SMogball if (stepVal.isNegative()) {
203ab701975SMogball std::swap(min, max);
204ab701975SMogball } else {
205ab701975SMogball // Correct the upper bound by subtracting 1 so that it becomes a <=
206ab701975SMogball // bound, because loops do not generally include their upper bound.
207ab701975SMogball max -= 1;
208ab701975SMogball }
209ab701975SMogball
210ab701975SMogball IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
211ab701975SMogball auto ivRange = ConstantIntRanges::fromSigned(min, max);
212ab701975SMogball propagateIfChanged(ivEntry, ivEntry->join(ivRange));
213ab701975SMogball return;
214ab701975SMogball }
215ab701975SMogball
216ab701975SMogball return SparseDataFlowAnalysis::visitNonControlFlowArguments(
217ab701975SMogball op, successor, argLattices, firstIndex);
218ab701975SMogball }
219