1 //===- DependenceAnalysis.h - Dependence analysis on SSA views --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
10 #define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
11 
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/OpDefinition.h"
15 
16 namespace mlir {
17 namespace func {
18 class FuncOp;
19 } // namespace func
20 
21 namespace linalg {
22 
23 class LinalgOp;
24 
25 /// A very primitive alias analysis which just records for each view, either:
26 ///   1. The base buffer, or
27 ///   2. The block argument view
28 /// that it indexes into.
29 /// This does not perform inter-block or inter-procedural analysis and assumes
30 /// that different block argument views do not alias.
31 class Aliases {
32 public:
33   /// Returns true if v1 and v2 alias.
alias(Value v1,Value v2)34   bool alias(Value v1, Value v2) { return find(v1) == find(v2); }
35 
36 private:
37   /// Returns the base buffer or block argument into which the view `v` aliases.
38   /// This lazily records the new aliases discovered while walking back the
39   /// use-def chain.
40   Value find(Value v);
41 
42   DenseMap<Value, Value> aliases;
43 };
44 
45 /// Data structure for holding a dependence graph that operates on LinalgOp and
46 /// views as SSA values.
47 class LinalgDependenceGraph {
48 public:
49   enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
50   // TODO: OpOperand tracks dependencies on buffer operands. Tensor result will
51   // need an extension to use OpResult.
52   struct LinalgDependenceGraphElem {
53     using OpView = PointerUnion<OpOperand *, Value>;
54     // dependentOpView may be either:
55     //   1. src in the case of dependencesIntoGraphs.
56     //   2. dst in the case of dependencesFromDstGraphs.
57     OpView dependentOpView;
58     // View in the op that is used to index in the graph:
59     //   1. src in the case of dependencesFromDstGraphs.
60     //   2. dst in the case of dependencesIntoGraphs.
61     OpView indexingOpView;
62     // Type of the dependence.
63     DependenceType dependenceType;
64 
65     // Return the Operation that owns the operand or result represented in
66     // `opView`.
getOwnerLinalgDependenceGraphElem67     static Operation *getOwner(OpView opView) {
68       if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
69         return operand->getOwner();
70       return opView.get<Value>().cast<OpResult>().getOwner();
71     }
72     // Return the operand or the result Value represented by the `opView`.
getValueLinalgDependenceGraphElem73     static Value getValue(OpView opView) {
74       if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
75         return operand->get();
76       return opView.get<Value>();
77     }
78     // Return the indexing map of the operand/result in `opView` specified in
79     // the owning LinalgOp. If the owner is not a LinalgOp returns llvm::None.
getIndexingMapLinalgDependenceGraphElem80     static Optional<AffineMap> getIndexingMap(OpView opView) {
81       auto owner = dyn_cast<LinalgOp>(getOwner(opView));
82       if (!owner)
83         return llvm::None;
84       if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
85         return owner.getTiedIndexingMap(operand);
86       return owner.getTiedIndexingMap(owner.getOutputOperand(
87           opView.get<Value>().cast<OpResult>().getResultNumber()));
88     }
89     // Return the operand number if the `opView` is an OpOperand *. Otherwise
90     // return llvm::None.
getOperandNumberLinalgDependenceGraphElem91     static Optional<unsigned> getOperandNumber(OpView opView) {
92       if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
93         return operand->getOperandNumber();
94       return llvm::None;
95     }
96     // Return the result number if the `opView` is an OpResult. Otherwise return
97     // llvm::None.
getResultNumberLinalgDependenceGraphElem98     static Optional<unsigned> getResultNumber(OpView opView) {
99       if (OpResult result = opView.dyn_cast<Value>().cast<OpResult>())
100         return result.getResultNumber();
101       return llvm::None;
102     }
103 
104     // Return the owner of the dependent OpView.
getDependentOpLinalgDependenceGraphElem105     Operation *getDependentOp() const { return getOwner(dependentOpView); }
106 
107     // Return the owner of the indexing OpView.
getIndexingOpLinalgDependenceGraphElem108     Operation *getIndexingOp() const { return getOwner(indexingOpView); }
109 
110     // Return the operand or result stored in the dependentOpView.
getDependentValueLinalgDependenceGraphElem111     Value getDependentValue() const { return getValue(dependentOpView); }
112 
113     // Return the operand or result stored in the indexingOpView.
getIndexingValueLinalgDependenceGraphElem114     Value getIndexingValue() const { return getValue(indexingOpView); }
115 
116     // If the dependent OpView is an operand, return operand number. Return
117     // llvm::None otherwise.
getDependentOpViewOperandNumLinalgDependenceGraphElem118     Optional<unsigned> getDependentOpViewOperandNum() const {
119       return getOperandNumber(dependentOpView);
120     }
121 
122     // If the indexing OpView is an operand, return operand number. Return
123     // llvm::None otherwise.
getIndexingOpViewOperandNumLinalgDependenceGraphElem124     Optional<unsigned> getIndexingOpViewOperandNum() const {
125       return getOperandNumber(indexingOpView);
126     }
127 
128     // If the dependent OpView is a result value, return the result
129     // number. Return llvm::None otherwise.
getDependentOpViewResultNumLinalgDependenceGraphElem130     Optional<unsigned> getDependentOpViewResultNum() const {
131       return getResultNumber(dependentOpView);
132     }
133 
134     // If the dependent OpView is a result value, return the result
135     // number. Return llvm::None otherwise.
getIndexingOpViewResultNumLinalgDependenceGraphElem136     Optional<unsigned> getIndexingOpViewResultNum() const {
137       return getResultNumber(indexingOpView);
138     }
139 
140     // Return the indexing map of the operand/result in the dependent OpView as
141     // specified in the owner of the OpView.
getDependentOpViewIndexingMapLinalgDependenceGraphElem142     Optional<AffineMap> getDependentOpViewIndexingMap() const {
143       return getIndexingMap(dependentOpView);
144     }
145 
146     // Return the indexing map of the operand/result in the indexing OpView as
147     // specified in the owner of the OpView.
getIndexingOpViewIndexingMapLinalgDependenceGraphElem148     Optional<AffineMap> getIndexingOpViewIndexingMap() const {
149       return getIndexingMap(indexingOpView);
150     }
151   };
152   using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
153   using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
154   using dependence_iterator = LinalgDependences::const_iterator;
155   using dependence_range = iterator_range<dependence_iterator>;
156 
157   static StringRef getDependenceTypeStr(DependenceType depType);
158 
159   // Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
160   static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases,
161                                                     func::FuncOp f);
162   LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
163 
164   /// Returns the X such that op -> X is a dependence of type dt.
165   dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
166   dependence_range getDependencesFrom(LinalgOp src, DependenceType dt) const;
167 
168   /// Returns the X such that X -> op is a dependence of type dt.
169   dependence_range getDependencesInto(Operation *dst, DependenceType dt) const;
170   dependence_range getDependencesInto(LinalgOp dst, DependenceType dt) const;
171 
172   /// Returns the operations that are interleaved between `srcLinalgOp` and
173   /// `dstLinalgOp` and that are involved in any RAW, WAR or WAW dependence
174   /// relation with `srcLinalgOp`, on any view.
175   /// Any such operation prevents reordering.
176   SmallVector<Operation *, 8>
177   findCoveringDependences(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp) const;
178 
179   /// Returns the operations that are interleaved between `srcLinalgOp` and
180   /// `dstLinalgOp` and that are involved in a RAR or RAW with `srcLinalgOp`.
181   /// Dependences are restricted to views aliasing `view`.
182   SmallVector<Operation *, 8> findCoveringReads(LinalgOp srcLinalgOp,
183                                                 LinalgOp dstLinalgOp,
184                                                 Value view) const;
185 
186   /// Returns the operations that are interleaved between `srcLinalgOp` and
187   /// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`.
188   /// Dependences are restricted to views aliasing `view`.
189   SmallVector<Operation *, 8> findCoveringWrites(LinalgOp srcLinalgOp,
190                                                  LinalgOp dstLinalgOp,
191                                                  Value view) const;
192 
193   /// Returns true if the two operations have the specified dependence from
194   /// `srcLinalgOp` to `dstLinalgOp`.
195   bool hasDependenceFrom(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
196                          ArrayRef<DependenceType> depTypes = {
197                              DependenceType::RAW, DependenceType::WAW}) const;
198 
199   /// Returns true if the `linalgOp` has dependences into it.
200   bool hasDependentOperationsInto(LinalgOp linalgOp,
201                                   ArrayRef<DependenceType> depTypes = {
202                                       DependenceType::RAW,
203                                       DependenceType::WAW}) const;
204 
205   /// Returns true if the `linalgOp` has dependences from it.
206   bool hasDependentOperationsFrom(LinalgOp linalgOp,
207                                   ArrayRef<DependenceType> depTypes = {
208                                       DependenceType::RAW,
209                                       DependenceType::WAW}) const;
210 
211   /// Returns true if the `linalgOp` has dependences into or from it.
212   bool hasDependentOperations(LinalgOp linalgOp,
213                               ArrayRef<DependenceType> depTypes = {
214                                   DependenceType::RAW,
215                                   DependenceType::WAW}) const;
216 
217   /// Returns all operations that have a dependence into `linalgOp` of types
218   /// listed in `depTypes`.
219   SmallVector<LinalgDependenceGraphElem, 2> getDependentOperationsInto(
220       LinalgOp linalgOp, ArrayRef<DependenceType> depTypes = {
221                              DependenceType::RAW, DependenceType::WAW}) const;
222 
223   /// Returns all operations that have a dependence from `linalgOp` of types
224   /// listed in `depTypes`.
225   SmallVector<LinalgDependenceGraphElem, 2> getDependentOperationsFrom(
226       LinalgOp linalgOp, ArrayRef<DependenceType> depTypes = {
227                              DependenceType::RAW, DependenceType::WAW}) const;
228 
229   /// Returns all dependent operations (into and from) given `operation`.
230   SmallVector<LinalgDependenceGraphElem, 2>
231   getDependentOperations(LinalgOp linalgOp,
232                          ArrayRef<DependenceType> depTypes = {
233                              DependenceType::RAW, DependenceType::WAW}) const;
234 
235   void print(raw_ostream &os) const;
236 
237   void dump() const;
238 
239 private:
240   // Keep dependences in both directions, this is not just a performance gain
241   // but it also reduces usage errors.
242   // Dependence information is stored as a map of:
243   //   (source operation -> LinalgDependenceGraphElem)
244   DependenceGraph dependencesFromGraphs[DependenceType::NumTypes];
245   // Reverse dependence information is stored as a map of:
246   //   (destination operation -> LinalgDependenceGraphElem)
247   DependenceGraph dependencesIntoGraphs[DependenceType::NumTypes];
248 
249   /// Analyses the aliasing views between `src` and `dst` and inserts the proper
250   /// dependences in the graph.
251   void addDependencesBetween(LinalgOp src, LinalgOp dst);
252 
253   // Adds an new dependence unit in the proper graph.
254   // Uses std::pair to keep operations and view together and avoid usage errors
255   // related to src/dst and producer/consumer terminology in the context of
256   // dependences.
257   void addDependenceElem(DependenceType dt,
258                          LinalgDependenceGraphElem::OpView indexingOpView,
259                          LinalgDependenceGraphElem::OpView dependentOpView);
260 
261   /// Implementation detail for findCoveringxxx.
262   SmallVector<Operation *, 8>
263   findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,
264                                         LinalgOp dstLinalgOp, Value view,
265                                         ArrayRef<DependenceType> types) const;
266 
267   Aliases &aliases;
268   SmallVector<LinalgOp, 8> linalgOps;
269   DenseMap<Operation *, unsigned> linalgOpPositions;
270 };
271 } // namespace linalg
272 } // namespace mlir
273 
274 #endif // MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
275