1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements view-based alias and dependence analyses.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/IR/BuiltinOps.h"
17 
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "linalg-dependence-analysis"
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 using llvm::dbgs;
27 
28 Value Aliases::find(Value v) {
29   if (v.isa<BlockArgument>())
30     return v;
31 
32   auto it = aliases.find(v);
33   if (it != aliases.end()) {
34     assert(it->getSecond().getType().isa<BaseMemRefType>() &&
35            "Memref expected");
36     return it->getSecond();
37   }
38 
39   while (true) {
40     if (v.isa<BlockArgument>())
41       return v;
42 
43     Operation *defOp = v.getDefiningOp();
44     if (!defOp)
45       return v;
46 
47     // Treat RegionBranchOpInterfaces like an allocate and don't try to follow
48     // the aliasing further.
49     if (isa<RegionBranchOpInterface>(defOp))
50       return v;
51     if (isa<bufferization::ToMemrefOp>(defOp))
52       return v;
53 
54     if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
55       // Collect all memory effects on `v`.
56       SmallVector<MemoryEffects::EffectInstance, 1> effects;
57       memEffect.getEffectsOnValue(v, effects);
58 
59       // If we have the 'Allocate' memory effect on `v`, then `v` should be the
60       // original buffer.
61       if (llvm::any_of(
62               effects, [](const MemoryEffects::EffectInstance &instance) {
63                 return isa<MemoryEffects::Allocate>(instance.getEffect());
64               }))
65         return v;
66     }
67 
68     if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) {
69       auto it =
70           aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource())));
71       return it.first->second;
72     }
73 
74     llvm::errs() << "View alias analysis reduces to: " << v << "\n";
75     llvm_unreachable("unsupported view alias case");
76   }
77 }
78 
79 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
80   switch (depType) {
81   case LinalgDependenceGraph::DependenceType::RAW:
82     return "RAW";
83   case LinalgDependenceGraph::DependenceType::RAR:
84     return "RAR";
85   case LinalgDependenceGraph::DependenceType::WAR:
86     return "WAR";
87   case LinalgDependenceGraph::DependenceType::WAW:
88     return "WAW";
89   default:
90     break;
91   }
92   llvm_unreachable("Unexpected DependenceType");
93 }
94 
95 LinalgDependenceGraph
96 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
97   SmallVector<LinalgOp, 8> linalgOps;
98   f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
99   return LinalgDependenceGraph(aliases, linalgOps);
100 }
101 
102 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
103                                              ArrayRef<LinalgOp> ops)
104     : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
105   for (const auto &en : llvm::enumerate(linalgOps)) {
106     linalgOpPositions.insert(
107         std::make_pair(en.value().getOperation(), en.index()));
108   }
109   for (unsigned i = 0, e = ops.size(); i < e; ++i) {
110     for (unsigned j = i + 1; j < e; ++j) {
111       addDependencesBetween(ops[i], ops[j]);
112     }
113   }
114 }
115 
116 void LinalgDependenceGraph::addDependenceElem(
117     DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView,
118     LinalgDependenceGraphElem::OpView dependentOpView) {
119   LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
120                     << LinalgDependenceGraphElem::getValue(indexingOpView)
121                     << " @) -> \n\t\t("
122                     << LinalgDependenceGraphElem::getValue(dependentOpView)
123                     << " @)");
124   dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)]
125       .push_back(
126           LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
127   dependencesIntoGraphs[dt]
128                        [LinalgDependenceGraphElem::getOwner(dependentOpView)]
129                            .push_back(LinalgDependenceGraphElem{
130                                indexingOpView, dependentOpView, dt});
131 }
132 
133 LinalgDependenceGraph::dependence_range
134 LinalgDependenceGraph::getDependencesFrom(
135     LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
136   return getDependencesFrom(src.getOperation(), dt);
137 }
138 
139 LinalgDependenceGraph::dependence_range
140 LinalgDependenceGraph::getDependencesFrom(
141     Operation *src, LinalgDependenceGraph::DependenceType dt) const {
142   auto iter = dependencesFromGraphs[dt].find(src);
143   if (iter == dependencesFromGraphs[dt].end())
144     return llvm::make_range(nullptr, nullptr);
145   return llvm::make_range(iter->second.begin(), iter->second.end());
146 }
147 
148 LinalgDependenceGraph::dependence_range
149 LinalgDependenceGraph::getDependencesInto(
150     LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
151   return getDependencesInto(dst.getOperation(), dt);
152 }
153 
154 LinalgDependenceGraph::dependence_range
155 LinalgDependenceGraph::getDependencesInto(
156     Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
157   auto iter = dependencesIntoGraphs[dt].find(dst);
158   if (iter == dependencesIntoGraphs[dt].end())
159     return llvm::make_range(nullptr, nullptr);
160   return llvm::make_range(iter->second.begin(), iter->second.end());
161 }
162 
163 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
164   LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation()
165                     << " and " << *dst.getOperation() << "\n");
166   if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
167     for (OpOperand *dstOpOperand : dst.getInputOperands()) {
168       // Check if the operand is defined by the src.
169       auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
170       if (definingOp && definingOp == src)
171         addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
172                           dstOpOperand);
173     }
174     for (OpOperand *dstOpOperand : dst.getOutputOperands()) {
175       // Check if the operand is defined by the src.
176       auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
177       if (definingOp && definingOp == src) {
178         if (dst.isInitTensor(dstOpOperand)) {
179           addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
180                             dstOpOperand);
181         }
182         addDependenceElem(DependenceType::WAW, dstOpOperand->get(),
183                           dstOpOperand);
184       }
185     }
186     return;
187   }
188   assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
189          "unhandled dependence tracking for mixed buffer/tensor operations");
190   for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W
191     // RAW graph
192     for (OpOperand *dstOpOperand : dst.getInputBufferOperands())   // R
193       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias
194         addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
195     // WAW graph
196     for (OpOperand *dstOpOperand : dst.getOutputBufferOperands())  // W
197       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
198         addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
199   }
200   for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R
201     // RAR graph
202     for (OpOperand *dstOpOperand : dst.getInputBufferOperands())   // R
203       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias
204         addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
205     // WAR graph
206     for (OpOperand *dstOpOperand : dst.getOutputBufferOperands())  // W
207       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
208         addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
209   }
210 }
211 
212 SmallVector<Operation *, 8>
213 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
214                                                LinalgOp dstLinalgOp) const {
215   return findOperationsWithCoveringDependences(
216       srcLinalgOp, dstLinalgOp, nullptr,
217       {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
218 }
219 
220 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
221     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
222   return findOperationsWithCoveringDependences(
223       srcLinalgOp, dstLinalgOp, view,
224       {DependenceType::WAW, DependenceType::WAR});
225 }
226 
227 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
228     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
229   return findOperationsWithCoveringDependences(
230       srcLinalgOp, dstLinalgOp, view,
231       {DependenceType::RAR, DependenceType::RAW});
232 }
233 
234 SmallVector<Operation *, 8>
235 LinalgDependenceGraph::findOperationsWithCoveringDependences(
236     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
237     ArrayRef<DependenceType> types) const {
238   auto *src = srcLinalgOp.getOperation();
239   auto *dst = dstLinalgOp.getOperation();
240   auto srcPos = linalgOpPositions.lookup(src);
241   auto dstPos = linalgOpPositions.lookup(dst);
242   assert(srcPos < dstPos && "expected dst after src in IR traversal order");
243 
244   SmallVector<Operation *, 8> res;
245   // Consider an intermediate interleaved `interim` op, look for any dependence
246   // to an aliasing view on a src -> op -> dst path.
247   // TODO: we are not considering paths yet, just interleaved positions.
248   for (auto dt : types) {
249     for (auto dependence : getDependencesFrom(src, dt)) {
250       auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp());
251       // Skip if not interleaved.
252       if (interimPos >= dstPos || interimPos <= srcPos)
253         continue;
254       Value consumerView = dependence.getIndexingValue();
255       if (view && !aliases.alias(view, consumerView))
256         continue;
257       auto *op = dependence.getDependentOp();
258       LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
259                         << getDependenceTypeStr(dt) << ": " << *src << " -> "
260                         << *op << " on " << consumerView);
261       res.push_back(op);
262     }
263   }
264   return res;
265 }
266 
267 bool LinalgDependenceGraph::hasDependenceFrom(
268     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
269     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
270   for (auto dep : depTypes)
271     for (auto dependence : getDependencesInto(dstLinalgOp, dep))
272       if (dependence.getDependentOp() == srcLinalgOp)
273         return true;
274   return false;
275 }
276 
277 bool LinalgDependenceGraph::hasDependentOperationsFrom(
278     LinalgOp linalgOp,
279     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
280   for (auto dep : depTypes) {
281     if (!getDependencesFrom(linalgOp, dep).empty())
282       return true;
283   }
284   return false;
285 }
286 
287 bool LinalgDependenceGraph::hasDependentOperationsInto(
288     LinalgOp linalgOp,
289     ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
290   for (auto dep : depTypes) {
291     if (!getDependencesInto(linalgOp, dep).empty())
292       return true;
293   }
294   return false;
295 }
296 
297 bool LinalgDependenceGraph::hasDependentOperations(
298     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
299   return hasDependentOperationsInto(linalgOp, depTypes) ||
300          hasDependentOperationsFrom(linalgOp, depTypes);
301 }
302 
303 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
304 LinalgDependenceGraph::getDependentOperationsInto(
305     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
306   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
307       dependentOperations;
308   for (auto dependenceType : depTypes) {
309     auto dependencies = getDependencesInto(linalgOp, dependenceType);
310     dependentOperations.append(dependencies.begin(), dependencies.end());
311   }
312   return dependentOperations;
313 }
314 
315 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
316 LinalgDependenceGraph::getDependentOperationsFrom(
317     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
318   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
319       dependentOperations;
320   for (auto dependenceType : depTypes) {
321     auto dependencies = getDependencesFrom(linalgOp, dependenceType);
322     dependentOperations.append(dependencies.begin(), dependencies.end());
323   }
324   return dependentOperations;
325 }
326 
327 /// Returns all dependent operations (into and from) given `operation`.
328 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
329 LinalgDependenceGraph::getDependentOperations(
330     LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
331   SmallVector<LinalgDependenceGraphElem, 2> dependentOperations =
332       getDependentOperationsInto(linalgOp, depTypes);
333   SmallVector<LinalgDependenceGraphElem, 2> t =
334       getDependentOperationsFrom(linalgOp, depTypes);
335   dependentOperations.append(t.begin(), t.end());
336   return dependentOperations;
337 }
338 
339 void LinalgDependenceGraph::print(raw_ostream &os) const {
340   for (auto dt : {
341            LinalgDependenceGraph::DependenceType::RAW,
342            LinalgDependenceGraph::DependenceType::WAW,
343        }) {
344     const auto &fromGraph = dependencesFromGraphs[dt];
345     for (const auto &it : fromGraph) {
346       os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first
347          << ":\n";
348       for (const auto &dep : it.second) {
349         os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n";
350       }
351     }
352   }
353 }
354 
355 void LinalgDependenceGraph::dump() const { print(llvm::errs()); }
356