1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the dataflow analysis class for integer range inference 10 // which is used in transformations over the `arith` dialect such as 11 // branch elimination or signed->unsigned rewriting 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" 16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 17 #include "mlir/Interfaces/InferIntRangeInterface.h" 18 #include "mlir/Interfaces/LoopLikeInterface.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "int-range-analysis" 22 23 using namespace mlir; 24 using namespace mlir::dataflow; 25 26 IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) { 27 unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType()); 28 APInt umin = APInt::getMinValue(width); 29 APInt umax = APInt::getMaxValue(width); 30 APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin; 31 APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax; 32 return {{umin, umax, smin, smax}}; 33 } 34 35 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { 36 Lattice::onUpdate(solver); 37 38 // If the integer range can be narrowed to a constant, update the constant 39 // value of the SSA value. 40 Optional<APInt> constant = getValue().getValue().getConstantValue(); 41 auto value = point.get<Value>(); 42 auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value); 43 if (!constant) 44 return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint()); 45 46 Dialect *dialect; 47 if (auto *parent = value.getDefiningOp()) 48 dialect = parent->getDialect(); 49 else 50 dialect = value.getParentBlock()->getParentOp()->getDialect(); 51 solver->propagateIfChanged( 52 cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant), 53 dialect))); 54 } 55 56 void IntegerRangeAnalysis::visitOperation( 57 Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands, 58 ArrayRef<IntegerValueRangeLattice *> results) { 59 // Ignore non-integer outputs - return early if the op has no scalar 60 // integer results 61 bool hasIntegerResult = false; 62 for (auto it : llvm::zip(results, op->getResults())) { 63 if (std::get<1>(it).getType().isIntOrIndex()) { 64 hasIntegerResult = true; 65 } else { 66 propagateIfChanged(std::get<0>(it), 67 std::get<0>(it)->markPessimisticFixpoint()); 68 } 69 } 70 if (!hasIntegerResult) 71 return; 72 73 auto inferrable = dyn_cast<InferIntRangeInterface>(op); 74 if (!inferrable) 75 return markAllPessimisticFixpoint(results); 76 77 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); 78 SmallVector<ConstantIntRanges> argRanges( 79 llvm::map_range(operands, [](const IntegerValueRangeLattice *val) { 80 return val->getValue().getValue(); 81 })); 82 83 auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { 84 auto result = v.dyn_cast<OpResult>(); 85 if (!result) 86 return; 87 assert(llvm::is_contained(op->getResults(), result)); 88 89 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); 90 IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; 91 Optional<IntegerValueRange> oldRange; 92 if (!lattice->isUninitialized()) 93 oldRange = lattice->getValue(); 94 95 ChangeResult changed = lattice->join(attrs); 96 97 // Catch loop results with loop variant bounds and conservatively make 98 // them [-inf, inf] so we don't circle around infinitely often (because 99 // the dataflow analysis in MLIR doesn't attempt to work out trip counts 100 // and often can't). 101 bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { 102 return op->hasTrait<OpTrait::IsTerminator>(); 103 }); 104 if (isYieldedResult && oldRange.has_value() && 105 !(lattice->getValue() == *oldRange)) { 106 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); 107 changed |= lattice->markPessimisticFixpoint(); 108 } 109 propagateIfChanged(lattice, changed); 110 }; 111 112 inferrable.inferResultRanges(argRanges, joinCallback); 113 } 114 115 void IntegerRangeAnalysis::visitNonControlFlowArguments( 116 Operation *op, const RegionSuccessor &successor, 117 ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) { 118 if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) { 119 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); 120 SmallVector<ConstantIntRanges> argRanges( 121 llvm::map_range(op->getOperands(), [&](Value value) { 122 return getLatticeElementFor(op, value)->getValue().getValue(); 123 })); 124 125 auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { 126 auto arg = v.dyn_cast<BlockArgument>(); 127 if (!arg) 128 return; 129 if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg)) 130 return; 131 132 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); 133 IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; 134 Optional<IntegerValueRange> oldRange; 135 if (!lattice->isUninitialized()) 136 oldRange = lattice->getValue(); 137 138 ChangeResult changed = lattice->join(attrs); 139 140 // Catch loop results with loop variant bounds and conservatively make 141 // them [-inf, inf] so we don't circle around infinitely often (because 142 // the dataflow analysis in MLIR doesn't attempt to work out trip counts 143 // and often can't). 144 bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { 145 return op->hasTrait<OpTrait::IsTerminator>(); 146 }); 147 if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) { 148 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); 149 changed |= lattice->markPessimisticFixpoint(); 150 } 151 propagateIfChanged(lattice, changed); 152 }; 153 154 inferrable.inferResultRanges(argRanges, joinCallback); 155 return; 156 } 157 158 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() 159 /// on a LoopLikeInterface return the lower/upper bound for that result if 160 /// possible. 161 auto getLoopBoundFromFold = [&](Optional<OpFoldResult> loopBound, 162 Type boundType, bool getUpper) { 163 unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); 164 if (loopBound.has_value()) { 165 if (loopBound->is<Attribute>()) { 166 if (auto bound = 167 loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>()) 168 return bound.getValue(); 169 } else if (auto value = loopBound->dyn_cast<Value>()) { 170 const IntegerValueRangeLattice *lattice = 171 getLatticeElementFor(op, value); 172 if (lattice != nullptr) 173 return getUpper ? lattice->getValue().getValue().smax() 174 : lattice->getValue().getValue().smin(); 175 } 176 } 177 // Given the results of getConstant{Lower,Upper}Bound() 178 // or getConstantStep() on a LoopLikeInterface return the lower/upper 179 // bound 180 return getUpper ? APInt::getSignedMaxValue(width) 181 : APInt::getSignedMinValue(width); 182 }; 183 184 // Infer bounds for loop arguments that have static bounds 185 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) { 186 Optional<Value> iv = loop.getSingleInductionVar(); 187 if (!iv) { 188 return SparseDataFlowAnalysis ::visitNonControlFlowArguments( 189 op, successor, argLattices, firstIndex); 190 } 191 Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound(); 192 Optional<OpFoldResult> upperBound = loop.getSingleUpperBound(); 193 Optional<OpFoldResult> step = loop.getSingleStep(); 194 APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), 195 /*getUpper=*/false); 196 APInt max = getLoopBoundFromFold(upperBound, iv->getType(), 197 /*getUpper=*/true); 198 // Assume positivity for uniscoverable steps by way of getUpper = true. 199 APInt stepVal = 200 getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true); 201 202 if (stepVal.isNegative()) { 203 std::swap(min, max); 204 } else { 205 // Correct the upper bound by subtracting 1 so that it becomes a <= 206 // bound, because loops do not generally include their upper bound. 207 max -= 1; 208 } 209 210 IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); 211 auto ivRange = ConstantIntRanges::fromSigned(min, max); 212 propagateIfChanged(ivEntry, ivEntry->join(ivRange)); 213 return; 214 } 215 216 return SparseDataFlowAnalysis::visitNonControlFlowArguments( 217 op, successor, argLattices, firstIndex); 218 } 219