1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements view-based alias and dependence analyses. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/BuiltinOps.h" 17 18 #include "llvm/Support/CommandLine.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "linalg-dependence-analysis" 22 23 using namespace mlir; 24 using namespace mlir::linalg; 25 26 using llvm::dbgs; 27 28 Value Aliases::find(Value v) { 29 if (v.isa<BlockArgument>()) 30 return v; 31 32 auto it = aliases.find(v); 33 if (it != aliases.end()) { 34 assert(it->getSecond().getType().isa<BaseMemRefType>() && 35 "Memref expected"); 36 return it->getSecond(); 37 } 38 39 while (true) { 40 if (v.isa<BlockArgument>()) 41 return v; 42 43 Operation *defOp = v.getDefiningOp(); 44 if (!defOp) 45 return v; 46 47 // Treat RegionBranchOpInterfaces like an allocate and don't try to follow 48 // the aliasing further. 49 if (isa<RegionBranchOpInterface>(defOp)) 50 return v; 51 if (isa<TensorToMemrefOp>(defOp)) 52 return v; 53 54 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) { 55 // Collect all memory effects on `v`. 56 SmallVector<MemoryEffects::EffectInstance, 1> effects; 57 memEffect.getEffectsOnValue(v, effects); 58 59 // If we have the 'Allocate' memory effect on `v`, then `v` should be the 60 // original buffer. 61 if (llvm::any_of( 62 effects, [](const MemoryEffects::EffectInstance &instance) { 63 return isa<MemoryEffects::Allocate>(instance.getEffect()); 64 })) 65 return v; 66 } 67 68 if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) { 69 auto it = 70 aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource()))); 71 return it.first->second; 72 } 73 74 llvm::errs() << "View alias analysis reduces to: " << v << "\n"; 75 llvm_unreachable("unsupported view alias case"); 76 } 77 } 78 79 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) { 80 switch (depType) { 81 case LinalgDependenceGraph::DependenceType::RAW: 82 return "RAW"; 83 case LinalgDependenceGraph::DependenceType::RAR: 84 return "RAR"; 85 case LinalgDependenceGraph::DependenceType::WAR: 86 return "WAR"; 87 case LinalgDependenceGraph::DependenceType::WAW: 88 return "WAW"; 89 default: 90 break; 91 } 92 llvm_unreachable("Unexpected DependenceType"); 93 } 94 95 LinalgDependenceGraph 96 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) { 97 SmallVector<LinalgOp, 8> linalgOps; 98 f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); 99 return LinalgDependenceGraph(aliases, linalgOps); 100 } 101 102 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases, 103 ArrayRef<LinalgOp> ops) 104 : aliases(aliases), linalgOps(ops.begin(), ops.end()) { 105 for (auto en : llvm::enumerate(linalgOps)) { 106 linalgOpPositions.insert( 107 std::make_pair(en.value().getOperation(), en.index())); 108 } 109 for (unsigned i = 0, e = ops.size(); i < e; ++i) { 110 for (unsigned j = i + 1; j < e; ++j) { 111 addDependencesBetween(ops[i], ops[j]); 112 } 113 } 114 } 115 116 void LinalgDependenceGraph::addDependenceElem( 117 DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView, 118 LinalgDependenceGraphElem::OpView dependentOpView) { 119 LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t (" 120 << LinalgDependenceGraphElem::getValue(indexingOpView) 121 << " @) -> \n\t\t(" 122 << LinalgDependenceGraphElem::getValue(dependentOpView) 123 << " @)"); 124 dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)] 125 .push_back( 126 LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt}); 127 dependencesIntoGraphs[dt] 128 [LinalgDependenceGraphElem::getOwner(dependentOpView)] 129 .push_back(LinalgDependenceGraphElem{ 130 indexingOpView, dependentOpView, dt}); 131 } 132 133 LinalgDependenceGraph::dependence_range 134 LinalgDependenceGraph::getDependencesFrom( 135 LinalgOp src, LinalgDependenceGraph::DependenceType dt) const { 136 return getDependencesFrom(src.getOperation(), dt); 137 } 138 139 LinalgDependenceGraph::dependence_range 140 LinalgDependenceGraph::getDependencesFrom( 141 Operation *src, LinalgDependenceGraph::DependenceType dt) const { 142 auto iter = dependencesFromGraphs[dt].find(src); 143 if (iter == dependencesFromGraphs[dt].end()) 144 return llvm::make_range(nullptr, nullptr); 145 return llvm::make_range(iter->second.begin(), iter->second.end()); 146 } 147 148 LinalgDependenceGraph::dependence_range 149 LinalgDependenceGraph::getDependencesInto( 150 LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const { 151 return getDependencesInto(dst.getOperation(), dt); 152 } 153 154 LinalgDependenceGraph::dependence_range 155 LinalgDependenceGraph::getDependencesInto( 156 Operation *dst, LinalgDependenceGraph::DependenceType dt) const { 157 auto iter = dependencesIntoGraphs[dt].find(dst); 158 if (iter == dependencesIntoGraphs[dt].end()) 159 return llvm::make_range(nullptr, nullptr); 160 return llvm::make_range(iter->second.begin(), iter->second.end()); 161 } 162 163 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { 164 if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { 165 for (OpOperand &dstOpOperand : dst.getInputOpOperands()) { 166 // Check if the operand is defined by the src. 167 auto definingOp = dstOpOperand.get().getDefiningOp<LinalgOp>(); 168 if (definingOp && definingOp == src) 169 addDependenceElem(DependenceType::RAW, dstOpOperand.get(), 170 &dstOpOperand); 171 } 172 return; 173 } 174 assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && 175 "unhandled dependence tracking for mixed buffer/tensor operations"); 176 for (OpOperand *srcOpOperand : src.getOutputBuffersOpOperands()) { // W 177 // RAW graph 178 for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R 179 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias 180 addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand); 181 // WAW graph 182 for (OpOperand *dstOpOperand : dst.getOutputBuffersOpOperands()) // W 183 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias 184 addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand); 185 } 186 for (OpOperand *srcOpOperand : src.getInputBuffersOpOperands()) { // R 187 // RAR graph 188 for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R 189 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias 190 addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand); 191 // WAR graph 192 for (OpOperand *dstOpOperand : dst.getOutputBuffersOpOperands()) // W 193 if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias 194 addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand); 195 } 196 } 197 198 SmallVector<Operation *, 8> 199 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp, 200 LinalgOp dstLinalgOp) const { 201 return findOperationsWithCoveringDependences( 202 srcLinalgOp, dstLinalgOp, nullptr, 203 {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW}); 204 } 205 206 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites( 207 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const { 208 return findOperationsWithCoveringDependences( 209 srcLinalgOp, dstLinalgOp, view, 210 {DependenceType::WAW, DependenceType::WAR}); 211 } 212 213 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads( 214 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const { 215 return findOperationsWithCoveringDependences( 216 srcLinalgOp, dstLinalgOp, view, 217 {DependenceType::RAR, DependenceType::RAW}); 218 } 219 220 SmallVector<Operation *, 8> 221 LinalgDependenceGraph::findOperationsWithCoveringDependences( 222 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view, 223 ArrayRef<DependenceType> types) const { 224 auto *src = srcLinalgOp.getOperation(); 225 auto *dst = dstLinalgOp.getOperation(); 226 auto srcPos = linalgOpPositions.lookup(src); 227 auto dstPos = linalgOpPositions.lookup(dst); 228 assert(srcPos < dstPos && "expected dst after src in IR traversal order"); 229 230 SmallVector<Operation *, 8> res; 231 // Consider an intermediate interleaved `interim` op, look for any dependence 232 // to an aliasing view on a src -> op -> dst path. 233 // TODO: we are not considering paths yet, just interleaved positions. 234 for (auto dt : types) { 235 for (auto dependence : getDependencesFrom(src, dt)) { 236 auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp()); 237 // Skip if not interleaved. 238 if (interimPos >= dstPos || interimPos <= srcPos) 239 continue; 240 Value consumerView = dependence.getIndexingValue(); 241 if (view && !aliases.alias(view, consumerView)) 242 continue; 243 auto *op = dependence.getDependentOp(); 244 LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " 245 << getDependenceTypeStr(dt) << ": " << *src << " -> " 246 << *op << " on " << consumerView); 247 res.push_back(op); 248 } 249 } 250 return res; 251 } 252 253 bool LinalgDependenceGraph::hasDependenceFrom( 254 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, 255 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const { 256 for (auto dep : depTypes) 257 for (auto dependence : getDependencesInto(dstLinalgOp, dep)) 258 if (dependence.getDependentOp() == srcLinalgOp) 259 return true; 260 return false; 261 } 262 263 bool LinalgDependenceGraph::hasDependentOperationsFrom( 264 LinalgOp linalgOp, 265 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const { 266 for (auto dep : depTypes) { 267 if (!getDependencesFrom(linalgOp, dep).empty()) 268 return true; 269 } 270 return false; 271 } 272 273 bool LinalgDependenceGraph::hasDependentOperationsInto( 274 LinalgOp linalgOp, 275 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const { 276 for (auto dep : depTypes) { 277 if (!getDependencesInto(linalgOp, dep).empty()) 278 return true; 279 } 280 return false; 281 } 282 283 bool LinalgDependenceGraph::hasDependentOperations( 284 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 285 return hasDependentOperationsInto(linalgOp, depTypes) || 286 hasDependentOperationsFrom(linalgOp, depTypes); 287 } 288 289 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 290 LinalgDependenceGraph::getDependentOperationsInto( 291 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 292 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 293 dependentOperations; 294 for (auto dependenceType : depTypes) { 295 auto dependencies = getDependencesInto(linalgOp, dependenceType); 296 dependentOperations.append(dependencies.begin(), dependencies.end()); 297 } 298 return dependentOperations; 299 } 300 301 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 302 LinalgDependenceGraph::getDependentOperationsFrom( 303 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 304 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 305 dependentOperations; 306 for (auto dependenceType : depTypes) { 307 auto dependencies = getDependencesFrom(linalgOp, dependenceType); 308 dependentOperations.append(dependencies.begin(), dependencies.end()); 309 } 310 return dependentOperations; 311 } 312 313 /// Returns all dependent operations (into and from) given `operation`. 314 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2> 315 LinalgDependenceGraph::getDependentOperations( 316 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const { 317 SmallVector<LinalgDependenceGraphElem, 2> dependentOperations = 318 getDependentOperationsInto(linalgOp, depTypes); 319 SmallVector<LinalgDependenceGraphElem, 2> t = 320 getDependentOperationsFrom(linalgOp, depTypes); 321 dependentOperations.append(t.begin(), t.end()); 322 return dependentOperations; 323 } 324