1 //===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
12 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15 
16 using namespace mlir;
17 using namespace mlir::dataflow;
18 
19 namespace {
20 /// This lattice represents a single underlying value for an SSA value.
21 class UnderlyingValue {
22 public:
23   /// The pessimistic underlying value of a value is itself.
getPessimisticValueState(Value value)24   static UnderlyingValue getPessimisticValueState(Value value) {
25     return {value};
26   }
27 
28   /// Create an underlying value state with a known underlying value.
UnderlyingValue(Value underlyingValue={})29   UnderlyingValue(Value underlyingValue = {})
30       : underlyingValue(underlyingValue) {}
31 
32   /// Returns the underlying value.
getUnderlyingValue() const33   Value getUnderlyingValue() const { return underlyingValue; }
34 
35   /// Join two underlying values. If there are conflicting underlying values,
36   /// go to the pessimistic value.
join(const UnderlyingValue & lhs,const UnderlyingValue & rhs)37   static UnderlyingValue join(const UnderlyingValue &lhs,
38                               const UnderlyingValue &rhs) {
39     return lhs.underlyingValue == rhs.underlyingValue ? lhs : UnderlyingValue();
40   }
41 
42   /// Compare underlying values.
operator ==(const UnderlyingValue & rhs) const43   bool operator==(const UnderlyingValue &rhs) const {
44     return underlyingValue == rhs.underlyingValue;
45   }
46 
print(raw_ostream & os) const47   void print(raw_ostream &os) const { os << underlyingValue; }
48 
49 private:
50   Value underlyingValue;
51 };
52 
53 /// This lattice represents, for a given memory resource, the potential last
54 /// operations that modified the resource.
55 class LastModification : public AbstractDenseLattice {
56 public:
57   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
58 
59   using AbstractDenseLattice::AbstractDenseLattice;
60 
61   /// The lattice is always initialized.
isUninitialized() const62   bool isUninitialized() const override { return false; }
63 
64   /// Initialize the lattice. Does nothing.
defaultInitialize()65   ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
66 
67   /// Mark the lattice as having reached its pessimistic fixpoint. That is, the
68   /// last modifications of all memory resources are unknown.
reset()69   ChangeResult reset() override {
70     if (lastMods.empty())
71       return ChangeResult::NoChange;
72     lastMods.clear();
73     return ChangeResult::Change;
74   }
75 
76   /// The lattice is never at a fixpoint.
isAtFixpoint() const77   bool isAtFixpoint() const override { return false; }
78 
79   /// Join the last modifications.
join(const AbstractDenseLattice & lattice)80   ChangeResult join(const AbstractDenseLattice &lattice) override {
81     const auto &rhs = static_cast<const LastModification &>(lattice);
82     ChangeResult result = ChangeResult::NoChange;
83     for (const auto &mod : rhs.lastMods) {
84       auto &lhsMod = lastMods[mod.first];
85       if (lhsMod != mod.second) {
86         lhsMod.insert(mod.second.begin(), mod.second.end());
87         result |= ChangeResult::Change;
88       }
89     }
90     return result;
91   }
92 
93   /// Set the last modification of a value.
set(Value value,Operation * op)94   ChangeResult set(Value value, Operation *op) {
95     auto &lastMod = lastMods[value];
96     ChangeResult result = ChangeResult::NoChange;
97     if (lastMod.size() != 1 || *lastMod.begin() != op) {
98       result = ChangeResult::Change;
99       lastMod.clear();
100       lastMod.insert(op);
101     }
102     return result;
103   }
104 
105   /// Get the last modifications of a value. Returns none if the last
106   /// modifications are not known.
getLastModifiers(Value value) const107   Optional<ArrayRef<Operation *>> getLastModifiers(Value value) const {
108     auto it = lastMods.find(value);
109     if (it == lastMods.end())
110       return {};
111     return it->second.getArrayRef();
112   }
113 
print(raw_ostream & os) const114   void print(raw_ostream &os) const override {
115     for (const auto &lastMod : lastMods) {
116       os << lastMod.first << ":\n";
117       for (Operation *op : lastMod.second)
118         os << "  " << *op << "\n";
119     }
120   }
121 
122 private:
123   /// The potential last modifications of a memory resource. Use a set vector to
124   /// keep the results deterministic.
125   DenseMap<Value, SetVector<Operation *, SmallVector<Operation *, 2>,
126                             SmallPtrSet<Operation *, 2>>>
127       lastMods;
128 };
129 
130 class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
131 public:
132   using DenseDataFlowAnalysis::DenseDataFlowAnalysis;
133 
134   /// Visit an operation. If the operation has no memory effects, then the state
135   /// is propagated with no change. If the operation allocates a resource, then
136   /// its reaching definitions is set to empty. If the operation writes to a
137   /// resource, then its reaching definition is set to the written value.
138   void visitOperation(Operation *op, const LastModification &before,
139                       LastModification *after) override;
140 };
141 
142 /// Define the lattice class explicitly to provide a type ID.
143 struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
144   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
145   using Lattice::Lattice;
146 };
147 
148 /// An analysis that uses forwarding of values along control-flow and callgraph
149 /// edges to determine single underlying values for block arguments. This
150 /// analysis exists so that the test analysis and pass can test the behaviour of
151 /// the dense data-flow analysis on the callgraph.
152 class UnderlyingValueAnalysis
153     : public SparseDataFlowAnalysis<UnderlyingValueLattice> {
154 public:
155   using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
156 
157   /// The underlying value of the results of an operation are not known.
visitOperation(Operation * op,ArrayRef<const UnderlyingValueLattice * > operands,ArrayRef<UnderlyingValueLattice * > results)158   void visitOperation(Operation *op,
159                       ArrayRef<const UnderlyingValueLattice *> operands,
160                       ArrayRef<UnderlyingValueLattice *> results) override {
161     markAllPessimisticFixpoint(results);
162   }
163 };
164 } // end anonymous namespace
165 
166 /// Look for the most underlying value of a value.
getMostUnderlyingValue(Value value,function_ref<const UnderlyingValueLattice * (Value)> getUnderlyingValueFn)167 static Value getMostUnderlyingValue(
168     Value value,
169     function_ref<const UnderlyingValueLattice *(Value)> getUnderlyingValueFn) {
170   const UnderlyingValueLattice *underlying;
171   do {
172     underlying = getUnderlyingValueFn(value);
173     if (!underlying || underlying->isUninitialized())
174       return {};
175     Value underlyingValue = underlying->getValue().getUnderlyingValue();
176     if (underlyingValue == value)
177       break;
178     value = underlyingValue;
179   } while (true);
180   return value;
181 }
182 
visitOperation(Operation * op,const LastModification & before,LastModification * after)183 void LastModifiedAnalysis::visitOperation(Operation *op,
184                                           const LastModification &before,
185                                           LastModification *after) {
186   auto memory = dyn_cast<MemoryEffectOpInterface>(op);
187   // If we can't reason about the memory effects, then conservatively assume we
188   // can't deduce anything about the last modifications.
189   if (!memory)
190     return reset(after);
191 
192   SmallVector<MemoryEffects::EffectInstance> effects;
193   memory.getEffects(effects);
194 
195   ChangeResult result = after->join(before);
196   for (const auto &effect : effects) {
197     Value value = effect.getValue();
198 
199     // If we see an effect on anything other than a value, assume we can't
200     // deduce anything about the last modifications.
201     if (!value)
202       return reset(after);
203 
204     value = getMostUnderlyingValue(value, [&](Value value) {
205       return getOrCreateFor<UnderlyingValueLattice>(op, value);
206     });
207     if (!value)
208       return;
209 
210     // Nothing to do for reads.
211     if (isa<MemoryEffects::Read>(effect.getEffect()))
212       continue;
213 
214     result |= after->set(value, op);
215   }
216   propagateIfChanged(after, result);
217 }
218 
219 namespace {
220 struct TestLastModifiedPass
221     : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon39d6bac10311::TestLastModifiedPass222   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
223 
224   StringRef getArgument() const override { return "test-last-modified"; }
225 
runOnOperation__anon39d6bac10311::TestLastModifiedPass226   void runOnOperation() override {
227     Operation *op = getOperation();
228 
229     DataFlowSolver solver;
230     solver.load<DeadCodeAnalysis>();
231     solver.load<SparseConstantPropagation>();
232     solver.load<LastModifiedAnalysis>();
233     solver.load<UnderlyingValueAnalysis>();
234     if (failed(solver.initializeAndRun(op)))
235       return signalPassFailure();
236 
237     raw_ostream &os = llvm::errs();
238 
239     op->walk([&](Operation *op) {
240       auto tag = op->getAttrOfType<StringAttr>("tag");
241       if (!tag)
242         return;
243       os << "test_tag: " << tag.getValue() << ":\n";
244       const LastModification *lastMods =
245           solver.lookupState<LastModification>(op);
246       assert(lastMods && "expected a dense lattice");
247       for (auto &it : llvm::enumerate(op->getOperands())) {
248         os << " operand #" << it.index() << "\n";
249         Value value = getMostUnderlyingValue(it.value(), [&](Value value) {
250           return solver.lookupState<UnderlyingValueLattice>(value);
251         });
252         assert(value && "expected an underlying value");
253         if (Optional<ArrayRef<Operation *>> lastMod =
254                 lastMods->getLastModifiers(value)) {
255           for (Operation *lastModifier : *lastMod) {
256             if (auto tagName =
257                     lastModifier->getAttrOfType<StringAttr>("tag_name")) {
258               os << "  - " << tagName.getValue() << "\n";
259             } else {
260               os << "  - " << lastModifier->getName() << "\n";
261             }
262           }
263         } else {
264           os << "  - <unknown>\n";
265         }
266       }
267     });
268   }
269 };
270 } // end anonymous namespace
271 
272 namespace mlir {
273 namespace test {
registerTestLastModifiedPass()274 void registerTestLastModifiedPass() {
275   PassRegistration<TestLastModifiedPass>();
276 }
277 } // end namespace test
278 } // end namespace mlir
279