1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/IR/AsmState.h" 13 #include "mlir/IR/BlockAndValueMapping.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/IR/Value.h" 18 #include "llvm/Support/Debug.h" 19 20 namespace mlir { 21 namespace bufferization { 22 23 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 24 25 } // namespace bufferization 26 } // namespace mlir 27 28 #define DEBUG_TYPE "bufferizable-op-interface" 29 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 30 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 31 32 using namespace mlir; 33 using namespace bufferization; 34 35 /// Attribute name used to mark the bufferization layout for region 36 /// arguments during linalg comprehensive bufferization. 37 constexpr const ::llvm::StringLiteral 38 bufferization::BufferizableOpInterface::kBufferLayoutAttrName; 39 40 /// Attribute name used to mark region arguments that can be bufferized 41 /// in-place during linalg comprehensive bufferization. 42 constexpr const ::llvm::StringLiteral 43 bufferization::BufferizableOpInterface::kInplaceableAttrName; 44 45 static const char *kBufferAllocationAttr = "bufferization.allocation"; 46 47 //===----------------------------------------------------------------------===// 48 // BufferizationOptions 49 //===----------------------------------------------------------------------===// 50 51 // Default constructor for BufferizationOptions. 52 BufferizationOptions::BufferizationOptions() = default; 53 54 BufferizableOpInterface 55 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 56 if (isOpAllowed(op)) 57 return dyn_cast<BufferizableOpInterface>(op); 58 return nullptr; 59 } 60 61 BufferizableOpInterface 62 BufferizationOptions::dynCastBufferizableOp(Value value) const { 63 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 64 if (isOpAllowed(bufferizableOp.getOperation())) 65 return bufferizableOp; 66 return nullptr; 67 } 68 69 void BufferizationOptions::addDialectStateInitializer( 70 StringRef name, const DialectStateInitFn &fn) { 71 stateInitializers.push_back( 72 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // Helper functions for BufferizableOpInterface 77 //===----------------------------------------------------------------------===// 78 79 static void setInsertionPointAfter(OpBuilder &b, Value value) { 80 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 81 b.setInsertionPointToStart(bbArg.getOwner()); 82 } else { 83 b.setInsertionPointAfter(value.getDefiningOp()); 84 } 85 } 86 87 /// Determine which OpOperand* will alias with `result` if the op is bufferized 88 /// in place. Return an empty vector if the op is not bufferizable. 89 SmallVector<OpOperand *> 90 AnalysisState::getAliasingOpOperand(OpResult result) const { 91 if (Operation *op = result.getDefiningOp()) 92 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) 93 return bufferizableOp.getAliasingOpOperand(result, *this); 94 return {}; 95 } 96 97 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 98 /// in place. Return an empty vector if the op is not bufferizable. 99 SmallVector<OpResult> 100 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 101 if (auto bufferizableOp = 102 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 103 return bufferizableOp.getAliasingOpResult(opOperand, *this); 104 return {}; 105 } 106 107 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 108 /// op is not bufferizable. 109 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 110 if (auto bufferizableOp = 111 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 112 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 113 114 // Unknown op that returns a tensor. The inplace analysis does not support it. 115 // Conservatively return true. 116 return true; 117 } 118 119 /// Return true if `opOperand` bufferizes to a memory write. Return 120 /// `true` if the op is not bufferizable. 121 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 122 if (auto bufferizableOp = 123 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 124 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 125 126 // Unknown op that returns a tensor. The inplace analysis does not support it. 127 // Conservatively return true. 128 return true; 129 } 130 131 /// Return true if `opOperand` does neither read nor write but bufferizes to an 132 /// alias. Return false if the op is not bufferizable. 133 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 134 if (auto bufferizableOp = 135 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 136 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 137 138 // Unknown op that returns a tensor. The inplace analysis does not support it. 139 // Conservatively return false. 140 return false; 141 } 142 143 /// Return true if the given value is read by an op that bufferizes to a memory 144 /// read. Also takes into account ops that create an alias but do not read by 145 /// themselves (e.g., ExtractSliceOp). 146 bool AnalysisState::isValueRead(Value value) const { 147 assert(value.getType().isa<TensorType>() && "expected TensorType"); 148 SmallVector<OpOperand *> workingSet; 149 for (OpOperand &use : value.getUses()) 150 workingSet.push_back(&use); 151 152 while (!workingSet.empty()) { 153 OpOperand *uMaybeReading = workingSet.pop_back_val(); 154 // Skip over all ops that neither read nor write (but create an alias). 155 if (bufferizesToAliasOnly(*uMaybeReading)) 156 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 157 for (OpOperand &use : opResult.getUses()) 158 workingSet.push_back(&use); 159 if (bufferizesToMemoryRead(*uMaybeReading)) 160 return true; 161 } 162 163 return false; 164 } 165 166 // Starting from `value`, follow the use-def chain in reverse, always selecting 167 // the aliasing OpOperands. Find and return Values for which `condition` 168 // evaluates to true. OpOperands of such matching Values are not traversed any 169 // further. 170 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 171 Value value, llvm::function_ref<bool(Value)> condition) const { 172 llvm::SetVector<Value> result, workingSet; 173 workingSet.insert(value); 174 175 while (!workingSet.empty()) { 176 Value value = workingSet.pop_back_val(); 177 if (condition(value) || value.isa<BlockArgument>()) { 178 result.insert(value); 179 continue; 180 } 181 182 OpResult opResult = value.cast<OpResult>(); 183 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 184 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 185 result.insert(value); 186 continue; 187 } 188 189 for (OpOperand *o : opOperands) 190 workingSet.insert(o->get()); 191 } 192 193 return result; 194 } 195 196 // Find the Values of the last preceding write of a given Value. 197 llvm::SetVector<Value> 198 AnalysisState::findLastPrecedingWrite(Value value) const { 199 return findValueInReverseUseDefChain(value, [&](Value value) { 200 Operation *op = value.getDefiningOp(); 201 if (!op) 202 return true; 203 auto bufferizableOp = options.dynCastBufferizableOp(op); 204 if (!bufferizableOp) 205 return true; 206 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 207 }); 208 } 209 210 AnalysisState::AnalysisState(const BufferizationOptions &options) 211 : options(options) { 212 for (const BufferizationOptions::AnalysisStateInitFn &fn : 213 options.stateInitializers) 214 fn(*this); 215 } 216 217 // bufferization.to_memref is not allowed to change the rank. 218 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 219 #ifndef NDEBUG 220 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 221 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 222 rankedTensorType.getRank()) && 223 "to_memref would be invalid: mismatching ranks"); 224 #endif 225 } 226 227 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, 228 const BufferizationOptions &options) { 229 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 230 assert(tensorType && "unexpected non-tensor type"); 231 232 // Replace "%t = to_tensor %m" with %m. 233 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 234 return toTensorOp.memref(); 235 236 // Insert to_memref op. 237 OpBuilder::InsertionGuard g(rewriter); 238 setInsertionPointAfter(rewriter, tensor); 239 Type memrefType = getMemRefType(tensorType, options); 240 ensureToMemrefOpIsValid(tensor, memrefType); 241 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 242 tensor); 243 } 244 245 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate 246 /// a new buffer and copy over data from the existing buffer if out-of-place 247 /// bufferization is necessary. 248 FailureOr<Value> 249 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, 250 bool forceInPlace, 251 Optional<Operation *> customCopyInsertionPoint) { 252 const BufferizationOptions &options = analysisState.getOptions(); 253 OpBuilder::InsertionGuard guard(rewriter); 254 Operation *op = opOperand.getOwner(); 255 Location loc = op->getLoc(); 256 Value operand = opOperand.get(); 257 Value operandBuffer = lookupBuffer(rewriter, operand, options); 258 259 if (forceInPlace || analysisState.isInPlace(opOperand)) 260 return operandBuffer; 261 262 // Bufferizing out-of-place: Allocate a new buffer. 263 // Move insertion point right after `operandBuffer`. That is where the 264 // allocation should be inserted (in the absence of allocation hoisting). 265 setInsertionPointAfter(rewriter, operandBuffer); 266 // Allocate the result buffer. 267 FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer); 268 if (failed(resultBuffer)) 269 return failure(); 270 // Do not copy if the last preceding writes of `operand` are ops that do 271 // not write (skipping ops that merely create aliases). E.g., InitTensorOp. 272 // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA 273 // use-def chain, it returns that value, regardless of whether it is a 274 // memory write or not. 275 SetVector<Value> lastWrites = analysisState.findLastPrecedingWrite(operand); 276 if (llvm::none_of(lastWrites, [&](Value lastWrite) { 277 if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) 278 return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), 279 analysisState); 280 return true; 281 })) 282 return resultBuffer; 283 // Do not copy if the copied data is never read. 284 SmallVector<OpResult> aliasingOpResults = 285 analysisState.getAliasingOpResult(opOperand); 286 if (!aliasingOpResults.empty() && 287 !analysisState.bufferizesToMemoryRead(opOperand) && 288 llvm::none_of(aliasingOpResults, [&](OpResult opResult) { 289 return analysisState.isValueRead(opResult); 290 })) 291 return resultBuffer; 292 // Do not copy if this op does not read the data, but writes it. 293 if (analysisState.bufferizesToMemoryWrite(opOperand) && 294 !analysisState.bufferizesToMemoryRead(opOperand)) 295 return resultBuffer; 296 297 if (customCopyInsertionPoint) { 298 rewriter.setInsertionPoint(*customCopyInsertionPoint); 299 } else { 300 // The copy happens right before the op that is bufferized. 301 rewriter.setInsertionPoint(op); 302 } 303 if (failed( 304 createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options))) 305 return failure(); 306 307 return resultBuffer; 308 } 309 310 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 311 Operation *op, 312 ValueRange values) { 313 assert(values.size() == op->getNumResults() && 314 "expected one value per OpResult"); 315 OpBuilder::InsertionGuard g(rewriter); 316 317 // Replace all OpResults with the given values. 318 SmallVector<Value> replacements; 319 for (OpResult opResult : op->getOpResults()) { 320 Value replacement = values[opResult.getResultNumber()]; 321 if (opResult.getType().isa<TensorType>()) { 322 // The OpResult is a tensor. Such values are replaced with memrefs during 323 // bufferization. 324 assert((replacement.getType().isa<MemRefType>() || 325 replacement.getType().isa<UnrankedMemRefType>()) && 326 "tensor op result should be replaced with a memref value"); 327 // The existing uses of the OpResult still expect a tensor. Insert a 328 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 329 // loose all of its users and eventually DCE away. 330 rewriter.setInsertionPointAfter(op); 331 replacement = rewriter.create<bufferization::ToTensorOp>( 332 replacement.getLoc(), replacement); 333 } 334 replacements.push_back(replacement); 335 } 336 337 rewriter.replaceOp(op, replacements); 338 } 339 340 AlwaysCopyAnalysisState::AlwaysCopyAnalysisState( 341 const BufferizationOptions &options) 342 : AnalysisState(options) {} 343 344 /// Return `true` if the given OpResult has been decided to bufferize inplace. 345 bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const { 346 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 347 // alloc and copy is inserted. 348 return !bufferizesToMemoryWrite(opOperand); 349 } 350 351 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 352 bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1, 353 Value v2) const { 354 // There is no analysis, so we do not know if the values are equivalent. The 355 // conservative answer is "false". 356 return false; 357 } 358 359 //===----------------------------------------------------------------------===// 360 // Bufferization-specific scoped alloc/dealloc insertion support. 361 //===----------------------------------------------------------------------===// 362 363 /// Create a memref allocation with the given type and dynamic extents. 364 static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type, 365 ValueRange dynShape, 366 const BufferizationOptions &options) { 367 if (options.allocationFn) 368 return (*options.allocationFn)(b, loc, type, dynShape, 369 options.bufferAlignment); 370 371 // Default bufferallocation via AllocOp. 372 Value allocated = b.create<memref::AllocOp>( 373 loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment)); 374 return allocated; 375 } 376 377 /// Creates a memref deallocation. The given memref buffer must have been 378 /// allocated using `createAlloc`. 379 LogicalResult 380 bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, 381 const BufferizationOptions &options) { 382 if (options.deallocationFn) 383 return (*options.deallocationFn)(b, loc, allocatedBuffer); 384 385 // Default buffer deallocation via DeallocOp. 386 b.create<memref::DeallocOp>(loc, allocatedBuffer); 387 return success(); 388 } 389 390 /// Compute the type of the `memref` to use for allocating the buffer for 391 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 392 /// dynamic dimensions in the returned `memref` type. 393 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 394 Value shapedValue, 395 SmallVectorImpl<Value> &dynShape) { 396 MemRefType allocMemRefType = 397 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 398 399 // Compute the dynamic part of the shape. 400 bool reifiedShapes = false; 401 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 402 shapedValue.getDefiningOp())) { 403 ReifiedRankedShapedTypeDims resultDims; 404 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 405 reifiedShapes = true; 406 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 407 auto &shape = resultDims[resultValue.getResultNumber()]; 408 for (const auto &dim : enumerate(allocMemRefType.getShape())) 409 if (ShapedType::isDynamic(dim.value())) 410 dynShape.push_back(shape[dim.index()]); 411 } 412 } 413 414 if (!reifiedShapes) { 415 for (const auto &dim : enumerate(allocMemRefType.getShape())) 416 if (ShapedType::isDynamic(dim.value())) { 417 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 418 shapedValue.getType().isa<MemRefType>()) && 419 "expected MemRef type"); 420 dynShape.push_back( 421 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 422 } 423 } 424 425 return allocMemRefType; 426 } 427 428 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type, 429 ValueRange dynShape) { 430 auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape); 431 allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr()); 432 return allocaOp.getResult(); 433 } 434 435 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the 436 /// block in case of a bbArg). 437 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 438 Value shapedValue) { 439 // Take a guard before anything else. 440 OpBuilder::InsertionGuard g(b); 441 assert(shapedValue.getType().isa<ShapedType>()); 442 MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>(); 443 SmallVector<Value> dynShape; 444 MemRefType allocMemRefType = 445 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 446 Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape); 447 if (memRefType && memRefType != allocMemRefType) { 448 assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) && 449 "createAlloc: cast incompatible"); 450 alloc = b.create<memref::CastOp>(loc, memRefType, alloc); 451 } 452 return alloc; 453 } 454 455 /// Create a memref allocation with the given type and dynamic extents. 456 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 457 MemRefType type, 458 ValueRange dynShape) { 459 return createBufferAllocation(b, loc, type, dynShape); 460 } 461 462 /// Create a memory copy between two memref buffers. 463 LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc, 464 Value from, Value to, 465 const BufferizationOptions &options) { 466 if (options.memCpyFn) 467 return (*options.memCpyFn)(b, loc, from, to); 468 469 b.create<memref::CopyOp>(loc, from, to); 470 return success(); 471 } 472 473 static LogicalResult 474 createAllocDeallocOps(Operation *op, const BufferizationOptions &options) { 475 IRRewriter rewriter(op->getContext()); 476 477 // Bufferization creates memref.alloca ops. After bufferization, these must be 478 // rewritten to alloc/dealloc ops as specified in the bufferization options. 479 WalkResult status = op->walk([&](memref::AllocaOp allocaOp) { 480 // Ignore memref.alloca ops that were not created by the bufferization. 481 if (!allocaOp->hasAttr(kBufferAllocationAttr)) 482 return WalkResult::skip(); 483 484 Block *block = allocaOp->getBlock(); 485 rewriter.setInsertionPoint(allocaOp); 486 FailureOr<Value> alloc = 487 createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), 488 allocaOp.dynamicSizes(), options); 489 if (failed(alloc)) 490 return WalkResult::interrupt(); 491 rewriter.replaceOp(allocaOp, *alloc); 492 493 // Stop here if deallocations are deactivated. 494 if (!options.createDeallocs) 495 return WalkResult::advance(); 496 497 rewriter.setInsertionPoint(block->getTerminator()); 498 if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options))) 499 return WalkResult::interrupt(); 500 501 return WalkResult::advance(); 502 }); 503 504 return success(!status.wasInterrupted()); 505 } 506 507 /// Try to hoist all new buffer allocations until the next hoisting barrier. 508 // TODO: Consolidate this function with the existing buffer hoisting pass. 509 static LogicalResult 510 hoistBufferAllocations(Operation *op, const BufferizationOptions &options) { 511 // Nothing to do if allocation hoisting is deactivated. 512 if (!options.hoistAllocations) 513 return success(); 514 515 // Gather all buffer allocations that were created by the bufferization. 516 SmallVector<Operation *> allocaOps; 517 op->walk([&](memref::AllocaOp allocaOp) { 518 if (allocaOp->hasAttr(kBufferAllocationAttr)) 519 allocaOps.push_back(allocaOp); 520 }); 521 522 for (Operation *allocaOp : allocaOps) { 523 // TODO: Hoisting of allocs with dynamic shape not implemented. 524 if (!allocaOp->getOpOperands().empty()) 525 continue; 526 527 Operation *op = allocaOp->getParentOp(); 528 while (op) { 529 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) { 530 if (bufferizableOp.isAllocationHoistingBarrier()) { 531 break; 532 } 533 } else { 534 // Op is not bufferizable: It may not be safe to hoist across this op. 535 break; 536 } 537 op = op->getParentOp(); 538 } 539 540 // FuncOp is an allocation hoisting barrier, so this should never happen. 541 assert(op && "allocation hoisting barrier not found"); 542 543 // Nothing to do if the insertion point is in the same block. 544 if (op == allocaOp->getParentOp()) 545 continue; 546 547 // `op` may have multiple blocks. Make sure that we insert in the right one. 548 SmallVector<Block *> blocks; 549 for (Region &r : op->getRegions()) 550 for (Block &b : r.getBlocks()) 551 blocks.push_back(&b); 552 auto *insertionBlock = llvm::find_if( 553 blocks, [&](Block *b) { return b->findAncestorOpInBlock(*allocaOp); }); 554 assert(insertionBlock != blocks.end() && "owning block not found"); 555 556 // Move to the beginning of the block. 557 allocaOp->moveBefore(&(*insertionBlock)->front()); 558 } 559 560 return success(); 561 } 562 563 LogicalResult 564 bufferization::finalizeBuffers(Operation *op, 565 const BufferizationOptions &options) { 566 if (failed(hoistBufferAllocations(op, options))) 567 return failure(); 568 if (failed(createAllocDeallocOps(op, options))) 569 return failure(); 570 571 return success(); 572 } 573 574 //===----------------------------------------------------------------------===// 575 // Bufferization-specific BlockAndValueMapping support with debugging. 576 //===----------------------------------------------------------------------===// 577 578 bool bufferization::isFunctionArgument(Value value) { 579 auto bbArg = value.dyn_cast<BlockArgument>(); 580 if (!bbArg) 581 return false; 582 return isa<FuncOp>(bbArg.getOwner()->getParentOp()); 583 } 584 585 MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType, 586 Attribute memorySpace) { 587 MemRefLayoutAttrInterface layout = {}; 588 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 589 layout, memorySpace); 590 } 591 592 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 593 const BufferizationOptions &options, 594 MemRefLayoutAttrInterface layout, 595 Attribute memorySpace) { 596 // Case 1: Unranked memref type. 597 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 598 assert(!layout && "UnrankedTensorType cannot have a layout map"); 599 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 600 memorySpace); 601 } 602 603 // Case 2: Ranked memref type with specified layout. If fully dynamic layout 604 // maps are not requested, generate a type with `layout`, which is empty (no 605 // layout map) by default. 606 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 607 if (layout || !options.fullyDynamicLayoutMaps) { 608 return MemRefType::get(rankedTensorType.getShape(), 609 rankedTensorType.getElementType(), layout, 610 memorySpace); 611 } 612 613 // Case 3: Ranked memref type with unspecified layout. Choose the most dynamic 614 // one. 615 // TODO: address space decisions to connect with the actual alloc. 616 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 617 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 618 ShapedType::kDynamicStrideOrOffset); 619 AffineMap stridedLayout = makeStridedLinearLayoutMap( 620 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 621 return MemRefType::get(rankedTensorType.getShape(), 622 rankedTensorType.getElementType(), stridedLayout, 623 memorySpace); 624 } 625