1 //===- TestDominance.cpp - Test dominance construction and information
2 //-------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains test passes for constructing and resolving dominance
11 // information.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/IR/Dominance.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "mlir/Pass/Pass.h"
18 
19 using namespace mlir;
20 
21 /// Overloaded helper to call the right function based on whether we are testing
22 /// dominance or post-dominance.
23 static bool dominatesOrPostDominates(DominanceInfo &dominanceInfo, Block *a,
24                                      Block *b) {
25   return dominanceInfo.dominates(a, b);
26 }
27 
28 static bool dominatesOrPostDominates(PostDominanceInfo &dominanceInfo, Block *a,
29                                      Block *b) {
30   return dominanceInfo.postDominates(a, b);
31 }
32 
33 namespace {
34 
35 /// Helper class to print dominance information.
36 class DominanceTest {
37 public:
38   /// Constructs a new test instance using the given operation.
39   DominanceTest(Operation *operation) : operation(operation) {
40     // Create unique ids for each block.
41     operation->walk([&](Operation *nested) {
42       if (blockIds.count(nested->getBlock()) > 0)
43         return;
44       blockIds.insert({nested->getBlock(), blockIds.size()});
45     });
46   }
47 
48   /// Prints dominance information of all blocks.
49   template <typename DominanceT>
50   void printDominance(DominanceT &dominanceInfo,
51                       bool printCommonDominatorInfo) {
52     DenseSet<Block *> parentVisited;
53     operation->walk([&](Operation *op) {
54       Block *block = op->getBlock();
55       if (!parentVisited.insert(block).second)
56         return;
57 
58       DenseSet<Block *> visited;
59       operation->walk([&](Operation *nested) {
60         Block *nestedBlock = nested->getBlock();
61         if (!visited.insert(nestedBlock).second)
62           return;
63         if (printCommonDominatorInfo) {
64           llvm::errs() << "Nearest(" << blockIds[block] << ", "
65                        << blockIds[nestedBlock] << ") = ";
66           Block *dom =
67               dominanceInfo.findNearestCommonDominator(block, nestedBlock);
68           if (dom)
69             llvm::errs() << blockIds[dom];
70           else
71             llvm::errs() << "<no dom>";
72           llvm::errs() << "\n";
73         } else {
74           if (std::is_same<DominanceInfo, DominanceT>::value)
75             llvm::errs() << "dominates(";
76           else
77             llvm::errs() << "postdominates(";
78           llvm::errs() << blockIds[block] << ", " << blockIds[nestedBlock]
79                        << ") = ";
80           if (dominatesOrPostDominates(dominanceInfo, block, nestedBlock))
81             llvm::errs() << "true\n";
82           else
83             llvm::errs() << "false\n";
84         }
85       });
86     });
87   }
88 
89 private:
90   Operation *operation;
91   DenseMap<Block *, size_t> blockIds;
92 };
93 
94 struct TestDominancePass
95     : public PassWrapper<TestDominancePass, InterfacePass<SymbolOpInterface>> {
96   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDominancePass)
97 
98   StringRef getArgument() const final { return "test-print-dominance"; }
99   StringRef getDescription() const final {
100     return "Print the dominance information for multiple regions.";
101   }
102 
103   void runOnOperation() override {
104     llvm::errs() << "Testing : " << getOperation().getName() << "\n";
105     DominanceTest dominanceTest(getOperation());
106 
107     // Print dominance information.
108     llvm::errs() << "--- DominanceInfo ---\n";
109     dominanceTest.printDominance(getAnalysis<DominanceInfo>(),
110                                  /*printCommonDominatorInfo=*/true);
111 
112     llvm::errs() << "--- PostDominanceInfo ---\n";
113     dominanceTest.printDominance(getAnalysis<PostDominanceInfo>(),
114                                  /*printCommonDominatorInfo=*/true);
115 
116     // Print dominance relationship between blocks.
117     llvm::errs() << "--- Block Dominance relationship ---\n";
118     dominanceTest.printDominance(getAnalysis<DominanceInfo>(),
119                                  /*printCommonDominatorInfo=*/false);
120 
121     llvm::errs() << "--- Block PostDominance relationship ---\n";
122     dominanceTest.printDominance(getAnalysis<PostDominanceInfo>(),
123                                  /*printCommonDominatorInfo=*/false);
124   }
125 };
126 
127 } // namespace
128 
129 namespace mlir {
130 namespace test {
131 void registerTestDominancePass() { PassRegistration<TestDominancePass>(); }
132 } // namespace test
133 } // namespace mlir
134