1 //===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
12 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15
16 using namespace mlir;
17 using namespace mlir::dataflow;
18
19 namespace {
20 /// This lattice represents a single underlying value for an SSA value.
21 class UnderlyingValue {
22 public:
23 /// The pessimistic underlying value of a value is itself.
getPessimisticValueState(Value value)24 static UnderlyingValue getPessimisticValueState(Value value) {
25 return {value};
26 }
27
28 /// Create an underlying value state with a known underlying value.
UnderlyingValue(Value underlyingValue={})29 UnderlyingValue(Value underlyingValue = {})
30 : underlyingValue(underlyingValue) {}
31
32 /// Returns the underlying value.
getUnderlyingValue() const33 Value getUnderlyingValue() const { return underlyingValue; }
34
35 /// Join two underlying values. If there are conflicting underlying values,
36 /// go to the pessimistic value.
join(const UnderlyingValue & lhs,const UnderlyingValue & rhs)37 static UnderlyingValue join(const UnderlyingValue &lhs,
38 const UnderlyingValue &rhs) {
39 return lhs.underlyingValue == rhs.underlyingValue ? lhs : UnderlyingValue();
40 }
41
42 /// Compare underlying values.
operator ==(const UnderlyingValue & rhs) const43 bool operator==(const UnderlyingValue &rhs) const {
44 return underlyingValue == rhs.underlyingValue;
45 }
46
print(raw_ostream & os) const47 void print(raw_ostream &os) const { os << underlyingValue; }
48
49 private:
50 Value underlyingValue;
51 };
52
53 /// This lattice represents, for a given memory resource, the potential last
54 /// operations that modified the resource.
55 class LastModification : public AbstractDenseLattice {
56 public:
57 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
58
59 using AbstractDenseLattice::AbstractDenseLattice;
60
61 /// The lattice is always initialized.
isUninitialized() const62 bool isUninitialized() const override { return false; }
63
64 /// Initialize the lattice. Does nothing.
defaultInitialize()65 ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
66
67 /// Mark the lattice as having reached its pessimistic fixpoint. That is, the
68 /// last modifications of all memory resources are unknown.
reset()69 ChangeResult reset() override {
70 if (lastMods.empty())
71 return ChangeResult::NoChange;
72 lastMods.clear();
73 return ChangeResult::Change;
74 }
75
76 /// The lattice is never at a fixpoint.
isAtFixpoint() const77 bool isAtFixpoint() const override { return false; }
78
79 /// Join the last modifications.
join(const AbstractDenseLattice & lattice)80 ChangeResult join(const AbstractDenseLattice &lattice) override {
81 const auto &rhs = static_cast<const LastModification &>(lattice);
82 ChangeResult result = ChangeResult::NoChange;
83 for (const auto &mod : rhs.lastMods) {
84 auto &lhsMod = lastMods[mod.first];
85 if (lhsMod != mod.second) {
86 lhsMod.insert(mod.second.begin(), mod.second.end());
87 result |= ChangeResult::Change;
88 }
89 }
90 return result;
91 }
92
93 /// Set the last modification of a value.
set(Value value,Operation * op)94 ChangeResult set(Value value, Operation *op) {
95 auto &lastMod = lastMods[value];
96 ChangeResult result = ChangeResult::NoChange;
97 if (lastMod.size() != 1 || *lastMod.begin() != op) {
98 result = ChangeResult::Change;
99 lastMod.clear();
100 lastMod.insert(op);
101 }
102 return result;
103 }
104
105 /// Get the last modifications of a value. Returns none if the last
106 /// modifications are not known.
getLastModifiers(Value value) const107 Optional<ArrayRef<Operation *>> getLastModifiers(Value value) const {
108 auto it = lastMods.find(value);
109 if (it == lastMods.end())
110 return {};
111 return it->second.getArrayRef();
112 }
113
print(raw_ostream & os) const114 void print(raw_ostream &os) const override {
115 for (const auto &lastMod : lastMods) {
116 os << lastMod.first << ":\n";
117 for (Operation *op : lastMod.second)
118 os << " " << *op << "\n";
119 }
120 }
121
122 private:
123 /// The potential last modifications of a memory resource. Use a set vector to
124 /// keep the results deterministic.
125 DenseMap<Value, SetVector<Operation *, SmallVector<Operation *, 2>,
126 SmallPtrSet<Operation *, 2>>>
127 lastMods;
128 };
129
130 class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
131 public:
132 using DenseDataFlowAnalysis::DenseDataFlowAnalysis;
133
134 /// Visit an operation. If the operation has no memory effects, then the state
135 /// is propagated with no change. If the operation allocates a resource, then
136 /// its reaching definitions is set to empty. If the operation writes to a
137 /// resource, then its reaching definition is set to the written value.
138 void visitOperation(Operation *op, const LastModification &before,
139 LastModification *after) override;
140 };
141
142 /// Define the lattice class explicitly to provide a type ID.
143 struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
144 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
145 using Lattice::Lattice;
146 };
147
148 /// An analysis that uses forwarding of values along control-flow and callgraph
149 /// edges to determine single underlying values for block arguments. This
150 /// analysis exists so that the test analysis and pass can test the behaviour of
151 /// the dense data-flow analysis on the callgraph.
152 class UnderlyingValueAnalysis
153 : public SparseDataFlowAnalysis<UnderlyingValueLattice> {
154 public:
155 using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
156
157 /// The underlying value of the results of an operation are not known.
visitOperation(Operation * op,ArrayRef<const UnderlyingValueLattice * > operands,ArrayRef<UnderlyingValueLattice * > results)158 void visitOperation(Operation *op,
159 ArrayRef<const UnderlyingValueLattice *> operands,
160 ArrayRef<UnderlyingValueLattice *> results) override {
161 markAllPessimisticFixpoint(results);
162 }
163 };
164 } // end anonymous namespace
165
166 /// Look for the most underlying value of a value.
getMostUnderlyingValue(Value value,function_ref<const UnderlyingValueLattice * (Value)> getUnderlyingValueFn)167 static Value getMostUnderlyingValue(
168 Value value,
169 function_ref<const UnderlyingValueLattice *(Value)> getUnderlyingValueFn) {
170 const UnderlyingValueLattice *underlying;
171 do {
172 underlying = getUnderlyingValueFn(value);
173 if (!underlying || underlying->isUninitialized())
174 return {};
175 Value underlyingValue = underlying->getValue().getUnderlyingValue();
176 if (underlyingValue == value)
177 break;
178 value = underlyingValue;
179 } while (true);
180 return value;
181 }
182
visitOperation(Operation * op,const LastModification & before,LastModification * after)183 void LastModifiedAnalysis::visitOperation(Operation *op,
184 const LastModification &before,
185 LastModification *after) {
186 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
187 // If we can't reason about the memory effects, then conservatively assume we
188 // can't deduce anything about the last modifications.
189 if (!memory)
190 return reset(after);
191
192 SmallVector<MemoryEffects::EffectInstance> effects;
193 memory.getEffects(effects);
194
195 ChangeResult result = after->join(before);
196 for (const auto &effect : effects) {
197 Value value = effect.getValue();
198
199 // If we see an effect on anything other than a value, assume we can't
200 // deduce anything about the last modifications.
201 if (!value)
202 return reset(after);
203
204 value = getMostUnderlyingValue(value, [&](Value value) {
205 return getOrCreateFor<UnderlyingValueLattice>(op, value);
206 });
207 if (!value)
208 return;
209
210 // Nothing to do for reads.
211 if (isa<MemoryEffects::Read>(effect.getEffect()))
212 continue;
213
214 result |= after->set(value, op);
215 }
216 propagateIfChanged(after, result);
217 }
218
219 namespace {
220 struct TestLastModifiedPass
221 : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon39d6bac10311::TestLastModifiedPass222 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
223
224 StringRef getArgument() const override { return "test-last-modified"; }
225
runOnOperation__anon39d6bac10311::TestLastModifiedPass226 void runOnOperation() override {
227 Operation *op = getOperation();
228
229 DataFlowSolver solver;
230 solver.load<DeadCodeAnalysis>();
231 solver.load<SparseConstantPropagation>();
232 solver.load<LastModifiedAnalysis>();
233 solver.load<UnderlyingValueAnalysis>();
234 if (failed(solver.initializeAndRun(op)))
235 return signalPassFailure();
236
237 raw_ostream &os = llvm::errs();
238
239 op->walk([&](Operation *op) {
240 auto tag = op->getAttrOfType<StringAttr>("tag");
241 if (!tag)
242 return;
243 os << "test_tag: " << tag.getValue() << ":\n";
244 const LastModification *lastMods =
245 solver.lookupState<LastModification>(op);
246 assert(lastMods && "expected a dense lattice");
247 for (auto &it : llvm::enumerate(op->getOperands())) {
248 os << " operand #" << it.index() << "\n";
249 Value value = getMostUnderlyingValue(it.value(), [&](Value value) {
250 return solver.lookupState<UnderlyingValueLattice>(value);
251 });
252 assert(value && "expected an underlying value");
253 if (Optional<ArrayRef<Operation *>> lastMod =
254 lastMods->getLastModifiers(value)) {
255 for (Operation *lastModifier : *lastMod) {
256 if (auto tagName =
257 lastModifier->getAttrOfType<StringAttr>("tag_name")) {
258 os << " - " << tagName.getValue() << "\n";
259 } else {
260 os << " - " << lastModifier->getName() << "\n";
261 }
262 }
263 } else {
264 os << " - <unknown>\n";
265 }
266 }
267 });
268 }
269 };
270 } // end anonymous namespace
271
272 namespace mlir {
273 namespace test {
registerTestLastModifiedPass()274 void registerTestLastModifiedPass() {
275 PassRegistration<TestLastModifiedPass>();
276 }
277 } // end namespace test
278 } // end namespace mlir
279