1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/AsmState.h" 14 #include "mlir/IR/BlockAndValueMapping.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/IR/Value.h" 19 #include "llvm/Support/Debug.h" 20 21 namespace mlir { 22 namespace bufferization { 23 24 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 25 26 } // namespace bufferization 27 } // namespace mlir 28 29 #define DEBUG_TYPE "bufferizable-op-interface" 30 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 31 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 32 33 using namespace mlir; 34 using namespace bufferization; 35 36 /// Attribute name used to mark the bufferization layout for region 37 /// arguments during linalg comprehensive bufferization. 38 constexpr const ::llvm::StringLiteral 39 bufferization::BufferizableOpInterface::kBufferLayoutAttrName; 40 41 /// Attribute name used to mark region arguments that can be bufferized 42 /// in-place during linalg comprehensive bufferization. 43 constexpr const ::llvm::StringLiteral 44 bufferization::BufferizableOpInterface::kInplaceableAttrName; 45 46 /// Attribute name used to mark allocs that are created by the bufferization. 47 static const char *kBufferAllocationAttr = "bufferization.allocation"; 48 49 /// Attribute name used to mark allocs that should not be deallocated. 50 static const char *kSkipDeallocAttr = "bufferization.skip_dealloc"; 51 52 //===----------------------------------------------------------------------===// 53 // BufferizationOptions 54 //===----------------------------------------------------------------------===// 55 56 // Default constructor for BufferizationOptions. 57 BufferizationOptions::BufferizationOptions() = default; 58 59 BufferizableOpInterface 60 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 61 if (isOpAllowed(op)) 62 return dyn_cast<BufferizableOpInterface>(op); 63 return nullptr; 64 } 65 66 BufferizableOpInterface 67 BufferizationOptions::dynCastBufferizableOp(Value value) const { 68 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 69 if (isOpAllowed(bufferizableOp.getOperation())) 70 return bufferizableOp; 71 return nullptr; 72 } 73 74 void BufferizationOptions::addDialectStateInitializer( 75 StringRef name, const DialectStateInitFn &fn) { 76 stateInitializers.push_back( 77 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 78 } 79 80 //===----------------------------------------------------------------------===// 81 // Helper functions for BufferizableOpInterface 82 //===----------------------------------------------------------------------===// 83 84 static void setInsertionPointAfter(OpBuilder &b, Value value) { 85 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 86 b.setInsertionPointToStart(bbArg.getOwner()); 87 } else { 88 b.setInsertionPointAfter(value.getDefiningOp()); 89 } 90 } 91 92 /// Determine which OpOperand* will alias with `result` if the op is bufferized 93 /// in place. Return an empty vector if the op is not bufferizable. 94 SmallVector<OpOperand *> 95 AnalysisState::getAliasingOpOperand(OpResult result) const { 96 if (Operation *op = result.getDefiningOp()) 97 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) 98 return bufferizableOp.getAliasingOpOperand(result, *this); 99 return {}; 100 } 101 102 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 103 /// in place. Return an empty vector if the op is not bufferizable. 104 SmallVector<OpResult> 105 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 106 if (auto bufferizableOp = 107 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 108 return bufferizableOp.getAliasingOpResult(opOperand, *this); 109 return {}; 110 } 111 112 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 113 /// op is not bufferizable. 114 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 115 if (auto bufferizableOp = 116 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 117 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 118 119 // Unknown op that returns a tensor. The inplace analysis does not support it. 120 // Conservatively return true. 121 return true; 122 } 123 124 /// Return true if `opOperand` bufferizes to a memory write. Return 125 /// `true` if the op is not bufferizable. 126 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 127 if (auto bufferizableOp = 128 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 129 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 130 131 // Unknown op that returns a tensor. The inplace analysis does not support it. 132 // Conservatively return true. 133 return true; 134 } 135 136 /// Return true if `opOperand` does neither read nor write but bufferizes to an 137 /// alias. Return false if the op is not bufferizable. 138 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 139 if (auto bufferizableOp = 140 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 141 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 142 143 // Unknown op that returns a tensor. The inplace analysis does not support it. 144 // Conservatively return false. 145 return false; 146 } 147 148 /// Return true if the given value is read by an op that bufferizes to a memory 149 /// read. Also takes into account ops that create an alias but do not read by 150 /// themselves (e.g., ExtractSliceOp). 151 bool AnalysisState::isValueRead(Value value) const { 152 assert(value.getType().isa<TensorType>() && "expected TensorType"); 153 SmallVector<OpOperand *> workingSet; 154 for (OpOperand &use : value.getUses()) 155 workingSet.push_back(&use); 156 157 while (!workingSet.empty()) { 158 OpOperand *uMaybeReading = workingSet.pop_back_val(); 159 // Skip over all ops that neither read nor write (but create an alias). 160 if (bufferizesToAliasOnly(*uMaybeReading)) 161 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 162 for (OpOperand &use : opResult.getUses()) 163 workingSet.push_back(&use); 164 if (bufferizesToMemoryRead(*uMaybeReading)) 165 return true; 166 } 167 168 return false; 169 } 170 171 // Starting from `value`, follow the use-def chain in reverse, always selecting 172 // the aliasing OpOperands. Find and return Values for which `condition` 173 // evaluates to true. OpOperands of such matching Values are not traversed any 174 // further. 175 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 176 Value value, llvm::function_ref<bool(Value)> condition) const { 177 llvm::SetVector<Value> result, workingSet; 178 workingSet.insert(value); 179 180 while (!workingSet.empty()) { 181 Value value = workingSet.pop_back_val(); 182 if (condition(value) || value.isa<BlockArgument>()) { 183 result.insert(value); 184 continue; 185 } 186 187 OpResult opResult = value.cast<OpResult>(); 188 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 189 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 190 result.insert(value); 191 continue; 192 } 193 194 for (OpOperand *o : opOperands) 195 workingSet.insert(o->get()); 196 } 197 198 return result; 199 } 200 201 // Find the Values of the last preceding write of a given Value. 202 llvm::SetVector<Value> 203 AnalysisState::findLastPrecedingWrite(Value value) const { 204 return findValueInReverseUseDefChain(value, [&](Value value) { 205 Operation *op = value.getDefiningOp(); 206 if (!op) 207 return true; 208 auto bufferizableOp = options.dynCastBufferizableOp(op); 209 if (!bufferizableOp) 210 return true; 211 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 212 }); 213 } 214 215 AnalysisState::AnalysisState(const BufferizationOptions &options) 216 : options(options) { 217 for (const BufferizationOptions::AnalysisStateInitFn &fn : 218 options.stateInitializers) 219 fn(*this); 220 } 221 222 // bufferization.to_memref is not allowed to change the rank. 223 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 224 #ifndef NDEBUG 225 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 226 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 227 rankedTensorType.getRank()) && 228 "to_memref would be invalid: mismatching ranks"); 229 #endif 230 } 231 232 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, 233 const BufferizationOptions &options) { 234 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 235 assert(tensorType && "unexpected non-tensor type"); 236 237 // Replace "%t = to_tensor %m" with %m. 238 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 239 return toTensorOp.memref(); 240 241 // Insert to_memref op. 242 OpBuilder::InsertionGuard g(rewriter); 243 setInsertionPointAfter(rewriter, tensor); 244 Type memrefType = getMemRefType(tensorType, options); 245 ensureToMemrefOpIsValid(tensor, memrefType); 246 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 247 tensor); 248 } 249 250 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate 251 /// a new buffer and copy over data from the existing buffer if out-of-place 252 /// bufferization was decided. 253 FailureOr<Value> 254 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, 255 Optional<ForceInPlacability> overrideInPlace, 256 Optional<Operation *> customCopyInsertionPoint) { 257 const BufferizationOptions &options = analysisState.getOptions(); 258 OpBuilder::InsertionGuard guard(rewriter); 259 Operation *op = opOperand.getOwner(); 260 Location loc = op->getLoc(); 261 SmallVector<OpResult> aliasingOpResults = 262 analysisState.getAliasingOpResult(opOperand); 263 Value operand = opOperand.get(); 264 Value operandBuffer = lookupBuffer(rewriter, operand, options); 265 266 // Can `operandBuffer` be used directly or do we need a copy? 267 bool inplace = 268 overrideInPlace != FORCE_OUT_OF_PLACE && 269 (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand)); 270 if (inplace) 271 return operandBuffer; 272 273 // Bufferizing out-of-place: Allocate a new buffer. 274 // Move insertion point right after `operandBuffer`. That is where the 275 // allocation should be inserted (in the absence of allocation hoisting). 276 setInsertionPointAfter(rewriter, operandBuffer); 277 // Allocate the result buffer. The buffer should be deallocated if the tensor 278 // is not yielded and deallocs are enabled in general. 279 bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { 280 return getAnalysisState().isTensorYielded(v); 281 }); 282 FailureOr<Value> resultBuffer = createAlloc( 283 rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); 284 if (failed(resultBuffer)) 285 return failure(); 286 // Do not copy if the last preceding writes of `operand` are ops that do 287 // not write (skipping ops that merely create aliases). E.g., InitTensorOp. 288 // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA 289 // use-def chain, it returns that value, regardless of whether it is a 290 // memory write or not. 291 SetVector<Value> lastWrites = analysisState.findLastPrecedingWrite(operand); 292 if (llvm::none_of(lastWrites, [&](Value lastWrite) { 293 if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) 294 return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), 295 analysisState); 296 return true; 297 })) 298 return resultBuffer; 299 // Do not copy if the copied data is never read. 300 if (!aliasingOpResults.empty() && 301 !analysisState.bufferizesToMemoryRead(opOperand) && 302 llvm::none_of(aliasingOpResults, [&](OpResult opResult) { 303 return analysisState.isValueRead(opResult); 304 })) 305 return resultBuffer; 306 // Do not copy if this op does not read the data, but writes it. 307 if (analysisState.bufferizesToMemoryWrite(opOperand) && 308 !analysisState.bufferizesToMemoryRead(opOperand)) 309 return resultBuffer; 310 311 if (customCopyInsertionPoint) { 312 rewriter.setInsertionPoint(*customCopyInsertionPoint); 313 } else { 314 // The copy happens right before the op that is bufferized. 315 rewriter.setInsertionPoint(op); 316 } 317 if (failed( 318 createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options))) 319 return failure(); 320 321 return resultBuffer; 322 } 323 324 /// Return the buffer type for a given OpOperand (tensor) after bufferization. 325 BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const { 326 Value tensor = opOperand.get(); 327 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 328 assert(tensorType && "unexpected non-tensor type"); 329 330 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 331 return toTensorOp.memref().getType().cast<BaseMemRefType>(); 332 333 return getMemRefType(tensorType, getOptions()); 334 } 335 336 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 337 Operation *op, 338 ValueRange values) { 339 assert(values.size() == op->getNumResults() && 340 "expected one value per OpResult"); 341 OpBuilder::InsertionGuard g(rewriter); 342 343 // Replace all OpResults with the given values. 344 SmallVector<Value> replacements; 345 for (OpResult opResult : op->getOpResults()) { 346 Value replacement = values[opResult.getResultNumber()]; 347 if (opResult.getType().isa<TensorType>()) { 348 // The OpResult is a tensor. Such values are replaced with memrefs during 349 // bufferization. 350 assert((replacement.getType().isa<MemRefType>() || 351 replacement.getType().isa<UnrankedMemRefType>()) && 352 "tensor op result should be replaced with a memref value"); 353 // The existing uses of the OpResult still expect a tensor. Insert a 354 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 355 // loose all of its users and eventually DCE away. 356 rewriter.setInsertionPointAfter(op); 357 replacement = rewriter.create<bufferization::ToTensorOp>( 358 replacement.getLoc(), replacement); 359 } 360 replacements.push_back(replacement); 361 } 362 363 rewriter.replaceOp(op, replacements); 364 } 365 366 AlwaysCopyAnalysisState::AlwaysCopyAnalysisState( 367 const BufferizationOptions &options) 368 : AnalysisState(options) { 369 // Note: Allocations must be deallocated with a subsequent run of the buffer 370 // deallocation pass. 371 assert(!options.createDeallocs && 372 "cannot create deallocs with AlwaysCopyBufferizationState"); 373 } 374 375 /// Return `true` if the given OpResult has been decided to bufferize inplace. 376 bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const { 377 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 378 // alloc and copy is inserted. 379 return !bufferizesToMemoryWrite(opOperand); 380 } 381 382 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 383 bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1, 384 Value v2) const { 385 // There is no analysis, so we do not know if the values are equivalent. The 386 // conservative answer is "false". 387 return false; 388 } 389 390 /// Return true if the given tensor (or an aliasing tensor) is yielded from 391 /// the containing block. Also include all aliasing tensors in the same block. 392 bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const { 393 // There is no analysis, so conservatively answer "true". 394 return true; 395 } 396 397 //===----------------------------------------------------------------------===// 398 // Bufferization-specific scoped alloc/dealloc insertion support. 399 //===----------------------------------------------------------------------===// 400 401 /// Create a memref allocation with the given type and dynamic extents. 402 static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type, 403 ValueRange dynShape, 404 const BufferizationOptions &options) { 405 if (options.allocationFn) 406 return (*options.allocationFn)(b, loc, type, dynShape, 407 options.bufferAlignment); 408 409 // Default bufferallocation via AllocOp. 410 Value allocated = b.create<memref::AllocOp>( 411 loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment)); 412 return allocated; 413 } 414 415 /// Creates a memref deallocation. The given memref buffer must have been 416 /// allocated using `createAlloc`. 417 LogicalResult 418 bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, 419 const BufferizationOptions &options) { 420 if (options.deallocationFn) 421 return (*options.deallocationFn)(b, loc, allocatedBuffer); 422 423 // Default buffer deallocation via DeallocOp. 424 b.create<memref::DeallocOp>(loc, allocatedBuffer); 425 return success(); 426 } 427 428 /// Compute the type of the `memref` to use for allocating the buffer for 429 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 430 /// dynamic dimensions in the returned `memref` type. 431 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 432 Value shapedValue, 433 SmallVectorImpl<Value> &dynShape) { 434 MemRefType allocMemRefType = 435 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 436 437 // Compute the dynamic part of the shape. 438 bool reifiedShapes = false; 439 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 440 shapedValue.getDefiningOp())) { 441 ReifiedRankedShapedTypeDims resultDims; 442 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 443 reifiedShapes = true; 444 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 445 auto &shape = resultDims[resultValue.getResultNumber()]; 446 for (const auto &dim : enumerate(allocMemRefType.getShape())) 447 if (ShapedType::isDynamic(dim.value())) 448 dynShape.push_back(shape[dim.index()]); 449 } 450 } 451 452 if (!reifiedShapes) { 453 for (const auto &dim : enumerate(allocMemRefType.getShape())) 454 if (ShapedType::isDynamic(dim.value())) { 455 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 456 shapedValue.getType().isa<MemRefType>()) && 457 "expected MemRef type"); 458 dynShape.push_back( 459 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 460 } 461 } 462 463 return allocMemRefType; 464 } 465 466 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type, 467 ValueRange dynShape, bool skipDealloc) { 468 auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape); 469 allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr()); 470 if (skipDealloc) 471 allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr()); 472 return allocaOp.getResult(); 473 } 474 475 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the 476 /// block in case of a bbArg). 477 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 478 Value shapedValue, 479 Optional<bool> dealloc) { 480 // Take a guard before anything else. 481 OpBuilder::InsertionGuard g(b); 482 483 // Compute allocation memref type. 484 assert(shapedValue.getType().isa<ShapedType>()); 485 SmallVector<Value> dynShape; 486 MemRefType allocMemRefType = 487 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 488 489 // Should be the buffer be deallocated again or should we let it leak? 490 bool skipDealloc; 491 if (dealloc) { 492 skipDealloc = !dealloc.getValue(); 493 } else { 494 assert(shapedValue.getType().isa<TensorType>() && 495 "must specify `dealloc` if non-tensor value is passed"); 496 // Buffer should be not be deallocated if deallocs are generally deactivated 497 // or if the tensor is yielded from a block. 498 skipDealloc = !getOptions().createDeallocs || 499 getAnalysisState().isTensorYielded(shapedValue); 500 } 501 502 // Create the buffer allocation. 503 return createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc); 504 } 505 506 /// Create a memory copy between two memref buffers. 507 LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc, 508 Value from, Value to, 509 const BufferizationOptions &options) { 510 if (options.memCpyFn) 511 return (*options.memCpyFn)(b, loc, from, to); 512 513 b.create<memref::CopyOp>(loc, from, to); 514 return success(); 515 } 516 517 LogicalResult 518 bufferization::createAllocDeallocOps(Operation *op, 519 const BufferizationOptions &options, 520 bool onlyLeakingAllocs, bool *changed) { 521 IRRewriter rewriter(op->getContext()); 522 if (changed) 523 *changed = false; 524 525 // Bufferization creates memref.alloca ops. After bufferization, these must be 526 // rewritten to alloc/dealloc ops as specified in the bufferization options. 527 WalkResult status = op->walk([&](memref::AllocaOp allocaOp) { 528 // Ignore memref.alloca ops that were not created by the bufferization. 529 if (!allocaOp->hasAttr(kBufferAllocationAttr)) 530 return WalkResult::skip(); 531 // If `onlyLeakingAllocs`, process only ops that are marked as 532 // "skip dealloc". 533 bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr); 534 if (onlyLeakingAllocs && !skipDealloc) 535 return WalkResult::skip(); 536 537 // Create alloc. 538 Block *block = allocaOp->getBlock(); 539 rewriter.setInsertionPoint(allocaOp); 540 FailureOr<Value> alloc = 541 createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), 542 allocaOp.dynamicSizes(), options); 543 if (failed(alloc)) 544 return WalkResult::interrupt(); 545 rewriter.replaceOp(allocaOp, *alloc); 546 if (changed) 547 *changed = true; 548 549 // Stop here if the buffer should not be deallocated. 550 if (skipDealloc) 551 return WalkResult::advance(); 552 553 // Create dealloc. 554 rewriter.setInsertionPoint(block->getTerminator()); 555 if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options))) 556 return WalkResult::interrupt(); 557 558 return WalkResult::advance(); 559 }); 560 561 return success(!status.wasInterrupted()); 562 } 563 564 /// Try to hoist all new buffer allocations until the next hoisting barrier. 565 // TODO: Consolidate this function with the existing buffer hoisting pass. 566 LogicalResult 567 bufferization::hoistBufferAllocations(Operation *op, 568 const BufferizationOptions &options) { 569 // Nothing to do if allocation hoisting is deactivated. 570 if (!options.hoistAllocations) 571 return success(); 572 573 // Gather all buffer allocations that were created by the bufferization. 574 SmallVector<Operation *> allocaOps; 575 op->walk([&](memref::AllocaOp allocaOp) { 576 if (allocaOp->hasAttr(kBufferAllocationAttr)) 577 allocaOps.push_back(allocaOp); 578 }); 579 580 for (Operation *allocaOp : allocaOps) { 581 // TODO: Hoisting of allocs with dynamic shape not implemented. 582 if (!allocaOp->getOpOperands().empty()) 583 continue; 584 585 Operation *op = allocaOp->getParentOp(); 586 while (op) { 587 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) { 588 if (bufferizableOp.isAllocationHoistingBarrier()) { 589 break; 590 } 591 } else { 592 // Op is not bufferizable: It may not be safe to hoist across this op. 593 break; 594 } 595 op = op->getParentOp(); 596 } 597 598 // FuncOp is an allocation hoisting barrier, so this should never happen. 599 assert(op && "allocation hoisting barrier not found"); 600 601 // Nothing to do if the insertion point is in the same block. 602 if (op == allocaOp->getParentOp()) 603 continue; 604 605 // `op` may have multiple blocks. Make sure that we insert in the right one. 606 SmallVector<Block *> blocks; 607 for (Region &r : op->getRegions()) 608 for (Block &b : r.getBlocks()) 609 blocks.push_back(&b); 610 auto *insertionBlock = llvm::find_if( 611 blocks, [&](Block *b) { return b->findAncestorOpInBlock(*allocaOp); }); 612 assert(insertionBlock != blocks.end() && "owning block not found"); 613 614 // Move to the beginning of the block. 615 allocaOp->moveBefore(&(*insertionBlock)->front()); 616 } 617 618 return success(); 619 } 620 621 //===----------------------------------------------------------------------===// 622 // Bufferization-specific BlockAndValueMapping support with debugging. 623 //===----------------------------------------------------------------------===// 624 625 bool bufferization::isFunctionArgument(Value value) { 626 auto bbArg = value.dyn_cast<BlockArgument>(); 627 if (!bbArg) 628 return false; 629 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 630 } 631 632 MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType, 633 Attribute memorySpace) { 634 MemRefLayoutAttrInterface layout = {}; 635 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 636 layout, memorySpace); 637 } 638 639 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 640 const BufferizationOptions &options, 641 MemRefLayoutAttrInterface layout, 642 Attribute memorySpace) { 643 // Case 1: Unranked memref type. 644 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 645 assert(!layout && "UnrankedTensorType cannot have a layout map"); 646 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 647 memorySpace); 648 } 649 650 // Case 2: Ranked memref type with specified layout. If fully dynamic layout 651 // maps are not requested, generate a type with `layout`, which is empty (no 652 // layout map) by default. 653 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 654 if (layout || !options.fullyDynamicLayoutMaps) { 655 return MemRefType::get(rankedTensorType.getShape(), 656 rankedTensorType.getElementType(), layout, 657 memorySpace); 658 } 659 660 // Case 3: Ranked memref type with unspecified layout. Choose the most dynamic 661 // one. 662 // TODO: address space decisions to connect with the actual alloc. 663 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 664 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 665 ShapedType::kDynamicStrideOrOffset); 666 AffineMap stridedLayout = makeStridedLinearLayoutMap( 667 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 668 return MemRefType::get(rankedTensorType.getShape(), 669 rankedTensorType.getElementType(), stridedLayout, 670 memorySpace); 671 } 672