1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements view-based alias and dependence analyses.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/IR/BuiltinOps.h"
18
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21
22 #define DEBUG_TYPE "linalg-dependence-analysis"
23
24 using namespace mlir;
25 using namespace mlir::linalg;
26
27 using llvm::dbgs;
28
find(Value v)29 Value Aliases::find(Value v) {
30 if (v.isa<BlockArgument>())
31 return v;
32
33 auto it = aliases.find(v);
34 if (it != aliases.end()) {
35 assert(it->getSecond().getType().isa<BaseMemRefType>() &&
36 "Memref expected");
37 return it->getSecond();
38 }
39
40 while (true) {
41 if (v.isa<BlockArgument>())
42 return v;
43
44 Operation *defOp = v.getDefiningOp();
45 if (!defOp)
46 return v;
47
48 // Treat RegionBranchOpInterfaces like an allocate and don't try to follow
49 // the aliasing further.
50 if (isa<RegionBranchOpInterface>(defOp))
51 return v;
52 if (isa<bufferization::ToMemrefOp>(defOp))
53 return v;
54
55 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
56 // Collect all memory effects on `v`.
57 SmallVector<MemoryEffects::EffectInstance, 1> effects;
58 memEffect.getEffectsOnValue(v, effects);
59
60 // If we have the 'Allocate' memory effect on `v`, then `v` should be the
61 // original buffer.
62 if (llvm::any_of(
63 effects, [](const MemoryEffects::EffectInstance &instance) {
64 return isa<MemoryEffects::Allocate>(instance.getEffect());
65 }))
66 return v;
67 }
68
69 if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) {
70 auto it =
71 aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource())));
72 return it.first->second;
73 }
74
75 llvm::errs() << "View alias analysis reduces to: " << v << "\n";
76 llvm_unreachable("unsupported view alias case");
77 }
78 }
79
getDependenceTypeStr(DependenceType depType)80 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
81 switch (depType) {
82 case LinalgDependenceGraph::DependenceType::RAW:
83 return "RAW";
84 case LinalgDependenceGraph::DependenceType::RAR:
85 return "RAR";
86 case LinalgDependenceGraph::DependenceType::WAR:
87 return "WAR";
88 case LinalgDependenceGraph::DependenceType::WAW:
89 return "WAW";
90 default:
91 break;
92 }
93 llvm_unreachable("Unexpected DependenceType");
94 }
95
96 LinalgDependenceGraph
buildDependenceGraph(Aliases & aliases,func::FuncOp f)97 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, func::FuncOp f) {
98 SmallVector<LinalgOp, 8> linalgOps;
99 f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
100 return LinalgDependenceGraph(aliases, linalgOps);
101 }
102
LinalgDependenceGraph(Aliases & aliases,ArrayRef<LinalgOp> ops)103 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
104 ArrayRef<LinalgOp> ops)
105 : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
106 for (const auto &en : llvm::enumerate(linalgOps)) {
107 linalgOpPositions.insert(
108 std::make_pair(en.value().getOperation(), en.index()));
109 }
110 for (unsigned i = 0, e = ops.size(); i < e; ++i) {
111 for (unsigned j = i + 1; j < e; ++j) {
112 addDependencesBetween(ops[i], ops[j]);
113 }
114 }
115 }
116
addDependenceElem(DependenceType dt,LinalgDependenceGraphElem::OpView indexingOpView,LinalgDependenceGraphElem::OpView dependentOpView)117 void LinalgDependenceGraph::addDependenceElem(
118 DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView,
119 LinalgDependenceGraphElem::OpView dependentOpView) {
120 LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
121 << LinalgDependenceGraphElem::getValue(indexingOpView)
122 << " @) -> \n\t\t("
123 << LinalgDependenceGraphElem::getValue(dependentOpView)
124 << " @)");
125 dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)]
126 .push_back(
127 LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
128 dependencesIntoGraphs[dt]
129 [LinalgDependenceGraphElem::getOwner(dependentOpView)]
130 .push_back(LinalgDependenceGraphElem{
131 indexingOpView, dependentOpView, dt});
132 }
133
134 LinalgDependenceGraph::dependence_range
getDependencesFrom(LinalgOp src,LinalgDependenceGraph::DependenceType dt) const135 LinalgDependenceGraph::getDependencesFrom(
136 LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
137 return getDependencesFrom(src.getOperation(), dt);
138 }
139
140 LinalgDependenceGraph::dependence_range
getDependencesFrom(Operation * src,LinalgDependenceGraph::DependenceType dt) const141 LinalgDependenceGraph::getDependencesFrom(
142 Operation *src, LinalgDependenceGraph::DependenceType dt) const {
143 auto iter = dependencesFromGraphs[dt].find(src);
144 if (iter == dependencesFromGraphs[dt].end())
145 return llvm::make_range(nullptr, nullptr);
146 return llvm::make_range(iter->second.begin(), iter->second.end());
147 }
148
149 LinalgDependenceGraph::dependence_range
getDependencesInto(LinalgOp dst,LinalgDependenceGraph::DependenceType dt) const150 LinalgDependenceGraph::getDependencesInto(
151 LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
152 return getDependencesInto(dst.getOperation(), dt);
153 }
154
155 LinalgDependenceGraph::dependence_range
getDependencesInto(Operation * dst,LinalgDependenceGraph::DependenceType dt) const156 LinalgDependenceGraph::getDependencesInto(
157 Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
158 auto iter = dependencesIntoGraphs[dt].find(dst);
159 if (iter == dependencesIntoGraphs[dt].end())
160 return llvm::make_range(nullptr, nullptr);
161 return llvm::make_range(iter->second.begin(), iter->second.end());
162 }
163
addDependencesBetween(LinalgOp src,LinalgOp dst)164 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
165 LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation()
166 << " and " << *dst.getOperation() << "\n");
167 if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
168 for (OpOperand *dstOpOperand : dst.getInputOperands()) {
169 // Check if the operand is defined by the src.
170 auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
171 if (definingOp && definingOp == src)
172 addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
173 dstOpOperand);
174 }
175 for (OpOperand *dstOpOperand : dst.getOutputOperands()) {
176 // Check if the operand is defined by the src.
177 auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
178 if (definingOp && definingOp == src) {
179 if (dst.isInitTensor(dstOpOperand)) {
180 addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
181 dstOpOperand);
182 }
183 addDependenceElem(DependenceType::WAW, dstOpOperand->get(),
184 dstOpOperand);
185 }
186 }
187 return;
188 }
189 assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
190 "unhandled dependence tracking for mixed buffer/tensor operations");
191 for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W
192 // RAW graph
193 for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R
194 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias
195 addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
196 // WAW graph
197 for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W
198 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
199 addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
200 }
201 for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R
202 // RAR graph
203 for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R
204 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias
205 addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
206 // WAR graph
207 for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W
208 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
209 addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
210 }
211 }
212
213 SmallVector<Operation *, 8>
findCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp) const214 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
215 LinalgOp dstLinalgOp) const {
216 return findOperationsWithCoveringDependences(
217 srcLinalgOp, dstLinalgOp, nullptr,
218 {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
219 }
220
findCoveringWrites(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const221 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
222 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
223 return findOperationsWithCoveringDependences(
224 srcLinalgOp, dstLinalgOp, view,
225 {DependenceType::WAW, DependenceType::WAR});
226 }
227
findCoveringReads(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const228 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
229 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
230 return findOperationsWithCoveringDependences(
231 srcLinalgOp, dstLinalgOp, view,
232 {DependenceType::RAR, DependenceType::RAW});
233 }
234
235 SmallVector<Operation *, 8>
findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view,ArrayRef<DependenceType> types) const236 LinalgDependenceGraph::findOperationsWithCoveringDependences(
237 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
238 ArrayRef<DependenceType> types) const {
239 auto *src = srcLinalgOp.getOperation();
240 auto *dst = dstLinalgOp.getOperation();
241 auto srcPos = linalgOpPositions.lookup(src);
242 auto dstPos = linalgOpPositions.lookup(dst);
243 assert(srcPos < dstPos && "expected dst after src in IR traversal order");
244
245 SmallVector<Operation *, 8> res;
246 // Consider an intermediate interleaved `interim` op, look for any dependence
247 // to an aliasing view on a src -> op -> dst path.
248 // TODO: we are not considering paths yet, just interleaved positions.
249 for (auto dt : types) {
250 for (auto dependence : getDependencesFrom(src, dt)) {
251 auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp());
252 // Skip if not interleaved.
253 if (interimPos >= dstPos || interimPos <= srcPos)
254 continue;
255 Value consumerView = dependence.getIndexingValue();
256 if (view && !aliases.alias(view, consumerView))
257 continue;
258 auto *op = dependence.getDependentOp();
259 LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
260 << getDependenceTypeStr(dt) << ": " << *src << " -> "
261 << *op << " on " << consumerView);
262 res.push_back(op);
263 }
264 }
265 return res;
266 }
267
hasDependenceFrom(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const268 bool LinalgDependenceGraph::hasDependenceFrom(
269 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
270 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
271 for (auto dep : depTypes)
272 for (auto dependence : getDependencesInto(dstLinalgOp, dep))
273 if (dependence.getDependentOp() == srcLinalgOp)
274 return true;
275 return false;
276 }
277
hasDependentOperationsFrom(LinalgOp linalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const278 bool LinalgDependenceGraph::hasDependentOperationsFrom(
279 LinalgOp linalgOp,
280 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
281 for (auto dep : depTypes) {
282 if (!getDependencesFrom(linalgOp, dep).empty())
283 return true;
284 }
285 return false;
286 }
287
hasDependentOperationsInto(LinalgOp linalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const288 bool LinalgDependenceGraph::hasDependentOperationsInto(
289 LinalgOp linalgOp,
290 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
291 for (auto dep : depTypes) {
292 if (!getDependencesInto(linalgOp, dep).empty())
293 return true;
294 }
295 return false;
296 }
297
hasDependentOperations(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const298 bool LinalgDependenceGraph::hasDependentOperations(
299 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
300 return hasDependentOperationsInto(linalgOp, depTypes) ||
301 hasDependentOperationsFrom(linalgOp, depTypes);
302 }
303
304 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperationsInto(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const305 LinalgDependenceGraph::getDependentOperationsInto(
306 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
307 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
308 dependentOperations;
309 for (auto dependenceType : depTypes) {
310 auto dependencies = getDependencesInto(linalgOp, dependenceType);
311 dependentOperations.append(dependencies.begin(), dependencies.end());
312 }
313 return dependentOperations;
314 }
315
316 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperationsFrom(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const317 LinalgDependenceGraph::getDependentOperationsFrom(
318 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
319 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
320 dependentOperations;
321 for (auto dependenceType : depTypes) {
322 auto dependencies = getDependencesFrom(linalgOp, dependenceType);
323 dependentOperations.append(dependencies.begin(), dependencies.end());
324 }
325 return dependentOperations;
326 }
327
328 /// Returns all dependent operations (into and from) given `operation`.
329 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperations(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const330 LinalgDependenceGraph::getDependentOperations(
331 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
332 SmallVector<LinalgDependenceGraphElem, 2> dependentOperations =
333 getDependentOperationsInto(linalgOp, depTypes);
334 SmallVector<LinalgDependenceGraphElem, 2> t =
335 getDependentOperationsFrom(linalgOp, depTypes);
336 dependentOperations.append(t.begin(), t.end());
337 return dependentOperations;
338 }
339
print(raw_ostream & os) const340 void LinalgDependenceGraph::print(raw_ostream &os) const {
341 for (auto dt : {
342 LinalgDependenceGraph::DependenceType::RAW,
343 LinalgDependenceGraph::DependenceType::WAW,
344 }) {
345 const auto &fromGraph = dependencesFromGraphs[dt];
346 for (const auto &it : fromGraph) {
347 os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first
348 << ":\n";
349 for (const auto &dep : it.second) {
350 os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n";
351 }
352 }
353 }
354 }
355
dump() const356 void LinalgDependenceGraph::dump() const { print(llvm::errs()); }
357