1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements view-based alias and dependence analyses. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/IR/BuiltinOps.h" 17 18 #include "llvm/Support/CommandLine.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "linalg-dependence-analysis" 22 23 using namespace mlir; 24 using namespace mlir::linalg; 25 26 using llvm::dbgs; 27 28 Value Aliases::find(Value v) { 29 if (v.isa<BlockArgument>()) 30 return v; 31 32 auto it = aliases.find(v); 33 if (it != aliases.end()) { 34 assert(it->getSecond().getType().isa<BaseMemRefType>() && 35 "Memref expected"); 36 return it->getSecond(); 37 } 38 39 while (true) { 40 if (v.isa<BlockArgument>()) 41 return v; 42 43 Operation *defOp = v.getDefiningOp(); 44 if (!defOp) 45 return v; 46 47 // Treat RegionBranchOpInterfaces like an allocate and don't try to follow 48 // the aliasing further. 49 if (isa<RegionBranchOpInterface>(defOp)) 50 return v; 51 if (isa<bufferization::ToMemrefOp>(defOp)) 52 return v; 53 54 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) { 55 // Collect all memory effects on `v`. 56 SmallVector<MemoryEffects::EffectInstance, 1> effects; 57 memEffect.getEffectsOnValue(v, effects); 58 59 // If we have the 'Allocate' memory effect on `v`, then `v` should be the 60 // original buffer. 61 if (llvm::any_of( 62 effects, [](const MemoryEffects::EffectInstance &instance) { 63 return isa<MemoryEffects::Allocate>(instance.getEffect()); 64 })) 65 return v; 66 } 67 68 if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) { 69 auto it = 70 aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource()))); 71 return it.first->second; 72 } 73 74 llvm::errs() << "View alias analysis reduces to: " << v << "\n"; 75 llvm_unreachable("unsupported view alias case"); 76 } 77 } 78 79 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) { 80 switch (depType) { 81 case LinalgDependenceGraph::DependenceType::RAW: 82 return "RAW"; 83 case LinalgDependenceGraph::DependenceType::RAR: 84 return "RAR"; 85 case LinalgDependenceGraph::DependenceType::WAR: 86 return "WAR"; 87 case LinalgDependenceGraph::DependenceType::WAW: 88 return "WAW"; 89 default: 90 break; 91 } 92 llvm_unreachable("Unexpected DependenceType"); 93 } 94 95 LinalgDependenceGraph 96 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) { 97 SmallVector<LinalgOp, 8> linalgOps; 98 f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); 99 return LinalgDependenceGraph(aliases, linalgOps); 100 } 101 102 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases, 103 ArrayRef<LinalgOp> ops) 104 : aliases(aliases), linalgOps(ops.begin(), ops.end()) { 105 for (const auto &en : llvm::enumerate(linalgOps)) { 106 linalgOpPositions.insert( 107 std::make_pair(en.value().getOperation(), en.index())); 108 } 109 for (unsigned i = 0, e = ops.size(); i < e; ++i) { 110 for (unsigned j = i + 1; j < e; ++j) { 111 addDependencesBetween(ops[i], ops[j]); 112 } 113 } 114 } 115 116 void LinalgDependenceGraph::addDependenceElem( 117 DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView, 118 LinalgDependenceGraphElem::OpView dependentOpView) { 119 LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t (" 120 << LinalgDependenceGraphElem::getValue(indexingOpView) 121 << " @) -> \n\t\t(" 122 << LinalgDependenceGraphElem::getValue(dependentOpView) 123 << " @)"); 124 dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)] 125 .push_back( 126 LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt}); 127 dependencesIntoGraphs[dt] 128 [LinalgDependenceGraphElem::getOwner(dependentOpView)] 129 .push_back(LinalgDependenceGraphElem{ 130 indexingOpView, dependentOpView, dt}); 131 } 132 133 LinalgDependenceGraph::dependence_range 134 LinalgDependenceGraph::getDependencesFrom( 135 LinalgOp src, LinalgDependenceGraph::DependenceType dt) const { 136 return getDependencesFrom(src.getOperation(), dt); 137 } 138 139 LinalgDependenceGraph::dependence_range 140 LinalgDependenceGraph::getDependencesFrom( 141 Operation *src, LinalgDependenceGraph::DependenceType dt) const { 142 auto iter = dependencesFromGraphs[dt].find(src); 143 if (iter == dependencesFromGraphs[dt].end()) 144 return llvm::make_range(nullptr, nullptr); 145 return llvm::make_range(iter->second.begin(), iter->second.end()); 146 } 147 148 LinalgDependenceGraph::dependence_range 149 LinalgDependenceGraph::getDependencesInto( 150 LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const { 151 return getDependencesInto(dst.getOperation(), dt); 152 } 153 154 LinalgDependenceGraph::dependence_range 155 LinalgDependenceGraph::getDependencesInto( 156 Operation *dst, LinalgDependenceGraph::DependenceType dt) const { 157 auto iter = dependencesIntoGraphs[dt].find(dst); 158 if (iter == dependencesIntoGraphs[dt].end()) 159 return llvm::make_range(nullptr, nullptr); 160 return llvm::make_range(iter->second.begin(), iter->second.end()); 161 } 162 163 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { 164 LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation() 165 << " and " << *dst.getOperation() << "\n"); 166 if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { 167 for (OpOperand *dstOpOperand : dst.getInputOperands()) { 168 // Check if the operand is defined by the src. 169 auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>(); 170 if (definingOp && definingOp == src) 171 addDependenceElem(DependenceType::RAW, dstOpOperand->get(), 172 dstOpOperand); 173 } 174 for (OpOperand *dstOpOperand : dst.getOutputOperands()) { 175 // Check if the operand is defined by the src. 176 auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>(); 177 if (definingOp && definingOp == src) { 178 if (dst.isInitTensor(dstOpOperand)) { 179 addDependenceElem(DependenceType::RAW, dstOpOperand->get(), 180 dstOpOperand); 181 } 182 addDependenceElem(DependenceType::WAW, dstOpOperand->get(), 183 dstOpOperand); 184 } 185 } 186 return; 187 } 188 assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && 189 "unhandled dependence tracking for mixed buffer/tensor operations"); 190 for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W 191 // RAW graph 192 for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R 193 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias 194 addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand); 195 // WAW graph 196 for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W 197 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias 198 addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand); 199 } 200 for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R 201 // RAR graph 202 for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R 203 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias 204 addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand); 205 // WAR graph 206 for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W 207 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias 208 addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand); 209 } 210 } 211 212 SmallVector<Operation *, 8> 213 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp, 214 LinalgOp dstLinalgOp) const { 215 return findOperationsWithCoveringDependences( 216 srcLinalgOp, dstLinalgOp, nullptr, 217 {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW}); 218 } 219 220 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites( 221 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const { 222 return findOperationsWithCoveringDependences( 223 srcLinalgOp, dstLinalgOp, view, 224 {DependenceType::WAW, DependenceType::WAR}); 225 } 226 227 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads( 228 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const { 229 return findOperationsWithCoveringDependences( 230 srcLinalgOp, dstLinalgOp, view, 231 {DependenceType::RAR, DependenceType::RAW}); 232 } 233 234 SmallVector<Operation *, 8> 235 LinalgDependenceGraph::findOperationsWithCoveringDependences( 236 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view, 237 ArrayRef<DependenceType> types) const { 238 auto *src = srcLinalgOp.getOperation(); 239 auto *dst = dstLinalgOp.getOperation(); 240 auto srcPos = linalgOpPositions.lookup(src); 241 auto dstPos = linalgOpPositions.lookup(dst); 242 assert(srcPos < dstPos && "expected dst after src in IR traversal order"); 243 244 SmallVector<Operation *, 8> res; 245 // Consider an intermediate interleaved `interim` op, look for any dependence 246 // to an aliasing view on a src -> op -> dst path. 247 // TODO: we are not considering paths yet, just interleaved positions. 248 for (auto dt : types) { 249 for (auto dependence : getDependencesFrom(src, dt)) { 250 auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp()); 251 // Skip if not interleaved. 252 if (interimPos >= dstPos || interimPos <= srcPos) 253 continue; 254 Value consumerView = dependence.getIndexingValue(); 255 if (view && !aliases.alias(view, consumerView)) 256 continue; 257 auto *op = dependence.getDependentOp(); 258 LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " 259 << getDependenceTypeStr(dt) << ": " << *src << " -> " 260 << *op << " on " << consumerView); 261 res.push_back(op); 262 } 263 } 264 return res; 265 } 266 267 bool LinalgDependenceGraph::hasDependenceFrom( 268 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, 269 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const { 270 for (auto dep : depTypes) 271 for (auto dependence : getDependencesInto(dstLinalgOp, dep)) 272 if (dependence.getDependentOp() == srcLinalgOp) 273 return true; 274 return false; 275 } 276 277 bool LinalgDependenceGraph::hasDependentOperationsFrom( 278 LinalgOp linalgOp, 279 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const { 280 for (auto dep : depTypes) { 281 if (!getDependencesFrom(linalgOp, dep).empty()) 282 return true; 283 } 284 return false; 285 } 286 287 bool LinalgDependenceGraph::hasDependentOperationsInto( 288 LinalgOp linalgOp, 289 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const { 290 for (auto dep : depTypes) { 291 if (!getDependencesInto(linalgOp, dep).empty()) 292 return true; 293 } 294 return false; 295 } 296 297 bool LinalgDependenceGraph::hasDependentOperations( 298 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 299 return hasDependentOperationsInto(linalgOp, depTypes) || 300 hasDependentOperationsFrom(linalgOp, depTypes); 301 } 302 303 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 304 LinalgDependenceGraph::getDependentOperationsInto( 305 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 306 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 307 dependentOperations; 308 for (auto dependenceType : depTypes) { 309 auto dependencies = getDependencesInto(linalgOp, dependenceType); 310 dependentOperations.append(dependencies.begin(), dependencies.end()); 311 } 312 return dependentOperations; 313 } 314 315 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 316 LinalgDependenceGraph::getDependentOperationsFrom( 317 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 318 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 319 dependentOperations; 320 for (auto dependenceType : depTypes) { 321 auto dependencies = getDependencesFrom(linalgOp, dependenceType); 322 dependentOperations.append(dependencies.begin(), dependencies.end()); 323 } 324 return dependentOperations; 325 } 326 327 /// Returns all dependent operations (into and from) given `operation`. 328 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 329 LinalgDependenceGraph::getDependentOperations( 330 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 331 SmallVector<LinalgDependenceGraphElem, 2> dependentOperations = 332 getDependentOperationsInto(linalgOp, depTypes); 333 SmallVector<LinalgDependenceGraphElem, 2> t = 334 getDependentOperationsFrom(linalgOp, depTypes); 335 dependentOperations.append(t.begin(), t.end()); 336 return dependentOperations; 337 } 338 339 void LinalgDependenceGraph::print(raw_ostream &os) const { 340 for (auto dt : { 341 LinalgDependenceGraph::DependenceType::RAW, 342 LinalgDependenceGraph::DependenceType::WAW, 343 }) { 344 const auto &fromGraph = dependencesFromGraphs[dt]; 345 for (const auto &it : fromGraph) { 346 os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first 347 << ":\n"; 348 for (const auto &dep : it.second) { 349 os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n"; 350 } 351 } 352 } 353 } 354 355 void LinalgDependenceGraph::dump() const { print(llvm::errs()); } 356