1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/AsmState.h" 14 #include "mlir/IR/BlockAndValueMapping.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/IR/Value.h" 19 #include "llvm/Support/Debug.h" 20 21 namespace mlir { 22 namespace bufferization { 23 24 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 25 26 } // namespace bufferization 27 } // namespace mlir 28 29 #define DEBUG_TYPE "bufferizable-op-interface" 30 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 31 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 32 33 using namespace mlir; 34 using namespace bufferization; 35 36 /// Attribute name used to mark region arguments that can be bufferized 37 /// in-place during linalg comprehensive bufferization. 38 constexpr const ::llvm::StringLiteral 39 bufferization::BufferizableOpInterface::kInplaceableAttrName; 40 41 /// Attribute name used to mark allocs that are created by the bufferization. 42 static const char *kBufferAllocationAttr = "bufferization.allocation"; 43 44 /// Attribute name used to mark allocs that should not be deallocated. 45 static const char *kSkipDeallocAttr = "bufferization.skip_dealloc"; 46 47 //===----------------------------------------------------------------------===// 48 // BufferizationOptions 49 //===----------------------------------------------------------------------===// 50 51 // Default constructor for BufferizationOptions. 52 BufferizationOptions::BufferizationOptions() = default; 53 54 bool BufferizationOptions::isOpAllowed(Operation *op) const { 55 // Special case: If function boundary bufferization is deactivated, do not 56 // allow ops that belong to the `func` dialect. 57 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 58 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 59 return false; 60 61 // All other ops: Allow/disallow according to filter. 62 bool isAllowed = !filterHasAllowRule(); 63 for (const OpFilterEntry &entry : opFilter) { 64 bool filterResult = entry.fn(op); 65 switch (entry.type) { 66 case OpFilterEntry::ALLOW: 67 isAllowed |= filterResult; 68 break; 69 case OpFilterEntry::DENY: 70 if (filterResult) 71 // DENY filter matches. This op is no allowed. (Even if other ALLOW 72 // filters may match.) 73 return false; 74 }; 75 } 76 return isAllowed; 77 } 78 79 BufferizableOpInterface 80 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 81 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 82 if (!bufferizableOp) 83 return nullptr; 84 if (!isOpAllowed(op)) 85 return nullptr; 86 return bufferizableOp; 87 } 88 89 BufferizableOpInterface 90 BufferizationOptions::dynCastBufferizableOp(Value value) const { 91 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 92 if (isOpAllowed(bufferizableOp.getOperation())) 93 return bufferizableOp; 94 return nullptr; 95 } 96 97 void BufferizationOptions::addDialectStateInitializer( 98 StringRef name, const DialectStateInitFn &fn) { 99 stateInitializers.push_back( 100 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 101 } 102 103 //===----------------------------------------------------------------------===// 104 // Helper functions for BufferizableOpInterface 105 //===----------------------------------------------------------------------===// 106 107 static void setInsertionPointAfter(OpBuilder &b, Value value) { 108 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 109 b.setInsertionPointToStart(bbArg.getOwner()); 110 } else { 111 b.setInsertionPointAfter(value.getDefiningOp()); 112 } 113 } 114 115 /// Determine which OpOperand* will alias with `result` if the op is bufferized 116 /// in place. Return an empty vector if the op is not bufferizable. 117 SmallVector<OpOperand *> 118 AnalysisState::getAliasingOpOperand(OpResult result) const { 119 if (Operation *op = result.getDefiningOp()) 120 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 121 return bufferizableOp.getAliasingOpOperand(result, *this); 122 return {}; 123 } 124 125 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 126 /// in place. Return an empty vector if the op is not bufferizable. 127 SmallVector<OpResult> 128 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 129 if (auto bufferizableOp = 130 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 131 return bufferizableOp.getAliasingOpResult(opOperand, *this); 132 return {}; 133 } 134 135 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 136 /// op is not bufferizable. 137 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 138 if (auto bufferizableOp = 139 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 140 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 141 142 // Unknown op that returns a tensor. The inplace analysis does not support it. 143 // Conservatively return true. 144 return true; 145 } 146 147 /// Return true if `opOperand` bufferizes to a memory write. Return 148 /// `true` if the op is not bufferizable. 149 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 150 if (auto bufferizableOp = 151 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 152 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 153 154 // Unknown op that returns a tensor. The inplace analysis does not support it. 155 // Conservatively return true. 156 return true; 157 } 158 159 /// Return true if `opOperand` does neither read nor write but bufferizes to an 160 /// alias. Return false if the op is not bufferizable. 161 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 162 if (auto bufferizableOp = 163 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 164 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 165 166 // Unknown op that returns a tensor. The inplace analysis does not support it. 167 // Conservatively return false. 168 return false; 169 } 170 171 /// Return true if the given value is read by an op that bufferizes to a memory 172 /// read. Also takes into account ops that create an alias but do not read by 173 /// themselves (e.g., ExtractSliceOp). 174 bool AnalysisState::isValueRead(Value value) const { 175 assert(value.getType().isa<TensorType>() && "expected TensorType"); 176 SmallVector<OpOperand *> workingSet; 177 for (OpOperand &use : value.getUses()) 178 workingSet.push_back(&use); 179 180 while (!workingSet.empty()) { 181 OpOperand *uMaybeReading = workingSet.pop_back_val(); 182 // Skip over all ops that neither read nor write (but create an alias). 183 if (bufferizesToAliasOnly(*uMaybeReading)) 184 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 185 for (OpOperand &use : opResult.getUses()) 186 workingSet.push_back(&use); 187 if (bufferizesToMemoryRead(*uMaybeReading)) 188 return true; 189 } 190 191 return false; 192 } 193 194 // Starting from `value`, follow the use-def chain in reverse, always selecting 195 // the aliasing OpOperands. Find and return Values for which `condition` 196 // evaluates to true. OpOperands of such matching Values are not traversed any 197 // further. 198 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 199 Value value, llvm::function_ref<bool(Value)> condition) const { 200 llvm::SetVector<Value> result, workingSet; 201 workingSet.insert(value); 202 203 while (!workingSet.empty()) { 204 Value value = workingSet.pop_back_val(); 205 if (condition(value) || value.isa<BlockArgument>()) { 206 result.insert(value); 207 continue; 208 } 209 210 OpResult opResult = value.cast<OpResult>(); 211 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 212 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 213 result.insert(value); 214 continue; 215 } 216 217 for (OpOperand *o : opOperands) 218 workingSet.insert(o->get()); 219 } 220 221 return result; 222 } 223 224 // Find the Values of the last preceding write of a given Value. 225 llvm::SetVector<Value> 226 AnalysisState::findLastPrecedingWrite(Value value) const { 227 return findValueInReverseUseDefChain(value, [&](Value value) { 228 Operation *op = value.getDefiningOp(); 229 if (!op) 230 return true; 231 auto bufferizableOp = options.dynCastBufferizableOp(op); 232 if (!bufferizableOp) 233 return true; 234 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 235 }); 236 } 237 238 AnalysisState::AnalysisState(const BufferizationOptions &options) 239 : options(options) { 240 for (const BufferizationOptions::AnalysisStateInitFn &fn : 241 options.stateInitializers) 242 fn(*this); 243 } 244 245 // bufferization.to_memref is not allowed to change the rank. 246 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 247 #ifndef NDEBUG 248 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 249 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 250 rankedTensorType.getRank()) && 251 "to_memref would be invalid: mismatching ranks"); 252 #endif 253 } 254 255 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, 256 const BufferizationOptions &options) { 257 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 258 assert(tensorType && "unexpected non-tensor type"); 259 260 // Replace "%t = to_tensor %m" with %m. 261 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 262 return toTensorOp.memref(); 263 264 // Insert to_memref op. 265 OpBuilder::InsertionGuard g(rewriter); 266 setInsertionPointAfter(rewriter, tensor); 267 Type memrefType = getMemRefType(tensorType, options); 268 ensureToMemrefOpIsValid(tensor, memrefType); 269 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 270 tensor); 271 } 272 273 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate 274 /// a new buffer and copy over data from the existing buffer if out-of-place 275 /// bufferization was decided. 276 FailureOr<Value> 277 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, 278 Optional<ForceInPlacability> overrideInPlace, 279 Optional<Operation *> customCopyInsertionPoint) { 280 const BufferizationOptions &options = analysisState.getOptions(); 281 OpBuilder::InsertionGuard guard(rewriter); 282 Operation *op = opOperand.getOwner(); 283 Location loc = op->getLoc(); 284 SmallVector<OpResult> aliasingOpResults = 285 analysisState.getAliasingOpResult(opOperand); 286 Value operand = opOperand.get(); 287 Value operandBuffer = lookupBuffer(rewriter, operand, options); 288 289 // Can `operandBuffer` be used directly or do we need a copy? 290 bool inplace = 291 overrideInPlace != FORCE_OUT_OF_PLACE && 292 (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand)); 293 if (inplace) 294 return operandBuffer; 295 296 // Bufferizing out-of-place: Allocate a new buffer. 297 // Move insertion point right after `operandBuffer`. That is where the 298 // allocation should be inserted (in the absence of allocation hoisting). 299 setInsertionPointAfter(rewriter, operandBuffer); 300 // Allocate the result buffer. The buffer should be deallocated if the tensor 301 // is not yielded and deallocs are enabled in general. 302 bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { 303 return getAnalysisState().isTensorYielded(v); 304 }); 305 FailureOr<Value> resultBuffer = createAlloc( 306 rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); 307 if (failed(resultBuffer)) 308 return failure(); 309 // Do not copy the buffer if its contents are undefined. 310 if (analysisState.hasUndefinedContents(&opOperand)) 311 return resultBuffer; 312 // Do not copy if the copied data is never read. 313 if (!aliasingOpResults.empty() && 314 !analysisState.bufferizesToMemoryRead(opOperand) && 315 llvm::none_of(aliasingOpResults, [&](OpResult opResult) { 316 return analysisState.isValueRead(opResult); 317 })) 318 return resultBuffer; 319 // Do not copy if this op does not read the data, but writes it. 320 if (analysisState.bufferizesToMemoryWrite(opOperand) && 321 !analysisState.bufferizesToMemoryRead(opOperand)) 322 return resultBuffer; 323 324 if (customCopyInsertionPoint) { 325 rewriter.setInsertionPoint(*customCopyInsertionPoint); 326 } else { 327 // The copy happens right before the op that is bufferized. 328 rewriter.setInsertionPoint(op); 329 } 330 if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer))) 331 return failure(); 332 333 return resultBuffer; 334 } 335 336 /// Return the buffer type for a given OpOperand (tensor) after bufferization. 337 BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const { 338 Value tensor = opOperand.get(); 339 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 340 assert(tensorType && "unexpected non-tensor type"); 341 342 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 343 return toTensorOp.memref().getType().cast<BaseMemRefType>(); 344 345 return getMemRefType(tensorType, getOptions()); 346 } 347 348 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 349 Operation *op, 350 ValueRange values) { 351 assert(values.size() == op->getNumResults() && 352 "expected one value per OpResult"); 353 OpBuilder::InsertionGuard g(rewriter); 354 355 // Replace all OpResults with the given values. 356 SmallVector<Value> replacements; 357 for (OpResult opResult : op->getOpResults()) { 358 Value replacement = values[opResult.getResultNumber()]; 359 if (opResult.getType().isa<TensorType>()) { 360 // The OpResult is a tensor. Such values are replaced with memrefs during 361 // bufferization. 362 assert((replacement.getType().isa<MemRefType>() || 363 replacement.getType().isa<UnrankedMemRefType>()) && 364 "tensor op result should be replaced with a memref value"); 365 // The existing uses of the OpResult still expect a tensor. Insert a 366 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 367 // loose all of its users and eventually DCE away. 368 rewriter.setInsertionPointAfter(op); 369 replacement = rewriter.create<bufferization::ToTensorOp>( 370 replacement.getLoc(), replacement); 371 } 372 replacements.push_back(replacement); 373 } 374 375 rewriter.replaceOp(op, replacements); 376 } 377 378 AlwaysCopyAnalysisState::AlwaysCopyAnalysisState( 379 const BufferizationOptions &options) 380 : AnalysisState(options) { 381 // Note: Allocations must be deallocated with a subsequent run of the buffer 382 // deallocation pass. 383 assert(!options.createDeallocs && 384 "cannot create deallocs with AlwaysCopyBufferizationState"); 385 } 386 387 /// Return `true` if the given OpResult has been decided to bufferize inplace. 388 bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const { 389 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 390 // alloc and copy is inserted. 391 return !bufferizesToMemoryWrite(opOperand); 392 } 393 394 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 395 bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1, 396 Value v2) const { 397 // There is no analysis, so we do not know if the values are equivalent. The 398 // conservative answer is "false". 399 return false; 400 } 401 402 /// Return `true` if the given tensor has undefined contents. 403 bool AlwaysCopyAnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 404 // There is no analysis, so the conservative answer is "false". 405 return false; 406 } 407 408 /// Return true if the given tensor (or an aliasing tensor) is yielded from 409 /// the containing block. Also include all aliasing tensors in the same block. 410 bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const { 411 // There is no analysis, so conservatively answer "true". 412 return true; 413 } 414 415 //===----------------------------------------------------------------------===// 416 // Bufferization-specific scoped alloc/dealloc insertion support. 417 //===----------------------------------------------------------------------===// 418 419 /// Create a memref allocation with the given type and dynamic extents. 420 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 421 MemRefType type, 422 ValueRange dynShape) const { 423 if (allocationFn) 424 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 425 426 // Default bufferallocation via AllocOp. 427 Value allocated = b.create<memref::AllocOp>( 428 loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment)); 429 return allocated; 430 } 431 432 /// Creates a memref deallocation. The given memref buffer must have been 433 /// allocated using `createAlloc`. 434 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 435 Value allocatedBuffer) const { 436 if (deallocationFn) 437 return (*deallocationFn)(b, loc, allocatedBuffer); 438 439 // Default buffer deallocation via DeallocOp. 440 b.create<memref::DeallocOp>(loc, allocatedBuffer); 441 return success(); 442 } 443 444 static MemRefType getContiguousMemRefType(ShapedType shapedType, 445 Attribute memorySpace = {}) { 446 MemRefLayoutAttrInterface layout = {}; 447 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 448 layout, memorySpace); 449 } 450 451 /// Compute the type of the `memref` to use for allocating the buffer for 452 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 453 /// dynamic dimensions in the returned `memref` type. 454 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 455 Value shapedValue, 456 SmallVectorImpl<Value> &dynShape) { 457 MemRefType allocMemRefType = 458 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 459 460 // Compute the dynamic part of the shape. 461 bool reifiedShapes = false; 462 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 463 shapedValue.getDefiningOp())) { 464 ReifiedRankedShapedTypeDims resultDims; 465 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 466 reifiedShapes = true; 467 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 468 auto &shape = resultDims[resultValue.getResultNumber()]; 469 for (const auto &dim : enumerate(allocMemRefType.getShape())) 470 if (ShapedType::isDynamic(dim.value())) 471 dynShape.push_back(shape[dim.index()]); 472 } 473 } 474 475 if (!reifiedShapes) { 476 for (const auto &dim : enumerate(allocMemRefType.getShape())) 477 if (ShapedType::isDynamic(dim.value())) { 478 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 479 shapedValue.getType().isa<MemRefType>()) && 480 "expected MemRef type"); 481 dynShape.push_back( 482 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 483 } 484 } 485 486 return allocMemRefType; 487 } 488 489 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type, 490 ValueRange dynShape, bool skipDealloc) { 491 auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape); 492 allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr()); 493 if (skipDealloc) 494 allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr()); 495 return allocaOp.getResult(); 496 } 497 498 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the 499 /// block in case of a bbArg). 500 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 501 Value shapedValue, 502 Optional<bool> dealloc) { 503 // Take a guard before anything else. 504 OpBuilder::InsertionGuard g(b); 505 506 // Compute allocation memref type. 507 assert(shapedValue.getType().isa<ShapedType>()); 508 SmallVector<Value> dynShape; 509 MemRefType allocMemRefType = 510 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 511 512 // Should be the buffer be deallocated again or should we let it leak? 513 bool skipDealloc; 514 if (dealloc) { 515 skipDealloc = !dealloc.getValue(); 516 } else { 517 assert(shapedValue.getType().isa<TensorType>() && 518 "must specify `dealloc` if non-tensor value is passed"); 519 // Buffer should be not be deallocated if deallocs are generally deactivated 520 // or if the tensor is yielded from a block. 521 skipDealloc = !getOptions().createDeallocs || 522 getAnalysisState().isTensorYielded(shapedValue); 523 } 524 525 // Create the buffer allocation. 526 return createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc); 527 } 528 529 /// Create a memory copy between two memref buffers. 530 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 531 Value from, Value to) const { 532 if (memCpyFn) 533 return (*memCpyFn)(b, loc, from, to); 534 535 b.create<memref::CopyOp>(loc, from, to); 536 return success(); 537 } 538 539 LogicalResult 540 bufferization::createAllocDeallocOps(Operation *op, 541 const BufferizationOptions &options, 542 bool onlyLeakingAllocs, bool *changed) { 543 IRRewriter rewriter(op->getContext()); 544 if (changed) 545 *changed = false; 546 547 // Bufferization creates memref.alloca ops. After bufferization, these must be 548 // rewritten to alloc/dealloc ops as specified in the bufferization options. 549 WalkResult status = op->walk([&](memref::AllocaOp allocaOp) { 550 // Ignore memref.alloca ops that were not created by the bufferization. 551 if (!allocaOp->hasAttr(kBufferAllocationAttr)) 552 return WalkResult::skip(); 553 // If `onlyLeakingAllocs`, process only ops that are marked as 554 // "skip dealloc". 555 bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr); 556 if (onlyLeakingAllocs && !skipDealloc) 557 return WalkResult::skip(); 558 559 // Create alloc. 560 Block *block = allocaOp->getBlock(); 561 rewriter.setInsertionPoint(allocaOp); 562 FailureOr<Value> alloc = 563 options.createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), 564 allocaOp.dynamicSizes()); 565 if (failed(alloc)) 566 return WalkResult::interrupt(); 567 rewriter.replaceOp(allocaOp, *alloc); 568 if (changed) 569 *changed = true; 570 571 // Stop here if the buffer should not be deallocated. 572 if (skipDealloc) 573 return WalkResult::advance(); 574 575 // Create dealloc. 576 rewriter.setInsertionPoint(block->getTerminator()); 577 if (failed(options.createDealloc(rewriter, alloc->getLoc(), *alloc))) 578 return WalkResult::interrupt(); 579 580 return WalkResult::advance(); 581 }); 582 583 return success(!status.wasInterrupted()); 584 } 585 586 /// Try to hoist all new buffer allocations until the next hoisting barrier. 587 // TODO: Consolidate this function with the existing buffer hoisting pass. 588 LogicalResult 589 bufferization::hoistBufferAllocations(Operation *op, 590 const BufferizationOptions &options) { 591 // Nothing to do if allocation hoisting is deactivated. 592 if (!options.hoistAllocations) 593 return success(); 594 595 // Gather all buffer allocations that were created by the bufferization. 596 SmallVector<Operation *> allocaOps; 597 op->walk([&](memref::AllocaOp allocaOp) { 598 if (allocaOp->hasAttr(kBufferAllocationAttr)) 599 allocaOps.push_back(allocaOp); 600 }); 601 602 for (Operation *allocaOp : allocaOps) { 603 // TODO: Hoisting of allocs with dynamic shape not implemented. 604 if (!allocaOp->getOpOperands().empty()) 605 continue; 606 607 Operation *op = allocaOp->getParentOp(); 608 while (op) { 609 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) { 610 if (bufferizableOp.isAllocationHoistingBarrier()) { 611 break; 612 } 613 } else { 614 // Op is not bufferizable: It may not be safe to hoist across this op. 615 break; 616 } 617 op = op->getParentOp(); 618 } 619 620 // FuncOp is an allocation hoisting barrier, so this should never happen. 621 assert(op && "allocation hoisting barrier not found"); 622 623 // Nothing to do if the insertion point is in the same block. 624 if (op == allocaOp->getParentOp()) 625 continue; 626 627 // `op` may have multiple blocks. Make sure that we insert in the right one. 628 SmallVector<Block *> blocks; 629 for (Region &r : op->getRegions()) 630 for (Block &b : r.getBlocks()) 631 blocks.push_back(&b); 632 auto *insertionBlock = llvm::find_if( 633 blocks, [&](Block *b) { return b->findAncestorOpInBlock(*allocaOp); }); 634 assert(insertionBlock != blocks.end() && "owning block not found"); 635 636 // Move to the beginning of the block. 637 allocaOp->moveBefore(&(*insertionBlock)->front()); 638 } 639 640 return success(); 641 } 642 643 //===----------------------------------------------------------------------===// 644 // Bufferization-specific BlockAndValueMapping support with debugging. 645 //===----------------------------------------------------------------------===// 646 647 bool bufferization::isFunctionArgument(Value value) { 648 auto bbArg = value.dyn_cast<BlockArgument>(); 649 if (!bbArg) 650 return false; 651 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 652 } 653 654 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 655 const BufferizationOptions &options, 656 MemRefLayoutAttrInterface layout, 657 Attribute memorySpace) { 658 // Case 1: Unranked memref type. 659 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 660 assert(!layout && "UnrankedTensorType cannot have a layout map"); 661 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 662 memorySpace); 663 } 664 665 // Case 2: Ranked memref type with specified layout. 666 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 667 if (layout) { 668 return MemRefType::get(rankedTensorType.getShape(), 669 rankedTensorType.getElementType(), layout, 670 memorySpace); 671 } 672 673 // Case 3: Configured with "fully dynamic layout maps". 674 if (options.unknownTypeConversion == 675 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap) 676 return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); 677 678 // Case 4: Configured with "static identity layout maps". 679 if (options.unknownTypeConversion == 680 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) 681 return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); 682 683 llvm_unreachable("InferLayoutMap is an invalid option"); 684 } 685 686 BaseMemRefType 687 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 688 Attribute memorySpace) { 689 // Case 1: Unranked memref type. 690 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 691 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 692 memorySpace); 693 } 694 695 // Case 2: Ranked memref type. 696 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 697 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 698 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 699 ShapedType::kDynamicStrideOrOffset); 700 AffineMap stridedLayout = makeStridedLinearLayoutMap( 701 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 702 return MemRefType::get(rankedTensorType.getShape(), 703 rankedTensorType.getElementType(), stridedLayout, 704 memorySpace); 705 } 706 707 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 708 /// the given tensor type is unranked, return an unranked MemRef type. 709 BaseMemRefType 710 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 711 Attribute memorySpace) { 712 // Case 1: Unranked memref type. 713 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 714 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 715 memorySpace); 716 } 717 718 // Case 2: Ranked memref type. 719 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 720 MemRefLayoutAttrInterface layout = {}; 721 return MemRefType::get(rankedTensorType.getShape(), 722 rankedTensorType.getElementType(), layout, 723 memorySpace); 724 } 725