1 //===- TestAliasAnalysis.cpp - Test alias analysis results ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains test passes for constructing and testing alias analysis
10 // results.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Analysis/AliasAnalysis.h"
15 #include "mlir/Pass/Pass.h"
16 
17 using namespace mlir;
18 
19 /// Print a value that is used as an operand of an alias query.
20 static void printAliasOperand(Operation *op) {
21   llvm::errs() << op->getAttrOfType<StringAttr>("test.ptr").getValue();
22 }
23 static void printAliasOperand(Value value) {
24   if (BlockArgument arg = value.dyn_cast<BlockArgument>()) {
25     Region *region = arg.getParentRegion();
26     unsigned parentBlockNumber =
27         std::distance(region->begin(), arg.getOwner()->getIterator());
28     llvm::errs() << region->getParentOp()
29                         ->getAttrOfType<StringAttr>("test.ptr")
30                         .getValue()
31                  << ".region" << region->getRegionNumber();
32     if (parentBlockNumber != 0)
33       llvm::errs() << ".block" << parentBlockNumber;
34     llvm::errs() << "#" << arg.getArgNumber();
35     return;
36   }
37   OpResult result = value.cast<OpResult>();
38   printAliasOperand(result.getOwner());
39   llvm::errs() << "#" << result.getResultNumber();
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Testing AliasResult
44 //===----------------------------------------------------------------------===//
45 
46 namespace {
47 struct TestAliasAnalysisPass
48     : public PassWrapper<TestAliasAnalysisPass, OperationPass<>> {
49   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisPass)
50 
51   StringRef getArgument() const final { return "test-alias-analysis"; }
52   StringRef getDescription() const final {
53     return "Test alias analysis results.";
54   }
55   void runOnOperation() override {
56     llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
57 
58     // Collect all of the values to check for aliasing behavior.
59     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
60     SmallVector<Value, 32> valsToCheck;
61     getOperation()->walk([&](Operation *op) {
62       if (!op->getAttr("test.ptr"))
63         return;
64       valsToCheck.append(op->result_begin(), op->result_end());
65       for (Region &region : op->getRegions())
66         for (Block &block : region)
67           valsToCheck.append(block.args_begin(), block.args_end());
68     });
69 
70     // Check for aliasing behavior between each of the values.
71     for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it)
72       for (auto *innerIt = valsToCheck.begin(); innerIt != it; ++innerIt)
73         printAliasResult(aliasAnalysis.alias(*innerIt, *it), *innerIt, *it);
74   }
75 
76   /// Print the result of an alias query.
77   void printAliasResult(AliasResult result, Value lhs, Value rhs) {
78     printAliasOperand(lhs);
79     llvm::errs() << " <-> ";
80     printAliasOperand(rhs);
81     llvm::errs() << ": " << result << "\n";
82   }
83 };
84 } // namespace
85 
86 //===----------------------------------------------------------------------===//
87 // Testing ModRefResult
88 //===----------------------------------------------------------------------===//
89 
90 namespace {
91 struct TestAliasAnalysisModRefPass
92     : public PassWrapper<TestAliasAnalysisModRefPass, OperationPass<>> {
93   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisModRefPass)
94 
95   StringRef getArgument() const final { return "test-alias-analysis-modref"; }
96   StringRef getDescription() const final {
97     return "Test alias analysis ModRef results.";
98   }
99   void runOnOperation() override {
100     llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
101 
102     // Collect all of the values to check for aliasing behavior.
103     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
104     SmallVector<Value, 32> valsToCheck;
105     getOperation()->walk([&](Operation *op) {
106       if (!op->getAttr("test.ptr"))
107         return;
108       valsToCheck.append(op->result_begin(), op->result_end());
109       for (Region &region : op->getRegions())
110         for (Block &block : region)
111           valsToCheck.append(block.args_begin(), block.args_end());
112     });
113 
114     // Check for aliasing behavior between each of the values.
115     for (auto &it : valsToCheck) {
116       getOperation()->walk([&](Operation *op) {
117         if (!op->getAttr("test.ptr"))
118           return;
119         printModRefResult(aliasAnalysis.getModRef(op, it), op, it);
120       });
121     }
122   }
123 
124   /// Print the result of an alias query.
125   void printModRefResult(ModRefResult result, Operation *op, Value location) {
126     printAliasOperand(op);
127     llvm::errs() << " -> ";
128     printAliasOperand(location);
129     llvm::errs() << ": " << result << "\n";
130   }
131 };
132 } // namespace
133 
134 //===----------------------------------------------------------------------===//
135 // Pass Registration
136 //===----------------------------------------------------------------------===//
137 
138 namespace mlir {
139 namespace test {
140 void registerTestAliasAnalysisPass() {
141   PassRegistration<TestAliasAnalysisPass>();
142   PassRegistration<TestAliasAnalysisModRefPass>();
143 }
144 } // namespace test
145 } // namespace mlir
146