1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the dataflow analysis class for integer range inference 10 // which is used in transformations over the `arith` dialect such as 11 // branch elimination or signed->unsigned rewriting 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" 16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 17 #include "mlir/Interfaces/InferIntRangeInterface.h" 18 #include "mlir/Interfaces/LoopLikeInterface.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "int-range-analysis" 22 23 using namespace mlir; 24 using namespace mlir::dataflow; 25 26 IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) { 27 unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType()); 28 APInt umin = APInt::getMinValue(width); 29 APInt umax = APInt::getMaxValue(width); 30 APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin; 31 APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax; 32 return {{umin, umax, smin, smax}}; 33 } 34 35 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { 36 Lattice::onUpdate(solver); 37 38 // If the integer range can be narrowed to a constant, update the constant 39 // value of the SSA value. 40 Optional<APInt> constant = getValue().getValue().getConstantValue(); 41 auto value = point.get<Value>(); 42 auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value); 43 if (!constant) 44 return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint()); 45 46 Dialect *dialect; 47 if (auto *parent = value.getDefiningOp()) 48 dialect = parent->getDialect(); 49 else 50 dialect = value.getParentBlock()->getParentOp()->getDialect(); 51 solver->propagateIfChanged( 52 cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant), 53 dialect))); 54 } 55 56 void IntegerRangeAnalysis::visitOperation( 57 Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands, 58 ArrayRef<IntegerValueRangeLattice *> results) { 59 // Ignore non-integer outputs - return early if the op has no scalar 60 // integer results 61 bool hasIntegerResult = false; 62 for (auto it : llvm::zip(results, op->getResults())) { 63 if (std::get<1>(it).getType().isIntOrIndex()) { 64 hasIntegerResult = true; 65 } else { 66 propagateIfChanged(std::get<0>(it), 67 std::get<0>(it)->markPessimisticFixpoint()); 68 } 69 } 70 if (!hasIntegerResult) 71 return; 72 73 auto inferrable = dyn_cast<InferIntRangeInterface>(op); 74 if (!inferrable) 75 return markAllPessimisticFixpoint(results); 76 77 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); 78 SmallVector<ConstantIntRanges> argRanges( 79 llvm::map_range(operands, [](const IntegerValueRangeLattice *val) { 80 return val->getValue().getValue(); 81 })); 82 83 auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { 84 auto result = v.dyn_cast<OpResult>(); 85 if (!result) 86 return; 87 assert(llvm::find(op->getResults(), result) != op->result_end()); 88 89 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); 90 IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; 91 Optional<IntegerValueRange> oldRange; 92 if (!lattice->isUninitialized()) 93 oldRange = lattice->getValue(); 94 95 ChangeResult changed = lattice->join(attrs); 96 97 // Catch loop results with loop variant bounds and conservatively make 98 // them [-inf, inf] so we don't circle around infinitely often (because 99 // the dataflow analysis in MLIR doesn't attempt to work out trip counts 100 // and often can't). 101 bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { 102 return op->hasTrait<OpTrait::IsTerminator>(); 103 }); 104 if (isYieldedResult && oldRange.has_value() && 105 !(lattice->getValue() == *oldRange)) { 106 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); 107 changed |= lattice->markPessimisticFixpoint(); 108 } 109 propagateIfChanged(lattice, changed); 110 }; 111 112 inferrable.inferResultRanges(argRanges, joinCallback); 113 } 114 115 void IntegerRangeAnalysis::visitNonControlFlowArguments( 116 Operation *op, const RegionSuccessor &successor, 117 ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) { 118 if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) { 119 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); 120 SmallVector<ConstantIntRanges> argRanges( 121 llvm::map_range(op->getOperands(), [&](Value value) { 122 return getLatticeElementFor(op, value)->getValue().getValue(); 123 })); 124 125 auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { 126 auto arg = v.dyn_cast<BlockArgument>(); 127 if (!arg) 128 return; 129 if (llvm::find(successor.getSuccessor()->getArguments(), arg) == 130 successor.getSuccessor()->args_end()) 131 return; 132 133 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); 134 IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; 135 Optional<IntegerValueRange> oldRange; 136 if (!lattice->isUninitialized()) 137 oldRange = lattice->getValue(); 138 139 ChangeResult changed = lattice->join(attrs); 140 141 // Catch loop results with loop variant bounds and conservatively make 142 // them [-inf, inf] so we don't circle around infinitely often (because 143 // the dataflow analysis in MLIR doesn't attempt to work out trip counts 144 // and often can't). 145 bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { 146 return op->hasTrait<OpTrait::IsTerminator>(); 147 }); 148 if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) { 149 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); 150 changed |= lattice->markPessimisticFixpoint(); 151 } 152 propagateIfChanged(lattice, changed); 153 }; 154 155 inferrable.inferResultRanges(argRanges, joinCallback); 156 return; 157 } 158 159 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() 160 /// on a LoopLikeInterface return the lower/upper bound for that result if 161 /// possible. 162 auto getLoopBoundFromFold = [&](Optional<OpFoldResult> loopBound, 163 Type boundType, bool getUpper) { 164 unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); 165 if (loopBound.has_value()) { 166 if (loopBound->is<Attribute>()) { 167 if (auto bound = 168 loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>()) 169 return bound.getValue(); 170 } else if (auto value = loopBound->dyn_cast<Value>()) { 171 const IntegerValueRangeLattice *lattice = 172 getLatticeElementFor(op, value); 173 if (lattice != nullptr) 174 return getUpper ? lattice->getValue().getValue().smax() 175 : lattice->getValue().getValue().smin(); 176 } 177 } 178 // Given the results of getConstant{Lower,Upper}Bound() 179 // or getConstantStep() on a LoopLikeInterface return the lower/upper 180 // bound 181 return getUpper ? APInt::getSignedMaxValue(width) 182 : APInt::getSignedMinValue(width); 183 }; 184 185 // Infer bounds for loop arguments that have static bounds 186 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) { 187 Optional<Value> iv = loop.getSingleInductionVar(); 188 if (!iv) { 189 return SparseDataFlowAnalysis ::visitNonControlFlowArguments( 190 op, successor, argLattices, firstIndex); 191 } 192 Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound(); 193 Optional<OpFoldResult> upperBound = loop.getSingleUpperBound(); 194 Optional<OpFoldResult> step = loop.getSingleStep(); 195 APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), 196 /*getUpper=*/false); 197 APInt max = getLoopBoundFromFold(upperBound, iv->getType(), 198 /*getUpper=*/true); 199 // Assume positivity for uniscoverable steps by way of getUpper = true. 200 APInt stepVal = 201 getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true); 202 203 if (stepVal.isNegative()) { 204 std::swap(min, max); 205 } else { 206 // Correct the upper bound by subtracting 1 so that it becomes a <= 207 // bound, because loops do not generally include their upper bound. 208 max -= 1; 209 } 210 211 IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); 212 auto ivRange = ConstantIntRanges::fromSigned(min, max); 213 propagateIfChanged(ivEntry, ivEntry->join(ivRange)); 214 return; 215 } 216 217 return SparseDataFlowAnalysis::visitNonControlFlowArguments( 218 op, successor, argLattices, firstIndex); 219 } 220