1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp 10 // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. 11 // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple 12 // call graphs. 13 // 14 // One-Shot Bufferize consists of two phases. 15 // 16 // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without 17 // inserting buffer copies. The analysis queries op bufferization semantics 18 // via `BufferizableOpInterface`. 19 // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This 20 // function does not generate buffer copies for OpResults that were decided 21 // to bufferize inplace during the analysis phase. 22 // 23 // This file contains only the analysis. The actual bufferization is implemented 24 // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a 25 // helper function `runOneShotBufferize` that analyzes an op (and its nested 26 // ops) and then bufferizes it. 27 // 28 // Inplace bufferization decisions are passed from the analysis to the 29 // bufferization phase via `AnalysisState` and `BufferizationAliasInfo`. 30 // They can be printed for debugging purposes with `testAnalysisOnly`. 31 // 32 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are 33 // treated conservatively. E.g., the analysis has to assume that their tensor 34 // OpOperands bufferize to memory writes. While such ops can be analyzed, they 35 // are not bufferized and remain in the IR. to_tensor and to_memref ops are 36 // inserted at the bufferization boundary. 37 // 38 // This analysis caters to high-performance codegen where buffer reuse is deemed 39 // critical: the analysis should fail if the bufferized form of the function 40 // needs to return a buffer, unless `allowReturnAllocs` is enabled. 41 42 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 43 44 #include <random> 45 46 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 47 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 48 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 49 #include "mlir/Dialect/Func/IR/FuncOps.h" 50 #include "mlir/Dialect/MemRef/IR/MemRef.h" 51 #include "mlir/IR/AsmState.h" 52 #include "mlir/IR/Dominance.h" 53 #include "mlir/IR/Operation.h" 54 #include "mlir/IR/TypeUtilities.h" 55 #include "mlir/Interfaces/ControlFlowInterfaces.h" 56 #include "llvm/ADT/DenseSet.h" 57 #include "llvm/ADT/SetVector.h" 58 59 using namespace mlir; 60 using namespace mlir::bufferization; 61 62 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 63 64 //===----------------------------------------------------------------------===// 65 // Bufferization-specific attribute manipulation. 66 // These are for testing and debugging only. Bufferization information is 67 // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR 68 // is annotated with the results of the analysis (copied from 69 // BufferizationAliasInfo), so that they can be checked in tests. 70 //===----------------------------------------------------------------------===// 71 72 /// Attribute marker to specify op results that can be bufferized inPlace. 73 constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__"; 74 75 /// Mark whether OpOperand will be bufferized inplace. 76 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) { 77 Operation *op = opOperand.getOwner(); 78 auto attr = 79 op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>(); 80 SmallVector<StringRef> inPlaceVector; 81 if (attr) { 82 inPlaceVector = SmallVector<StringRef>( 83 llvm::to_vector<4>(attr.getAsValueRange<StringAttr>())); 84 } else { 85 inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none"); 86 for (OpOperand &opOperand : op->getOpOperands()) 87 if (opOperand.get().getType().isa<TensorType>()) 88 inPlaceVector[opOperand.getOperandNumber()] = "false"; 89 } 90 91 inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; 92 op->setAttr(kInPlaceResultsAttrName, 93 OpBuilder(op).getStrArrayAttr(inPlaceVector)); 94 } 95 96 //===----------------------------------------------------------------------===// 97 // BufferizationAliasInfo 98 //===----------------------------------------------------------------------===// 99 100 BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { 101 rootOp->walk([&](Operation *op) { 102 for (Value v : op->getResults()) 103 if (v.getType().isa<TensorType>()) 104 createAliasInfoEntry(v); 105 for (Region &r : op->getRegions()) 106 for (Block &b : r.getBlocks()) 107 for (auto bbArg : b.getArguments()) 108 if (bbArg.getType().isa<TensorType>()) 109 createAliasInfoEntry(bbArg); 110 }); 111 } 112 113 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the 114 /// beginning the alias and equivalence sets only contain `v` itself. 115 void BufferizationAliasInfo::createAliasInfoEntry(Value v) { 116 aliasInfo.insert(v); 117 equivalentInfo.insert(v); 118 } 119 120 /// Insert an info entry for `newValue` and merge its alias set with that of 121 /// `alias`. 122 void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) { 123 createAliasInfoEntry(newValue); 124 aliasInfo.unionSets(newValue, alias); 125 } 126 127 /// Insert an info entry for `newValue` and merge its alias set with that of 128 /// `alias`. Additionally, merge their equivalence classes. 129 void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue, 130 Value alias) { 131 insertNewBufferAlias(newValue, alias); 132 equivalentInfo.unionSets(newValue, alias); 133 } 134 135 /// Return `true` if a value was marked as in-place bufferized. 136 bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const { 137 return inplaceBufferized.contains(&operand); 138 } 139 140 /// Set the inPlace bufferization spec to true. 141 void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand, 142 AnalysisState &state) { 143 markInPlace(operand); 144 for (OpResult result : state.getAliasingOpResult(operand)) 145 aliasInfo.unionSets(result, operand.get()); 146 } 147 148 /// Set the inPlace bufferization spec to false. 149 void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) { 150 assert(!inplaceBufferized.contains(&operand) && 151 "OpOperand was already decided to bufferize inplace"); 152 } 153 154 /// Apply `fun` to all the members of the equivalence class of `v`. 155 void BufferizationAliasInfo::applyOnEquivalenceClass( 156 Value v, function_ref<void(Value)> fun) const { 157 auto leaderIt = equivalentInfo.findLeader(v); 158 for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; 159 ++mit) { 160 fun(*mit); 161 } 162 } 163 164 /// Apply `fun` to all aliases of `v`. 165 void BufferizationAliasInfo::applyOnAliases( 166 Value v, function_ref<void(Value)> fun) const { 167 auto leaderIt = aliasInfo.findLeader(v); 168 for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) { 169 fun(*mit); 170 } 171 } 172 173 BufferizationAliasInfo::EquivalenceClassRangeType 174 BufferizationAliasInfo::getAliases(Value v) const { 175 DenseSet<Value> res; 176 auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v)); 177 for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); 178 mit != meit; ++mit) { 179 res.insert(static_cast<Value>(*mit)); 180 } 181 return BufferizationAliasInfo::EquivalenceClassRangeType( 182 aliasInfo.member_begin(it), aliasInfo.member_end()); 183 } 184 185 //===----------------------------------------------------------------------===// 186 // OneShotAnalysisState 187 //===----------------------------------------------------------------------===// 188 189 OneShotAnalysisState::OneShotAnalysisState( 190 Operation *op, const OneShotBufferizationOptions &options) 191 : AnalysisState(options), aliasInfo(op) { 192 // Set up alias sets for OpResults that must bufferize in-place. This should 193 // be done before making any other bufferization decisions. 194 op->walk([&](BufferizableOpInterface bufferizableOp) { 195 if (!options.isOpAllowed(bufferizableOp)) 196 return WalkResult::skip(); 197 for (OpOperand &opOperand : bufferizableOp->getOpOperands()) { 198 if (opOperand.get().getType().isa<TensorType>()) 199 if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) { 200 for (OpResult opResult : 201 bufferizableOp.getAliasingOpResult(opOperand, *this)) 202 aliasInfo.unionAliasSets(opOperand.get(), opResult); 203 aliasInfo.markInPlace(opOperand); 204 } 205 } 206 return WalkResult::advance(); 207 }); 208 } 209 210 bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const { 211 return aliasInfo.isInPlace(opOperand); 212 } 213 214 bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1, 215 Value v2) const { 216 return aliasInfo.areEquivalentBufferizedValues(v1, v2); 217 } 218 219 bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1, 220 Value v2) const { 221 return aliasInfo.areAliasingBufferizedValues(v1, v2); 222 } 223 224 // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is 225 // to ensure that such information is available during bufferization time. 226 // Alias information can no longer be queried through BufferizationAliasInfo 227 // once we have started modifying the IR. 228 void OneShotAnalysisState::gatherYieldedTensors(Operation *op) { 229 op->walk([&](Operation *returnOp) { 230 if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp)) 231 return WalkResult::advance(); 232 233 for (OpOperand &returnValOperand : returnOp->getOpOperands()) { 234 Value returnVal = returnValOperand.get(); 235 // Skip non-tensor values. 236 if (!returnVal.getType().isa<TensorType>()) 237 continue; 238 239 // Add all aliases of the returned value. But only the ones that are in 240 // the same block. 241 aliasInfo.applyOnAliases(returnVal, [&](Value v) { 242 if (auto bbArg = v.dyn_cast<BlockArgument>()) { 243 if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp()) 244 yieldedTensors.insert(bbArg); 245 return; 246 } 247 Operation *definingOp = v.getDefiningOp(); 248 if (definingOp->getParentOp() == returnOp->getParentOp()) 249 yieldedTensors.insert(v); 250 }); 251 } 252 253 return WalkResult::advance(); 254 }); 255 } 256 257 void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) { 258 op->walk([&](Operation *op) { 259 // Skip unknown ops. 260 auto bufferizableOp = getOptions().dynCastBufferizableOp(op); 261 if (!bufferizableOp) 262 return WalkResult::skip(); 263 264 // Check all tensor OpResults. 265 for (OpResult opResult : op->getOpResults()) { 266 if (!opResult.getType().isa<TensorType>()) 267 continue; 268 269 // If there is no preceding memory write, the tensor contents are 270 // undefined. 271 // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA 272 // use-def chain, it returns that value, regardless of whether it is a 273 // memory write or not. 274 SetVector<Value> lastWrites = findLastPrecedingWrite(opResult); 275 bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) { 276 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite)) 277 return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), 278 *this); 279 return true; 280 }); 281 if (isUndefined) 282 for (OpOperand &use : opResult.getUses()) 283 undefinedTensorUses.insert(&use); 284 } 285 286 return WalkResult::advance(); 287 }); 288 } 289 290 bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 291 return undefinedTensorUses.contains(opOperand); 292 } 293 294 bool OneShotAnalysisState::isTensorYielded(Value tensor) const { 295 return yieldedTensors.contains(tensor); 296 } 297 298 bool OneShotAnalysisState::isValueWritten(Value value) const { 299 bool isWritten = false; 300 aliasInfo.applyOnAliases(value, [&](Value val) { 301 for (OpOperand &use : val.getUses()) 302 if (isInPlace(use) && bufferizesToMemoryWrite(use)) 303 isWritten = true; 304 }); 305 return isWritten; 306 } 307 308 bool OneShotAnalysisState::isWritable(Value value) const { 309 // TODO: Out-of-place bufferized value could be considered writable. 310 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(value)) 311 return bufferizableOp.isWritable(value, *this); 312 313 // Query BufferizableOpInterface to see if the BlockArgument is writable. 314 if (auto bbArg = value.dyn_cast<BlockArgument>()) 315 if (auto bufferizableOp = 316 getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) 317 return bufferizableOp.isWritable(bbArg, *this); 318 319 // Not a bufferizable op: The conservative answer is "not writable". 320 return false; 321 } 322 323 //===----------------------------------------------------------------------===// 324 // Bufferization-specific alias analysis. 325 //===----------------------------------------------------------------------===// 326 327 /// Return true if opOperand has been decided to bufferize in-place. 328 static bool isInplaceMemoryWrite(OpOperand &opOperand, 329 const BufferizationAliasInfo &aliasInfo, 330 const AnalysisState &state) { 331 // OpOperands that do not bufferize to a memory write do not write in-place. 332 if (!state.bufferizesToMemoryWrite(opOperand)) 333 return false; 334 // Check current bufferization decisions. 335 return aliasInfo.isInPlace(opOperand); 336 } 337 338 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors 339 /// properly dominates `b` and `b` is not inside `a`. 340 static bool happensBefore(Operation *a, Operation *b, 341 const DominanceInfo &domInfo) { 342 do { 343 // TODO: Instead of isProperAncestor + properlyDominates, we should use 344 // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false) 345 if (a->isProperAncestor(b)) 346 return false; 347 if (domInfo.properlyDominates(a, b)) 348 return true; 349 } while ((a = a->getParentOp())); 350 return false; 351 } 352 353 /// For each given value, find the closest enclosing repetitive region. If this 354 /// is the same region for each value, return it. Otherwise return None. 355 /// Note: If there is no enclosing repetitive region, return nullptr. 356 static Optional<Region *> 357 getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) { 358 if (values.empty()) 359 return None; 360 Region *r = getEnclosingRepetitiveRegion(values.front()); 361 for (Value value : values.drop_front()) 362 if (getEnclosingRepetitiveRegion(value) != r) 363 return None; 364 return r; 365 } 366 367 /// Return `true` if the given tensor value is a memory write. Most values are 368 /// tensor writes, but ops that define a tensor SSA value without specifying its 369 /// contents (e.g., alloc_tensor) are not. 370 static bool isMemoryWrite(Value value, const AnalysisState &state) { 371 auto opResult = value.dyn_cast<OpResult>(); 372 if (!opResult) 373 return true; 374 auto bufferizableOp = state.getOptions().dynCastBufferizableOp(value); 375 if (!bufferizableOp) 376 return true; 377 return bufferizableOp.isMemoryWrite(opResult, state); 378 } 379 380 /// Annotate IR with details about the detected RaW conflict. 381 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, 382 Value lastWrite) { 383 static uint64_t counter = 0; 384 Operation *readingOp = uRead->getOwner(); 385 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 386 387 OpBuilder b(conflictingWritingOp->getContext()); 388 std::string id = "C_" + std::to_string(counter++); 389 390 std::string conflictingWriteAttr = 391 id + 392 "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) + 393 "]"; 394 conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr()); 395 396 std::string readAttr = 397 id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; 398 readingOp->setAttr(readAttr, b.getUnitAttr()); 399 400 if (auto opResult = lastWrite.dyn_cast<OpResult>()) { 401 std::string lastWriteAttr = id + "[LAST-WRITE: result " + 402 std::to_string(opResult.getResultNumber()) + 403 "]"; 404 opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr()); 405 } else { 406 auto bbArg = lastWrite.cast<BlockArgument>(); 407 std::string lastWriteAttr = 408 id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; 409 bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr()); 410 } 411 } 412 413 /// Given sets of uses and writes, return true if there is a RaW conflict under 414 /// the assumption that all given reads/writes alias the same buffer and that 415 /// all given writes bufferize inplace. 416 /// 417 /// A conflict is: According to SSA use-def chains, a read R is supposed to read 418 /// the result of a write W1. But because of bufferization decisions, R actually 419 /// reads another write W2. 420 static bool hasReadAfterWriteInterference( 421 const DenseSet<OpOperand *> &usesRead, 422 const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo, 423 AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { 424 const BufferizationOptions &options = state.getOptions(); 425 426 // Gather all written aliases. Skip over aliases that are not actual writes. 427 SmallVector<Value> writtenAliases; 428 for (OpOperand *uWrite : usesWrite) 429 if (isMemoryWrite(uWrite->get(), state)) 430 writtenAliases.push_back(uWrite->get()); 431 // Find the inner-most enclosing repetitive region of each alias. If this is 432 // the same region for every alias, save it in `repetitiveRegionOfWrites`. 433 Optional<Region *> repetitiveRegionOfWrites = 434 getCommonEnclosingRepetitiveRegion(writtenAliases); 435 436 for (OpOperand *uRead : usesRead) { 437 Operation *readingOp = uRead->getOwner(); 438 439 // Find most recent writes of uRead by following the SSA use-def chain. 440 // E.g.: 441 // 442 // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32> 443 // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32> 444 // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type 445 // 446 // In the above example, if uRead is the OpOperand of reading_op, lastWrite 447 // is %0. Note that operations that create an alias but do not write (such 448 // as ExtractSliceOp) are skipped. 449 SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get()); 450 451 // Look for conflicting memory writes. Potential conflicts are writes to an 452 // alias that have been decided to bufferize inplace. 453 for (OpOperand *uConflictingWrite : usesWrite) { 454 // Throughout this loop, check for multiple requirements that have to be 455 // met for uConflictingWrite to be an actual conflict. 456 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 457 458 // Check if conflictingWritingOp is in the same repetitive region as all 459 // written aliases. If this is not the case, there is no meaningful 460 // `happensBefore` relationship because conflictingWritingOp may be 461 // executed multiple times. E.g.: 462 // 463 // %0 = ... : tensor<?xf32> 464 // scf.for ... { 465 // "reading_op"(%0) : tensor<?xf32> 466 // %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> 467 // ... 468 // } 469 // 470 // In the above example, reading_op happens before writing_op according to 471 // op dominance. However, both ops may happen multiple times; in 472 // particular, the second execution of reading_op happens after the first 473 // execution of writing_op. This is problematic if the tensor they operate 474 // on (%0) is defined outside of the loop. 475 // 476 // Counter example: 477 // 478 // scf.for ... { 479 // %0 = ... : tensor<?xf32> 480 // "reading_op"(%0) : tensor<?xf32> 481 // %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> 482 // ... 483 // } 484 // 485 // In this example, %0 is in the same repetitive region as 486 // conflictingWritingOp, so op dominance can be used to compute the 487 // `happensBefore` relationship. 488 // 489 // Note: iter_args of loops are not aliases of their respective block 490 // arguments, so op domanice can be used when analyzing ops that operate 491 // on them. 492 // 493 // Note: If `writtenAliases` is empty, there are no memory writes outside 494 // of the repetitive region of conflictingWritingOp, which means that all 495 // relevant aliases are inside the same repetitive region. 496 bool canUseOpDominance = 497 writtenAliases.empty() || 498 repetitiveRegionOfWrites == 499 getEnclosingRepetitiveRegion(conflictingWritingOp); 500 501 // No conflict if the readingOp dominates conflictingWritingOp, i.e., the 502 // write is not visible when reading. 503 // 504 // Note: If ops are executed multiple times (e.g., because they are inside 505 // a loop), there may be no meaningful `happensBefore` relationship. 506 if (canUseOpDominance && 507 happensBefore(readingOp, conflictingWritingOp, domInfo)) 508 continue; 509 510 // No conflict if the reading use equals the use of the conflicting write. 511 // A use cannot conflict with itself. 512 // 513 // Note: Just being the same op is not enough. It has to be the same use. 514 // Note: If the op is executed multiple times (e.g., because it is inside 515 // a loop), it may be conflicting with itself. 516 if (canUseOpDominance && uConflictingWrite == uRead) 517 continue; 518 519 // No conflict if the op interface says so. 520 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) 521 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) 522 continue; 523 524 if (conflictingWritingOp != readingOp) 525 if (auto bufferizableOp = 526 options.dynCastBufferizableOp(conflictingWritingOp)) 527 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) 528 continue; 529 530 // Ops are not conflicting if they are in mutually exclusive regions. 531 // 532 // Note: If ops are executed multiple times (e.g., because they are inside 533 // a loop), mutually exclusive regions may be executed multiple 534 // times. 535 if (canUseOpDominance && 536 insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) 537 continue; 538 539 // Check all possible last writes. 540 for (Value lastWrite : lastWrites) { 541 // No conflict if the conflicting write happens before the last 542 // write. 543 if (Operation *writingOp = lastWrite.getDefiningOp()) { 544 if (happensBefore(conflictingWritingOp, writingOp, domInfo)) 545 // conflictingWritingOp happens before writingOp. No conflict. 546 continue; 547 // No conflict if conflictingWritingOp is contained in writingOp. 548 if (writingOp->isProperAncestor(conflictingWritingOp)) 549 continue; 550 } else { 551 auto bbArg = lastWrite.cast<BlockArgument>(); 552 Block *block = bbArg.getOwner(); 553 if (!block->findAncestorOpInBlock(*conflictingWritingOp)) 554 // conflictingWritingOp happens outside of the block. No 555 // conflict. 556 continue; 557 } 558 559 // No conflict if the conflicting write and the last write are the same 560 // use. 561 SmallVector<OpResult> aliasingOpResult = 562 state.getAliasingOpResult(*uConflictingWrite); 563 if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) 564 continue; 565 566 // All requirements are met. Conflict found! 567 568 if (options.printConflicts) 569 annotateConflict(uRead, uConflictingWrite, lastWrite); 570 571 return true; 572 } 573 } 574 } 575 576 return false; 577 } 578 579 // Helper function to iterate on aliases of `root` and capture the writes. 580 static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root, 581 const BufferizationAliasInfo &aliasInfo, 582 const AnalysisState &state) { 583 aliasInfo.applyOnAliases(root, [&](Value alias) { 584 for (auto &use : alias.getUses()) 585 // Inplace write to a value that aliases root. 586 if (isInplaceMemoryWrite(use, aliasInfo, state)) 587 res.insert(&use); 588 }); 589 } 590 591 // Helper function to iterate on aliases of `root` and capture the reads. 592 static void getAliasingReads(DenseSet<OpOperand *> &res, Value root, 593 const BufferizationAliasInfo &aliasInfo, 594 const AnalysisState &state) { 595 aliasInfo.applyOnAliases(root, [&](Value alias) { 596 for (auto &use : alias.getUses()) 597 // Read to a value that aliases root. 598 if (state.bufferizesToMemoryRead(use)) 599 res.insert(&use); 600 }); 601 } 602 603 /// Return true if bufferizing `operand` inplace would create a conflict. A read 604 /// R and a write W of the same alias set is a conflict if inplace bufferization 605 /// of W changes the value read by R to a value different from the one that 606 /// would be expected by tracing back R's origin through SSA use-def chains. 607 /// A conflict can only be introduced by a new alias and/or an inplace 608 /// bufferization decision. 609 /// 610 /// Example: 611 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?} 612 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32> 613 /// %e = tensor.extract_slice %1 614 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32> 615 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32> 616 /// 617 /// In the above example, the two TransferWriteOps have already been decided to 618 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a 619 /// conflict because: 620 /// * According to SSA use-def chains, we expect to read the result of %1. 621 /// * However, adding an alias {%0, %t} would mean that the second 622 /// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp 623 /// would no longer be reading the result of %1. 624 /// 625 /// If `checkConsistencyOnly` is true, this function checks if there is a 626 /// read-after-write conflict without bufferizing `operand` inplace. This would 627 /// indicate a problem with the current inplace bufferization decisions. 628 /// 629 /// Note: If `checkConsistencyOnly`, this function may be called with a null 630 /// OpResult. In that case, only the consistency of bufferization decisions 631 /// involving aliases of the given OpOperand are checked. 632 static bool wouldCreateReadAfterWriteInterference( 633 OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state, 634 const BufferizationAliasInfo &aliasInfo, 635 bool checkConsistencyOnly = false) { 636 // Collect reads and writes of all aliases of OpOperand and OpResult. 637 DenseSet<OpOperand *> usesRead, usesWrite; 638 getAliasingReads(usesRead, operand.get(), aliasInfo, state); 639 getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); 640 for (OpResult result : state.getAliasingOpResult(operand)) { 641 getAliasingReads(usesRead, result, aliasInfo, state); 642 getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); 643 } 644 if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) 645 usesWrite.insert(&operand); 646 647 return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state, 648 aliasInfo); 649 } 650 651 /// Check the reverse SSA use-def chain (following aliasing OpOperands) for 652 /// non-writable tensor values. Stop searching when an out-of-place bufferized 653 /// OpOperand was found (or when the OpOperand was not bufferized yet). 654 /// `currentOpOperand` is assumed to be in-place, even if that decision was not 655 /// materialized in `aliasInfo` yet. 656 static bool 657 hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand, 658 const BufferizationAliasInfo &aliasInfo, 659 const OneShotAnalysisState &state) { 660 SmallVector<Value> worklist; 661 worklist.push_back(value); 662 while (!worklist.empty()) { 663 Value nextVal = worklist.pop_back_val(); 664 if (!state.isWritable(nextVal)) 665 return true; 666 667 // If `nextVal` is not a BlockArgument: End of use-def chain reached. 668 auto opResult = nextVal.dyn_cast<OpResult>(); 669 if (!opResult) 670 continue; 671 672 // Follow reverse SSA use-def chain. 673 SmallVector<OpOperand *> aliasingOpOperands = 674 state.getAliasingOpOperand(opResult); 675 for (OpOperand *opOperand : aliasingOpOperands) 676 if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand) 677 worklist.push_back(opOperand->get()); 678 } 679 return false; 680 } 681 682 /// Return true if bufferizing `operand` inplace would create a write to a 683 /// non-writable buffer. 684 static bool wouldCreateWriteToNonWritableBuffer( 685 OpOperand &operand, const BufferizationAliasInfo &aliasInfo, 686 OneShotAnalysisState &state, bool checkConsistencyOnly = false) { 687 // Collect writes of all aliases of OpOperand and OpResult. 688 DenseSet<OpOperand *> usesWrite; 689 getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); 690 for (OpResult result : state.getAliasingOpResult(operand)) { 691 getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); 692 } 693 if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) 694 usesWrite.insert(&operand); 695 696 // Assuming that `operand` bufferizes in-place: For each write (to each 697 // alias), check if there is a non-writable tensor in the reverse SSA use-def 698 // chain. 699 for (OpOperand *uWrite : usesWrite) 700 if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, 701 aliasInfo, state)) 702 return true; 703 704 return false; 705 } 706 707 //===----------------------------------------------------------------------===// 708 // Bufferization analyses. 709 //===----------------------------------------------------------------------===// 710 711 /// Determine if `operand` can be bufferized in-place. 712 static LogicalResult bufferizableInPlaceAnalysisImpl( 713 OpOperand &operand, BufferizationAliasInfo &aliasInfo, 714 OneShotAnalysisState &state, const DominanceInfo &domInfo) { 715 bool foundInterference = 716 wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) || 717 wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo); 718 719 if (foundInterference) 720 aliasInfo.bufferizeOutOfPlace(operand); 721 else 722 aliasInfo.bufferizeInPlace(operand, state); 723 724 return success(); 725 } 726 727 /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in 728 /// reverse and bufferize ops greedily. This is a good starter heuristic. 729 /// 730 /// Even if an op does not read or write, it may still create an alias when 731 /// bufferized in-place. An example of such ops is tensor.extract_slice. 732 /// 733 /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace: 734 /// 735 /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This 736 /// cannot change the flow of information for either the source or the 737 /// result buffers. 738 /// 739 /// When bufferized inplace, an ExtractSliceOp does not by itself create any 740 /// read or write from memory. Instead, it has the effect of merging the alias 741 /// sets of the source and the result buffers. 742 /// 743 /// An analysis is required to ensure inplace bufferization would not result in 744 /// RaW dependence violations. 745 static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops, 746 BufferizationAliasInfo &aliasInfo, 747 OneShotAnalysisState &state, 748 const DominanceInfo &domInfo, 749 unsigned analysisFuzzerSeed = 0) { 750 if (analysisFuzzerSeed) { 751 // This is a fuzzer. For testing purposes only. Randomize the order in which 752 // operations are analyzed. The bufferization quality is likely worse, but 753 // we want to make sure that no assertions are triggered anywhere. 754 std::mt19937 g(analysisFuzzerSeed); 755 llvm::shuffle(ops.begin(), ops.end(), g); 756 } 757 758 // Walk ops in reverse for better interference analysis. 759 for (Operation *op : reverse(ops)) 760 for (OpOperand &opOperand : op->getOpOperands()) 761 if (opOperand.get().getType().isa<TensorType>()) 762 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 763 if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo, 764 state, domInfo))) 765 return failure(); 766 767 return success(); 768 } 769 770 /// Return true if the given op has a tensor result or a tensor operand. 771 static bool hasTensorSemantics(Operation *op) { 772 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 773 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 774 return hasTensorResult || hasTensorOperand; 775 } 776 777 /// Analyze all ops that are contained in `op`. 778 static LogicalResult inPlaceAnalysis(Operation *op, 779 BufferizationAliasInfo &aliasInfo, 780 OneShotAnalysisState &state, 781 const DominanceInfo &domInfo, 782 unsigned analysisFuzzerSeed = 0) { 783 // Collect ops so we can build our own reverse traversal. 784 SmallVector<Operation *> ops; 785 op->walk([&](Operation *op) { 786 // No tensors => no buffers. 787 if (!hasTensorSemantics(op)) 788 return; 789 ops.push_back(op); 790 }); 791 792 return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed); 793 } 794 795 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. 796 static void equivalenceAnalysis(SmallVector<Operation *> &ops, 797 BufferizationAliasInfo &aliasInfo, 798 AnalysisState &state) { 799 for (Operation *op : ops) 800 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 801 for (OpResult opResult : op->getOpResults()) 802 if (opResult.getType().isa<TensorType>()) 803 for (OpOperand *opOperand : 804 bufferizableOp.getAliasingOpOperand(opResult, state)) 805 if (state.isInPlace(*opOperand)) 806 if (bufferizableOp.bufferRelation(opResult, state) == 807 BufferRelation::Equivalent) 808 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get()); 809 } 810 811 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained 812 /// in `op`. 813 static void equivalenceAnalysis(Operation *op, 814 BufferizationAliasInfo &aliasInfo, 815 AnalysisState &state) { 816 // Traverse ops in PostOrder: Nested ops first, then enclosing ops. 817 SmallVector<Operation *> ops; 818 op->walk<WalkOrder::PostOrder>([&](Operation *op) { 819 // No tensors => no buffers. 820 if (none_of(op->getResultTypes(), isaTensor)) 821 return; 822 ops.push_back(op); 823 }); 824 825 equivalenceAnalysis(ops, aliasInfo, state); 826 } 827 828 /// Assert that the current bufferization decisions are consistent. 829 static LogicalResult 830 checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, 831 AnalysisState &state, 832 const BufferizationAliasInfo &aliasInfo) { 833 const BufferizationOptions &options = state.getOptions(); 834 Operation *inconsistentOp = nullptr; 835 WalkResult walkResult = op->walk([&](Operation *op) { 836 if (auto bufferizableOp = options.dynCastBufferizableOp(op)) 837 for (OpOperand &opOperand : op->getOpOperands()) 838 if (opOperand.get().getType().isa<TensorType>()) { 839 if (wouldCreateReadAfterWriteInterference( 840 opOperand, domInfo, state, aliasInfo, 841 /*checkConsistencyOnly=*/true)) { 842 // This error can happen if certain "mustBufferizeInPlace" interface 843 // methods are implemented incorrectly, such that the IR already has 844 // a RaW conflict before making any bufferization decisions. 845 inconsistentOp = op; 846 return WalkResult::interrupt(); 847 } 848 } 849 return WalkResult::advance(); 850 }); 851 852 if (walkResult.wasInterrupted()) 853 return inconsistentOp->emitError("input IR has RaW conflict"); 854 return success(); 855 } 856 857 /// Annotate the IR with the result of the analysis. For testing/debugging only. 858 static void 859 annotateOpsWithBufferizationMarkers(Operation *op, 860 const BufferizationAliasInfo &aliasInfo, 861 AnalysisState &state) { 862 op->walk([&](Operation *op) { 863 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 864 for (OpOperand &opOperand : op->getOpOperands()) 865 if (opOperand.get().getType().isa<TensorType>()) 866 setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand)); 867 }); 868 } 869 870 /// Assert that IR is in destination-passing style. I.e., every value that is 871 /// returned or yielded from a block is: 872 /// * aliasing a bbArg of that block or a parent block, or 873 /// * aliasing an OpResult of a op in a parent block. 874 /// 875 /// Example: 876 /// ``` 877 /// %0 = "some_op" : tensor<?xf32> 878 /// %1 = scf.if %c -> (tensor<?xf32>) { 879 /// scf.yield %0 : tensor<?xf32> 880 /// } else { 881 /// %t = linalg.alloc_tensor : tensor<?xf32> 882 /// scf.yield %t : tensor<?xf32> 883 /// } 884 /// ``` 885 /// In the above example, the first scf.yield op satifies destination-passing 886 /// style because the yielded value %0 is defined in the parent block. The 887 /// second scf.yield op does not satisfy destination-passing style because the 888 /// yielded value %t is defined in the same block as the scf.yield op. 889 // TODO: The current implementation checks for equivalent values instead of 890 // aliasing values, which is stricter than needed. We can currently not check 891 // for aliasing values because the analysis is a maybe-alias analysis and we 892 // need a must-alias analysis here. 893 static LogicalResult 894 assertDestinationPassingStyle(Operation *op, AnalysisState &state, 895 BufferizationAliasInfo &aliasInfo, 896 SmallVector<Operation *> &newOps) { 897 LogicalResult status = success(); 898 DominanceInfo domInfo(op); 899 op->walk([&](Operation *returnOp) { 900 if (!isRegionReturnLike(returnOp) || 901 !state.getOptions().isOpAllowed(returnOp)) 902 return WalkResult::advance(); 903 904 for (OpOperand &returnValOperand : returnOp->getOpOperands()) { 905 Value returnVal = returnValOperand.get(); 906 // Skip non-tensor values. 907 if (!returnVal.getType().isa<TensorType>()) 908 continue; 909 910 bool foundEquivValue = false; 911 aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { 912 if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) { 913 Operation *definingOp = bbArg.getOwner()->getParentOp(); 914 if (definingOp->isProperAncestor(returnOp)) 915 foundEquivValue = true; 916 return; 917 } 918 919 Operation *definingOp = equivVal.getDefiningOp(); 920 if (definingOp->getBlock()->findAncestorOpInBlock( 921 *returnOp->getParentOp())) 922 // Skip ops that happen after `returnOp` and parent ops. 923 if (happensBefore(definingOp, returnOp, domInfo)) 924 foundEquivValue = true; 925 }); 926 927 if (!foundEquivValue) 928 status = 929 returnOp->emitError() 930 << "operand #" << returnValOperand.getOperandNumber() 931 << " of ReturnLike op does not satisfy destination passing style"; 932 } 933 934 return WalkResult::advance(); 935 }); 936 937 return status; 938 } 939 940 LogicalResult bufferization::analyzeOp(Operation *op, 941 OneShotAnalysisState &state) { 942 DominanceInfo domInfo(op); 943 BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); 944 const auto &options = 945 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 946 947 // Catch incorrect API usage. 948 assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) || 949 !options.bufferizeFunctionBoundaries) && 950 "must use ModuleBufferize to bufferize function boundaries"); 951 952 if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) 953 return failure(); 954 955 // If the analysis fails, just return. 956 if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo, 957 options.analysisFuzzerSeed))) 958 return failure(); 959 equivalenceAnalysis(op, aliasInfo, state); 960 961 bool failedAnalysis = false; 962 if (!options.allowReturnAllocs) { 963 SmallVector<Operation *> newOps; 964 failedAnalysis |= 965 failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps)); 966 } 967 968 // Gather some extra analysis data. 969 state.gatherYieldedTensors(op); 970 state.gatherUndefinedTensorUses(op); 971 972 // Analysis verification: After setting up alias/equivalence sets, each op 973 // can check for expected invariants/limitations and fail the analysis if 974 // necessary. 975 op->walk([&](Operation *op) { 976 if (BufferizableOpInterface bufferizableOp = 977 options.dynCastBufferizableOp(op)) 978 failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state)); 979 }); 980 981 // Annotate operations if we only want to report the analysis. 982 if (options.testAnalysisOnly) 983 annotateOpsWithBufferizationMarkers(op, aliasInfo, state); 984 985 return success(!failedAnalysis); 986 } 987 988 LogicalResult 989 bufferization::runOneShotBufferize(Operation *op, 990 const OneShotBufferizationOptions &options) { 991 OneShotAnalysisState state(op, options); 992 if (failed(analyzeOp(op, state))) 993 return failure(); 994 if (options.testAnalysisOnly) 995 return success(); 996 return bufferizeOp(op, state); 997 } 998