1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp 10 // bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. 11 // ModuleBufferization.cpp is an extension of One-Shot Analysis for simple 12 // call graphs. 13 // 14 // One-Shot Bufferize consists of two phases. 15 // 16 // 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without 17 // inserting buffer copies. The analysis queries op bufferization semantics 18 // via `BufferizableOpInterface`. 19 // 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This 20 // function does not generate buffer copies for OpResults that were decided 21 // to bufferize inplace during the analysis phase. 22 // 23 // This file contains only the analysis. The actual bufferization is implemented 24 // via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a 25 // helper function `runOneShotBufferize` that analyzes an op (and its nested 26 // ops) and then bufferizes it. 27 // 28 // Inplace bufferization decisions are passed from the analysis to the 29 // bufferization phase via `AnalysisState` and `BufferizationAliasInfo`. 30 // They can be printed for debugging purposes with `testAnalysisOnly`. 31 // 32 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are 33 // treated conservatively. E.g., the analysis has to assume that their tensor 34 // OpOperands bufferize to memory writes. While such ops can be analyzed, they 35 // are not bufferized and remain in the IR. to_tensor and to_memref ops are 36 // inserted at the bufferization boundary. 37 // 38 // This analysis caters to high-performance codegen where buffer reuse is deemed 39 // critical: the analysis should fail if the bufferized form of the function 40 // needs to return a buffer, unless `allowReturnAllocs` is enabled. 41 42 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 43 44 #include <random> 45 46 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 47 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 48 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 49 #include "mlir/Dialect/MemRef/IR/MemRef.h" 50 #include "mlir/IR/AsmState.h" 51 #include "mlir/IR/Dominance.h" 52 #include "mlir/IR/Operation.h" 53 #include "mlir/IR/TypeUtilities.h" 54 #include "mlir/Interfaces/ControlFlowInterfaces.h" 55 #include "llvm/ADT/DenseSet.h" 56 #include "llvm/ADT/SetVector.h" 57 58 using namespace mlir; 59 using namespace mlir::bufferization; 60 61 static bool isaTensor(Type t) { return t.isa<TensorType>(); } 62 63 //===----------------------------------------------------------------------===// 64 // Bufferization-specific attribute manipulation. 65 // These are for testing and debugging only. Bufferization information is 66 // stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR 67 // is annotated with the results of the analysis (copied from 68 // BufferizationAliasInfo), so that they can be checked in tests. 69 //===----------------------------------------------------------------------===// 70 71 /// Attribute marker to specify op results that can be bufferized inPlace. 72 constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__"; 73 74 /// Mark whether OpOperand will be bufferized inplace. 75 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) { 76 Operation *op = opOperand.getOwner(); 77 auto attr = 78 op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>(); 79 SmallVector<StringRef> inPlaceVector; 80 if (attr) { 81 inPlaceVector = SmallVector<StringRef>( 82 llvm::to_vector<4>(attr.getAsValueRange<StringAttr>())); 83 } else { 84 inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none"); 85 for (OpOperand &opOperand : op->getOpOperands()) 86 if (opOperand.get().getType().isa<TensorType>()) 87 inPlaceVector[opOperand.getOperandNumber()] = "false"; 88 } 89 90 inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; 91 op->setAttr(kInPlaceResultsAttrName, 92 OpBuilder(op).getStrArrayAttr(inPlaceVector)); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // BufferizationAliasInfo 97 //===----------------------------------------------------------------------===// 98 99 BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { 100 rootOp->walk([&](Operation *op) { 101 for (Value v : op->getResults()) 102 if (v.getType().isa<TensorType>()) 103 createAliasInfoEntry(v); 104 for (Region &r : op->getRegions()) 105 for (Block &b : r.getBlocks()) 106 for (auto bbArg : b.getArguments()) 107 if (bbArg.getType().isa<TensorType>()) 108 createAliasInfoEntry(bbArg); 109 }); 110 } 111 112 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the 113 /// beginning the alias and equivalence sets only contain `v` itself. 114 void BufferizationAliasInfo::createAliasInfoEntry(Value v) { 115 aliasInfo.insert(v); 116 equivalentInfo.insert(v); 117 } 118 119 /// Insert an info entry for `newValue` and merge its alias set with that of 120 /// `alias`. 121 void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) { 122 createAliasInfoEntry(newValue); 123 aliasInfo.unionSets(newValue, alias); 124 } 125 126 /// Insert an info entry for `newValue` and merge its alias set with that of 127 /// `alias`. Additionally, merge their equivalence classes. 128 void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue, 129 Value alias) { 130 insertNewBufferAlias(newValue, alias); 131 equivalentInfo.unionSets(newValue, alias); 132 } 133 134 /// Return `true` if a value was marked as in-place bufferized. 135 bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const { 136 return inplaceBufferized.contains(&operand); 137 } 138 139 /// Set the inPlace bufferization spec to true. 140 void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand, 141 AnalysisState &state) { 142 markInPlace(operand); 143 for (OpResult result : state.getAliasingOpResult(operand)) 144 aliasInfo.unionSets(result, operand.get()); 145 } 146 147 /// Set the inPlace bufferization spec to false. 148 void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) { 149 assert(!inplaceBufferized.contains(&operand) && 150 "OpOperand was already decided to bufferize inplace"); 151 } 152 153 /// Apply `fun` to all the members of the equivalence class of `v`. 154 void BufferizationAliasInfo::applyOnEquivalenceClass( 155 Value v, function_ref<void(Value)> fun) const { 156 auto leaderIt = equivalentInfo.findLeader(v); 157 for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; 158 ++mit) { 159 fun(*mit); 160 } 161 } 162 163 /// Apply `fun` to all aliases of `v`. 164 void BufferizationAliasInfo::applyOnAliases( 165 Value v, function_ref<void(Value)> fun) const { 166 auto leaderIt = aliasInfo.findLeader(v); 167 for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) { 168 fun(*mit); 169 } 170 } 171 172 BufferizationAliasInfo::EquivalenceClassRangeType 173 BufferizationAliasInfo::getAliases(Value v) const { 174 DenseSet<Value> res; 175 auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v)); 176 for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); 177 mit != meit; ++mit) { 178 res.insert(static_cast<Value>(*mit)); 179 } 180 return BufferizationAliasInfo::EquivalenceClassRangeType( 181 aliasInfo.member_begin(it), aliasInfo.member_end()); 182 } 183 184 //===----------------------------------------------------------------------===// 185 // OneShotAnalysisState 186 //===----------------------------------------------------------------------===// 187 188 OneShotAnalysisState::OneShotAnalysisState( 189 Operation *op, const OneShotBufferizationOptions &options) 190 : AnalysisState(options), aliasInfo(op) { 191 // Set up alias sets for OpResults that must bufferize in-place. This should 192 // be done before making any other bufferization decisions. 193 op->walk([&](BufferizableOpInterface bufferizableOp) { 194 if (!options.isOpAllowed(bufferizableOp)) 195 return WalkResult::skip(); 196 for (OpOperand &opOperand : bufferizableOp->getOpOperands()) { 197 if (opOperand.get().getType().isa<TensorType>()) 198 if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) { 199 for (OpResult opResult : 200 bufferizableOp.getAliasingOpResult(opOperand, *this)) 201 aliasInfo.unionAliasSets(opOperand.get(), opResult); 202 aliasInfo.markInPlace(opOperand); 203 } 204 } 205 return WalkResult::advance(); 206 }); 207 } 208 209 bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const { 210 return aliasInfo.isInPlace(opOperand); 211 } 212 213 bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1, 214 Value v2) const { 215 return aliasInfo.areEquivalentBufferizedValues(v1, v2); 216 } 217 218 // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is 219 // to ensure that such information is available during bufferization time. 220 // Alias information can no longer be queried through BufferizationAliasInfo 221 // once we have started modifying the IR. 222 void OneShotAnalysisState::gatherYieldedTensors(Operation *op) { 223 op->walk([&](Operation *returnOp) { 224 if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp)) 225 return WalkResult::advance(); 226 227 for (OpOperand &returnValOperand : returnOp->getOpOperands()) { 228 Value returnVal = returnValOperand.get(); 229 // Skip non-tensor values. 230 if (!returnVal.getType().isa<TensorType>()) 231 continue; 232 233 // Add all aliases of the returned value. But only the ones that are in 234 // the same block. 235 aliasInfo.applyOnAliases(returnVal, [&](Value v) { 236 if (auto bbArg = v.dyn_cast<BlockArgument>()) { 237 if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp()) 238 yieldedTensors.insert(bbArg); 239 return; 240 } 241 Operation *definingOp = v.getDefiningOp(); 242 if (definingOp->getParentOp() == returnOp->getParentOp()) 243 yieldedTensors.insert(v); 244 }); 245 } 246 247 return WalkResult::advance(); 248 }); 249 } 250 251 bool OneShotAnalysisState::isTensorYielded(Value tensor) const { 252 return yieldedTensors.contains(tensor); 253 } 254 255 //===----------------------------------------------------------------------===// 256 // Bufferization-specific alias analysis. 257 //===----------------------------------------------------------------------===// 258 259 /// Return true if opOperand has been decided to bufferize in-place. 260 static bool isInplaceMemoryWrite(OpOperand &opOperand, 261 const BufferizationAliasInfo &aliasInfo, 262 AnalysisState &state) { 263 // OpOperands that do not bufferize to a memory write do not write in-place. 264 if (!state.bufferizesToMemoryWrite(opOperand)) 265 return false; 266 // Check current bufferization decisions. 267 return aliasInfo.isInPlace(opOperand); 268 } 269 270 /// Return true if, under current bufferization decisions, the buffer of `value` 271 /// is not writable. 272 static bool aliasesNonWritableBuffer(Value value, 273 const BufferizationAliasInfo &aliasInfo, 274 AnalysisState &state) { 275 bool foundNonWritableBuffer = false; 276 aliasInfo.applyOnAliases(value, [&](Value v) { 277 // Query BufferizableOpInterface to see if the value is writable. 278 // TODO: Out-of-place bufferized value could be considered writable. 279 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v)) 280 if (bufferizableOp && bufferizableOp.isWritable(v, state)) 281 return; 282 283 // Query BufferizableOpInterface to see if the BlockArgument is writable. 284 if (auto bbArg = v.dyn_cast<BlockArgument>()) 285 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp( 286 bbArg.getOwner()->getParentOp())) 287 if (bufferizableOp.isWritable(bbArg, state)) 288 return; 289 290 foundNonWritableBuffer = true; 291 }); 292 293 return foundNonWritableBuffer; 294 } 295 296 /// Return true if the buffer to which `operand` would bufferize is equivalent 297 /// to some buffer write. 298 static bool aliasesInPlaceWrite(Value value, 299 const BufferizationAliasInfo &aliasInfo, 300 AnalysisState &state) { 301 bool foundInplaceWrite = false; 302 aliasInfo.applyOnAliases(value, [&](Value v) { 303 for (auto &use : v.getUses()) { 304 if (isInplaceMemoryWrite(use, aliasInfo, state)) { 305 foundInplaceWrite = true; 306 return; 307 } 308 } 309 }); 310 return foundInplaceWrite; 311 } 312 313 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors 314 /// properly dominates `b` and `b` is not inside `a`. 315 static bool happensBefore(Operation *a, Operation *b, 316 const DominanceInfo &domInfo) { 317 do { 318 // TODO: Instead of isProperAncestor + properlyDominates, we should use 319 // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false) 320 if (a->isProperAncestor(b)) 321 return false; 322 if (domInfo.properlyDominates(a, b)) 323 return true; 324 } while ((a = a->getParentOp())); 325 return false; 326 } 327 328 /// For each given value, find the closest enclosing repetitive region. If this 329 /// is the same region for each value, return it. Otherwise return None. 330 /// Note: If there is no enclosing repetitive region, return nullptr. 331 static Optional<Region *> 332 getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) { 333 if (values.empty()) 334 return None; 335 Region *r = getEnclosingRepetitiveRegion(values.front()); 336 for (Value value : values.drop_front()) 337 if (getEnclosingRepetitiveRegion(value) != r) 338 return None; 339 return r; 340 } 341 342 /// Annotate IR with details about the detected RaW conflict. 343 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, 344 Value lastWrite) { 345 static uint64_t counter = 0; 346 Operation *readingOp = uRead->getOwner(); 347 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 348 349 OpBuilder b(conflictingWritingOp->getContext()); 350 std::string id = "C_" + std::to_string(counter++); 351 352 std::string conflictingWriteAttr = 353 id + 354 "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) + 355 "]"; 356 conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr()); 357 358 std::string readAttr = 359 id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; 360 readingOp->setAttr(readAttr, b.getUnitAttr()); 361 362 if (auto opResult = lastWrite.dyn_cast<OpResult>()) { 363 std::string lastWriteAttr = id + "[LAST-WRITE: result " + 364 std::to_string(opResult.getResultNumber()) + 365 "]"; 366 opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr()); 367 } else { 368 auto bbArg = lastWrite.cast<BlockArgument>(); 369 std::string lastWriteAttr = 370 id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; 371 bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr()); 372 } 373 } 374 375 /// Given sets of uses and writes, return true if there is a RaW conflict under 376 /// the assumption that all given reads/writes alias the same buffer and that 377 /// all given writes bufferize inplace. 378 /// 379 /// A conflict is: According to SSA use-def chains, a read R is supposed to read 380 /// the result of a write W1. But because of bufferization decisions, R actually 381 /// reads another write W2. 382 static bool hasReadAfterWriteInterference( 383 const DenseSet<OpOperand *> &usesRead, 384 const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo, 385 AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { 386 const BufferizationOptions &options = state.getOptions(); 387 388 // Gather all written aliases. 389 SmallVector<Value> writtenAliases; 390 for (OpOperand *uWrite : usesWrite) 391 writtenAliases.push_back(uWrite->get()); 392 // Find the inner-most enclosing repetitive region of each alias. If this is 393 // the same region for every alias, save it in `repetitiveRegionOfWrites`. 394 Optional<Region *> repetitiveRegionOfWrites = 395 getCommonEnclosingRepetitiveRegion(writtenAliases); 396 397 for (OpOperand *uRead : usesRead) { 398 Operation *readingOp = uRead->getOwner(); 399 400 // Find most recent writes of uRead by following the SSA use-def chain. 401 // E.g.: 402 // 403 // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32> 404 // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32> 405 // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type 406 // 407 // In the above example, if uRead is the OpOperand of reading_op, lastWrite 408 // is %0. Note that operations that create an alias but do not write (such 409 // as ExtractSliceOp) are skipped. 410 SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get()); 411 412 // Look for conflicting memory writes. Potential conflicts are writes to an 413 // alias that have been decided to bufferize inplace. 414 for (OpOperand *uConflictingWrite : usesWrite) { 415 // Throughout this loop, check for multiple requirements that have to be 416 // met for uConflictingWrite to be an actual conflict. 417 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 418 419 // Check if conflictingWritingOp is in the same repetitive region as all 420 // written aliases. If this is not the case, there is no meaningful 421 // `happensBefore` relationship because conflictingWritingOp may be 422 // executed multiple times. E.g.: 423 // 424 // %0 = ... : tensor<?xf32> 425 // scf.for ... { 426 // "reading_op"(%0) : tensor<?xf32> 427 // %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> 428 // ... 429 // } 430 // 431 // In the above example, reading_op happens before writing_op according to 432 // op dominance. However, both ops may happen multiple times; in 433 // particular, the second execution of reading_op happens after the first 434 // execution of writing_op. This is problematic if the tensor they operate 435 // on (%0) is defined outside of the loop. 436 // 437 // Counter example: 438 // 439 // scf.for ... { 440 // %0 = ... : tensor<?xf32> 441 // "reading_op"(%0) : tensor<?xf32> 442 // %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> 443 // ... 444 // } 445 // 446 // In this example, %0 is in the same repetitive region as 447 // conflictingWritingOp, so op dominance can be used to compute the 448 // `happensBefore` relationship. 449 // 450 // Note: iter_args of loops are not aliases of their respective block 451 // arguments, so op domanice can be used when analyzing ops that operate 452 // on them. 453 bool canUseOpDominance = 454 repetitiveRegionOfWrites == 455 getEnclosingRepetitiveRegion(conflictingWritingOp); 456 457 // No conflict if the readingOp dominates conflictingWritingOp, i.e., the 458 // write is not visible when reading. 459 // 460 // Note: If ops are executed multiple times (e.g., because they are inside 461 // a loop), there may be no meaningful `happensBefore` relationship. 462 if (canUseOpDominance && 463 happensBefore(readingOp, conflictingWritingOp, domInfo)) 464 continue; 465 466 // No conflict if the reading use equals the use of the conflicting write. 467 // A use cannot conflict with itself. 468 // 469 // Note: Just being the same op is not enough. It has to be the same use. 470 // Note: If the op is executed multiple times (e.g., because it is inside 471 // a loop), it may be conflicting with itself. 472 if (canUseOpDominance && uConflictingWrite == uRead) 473 continue; 474 475 // No conflict if the op interface says so. 476 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) 477 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) 478 continue; 479 480 if (conflictingWritingOp != readingOp) 481 if (auto bufferizableOp = 482 options.dynCastBufferizableOp(conflictingWritingOp)) 483 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) 484 continue; 485 486 // Ops are not conflicting if they are in mutually exclusive regions. 487 // 488 // Note: If ops are executed multiple times (e.g., because they are inside 489 // a loop), mutually exclusive regions may be executed multiple 490 // times. 491 if (canUseOpDominance && 492 insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) 493 continue; 494 495 // Check all possible last writes. 496 for (Value lastWrite : lastWrites) { 497 // No conflict if the conflicting write happens before the last 498 // write. 499 if (Operation *writingOp = lastWrite.getDefiningOp()) { 500 if (happensBefore(conflictingWritingOp, writingOp, domInfo)) 501 // conflictingWritingOp happens before writingOp. No conflict. 502 continue; 503 // No conflict if conflictingWritingOp is contained in writingOp. 504 if (writingOp->isProperAncestor(conflictingWritingOp)) 505 continue; 506 } else { 507 auto bbArg = lastWrite.cast<BlockArgument>(); 508 Block *block = bbArg.getOwner(); 509 if (!block->findAncestorOpInBlock(*conflictingWritingOp)) 510 // conflictingWritingOp happens outside of the block. No 511 // conflict. 512 continue; 513 } 514 515 // No conflict if the conflicting write and the last write are the same 516 // use. 517 SmallVector<OpResult> aliasingOpResult = 518 state.getAliasingOpResult(*uConflictingWrite); 519 if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) 520 continue; 521 522 // All requirements are met. Conflict found! 523 524 if (options.printConflicts) 525 annotateConflict(uRead, uConflictingWrite, lastWrite); 526 527 return true; 528 } 529 } 530 } 531 532 return false; 533 } 534 535 /// Return true if bufferizing `operand` inplace would create a conflict. A read 536 /// R and a write W of the same alias set is a conflict if inplace bufferization 537 /// of W changes the value read by R to a value different from the one that 538 /// would be expected by tracing back R's origin through SSA use-def chains. 539 /// A conflict can only be introduced by a new alias and/or an inplace 540 /// bufferization decision. 541 /// 542 /// Example: 543 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?} 544 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32> 545 /// %e = tensor.extract_slice %1 546 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32> 547 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32> 548 /// 549 /// In the above example, the two TransferWriteOps have already been decided to 550 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a 551 /// conflict because: 552 /// * According to SSA use-def chains, we expect to read the result of %1. 553 /// * However, adding an alias {%0, %t} would mean that the second 554 /// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp 555 /// would no longer be reading the result of %1. 556 /// 557 /// If `checkConsistencyOnly` is true, this function checks if there is a 558 /// read-after-write conflict without bufferizing `operand` inplace. This would 559 /// indicate a problem with the current inplace bufferization decisions. 560 /// 561 /// Note: If `checkConsistencyOnly`, this function may be called with a null 562 /// OpResult. In that case, only the consistency of bufferization decisions 563 /// involving aliases of the given OpOperand are checked. 564 static bool wouldCreateReadAfterWriteInterference( 565 OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state, 566 const BufferizationAliasInfo &aliasInfo, 567 bool checkConsistencyOnly = false) { 568 // Helper function to iterate on aliases of `root` and capture the reads. 569 auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) { 570 aliasInfo.applyOnAliases(root, [&](Value alias) { 571 for (auto &use : alias.getUses()) 572 // Read to a value that aliases root. 573 if (state.bufferizesToMemoryRead(use)) 574 res.insert(&use); 575 }); 576 }; 577 578 // Helper function to iterate on aliases of `root` and capture the writes. 579 auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) { 580 aliasInfo.applyOnAliases(root, [&](Value alias) { 581 for (auto &use : alias.getUses()) 582 // Inplace write to a value that aliases root. 583 if (isInplaceMemoryWrite(use, aliasInfo, state)) 584 res.insert(&use); 585 }); 586 }; 587 588 // Collect reads and writes of all aliases of OpOperand and OpResult. 589 DenseSet<OpOperand *> usesRead, usesWrite; 590 getAliasingReads(usesRead, operand.get()); 591 getAliasingInplaceWrites(usesWrite, operand.get()); 592 for (OpResult result : state.getAliasingOpResult(operand)) { 593 getAliasingReads(usesRead, result); 594 getAliasingInplaceWrites(usesWrite, result); 595 } 596 if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) 597 usesWrite.insert(&operand); 598 599 return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state, 600 aliasInfo); 601 } 602 603 /// Return true if bufferizing `opOperand` inplace would create a write to a 604 /// non-writable buffer. 605 static bool 606 wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, 607 const BufferizationAliasInfo &aliasInfo, 608 AnalysisState &state) { 609 // Certain buffers are not writeable: 610 // 1. A function bbArg that is not inplaceable or 611 // 2. A constant op. 612 bool nonWritable = 613 aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state); 614 if (!nonWritable) 615 return false; 616 617 // This is a problem only if the buffer is written to via some alias. 618 bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) || 619 state.bufferizesToMemoryWrite(opOperand); 620 621 for (OpResult opResult : state.getAliasingOpResult(opOperand)) 622 hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state); 623 624 return hasWrite; 625 } 626 627 //===----------------------------------------------------------------------===// 628 // Bufferization analyses. 629 //===----------------------------------------------------------------------===// 630 631 /// Determine if `operand` can be bufferized in-place. 632 static LogicalResult bufferizableInPlaceAnalysisImpl( 633 OpOperand &operand, BufferizationAliasInfo &aliasInfo, AnalysisState &state, 634 const DominanceInfo &domInfo) { 635 bool foundInterference = 636 wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) || 637 wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo); 638 639 if (foundInterference) 640 aliasInfo.bufferizeOutOfPlace(operand); 641 else 642 aliasInfo.bufferizeInPlace(operand, state); 643 644 return success(); 645 } 646 647 /// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in 648 /// reverse and bufferize ops greedily. This is a good starter heuristic. 649 /// 650 /// Even if an op does not read or write, it may still create an alias when 651 /// bufferized in-place. An example of such ops is tensor.extract_slice. 652 /// 653 /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace: 654 /// 655 /// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This 656 /// cannot change the flow of information for either the source or the 657 /// result buffers. 658 /// 659 /// When bufferized inplace, an ExtractSliceOp does not by itself create any 660 /// read or write from memory. Instead, it has the effect of merging the alias 661 /// sets of the source and the result buffers. 662 /// 663 /// An analysis is required to ensure inplace bufferization would not result in 664 /// RaW dependence violations. 665 static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops, 666 BufferizationAliasInfo &aliasInfo, 667 AnalysisState &state, 668 const DominanceInfo &domInfo, 669 unsigned analysisFuzzerSeed = 0) { 670 if (analysisFuzzerSeed) { 671 // This is a fuzzer. For testing purposes only. Randomize the order in which 672 // operations are analyzed. The bufferization quality is likely worse, but 673 // we want to make sure that no assertions are triggered anywhere. 674 std::mt19937 g(analysisFuzzerSeed); 675 llvm::shuffle(ops.begin(), ops.end(), g); 676 } 677 678 // Walk ops in reverse for better interference analysis. 679 for (Operation *op : reverse(ops)) 680 for (OpOperand &opOperand : op->getOpOperands()) 681 if (opOperand.get().getType().isa<TensorType>()) 682 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 683 if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo, 684 state, domInfo))) 685 return failure(); 686 687 return success(); 688 } 689 690 /// Return true if the given op has a tensor result or a tensor operand. 691 static bool hasTensorSemantics(Operation *op) { 692 bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); 693 bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); 694 return hasTensorResult || hasTensorOperand; 695 } 696 697 /// Analyze all ops that are contained in `op`. 698 static LogicalResult inPlaceAnalysis(Operation *op, 699 BufferizationAliasInfo &aliasInfo, 700 AnalysisState &state, 701 const DominanceInfo &domInfo, 702 unsigned analysisFuzzerSeed = 0) { 703 // Collect ops so we can build our own reverse traversal. 704 SmallVector<Operation *> ops; 705 op->walk([&](Operation *op) { 706 // No tensors => no buffers. 707 if (!hasTensorSemantics(op)) 708 return; 709 ops.push_back(op); 710 }); 711 712 return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed); 713 } 714 715 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. 716 static void equivalenceAnalysis(SmallVector<Operation *> &ops, 717 BufferizationAliasInfo &aliasInfo, 718 AnalysisState &state) { 719 for (Operation *op : ops) 720 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 721 for (OpResult opResult : op->getOpResults()) 722 if (opResult.getType().isa<TensorType>()) 723 for (OpOperand *opOperand : 724 bufferizableOp.getAliasingOpOperand(opResult, state)) 725 if (state.isInPlace(*opOperand)) 726 if (bufferizableOp.bufferRelation(opResult, state) == 727 BufferRelation::Equivalent) 728 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get()); 729 } 730 731 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained 732 /// in `op`. 733 static void equivalenceAnalysis(Operation *op, 734 BufferizationAliasInfo &aliasInfo, 735 AnalysisState &state) { 736 // Traverse ops in PostOrder: Nested ops first, then enclosing ops. 737 SmallVector<Operation *> ops; 738 op->walk<WalkOrder::PostOrder>([&](Operation *op) { 739 // No tensors => no buffers. 740 if (none_of(op->getResultTypes(), isaTensor)) 741 return; 742 ops.push_back(op); 743 }); 744 745 equivalenceAnalysis(ops, aliasInfo, state); 746 } 747 748 /// Assert that the current bufferization decisions are consistent. 749 static LogicalResult 750 checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, 751 AnalysisState &state, 752 const BufferizationAliasInfo &aliasInfo) { 753 const BufferizationOptions &options = state.getOptions(); 754 Operation *inconsistentOp = nullptr; 755 WalkResult walkResult = op->walk([&](Operation *op) { 756 if (auto bufferizableOp = options.dynCastBufferizableOp(op)) 757 for (OpOperand &opOperand : op->getOpOperands()) 758 if (opOperand.get().getType().isa<TensorType>()) { 759 if (wouldCreateReadAfterWriteInterference( 760 opOperand, domInfo, state, aliasInfo, 761 /*checkConsistencyOnly=*/true)) { 762 // This error can happen if certain "mustBufferizeInPlace" interface 763 // methods are implemented incorrectly, such that the IR already has 764 // a RaW conflict before making any bufferization decisions. 765 inconsistentOp = op; 766 return WalkResult::interrupt(); 767 } 768 } 769 return WalkResult::advance(); 770 }); 771 772 if (walkResult.wasInterrupted()) 773 return inconsistentOp->emitError("input IR has RaW conflict"); 774 return success(); 775 } 776 777 /// Annotate the IR with the result of the analysis. For testing/debugging only. 778 static void 779 annotateOpsWithBufferizationMarkers(Operation *op, 780 const BufferizationAliasInfo &aliasInfo, 781 AnalysisState &state) { 782 op->walk([&](Operation *op) { 783 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) 784 for (OpOperand &opOperand : op->getOpOperands()) 785 if (opOperand.get().getType().isa<TensorType>()) 786 setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand)); 787 }); 788 } 789 790 /// Assert that IR is in destination-passing style. I.e., every value that is 791 /// returned or yielded from a block is: 792 /// * aliasing a bbArg of that block or a parent block, or 793 /// * aliasing an OpResult of a op in a parent block. 794 /// 795 /// Example: 796 /// ``` 797 /// %0 = "some_op" : tensor<?xf32> 798 /// %1 = scf.if %c -> (tensor<?xf32>) { 799 /// scf.yield %0 : tensor<?xf32> 800 /// } else { 801 /// %t = linalg.init_tensor : tensor<?xf32> 802 /// scf.yield %t : tensor<?xf32> 803 /// } 804 /// ``` 805 /// In the above example, the first scf.yield op satifies destination-passing 806 /// style because the yielded value %0 is defined in the parent block. The 807 /// second scf.yield op does not satisfy destination-passing style because the 808 /// yielded value %t is defined in the same block as the scf.yield op. 809 // TODO: The current implementation checks for equivalent values instead of 810 // aliasing values, which is stricter than needed. We can currently not check 811 // for aliasing values because the analysis is a maybe-alias analysis and we 812 // need a must-alias analysis here. 813 static LogicalResult 814 assertDestinationPassingStyle(Operation *op, AnalysisState &state, 815 BufferizationAliasInfo &aliasInfo, 816 SmallVector<Operation *> &newOps) { 817 LogicalResult status = success(); 818 DominanceInfo domInfo(op); 819 op->walk([&](Operation *returnOp) { 820 if (!isRegionReturnLike(returnOp) || 821 !state.getOptions().isOpAllowed(returnOp)) 822 return WalkResult::advance(); 823 824 for (OpOperand &returnValOperand : returnOp->getOpOperands()) { 825 Value returnVal = returnValOperand.get(); 826 // Skip non-tensor values. 827 if (!returnVal.getType().isa<TensorType>()) 828 continue; 829 830 bool foundEquivValue = false; 831 aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { 832 if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) { 833 Operation *definingOp = bbArg.getOwner()->getParentOp(); 834 if (definingOp->isProperAncestor(returnOp)) 835 foundEquivValue = true; 836 return; 837 } 838 839 Operation *definingOp = equivVal.getDefiningOp(); 840 if (definingOp->getBlock()->findAncestorOpInBlock( 841 *returnOp->getParentOp())) 842 // Skip ops that happen after `returnOp` and parent ops. 843 if (happensBefore(definingOp, returnOp, domInfo)) 844 foundEquivValue = true; 845 }); 846 847 if (!foundEquivValue) 848 status = 849 returnOp->emitError() 850 << "operand #" << returnValOperand.getOperandNumber() 851 << " of ReturnLike op does not satisfy destination passing style"; 852 } 853 854 return WalkResult::advance(); 855 }); 856 857 return status; 858 } 859 860 LogicalResult bufferization::analyzeOp(Operation *op, 861 OneShotAnalysisState &state) { 862 DominanceInfo domInfo(op); 863 BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); 864 const auto &options = 865 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 866 867 if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) 868 return failure(); 869 870 // If the analysis fails, just return. 871 if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo, 872 options.analysisFuzzerSeed))) 873 return failure(); 874 equivalenceAnalysis(op, aliasInfo, state); 875 876 for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) { 877 SmallVector<Operation *> newOps; 878 if (failed(fn(op, state, aliasInfo, newOps))) 879 return failure(); 880 // Analyze ops that were created by the PostAnalysisStepFn. 881 if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) 882 return failure(); 883 equivalenceAnalysis(newOps, aliasInfo, state); 884 } 885 886 bool failedAnalysis = false; 887 if (!options.allowReturnAllocs) { 888 SmallVector<Operation *> newOps; 889 failedAnalysis |= 890 failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps)); 891 } 892 893 // Gather all yielded tensors. 894 state.gatherYieldedTensors(op); 895 896 // Analysis verification: After setting up alias/equivalence sets, each op 897 // can check for expected invariants/limitations and fail the analysis if 898 // necessary. 899 op->walk([&](Operation *op) { 900 if (BufferizableOpInterface bufferizableOp = 901 options.dynCastBufferizableOp(op)) 902 failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state)); 903 }); 904 905 // Annotate operations if we only want to report the analysis. 906 if (options.testAnalysisOnly) 907 annotateOpsWithBufferizationMarkers(op, aliasInfo, state); 908 909 return success(!failedAnalysis); 910 } 911 912 LogicalResult 913 bufferization::runOneShotBufferize(Operation *op, 914 const OneShotBufferizationOptions &options) { 915 OneShotAnalysisState state(op, options); 916 if (failed(analyzeOp(op, state))) 917 return failure(); 918 if (options.testAnalysisOnly) 919 return success(); 920 return bufferizeOp(op, state); 921 } 922