1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/AsmState.h" 14 #include "mlir/IR/BlockAndValueMapping.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/IR/Value.h" 19 #include "llvm/Support/Debug.h" 20 21 namespace mlir { 22 namespace bufferization { 23 24 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 25 26 } // namespace bufferization 27 } // namespace mlir 28 29 #define DEBUG_TYPE "bufferizable-op-interface" 30 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 31 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 32 33 using namespace mlir; 34 using namespace bufferization; 35 36 /// Attribute name used to mark region arguments that can be bufferized 37 /// in-place during linalg comprehensive bufferization. 38 constexpr const ::llvm::StringLiteral 39 bufferization::BufferizableOpInterface::kInplaceableAttrName; 40 41 /// Attribute name used to mark allocs that are created by the bufferization. 42 static const char *kBufferAllocationAttr = "bufferization.allocation"; 43 44 /// Attribute name used to mark allocs that should not be deallocated. 45 static const char *kSkipDeallocAttr = "bufferization.skip_dealloc"; 46 47 //===----------------------------------------------------------------------===// 48 // OpFilter 49 //===----------------------------------------------------------------------===// 50 51 bool OpFilter::isOpAllowed(Operation *op) const { 52 // All other ops: Allow/disallow according to filter. 53 bool isAllowed = !hasAllowRule(); 54 for (const Entry &entry : entries) { 55 bool filterResult = entry.fn(op); 56 switch (entry.type) { 57 case Entry::ALLOW: 58 isAllowed |= filterResult; 59 break; 60 case Entry::DENY: 61 if (filterResult) 62 // DENY filter matches. This op is no allowed. (Even if other ALLOW 63 // filters may match.) 64 return false; 65 }; 66 } 67 return isAllowed; 68 } 69 70 //===----------------------------------------------------------------------===// 71 // BufferizationOptions 72 //===----------------------------------------------------------------------===// 73 74 // Default constructor for BufferizationOptions. 75 BufferizationOptions::BufferizationOptions() = default; 76 77 bool BufferizationOptions::isOpAllowed(Operation *op) const { 78 // Special case: If function boundary bufferization is deactivated, do not 79 // allow ops that belong to the `func` dialect. 80 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 81 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 82 return false; 83 84 return opFilter.isOpAllowed(op); 85 } 86 87 BufferizableOpInterface 88 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 89 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 90 if (!bufferizableOp) 91 return nullptr; 92 if (!isOpAllowed(op)) 93 return nullptr; 94 return bufferizableOp; 95 } 96 97 BufferizableOpInterface 98 BufferizationOptions::dynCastBufferizableOp(Value value) const { 99 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 100 if (isOpAllowed(bufferizableOp.getOperation())) 101 return bufferizableOp; 102 return nullptr; 103 } 104 105 void BufferizationOptions::addDialectStateInitializer( 106 StringRef name, const DialectStateInitFn &fn) { 107 stateInitializers.push_back( 108 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 109 } 110 111 //===----------------------------------------------------------------------===// 112 // Helper functions for BufferizableOpInterface 113 //===----------------------------------------------------------------------===// 114 115 static void setInsertionPointAfter(OpBuilder &b, Value value) { 116 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 117 b.setInsertionPointToStart(bbArg.getOwner()); 118 } else { 119 b.setInsertionPointAfter(value.getDefiningOp()); 120 } 121 } 122 123 /// Determine which OpOperand* will alias with `result` if the op is bufferized 124 /// in place. Return an empty vector if the op is not bufferizable. 125 SmallVector<OpOperand *> 126 AnalysisState::getAliasingOpOperand(OpResult result) const { 127 if (Operation *op = result.getDefiningOp()) 128 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 129 return bufferizableOp.getAliasingOpOperand(result, *this); 130 return {}; 131 } 132 133 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 134 /// in place. Return an empty vector if the op is not bufferizable. 135 SmallVector<OpResult> 136 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 137 if (auto bufferizableOp = 138 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 139 return bufferizableOp.getAliasingOpResult(opOperand, *this); 140 return {}; 141 } 142 143 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 144 /// op is not bufferizable. 145 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 146 if (auto bufferizableOp = 147 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 148 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 149 150 // Unknown op that returns a tensor. The inplace analysis does not support it. 151 // Conservatively return true. 152 return true; 153 } 154 155 /// Return true if `opOperand` bufferizes to a memory write. Return 156 /// `true` if the op is not bufferizable. 157 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 158 if (auto bufferizableOp = 159 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 160 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 161 162 // Unknown op that returns a tensor. The inplace analysis does not support it. 163 // Conservatively return true. 164 return true; 165 } 166 167 /// Return true if `opOperand` does neither read nor write but bufferizes to an 168 /// alias. Return false if the op is not bufferizable. 169 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 170 if (auto bufferizableOp = 171 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 172 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 173 174 // Unknown op that returns a tensor. The inplace analysis does not support it. 175 // Conservatively return false. 176 return false; 177 } 178 179 /// Return true if the given value is read by an op that bufferizes to a memory 180 /// read. Also takes into account ops that create an alias but do not read by 181 /// themselves (e.g., ExtractSliceOp). 182 bool AnalysisState::isValueRead(Value value) const { 183 assert(value.getType().isa<TensorType>() && "expected TensorType"); 184 SmallVector<OpOperand *> workingSet; 185 for (OpOperand &use : value.getUses()) 186 workingSet.push_back(&use); 187 188 while (!workingSet.empty()) { 189 OpOperand *uMaybeReading = workingSet.pop_back_val(); 190 // Skip over all ops that neither read nor write (but create an alias). 191 if (bufferizesToAliasOnly(*uMaybeReading)) 192 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 193 for (OpOperand &use : opResult.getUses()) 194 workingSet.push_back(&use); 195 if (bufferizesToMemoryRead(*uMaybeReading)) 196 return true; 197 } 198 199 return false; 200 } 201 202 // Starting from `value`, follow the use-def chain in reverse, always selecting 203 // the aliasing OpOperands. Find and return Values for which `condition` 204 // evaluates to true. OpOperands of such matching Values are not traversed any 205 // further. 206 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 207 Value value, llvm::function_ref<bool(Value)> condition) const { 208 llvm::SetVector<Value> result, workingSet; 209 workingSet.insert(value); 210 211 while (!workingSet.empty()) { 212 Value value = workingSet.pop_back_val(); 213 if (condition(value) || value.isa<BlockArgument>()) { 214 result.insert(value); 215 continue; 216 } 217 218 OpResult opResult = value.cast<OpResult>(); 219 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 220 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 221 result.insert(value); 222 continue; 223 } 224 225 for (OpOperand *o : opOperands) 226 workingSet.insert(o->get()); 227 } 228 229 return result; 230 } 231 232 // Find the Values of the last preceding write of a given Value. 233 llvm::SetVector<Value> 234 AnalysisState::findLastPrecedingWrite(Value value) const { 235 return findValueInReverseUseDefChain(value, [&](Value value) { 236 Operation *op = value.getDefiningOp(); 237 if (!op) 238 return true; 239 auto bufferizableOp = options.dynCastBufferizableOp(op); 240 if (!bufferizableOp) 241 return true; 242 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 243 }); 244 } 245 246 AnalysisState::AnalysisState(const BufferizationOptions &options) 247 : options(options) { 248 for (const BufferizationOptions::AnalysisStateInitFn &fn : 249 options.stateInitializers) 250 fn(*this); 251 } 252 253 // bufferization.to_memref is not allowed to change the rank. 254 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 255 #ifndef NDEBUG 256 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 257 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 258 rankedTensorType.getRank()) && 259 "to_memref would be invalid: mismatching ranks"); 260 #endif 261 } 262 263 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, 264 const BufferizationOptions &options) { 265 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 266 assert(tensorType && "unexpected non-tensor type"); 267 268 // Replace "%t = to_tensor %m" with %m. 269 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 270 return toTensorOp.memref(); 271 272 // Insert to_memref op. 273 OpBuilder::InsertionGuard g(rewriter); 274 setInsertionPointAfter(rewriter, tensor); 275 Type memrefType = getMemRefType(tensorType, options); 276 ensureToMemrefOpIsValid(tensor, memrefType); 277 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 278 tensor); 279 } 280 281 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate 282 /// a new buffer and copy over data from the existing buffer if out-of-place 283 /// bufferization was decided. 284 FailureOr<Value> 285 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, 286 Optional<ForceInPlacability> overrideInPlace, 287 Optional<Operation *> customCopyInsertionPoint) { 288 const BufferizationOptions &options = analysisState.getOptions(); 289 OpBuilder::InsertionGuard guard(rewriter); 290 Operation *op = opOperand.getOwner(); 291 Location loc = op->getLoc(); 292 SmallVector<OpResult> aliasingOpResults = 293 analysisState.getAliasingOpResult(opOperand); 294 Value operand = opOperand.get(); 295 Value operandBuffer = lookupBuffer(rewriter, operand, options); 296 297 // Can `operandBuffer` be used directly or do we need a copy? 298 bool inplace = 299 overrideInPlace != FORCE_OUT_OF_PLACE && 300 (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand)); 301 if (inplace) 302 return operandBuffer; 303 304 // Bufferizing out-of-place: Allocate a new buffer. 305 // Move insertion point right after `operandBuffer`. That is where the 306 // allocation should be inserted (in the absence of allocation hoisting). 307 setInsertionPointAfter(rewriter, operandBuffer); 308 // Allocate the result buffer. The buffer should be deallocated if the tensor 309 // is not yielded and deallocs are enabled in general. 310 bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { 311 return getAnalysisState().isTensorYielded(v); 312 }); 313 FailureOr<Value> resultBuffer = createAlloc( 314 rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); 315 if (failed(resultBuffer)) 316 return failure(); 317 // Do not copy the buffer if its contents are undefined. 318 if (analysisState.hasUndefinedContents(&opOperand)) 319 return resultBuffer; 320 // Do not copy if the copied data is never read. 321 if (!aliasingOpResults.empty() && 322 !analysisState.bufferizesToMemoryRead(opOperand) && 323 llvm::none_of(aliasingOpResults, [&](OpResult opResult) { 324 return analysisState.isValueRead(opResult); 325 })) 326 return resultBuffer; 327 // Do not copy if this op does not read the data, but writes it. 328 if (analysisState.bufferizesToMemoryWrite(opOperand) && 329 !analysisState.bufferizesToMemoryRead(opOperand)) 330 return resultBuffer; 331 332 if (customCopyInsertionPoint) { 333 rewriter.setInsertionPoint(*customCopyInsertionPoint); 334 } else { 335 // The copy happens right before the op that is bufferized. 336 rewriter.setInsertionPoint(op); 337 } 338 if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer))) 339 return failure(); 340 341 return resultBuffer; 342 } 343 344 /// Return the buffer type for a given Value (tensor) after bufferization. 345 BaseMemRefType BufferizationState::getBufferType(Value value) const { 346 auto tensorType = value.getType().dyn_cast<TensorType>(); 347 assert(tensorType && "unexpected non-tensor type"); 348 349 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 350 return toTensorOp.memref().getType().cast<BaseMemRefType>(); 351 352 return getMemRefType(tensorType, getOptions()); 353 } 354 355 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 356 Operation *op, 357 ValueRange values) { 358 assert(values.size() == op->getNumResults() && 359 "expected one value per OpResult"); 360 OpBuilder::InsertionGuard g(rewriter); 361 362 // Replace all OpResults with the given values. 363 SmallVector<Value> replacements; 364 for (OpResult opResult : op->getOpResults()) { 365 Value replacement = values[opResult.getResultNumber()]; 366 if (opResult.getType().isa<TensorType>()) { 367 // The OpResult is a tensor. Such values are replaced with memrefs during 368 // bufferization. 369 assert((replacement.getType().isa<MemRefType>() || 370 replacement.getType().isa<UnrankedMemRefType>()) && 371 "tensor op result should be replaced with a memref value"); 372 // The existing uses of the OpResult still expect a tensor. Insert a 373 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 374 // loose all of its users and eventually DCE away. 375 rewriter.setInsertionPointAfter(op); 376 replacement = rewriter.create<bufferization::ToTensorOp>( 377 replacement.getLoc(), replacement); 378 } 379 replacements.push_back(replacement); 380 } 381 382 rewriter.replaceOp(op, replacements); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // Bufferization-specific scoped alloc/dealloc insertion support. 387 //===----------------------------------------------------------------------===// 388 389 /// Create a memref allocation with the given type and dynamic extents. 390 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 391 MemRefType type, 392 ValueRange dynShape) const { 393 if (allocationFn) 394 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 395 396 // Default bufferallocation via AllocOp. 397 Value allocated = b.create<memref::AllocOp>( 398 loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment)); 399 return allocated; 400 } 401 402 /// Creates a memref deallocation. The given memref buffer must have been 403 /// allocated using `createAlloc`. 404 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 405 Value allocatedBuffer) const { 406 if (deallocationFn) 407 return (*deallocationFn)(b, loc, allocatedBuffer); 408 409 // Default buffer deallocation via DeallocOp. 410 b.create<memref::DeallocOp>(loc, allocatedBuffer); 411 return success(); 412 } 413 414 static MemRefType getContiguousMemRefType(ShapedType shapedType, 415 Attribute memorySpace = {}) { 416 MemRefLayoutAttrInterface layout = {}; 417 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 418 layout, memorySpace); 419 } 420 421 /// Compute the type of the `memref` to use for allocating the buffer for 422 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 423 /// dynamic dimensions in the returned `memref` type. 424 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 425 Value shapedValue, 426 SmallVectorImpl<Value> &dynShape) { 427 MemRefType allocMemRefType = 428 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 429 430 // Compute the dynamic part of the shape. 431 bool reifiedShapes = false; 432 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 433 shapedValue.getDefiningOp())) { 434 ReifiedRankedShapedTypeDims resultDims; 435 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 436 reifiedShapes = true; 437 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 438 auto &shape = resultDims[resultValue.getResultNumber()]; 439 for (const auto &dim : enumerate(allocMemRefType.getShape())) 440 if (ShapedType::isDynamic(dim.value())) 441 dynShape.push_back(shape[dim.index()]); 442 } 443 } 444 445 if (!reifiedShapes) { 446 for (const auto &dim : enumerate(allocMemRefType.getShape())) 447 if (ShapedType::isDynamic(dim.value())) { 448 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 449 shapedValue.getType().isa<MemRefType>()) && 450 "expected MemRef type"); 451 dynShape.push_back( 452 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 453 } 454 } 455 456 return allocMemRefType; 457 } 458 459 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type, 460 ValueRange dynShape, bool skipDealloc) { 461 auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape); 462 allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr()); 463 if (skipDealloc) 464 allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr()); 465 return allocaOp.getResult(); 466 } 467 468 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the 469 /// block in case of a bbArg). 470 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 471 Value shapedValue, 472 Optional<bool> dealloc) { 473 // Take a guard before anything else. 474 OpBuilder::InsertionGuard g(b); 475 476 // Compute allocation memref type. 477 assert(shapedValue.getType().isa<ShapedType>()); 478 SmallVector<Value> dynShape; 479 MemRefType allocMemRefType = 480 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 481 482 // Should be the buffer be deallocated again or should we let it leak? 483 bool skipDealloc; 484 if (dealloc) { 485 skipDealloc = !dealloc.getValue(); 486 } else { 487 assert(shapedValue.getType().isa<TensorType>() && 488 "must specify `dealloc` if non-tensor value is passed"); 489 // Buffer should be not be deallocated if deallocs are generally deactivated 490 // or if the tensor is yielded from a block. 491 skipDealloc = !getOptions().createDeallocs || 492 getAnalysisState().isTensorYielded(shapedValue); 493 } 494 495 // Create the buffer allocation. 496 return createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc); 497 } 498 499 /// Create a memory copy between two memref buffers. 500 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 501 Value from, Value to) const { 502 if (memCpyFn) 503 return (*memCpyFn)(b, loc, from, to); 504 505 b.create<memref::CopyOp>(loc, from, to); 506 return success(); 507 } 508 509 LogicalResult 510 bufferization::createAllocDeallocOps(Operation *op, 511 const BufferizationOptions &options, 512 bool onlyLeakingAllocs, bool *changed) { 513 IRRewriter rewriter(op->getContext()); 514 if (changed) 515 *changed = false; 516 517 // Bufferization creates memref.alloca ops. After bufferization, these must be 518 // rewritten to alloc/dealloc ops as specified in the bufferization options. 519 WalkResult status = op->walk([&](memref::AllocaOp allocaOp) { 520 // Ignore memref.alloca ops that were not created by the bufferization. 521 if (!allocaOp->hasAttr(kBufferAllocationAttr)) 522 return WalkResult::skip(); 523 // If `onlyLeakingAllocs`, process only ops that are marked as 524 // "skip dealloc". 525 bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr); 526 if (onlyLeakingAllocs && !skipDealloc) 527 return WalkResult::skip(); 528 529 // Create alloc. 530 Block *block = allocaOp->getBlock(); 531 rewriter.setInsertionPoint(allocaOp); 532 FailureOr<Value> alloc = 533 options.createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), 534 allocaOp.dynamicSizes()); 535 if (failed(alloc)) 536 return WalkResult::interrupt(); 537 rewriter.replaceOp(allocaOp, *alloc); 538 if (changed) 539 *changed = true; 540 541 // Stop here if the buffer should not be deallocated. 542 if (skipDealloc) 543 return WalkResult::advance(); 544 545 // Create dealloc. 546 rewriter.setInsertionPoint(block->getTerminator()); 547 if (failed(options.createDealloc(rewriter, alloc->getLoc(), *alloc))) 548 return WalkResult::interrupt(); 549 550 return WalkResult::advance(); 551 }); 552 553 return success(!status.wasInterrupted()); 554 } 555 556 //===----------------------------------------------------------------------===// 557 // Bufferization-specific BlockAndValueMapping support with debugging. 558 //===----------------------------------------------------------------------===// 559 560 bool bufferization::isFunctionArgument(Value value) { 561 auto bbArg = value.dyn_cast<BlockArgument>(); 562 if (!bbArg) 563 return false; 564 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 565 } 566 567 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 568 const BufferizationOptions &options, 569 MemRefLayoutAttrInterface layout, 570 Attribute memorySpace) { 571 // Case 1: Unranked memref type. 572 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 573 assert(!layout && "UnrankedTensorType cannot have a layout map"); 574 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 575 memorySpace); 576 } 577 578 // Case 2: Ranked memref type with specified layout. 579 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 580 if (layout) { 581 return MemRefType::get(rankedTensorType.getShape(), 582 rankedTensorType.getElementType(), layout, 583 memorySpace); 584 } 585 586 // Case 3: Configured with "fully dynamic layout maps". 587 if (options.unknownTypeConversion == 588 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap) 589 return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); 590 591 // Case 4: Configured with "static identity layout maps". 592 if (options.unknownTypeConversion == 593 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) 594 return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); 595 596 llvm_unreachable("InferLayoutMap is an invalid option"); 597 } 598 599 BaseMemRefType 600 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 601 Attribute memorySpace) { 602 // Case 1: Unranked memref type. 603 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 604 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 605 memorySpace); 606 } 607 608 // Case 2: Ranked memref type. 609 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 610 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 611 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 612 ShapedType::kDynamicStrideOrOffset); 613 AffineMap stridedLayout = makeStridedLinearLayoutMap( 614 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 615 return MemRefType::get(rankedTensorType.getShape(), 616 rankedTensorType.getElementType(), stridedLayout, 617 memorySpace); 618 } 619 620 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 621 /// the given tensor type is unranked, return an unranked MemRef type. 622 BaseMemRefType 623 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 624 Attribute memorySpace) { 625 // Case 1: Unranked memref type. 626 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 627 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 628 memorySpace); 629 } 630 631 // Case 2: Ranked memref type. 632 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 633 MemRefLayoutAttrInterface layout = {}; 634 return MemRefType::get(rankedTensorType.getShape(), 635 rankedTensorType.getElementType(), layout, 636 memorySpace); 637 } 638