1 //===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/Pass/Pass.h"
13 
14 using namespace mlir;
15 using namespace mlir::dataflow;
16 
17 /// Print the liveness of every block, control-flow edge, and the predecessors
18 /// of all regions, callables, and calls.
printAnalysisResults(DataFlowSolver & solver,Operation * op,raw_ostream & os)19 static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
20                                  raw_ostream &os) {
21   op->walk([&](Operation *op) {
22     auto tag = op->getAttrOfType<StringAttr>("tag");
23     if (!tag)
24       return;
25     os << tag.getValue() << ":\n";
26     for (Region &region : op->getRegions()) {
27       os << " region #" << region.getRegionNumber() << "\n";
28       for (Block &block : region) {
29         os << "  ";
30         block.printAsOperand(os);
31         os << " = ";
32         auto *live = solver.lookupState<Executable>(&block);
33         if (live)
34           os << *live;
35         else
36           os << "dead";
37         os << "\n";
38         for (Block *pred : block.getPredecessors()) {
39           os << "   from ";
40           pred->printAsOperand(os);
41           os << " = ";
42           auto *live = solver.lookupState<Executable>(
43               solver.getProgramPoint<CFGEdge>(pred, &block));
44           if (live)
45             os << *live;
46           else
47             os << "dead";
48           os << "\n";
49         }
50       }
51       if (!region.empty()) {
52         auto *preds = solver.lookupState<PredecessorState>(&region.front());
53         if (preds)
54           os << "region_preds: " << *preds << "\n";
55       }
56     }
57     auto *preds = solver.lookupState<PredecessorState>(op);
58     if (preds)
59       os << "op_preds: " << *preds << "\n";
60   });
61 }
62 
63 namespace {
64 /// This is a simple analysis that implements a transfer function for constant
65 /// operations.
66 struct ConstantAnalysis : public DataFlowAnalysis {
67   using DataFlowAnalysis::DataFlowAnalysis;
68 
initialize__anon68c9b7e90211::ConstantAnalysis69   LogicalResult initialize(Operation *top) override {
70     WalkResult result = top->walk([&](Operation *op) {
71       if (failed(visit(op)))
72         return WalkResult::interrupt();
73       return WalkResult::advance();
74     });
75     return success(!result.wasInterrupted());
76   }
77 
visit__anon68c9b7e90211::ConstantAnalysis78   LogicalResult visit(ProgramPoint point) override {
79     Operation *op = point.get<Operation *>();
80     Attribute value;
81     if (matchPattern(op, m_Constant(&value))) {
82       auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
83       propagateIfChanged(
84           constant, constant->join(ConstantValue(value, op->getDialect())));
85       return success();
86     }
87     markAllPessimisticFixpoint(op->getResults());
88     for (Region &region : op->getRegions())
89       markAllPessimisticFixpoint(region.getArguments());
90     return success();
91   }
92 
93   /// Mark the constant values of all given values as having reached a
94   /// pessimistic fixpoint.
markAllPessimisticFixpoint__anon68c9b7e90211::ConstantAnalysis95   void markAllPessimisticFixpoint(ValueRange values) {
96     for (Value value : values) {
97       auto *constantValue = getOrCreate<Lattice<ConstantValue>>(value);
98       propagateIfChanged(constantValue,
99                          constantValue->markPessimisticFixpoint());
100     }
101   }
102 };
103 
104 /// This is a simple pass that runs dead code analysis with a constant value
105 /// provider that only understands constant operations.
106 struct TestDeadCodeAnalysisPass
107     : public PassWrapper<TestDeadCodeAnalysisPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon68c9b7e90211::TestDeadCodeAnalysisPass108   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass)
109 
110   StringRef getArgument() const override { return "test-dead-code-analysis"; }
111 
runOnOperation__anon68c9b7e90211::TestDeadCodeAnalysisPass112   void runOnOperation() override {
113     Operation *op = getOperation();
114 
115     DataFlowSolver solver;
116     solver.load<DeadCodeAnalysis>();
117     solver.load<ConstantAnalysis>();
118     if (failed(solver.initializeAndRun(op)))
119       return signalPassFailure();
120     printAnalysisResults(solver, op, llvm::errs());
121   }
122 };
123 } // end anonymous namespace
124 
125 namespace mlir {
126 namespace test {
registerTestDeadCodeAnalysisPass()127 void registerTestDeadCodeAnalysisPass() {
128   PassRegistration<TestDeadCodeAnalysisPass>();
129 }
130 } // end namespace test
131 } // end namespace mlir
132