1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/IR/AsmState.h" 13 #include "mlir/IR/BlockAndValueMapping.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/IR/Value.h" 18 #include "llvm/Support/Debug.h" 19 20 namespace mlir { 21 namespace bufferization { 22 23 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 24 25 } // namespace bufferization 26 } // namespace mlir 27 28 #define DEBUG_TYPE "bufferizable-op-interface" 29 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 30 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 31 32 using namespace mlir; 33 using namespace bufferization; 34 35 /// Attribute name used to mark the bufferization layout for region 36 /// arguments during linalg comprehensive bufferization. 37 constexpr const ::llvm::StringLiteral 38 bufferization::BufferizableOpInterface::kBufferLayoutAttrName; 39 40 /// Attribute name used to mark region arguments that can be bufferized 41 /// in-place during linalg comprehensive bufferization. 42 constexpr const ::llvm::StringLiteral 43 bufferization::BufferizableOpInterface::kInplaceableAttrName; 44 45 /// Attribute name used to mark allocs that are created by the bufferization. 46 static const char *kBufferAllocationAttr = "bufferization.allocation"; 47 48 /// Attribute name used to mark allocs that should not be deallocated. 49 static const char *kSkipDeallocAttr = "bufferization.skip_dealloc"; 50 51 //===----------------------------------------------------------------------===// 52 // BufferizationOptions 53 //===----------------------------------------------------------------------===// 54 55 // Default constructor for BufferizationOptions. 56 BufferizationOptions::BufferizationOptions() = default; 57 58 BufferizableOpInterface 59 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 60 if (isOpAllowed(op)) 61 return dyn_cast<BufferizableOpInterface>(op); 62 return nullptr; 63 } 64 65 BufferizableOpInterface 66 BufferizationOptions::dynCastBufferizableOp(Value value) const { 67 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 68 if (isOpAllowed(bufferizableOp.getOperation())) 69 return bufferizableOp; 70 return nullptr; 71 } 72 73 void BufferizationOptions::addDialectStateInitializer( 74 StringRef name, const DialectStateInitFn &fn) { 75 stateInitializers.push_back( 76 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 77 } 78 79 //===----------------------------------------------------------------------===// 80 // Helper functions for BufferizableOpInterface 81 //===----------------------------------------------------------------------===// 82 83 static void setInsertionPointAfter(OpBuilder &b, Value value) { 84 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 85 b.setInsertionPointToStart(bbArg.getOwner()); 86 } else { 87 b.setInsertionPointAfter(value.getDefiningOp()); 88 } 89 } 90 91 /// Determine which OpOperand* will alias with `result` if the op is bufferized 92 /// in place. Return an empty vector if the op is not bufferizable. 93 SmallVector<OpOperand *> 94 AnalysisState::getAliasingOpOperand(OpResult result) const { 95 if (Operation *op = result.getDefiningOp()) 96 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) 97 return bufferizableOp.getAliasingOpOperand(result, *this); 98 return {}; 99 } 100 101 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 102 /// in place. Return an empty vector if the op is not bufferizable. 103 SmallVector<OpResult> 104 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 105 if (auto bufferizableOp = 106 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 107 return bufferizableOp.getAliasingOpResult(opOperand, *this); 108 return {}; 109 } 110 111 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 112 /// op is not bufferizable. 113 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 114 if (auto bufferizableOp = 115 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 116 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 117 118 // Unknown op that returns a tensor. The inplace analysis does not support it. 119 // Conservatively return true. 120 return true; 121 } 122 123 /// Return true if `opOperand` bufferizes to a memory write. Return 124 /// `true` if the op is not bufferizable. 125 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 126 if (auto bufferizableOp = 127 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 128 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 129 130 // Unknown op that returns a tensor. The inplace analysis does not support it. 131 // Conservatively return true. 132 return true; 133 } 134 135 /// Return true if `opOperand` does neither read nor write but bufferizes to an 136 /// alias. Return false if the op is not bufferizable. 137 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 138 if (auto bufferizableOp = 139 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 140 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 141 142 // Unknown op that returns a tensor. The inplace analysis does not support it. 143 // Conservatively return false. 144 return false; 145 } 146 147 /// Return true if the given value is read by an op that bufferizes to a memory 148 /// read. Also takes into account ops that create an alias but do not read by 149 /// themselves (e.g., ExtractSliceOp). 150 bool AnalysisState::isValueRead(Value value) const { 151 assert(value.getType().isa<TensorType>() && "expected TensorType"); 152 SmallVector<OpOperand *> workingSet; 153 for (OpOperand &use : value.getUses()) 154 workingSet.push_back(&use); 155 156 while (!workingSet.empty()) { 157 OpOperand *uMaybeReading = workingSet.pop_back_val(); 158 // Skip over all ops that neither read nor write (but create an alias). 159 if (bufferizesToAliasOnly(*uMaybeReading)) 160 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 161 for (OpOperand &use : opResult.getUses()) 162 workingSet.push_back(&use); 163 if (bufferizesToMemoryRead(*uMaybeReading)) 164 return true; 165 } 166 167 return false; 168 } 169 170 // Starting from `value`, follow the use-def chain in reverse, always selecting 171 // the aliasing OpOperands. Find and return Values for which `condition` 172 // evaluates to true. OpOperands of such matching Values are not traversed any 173 // further. 174 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 175 Value value, llvm::function_ref<bool(Value)> condition) const { 176 llvm::SetVector<Value> result, workingSet; 177 workingSet.insert(value); 178 179 while (!workingSet.empty()) { 180 Value value = workingSet.pop_back_val(); 181 if (condition(value) || value.isa<BlockArgument>()) { 182 result.insert(value); 183 continue; 184 } 185 186 OpResult opResult = value.cast<OpResult>(); 187 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 188 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 189 result.insert(value); 190 continue; 191 } 192 193 for (OpOperand *o : opOperands) 194 workingSet.insert(o->get()); 195 } 196 197 return result; 198 } 199 200 // Find the Values of the last preceding write of a given Value. 201 llvm::SetVector<Value> 202 AnalysisState::findLastPrecedingWrite(Value value) const { 203 return findValueInReverseUseDefChain(value, [&](Value value) { 204 Operation *op = value.getDefiningOp(); 205 if (!op) 206 return true; 207 auto bufferizableOp = options.dynCastBufferizableOp(op); 208 if (!bufferizableOp) 209 return true; 210 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 211 }); 212 } 213 214 AnalysisState::AnalysisState(const BufferizationOptions &options) 215 : options(options) { 216 for (const BufferizationOptions::AnalysisStateInitFn &fn : 217 options.stateInitializers) 218 fn(*this); 219 } 220 221 // bufferization.to_memref is not allowed to change the rank. 222 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 223 #ifndef NDEBUG 224 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 225 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 226 rankedTensorType.getRank()) && 227 "to_memref would be invalid: mismatching ranks"); 228 #endif 229 } 230 231 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, 232 const BufferizationOptions &options) { 233 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 234 assert(tensorType && "unexpected non-tensor type"); 235 236 // Replace "%t = to_tensor %m" with %m. 237 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 238 return toTensorOp.memref(); 239 240 // Insert to_memref op. 241 OpBuilder::InsertionGuard g(rewriter); 242 setInsertionPointAfter(rewriter, tensor); 243 Type memrefType = getMemRefType(tensorType, options); 244 ensureToMemrefOpIsValid(tensor, memrefType); 245 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 246 tensor); 247 } 248 249 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate 250 /// a new buffer and copy over data from the existing buffer if out-of-place 251 /// bufferization is necessary. 252 FailureOr<Value> 253 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, 254 bool forceInPlace, 255 Optional<Operation *> customCopyInsertionPoint) { 256 const BufferizationOptions &options = analysisState.getOptions(); 257 OpBuilder::InsertionGuard guard(rewriter); 258 Operation *op = opOperand.getOwner(); 259 Location loc = op->getLoc(); 260 SmallVector<OpResult> aliasingOpResults = 261 analysisState.getAliasingOpResult(opOperand); 262 Value operand = opOperand.get(); 263 Value operandBuffer = lookupBuffer(rewriter, operand, options); 264 265 if (forceInPlace || analysisState.isInPlace(opOperand)) 266 return operandBuffer; 267 268 // Bufferizing out-of-place: Allocate a new buffer. 269 // Move insertion point right after `operandBuffer`. That is where the 270 // allocation should be inserted (in the absence of allocation hoisting). 271 setInsertionPointAfter(rewriter, operandBuffer); 272 // Allocate the result buffer. The buffer should be deallocated if the tensor 273 // is not yielded and deallocs are enabled in general. 274 bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { 275 return getAnalysisState().isTensorYielded(v); 276 }); 277 FailureOr<Value> resultBuffer = createAlloc( 278 rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); 279 if (failed(resultBuffer)) 280 return failure(); 281 // Do not copy if the last preceding writes of `operand` are ops that do 282 // not write (skipping ops that merely create aliases). E.g., InitTensorOp. 283 // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA 284 // use-def chain, it returns that value, regardless of whether it is a 285 // memory write or not. 286 SetVector<Value> lastWrites = analysisState.findLastPrecedingWrite(operand); 287 if (llvm::none_of(lastWrites, [&](Value lastWrite) { 288 if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) 289 return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), 290 analysisState); 291 return true; 292 })) 293 return resultBuffer; 294 // Do not copy if the copied data is never read. 295 if (!aliasingOpResults.empty() && 296 !analysisState.bufferizesToMemoryRead(opOperand) && 297 llvm::none_of(aliasingOpResults, [&](OpResult opResult) { 298 return analysisState.isValueRead(opResult); 299 })) 300 return resultBuffer; 301 // Do not copy if this op does not read the data, but writes it. 302 if (analysisState.bufferizesToMemoryWrite(opOperand) && 303 !analysisState.bufferizesToMemoryRead(opOperand)) 304 return resultBuffer; 305 306 if (customCopyInsertionPoint) { 307 rewriter.setInsertionPoint(*customCopyInsertionPoint); 308 } else { 309 // The copy happens right before the op that is bufferized. 310 rewriter.setInsertionPoint(op); 311 } 312 if (failed( 313 createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options))) 314 return failure(); 315 316 return resultBuffer; 317 } 318 319 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 320 Operation *op, 321 ValueRange values) { 322 assert(values.size() == op->getNumResults() && 323 "expected one value per OpResult"); 324 OpBuilder::InsertionGuard g(rewriter); 325 326 // Replace all OpResults with the given values. 327 SmallVector<Value> replacements; 328 for (OpResult opResult : op->getOpResults()) { 329 Value replacement = values[opResult.getResultNumber()]; 330 if (opResult.getType().isa<TensorType>()) { 331 // The OpResult is a tensor. Such values are replaced with memrefs during 332 // bufferization. 333 assert((replacement.getType().isa<MemRefType>() || 334 replacement.getType().isa<UnrankedMemRefType>()) && 335 "tensor op result should be replaced with a memref value"); 336 // The existing uses of the OpResult still expect a tensor. Insert a 337 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 338 // loose all of its users and eventually DCE away. 339 rewriter.setInsertionPointAfter(op); 340 replacement = rewriter.create<bufferization::ToTensorOp>( 341 replacement.getLoc(), replacement); 342 } 343 replacements.push_back(replacement); 344 } 345 346 rewriter.replaceOp(op, replacements); 347 } 348 349 AlwaysCopyAnalysisState::AlwaysCopyAnalysisState( 350 const BufferizationOptions &options) 351 : AnalysisState(options) { 352 // Note: Allocations must be deallocated with a subsequent run of the buffer 353 // deallocation pass. 354 assert(!options.createDeallocs && 355 "cannot create deallocs with AlwaysCopyBufferizationState"); 356 } 357 358 /// Return `true` if the given OpResult has been decided to bufferize inplace. 359 bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const { 360 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 361 // alloc and copy is inserted. 362 return !bufferizesToMemoryWrite(opOperand); 363 } 364 365 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 366 bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1, 367 Value v2) const { 368 // There is no analysis, so we do not know if the values are equivalent. The 369 // conservative answer is "false". 370 return false; 371 } 372 373 /// Return true if the given tensor (or an aliasing tensor) is yielded from 374 /// the containing block. Also include all aliasing tensors in the same block. 375 bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const { 376 // There is no analysis, so conservatively answer "true". 377 return true; 378 } 379 380 //===----------------------------------------------------------------------===// 381 // Bufferization-specific scoped alloc/dealloc insertion support. 382 //===----------------------------------------------------------------------===// 383 384 /// Create a memref allocation with the given type and dynamic extents. 385 static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type, 386 ValueRange dynShape, 387 const BufferizationOptions &options) { 388 if (options.allocationFn) 389 return (*options.allocationFn)(b, loc, type, dynShape, 390 options.bufferAlignment); 391 392 // Default bufferallocation via AllocOp. 393 Value allocated = b.create<memref::AllocOp>( 394 loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment)); 395 return allocated; 396 } 397 398 /// Creates a memref deallocation. The given memref buffer must have been 399 /// allocated using `createAlloc`. 400 LogicalResult 401 bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, 402 const BufferizationOptions &options) { 403 if (options.deallocationFn) 404 return (*options.deallocationFn)(b, loc, allocatedBuffer); 405 406 // Default buffer deallocation via DeallocOp. 407 b.create<memref::DeallocOp>(loc, allocatedBuffer); 408 return success(); 409 } 410 411 /// Compute the type of the `memref` to use for allocating the buffer for 412 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 413 /// dynamic dimensions in the returned `memref` type. 414 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 415 Value shapedValue, 416 SmallVectorImpl<Value> &dynShape) { 417 MemRefType allocMemRefType = 418 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 419 420 // Compute the dynamic part of the shape. 421 bool reifiedShapes = false; 422 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 423 shapedValue.getDefiningOp())) { 424 ReifiedRankedShapedTypeDims resultDims; 425 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 426 reifiedShapes = true; 427 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 428 auto &shape = resultDims[resultValue.getResultNumber()]; 429 for (const auto &dim : enumerate(allocMemRefType.getShape())) 430 if (ShapedType::isDynamic(dim.value())) 431 dynShape.push_back(shape[dim.index()]); 432 } 433 } 434 435 if (!reifiedShapes) { 436 for (const auto &dim : enumerate(allocMemRefType.getShape())) 437 if (ShapedType::isDynamic(dim.value())) { 438 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 439 shapedValue.getType().isa<MemRefType>()) && 440 "expected MemRef type"); 441 dynShape.push_back( 442 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 443 } 444 } 445 446 return allocMemRefType; 447 } 448 449 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type, 450 ValueRange dynShape, bool skipDealloc) { 451 auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape); 452 allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr()); 453 if (skipDealloc) 454 allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr()); 455 return allocaOp.getResult(); 456 } 457 458 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the 459 /// block in case of a bbArg). 460 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 461 Value shapedValue, 462 Optional<bool> dealloc) { 463 // Take a guard before anything else. 464 OpBuilder::InsertionGuard g(b); 465 466 // Compute allocation memref type. 467 assert(shapedValue.getType().isa<ShapedType>()); 468 MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>(); 469 SmallVector<Value> dynShape; 470 MemRefType allocMemRefType = 471 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 472 473 // Should be the buffer be deallocated again or should we let it leak? 474 bool skipDealloc; 475 if (dealloc) { 476 skipDealloc = !dealloc.getValue(); 477 } else { 478 assert(shapedValue.getType().isa<TensorType>() && 479 "must specify `dealloc` if non-tensor value is passed"); 480 // Buffer should be not be deallocated if deallocs are generally deactivated 481 // or if the tensor is yielded from a block. 482 skipDealloc = !getOptions().createDeallocs || 483 getAnalysisState().isTensorYielded(shapedValue); 484 } 485 486 // Create the buffer allocation. 487 Value alloc = 488 createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc); 489 490 // Insert a cast if a different type was requested. 491 if (memRefType && memRefType != allocMemRefType) { 492 assert(memref::CastOp::areCastCompatible(allocMemRefType, memRefType) && 493 "createAlloc: cast incompatible"); 494 alloc = b.create<memref::CastOp>(loc, memRefType, alloc); 495 } 496 497 return alloc; 498 } 499 500 /// Create a memory copy between two memref buffers. 501 LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc, 502 Value from, Value to, 503 const BufferizationOptions &options) { 504 if (options.memCpyFn) 505 return (*options.memCpyFn)(b, loc, from, to); 506 507 b.create<memref::CopyOp>(loc, from, to); 508 return success(); 509 } 510 511 LogicalResult 512 bufferization::createAllocDeallocOps(Operation *op, 513 const BufferizationOptions &options, 514 bool onlyLeakingAllocs) { 515 IRRewriter rewriter(op->getContext()); 516 517 // Bufferization creates memref.alloca ops. After bufferization, these must be 518 // rewritten to alloc/dealloc ops as specified in the bufferization options. 519 WalkResult status = op->walk([&](memref::AllocaOp allocaOp) { 520 // Ignore memref.alloca ops that were not created by the bufferization. 521 if (!allocaOp->hasAttr(kBufferAllocationAttr)) 522 return WalkResult::skip(); 523 // If `onlyLeakingAllocs`, process only ops that are marked as 524 // "skip dealloc". 525 bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr); 526 if (onlyLeakingAllocs && !skipDealloc) 527 return WalkResult::skip(); 528 529 // Create alloc. 530 Block *block = allocaOp->getBlock(); 531 rewriter.setInsertionPoint(allocaOp); 532 FailureOr<Value> alloc = 533 createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), 534 allocaOp.dynamicSizes(), options); 535 if (failed(alloc)) 536 return WalkResult::interrupt(); 537 rewriter.replaceOp(allocaOp, *alloc); 538 539 // Stop here if the buffer should not be deallocated. 540 if (skipDealloc) 541 return WalkResult::advance(); 542 543 // Create dealloc. 544 rewriter.setInsertionPoint(block->getTerminator()); 545 if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options))) 546 return WalkResult::interrupt(); 547 548 return WalkResult::advance(); 549 }); 550 551 return success(!status.wasInterrupted()); 552 } 553 554 /// Try to hoist all new buffer allocations until the next hoisting barrier. 555 // TODO: Consolidate this function with the existing buffer hoisting pass. 556 LogicalResult 557 bufferization::hoistBufferAllocations(Operation *op, 558 const BufferizationOptions &options) { 559 // Nothing to do if allocation hoisting is deactivated. 560 if (!options.hoistAllocations) 561 return success(); 562 563 // Gather all buffer allocations that were created by the bufferization. 564 SmallVector<Operation *> allocaOps; 565 op->walk([&](memref::AllocaOp allocaOp) { 566 if (allocaOp->hasAttr(kBufferAllocationAttr)) 567 allocaOps.push_back(allocaOp); 568 }); 569 570 for (Operation *allocaOp : allocaOps) { 571 // TODO: Hoisting of allocs with dynamic shape not implemented. 572 if (!allocaOp->getOpOperands().empty()) 573 continue; 574 575 Operation *op = allocaOp->getParentOp(); 576 while (op) { 577 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) { 578 if (bufferizableOp.isAllocationHoistingBarrier()) { 579 break; 580 } 581 } else { 582 // Op is not bufferizable: It may not be safe to hoist across this op. 583 break; 584 } 585 op = op->getParentOp(); 586 } 587 588 // FuncOp is an allocation hoisting barrier, so this should never happen. 589 assert(op && "allocation hoisting barrier not found"); 590 591 // Nothing to do if the insertion point is in the same block. 592 if (op == allocaOp->getParentOp()) 593 continue; 594 595 // `op` may have multiple blocks. Make sure that we insert in the right one. 596 SmallVector<Block *> blocks; 597 for (Region &r : op->getRegions()) 598 for (Block &b : r.getBlocks()) 599 blocks.push_back(&b); 600 auto *insertionBlock = llvm::find_if( 601 blocks, [&](Block *b) { return b->findAncestorOpInBlock(*allocaOp); }); 602 assert(insertionBlock != blocks.end() && "owning block not found"); 603 604 // Move to the beginning of the block. 605 allocaOp->moveBefore(&(*insertionBlock)->front()); 606 } 607 608 return success(); 609 } 610 611 //===----------------------------------------------------------------------===// 612 // Bufferization-specific BlockAndValueMapping support with debugging. 613 //===----------------------------------------------------------------------===// 614 615 bool bufferization::isFunctionArgument(Value value) { 616 auto bbArg = value.dyn_cast<BlockArgument>(); 617 if (!bbArg) 618 return false; 619 return isa<FuncOp>(bbArg.getOwner()->getParentOp()); 620 } 621 622 MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType, 623 Attribute memorySpace) { 624 MemRefLayoutAttrInterface layout = {}; 625 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 626 layout, memorySpace); 627 } 628 629 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 630 const BufferizationOptions &options, 631 MemRefLayoutAttrInterface layout, 632 Attribute memorySpace) { 633 // Case 1: Unranked memref type. 634 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 635 assert(!layout && "UnrankedTensorType cannot have a layout map"); 636 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 637 memorySpace); 638 } 639 640 // Case 2: Ranked memref type with specified layout. If fully dynamic layout 641 // maps are not requested, generate a type with `layout`, which is empty (no 642 // layout map) by default. 643 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 644 if (layout || !options.fullyDynamicLayoutMaps) { 645 return MemRefType::get(rankedTensorType.getShape(), 646 rankedTensorType.getElementType(), layout, 647 memorySpace); 648 } 649 650 // Case 3: Ranked memref type with unspecified layout. Choose the most dynamic 651 // one. 652 // TODO: address space decisions to connect with the actual alloc. 653 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 654 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 655 ShapedType::kDynamicStrideOrOffset); 656 AffineMap stridedLayout = makeStridedLinearLayoutMap( 657 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 658 return MemRefType::get(rankedTensorType.getShape(), 659 rankedTensorType.getElementType(), stridedLayout, 660 memorySpace); 661 } 662