1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements view-based alias and dependence analyses. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/IR/BuiltinOps.h" 18 19 #include "llvm/Support/CommandLine.h" 20 #include "llvm/Support/Debug.h" 21 22 #define DEBUG_TYPE "linalg-dependence-analysis" 23 24 using namespace mlir; 25 using namespace mlir::linalg; 26 27 using llvm::dbgs; 28 29 Value Aliases::find(Value v) { 30 if (v.isa<BlockArgument>()) 31 return v; 32 33 auto it = aliases.find(v); 34 if (it != aliases.end()) { 35 assert(it->getSecond().getType().isa<BaseMemRefType>() && 36 "Memref expected"); 37 return it->getSecond(); 38 } 39 40 while (true) { 41 if (v.isa<BlockArgument>()) 42 return v; 43 44 Operation *defOp = v.getDefiningOp(); 45 if (!defOp) 46 return v; 47 48 // Treat RegionBranchOpInterfaces like an allocate and don't try to follow 49 // the aliasing further. 50 if (isa<RegionBranchOpInterface>(defOp)) 51 return v; 52 if (isa<bufferization::ToMemrefOp>(defOp)) 53 return v; 54 55 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) { 56 // Collect all memory effects on `v`. 57 SmallVector<MemoryEffects::EffectInstance, 1> effects; 58 memEffect.getEffectsOnValue(v, effects); 59 60 // If we have the 'Allocate' memory effect on `v`, then `v` should be the 61 // original buffer. 62 if (llvm::any_of( 63 effects, [](const MemoryEffects::EffectInstance &instance) { 64 return isa<MemoryEffects::Allocate>(instance.getEffect()); 65 })) 66 return v; 67 } 68 69 if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) { 70 auto it = 71 aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource()))); 72 return it.first->second; 73 } 74 75 llvm::errs() << "View alias analysis reduces to: " << v << "\n"; 76 llvm_unreachable("unsupported view alias case"); 77 } 78 } 79 80 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) { 81 switch (depType) { 82 case LinalgDependenceGraph::DependenceType::RAW: 83 return "RAW"; 84 case LinalgDependenceGraph::DependenceType::RAR: 85 return "RAR"; 86 case LinalgDependenceGraph::DependenceType::WAR: 87 return "WAR"; 88 case LinalgDependenceGraph::DependenceType::WAW: 89 return "WAW"; 90 default: 91 break; 92 } 93 llvm_unreachable("Unexpected DependenceType"); 94 } 95 96 LinalgDependenceGraph 97 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, func::FuncOp f) { 98 SmallVector<LinalgOp, 8> linalgOps; 99 f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); 100 return LinalgDependenceGraph(aliases, linalgOps); 101 } 102 103 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases, 104 ArrayRef<LinalgOp> ops) 105 : aliases(aliases), linalgOps(ops.begin(), ops.end()) { 106 for (const auto &en : llvm::enumerate(linalgOps)) { 107 linalgOpPositions.insert( 108 std::make_pair(en.value().getOperation(), en.index())); 109 } 110 for (unsigned i = 0, e = ops.size(); i < e; ++i) { 111 for (unsigned j = i + 1; j < e; ++j) { 112 addDependencesBetween(ops[i], ops[j]); 113 } 114 } 115 } 116 117 void LinalgDependenceGraph::addDependenceElem( 118 DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView, 119 LinalgDependenceGraphElem::OpView dependentOpView) { 120 LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t (" 121 << LinalgDependenceGraphElem::getValue(indexingOpView) 122 << " @) -> \n\t\t(" 123 << LinalgDependenceGraphElem::getValue(dependentOpView) 124 << " @)"); 125 dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)] 126 .push_back( 127 LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt}); 128 dependencesIntoGraphs[dt] 129 [LinalgDependenceGraphElem::getOwner(dependentOpView)] 130 .push_back(LinalgDependenceGraphElem{ 131 indexingOpView, dependentOpView, dt}); 132 } 133 134 LinalgDependenceGraph::dependence_range 135 LinalgDependenceGraph::getDependencesFrom( 136 LinalgOp src, LinalgDependenceGraph::DependenceType dt) const { 137 return getDependencesFrom(src.getOperation(), dt); 138 } 139 140 LinalgDependenceGraph::dependence_range 141 LinalgDependenceGraph::getDependencesFrom( 142 Operation *src, LinalgDependenceGraph::DependenceType dt) const { 143 auto iter = dependencesFromGraphs[dt].find(src); 144 if (iter == dependencesFromGraphs[dt].end()) 145 return llvm::make_range(nullptr, nullptr); 146 return llvm::make_range(iter->second.begin(), iter->second.end()); 147 } 148 149 LinalgDependenceGraph::dependence_range 150 LinalgDependenceGraph::getDependencesInto( 151 LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const { 152 return getDependencesInto(dst.getOperation(), dt); 153 } 154 155 LinalgDependenceGraph::dependence_range 156 LinalgDependenceGraph::getDependencesInto( 157 Operation *dst, LinalgDependenceGraph::DependenceType dt) const { 158 auto iter = dependencesIntoGraphs[dt].find(dst); 159 if (iter == dependencesIntoGraphs[dt].end()) 160 return llvm::make_range(nullptr, nullptr); 161 return llvm::make_range(iter->second.begin(), iter->second.end()); 162 } 163 164 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { 165 LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation() 166 << " and " << *dst.getOperation() << "\n"); 167 if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { 168 for (OpOperand *dstOpOperand : dst.getInputOperands()) { 169 // Check if the operand is defined by the src. 170 auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>(); 171 if (definingOp && definingOp == src) 172 addDependenceElem(DependenceType::RAW, dstOpOperand->get(), 173 dstOpOperand); 174 } 175 for (OpOperand *dstOpOperand : dst.getOutputOperands()) { 176 // Check if the operand is defined by the src. 177 auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>(); 178 if (definingOp && definingOp == src) { 179 if (dst.isInitTensor(dstOpOperand)) { 180 addDependenceElem(DependenceType::RAW, dstOpOperand->get(), 181 dstOpOperand); 182 } 183 addDependenceElem(DependenceType::WAW, dstOpOperand->get(), 184 dstOpOperand); 185 } 186 } 187 return; 188 } 189 assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && 190 "unhandled dependence tracking for mixed buffer/tensor operations"); 191 for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W 192 // RAW graph 193 for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R 194 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias 195 addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand); 196 // WAW graph 197 for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W 198 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias 199 addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand); 200 } 201 for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R 202 // RAR graph 203 for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R 204 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias 205 addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand); 206 // WAR graph 207 for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W 208 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias 209 addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand); 210 } 211 } 212 213 SmallVector<Operation *, 8> 214 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp, 215 LinalgOp dstLinalgOp) const { 216 return findOperationsWithCoveringDependences( 217 srcLinalgOp, dstLinalgOp, nullptr, 218 {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW}); 219 } 220 221 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites( 222 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const { 223 return findOperationsWithCoveringDependences( 224 srcLinalgOp, dstLinalgOp, view, 225 {DependenceType::WAW, DependenceType::WAR}); 226 } 227 228 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads( 229 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const { 230 return findOperationsWithCoveringDependences( 231 srcLinalgOp, dstLinalgOp, view, 232 {DependenceType::RAR, DependenceType::RAW}); 233 } 234 235 SmallVector<Operation *, 8> 236 LinalgDependenceGraph::findOperationsWithCoveringDependences( 237 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view, 238 ArrayRef<DependenceType> types) const { 239 auto *src = srcLinalgOp.getOperation(); 240 auto *dst = dstLinalgOp.getOperation(); 241 auto srcPos = linalgOpPositions.lookup(src); 242 auto dstPos = linalgOpPositions.lookup(dst); 243 assert(srcPos < dstPos && "expected dst after src in IR traversal order"); 244 245 SmallVector<Operation *, 8> res; 246 // Consider an intermediate interleaved `interim` op, look for any dependence 247 // to an aliasing view on a src -> op -> dst path. 248 // TODO: we are not considering paths yet, just interleaved positions. 249 for (auto dt : types) { 250 for (auto dependence : getDependencesFrom(src, dt)) { 251 auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp()); 252 // Skip if not interleaved. 253 if (interimPos >= dstPos || interimPos <= srcPos) 254 continue; 255 Value consumerView = dependence.getIndexingValue(); 256 if (view && !aliases.alias(view, consumerView)) 257 continue; 258 auto *op = dependence.getDependentOp(); 259 LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " 260 << getDependenceTypeStr(dt) << ": " << *src << " -> " 261 << *op << " on " << consumerView); 262 res.push_back(op); 263 } 264 } 265 return res; 266 } 267 268 bool LinalgDependenceGraph::hasDependenceFrom( 269 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, 270 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const { 271 for (auto dep : depTypes) 272 for (auto dependence : getDependencesInto(dstLinalgOp, dep)) 273 if (dependence.getDependentOp() == srcLinalgOp) 274 return true; 275 return false; 276 } 277 278 bool LinalgDependenceGraph::hasDependentOperationsFrom( 279 LinalgOp linalgOp, 280 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const { 281 for (auto dep : depTypes) { 282 if (!getDependencesFrom(linalgOp, dep).empty()) 283 return true; 284 } 285 return false; 286 } 287 288 bool LinalgDependenceGraph::hasDependentOperationsInto( 289 LinalgOp linalgOp, 290 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const { 291 for (auto dep : depTypes) { 292 if (!getDependencesInto(linalgOp, dep).empty()) 293 return true; 294 } 295 return false; 296 } 297 298 bool LinalgDependenceGraph::hasDependentOperations( 299 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 300 return hasDependentOperationsInto(linalgOp, depTypes) || 301 hasDependentOperationsFrom(linalgOp, depTypes); 302 } 303 304 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 305 LinalgDependenceGraph::getDependentOperationsInto( 306 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 307 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 308 dependentOperations; 309 for (auto dependenceType : depTypes) { 310 auto dependencies = getDependencesInto(linalgOp, dependenceType); 311 dependentOperations.append(dependencies.begin(), dependencies.end()); 312 } 313 return dependentOperations; 314 } 315 316 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 317 LinalgDependenceGraph::getDependentOperationsFrom( 318 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 319 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 320 dependentOperations; 321 for (auto dependenceType : depTypes) { 322 auto dependencies = getDependencesFrom(linalgOp, dependenceType); 323 dependentOperations.append(dependencies.begin(), dependencies.end()); 324 } 325 return dependentOperations; 326 } 327 328 /// Returns all dependent operations (into and from) given `operation`. 329 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 330 LinalgDependenceGraph::getDependentOperations( 331 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 332 SmallVector<LinalgDependenceGraphElem, 2> dependentOperations = 333 getDependentOperationsInto(linalgOp, depTypes); 334 SmallVector<LinalgDependenceGraphElem, 2> t = 335 getDependentOperationsFrom(linalgOp, depTypes); 336 dependentOperations.append(t.begin(), t.end()); 337 return dependentOperations; 338 } 339 340 void LinalgDependenceGraph::print(raw_ostream &os) const { 341 for (auto dt : { 342 LinalgDependenceGraph::DependenceType::RAW, 343 LinalgDependenceGraph::DependenceType::WAW, 344 }) { 345 const auto &fromGraph = dependencesFromGraphs[dt]; 346 for (const auto &it : fromGraph) { 347 os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first 348 << ":\n"; 349 for (const auto &dep : it.second) { 350 os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n"; 351 } 352 } 353 } 354 } 355 356 void LinalgDependenceGraph::dump() const { print(llvm::errs()); } 357