1 //===- TestAliasAnalysis.cpp - Test alias analysis results ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains test passes for constructing and testing alias analysis
10 // results.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Analysis/AliasAnalysis.h"
15 #include "mlir/Pass/Pass.h"
16 
17 using namespace mlir;
18 
19 /// Print a value that is used as an operand of an alias query.
20 static void printAliasOperand(Operation *op) {
21   llvm::errs() << op->getAttrOfType<StringAttr>("test.ptr").getValue();
22 }
23 static void printAliasOperand(Value value) {
24   if (BlockArgument arg = value.dyn_cast<BlockArgument>()) {
25     Region *region = arg.getParentRegion();
26     unsigned parentBlockNumber =
27         std::distance(region->begin(), arg.getOwner()->getIterator());
28     llvm::errs() << region->getParentOp()
29                         ->getAttrOfType<StringAttr>("test.ptr")
30                         .getValue()
31                  << ".region" << region->getRegionNumber();
32     if (parentBlockNumber != 0)
33       llvm::errs() << ".block" << parentBlockNumber;
34     llvm::errs() << "#" << arg.getArgNumber();
35     return;
36   }
37   OpResult result = value.cast<OpResult>();
38   printAliasOperand(result.getOwner());
39   llvm::errs() << "#" << result.getResultNumber();
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Testing AliasResult
44 //===----------------------------------------------------------------------===//
45 
46 namespace {
47 struct TestAliasAnalysisPass
48     : public PassWrapper<TestAliasAnalysisPass, OperationPass<>> {
49   void runOnOperation() override {
50     llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
51 
52     // Collect all of the values to check for aliasing behavior.
53     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
54     SmallVector<Value, 32> valsToCheck;
55     getOperation()->walk([&](Operation *op) {
56       if (!op->getAttr("test.ptr"))
57         return;
58       valsToCheck.append(op->result_begin(), op->result_end());
59       for (Region &region : op->getRegions())
60         for (Block &block : region)
61           valsToCheck.append(block.args_begin(), block.args_end());
62     });
63 
64     // Check for aliasing behavior between each of the values.
65     for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it)
66       for (auto innerIt = valsToCheck.begin(); innerIt != it; ++innerIt)
67         printAliasResult(aliasAnalysis.alias(*innerIt, *it), *innerIt, *it);
68   }
69 
70   /// Print the result of an alias query.
71   void printAliasResult(AliasResult result, Value lhs, Value rhs) {
72     printAliasOperand(lhs);
73     llvm::errs() << " <-> ";
74     printAliasOperand(rhs);
75     llvm::errs() << ": " << result << "\n";
76   }
77 };
78 } // end anonymous namespace
79 
80 //===----------------------------------------------------------------------===//
81 // Testing ModRefResult
82 //===----------------------------------------------------------------------===//
83 
84 namespace {
85 struct TestAliasAnalysisModRefPass
86     : public PassWrapper<TestAliasAnalysisModRefPass, OperationPass<>> {
87   void runOnOperation() override {
88     llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n";
89 
90     // Collect all of the values to check for aliasing behavior.
91     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
92     SmallVector<Value, 32> valsToCheck;
93     getOperation()->walk([&](Operation *op) {
94       if (!op->getAttr("test.ptr"))
95         return;
96       valsToCheck.append(op->result_begin(), op->result_end());
97       for (Region &region : op->getRegions())
98         for (Block &block : region)
99           valsToCheck.append(block.args_begin(), block.args_end());
100     });
101 
102     // Check for aliasing behavior between each of the values.
103     for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it) {
104       getOperation()->walk([&](Operation *op) {
105         if (!op->getAttr("test.ptr"))
106           return;
107         printModRefResult(aliasAnalysis.getModRef(op, *it), op, *it);
108       });
109     }
110   }
111 
112   /// Print the result of an alias query.
113   void printModRefResult(ModRefResult result, Operation *op, Value location) {
114     printAliasOperand(op);
115     llvm::errs() << " -> ";
116     printAliasOperand(location);
117     llvm::errs() << ": " << result << "\n";
118   }
119 };
120 } // end anonymous namespace
121 
122 //===----------------------------------------------------------------------===//
123 // Pass Registration
124 //===----------------------------------------------------------------------===//
125 
126 namespace mlir {
127 namespace test {
128 void registerTestAliasAnalysisPass() {
129   PassRegistration<TestAliasAnalysisPass> aliasPass(
130       "test-alias-analysis", "Test alias analysis results.");
131   PassRegistration<TestAliasAnalysisModRefPass> modRefPass(
132       "test-alias-analysis-modref", "Test alias analysis ModRef results.");
133 }
134 } // namespace test
135 } // namespace mlir
136