1*d80c271cSMogball //===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
2*d80c271cSMogball //
3*d80c271cSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*d80c271cSMogball // See https://llvm.org/LICENSE.txt for license information.
5*d80c271cSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*d80c271cSMogball //
7*d80c271cSMogball //===----------------------------------------------------------------------===//
8*d80c271cSMogball 
9*d80c271cSMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10*d80c271cSMogball #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11*d80c271cSMogball #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
12*d80c271cSMogball #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
13*d80c271cSMogball #include "mlir/Interfaces/SideEffectInterfaces.h"
14*d80c271cSMogball #include "mlir/Pass/Pass.h"
15*d80c271cSMogball 
16*d80c271cSMogball using namespace mlir;
17*d80c271cSMogball using namespace mlir::dataflow;
18*d80c271cSMogball 
19*d80c271cSMogball namespace {
20*d80c271cSMogball /// This lattice represents a single underlying value for an SSA value.
21*d80c271cSMogball class UnderlyingValue {
22*d80c271cSMogball public:
23*d80c271cSMogball   /// The pessimistic underlying value of a value is itself.
getPessimisticValueState(Value value)24*d80c271cSMogball   static UnderlyingValue getPessimisticValueState(Value value) {
25*d80c271cSMogball     return {value};
26*d80c271cSMogball   }
27*d80c271cSMogball 
28*d80c271cSMogball   /// Create an underlying value state with a known underlying value.
UnderlyingValue(Value underlyingValue={})29*d80c271cSMogball   UnderlyingValue(Value underlyingValue = {})
30*d80c271cSMogball       : underlyingValue(underlyingValue) {}
31*d80c271cSMogball 
32*d80c271cSMogball   /// Returns the underlying value.
getUnderlyingValue() const33*d80c271cSMogball   Value getUnderlyingValue() const { return underlyingValue; }
34*d80c271cSMogball 
35*d80c271cSMogball   /// Join two underlying values. If there are conflicting underlying values,
36*d80c271cSMogball   /// go to the pessimistic value.
join(const UnderlyingValue & lhs,const UnderlyingValue & rhs)37*d80c271cSMogball   static UnderlyingValue join(const UnderlyingValue &lhs,
38*d80c271cSMogball                               const UnderlyingValue &rhs) {
39*d80c271cSMogball     return lhs.underlyingValue == rhs.underlyingValue ? lhs : UnderlyingValue();
40*d80c271cSMogball   }
41*d80c271cSMogball 
42*d80c271cSMogball   /// Compare underlying values.
operator ==(const UnderlyingValue & rhs) const43*d80c271cSMogball   bool operator==(const UnderlyingValue &rhs) const {
44*d80c271cSMogball     return underlyingValue == rhs.underlyingValue;
45*d80c271cSMogball   }
46*d80c271cSMogball 
print(raw_ostream & os) const47*d80c271cSMogball   void print(raw_ostream &os) const { os << underlyingValue; }
48*d80c271cSMogball 
49*d80c271cSMogball private:
50*d80c271cSMogball   Value underlyingValue;
51*d80c271cSMogball };
52*d80c271cSMogball 
53*d80c271cSMogball /// This lattice represents, for a given memory resource, the potential last
54*d80c271cSMogball /// operations that modified the resource.
55*d80c271cSMogball class LastModification : public AbstractDenseLattice {
56*d80c271cSMogball public:
57*d80c271cSMogball   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
58*d80c271cSMogball 
59*d80c271cSMogball   using AbstractDenseLattice::AbstractDenseLattice;
60*d80c271cSMogball 
61*d80c271cSMogball   /// The lattice is always initialized.
isUninitialized() const62*d80c271cSMogball   bool isUninitialized() const override { return false; }
63*d80c271cSMogball 
64*d80c271cSMogball   /// Initialize the lattice. Does nothing.
defaultInitialize()65*d80c271cSMogball   ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
66*d80c271cSMogball 
67*d80c271cSMogball   /// Mark the lattice as having reached its pessimistic fixpoint. That is, the
68*d80c271cSMogball   /// last modifications of all memory resources are unknown.
reset()69*d80c271cSMogball   ChangeResult reset() override {
70*d80c271cSMogball     if (lastMods.empty())
71*d80c271cSMogball       return ChangeResult::NoChange;
72*d80c271cSMogball     lastMods.clear();
73*d80c271cSMogball     return ChangeResult::Change;
74*d80c271cSMogball   }
75*d80c271cSMogball 
76*d80c271cSMogball   /// The lattice is never at a fixpoint.
isAtFixpoint() const77*d80c271cSMogball   bool isAtFixpoint() const override { return false; }
78*d80c271cSMogball 
79*d80c271cSMogball   /// Join the last modifications.
join(const AbstractDenseLattice & lattice)80*d80c271cSMogball   ChangeResult join(const AbstractDenseLattice &lattice) override {
81*d80c271cSMogball     const auto &rhs = static_cast<const LastModification &>(lattice);
82*d80c271cSMogball     ChangeResult result = ChangeResult::NoChange;
83*d80c271cSMogball     for (const auto &mod : rhs.lastMods) {
84*d80c271cSMogball       auto &lhsMod = lastMods[mod.first];
85*d80c271cSMogball       if (lhsMod != mod.second) {
86*d80c271cSMogball         lhsMod.insert(mod.second.begin(), mod.second.end());
87*d80c271cSMogball         result |= ChangeResult::Change;
88*d80c271cSMogball       }
89*d80c271cSMogball     }
90*d80c271cSMogball     return result;
91*d80c271cSMogball   }
92*d80c271cSMogball 
93*d80c271cSMogball   /// Set the last modification of a value.
set(Value value,Operation * op)94*d80c271cSMogball   ChangeResult set(Value value, Operation *op) {
95*d80c271cSMogball     auto &lastMod = lastMods[value];
96*d80c271cSMogball     ChangeResult result = ChangeResult::NoChange;
97*d80c271cSMogball     if (lastMod.size() != 1 || *lastMod.begin() != op) {
98*d80c271cSMogball       result = ChangeResult::Change;
99*d80c271cSMogball       lastMod.clear();
100*d80c271cSMogball       lastMod.insert(op);
101*d80c271cSMogball     }
102*d80c271cSMogball     return result;
103*d80c271cSMogball   }
104*d80c271cSMogball 
105*d80c271cSMogball   /// Get the last modifications of a value. Returns none if the last
106*d80c271cSMogball   /// modifications are not known.
getLastModifiers(Value value) const107*d80c271cSMogball   Optional<ArrayRef<Operation *>> getLastModifiers(Value value) const {
108*d80c271cSMogball     auto it = lastMods.find(value);
109*d80c271cSMogball     if (it == lastMods.end())
110*d80c271cSMogball       return {};
111*d80c271cSMogball     return it->second.getArrayRef();
112*d80c271cSMogball   }
113*d80c271cSMogball 
print(raw_ostream & os) const114*d80c271cSMogball   void print(raw_ostream &os) const override {
115*d80c271cSMogball     for (const auto &lastMod : lastMods) {
116*d80c271cSMogball       os << lastMod.first << ":\n";
117*d80c271cSMogball       for (Operation *op : lastMod.second)
118*d80c271cSMogball         os << "  " << *op << "\n";
119*d80c271cSMogball     }
120*d80c271cSMogball   }
121*d80c271cSMogball 
122*d80c271cSMogball private:
123*d80c271cSMogball   /// The potential last modifications of a memory resource. Use a set vector to
124*d80c271cSMogball   /// keep the results deterministic.
125*d80c271cSMogball   DenseMap<Value, SetVector<Operation *, SmallVector<Operation *, 2>,
126*d80c271cSMogball                             SmallPtrSet<Operation *, 2>>>
127*d80c271cSMogball       lastMods;
128*d80c271cSMogball };
129*d80c271cSMogball 
130*d80c271cSMogball class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
131*d80c271cSMogball public:
132*d80c271cSMogball   using DenseDataFlowAnalysis::DenseDataFlowAnalysis;
133*d80c271cSMogball 
134*d80c271cSMogball   /// Visit an operation. If the operation has no memory effects, then the state
135*d80c271cSMogball   /// is propagated with no change. If the operation allocates a resource, then
136*d80c271cSMogball   /// its reaching definitions is set to empty. If the operation writes to a
137*d80c271cSMogball   /// resource, then its reaching definition is set to the written value.
138*d80c271cSMogball   void visitOperation(Operation *op, const LastModification &before,
139*d80c271cSMogball                       LastModification *after) override;
140*d80c271cSMogball };
141*d80c271cSMogball 
142*d80c271cSMogball /// Define the lattice class explicitly to provide a type ID.
143*d80c271cSMogball struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
144*d80c271cSMogball   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
145*d80c271cSMogball   using Lattice::Lattice;
146*d80c271cSMogball };
147*d80c271cSMogball 
148*d80c271cSMogball /// An analysis that uses forwarding of values along control-flow and callgraph
149*d80c271cSMogball /// edges to determine single underlying values for block arguments. This
150*d80c271cSMogball /// analysis exists so that the test analysis and pass can test the behaviour of
151*d80c271cSMogball /// the dense data-flow analysis on the callgraph.
152*d80c271cSMogball class UnderlyingValueAnalysis
153*d80c271cSMogball     : public SparseDataFlowAnalysis<UnderlyingValueLattice> {
154*d80c271cSMogball public:
155*d80c271cSMogball   using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
156*d80c271cSMogball 
157*d80c271cSMogball   /// The underlying value of the results of an operation are not known.
visitOperation(Operation * op,ArrayRef<const UnderlyingValueLattice * > operands,ArrayRef<UnderlyingValueLattice * > results)158*d80c271cSMogball   void visitOperation(Operation *op,
159*d80c271cSMogball                       ArrayRef<const UnderlyingValueLattice *> operands,
160*d80c271cSMogball                       ArrayRef<UnderlyingValueLattice *> results) override {
161*d80c271cSMogball     markAllPessimisticFixpoint(results);
162*d80c271cSMogball   }
163*d80c271cSMogball };
164*d80c271cSMogball } // end anonymous namespace
165*d80c271cSMogball 
166*d80c271cSMogball /// Look for the most underlying value of a value.
getMostUnderlyingValue(Value value,function_ref<const UnderlyingValueLattice * (Value)> getUnderlyingValueFn)167*d80c271cSMogball static Value getMostUnderlyingValue(
168*d80c271cSMogball     Value value,
169*d80c271cSMogball     function_ref<const UnderlyingValueLattice *(Value)> getUnderlyingValueFn) {
170*d80c271cSMogball   const UnderlyingValueLattice *underlying;
171*d80c271cSMogball   do {
172*d80c271cSMogball     underlying = getUnderlyingValueFn(value);
173*d80c271cSMogball     if (!underlying || underlying->isUninitialized())
174*d80c271cSMogball       return {};
175*d80c271cSMogball     Value underlyingValue = underlying->getValue().getUnderlyingValue();
176*d80c271cSMogball     if (underlyingValue == value)
177*d80c271cSMogball       break;
178*d80c271cSMogball     value = underlyingValue;
179*d80c271cSMogball   } while (true);
180*d80c271cSMogball   return value;
181*d80c271cSMogball }
182*d80c271cSMogball 
visitOperation(Operation * op,const LastModification & before,LastModification * after)183*d80c271cSMogball void LastModifiedAnalysis::visitOperation(Operation *op,
184*d80c271cSMogball                                           const LastModification &before,
185*d80c271cSMogball                                           LastModification *after) {
186*d80c271cSMogball   auto memory = dyn_cast<MemoryEffectOpInterface>(op);
187*d80c271cSMogball   // If we can't reason about the memory effects, then conservatively assume we
188*d80c271cSMogball   // can't deduce anything about the last modifications.
189*d80c271cSMogball   if (!memory)
190*d80c271cSMogball     return reset(after);
191*d80c271cSMogball 
192*d80c271cSMogball   SmallVector<MemoryEffects::EffectInstance> effects;
193*d80c271cSMogball   memory.getEffects(effects);
194*d80c271cSMogball 
195*d80c271cSMogball   ChangeResult result = after->join(before);
196*d80c271cSMogball   for (const auto &effect : effects) {
197*d80c271cSMogball     Value value = effect.getValue();
198*d80c271cSMogball 
199*d80c271cSMogball     // If we see an effect on anything other than a value, assume we can't
200*d80c271cSMogball     // deduce anything about the last modifications.
201*d80c271cSMogball     if (!value)
202*d80c271cSMogball       return reset(after);
203*d80c271cSMogball 
204*d80c271cSMogball     value = getMostUnderlyingValue(value, [&](Value value) {
205*d80c271cSMogball       return getOrCreateFor<UnderlyingValueLattice>(op, value);
206*d80c271cSMogball     });
207*d80c271cSMogball     if (!value)
208*d80c271cSMogball       return;
209*d80c271cSMogball 
210*d80c271cSMogball     // Nothing to do for reads.
211*d80c271cSMogball     if (isa<MemoryEffects::Read>(effect.getEffect()))
212*d80c271cSMogball       continue;
213*d80c271cSMogball 
214*d80c271cSMogball     result |= after->set(value, op);
215*d80c271cSMogball   }
216*d80c271cSMogball   propagateIfChanged(after, result);
217*d80c271cSMogball }
218*d80c271cSMogball 
219*d80c271cSMogball namespace {
220*d80c271cSMogball struct TestLastModifiedPass
221*d80c271cSMogball     : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon39d6bac10311::TestLastModifiedPass222*d80c271cSMogball   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
223*d80c271cSMogball 
224*d80c271cSMogball   StringRef getArgument() const override { return "test-last-modified"; }
225*d80c271cSMogball 
runOnOperation__anon39d6bac10311::TestLastModifiedPass226*d80c271cSMogball   void runOnOperation() override {
227*d80c271cSMogball     Operation *op = getOperation();
228*d80c271cSMogball 
229*d80c271cSMogball     DataFlowSolver solver;
230*d80c271cSMogball     solver.load<DeadCodeAnalysis>();
231*d80c271cSMogball     solver.load<SparseConstantPropagation>();
232*d80c271cSMogball     solver.load<LastModifiedAnalysis>();
233*d80c271cSMogball     solver.load<UnderlyingValueAnalysis>();
234*d80c271cSMogball     if (failed(solver.initializeAndRun(op)))
235*d80c271cSMogball       return signalPassFailure();
236*d80c271cSMogball 
237*d80c271cSMogball     raw_ostream &os = llvm::errs();
238*d80c271cSMogball 
239*d80c271cSMogball     op->walk([&](Operation *op) {
240*d80c271cSMogball       auto tag = op->getAttrOfType<StringAttr>("tag");
241*d80c271cSMogball       if (!tag)
242*d80c271cSMogball         return;
243*d80c271cSMogball       os << "test_tag: " << tag.getValue() << ":\n";
244*d80c271cSMogball       const LastModification *lastMods =
245*d80c271cSMogball           solver.lookupState<LastModification>(op);
246*d80c271cSMogball       assert(lastMods && "expected a dense lattice");
247*d80c271cSMogball       for (auto &it : llvm::enumerate(op->getOperands())) {
248*d80c271cSMogball         os << " operand #" << it.index() << "\n";
249*d80c271cSMogball         Value value = getMostUnderlyingValue(it.value(), [&](Value value) {
250*d80c271cSMogball           return solver.lookupState<UnderlyingValueLattice>(value);
251*d80c271cSMogball         });
252*d80c271cSMogball         assert(value && "expected an underlying value");
253*d80c271cSMogball         if (Optional<ArrayRef<Operation *>> lastMod =
254*d80c271cSMogball                 lastMods->getLastModifiers(value)) {
255*d80c271cSMogball           for (Operation *lastModifier : *lastMod) {
256*d80c271cSMogball             if (auto tagName =
257*d80c271cSMogball                     lastModifier->getAttrOfType<StringAttr>("tag_name")) {
258*d80c271cSMogball               os << "  - " << tagName.getValue() << "\n";
259*d80c271cSMogball             } else {
260*d80c271cSMogball               os << "  - " << lastModifier->getName() << "\n";
261*d80c271cSMogball             }
262*d80c271cSMogball           }
263*d80c271cSMogball         } else {
264*d80c271cSMogball           os << "  - <unknown>\n";
265*d80c271cSMogball         }
266*d80c271cSMogball       }
267*d80c271cSMogball     });
268*d80c271cSMogball   }
269*d80c271cSMogball };
270*d80c271cSMogball } // end anonymous namespace
271*d80c271cSMogball 
272*d80c271cSMogball namespace mlir {
273*d80c271cSMogball namespace test {
registerTestLastModifiedPass()274*d80c271cSMogball void registerTestLastModifiedPass() {
275*d80c271cSMogball   PassRegistration<TestLastModifiedPass>();
276*d80c271cSMogball }
277*d80c271cSMogball } // end namespace test
278*d80c271cSMogball } // end namespace mlir
279