1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/AsmState.h" 14 #include "mlir/IR/BlockAndValueMapping.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/IR/Value.h" 19 #include "llvm/Support/Debug.h" 20 21 namespace mlir { 22 namespace bufferization { 23 24 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 25 26 } // namespace bufferization 27 } // namespace mlir 28 29 #define DEBUG_TYPE "bufferizable-op-interface" 30 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 31 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 32 33 using namespace mlir; 34 using namespace bufferization; 35 36 /// Attribute name used to mark region arguments that can be bufferized 37 /// in-place during linalg comprehensive bufferization. 38 constexpr const ::llvm::StringLiteral 39 bufferization::BufferizableOpInterface::kInplaceableAttrName; 40 41 /// Attribute name used to mark allocs that are created by the bufferization. 42 static const char *kBufferAllocationAttr = "bufferization.allocation"; 43 44 /// Attribute name used to mark allocs that should not be deallocated. 45 static const char *kSkipDeallocAttr = "bufferization.skip_dealloc"; 46 47 //===----------------------------------------------------------------------===// 48 // BufferizationOptions 49 //===----------------------------------------------------------------------===// 50 51 // Default constructor for BufferizationOptions. 52 BufferizationOptions::BufferizationOptions() = default; 53 54 bool BufferizationOptions::isOpAllowed(Operation *op) const { 55 // Special case: If function boundary bufferization is deactivated, do not 56 // allow ops that belong to the `func` dialect. 57 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 58 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 59 return false; 60 61 // All other ops: Allow/disallow according to filter. 62 bool isAllowed = !filterHasAllowRule(); 63 for (const OpFilterEntry &entry : opFilter) { 64 bool filterResult = entry.fn(op); 65 switch (entry.type) { 66 case OpFilterEntry::ALLOW: 67 isAllowed |= filterResult; 68 break; 69 case OpFilterEntry::DENY: 70 if (filterResult) 71 // DENY filter matches. This op is no allowed. (Even if other ALLOW 72 // filters may match.) 73 return false; 74 }; 75 } 76 return isAllowed; 77 } 78 79 BufferizableOpInterface 80 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 81 if (isOpAllowed(op)) 82 return dyn_cast<BufferizableOpInterface>(op); 83 return nullptr; 84 } 85 86 BufferizableOpInterface 87 BufferizationOptions::dynCastBufferizableOp(Value value) const { 88 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 89 if (isOpAllowed(bufferizableOp.getOperation())) 90 return bufferizableOp; 91 return nullptr; 92 } 93 94 void BufferizationOptions::addDialectStateInitializer( 95 StringRef name, const DialectStateInitFn &fn) { 96 stateInitializers.push_back( 97 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 98 } 99 100 //===----------------------------------------------------------------------===// 101 // Helper functions for BufferizableOpInterface 102 //===----------------------------------------------------------------------===// 103 104 static void setInsertionPointAfter(OpBuilder &b, Value value) { 105 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 106 b.setInsertionPointToStart(bbArg.getOwner()); 107 } else { 108 b.setInsertionPointAfter(value.getDefiningOp()); 109 } 110 } 111 112 /// Determine which OpOperand* will alias with `result` if the op is bufferized 113 /// in place. Return an empty vector if the op is not bufferizable. 114 SmallVector<OpOperand *> 115 AnalysisState::getAliasingOpOperand(OpResult result) const { 116 if (Operation *op = result.getDefiningOp()) 117 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) 118 return bufferizableOp.getAliasingOpOperand(result, *this); 119 return {}; 120 } 121 122 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 123 /// in place. Return an empty vector if the op is not bufferizable. 124 SmallVector<OpResult> 125 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 126 if (auto bufferizableOp = 127 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 128 return bufferizableOp.getAliasingOpResult(opOperand, *this); 129 return {}; 130 } 131 132 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 133 /// op is not bufferizable. 134 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 135 if (auto bufferizableOp = 136 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 137 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 138 139 // Unknown op that returns a tensor. The inplace analysis does not support it. 140 // Conservatively return true. 141 return true; 142 } 143 144 /// Return true if `opOperand` bufferizes to a memory write. Return 145 /// `true` if the op is not bufferizable. 146 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 147 if (auto bufferizableOp = 148 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 149 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 150 151 // Unknown op that returns a tensor. The inplace analysis does not support it. 152 // Conservatively return true. 153 return true; 154 } 155 156 /// Return true if `opOperand` does neither read nor write but bufferizes to an 157 /// alias. Return false if the op is not bufferizable. 158 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 159 if (auto bufferizableOp = 160 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 161 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 162 163 // Unknown op that returns a tensor. The inplace analysis does not support it. 164 // Conservatively return false. 165 return false; 166 } 167 168 /// Return true if the given value is read by an op that bufferizes to a memory 169 /// read. Also takes into account ops that create an alias but do not read by 170 /// themselves (e.g., ExtractSliceOp). 171 bool AnalysisState::isValueRead(Value value) const { 172 assert(value.getType().isa<TensorType>() && "expected TensorType"); 173 SmallVector<OpOperand *> workingSet; 174 for (OpOperand &use : value.getUses()) 175 workingSet.push_back(&use); 176 177 while (!workingSet.empty()) { 178 OpOperand *uMaybeReading = workingSet.pop_back_val(); 179 // Skip over all ops that neither read nor write (but create an alias). 180 if (bufferizesToAliasOnly(*uMaybeReading)) 181 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 182 for (OpOperand &use : opResult.getUses()) 183 workingSet.push_back(&use); 184 if (bufferizesToMemoryRead(*uMaybeReading)) 185 return true; 186 } 187 188 return false; 189 } 190 191 // Starting from `value`, follow the use-def chain in reverse, always selecting 192 // the aliasing OpOperands. Find and return Values for which `condition` 193 // evaluates to true. OpOperands of such matching Values are not traversed any 194 // further. 195 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 196 Value value, llvm::function_ref<bool(Value)> condition) const { 197 llvm::SetVector<Value> result, workingSet; 198 workingSet.insert(value); 199 200 while (!workingSet.empty()) { 201 Value value = workingSet.pop_back_val(); 202 if (condition(value) || value.isa<BlockArgument>()) { 203 result.insert(value); 204 continue; 205 } 206 207 OpResult opResult = value.cast<OpResult>(); 208 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 209 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 210 result.insert(value); 211 continue; 212 } 213 214 for (OpOperand *o : opOperands) 215 workingSet.insert(o->get()); 216 } 217 218 return result; 219 } 220 221 // Find the Values of the last preceding write of a given Value. 222 llvm::SetVector<Value> 223 AnalysisState::findLastPrecedingWrite(Value value) const { 224 return findValueInReverseUseDefChain(value, [&](Value value) { 225 Operation *op = value.getDefiningOp(); 226 if (!op) 227 return true; 228 auto bufferizableOp = options.dynCastBufferizableOp(op); 229 if (!bufferizableOp) 230 return true; 231 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 232 }); 233 } 234 235 AnalysisState::AnalysisState(const BufferizationOptions &options) 236 : options(options) { 237 for (const BufferizationOptions::AnalysisStateInitFn &fn : 238 options.stateInitializers) 239 fn(*this); 240 } 241 242 // bufferization.to_memref is not allowed to change the rank. 243 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 244 #ifndef NDEBUG 245 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 246 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 247 rankedTensorType.getRank()) && 248 "to_memref would be invalid: mismatching ranks"); 249 #endif 250 } 251 252 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, 253 const BufferizationOptions &options) { 254 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 255 assert(tensorType && "unexpected non-tensor type"); 256 257 // Replace "%t = to_tensor %m" with %m. 258 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 259 return toTensorOp.memref(); 260 261 // Insert to_memref op. 262 OpBuilder::InsertionGuard g(rewriter); 263 setInsertionPointAfter(rewriter, tensor); 264 Type memrefType = getMemRefType(tensorType, options); 265 ensureToMemrefOpIsValid(tensor, memrefType); 266 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 267 tensor); 268 } 269 270 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate 271 /// a new buffer and copy over data from the existing buffer if out-of-place 272 /// bufferization was decided. 273 FailureOr<Value> 274 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, 275 Optional<ForceInPlacability> overrideInPlace, 276 Optional<Operation *> customCopyInsertionPoint) { 277 const BufferizationOptions &options = analysisState.getOptions(); 278 OpBuilder::InsertionGuard guard(rewriter); 279 Operation *op = opOperand.getOwner(); 280 Location loc = op->getLoc(); 281 SmallVector<OpResult> aliasingOpResults = 282 analysisState.getAliasingOpResult(opOperand); 283 Value operand = opOperand.get(); 284 Value operandBuffer = lookupBuffer(rewriter, operand, options); 285 286 // Can `operandBuffer` be used directly or do we need a copy? 287 bool inplace = 288 overrideInPlace != FORCE_OUT_OF_PLACE && 289 (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand)); 290 if (inplace) 291 return operandBuffer; 292 293 // Bufferizing out-of-place: Allocate a new buffer. 294 // Move insertion point right after `operandBuffer`. That is where the 295 // allocation should be inserted (in the absence of allocation hoisting). 296 setInsertionPointAfter(rewriter, operandBuffer); 297 // Allocate the result buffer. The buffer should be deallocated if the tensor 298 // is not yielded and deallocs are enabled in general. 299 bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { 300 return getAnalysisState().isTensorYielded(v); 301 }); 302 FailureOr<Value> resultBuffer = createAlloc( 303 rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); 304 if (failed(resultBuffer)) 305 return failure(); 306 // Do not copy if the last preceding writes of `operand` are ops that do 307 // not write (skipping ops that merely create aliases). E.g., InitTensorOp. 308 // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA 309 // use-def chain, it returns that value, regardless of whether it is a 310 // memory write or not. 311 SetVector<Value> lastWrites = analysisState.findLastPrecedingWrite(operand); 312 if (llvm::none_of(lastWrites, [&](Value lastWrite) { 313 if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) 314 return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), 315 analysisState); 316 return true; 317 })) 318 return resultBuffer; 319 // Do not copy if the copied data is never read. 320 if (!aliasingOpResults.empty() && 321 !analysisState.bufferizesToMemoryRead(opOperand) && 322 llvm::none_of(aliasingOpResults, [&](OpResult opResult) { 323 return analysisState.isValueRead(opResult); 324 })) 325 return resultBuffer; 326 // Do not copy if this op does not read the data, but writes it. 327 if (analysisState.bufferizesToMemoryWrite(opOperand) && 328 !analysisState.bufferizesToMemoryRead(opOperand)) 329 return resultBuffer; 330 331 if (customCopyInsertionPoint) { 332 rewriter.setInsertionPoint(*customCopyInsertionPoint); 333 } else { 334 // The copy happens right before the op that is bufferized. 335 rewriter.setInsertionPoint(op); 336 } 337 if (failed( 338 createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options))) 339 return failure(); 340 341 return resultBuffer; 342 } 343 344 /// Return the buffer type for a given OpOperand (tensor) after bufferization. 345 BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const { 346 Value tensor = opOperand.get(); 347 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 348 assert(tensorType && "unexpected non-tensor type"); 349 350 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 351 return toTensorOp.memref().getType().cast<BaseMemRefType>(); 352 353 return getMemRefType(tensorType, getOptions()); 354 } 355 356 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 357 Operation *op, 358 ValueRange values) { 359 assert(values.size() == op->getNumResults() && 360 "expected one value per OpResult"); 361 OpBuilder::InsertionGuard g(rewriter); 362 363 // Replace all OpResults with the given values. 364 SmallVector<Value> replacements; 365 for (OpResult opResult : op->getOpResults()) { 366 Value replacement = values[opResult.getResultNumber()]; 367 if (opResult.getType().isa<TensorType>()) { 368 // The OpResult is a tensor. Such values are replaced with memrefs during 369 // bufferization. 370 assert((replacement.getType().isa<MemRefType>() || 371 replacement.getType().isa<UnrankedMemRefType>()) && 372 "tensor op result should be replaced with a memref value"); 373 // The existing uses of the OpResult still expect a tensor. Insert a 374 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 375 // loose all of its users and eventually DCE away. 376 rewriter.setInsertionPointAfter(op); 377 replacement = rewriter.create<bufferization::ToTensorOp>( 378 replacement.getLoc(), replacement); 379 } 380 replacements.push_back(replacement); 381 } 382 383 rewriter.replaceOp(op, replacements); 384 } 385 386 AlwaysCopyAnalysisState::AlwaysCopyAnalysisState( 387 const BufferizationOptions &options) 388 : AnalysisState(options) { 389 // Note: Allocations must be deallocated with a subsequent run of the buffer 390 // deallocation pass. 391 assert(!options.createDeallocs && 392 "cannot create deallocs with AlwaysCopyBufferizationState"); 393 } 394 395 /// Return `true` if the given OpResult has been decided to bufferize inplace. 396 bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const { 397 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 398 // alloc and copy is inserted. 399 return !bufferizesToMemoryWrite(opOperand); 400 } 401 402 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 403 bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1, 404 Value v2) const { 405 // There is no analysis, so we do not know if the values are equivalent. The 406 // conservative answer is "false". 407 return false; 408 } 409 410 /// Return true if the given tensor (or an aliasing tensor) is yielded from 411 /// the containing block. Also include all aliasing tensors in the same block. 412 bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const { 413 // There is no analysis, so conservatively answer "true". 414 return true; 415 } 416 417 //===----------------------------------------------------------------------===// 418 // Bufferization-specific scoped alloc/dealloc insertion support. 419 //===----------------------------------------------------------------------===// 420 421 /// Create a memref allocation with the given type and dynamic extents. 422 static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type, 423 ValueRange dynShape, 424 const BufferizationOptions &options) { 425 if (options.allocationFn) 426 return (*options.allocationFn)(b, loc, type, dynShape, 427 options.bufferAlignment); 428 429 // Default bufferallocation via AllocOp. 430 Value allocated = b.create<memref::AllocOp>( 431 loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment)); 432 return allocated; 433 } 434 435 /// Creates a memref deallocation. The given memref buffer must have been 436 /// allocated using `createAlloc`. 437 LogicalResult 438 bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, 439 const BufferizationOptions &options) { 440 if (options.deallocationFn) 441 return (*options.deallocationFn)(b, loc, allocatedBuffer); 442 443 // Default buffer deallocation via DeallocOp. 444 b.create<memref::DeallocOp>(loc, allocatedBuffer); 445 return success(); 446 } 447 448 /// Compute the type of the `memref` to use for allocating the buffer for 449 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 450 /// dynamic dimensions in the returned `memref` type. 451 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 452 Value shapedValue, 453 SmallVectorImpl<Value> &dynShape) { 454 MemRefType allocMemRefType = 455 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 456 457 // Compute the dynamic part of the shape. 458 bool reifiedShapes = false; 459 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 460 shapedValue.getDefiningOp())) { 461 ReifiedRankedShapedTypeDims resultDims; 462 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 463 reifiedShapes = true; 464 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 465 auto &shape = resultDims[resultValue.getResultNumber()]; 466 for (const auto &dim : enumerate(allocMemRefType.getShape())) 467 if (ShapedType::isDynamic(dim.value())) 468 dynShape.push_back(shape[dim.index()]); 469 } 470 } 471 472 if (!reifiedShapes) { 473 for (const auto &dim : enumerate(allocMemRefType.getShape())) 474 if (ShapedType::isDynamic(dim.value())) { 475 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 476 shapedValue.getType().isa<MemRefType>()) && 477 "expected MemRef type"); 478 dynShape.push_back( 479 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 480 } 481 } 482 483 return allocMemRefType; 484 } 485 486 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type, 487 ValueRange dynShape, bool skipDealloc) { 488 auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape); 489 allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr()); 490 if (skipDealloc) 491 allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr()); 492 return allocaOp.getResult(); 493 } 494 495 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the 496 /// block in case of a bbArg). 497 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 498 Value shapedValue, 499 Optional<bool> dealloc) { 500 // Take a guard before anything else. 501 OpBuilder::InsertionGuard g(b); 502 503 // Compute allocation memref type. 504 assert(shapedValue.getType().isa<ShapedType>()); 505 SmallVector<Value> dynShape; 506 MemRefType allocMemRefType = 507 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 508 509 // Should be the buffer be deallocated again or should we let it leak? 510 bool skipDealloc; 511 if (dealloc) { 512 skipDealloc = !dealloc.getValue(); 513 } else { 514 assert(shapedValue.getType().isa<TensorType>() && 515 "must specify `dealloc` if non-tensor value is passed"); 516 // Buffer should be not be deallocated if deallocs are generally deactivated 517 // or if the tensor is yielded from a block. 518 skipDealloc = !getOptions().createDeallocs || 519 getAnalysisState().isTensorYielded(shapedValue); 520 } 521 522 // Create the buffer allocation. 523 return createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc); 524 } 525 526 /// Create a memory copy between two memref buffers. 527 LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc, 528 Value from, Value to, 529 const BufferizationOptions &options) { 530 if (options.memCpyFn) 531 return (*options.memCpyFn)(b, loc, from, to); 532 533 b.create<memref::CopyOp>(loc, from, to); 534 return success(); 535 } 536 537 LogicalResult 538 bufferization::createAllocDeallocOps(Operation *op, 539 const BufferizationOptions &options, 540 bool onlyLeakingAllocs, bool *changed) { 541 IRRewriter rewriter(op->getContext()); 542 if (changed) 543 *changed = false; 544 545 // Bufferization creates memref.alloca ops. After bufferization, these must be 546 // rewritten to alloc/dealloc ops as specified in the bufferization options. 547 WalkResult status = op->walk([&](memref::AllocaOp allocaOp) { 548 // Ignore memref.alloca ops that were not created by the bufferization. 549 if (!allocaOp->hasAttr(kBufferAllocationAttr)) 550 return WalkResult::skip(); 551 // If `onlyLeakingAllocs`, process only ops that are marked as 552 // "skip dealloc". 553 bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr); 554 if (onlyLeakingAllocs && !skipDealloc) 555 return WalkResult::skip(); 556 557 // Create alloc. 558 Block *block = allocaOp->getBlock(); 559 rewriter.setInsertionPoint(allocaOp); 560 FailureOr<Value> alloc = 561 createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), 562 allocaOp.dynamicSizes(), options); 563 if (failed(alloc)) 564 return WalkResult::interrupt(); 565 rewriter.replaceOp(allocaOp, *alloc); 566 if (changed) 567 *changed = true; 568 569 // Stop here if the buffer should not be deallocated. 570 if (skipDealloc) 571 return WalkResult::advance(); 572 573 // Create dealloc. 574 rewriter.setInsertionPoint(block->getTerminator()); 575 if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options))) 576 return WalkResult::interrupt(); 577 578 return WalkResult::advance(); 579 }); 580 581 return success(!status.wasInterrupted()); 582 } 583 584 /// Try to hoist all new buffer allocations until the next hoisting barrier. 585 // TODO: Consolidate this function with the existing buffer hoisting pass. 586 LogicalResult 587 bufferization::hoistBufferAllocations(Operation *op, 588 const BufferizationOptions &options) { 589 // Nothing to do if allocation hoisting is deactivated. 590 if (!options.hoistAllocations) 591 return success(); 592 593 // Gather all buffer allocations that were created by the bufferization. 594 SmallVector<Operation *> allocaOps; 595 op->walk([&](memref::AllocaOp allocaOp) { 596 if (allocaOp->hasAttr(kBufferAllocationAttr)) 597 allocaOps.push_back(allocaOp); 598 }); 599 600 for (Operation *allocaOp : allocaOps) { 601 // TODO: Hoisting of allocs with dynamic shape not implemented. 602 if (!allocaOp->getOpOperands().empty()) 603 continue; 604 605 Operation *op = allocaOp->getParentOp(); 606 while (op) { 607 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) { 608 if (bufferizableOp.isAllocationHoistingBarrier()) { 609 break; 610 } 611 } else { 612 // Op is not bufferizable: It may not be safe to hoist across this op. 613 break; 614 } 615 op = op->getParentOp(); 616 } 617 618 // FuncOp is an allocation hoisting barrier, so this should never happen. 619 assert(op && "allocation hoisting barrier not found"); 620 621 // Nothing to do if the insertion point is in the same block. 622 if (op == allocaOp->getParentOp()) 623 continue; 624 625 // `op` may have multiple blocks. Make sure that we insert in the right one. 626 SmallVector<Block *> blocks; 627 for (Region &r : op->getRegions()) 628 for (Block &b : r.getBlocks()) 629 blocks.push_back(&b); 630 auto *insertionBlock = llvm::find_if( 631 blocks, [&](Block *b) { return b->findAncestorOpInBlock(*allocaOp); }); 632 assert(insertionBlock != blocks.end() && "owning block not found"); 633 634 // Move to the beginning of the block. 635 allocaOp->moveBefore(&(*insertionBlock)->front()); 636 } 637 638 return success(); 639 } 640 641 //===----------------------------------------------------------------------===// 642 // Bufferization-specific BlockAndValueMapping support with debugging. 643 //===----------------------------------------------------------------------===// 644 645 bool bufferization::isFunctionArgument(Value value) { 646 auto bbArg = value.dyn_cast<BlockArgument>(); 647 if (!bbArg) 648 return false; 649 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 650 } 651 652 MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType, 653 Attribute memorySpace) { 654 MemRefLayoutAttrInterface layout = {}; 655 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 656 layout, memorySpace); 657 } 658 659 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 660 const BufferizationOptions &options, 661 MemRefLayoutAttrInterface layout, 662 Attribute memorySpace) { 663 // Case 1: Unranked memref type. 664 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 665 assert(!layout && "UnrankedTensorType cannot have a layout map"); 666 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 667 memorySpace); 668 } 669 670 // Case 2: Ranked memref type with specified layout. If fully dynamic layout 671 // maps are not requested, generate a type with `layout`, which is empty (no 672 // layout map) by default. 673 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 674 if (layout || !options.fullyDynamicLayoutMaps) { 675 return MemRefType::get(rankedTensorType.getShape(), 676 rankedTensorType.getElementType(), layout, 677 memorySpace); 678 } 679 680 // Case 3: Ranked memref type with unspecified layout. Choose the most dynamic 681 // one. 682 // TODO: address space decisions to connect with the actual alloc. 683 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 684 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 685 ShapedType::kDynamicStrideOrOffset); 686 AffineMap stridedLayout = makeStridedLinearLayoutMap( 687 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 688 return MemRefType::get(rankedTensorType.getShape(), 689 rankedTensorType.getElementType(), stridedLayout, 690 memorySpace); 691 } 692