1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements view-based alias and dependence analyses.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/BuiltinOps.h"
17 
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "linalg-dependence-analysis"
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 using llvm::dbgs;
27 
28 Value Aliases::find(Value v) {
29   if (v.isa<BlockArgument>())
30     return v;
31 
32   auto it = aliases.find(v);
33   if (it != aliases.end()) {
34     assert(it->getSecond().getType().isa<BaseMemRefType>() &&
35            "Memref expected");
36     return it->getSecond();
37   }
38 
39   while (true) {
40     if (v.isa<BlockArgument>())
41       return v;
42 
43     Operation *defOp = v.getDefiningOp();
44     if (!defOp)
45       return v;
46 
47     // Treat RegionBranchOpInterfaces like an allocate and don't try to follow
48     // the aliasing further.
49     if (isa<RegionBranchOpInterface>(defOp))
50       return v;
51     if (isa<TensorToMemrefOp>(defOp))
52       return v;
53 
54     if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
55       // Collect all memory effects on `v`.
56       SmallVector<MemoryEffects::EffectInstance, 1> effects;
57       memEffect.getEffectsOnValue(v, effects);
58 
59       // If we have the 'Allocate' memory effect on `v`, then `v` should be the
60       // original buffer.
61       if (llvm::any_of(
62               effects, [](const MemoryEffects::EffectInstance &instance) {
63                 return isa<MemoryEffects::Allocate>(instance.getEffect());
64               }))
65         return v;
66     }
67 
68     if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) {
69       auto it =
70           aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource())));
71       return it.first->second;
72     }
73 
74     llvm::errs() << "View alias analysis reduces to: " << v << "\n";
75     llvm_unreachable("unsupported view alias case");
76   }
77 }
78 
79 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
80   switch (depType) {
81   case LinalgDependenceGraph::DependenceType::RAW:
82     return "RAW";
83   case LinalgDependenceGraph::DependenceType::RAR:
84     return "RAR";
85   case LinalgDependenceGraph::DependenceType::WAR:
86     return "WAR";
87   case LinalgDependenceGraph::DependenceType::WAW:
88     return "WAW";
89   default:
90     break;
91   }
92   llvm_unreachable("Unexpected DependenceType");
93 }
94 
95 LinalgDependenceGraph
96 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
97   SmallVector<LinalgOp, 8> linalgOps;
98   f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
99   return LinalgDependenceGraph(aliases, linalgOps);
100 }
101 
102 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
103                                              ArrayRef<LinalgOp> ops)
104     : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
105   for (auto en : llvm::enumerate(linalgOps)) {
106     linalgOpPositions.insert(
107         std::make_pair(en.value().getOperation(), en.index()));
108   }
109   for (unsigned i = 0, e = ops.size(); i < e; ++i) {
110     for (unsigned j = i + 1; j < e; ++j) {
111       addDependencesBetween(ops[i], ops[j]);
112     }
113   }
114 }
115 
116 void LinalgDependenceGraph::addDependenceElem(
117     DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView,
118     LinalgDependenceGraphElem::OpView dependentOpView) {
119   LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
120                     << LinalgDependenceGraphElem::getValue(indexingOpView)
121                     << " @) -> \n\t\t("
122                     << LinalgDependenceGraphElem::getValue(dependentOpView)
123                     << " @)");
124   dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)]
125       .push_back(
126           LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
127   dependencesIntoGraphs[dt]
128                        [LinalgDependenceGraphElem::getOwner(dependentOpView)]
129                            .push_back(LinalgDependenceGraphElem{
130                                indexingOpView, dependentOpView, dt});
131 }
132 
133 LinalgDependenceGraph::dependence_range
134 LinalgDependenceGraph::getDependencesFrom(
135     LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
136   return getDependencesFrom(src.getOperation(), dt);
137 }
138 
139 LinalgDependenceGraph::dependence_range
140 LinalgDependenceGraph::getDependencesFrom(
141     Operation *src, LinalgDependenceGraph::DependenceType dt) const {
142   auto iter = dependencesFromGraphs[dt].find(src);
143   if (iter == dependencesFromGraphs[dt].end())
144     return llvm::make_range(nullptr, nullptr);
145   return llvm::make_range(iter->second.begin(), iter->second.end());
146 }
147 
148 LinalgDependenceGraph::dependence_range
149 LinalgDependenceGraph::getDependencesInto(
150     LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
151   return getDependencesInto(dst.getOperation(), dt);
152 }
153 
154 LinalgDependenceGraph::dependence_range
155 LinalgDependenceGraph::getDependencesInto(
156     Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
157   auto iter = dependencesIntoGraphs[dt].find(dst);
158   if (iter == dependencesIntoGraphs[dt].end())
159     return llvm::make_range(nullptr, nullptr);
160   return llvm::make_range(iter->second.begin(), iter->second.end());
161 }
162 
163 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
164   if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
165     for (OpOperand &dstOpOperand : dst.getInputOpOperands()) {
166       // Check if the operand is defined by the src.
167       auto definingOp = dstOpOperand.get().getDefiningOp<LinalgOp>();
168       if (definingOp && definingOp == src)
169         addDependenceElem(DependenceType::RAW, dstOpOperand.get(),
170                           &dstOpOperand);
171     }
172     return;
173   }
174   assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
175          "unhandled dependence tracking for mixed buffer/tensor operations");
176   for (OpOperand *srcOpOperand : src.getOutputBuffersOpOperands()) { // W
177     // RAW graph
178     for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R
179       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get()))  // RAW alias
180         addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
181     // WAW graph
182     for (OpOperand *dstOpOperand : dst.getOutputBuffersOpOperands()) // W
183       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
184         addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
185   }
186   for (OpOperand *srcOpOperand : src.getInputBuffersOpOperands()) { // R
187     // RAR graph
188     for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R
189       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get()))  // RAR alias
190         addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
191     // WAR graph
192     for (OpOperand *dstOpOperand : dst.getOutputBuffersOpOperands()) // W
193       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
194         addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
195   }
196 }
197 
198 SmallVector<Operation *, 8>
199 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
200                                                LinalgOp dstLinalgOp) const {
201   return findOperationsWithCoveringDependences(
202       srcLinalgOp, dstLinalgOp, nullptr,
203       {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
204 }
205 
206 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
207     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
208   return findOperationsWithCoveringDependences(
209       srcLinalgOp, dstLinalgOp, view,
210       {DependenceType::WAW, DependenceType::WAR});
211 }
212 
213 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
214     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
215   return findOperationsWithCoveringDependences(
216       srcLinalgOp, dstLinalgOp, view,
217       {DependenceType::RAR, DependenceType::RAW});
218 }
219 
220 SmallVector<Operation *, 8>
221 LinalgDependenceGraph::findOperationsWithCoveringDependences(
222     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
223     ArrayRef<DependenceType> types) const {
224   auto *src = srcLinalgOp.getOperation();
225   auto *dst = dstLinalgOp.getOperation();
226   auto srcPos = linalgOpPositions.lookup(src);
227   auto dstPos = linalgOpPositions.lookup(dst);
228   assert(srcPos < dstPos && "expected dst after src in IR traversal order");
229 
230   SmallVector<Operation *, 8> res;
231   // Consider an intermediate interleaved `interim` op, look for any dependence
232   // to an aliasing view on a src -> op -> dst path.
233   // TODO: we are not considering paths yet, just interleaved positions.
234   for (auto dt : types) {
235     for (auto dependence : getDependencesFrom(src, dt)) {
236       auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp());
237       // Skip if not interleaved.
238       if (interimPos >= dstPos || interimPos <= srcPos)
239         continue;
240       Value consumerView = dependence.getIndexingValue();
241       if (view && !aliases.alias(view, consumerView))
242         continue;
243       auto *op = dependence.getDependentOp();
244       LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
245                         << getDependenceTypeStr(dt) << ": " << *src << " -> "
246                         << *op << " on " << consumerView);
247       res.push_back(op);
248     }
249   }
250   return res;
251 }
252 
253 bool LinalgDependenceGraph::hasDependenceFrom(
254     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
255     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
256   for (auto dep : depTypes)
257     for (auto dependence : getDependencesInto(dstLinalgOp, dep))
258       if (dependence.getDependentOp() == srcLinalgOp)
259         return true;
260   return false;
261 }
262 
263 bool LinalgDependenceGraph::hasDependentOperationsFrom(
264     LinalgOp linalgOp,
265     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
266   for (auto dep : depTypes) {
267     if (!getDependencesFrom(linalgOp, dep).empty())
268       return true;
269   }
270   return false;
271 }
272 
273 bool LinalgDependenceGraph::hasDependentOperationsInto(
274     LinalgOp linalgOp,
275     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
276   for (auto dep : depTypes) {
277     if (!getDependencesInto(linalgOp, dep).empty())
278       return true;
279   }
280   return false;
281 }
282 
283 bool LinalgDependenceGraph::hasDependentOperations(
284     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
285   return hasDependentOperationsInto(linalgOp, depTypes) ||
286          hasDependentOperationsFrom(linalgOp, depTypes);
287 }
288 
289 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
290 LinalgDependenceGraph::getDependentOperationsInto(
291     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
292   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
293       dependentOperations;
294   for (auto dependenceType : depTypes) {
295     auto dependencies = getDependencesInto(linalgOp, dependenceType);
296     dependentOperations.append(dependencies.begin(), dependencies.end());
297   }
298   return dependentOperations;
299 }
300 
301 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
302 LinalgDependenceGraph::getDependentOperationsFrom(
303     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
304   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
305       dependentOperations;
306   for (auto dependenceType : depTypes) {
307     auto dependencies = getDependencesFrom(linalgOp, dependenceType);
308     dependentOperations.append(dependencies.begin(), dependencies.end());
309   }
310   return dependentOperations;
311 }
312 
313 /// Returns all dependent operations (into and from) given `operation`.
314 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
315 LinalgDependenceGraph::getDependentOperations(
316     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
317   SmallVector<LinalgDependenceGraphElem, 2> dependentOperations =
318       getDependentOperationsInto(linalgOp, depTypes);
319   SmallVector<LinalgDependenceGraphElem, 2> t =
320       getDependentOperationsFrom(linalgOp, depTypes);
321   dependentOperations.append(t.begin(), t.end());
322   return dependentOperations;
323 }
324