1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Interfaces/InferIntRangeInterface.h"
18 #include "mlir/Interfaces/LoopLikeInterface.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "int-range-analysis"
22 
23 using namespace mlir;
24 using namespace mlir::dataflow;
25 
26 IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
27   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
28   APInt umin = APInt::getMinValue(width);
29   APInt umax = APInt::getMaxValue(width);
30   APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
31   APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
32   return {{umin, umax, smin, smax}};
33 }
34 
35 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
36   Lattice::onUpdate(solver);
37 
38   // If the integer range can be narrowed to a constant, update the constant
39   // value of the SSA value.
40   Optional<APInt> constant = getValue().getValue().getConstantValue();
41   auto value = point.get<Value>();
42   auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
43   if (!constant)
44     return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
45 
46   Dialect *dialect;
47   if (auto *parent = value.getDefiningOp())
48     dialect = parent->getDialect();
49   else
50     dialect = value.getParentBlock()->getParentOp()->getDialect();
51   solver->propagateIfChanged(
52       cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
53                                  dialect)));
54 }
55 
56 void IntegerRangeAnalysis::visitOperation(
57     Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
58     ArrayRef<IntegerValueRangeLattice *> results) {
59   // Ignore non-integer outputs - return early if the op has no scalar
60   // integer results
61   bool hasIntegerResult = false;
62   for (auto it : llvm::zip(results, op->getResults())) {
63     if (std::get<1>(it).getType().isIntOrIndex()) {
64       hasIntegerResult = true;
65     } else {
66       propagateIfChanged(std::get<0>(it),
67                          std::get<0>(it)->markPessimisticFixpoint());
68     }
69   }
70   if (!hasIntegerResult)
71     return;
72 
73   auto inferrable = dyn_cast<InferIntRangeInterface>(op);
74   if (!inferrable)
75     return markAllPessimisticFixpoint(results);
76 
77   LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
78   SmallVector<ConstantIntRanges> argRanges(
79       llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
80         return val->getValue().getValue();
81       }));
82 
83   auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
84     auto result = v.dyn_cast<OpResult>();
85     if (!result)
86       return;
87     assert(llvm::find(op->getResults(), result) != op->result_end());
88 
89     LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
90     IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
91     Optional<IntegerValueRange> oldRange;
92     if (!lattice->isUninitialized())
93       oldRange = lattice->getValue();
94 
95     ChangeResult changed = lattice->join(attrs);
96 
97     // Catch loop results with loop variant bounds and conservatively make
98     // them [-inf, inf] so we don't circle around infinitely often (because
99     // the dataflow analysis in MLIR doesn't attempt to work out trip counts
100     // and often can't).
101     bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
102       return op->hasTrait<OpTrait::IsTerminator>();
103     });
104     if (isYieldedResult && oldRange.has_value() &&
105         !(lattice->getValue() == *oldRange)) {
106       LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
107       changed |= lattice->markPessimisticFixpoint();
108     }
109     propagateIfChanged(lattice, changed);
110   };
111 
112   inferrable.inferResultRanges(argRanges, joinCallback);
113 }
114 
115 void IntegerRangeAnalysis::visitNonControlFlowArguments(
116     Operation *op, const RegionSuccessor &successor,
117     ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
118   if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
119     LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
120     SmallVector<ConstantIntRanges> argRanges(
121         llvm::map_range(op->getOperands(), [&](Value value) {
122           return getLatticeElementFor(op, value)->getValue().getValue();
123         }));
124 
125     auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
126       auto arg = v.dyn_cast<BlockArgument>();
127       if (!arg)
128         return;
129       if (llvm::find(successor.getSuccessor()->getArguments(), arg) ==
130           successor.getSuccessor()->args_end())
131         return;
132 
133       LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
134       IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
135       Optional<IntegerValueRange> oldRange;
136       if (!lattice->isUninitialized())
137         oldRange = lattice->getValue();
138 
139       ChangeResult changed = lattice->join(attrs);
140 
141       // Catch loop results with loop variant bounds and conservatively make
142       // them [-inf, inf] so we don't circle around infinitely often (because
143       // the dataflow analysis in MLIR doesn't attempt to work out trip counts
144       // and often can't).
145       bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
146         return op->hasTrait<OpTrait::IsTerminator>();
147       });
148       if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
149         LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
150         changed |= lattice->markPessimisticFixpoint();
151       }
152       propagateIfChanged(lattice, changed);
153     };
154 
155     inferrable.inferResultRanges(argRanges, joinCallback);
156     return;
157   }
158 
159   /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
160   /// on a LoopLikeInterface return the lower/upper bound for that result if
161   /// possible.
162   auto getLoopBoundFromFold = [&](Optional<OpFoldResult> loopBound,
163                                   Type boundType, bool getUpper) {
164     unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
165     if (loopBound.has_value()) {
166       if (loopBound->is<Attribute>()) {
167         if (auto bound =
168                 loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
169           return bound.getValue();
170       } else if (auto value = loopBound->dyn_cast<Value>()) {
171         const IntegerValueRangeLattice *lattice =
172             getLatticeElementFor(op, value);
173         if (lattice != nullptr)
174           return getUpper ? lattice->getValue().getValue().smax()
175                           : lattice->getValue().getValue().smin();
176       }
177     }
178     // Given the results of getConstant{Lower,Upper}Bound()
179     // or getConstantStep() on a LoopLikeInterface return the lower/upper
180     // bound
181     return getUpper ? APInt::getSignedMaxValue(width)
182                     : APInt::getSignedMinValue(width);
183   };
184 
185   // Infer bounds for loop arguments that have static bounds
186   if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
187     Optional<Value> iv = loop.getSingleInductionVar();
188     if (!iv) {
189       return SparseDataFlowAnalysis ::visitNonControlFlowArguments(
190           op, successor, argLattices, firstIndex);
191     }
192     Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
193     Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
194     Optional<OpFoldResult> step = loop.getSingleStep();
195     APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
196                                      /*getUpper=*/false);
197     APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
198                                      /*getUpper=*/true);
199     // Assume positivity for uniscoverable steps by way of getUpper = true.
200     APInt stepVal =
201         getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
202 
203     if (stepVal.isNegative()) {
204       std::swap(min, max);
205     } else {
206       // Correct the upper bound by subtracting 1 so that it becomes a <=
207       // bound, because loops do not generally include their upper bound.
208       max -= 1;
209     }
210 
211     IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
212     auto ivRange = ConstantIntRanges::fromSigned(min, max);
213     propagateIfChanged(ivEntry, ivEntry->join(ivRange));
214     return;
215   }
216 
217   return SparseDataFlowAnalysis::visitNonControlFlowArguments(
218       op, successor, argLattices, firstIndex);
219 }
220