1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Interfaces/InferIntRangeInterface.h"
18 #include "mlir/Interfaces/LoopLikeInterface.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "int-range-analysis"
22 
23 using namespace mlir;
24 using namespace mlir::dataflow;
25 
getPessimisticValueState(Value value)26 IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
27   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
28   APInt umin = APInt::getMinValue(width);
29   APInt umax = APInt::getMaxValue(width);
30   APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
31   APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
32   return {{umin, umax, smin, smax}};
33 }
34 
onUpdate(DataFlowSolver * solver) const35 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
36   Lattice::onUpdate(solver);
37 
38   // If the integer range can be narrowed to a constant, update the constant
39   // value of the SSA value.
40   Optional<APInt> constant = getValue().getValue().getConstantValue();
41   auto value = point.get<Value>();
42   auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
43   if (!constant)
44     return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
45 
46   Dialect *dialect;
47   if (auto *parent = value.getDefiningOp())
48     dialect = parent->getDialect();
49   else
50     dialect = value.getParentBlock()->getParentOp()->getDialect();
51   solver->propagateIfChanged(
52       cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
53                                  dialect)));
54 }
55 
visitOperation(Operation * op,ArrayRef<const IntegerValueRangeLattice * > operands,ArrayRef<IntegerValueRangeLattice * > results)56 void IntegerRangeAnalysis::visitOperation(
57     Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
58     ArrayRef<IntegerValueRangeLattice *> results) {
59   // Ignore non-integer outputs - return early if the op has no scalar
60   // integer results
61   bool hasIntegerResult = false;
62   for (auto it : llvm::zip(results, op->getResults())) {
63     if (std::get<1>(it).getType().isIntOrIndex()) {
64       hasIntegerResult = true;
65     } else {
66       propagateIfChanged(std::get<0>(it),
67                          std::get<0>(it)->markPessimisticFixpoint());
68     }
69   }
70   if (!hasIntegerResult)
71     return;
72 
73   auto inferrable = dyn_cast<InferIntRangeInterface>(op);
74   if (!inferrable)
75     return markAllPessimisticFixpoint(results);
76 
77   LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
78   SmallVector<ConstantIntRanges> argRanges(
79       llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
80         return val->getValue().getValue();
81       }));
82 
83   auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
84     auto result = v.dyn_cast<OpResult>();
85     if (!result)
86       return;
87     assert(llvm::is_contained(op->getResults(), result));
88 
89     LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
90     IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
91     Optional<IntegerValueRange> oldRange;
92     if (!lattice->isUninitialized())
93       oldRange = lattice->getValue();
94 
95     ChangeResult changed = lattice->join(attrs);
96 
97     // Catch loop results with loop variant bounds and conservatively make
98     // them [-inf, inf] so we don't circle around infinitely often (because
99     // the dataflow analysis in MLIR doesn't attempt to work out trip counts
100     // and often can't).
101     bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
102       return op->hasTrait<OpTrait::IsTerminator>();
103     });
104     if (isYieldedResult && oldRange.has_value() &&
105         !(lattice->getValue() == *oldRange)) {
106       LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
107       changed |= lattice->markPessimisticFixpoint();
108     }
109     propagateIfChanged(lattice, changed);
110   };
111 
112   inferrable.inferResultRanges(argRanges, joinCallback);
113 }
114 
visitNonControlFlowArguments(Operation * op,const RegionSuccessor & successor,ArrayRef<IntegerValueRangeLattice * > argLattices,unsigned firstIndex)115 void IntegerRangeAnalysis::visitNonControlFlowArguments(
116     Operation *op, const RegionSuccessor &successor,
117     ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
118   if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
119     LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
120     SmallVector<ConstantIntRanges> argRanges(
121         llvm::map_range(op->getOperands(), [&](Value value) {
122           return getLatticeElementFor(op, value)->getValue().getValue();
123         }));
124 
125     auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
126       auto arg = v.dyn_cast<BlockArgument>();
127       if (!arg)
128         return;
129       if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
130         return;
131 
132       LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
133       IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
134       Optional<IntegerValueRange> oldRange;
135       if (!lattice->isUninitialized())
136         oldRange = lattice->getValue();
137 
138       ChangeResult changed = lattice->join(attrs);
139 
140       // Catch loop results with loop variant bounds and conservatively make
141       // them [-inf, inf] so we don't circle around infinitely often (because
142       // the dataflow analysis in MLIR doesn't attempt to work out trip counts
143       // and often can't).
144       bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
145         return op->hasTrait<OpTrait::IsTerminator>();
146       });
147       if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
148         LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
149         changed |= lattice->markPessimisticFixpoint();
150       }
151       propagateIfChanged(lattice, changed);
152     };
153 
154     inferrable.inferResultRanges(argRanges, joinCallback);
155     return;
156   }
157 
158   /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
159   /// on a LoopLikeInterface return the lower/upper bound for that result if
160   /// possible.
161   auto getLoopBoundFromFold = [&](Optional<OpFoldResult> loopBound,
162                                   Type boundType, bool getUpper) {
163     unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
164     if (loopBound.has_value()) {
165       if (loopBound->is<Attribute>()) {
166         if (auto bound =
167                 loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
168           return bound.getValue();
169       } else if (auto value = loopBound->dyn_cast<Value>()) {
170         const IntegerValueRangeLattice *lattice =
171             getLatticeElementFor(op, value);
172         if (lattice != nullptr)
173           return getUpper ? lattice->getValue().getValue().smax()
174                           : lattice->getValue().getValue().smin();
175       }
176     }
177     // Given the results of getConstant{Lower,Upper}Bound()
178     // or getConstantStep() on a LoopLikeInterface return the lower/upper
179     // bound
180     return getUpper ? APInt::getSignedMaxValue(width)
181                     : APInt::getSignedMinValue(width);
182   };
183 
184   // Infer bounds for loop arguments that have static bounds
185   if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
186     Optional<Value> iv = loop.getSingleInductionVar();
187     if (!iv) {
188       return SparseDataFlowAnalysis ::visitNonControlFlowArguments(
189           op, successor, argLattices, firstIndex);
190     }
191     Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
192     Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
193     Optional<OpFoldResult> step = loop.getSingleStep();
194     APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
195                                      /*getUpper=*/false);
196     APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
197                                      /*getUpper=*/true);
198     // Assume positivity for uniscoverable steps by way of getUpper = true.
199     APInt stepVal =
200         getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
201 
202     if (stepVal.isNegative()) {
203       std::swap(min, max);
204     } else {
205       // Correct the upper bound by subtracting 1 so that it becomes a <=
206       // bound, because loops do not generally include their upper bound.
207       max -= 1;
208     }
209 
210     IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
211     auto ivRange = ConstantIntRanges::fromSigned(min, max);
212     propagateIfChanged(ivEntry, ivEntry->join(ivRange));
213     return;
214   }
215 
216   return SparseDataFlowAnalysis::visitNonControlFlowArguments(
217       op, successor, argLattices, firstIndex);
218 }
219