1ab701975SMogball //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2ab701975SMogball //
3ab701975SMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ab701975SMogball // See https://llvm.org/LICENSE.txt for license information.
5ab701975SMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ab701975SMogball //
7ab701975SMogball //===----------------------------------------------------------------------===//
8ab701975SMogball //
9ab701975SMogball // This file defines the dataflow analysis class for integer range inference
10ab701975SMogball // which is used in transformations over the `arith` dialect such as
11ab701975SMogball // branch elimination or signed->unsigned rewriting
12ab701975SMogball //
13ab701975SMogball //===----------------------------------------------------------------------===//
14ab701975SMogball 
15ab701975SMogball #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16ab701975SMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17ab701975SMogball #include "mlir/Interfaces/InferIntRangeInterface.h"
18ab701975SMogball #include "mlir/Interfaces/LoopLikeInterface.h"
19ab701975SMogball #include "llvm/Support/Debug.h"
20ab701975SMogball 
21ab701975SMogball #define DEBUG_TYPE "int-range-analysis"
22ab701975SMogball 
23ab701975SMogball using namespace mlir;
24ab701975SMogball using namespace mlir::dataflow;
25ab701975SMogball 
getPessimisticValueState(Value value)26ab701975SMogball IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
27ab701975SMogball   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
28ab701975SMogball   APInt umin = APInt::getMinValue(width);
29ab701975SMogball   APInt umax = APInt::getMaxValue(width);
30ab701975SMogball   APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
31ab701975SMogball   APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
32ab701975SMogball   return {{umin, umax, smin, smax}};
33ab701975SMogball }
34ab701975SMogball 
onUpdate(DataFlowSolver * solver) const35ab701975SMogball void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
36ab701975SMogball   Lattice::onUpdate(solver);
37ab701975SMogball 
38ab701975SMogball   // If the integer range can be narrowed to a constant, update the constant
39ab701975SMogball   // value of the SSA value.
40ab701975SMogball   Optional<APInt> constant = getValue().getValue().getConstantValue();
41ab701975SMogball   auto value = point.get<Value>();
42ab701975SMogball   auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
43ab701975SMogball   if (!constant)
44ab701975SMogball     return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
45ab701975SMogball 
46ab701975SMogball   Dialect *dialect;
47ab701975SMogball   if (auto *parent = value.getDefiningOp())
48ab701975SMogball     dialect = parent->getDialect();
49ab701975SMogball   else
50ab701975SMogball     dialect = value.getParentBlock()->getParentOp()->getDialect();
51ab701975SMogball   solver->propagateIfChanged(
52ab701975SMogball       cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
53ab701975SMogball                                  dialect)));
54ab701975SMogball }
55ab701975SMogball 
visitOperation(Operation * op,ArrayRef<const IntegerValueRangeLattice * > operands,ArrayRef<IntegerValueRangeLattice * > results)56ab701975SMogball void IntegerRangeAnalysis::visitOperation(
57ab701975SMogball     Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
58ab701975SMogball     ArrayRef<IntegerValueRangeLattice *> results) {
59ab701975SMogball   // Ignore non-integer outputs - return early if the op has no scalar
60ab701975SMogball   // integer results
61ab701975SMogball   bool hasIntegerResult = false;
62ab701975SMogball   for (auto it : llvm::zip(results, op->getResults())) {
63ab701975SMogball     if (std::get<1>(it).getType().isIntOrIndex()) {
64ab701975SMogball       hasIntegerResult = true;
65ab701975SMogball     } else {
66ab701975SMogball       propagateIfChanged(std::get<0>(it),
67ab701975SMogball                          std::get<0>(it)->markPessimisticFixpoint());
68ab701975SMogball     }
69ab701975SMogball   }
70ab701975SMogball   if (!hasIntegerResult)
71ab701975SMogball     return;
72ab701975SMogball 
73ab701975SMogball   auto inferrable = dyn_cast<InferIntRangeInterface>(op);
74ab701975SMogball   if (!inferrable)
75ab701975SMogball     return markAllPessimisticFixpoint(results);
76ab701975SMogball 
77ab701975SMogball   LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
78ab701975SMogball   SmallVector<ConstantIntRanges> argRanges(
79ab701975SMogball       llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
80ab701975SMogball         return val->getValue().getValue();
81ab701975SMogball       }));
82ab701975SMogball 
83ab701975SMogball   auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
84ab701975SMogball     auto result = v.dyn_cast<OpResult>();
85ab701975SMogball     if (!result)
86ab701975SMogball       return;
87*360c1111SKazu Hirata     assert(llvm::is_contained(op->getResults(), result));
88ab701975SMogball 
89ab701975SMogball     LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
90ab701975SMogball     IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
91ab701975SMogball     Optional<IntegerValueRange> oldRange;
92ab701975SMogball     if (!lattice->isUninitialized())
93ab701975SMogball       oldRange = lattice->getValue();
94ab701975SMogball 
95ab701975SMogball     ChangeResult changed = lattice->join(attrs);
96ab701975SMogball 
97ab701975SMogball     // Catch loop results with loop variant bounds and conservatively make
98ab701975SMogball     // them [-inf, inf] so we don't circle around infinitely often (because
99ab701975SMogball     // the dataflow analysis in MLIR doesn't attempt to work out trip counts
100ab701975SMogball     // and often can't).
101ab701975SMogball     bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
102ab701975SMogball       return op->hasTrait<OpTrait::IsTerminator>();
103ab701975SMogball     });
104491d2701SKazu Hirata     if (isYieldedResult && oldRange.has_value() &&
105ab701975SMogball         !(lattice->getValue() == *oldRange)) {
106ab701975SMogball       LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
107ab701975SMogball       changed |= lattice->markPessimisticFixpoint();
108ab701975SMogball     }
109ab701975SMogball     propagateIfChanged(lattice, changed);
110ab701975SMogball   };
111ab701975SMogball 
112ab701975SMogball   inferrable.inferResultRanges(argRanges, joinCallback);
113ab701975SMogball }
114ab701975SMogball 
visitNonControlFlowArguments(Operation * op,const RegionSuccessor & successor,ArrayRef<IntegerValueRangeLattice * > argLattices,unsigned firstIndex)115ab701975SMogball void IntegerRangeAnalysis::visitNonControlFlowArguments(
116ab701975SMogball     Operation *op, const RegionSuccessor &successor,
117ab701975SMogball     ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
118ab701975SMogball   if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
119ab701975SMogball     LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
120ab701975SMogball     SmallVector<ConstantIntRanges> argRanges(
121ab701975SMogball         llvm::map_range(op->getOperands(), [&](Value value) {
122ab701975SMogball           return getLatticeElementFor(op, value)->getValue().getValue();
123ab701975SMogball         }));
124ab701975SMogball 
125ab701975SMogball     auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
126ab701975SMogball       auto arg = v.dyn_cast<BlockArgument>();
127ab701975SMogball       if (!arg)
128ab701975SMogball         return;
129*360c1111SKazu Hirata       if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
130ab701975SMogball         return;
131ab701975SMogball 
132ab701975SMogball       LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
133ab701975SMogball       IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
134ab701975SMogball       Optional<IntegerValueRange> oldRange;
135ab701975SMogball       if (!lattice->isUninitialized())
136ab701975SMogball         oldRange = lattice->getValue();
137ab701975SMogball 
138ab701975SMogball       ChangeResult changed = lattice->join(attrs);
139ab701975SMogball 
140ab701975SMogball       // Catch loop results with loop variant bounds and conservatively make
141ab701975SMogball       // them [-inf, inf] so we don't circle around infinitely often (because
142ab701975SMogball       // the dataflow analysis in MLIR doesn't attempt to work out trip counts
143ab701975SMogball       // and often can't).
144ab701975SMogball       bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
145ab701975SMogball         return op->hasTrait<OpTrait::IsTerminator>();
146ab701975SMogball       });
147ab701975SMogball       if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
148ab701975SMogball         LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
149ab701975SMogball         changed |= lattice->markPessimisticFixpoint();
150ab701975SMogball       }
151ab701975SMogball       propagateIfChanged(lattice, changed);
152ab701975SMogball     };
153ab701975SMogball 
154ab701975SMogball     inferrable.inferResultRanges(argRanges, joinCallback);
155ab701975SMogball     return;
156ab701975SMogball   }
157ab701975SMogball 
158ab701975SMogball   /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
159ab701975SMogball   /// on a LoopLikeInterface return the lower/upper bound for that result if
160ab701975SMogball   /// possible.
161ab701975SMogball   auto getLoopBoundFromFold = [&](Optional<OpFoldResult> loopBound,
162ab701975SMogball                                   Type boundType, bool getUpper) {
163ab701975SMogball     unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
164491d2701SKazu Hirata     if (loopBound.has_value()) {
165ab701975SMogball       if (loopBound->is<Attribute>()) {
166ab701975SMogball         if (auto bound =
167ab701975SMogball                 loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
168ab701975SMogball           return bound.getValue();
169ab701975SMogball       } else if (auto value = loopBound->dyn_cast<Value>()) {
170ab701975SMogball         const IntegerValueRangeLattice *lattice =
171ab701975SMogball             getLatticeElementFor(op, value);
172ab701975SMogball         if (lattice != nullptr)
173ab701975SMogball           return getUpper ? lattice->getValue().getValue().smax()
174ab701975SMogball                           : lattice->getValue().getValue().smin();
175ab701975SMogball       }
176ab701975SMogball     }
177ab701975SMogball     // Given the results of getConstant{Lower,Upper}Bound()
178ab701975SMogball     // or getConstantStep() on a LoopLikeInterface return the lower/upper
179ab701975SMogball     // bound
180ab701975SMogball     return getUpper ? APInt::getSignedMaxValue(width)
181ab701975SMogball                     : APInt::getSignedMinValue(width);
182ab701975SMogball   };
183ab701975SMogball 
184ab701975SMogball   // Infer bounds for loop arguments that have static bounds
185ab701975SMogball   if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
186ab701975SMogball     Optional<Value> iv = loop.getSingleInductionVar();
187ab701975SMogball     if (!iv) {
188ab701975SMogball       return SparseDataFlowAnalysis ::visitNonControlFlowArguments(
189ab701975SMogball           op, successor, argLattices, firstIndex);
190ab701975SMogball     }
191ab701975SMogball     Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
192ab701975SMogball     Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
193ab701975SMogball     Optional<OpFoldResult> step = loop.getSingleStep();
194ab701975SMogball     APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
195ab701975SMogball                                      /*getUpper=*/false);
196ab701975SMogball     APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
197ab701975SMogball                                      /*getUpper=*/true);
198ab701975SMogball     // Assume positivity for uniscoverable steps by way of getUpper = true.
199ab701975SMogball     APInt stepVal =
200ab701975SMogball         getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
201ab701975SMogball 
202ab701975SMogball     if (stepVal.isNegative()) {
203ab701975SMogball       std::swap(min, max);
204ab701975SMogball     } else {
205ab701975SMogball       // Correct the upper bound by subtracting 1 so that it becomes a <=
206ab701975SMogball       // bound, because loops do not generally include their upper bound.
207ab701975SMogball       max -= 1;
208ab701975SMogball     }
209ab701975SMogball 
210ab701975SMogball     IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
211ab701975SMogball     auto ivRange = ConstantIntRanges::fromSigned(min, max);
212ab701975SMogball     propagateIfChanged(ivEntry, ivEntry->join(ivRange));
213ab701975SMogball     return;
214ab701975SMogball   }
215ab701975SMogball 
216ab701975SMogball   return SparseDataFlowAnalysis::visitNonControlFlowArguments(
217ab701975SMogball       op, successor, argLattices, firstIndex);
218ab701975SMogball }
219