1 //===- TestAliasAnalysis.cpp - Test alias analysis results ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains test passes for constructing and testing alias analysis
10 // results.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Analysis/AliasAnalysis.h"
15 #include "mlir/Pass/Pass.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 struct TestAliasAnalysisPass
21     : public PassWrapper<TestAliasAnalysisPass, OperationPass<>> {
22   void runOnOperation() override {
23     llvm::errs() << "Testing : ";
24     if (Attribute testName = getOperation()->getAttr("test.name"))
25       llvm::errs() << testName << "\n";
26     else
27       llvm::errs() << getOperation()->getAttr("sym_name") << "\n";
28 
29     // Collect all of the values to check for aliasing behavior.
30     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
31     SmallVector<Value, 32> valsToCheck;
32     getOperation()->walk([&](Operation *op) {
33       if (!op->getAttr("test.ptr"))
34         return;
35       valsToCheck.append(op->result_begin(), op->result_end());
36       for (Region &region : op->getRegions())
37         for (Block &block : region)
38           valsToCheck.append(block.args_begin(), block.args_end());
39     });
40 
41     // Check for aliasing behavior between each of the values.
42     for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it)
43       for (auto innerIt = valsToCheck.begin(); innerIt != it; ++innerIt)
44         printAliasResult(aliasAnalysis.alias(*innerIt, *it), *innerIt, *it);
45   }
46 
47   /// Print the result of an alias query.
48   void printAliasResult(AliasResult result, Value lhs, Value rhs) {
49     printAliasOperand(lhs);
50     llvm::errs() << " <-> ";
51     printAliasOperand(rhs);
52     llvm::errs() << ": ";
53 
54     switch (result.getKind()) {
55     case AliasResult::NoAlias:
56       llvm::errs() << "NoAlias";
57       break;
58     case AliasResult::MayAlias:
59       llvm::errs() << "MayAlias";
60       break;
61     case AliasResult::PartialAlias:
62       llvm::errs() << "PartialAlias";
63       break;
64     case AliasResult::MustAlias:
65       llvm::errs() << "MustAlias";
66       break;
67     }
68     llvm::errs() << "\n";
69   }
70   /// Print a value that is used as an operand of an alias query.
71   void printAliasOperand(Value value) {
72     if (BlockArgument arg = value.dyn_cast<BlockArgument>()) {
73       Region *region = arg.getParentRegion();
74       unsigned parentBlockNumber =
75           std::distance(region->begin(), arg.getOwner()->getIterator());
76       llvm::errs() << region->getParentOp()
77                           ->getAttrOfType<StringAttr>("test.ptr")
78                           .getValue()
79                    << ".region" << region->getRegionNumber();
80       if (parentBlockNumber != 0)
81         llvm::errs() << ".block" << parentBlockNumber;
82       llvm::errs() << "#" << arg.getArgNumber();
83       return;
84     }
85     OpResult result = value.cast<OpResult>();
86     llvm::errs()
87         << result.getOwner()->getAttrOfType<StringAttr>("test.ptr").getValue()
88         << "#" << result.getResultNumber();
89   }
90 };
91 } // end anonymous namespace
92 
93 namespace mlir {
94 namespace test {
95 void registerTestAliasAnalysisPass() {
96   PassRegistration<TestAliasAnalysisPass> pass("test-alias-analysis",
97                                                "Test alias analysis results.");
98 }
99 } // namespace test
100 } // namespace mlir
101