1*d80c271cSMogball //===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
2*d80c271cSMogball //
3*d80c271cSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*d80c271cSMogball // See https://llvm.org/LICENSE.txt for license information.
5*d80c271cSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*d80c271cSMogball //
7*d80c271cSMogball //===----------------------------------------------------------------------===//
8*d80c271cSMogball
9*d80c271cSMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10*d80c271cSMogball #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11*d80c271cSMogball #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
12*d80c271cSMogball #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
13*d80c271cSMogball #include "mlir/Interfaces/SideEffectInterfaces.h"
14*d80c271cSMogball #include "mlir/Pass/Pass.h"
15*d80c271cSMogball
16*d80c271cSMogball using namespace mlir;
17*d80c271cSMogball using namespace mlir::dataflow;
18*d80c271cSMogball
19*d80c271cSMogball namespace {
20*d80c271cSMogball /// This lattice represents a single underlying value for an SSA value.
21*d80c271cSMogball class UnderlyingValue {
22*d80c271cSMogball public:
23*d80c271cSMogball /// The pessimistic underlying value of a value is itself.
getPessimisticValueState(Value value)24*d80c271cSMogball static UnderlyingValue getPessimisticValueState(Value value) {
25*d80c271cSMogball return {value};
26*d80c271cSMogball }
27*d80c271cSMogball
28*d80c271cSMogball /// Create an underlying value state with a known underlying value.
UnderlyingValue(Value underlyingValue={})29*d80c271cSMogball UnderlyingValue(Value underlyingValue = {})
30*d80c271cSMogball : underlyingValue(underlyingValue) {}
31*d80c271cSMogball
32*d80c271cSMogball /// Returns the underlying value.
getUnderlyingValue() const33*d80c271cSMogball Value getUnderlyingValue() const { return underlyingValue; }
34*d80c271cSMogball
35*d80c271cSMogball /// Join two underlying values. If there are conflicting underlying values,
36*d80c271cSMogball /// go to the pessimistic value.
join(const UnderlyingValue & lhs,const UnderlyingValue & rhs)37*d80c271cSMogball static UnderlyingValue join(const UnderlyingValue &lhs,
38*d80c271cSMogball const UnderlyingValue &rhs) {
39*d80c271cSMogball return lhs.underlyingValue == rhs.underlyingValue ? lhs : UnderlyingValue();
40*d80c271cSMogball }
41*d80c271cSMogball
42*d80c271cSMogball /// Compare underlying values.
operator ==(const UnderlyingValue & rhs) const43*d80c271cSMogball bool operator==(const UnderlyingValue &rhs) const {
44*d80c271cSMogball return underlyingValue == rhs.underlyingValue;
45*d80c271cSMogball }
46*d80c271cSMogball
print(raw_ostream & os) const47*d80c271cSMogball void print(raw_ostream &os) const { os << underlyingValue; }
48*d80c271cSMogball
49*d80c271cSMogball private:
50*d80c271cSMogball Value underlyingValue;
51*d80c271cSMogball };
52*d80c271cSMogball
53*d80c271cSMogball /// This lattice represents, for a given memory resource, the potential last
54*d80c271cSMogball /// operations that modified the resource.
55*d80c271cSMogball class LastModification : public AbstractDenseLattice {
56*d80c271cSMogball public:
57*d80c271cSMogball MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
58*d80c271cSMogball
59*d80c271cSMogball using AbstractDenseLattice::AbstractDenseLattice;
60*d80c271cSMogball
61*d80c271cSMogball /// The lattice is always initialized.
isUninitialized() const62*d80c271cSMogball bool isUninitialized() const override { return false; }
63*d80c271cSMogball
64*d80c271cSMogball /// Initialize the lattice. Does nothing.
defaultInitialize()65*d80c271cSMogball ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
66*d80c271cSMogball
67*d80c271cSMogball /// Mark the lattice as having reached its pessimistic fixpoint. That is, the
68*d80c271cSMogball /// last modifications of all memory resources are unknown.
reset()69*d80c271cSMogball ChangeResult reset() override {
70*d80c271cSMogball if (lastMods.empty())
71*d80c271cSMogball return ChangeResult::NoChange;
72*d80c271cSMogball lastMods.clear();
73*d80c271cSMogball return ChangeResult::Change;
74*d80c271cSMogball }
75*d80c271cSMogball
76*d80c271cSMogball /// The lattice is never at a fixpoint.
isAtFixpoint() const77*d80c271cSMogball bool isAtFixpoint() const override { return false; }
78*d80c271cSMogball
79*d80c271cSMogball /// Join the last modifications.
join(const AbstractDenseLattice & lattice)80*d80c271cSMogball ChangeResult join(const AbstractDenseLattice &lattice) override {
81*d80c271cSMogball const auto &rhs = static_cast<const LastModification &>(lattice);
82*d80c271cSMogball ChangeResult result = ChangeResult::NoChange;
83*d80c271cSMogball for (const auto &mod : rhs.lastMods) {
84*d80c271cSMogball auto &lhsMod = lastMods[mod.first];
85*d80c271cSMogball if (lhsMod != mod.second) {
86*d80c271cSMogball lhsMod.insert(mod.second.begin(), mod.second.end());
87*d80c271cSMogball result |= ChangeResult::Change;
88*d80c271cSMogball }
89*d80c271cSMogball }
90*d80c271cSMogball return result;
91*d80c271cSMogball }
92*d80c271cSMogball
93*d80c271cSMogball /// Set the last modification of a value.
set(Value value,Operation * op)94*d80c271cSMogball ChangeResult set(Value value, Operation *op) {
95*d80c271cSMogball auto &lastMod = lastMods[value];
96*d80c271cSMogball ChangeResult result = ChangeResult::NoChange;
97*d80c271cSMogball if (lastMod.size() != 1 || *lastMod.begin() != op) {
98*d80c271cSMogball result = ChangeResult::Change;
99*d80c271cSMogball lastMod.clear();
100*d80c271cSMogball lastMod.insert(op);
101*d80c271cSMogball }
102*d80c271cSMogball return result;
103*d80c271cSMogball }
104*d80c271cSMogball
105*d80c271cSMogball /// Get the last modifications of a value. Returns none if the last
106*d80c271cSMogball /// modifications are not known.
getLastModifiers(Value value) const107*d80c271cSMogball Optional<ArrayRef<Operation *>> getLastModifiers(Value value) const {
108*d80c271cSMogball auto it = lastMods.find(value);
109*d80c271cSMogball if (it == lastMods.end())
110*d80c271cSMogball return {};
111*d80c271cSMogball return it->second.getArrayRef();
112*d80c271cSMogball }
113*d80c271cSMogball
print(raw_ostream & os) const114*d80c271cSMogball void print(raw_ostream &os) const override {
115*d80c271cSMogball for (const auto &lastMod : lastMods) {
116*d80c271cSMogball os << lastMod.first << ":\n";
117*d80c271cSMogball for (Operation *op : lastMod.second)
118*d80c271cSMogball os << " " << *op << "\n";
119*d80c271cSMogball }
120*d80c271cSMogball }
121*d80c271cSMogball
122*d80c271cSMogball private:
123*d80c271cSMogball /// The potential last modifications of a memory resource. Use a set vector to
124*d80c271cSMogball /// keep the results deterministic.
125*d80c271cSMogball DenseMap<Value, SetVector<Operation *, SmallVector<Operation *, 2>,
126*d80c271cSMogball SmallPtrSet<Operation *, 2>>>
127*d80c271cSMogball lastMods;
128*d80c271cSMogball };
129*d80c271cSMogball
130*d80c271cSMogball class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
131*d80c271cSMogball public:
132*d80c271cSMogball using DenseDataFlowAnalysis::DenseDataFlowAnalysis;
133*d80c271cSMogball
134*d80c271cSMogball /// Visit an operation. If the operation has no memory effects, then the state
135*d80c271cSMogball /// is propagated with no change. If the operation allocates a resource, then
136*d80c271cSMogball /// its reaching definitions is set to empty. If the operation writes to a
137*d80c271cSMogball /// resource, then its reaching definition is set to the written value.
138*d80c271cSMogball void visitOperation(Operation *op, const LastModification &before,
139*d80c271cSMogball LastModification *after) override;
140*d80c271cSMogball };
141*d80c271cSMogball
142*d80c271cSMogball /// Define the lattice class explicitly to provide a type ID.
143*d80c271cSMogball struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
144*d80c271cSMogball MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
145*d80c271cSMogball using Lattice::Lattice;
146*d80c271cSMogball };
147*d80c271cSMogball
148*d80c271cSMogball /// An analysis that uses forwarding of values along control-flow and callgraph
149*d80c271cSMogball /// edges to determine single underlying values for block arguments. This
150*d80c271cSMogball /// analysis exists so that the test analysis and pass can test the behaviour of
151*d80c271cSMogball /// the dense data-flow analysis on the callgraph.
152*d80c271cSMogball class UnderlyingValueAnalysis
153*d80c271cSMogball : public SparseDataFlowAnalysis<UnderlyingValueLattice> {
154*d80c271cSMogball public:
155*d80c271cSMogball using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
156*d80c271cSMogball
157*d80c271cSMogball /// The underlying value of the results of an operation are not known.
visitOperation(Operation * op,ArrayRef<const UnderlyingValueLattice * > operands,ArrayRef<UnderlyingValueLattice * > results)158*d80c271cSMogball void visitOperation(Operation *op,
159*d80c271cSMogball ArrayRef<const UnderlyingValueLattice *> operands,
160*d80c271cSMogball ArrayRef<UnderlyingValueLattice *> results) override {
161*d80c271cSMogball markAllPessimisticFixpoint(results);
162*d80c271cSMogball }
163*d80c271cSMogball };
164*d80c271cSMogball } // end anonymous namespace
165*d80c271cSMogball
166*d80c271cSMogball /// Look for the most underlying value of a value.
getMostUnderlyingValue(Value value,function_ref<const UnderlyingValueLattice * (Value)> getUnderlyingValueFn)167*d80c271cSMogball static Value getMostUnderlyingValue(
168*d80c271cSMogball Value value,
169*d80c271cSMogball function_ref<const UnderlyingValueLattice *(Value)> getUnderlyingValueFn) {
170*d80c271cSMogball const UnderlyingValueLattice *underlying;
171*d80c271cSMogball do {
172*d80c271cSMogball underlying = getUnderlyingValueFn(value);
173*d80c271cSMogball if (!underlying || underlying->isUninitialized())
174*d80c271cSMogball return {};
175*d80c271cSMogball Value underlyingValue = underlying->getValue().getUnderlyingValue();
176*d80c271cSMogball if (underlyingValue == value)
177*d80c271cSMogball break;
178*d80c271cSMogball value = underlyingValue;
179*d80c271cSMogball } while (true);
180*d80c271cSMogball return value;
181*d80c271cSMogball }
182*d80c271cSMogball
visitOperation(Operation * op,const LastModification & before,LastModification * after)183*d80c271cSMogball void LastModifiedAnalysis::visitOperation(Operation *op,
184*d80c271cSMogball const LastModification &before,
185*d80c271cSMogball LastModification *after) {
186*d80c271cSMogball auto memory = dyn_cast<MemoryEffectOpInterface>(op);
187*d80c271cSMogball // If we can't reason about the memory effects, then conservatively assume we
188*d80c271cSMogball // can't deduce anything about the last modifications.
189*d80c271cSMogball if (!memory)
190*d80c271cSMogball return reset(after);
191*d80c271cSMogball
192*d80c271cSMogball SmallVector<MemoryEffects::EffectInstance> effects;
193*d80c271cSMogball memory.getEffects(effects);
194*d80c271cSMogball
195*d80c271cSMogball ChangeResult result = after->join(before);
196*d80c271cSMogball for (const auto &effect : effects) {
197*d80c271cSMogball Value value = effect.getValue();
198*d80c271cSMogball
199*d80c271cSMogball // If we see an effect on anything other than a value, assume we can't
200*d80c271cSMogball // deduce anything about the last modifications.
201*d80c271cSMogball if (!value)
202*d80c271cSMogball return reset(after);
203*d80c271cSMogball
204*d80c271cSMogball value = getMostUnderlyingValue(value, [&](Value value) {
205*d80c271cSMogball return getOrCreateFor<UnderlyingValueLattice>(op, value);
206*d80c271cSMogball });
207*d80c271cSMogball if (!value)
208*d80c271cSMogball return;
209*d80c271cSMogball
210*d80c271cSMogball // Nothing to do for reads.
211*d80c271cSMogball if (isa<MemoryEffects::Read>(effect.getEffect()))
212*d80c271cSMogball continue;
213*d80c271cSMogball
214*d80c271cSMogball result |= after->set(value, op);
215*d80c271cSMogball }
216*d80c271cSMogball propagateIfChanged(after, result);
217*d80c271cSMogball }
218*d80c271cSMogball
219*d80c271cSMogball namespace {
220*d80c271cSMogball struct TestLastModifiedPass
221*d80c271cSMogball : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon39d6bac10311::TestLastModifiedPass222*d80c271cSMogball MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
223*d80c271cSMogball
224*d80c271cSMogball StringRef getArgument() const override { return "test-last-modified"; }
225*d80c271cSMogball
runOnOperation__anon39d6bac10311::TestLastModifiedPass226*d80c271cSMogball void runOnOperation() override {
227*d80c271cSMogball Operation *op = getOperation();
228*d80c271cSMogball
229*d80c271cSMogball DataFlowSolver solver;
230*d80c271cSMogball solver.load<DeadCodeAnalysis>();
231*d80c271cSMogball solver.load<SparseConstantPropagation>();
232*d80c271cSMogball solver.load<LastModifiedAnalysis>();
233*d80c271cSMogball solver.load<UnderlyingValueAnalysis>();
234*d80c271cSMogball if (failed(solver.initializeAndRun(op)))
235*d80c271cSMogball return signalPassFailure();
236*d80c271cSMogball
237*d80c271cSMogball raw_ostream &os = llvm::errs();
238*d80c271cSMogball
239*d80c271cSMogball op->walk([&](Operation *op) {
240*d80c271cSMogball auto tag = op->getAttrOfType<StringAttr>("tag");
241*d80c271cSMogball if (!tag)
242*d80c271cSMogball return;
243*d80c271cSMogball os << "test_tag: " << tag.getValue() << ":\n";
244*d80c271cSMogball const LastModification *lastMods =
245*d80c271cSMogball solver.lookupState<LastModification>(op);
246*d80c271cSMogball assert(lastMods && "expected a dense lattice");
247*d80c271cSMogball for (auto &it : llvm::enumerate(op->getOperands())) {
248*d80c271cSMogball os << " operand #" << it.index() << "\n";
249*d80c271cSMogball Value value = getMostUnderlyingValue(it.value(), [&](Value value) {
250*d80c271cSMogball return solver.lookupState<UnderlyingValueLattice>(value);
251*d80c271cSMogball });
252*d80c271cSMogball assert(value && "expected an underlying value");
253*d80c271cSMogball if (Optional<ArrayRef<Operation *>> lastMod =
254*d80c271cSMogball lastMods->getLastModifiers(value)) {
255*d80c271cSMogball for (Operation *lastModifier : *lastMod) {
256*d80c271cSMogball if (auto tagName =
257*d80c271cSMogball lastModifier->getAttrOfType<StringAttr>("tag_name")) {
258*d80c271cSMogball os << " - " << tagName.getValue() << "\n";
259*d80c271cSMogball } else {
260*d80c271cSMogball os << " - " << lastModifier->getName() << "\n";
261*d80c271cSMogball }
262*d80c271cSMogball }
263*d80c271cSMogball } else {
264*d80c271cSMogball os << " - <unknown>\n";
265*d80c271cSMogball }
266*d80c271cSMogball }
267*d80c271cSMogball });
268*d80c271cSMogball }
269*d80c271cSMogball };
270*d80c271cSMogball } // end anonymous namespace
271*d80c271cSMogball
272*d80c271cSMogball namespace mlir {
273*d80c271cSMogball namespace test {
registerTestLastModifiedPass()274*d80c271cSMogball void registerTestLastModifiedPass() {
275*d80c271cSMogball PassRegistration<TestLastModifiedPass>();
276*d80c271cSMogball }
277*d80c271cSMogball } // end namespace test
278*d80c271cSMogball } // end namespace mlir
279