1 //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/BufferViewFlowAnalysis.h"
10 
11 #include "mlir/Interfaces/ControlFlowInterfaces.h"
12 #include "mlir/Interfaces/ViewLikeInterface.h"
13 #include "llvm/ADT/SetOperations.h"
14 
15 using namespace mlir;
16 
17 /// Constructs a new alias analysis using the op provided.
BufferViewFlowAnalysis(Operation * op)18 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
19 
20 /// Find all immediate and indirect dependent buffers this value could
21 /// potentially have. Note that the resulting set will also contain the value
22 /// provided as it is a dependent alias of itself.
23 BufferViewFlowAnalysis::ValueSetT
resolve(Value rootValue) const24 BufferViewFlowAnalysis::resolve(Value rootValue) const {
25   ValueSetT result;
26   SmallVector<Value, 8> queue;
27   queue.push_back(rootValue);
28   while (!queue.empty()) {
29     Value currentValue = queue.pop_back_val();
30     if (result.insert(currentValue).second) {
31       auto it = dependencies.find(currentValue);
32       if (it != dependencies.end()) {
33         for (Value aliasValue : it->second)
34           queue.push_back(aliasValue);
35       }
36     }
37   }
38   return result;
39 }
40 
41 /// Removes the given values from all alias sets.
remove(const SmallPtrSetImpl<Value> & aliasValues)42 void BufferViewFlowAnalysis::remove(const SmallPtrSetImpl<Value> &aliasValues) {
43   for (auto &entry : dependencies)
44     llvm::set_subtract(entry.second, aliasValues);
45 }
46 
47 /// This function constructs a mapping from values to its immediate
48 /// dependencies. It iterates over all blocks, gets their predecessors,
49 /// determines the values that will be passed to the corresponding block
50 /// arguments and inserts them into the underlying map. Furthermore, it wires
51 /// successor regions and branch-like return operations from nested regions.
build(Operation * op)52 void BufferViewFlowAnalysis::build(Operation *op) {
53   // Registers all dependencies of the given values.
54   auto registerDependencies = [&](auto values, auto dependencies) {
55     for (auto entry : llvm::zip(values, dependencies))
56       this->dependencies[std::get<0>(entry)].insert(std::get<1>(entry));
57   };
58 
59   // Add additional dependencies created by view changes to the alias list.
60   op->walk([&](ViewLikeOpInterface viewInterface) {
61     dependencies[viewInterface.getViewSource()].insert(
62         viewInterface->getResult(0));
63   });
64 
65   // Query all branch interfaces to link block argument dependencies.
66   op->walk([&](BranchOpInterface branchInterface) {
67     Block *parentBlock = branchInterface->getBlock();
68     for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
69          it != e; ++it) {
70       // Query the branch op interface to get the successor operands.
71       auto successorOperands =
72           branchInterface.getSuccessorOperands(it.getIndex());
73       // Build the actual mapping of values to their immediate dependencies.
74       registerDependencies(successorOperands.getForwardedOperands(),
75                            (*it)->getArguments().drop_front(
76                                successorOperands.getProducedOperandCount()));
77     }
78   });
79 
80   // Query the RegionBranchOpInterface to find potential successor regions.
81   op->walk([&](RegionBranchOpInterface regionInterface) {
82     // Extract all entry regions and wire all initial entry successor inputs.
83     SmallVector<RegionSuccessor, 2> entrySuccessors;
84     regionInterface.getSuccessorRegions(/*index=*/llvm::None, entrySuccessors);
85     for (RegionSuccessor &entrySuccessor : entrySuccessors) {
86       // Wire the entry region's successor arguments with the initial
87       // successor inputs.
88       assert(entrySuccessor.getSuccessor() &&
89              "Invalid entry region without an attached successor region");
90       registerDependencies(
91           regionInterface.getSuccessorEntryOperands(
92               entrySuccessor.getSuccessor()->getRegionNumber()),
93           entrySuccessor.getSuccessorInputs());
94     }
95 
96     // Wire flow between regions and from region exits.
97     for (Region &region : regionInterface->getRegions()) {
98       // Iterate over all successor region entries that are reachable from the
99       // current region.
100       SmallVector<RegionSuccessor, 2> successorRegions;
101       regionInterface.getSuccessorRegions(region.getRegionNumber(),
102                                           successorRegions);
103       for (RegionSuccessor &successorRegion : successorRegions) {
104         // Determine the current region index (if any).
105         Optional<unsigned> regionIndex;
106         Region *regionSuccessor = successorRegion.getSuccessor();
107         if (regionSuccessor)
108           regionIndex = regionSuccessor->getRegionNumber();
109         // Iterate over all immediate terminator operations and wire the
110         // successor inputs with the successor operands of each terminator.
111         for (Block &block : region) {
112           auto successorOperands = getRegionBranchSuccessorOperands(
113               block.getTerminator(), regionIndex);
114           if (successorOperands) {
115             registerDependencies(*successorOperands,
116                                  successorRegion.getSuccessorInputs());
117           }
118         }
119       }
120     }
121   });
122 }
123