1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp 10 // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. 11 // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple 12 // call graphs. 13 // 14 // One-Shot Bufferize consists of two phases. 15 // 16 // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without 17 // inserting buffer copies. The analysis queries op bufferization semantics 18 // via `BufferizableOpInterface`. 19 // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This 20 // function does not generate buffer copies for OpResults that were decided 21 // to bufferize inplace during the analysis phase. 22 // 23 // This file contains only the analysis. The actual bufferization is implemented 24 // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a 25 // helper function `runOneShotBufferize` that analyzes an op (and its nested 26 // ops) and then bufferizes it. 27 // 28 // Inplace bufferization decisions are passed from the analysis to the 29 // bufferization phase via `AnalysisState` and `BufferizationAliasInfo`. 30 // They can be printed for debugging purposes with `testAnalysisOnly`. 31 // 32 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are 33 // treated conservatively. E.g., the analysis has to assume that their tensor 34 // OpOperands bufferize to memory writes. While such ops can be analyzed, they 35 // are not bufferized and remain in the IR. to_tensor and to_memref ops are 36 // inserted at the bufferization boundary. 37 // 38 // This analysis caters to high-performance codegen where buffer reuse is deemed 39 // critical: the analysis should fail if the bufferized form of the function 40 // needs to return a buffer, unless `allowReturnAllocs` is enabled. 41 42 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 43 44 #include <random> 45 46 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 47 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 48 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 49 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" 50 #include "mlir/Dialect/Func/IR/FuncOps.h" 51 #include "mlir/Dialect/MemRef/IR/MemRef.h" 52 #include "mlir/IR/AsmState.h" 53 #include "mlir/IR/Dominance.h" 54 #include "mlir/IR/Operation.h" 55 #include "mlir/IR/TypeUtilities.h" 56 #include "mlir/Interfaces/ControlFlowInterfaces.h" 57 #include "llvm/ADT/DenseSet.h" 58 #include "llvm/ADT/SetVector.h" 59 60 using namespace mlir; 61 using namespace mlir::bufferization; 62 63 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 64 65 //===----------------------------------------------------------------------===// 66 // Bufferization-specific attribute manipulation. 67 // These are for testing and debugging only. Bufferization information is 68 // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR 69 // is annotated with the results of the analysis (copied from 70 // BufferizationAliasInfo), so that they can be checked in tests. 71 //===----------------------------------------------------------------------===// 72 73 /// Attribute marker to specify op results that can be bufferized inPlace. 74 constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__"; 75 76 /// Mark whether OpOperand will be bufferized inplace. 77 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) { 78 Operation *op = opOperand.getOwner(); 79 auto attr = 80 op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>(); 81 SmallVector<StringRef> inPlaceVector; 82 if (attr) { 83 inPlaceVector = SmallVector<StringRef>( 84 llvm::to_vector<4>(attr.getAsValueRange<StringAttr>())); 85 } else { 86 inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none"); 87 for (OpOperand &opOperand : op->getOpOperands()) 88 if (opOperand.get().getType().isa<TensorType>()) 89 inPlaceVector[opOperand.getOperandNumber()] = "false"; 90 } 91 92 inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; 93 op->setAttr(kInPlaceResultsAttrName, 94 OpBuilder(op).getStrArrayAttr(inPlaceVector)); 95 } 96 97 //===----------------------------------------------------------------------===// 98 // BufferizationAliasInfo 99 //===----------------------------------------------------------------------===// 100 101 BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { 102 rootOp->walk([&](Operation *op) { 103 for (Value v : op->getResults()) 104 if (v.getType().isa<TensorType>()) 105 createAliasInfoEntry(v); 106 for (Region &r : op->getRegions()) 107 for (Block &b : r.getBlocks()) 108 for (auto bbArg : b.getArguments()) 109 if (bbArg.getType().isa<TensorType>()) 110 createAliasInfoEntry(bbArg); 111 }); 112 } 113 114 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the 115 /// beginning the alias and equivalence sets only contain `v` itself. 116 void BufferizationAliasInfo::createAliasInfoEntry(Value v) { 117 aliasInfo.insert(v); 118 equivalentInfo.insert(v); 119 } 120 121 /// Insert an info entry for `newValue` and merge its alias set with that of 122 /// `alias`. 123 void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) { 124 createAliasInfoEntry(newValue); 125 aliasInfo.unionSets(newValue, alias); 126 } 127 128 /// Insert an info entry for `newValue` and merge its alias set with that of 129 /// `alias`. Additionally, merge their equivalence classes. 130 void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue, 131 Value alias) { 132 insertNewBufferAlias(newValue, alias); 133 equivalentInfo.unionSets(newValue, alias); 134 } 135 136 /// Return `true` if a value was marked as in-place bufferized. 137 bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const { 138 return inplaceBufferized.contains(&operand); 139 } 140 141 /// Set the inPlace bufferization spec to true. 142 void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand, 143 AnalysisState &state) { 144 markInPlace(operand); 145 for (OpResult result : state.getAliasingOpResult(operand)) 146 aliasInfo.unionSets(result, operand.get()); 147 } 148 149 /// Set the inPlace bufferization spec to false. 150 void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) { 151 assert(!inplaceBufferized.contains(&operand) && 152 "OpOperand was already decided to bufferize inplace"); 153 } 154 155 /// Apply `fun` to all the members of the equivalence class of `v`. 156 void BufferizationAliasInfo::applyOnEquivalenceClass( 157 Value v, function_ref<void(Value)> fun) const { 158 auto leaderIt = equivalentInfo.findLeader(v); 159 for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; 160 ++mit) { 161 fun(*mit); 162 } 163 } 164 165 /// Apply `fun` to all aliases of `v`. 166 void BufferizationAliasInfo::applyOnAliases( 167 Value v, function_ref<void(Value)> fun) const { 168 auto leaderIt = aliasInfo.findLeader(v); 169 for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) { 170 fun(*mit); 171 } 172 } 173 174 BufferizationAliasInfo::EquivalenceClassRangeType 175 BufferizationAliasInfo::getAliases(Value v) const { 176 DenseSet<Value> res; 177 auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v)); 178 for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); 179 mit != meit; ++mit) { 180 res.insert(static_cast<Value>(*mit)); 181 } 182 return BufferizationAliasInfo::EquivalenceClassRangeType( 183 aliasInfo.member_begin(it), aliasInfo.member_end()); 184 } 185 186 //===----------------------------------------------------------------------===// 187 // OneShotAnalysisState 188 //===----------------------------------------------------------------------===// 189 190 OneShotAnalysisState::OneShotAnalysisState( 191 Operation *op, const OneShotBufferizationOptions &options) 192 : AnalysisState(options), aliasInfo(op) { 193 // Set up alias sets for OpResults that must bufferize in-place. This should 194 // be done before making any other bufferization decisions. 195 op->walk([&](BufferizableOpInterface bufferizableOp) { 196 if (!options.isOpAllowed(bufferizableOp)) 197 return WalkResult::skip(); 198 for (OpOperand &opOperand : bufferizableOp->getOpOperands()) { 199 if (opOperand.get().getType().isa<TensorType>()) 200 if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) { 201 for (OpResult opResult : 202 bufferizableOp.getAliasingOpResult(opOperand, *this)) 203 aliasInfo.unionAliasSets(opOperand.get(), opResult); 204 aliasInfo.markInPlace(opOperand); 205 } 206 } 207 return WalkResult::advance(); 208 }); 209 } 210 211 bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const { 212 return aliasInfo.isInPlace(opOperand); 213 } 214 215 bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1, 216 Value v2) const { 217 return aliasInfo.areEquivalentBufferizedValues(v1, v2); 218 } 219 220 bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1, 221 Value v2) const { 222 return aliasInfo.areAliasingBufferizedValues(v1, v2); 223 } 224 225 // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is 226 // to ensure that such information is available during bufferization time. 227 // Alias information can no longer be queried through BufferizationAliasInfo 228 // once we have started modifying the IR. 229 void OneShotAnalysisState::gatherYieldedTensors(Operation *op) { 230 op->walk([&](Operation *returnOp) { 231 if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp)) 232 return WalkResult::advance(); 233 234 for (OpOperand &returnValOperand : returnOp->getOpOperands()) { 235 Value returnVal = returnValOperand.get(); 236 // Skip non-tensor values. 237 if (!returnVal.getType().isa<TensorType>()) 238 continue; 239 240 // Add all aliases of the returned value. But only the ones that are in 241 // the same block. 242 aliasInfo.applyOnAliases(returnVal, [&](Value v) { 243 if (auto bbArg = v.dyn_cast<BlockArgument>()) { 244 if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp()) 245 yieldedTensors.insert(bbArg); 246 return; 247 } 248 Operation *definingOp = v.getDefiningOp(); 249 if (definingOp->getParentOp() == returnOp->getParentOp()) 250 yieldedTensors.insert(v); 251 }); 252 } 253 254 return WalkResult::advance(); 255 }); 256 } 257 258 void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) { 259 op->walk([&](Operation *op) { 260 // Skip unknown ops. 261 auto bufferizableOp = getOptions().dynCastBufferizableOp(op); 262 if (!bufferizableOp) 263 return WalkResult::skip(); 264 265 // Check all tensor OpResults. 266 for (OpResult opResult : op->getOpResults()) { 267 if (!opResult.getType().isa<TensorType>()) 268 continue; 269 270 // If there is no preceding memory write, the tensor contents are 271 // undefined. 272 // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA 273 // use-def chain, it returns that value, regardless of whether it is a 274 // memory write or not. 275 SetVector<Value> lastWrites = findLastPrecedingWrite(opResult); 276 bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) { 277 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite)) 278 return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), 279 *this); 280 return true; 281 }); 282 if (isUndefined) 283 for (OpOperand &use : opResult.getUses()) 284 undefinedTensorUses.insert(&use); 285 } 286 287 return WalkResult::advance(); 288 }); 289 } 290 291 bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 292 return undefinedTensorUses.contains(opOperand); 293 } 294 295 bool OneShotAnalysisState::isTensorYielded(Value tensor) const { 296 return yieldedTensors.contains(tensor); 297 } 298 299 bool OneShotAnalysisState::isValueWritten(Value value) const { 300 bool isWritten = false; 301 aliasInfo.applyOnAliases(value, [&](Value val) { 302 for (OpOperand &use : val.getUses()) 303 if (isInPlace(use) && bufferizesToMemoryWrite(use)) 304 isWritten = true; 305 }); 306 return isWritten; 307 } 308 309 bool OneShotAnalysisState::isWritable(Value value) const { 310 // TODO: Out-of-place bufferized value could be considered writable. 311 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(value)) 312 return bufferizableOp.isWritable(value, *this); 313 314 // Query BufferizableOpInterface to see if the BlockArgument is writable. 315 if (auto bbArg = value.dyn_cast<BlockArgument>()) 316 if (auto bufferizableOp = 317 getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) 318 return bufferizableOp.isWritable(bbArg, *this); 319 320 // Not a bufferizable op: The conservative answer is "not writable". 321 return false; 322 } 323 324 //===----------------------------------------------------------------------===// 325 // Bufferization-specific alias analysis. 326 //===----------------------------------------------------------------------===// 327 328 /// Return true if opOperand has been decided to bufferize in-place. 329 static bool isInplaceMemoryWrite(OpOperand &opOperand, 330 const BufferizationAliasInfo &aliasInfo, 331 const AnalysisState &state) { 332 // OpOperands that do not bufferize to a memory write do not write in-place. 333 if (!state.bufferizesToMemoryWrite(opOperand)) 334 return false; 335 // Check current bufferization decisions. 336 return aliasInfo.isInPlace(opOperand); 337 } 338 339 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors 340 /// properly dominates `b` and `b` is not inside `a`. 341 static bool happensBefore(Operation *a, Operation *b, 342 const DominanceInfo &domInfo) { 343 do { 344 // TODO: Instead of isProperAncestor + properlyDominates, we should use 345 // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false) 346 if (a->isProperAncestor(b)) 347 return false; 348 if (domInfo.properlyDominates(a, b)) 349 return true; 350 } while ((a = a->getParentOp())); 351 return false; 352 } 353 354 /// For each given value, find the closest enclosing repetitive region. If this 355 /// is the same region for each value, return it. Otherwise return None. 356 /// Note: If there is no enclosing repetitive region, return nullptr. 357 static Optional<Region *> 358 getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) { 359 if (values.empty()) 360 return None; 361 Region *r = getEnclosingRepetitiveRegion(values.front()); 362 for (Value value : values.drop_front()) 363 if (getEnclosingRepetitiveRegion(value) != r) 364 return None; 365 return r; 366 } 367 368 /// Return `true` if the given tensor value is a memory write. Most values are 369 /// tensor writes, but ops that define a tensor SSA value without specifying its 370 /// contents (e.g., alloc_tensor) are not. 371 static bool isMemoryWrite(Value value, const AnalysisState &state) { 372 auto opResult = value.dyn_cast<OpResult>(); 373 if (!opResult) 374 return true; 375 auto bufferizableOp = state.getOptions().dynCastBufferizableOp(value); 376 if (!bufferizableOp) 377 return true; 378 return bufferizableOp.isMemoryWrite(opResult, state); 379 } 380 381 /// Annotate IR with details about the detected RaW conflict. 382 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, 383 Value lastWrite) { 384 static uint64_t counter = 0; 385 Operation *readingOp = uRead->getOwner(); 386 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 387 388 OpBuilder b(conflictingWritingOp->getContext()); 389 std::string id = "C_" + std::to_string(counter++); 390 391 std::string conflictingWriteAttr = 392 id + 393 "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) + 394 "]"; 395 conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr()); 396 397 std::string readAttr = 398 id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; 399 readingOp->setAttr(readAttr, b.getUnitAttr()); 400 401 if (auto opResult = lastWrite.dyn_cast<OpResult>()) { 402 std::string lastWriteAttr = id + "[LAST-WRITE: result " + 403 std::to_string(opResult.getResultNumber()) + 404 "]"; 405 opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr()); 406 } else { 407 auto bbArg = lastWrite.cast<BlockArgument>(); 408 std::string lastWriteAttr = 409 id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; 410 bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr()); 411 } 412 } 413 414 /// Given sets of uses and writes, return true if there is a RaW conflict under 415 /// the assumption that all given reads/writes alias the same buffer and that 416 /// all given writes bufferize inplace. 417 /// 418 /// A conflict is: According to SSA use-def chains, a read R is supposed to read 419 /// the result of a write W1. But because of bufferization decisions, R actually 420 /// reads another write W2. 421 static bool hasReadAfterWriteInterference( 422 const DenseSet<OpOperand *> &usesRead, 423 const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo, 424 AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { 425 const BufferizationOptions &options = state.getOptions(); 426 427 // Gather all written aliases. Skip over aliases that are not actual writes. 428 SmallVector<Value> writtenAliases; 429 for (OpOperand *uWrite : usesWrite) 430 if (isMemoryWrite(uWrite->get(), state)) 431 writtenAliases.push_back(uWrite->get()); 432 // Find the inner-most enclosing repetitive region of each alias. If this is 433 // the same region for every alias, save it in `repetitiveRegionOfWrites`. 434 Optional<Region *> repetitiveRegionOfWrites = 435 getCommonEnclosingRepetitiveRegion(writtenAliases); 436 437 for (OpOperand *uRead : usesRead) { 438 Operation *readingOp = uRead->getOwner(); 439 440 // Find most recent writes of uRead by following the SSA use-def chain. 441 // E.g.: 442 // 443 // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32> 444 // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32> 445 // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type 446 // 447 // In the above example, if uRead is the OpOperand of reading_op, lastWrite 448 // is %0. Note that operations that create an alias but do not write (such 449 // as ExtractSliceOp) are skipped. 450 SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get()); 451 452 // Look for conflicting memory writes. Potential conflicts are writes to an 453 // alias that have been decided to bufferize inplace. 454 for (OpOperand *uConflictingWrite : usesWrite) { 455 // Throughout this loop, check for multiple requirements that have to be 456 // met for uConflictingWrite to be an actual conflict. 457 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 458 459 // Check if conflictingWritingOp is in the same repetitive region as all 460 // written aliases. If this is not the case, there is no meaningful 461 // `happensBefore` relationship because conflictingWritingOp may be 462 // executed multiple times. E.g.: 463 // 464 // %0 = ... : tensor<?xf32> 465 // scf.for ... { 466 // "reading_op"(%0) : tensor<?xf32> 467 // %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> 468 // ... 469 // } 470 // 471 // In the above example, reading_op happens before writing_op according to 472 // op dominance. However, both ops may happen multiple times; in 473 // particular, the second execution of reading_op happens after the first 474 // execution of writing_op. This is problematic if the tensor they operate 475 // on (%0) is defined outside of the loop. 476 // 477 // Counter example: 478 // 479 // scf.for ... { 480 // %0 = ... : tensor<?xf32> 481 // "reading_op"(%0) : tensor<?xf32> 482 // %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> 483 // ... 484 // } 485 // 486 // In this example, %0 is in the same repetitive region as 487 // conflictingWritingOp, so op dominance can be used to compute the 488 // `happensBefore` relationship. 489 // 490 // Note: iter_args of loops are not aliases of their respective block 491 // arguments, so op domanice can be used when analyzing ops that operate 492 // on them. 493 // 494 // Note: If `writtenAliases` is empty, there are no memory writes outside 495 // of the repetitive region of conflictingWritingOp, which means that all 496 // relevant aliases are inside the same repetitive region. 497 bool canUseOpDominance = 498 writtenAliases.empty() || 499 repetitiveRegionOfWrites == 500 getEnclosingRepetitiveRegion(conflictingWritingOp); 501 502 // No conflict if the readingOp dominates conflictingWritingOp, i.e., the 503 // write is not visible when reading. 504 // 505 // Note: If ops are executed multiple times (e.g., because they are inside 506 // a loop), there may be no meaningful `happensBefore` relationship. 507 if (canUseOpDominance && 508 happensBefore(readingOp, conflictingWritingOp, domInfo)) 509 continue; 510 511 // No conflict if the reading use equals the use of the conflicting write. 512 // A use cannot conflict with itself. 513 // 514 // Note: Just being the same op is not enough. It has to be the same use. 515 // Note: If the op is executed multiple times (e.g., because it is inside 516 // a loop), it may be conflicting with itself. 517 if (canUseOpDominance && uConflictingWrite == uRead) 518 continue; 519 520 // No conflict if the op interface says so. 521 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) 522 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) 523 continue; 524 525 if (conflictingWritingOp != readingOp) 526 if (auto bufferizableOp = 527 options.dynCastBufferizableOp(conflictingWritingOp)) 528 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) 529 continue; 530 531 // Ops are not conflicting if they are in mutually exclusive regions. 532 // 533 // Note: If ops are executed multiple times (e.g., because they are inside 534 // a loop), mutually exclusive regions may be executed multiple 535 // times. 536 if (canUseOpDominance && 537 insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) 538 continue; 539 540 // Check all possible last writes. 541 for (Value lastWrite : lastWrites) { 542 // No conflict if the conflicting write happens before the last 543 // write. 544 if (Operation *writingOp = lastWrite.getDefiningOp()) { 545 if (happensBefore(conflictingWritingOp, writingOp, domInfo)) 546 // conflictingWritingOp happens before writingOp. No conflict. 547 continue; 548 // No conflict if conflictingWritingOp is contained in writingOp. 549 if (writingOp->isProperAncestor(conflictingWritingOp)) 550 continue; 551 } else { 552 auto bbArg = lastWrite.cast<BlockArgument>(); 553 Block *block = bbArg.getOwner(); 554 if (!block->findAncestorOpInBlock(*conflictingWritingOp)) 555 // conflictingWritingOp happens outside of the block. No 556 // conflict. 557 continue; 558 } 559 560 // No conflict if the conflicting write and the last write are the same 561 // use. 562 SmallVector<OpResult> aliasingOpResult = 563 state.getAliasingOpResult(*uConflictingWrite); 564 if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) 565 continue; 566 567 // All requirements are met. Conflict found! 568 569 if (options.printConflicts) 570 annotateConflict(uRead, uConflictingWrite, lastWrite); 571 572 return true; 573 } 574 } 575 } 576 577 return false; 578 } 579 580 // Helper function to iterate on aliases of `root` and capture the writes. 581 static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root, 582 const BufferizationAliasInfo &aliasInfo, 583 const AnalysisState &state) { 584 aliasInfo.applyOnAliases(root, [&](Value alias) { 585 for (auto &use : alias.getUses()) 586 // Inplace write to a value that aliases root. 587 if (isInplaceMemoryWrite(use, aliasInfo, state)) 588 res.insert(&use); 589 }); 590 } 591 592 // Helper function to iterate on aliases of `root` and capture the reads. 593 static void getAliasingReads(DenseSet<OpOperand *> &res, Value root, 594 const BufferizationAliasInfo &aliasInfo, 595 const AnalysisState &state) { 596 aliasInfo.applyOnAliases(root, [&](Value alias) { 597 for (auto &use : alias.getUses()) 598 // Read to a value that aliases root. 599 if (state.bufferizesToMemoryRead(use)) 600 res.insert(&use); 601 }); 602 } 603 604 /// Return true if bufferizing `operand` inplace would create a conflict. A read 605 /// R and a write W of the same alias set is a conflict if inplace bufferization 606 /// of W changes the value read by R to a value different from the one that 607 /// would be expected by tracing back R's origin through SSA use-def chains. 608 /// A conflict can only be introduced by a new alias and/or an inplace 609 /// bufferization decision. 610 /// 611 /// Example: 612 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?} 613 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32> 614 /// %e = tensor.extract_slice %1 615 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32> 616 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32> 617 /// 618 /// In the above example, the two TransferWriteOps have already been decided to 619 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a 620 /// conflict because: 621 /// * According to SSA use-def chains, we expect to read the result of %1. 622 /// * However, adding an alias {%0, %t} would mean that the second 623 /// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp 624 /// would no longer be reading the result of %1. 625 /// 626 /// If `checkConsistencyOnly` is true, this function checks if there is a 627 /// read-after-write conflict without bufferizing `operand` inplace. This would 628 /// indicate a problem with the current inplace bufferization decisions. 629 /// 630 /// Note: If `checkConsistencyOnly`, this function may be called with a null 631 /// OpResult. In that case, only the consistency of bufferization decisions 632 /// involving aliases of the given OpOperand are checked. 633 static bool wouldCreateReadAfterWriteInterference( 634 OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state, 635 const BufferizationAliasInfo &aliasInfo, 636 bool checkConsistencyOnly = false) { 637 // Collect reads and writes of all aliases of OpOperand and OpResult. 638 DenseSet<OpOperand *> usesRead, usesWrite; 639 getAliasingReads(usesRead, operand.get(), aliasInfo, state); 640 getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); 641 for (OpResult result : state.getAliasingOpResult(operand)) { 642 getAliasingReads(usesRead, result, aliasInfo, state); 643 getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); 644 } 645 if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) 646 usesWrite.insert(&operand); 647 648 return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state, 649 aliasInfo); 650 } 651 652 /// Check the reverse SSA use-def chain (following aliasing OpOperands) for 653 /// non-writable tensor values. Stop searching when an out-of-place bufferized 654 /// OpOperand was found (or when the OpOperand was not bufferized yet). 655 /// `currentOpOperand` is assumed to be in-place, even if that decision was not 656 /// materialized in `aliasInfo` yet. 657 static bool 658 hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand, 659 const BufferizationAliasInfo &aliasInfo, 660 const OneShotAnalysisState &state) { 661 SmallVector<Value> worklist; 662 worklist.push_back(value); 663 while (!worklist.empty()) { 664 Value nextVal = worklist.pop_back_val(); 665 if (!state.isWritable(nextVal)) 666 return true; 667 668 // If `nextVal` is not a BlockArgument: End of use-def chain reached. 669 auto opResult = nextVal.dyn_cast<OpResult>(); 670 if (!opResult) 671 continue; 672 673 // Follow reverse SSA use-def chain. 674 SmallVector<OpOperand *> aliasingOpOperands = 675 state.getAliasingOpOperand(opResult); 676 for (OpOperand *opOperand : aliasingOpOperands) 677 if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand) 678 worklist.push_back(opOperand->get()); 679 } 680 return false; 681 } 682 683 /// Return true if bufferizing `operand` inplace would create a write to a 684 /// non-writable buffer. 685 static bool wouldCreateWriteToNonWritableBuffer( 686 OpOperand &operand, const BufferizationAliasInfo &aliasInfo, 687 OneShotAnalysisState &state, bool checkConsistencyOnly = false) { 688 // Collect writes of all aliases of OpOperand and OpResult. 689 DenseSet<OpOperand *> usesWrite; 690 getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); 691 for (OpResult result : state.getAliasingOpResult(operand)) { 692 getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); 693 } 694 if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) 695 usesWrite.insert(&operand); 696 697 // Assuming that `operand` bufferizes in-place: For each write (to each 698 // alias), check if there is a non-writable tensor in the reverse SSA use-def 699 // chain. 700 for (OpOperand *uWrite : usesWrite) 701 if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, 702 aliasInfo, state)) 703 return true; 704 705 return false; 706 } 707 708 //===----------------------------------------------------------------------===// 709 // Bufferization analyses. 710 //===----------------------------------------------------------------------===// 711 712 /// Determine if `operand` can be bufferized in-place. 713 static LogicalResult bufferizableInPlaceAnalysisImpl( 714 OpOperand &operand, BufferizationAliasInfo &aliasInfo, 715 OneShotAnalysisState &state, const DominanceInfo &domInfo) { 716 bool foundInterference = 717 wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) || 718 wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo); 719 720 if (foundInterference) 721 aliasInfo.bufferizeOutOfPlace(operand); 722 else 723 aliasInfo.bufferizeInPlace(operand, state); 724 725 return success(); 726 } 727 728 /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in 729 /// reverse and bufferize ops greedily. This is a good starter heuristic. 730 /// 731 /// Even if an op does not read or write, it may still create an alias when 732 /// bufferized in-place. An example of such ops is tensor.extract_slice. 733 /// 734 /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace: 735 /// 736 /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This 737 /// cannot change the flow of information for either the source or the 738 /// result buffers. 739 /// 740 /// When bufferized inplace, an ExtractSliceOp does not by itself create any 741 /// read or write from memory. Instead, it has the effect of merging the alias 742 /// sets of the source and the result buffers. 743 /// 744 /// An analysis is required to ensure inplace bufferization would not result in 745 /// RaW dependence violations. 746 static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops, 747 BufferizationAliasInfo &aliasInfo, 748 OneShotAnalysisState &state, 749 const DominanceInfo &domInfo, 750 unsigned analysisFuzzerSeed = 0) { 751 if (analysisFuzzerSeed) { 752 // This is a fuzzer. For testing purposes only. Randomize the order in which 753 // operations are analyzed. The bufferization quality is likely worse, but 754 // we want to make sure that no assertions are triggered anywhere. 755 std::mt19937 g(analysisFuzzerSeed); 756 llvm::shuffle(ops.begin(), ops.end(), g); 757 } 758 759 // Walk ops in reverse for better interference analysis. 760 for (Operation *op : reverse(ops)) 761 for (OpOperand &opOperand : op->getOpOperands()) 762 if (opOperand.get().getType().isa<TensorType>()) 763 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 764 if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo, 765 state, domInfo))) 766 return failure(); 767 768 return success(); 769 } 770 771 /// Return true if the given op has a tensor result or a tensor operand. 772 static bool hasTensorSemantics(Operation *op) { 773 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 774 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 775 return hasTensorResult || hasTensorOperand; 776 } 777 778 /// Analyze all ops that are contained in `op`. 779 static LogicalResult inPlaceAnalysis(Operation *op, 780 BufferizationAliasInfo &aliasInfo, 781 OneShotAnalysisState &state, 782 const DominanceInfo &domInfo, 783 unsigned analysisFuzzerSeed = 0) { 784 // Collect ops so we can build our own reverse traversal. 785 SmallVector<Operation *> ops; 786 op->walk([&](Operation *op) { 787 // No tensors => no buffers. 788 if (!hasTensorSemantics(op)) 789 return; 790 ops.push_back(op); 791 }); 792 793 return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed); 794 } 795 796 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. 797 static void equivalenceAnalysis(SmallVector<Operation *> &ops, 798 BufferizationAliasInfo &aliasInfo, 799 AnalysisState &state) { 800 for (Operation *op : ops) 801 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 802 for (OpResult opResult : op->getOpResults()) 803 if (opResult.getType().isa<TensorType>()) 804 for (OpOperand *opOperand : 805 bufferizableOp.getAliasingOpOperand(opResult, state)) 806 if (state.isInPlace(*opOperand)) 807 if (bufferizableOp.bufferRelation(opResult, state) == 808 BufferRelation::Equivalent) 809 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get()); 810 } 811 812 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained 813 /// in `op`. 814 static void equivalenceAnalysis(Operation *op, 815 BufferizationAliasInfo &aliasInfo, 816 AnalysisState &state) { 817 // Traverse ops in PostOrder: Nested ops first, then enclosing ops. 818 SmallVector<Operation *> ops; 819 op->walk<WalkOrder::PostOrder>([&](Operation *op) { 820 // No tensors => no buffers. 821 if (none_of(op->getResultTypes(), isaTensor)) 822 return; 823 ops.push_back(op); 824 }); 825 826 equivalenceAnalysis(ops, aliasInfo, state); 827 } 828 829 /// Assert that the current bufferization decisions are consistent. 830 static LogicalResult 831 checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, 832 AnalysisState &state, 833 const BufferizationAliasInfo &aliasInfo) { 834 const BufferizationOptions &options = state.getOptions(); 835 Operation *inconsistentOp = nullptr; 836 WalkResult walkResult = op->walk([&](Operation *op) { 837 if (auto bufferizableOp = options.dynCastBufferizableOp(op)) 838 for (OpOperand &opOperand : op->getOpOperands()) 839 if (opOperand.get().getType().isa<TensorType>()) { 840 if (wouldCreateReadAfterWriteInterference( 841 opOperand, domInfo, state, aliasInfo, 842 /*checkConsistencyOnly=*/true)) { 843 // This error can happen if certain "mustBufferizeInPlace" interface 844 // methods are implemented incorrectly, such that the IR already has 845 // a RaW conflict before making any bufferization decisions. 846 inconsistentOp = op; 847 return WalkResult::interrupt(); 848 } 849 } 850 return WalkResult::advance(); 851 }); 852 853 if (walkResult.wasInterrupted()) 854 return inconsistentOp->emitError("input IR has RaW conflict"); 855 return success(); 856 } 857 858 /// Annotate the IR with the result of the analysis. For testing/debugging only. 859 static void 860 annotateOpsWithBufferizationMarkers(Operation *op, 861 const BufferizationAliasInfo &aliasInfo, 862 AnalysisState &state) { 863 op->walk([&](Operation *op) { 864 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 865 for (OpOperand &opOperand : op->getOpOperands()) 866 if (opOperand.get().getType().isa<TensorType>()) 867 setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand)); 868 }); 869 } 870 871 /// Assert that IR is in destination-passing style. I.e., every value that is 872 /// returned or yielded from a block is: 873 /// * aliasing a bbArg of that block or a parent block, or 874 /// * aliasing an OpResult of a op in a parent block. 875 /// 876 /// Example: 877 /// ``` 878 /// %0 = "some_op" : tensor<?xf32> 879 /// %1 = scf.if %c -> (tensor<?xf32>) { 880 /// scf.yield %0 : tensor<?xf32> 881 /// } else { 882 /// %t = linalg.alloc_tensor : tensor<?xf32> 883 /// scf.yield %t : tensor<?xf32> 884 /// } 885 /// ``` 886 /// In the above example, the first scf.yield op satifies destination-passing 887 /// style because the yielded value %0 is defined in the parent block. The 888 /// second scf.yield op does not satisfy destination-passing style because the 889 /// yielded value %t is defined in the same block as the scf.yield op. 890 // TODO: The current implementation checks for equivalent values instead of 891 // aliasing values, which is stricter than needed. We can currently not check 892 // for aliasing values because the analysis is a maybe-alias analysis and we 893 // need a must-alias analysis here. 894 static LogicalResult 895 assertDestinationPassingStyle(Operation *op, AnalysisState &state, 896 BufferizationAliasInfo &aliasInfo, 897 SmallVector<Operation *> &newOps) { 898 LogicalResult status = success(); 899 DominanceInfo domInfo(op); 900 op->walk([&](Operation *returnOp) { 901 if (!isRegionReturnLike(returnOp) || 902 !state.getOptions().isOpAllowed(returnOp)) 903 return WalkResult::advance(); 904 905 for (OpOperand &returnValOperand : returnOp->getOpOperands()) { 906 Value returnVal = returnValOperand.get(); 907 // Skip non-tensor values. 908 if (!returnVal.getType().isa<TensorType>()) 909 continue; 910 911 bool foundEquivValue = false; 912 aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { 913 if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) { 914 Operation *definingOp = bbArg.getOwner()->getParentOp(); 915 if (definingOp->isProperAncestor(returnOp)) 916 foundEquivValue = true; 917 return; 918 } 919 920 Operation *definingOp = equivVal.getDefiningOp(); 921 if (definingOp->getBlock()->findAncestorOpInBlock( 922 *returnOp->getParentOp())) 923 // Skip ops that happen after `returnOp` and parent ops. 924 if (happensBefore(definingOp, returnOp, domInfo)) 925 foundEquivValue = true; 926 }); 927 928 if (!foundEquivValue) 929 status = 930 returnOp->emitError() 931 << "operand #" << returnValOperand.getOperandNumber() 932 << " of ReturnLike op does not satisfy destination passing style"; 933 } 934 935 return WalkResult::advance(); 936 }); 937 938 return status; 939 } 940 941 LogicalResult bufferization::analyzeOp(Operation *op, 942 OneShotAnalysisState &state) { 943 DominanceInfo domInfo(op); 944 BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); 945 const auto &options = 946 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 947 948 // Catch incorrect API usage. 949 assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) || 950 !options.bufferizeFunctionBoundaries) && 951 "must use ModuleBufferize to bufferize function boundaries"); 952 953 if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) 954 return failure(); 955 956 // If the analysis fails, just return. 957 if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo, 958 options.analysisFuzzerSeed))) 959 return failure(); 960 equivalenceAnalysis(op, aliasInfo, state); 961 962 bool failedAnalysis = false; 963 if (!options.allowReturnAllocs) { 964 SmallVector<Operation *> newOps; 965 failedAnalysis |= 966 failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps)); 967 } 968 969 // Gather some extra analysis data. 970 state.gatherYieldedTensors(op); 971 state.gatherUndefinedTensorUses(op); 972 973 // Analysis verification: After setting up alias/equivalence sets, each op 974 // can check for expected invariants/limitations and fail the analysis if 975 // necessary. 976 op->walk([&](Operation *op) { 977 if (BufferizableOpInterface bufferizableOp = 978 options.dynCastBufferizableOp(op)) 979 failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state)); 980 }); 981 982 // Annotate operations if we only want to report the analysis. 983 if (options.testAnalysisOnly) 984 annotateOpsWithBufferizationMarkers(op, aliasInfo, state); 985 986 return success(!failedAnalysis); 987 } 988 989 LogicalResult 990 bufferization::runOneShotBufferize(Operation *op, 991 const OneShotBufferizationOptions &options) { 992 OneShotAnalysisState state(op, options); 993 if (failed(insertTensorCopies(op, options))) 994 return failure(); 995 if (options.testAnalysisOnly) 996 return success(); 997 return bufferizeOp(op, options, /*copyBeforeWrite=*/false); 998 } 999