1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/AsmState.h" 14 #include "mlir/IR/BlockAndValueMapping.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/IR/Value.h" 19 #include "llvm/Support/Debug.h" 20 21 //===----------------------------------------------------------------------===// 22 // BufferizableOpInterface 23 //===----------------------------------------------------------------------===// 24 25 namespace mlir { 26 namespace bufferization { 27 28 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 29 30 } // namespace bufferization 31 } // namespace mlir 32 33 #define DEBUG_TYPE "bufferizable-op-interface" 34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 35 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 36 37 using namespace mlir; 38 using namespace bufferization; 39 40 /// Attribute name used to mark region arguments that can be bufferized 41 /// in-place during linalg comprehensive bufferization. 42 constexpr const ::llvm::StringLiteral 43 bufferization::BufferizableOpInterface::kInplaceableAttrName; 44 45 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 46 RewriterBase &rewriter, const AnalysisState &state) { 47 Operation *op = getOperation(); 48 for (OpOperand &opOperand : op->getOpOperands()) { 49 Type operandType = opOperand.get().getType(); 50 if (!operandType.isa<TensorType>()) 51 continue; 52 if (state.isInPlace(opOperand)) 53 continue; 54 if (operandType.isa<UnrankedTensorType>()) 55 return op->emitError("copies of unranked tensors are not supported"); 56 auto tensorType = operandType.dyn_cast<RankedTensorType>(); 57 if (!tensorType) 58 continue; 59 SmallVector<OpResult> aliasingOpResults = 60 state.getAliasingOpResult(opOperand); 61 bool escape = llvm::any_of( 62 aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); 63 Value copy = rewriter.create<AllocTensorOp>( 64 op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape); 65 rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); }); 66 } 67 return success(); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // OpFilter 72 //===----------------------------------------------------------------------===// 73 74 bool OpFilter::isOpAllowed(Operation *op) const { 75 // All other ops: Allow/disallow according to filter. 76 bool isAllowed = !hasAllowRule(); 77 for (const Entry &entry : entries) { 78 bool filterResult = entry.fn(op); 79 switch (entry.type) { 80 case Entry::ALLOW: 81 isAllowed |= filterResult; 82 break; 83 case Entry::DENY: 84 if (filterResult) 85 // DENY filter matches. This op is no allowed. (Even if other ALLOW 86 // filters may match.) 87 return false; 88 }; 89 } 90 return isAllowed; 91 } 92 93 //===----------------------------------------------------------------------===// 94 // BufferizationOptions 95 //===----------------------------------------------------------------------===// 96 97 // Default constructor for BufferizationOptions. 98 BufferizationOptions::BufferizationOptions() = default; 99 100 bool BufferizationOptions::isOpAllowed(Operation *op) const { 101 // Special case: If function boundary bufferization is deactivated, do not 102 // allow ops that belong to the `func` dialect. 103 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 104 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 105 return false; 106 107 return opFilter.isOpAllowed(op); 108 } 109 110 BufferizableOpInterface 111 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 112 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 113 if (!bufferizableOp) 114 return nullptr; 115 if (!isOpAllowed(op)) 116 return nullptr; 117 return bufferizableOp; 118 } 119 120 BufferizableOpInterface 121 BufferizationOptions::dynCastBufferizableOp(Value value) const { 122 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 123 if (isOpAllowed(bufferizableOp.getOperation())) 124 return bufferizableOp; 125 return nullptr; 126 } 127 128 void BufferizationOptions::addDialectStateInitializer( 129 StringRef name, const DialectStateInitFn &fn) { 130 stateInitializers.push_back( 131 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 132 } 133 134 //===----------------------------------------------------------------------===// 135 // Helper functions for BufferizableOpInterface 136 //===----------------------------------------------------------------------===// 137 138 static void setInsertionPointAfter(OpBuilder &b, Value value) { 139 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 140 b.setInsertionPointToStart(bbArg.getOwner()); 141 } else { 142 b.setInsertionPointAfter(value.getDefiningOp()); 143 } 144 } 145 146 /// Determine which OpOperand* will alias with `result` if the op is bufferized 147 /// in place. Return an empty vector if the op is not bufferizable. 148 SmallVector<OpOperand *> 149 AnalysisState::getAliasingOpOperand(OpResult result) const { 150 if (Operation *op = result.getDefiningOp()) 151 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 152 return bufferizableOp.getAliasingOpOperand(result, *this); 153 return {}; 154 } 155 156 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 157 /// in place. Return an empty vector if the op is not bufferizable. 158 SmallVector<OpResult> 159 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 160 if (auto bufferizableOp = 161 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 162 return bufferizableOp.getAliasingOpResult(opOperand, *this); 163 return {}; 164 } 165 166 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 167 /// op is not bufferizable. 168 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 169 if (auto bufferizableOp = 170 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 171 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 172 173 // Unknown op that returns a tensor. The inplace analysis does not support it. 174 // Conservatively return true. 175 return true; 176 } 177 178 /// Return true if `opOperand` bufferizes to a memory write. Return 179 /// `true` if the op is not bufferizable. 180 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 181 if (auto bufferizableOp = 182 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 183 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 184 185 // Unknown op that returns a tensor. The inplace analysis does not support it. 186 // Conservatively return true. 187 return true; 188 } 189 190 /// Return true if `opOperand` does neither read nor write but bufferizes to an 191 /// alias. Return false if the op is not bufferizable. 192 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 193 if (auto bufferizableOp = 194 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 195 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 196 197 // Unknown op that returns a tensor. The inplace analysis does not support it. 198 // Conservatively return false. 199 return false; 200 } 201 202 /// Return true if the given value is read by an op that bufferizes to a memory 203 /// read. Also takes into account ops that create an alias but do not read by 204 /// themselves (e.g., ExtractSliceOp). 205 bool AnalysisState::isValueRead(Value value) const { 206 assert(value.getType().isa<TensorType>() && "expected TensorType"); 207 SmallVector<OpOperand *> workingSet; 208 for (OpOperand &use : value.getUses()) 209 workingSet.push_back(&use); 210 211 while (!workingSet.empty()) { 212 OpOperand *uMaybeReading = workingSet.pop_back_val(); 213 // Skip over all ops that neither read nor write (but create an alias). 214 if (bufferizesToAliasOnly(*uMaybeReading)) 215 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 216 for (OpOperand &use : opResult.getUses()) 217 workingSet.push_back(&use); 218 if (bufferizesToMemoryRead(*uMaybeReading)) 219 return true; 220 } 221 222 return false; 223 } 224 225 // Starting from `value`, follow the use-def chain in reverse, always selecting 226 // the aliasing OpOperands. Find and return Values for which `condition` 227 // evaluates to true. OpOperands of such matching Values are not traversed any 228 // further. 229 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 230 Value value, llvm::function_ref<bool(Value)> condition) const { 231 llvm::SetVector<Value> result, workingSet; 232 workingSet.insert(value); 233 234 while (!workingSet.empty()) { 235 Value value = workingSet.pop_back_val(); 236 if (condition(value) || value.isa<BlockArgument>()) { 237 result.insert(value); 238 continue; 239 } 240 241 OpResult opResult = value.cast<OpResult>(); 242 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 243 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 244 result.insert(value); 245 continue; 246 } 247 248 for (OpOperand *o : opOperands) 249 workingSet.insert(o->get()); 250 } 251 252 return result; 253 } 254 255 // Find the Values of the last preceding write of a given Value. 256 llvm::SetVector<Value> 257 AnalysisState::findLastPrecedingWrite(Value value) const { 258 return findValueInReverseUseDefChain(value, [&](Value value) { 259 Operation *op = value.getDefiningOp(); 260 if (!op) 261 return true; 262 auto bufferizableOp = options.dynCastBufferizableOp(op); 263 if (!bufferizableOp) 264 return true; 265 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 266 }); 267 } 268 269 AnalysisState::AnalysisState(const BufferizationOptions &options) 270 : options(options) { 271 for (const BufferizationOptions::AnalysisStateInitFn &fn : 272 options.stateInitializers) 273 fn(*this); 274 } 275 276 // bufferization.to_memref is not allowed to change the rank. 277 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 278 #ifndef NDEBUG 279 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 280 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 281 rankedTensorType.getRank()) && 282 "to_memref would be invalid: mismatching ranks"); 283 #endif 284 } 285 286 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, 287 const BufferizationOptions &options) { 288 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 289 assert(tensorType && "unexpected non-tensor type"); 290 291 // Replace "%t = to_tensor %m" with %m. 292 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 293 return toTensorOp.memref(); 294 295 // Insert to_memref op. 296 OpBuilder::InsertionGuard g(rewriter); 297 setInsertionPointAfter(rewriter, tensor); 298 Type memrefType = getMemRefType(tensorType, options); 299 ensureToMemrefOpIsValid(tensor, memrefType); 300 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 301 tensor); 302 } 303 304 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate 305 /// a new buffer and copy over data from the existing buffer if out-of-place 306 /// bufferization was decided. 307 FailureOr<Value> 308 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, 309 Optional<ForceInPlacability> overrideInPlace, 310 Optional<Operation *> customCopyInsertionPoint) { 311 const BufferizationOptions &options = analysisState.getOptions(); 312 OpBuilder::InsertionGuard guard(rewriter); 313 Operation *op = opOperand.getOwner(); 314 Location loc = op->getLoc(); 315 SmallVector<OpResult> aliasingOpResults = 316 analysisState.getAliasingOpResult(opOperand); 317 Value operand = opOperand.get(); 318 Value operandBuffer = lookupBuffer(rewriter, operand, options); 319 320 // Can `operandBuffer` be used directly or do we need a copy? 321 bool inplace = 322 overrideInPlace != FORCE_OUT_OF_PLACE && 323 (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand)); 324 if (inplace) 325 return operandBuffer; 326 327 // Bufferizing out-of-place: Allocate a new buffer. 328 // Move insertion point right after `operandBuffer`. That is where the 329 // allocation should be inserted (in the absence of allocation hoisting). 330 setInsertionPointAfter(rewriter, operandBuffer); 331 // Allocate the result buffer. The buffer should be deallocated if the tensor 332 // is not yielded and deallocs are enabled in general. 333 bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { 334 return getAnalysisState().isTensorYielded(v); 335 }); 336 FailureOr<Value> resultBuffer = createAlloc( 337 rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); 338 if (failed(resultBuffer)) 339 return failure(); 340 // Do not copy the buffer if its contents are undefined. 341 if (analysisState.hasUndefinedContents(&opOperand)) 342 return resultBuffer; 343 // Do not copy if the copied data is never read. 344 if (!aliasingOpResults.empty() && 345 !analysisState.bufferizesToMemoryRead(opOperand) && 346 llvm::none_of(aliasingOpResults, [&](OpResult opResult) { 347 return analysisState.isValueRead(opResult); 348 })) 349 return resultBuffer; 350 // Do not copy if this op does not read the data, but writes it. 351 if (analysisState.bufferizesToMemoryWrite(opOperand) && 352 !analysisState.bufferizesToMemoryRead(opOperand)) 353 return resultBuffer; 354 355 if (customCopyInsertionPoint) { 356 rewriter.setInsertionPoint(*customCopyInsertionPoint); 357 } else { 358 // The copy happens right before the op that is bufferized. 359 rewriter.setInsertionPoint(op); 360 } 361 if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer))) 362 return failure(); 363 364 return resultBuffer; 365 } 366 367 /// Return the buffer type for a given Value (tensor) after bufferization. 368 BaseMemRefType BufferizationState::getBufferType(Value value) const { 369 auto tensorType = value.getType().dyn_cast<TensorType>(); 370 assert(tensorType && "unexpected non-tensor type"); 371 372 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 373 return toTensorOp.memref().getType().cast<BaseMemRefType>(); 374 375 return getMemRefType(tensorType, getOptions()); 376 } 377 378 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 379 Operation *op, 380 ValueRange values) { 381 assert(values.size() == op->getNumResults() && 382 "expected one value per OpResult"); 383 OpBuilder::InsertionGuard g(rewriter); 384 385 // Replace all OpResults with the given values. 386 SmallVector<Value> replacements; 387 for (OpResult opResult : op->getOpResults()) { 388 Value replacement = values[opResult.getResultNumber()]; 389 if (opResult.getType().isa<TensorType>()) { 390 // The OpResult is a tensor. Such values are replaced with memrefs during 391 // bufferization. 392 assert((replacement.getType().isa<MemRefType>() || 393 replacement.getType().isa<UnrankedMemRefType>()) && 394 "tensor op result should be replaced with a memref value"); 395 // The existing uses of the OpResult still expect a tensor. Insert a 396 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 397 // loose all of its users and eventually DCE away. 398 rewriter.setInsertionPointAfter(op); 399 replacement = rewriter.create<bufferization::ToTensorOp>( 400 replacement.getLoc(), replacement); 401 } 402 replacements.push_back(replacement); 403 } 404 405 rewriter.replaceOp(op, replacements); 406 } 407 408 //===----------------------------------------------------------------------===// 409 // Bufferization-specific scoped alloc/dealloc insertion support. 410 //===----------------------------------------------------------------------===// 411 412 /// Create a memref allocation with the given type and dynamic extents. 413 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 414 MemRefType type, 415 ValueRange dynShape) const { 416 if (allocationFn) 417 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 418 419 // Default bufferallocation via AllocOp. 420 Value allocated = b.create<memref::AllocOp>( 421 loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment)); 422 return allocated; 423 } 424 425 /// Creates a memref deallocation. The given memref buffer must have been 426 /// allocated using `createAlloc`. 427 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 428 Value allocatedBuffer) const { 429 if (deallocationFn) 430 return (*deallocationFn)(b, loc, allocatedBuffer); 431 432 // Default buffer deallocation via DeallocOp. 433 b.create<memref::DeallocOp>(loc, allocatedBuffer); 434 return success(); 435 } 436 437 static MemRefType getContiguousMemRefType(ShapedType shapedType, 438 Attribute memorySpace = {}) { 439 MemRefLayoutAttrInterface layout = {}; 440 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 441 layout, memorySpace); 442 } 443 444 /// Compute the type of the `memref` to use for allocating the buffer for 445 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 446 /// dynamic dimensions in the returned `memref` type. 447 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 448 Value shapedValue, 449 SmallVectorImpl<Value> &dynShape) { 450 MemRefType allocMemRefType = 451 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 452 453 // Compute the dynamic part of the shape. 454 bool reifiedShapes = false; 455 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 456 shapedValue.getDefiningOp())) { 457 ReifiedRankedShapedTypeDims resultDims; 458 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 459 reifiedShapes = true; 460 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 461 auto &shape = resultDims[resultValue.getResultNumber()]; 462 for (const auto &dim : enumerate(allocMemRefType.getShape())) 463 if (ShapedType::isDynamic(dim.value())) 464 dynShape.push_back(shape[dim.index()]); 465 } 466 } 467 468 if (!reifiedShapes) { 469 for (const auto &dim : enumerate(allocMemRefType.getShape())) 470 if (ShapedType::isDynamic(dim.value())) { 471 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 472 shapedValue.getType().isa<MemRefType>()) && 473 "expected MemRef type"); 474 dynShape.push_back( 475 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 476 } 477 } 478 479 return allocMemRefType; 480 } 481 482 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the 483 /// block in case of a bbArg). 484 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 485 Value shapedValue, 486 Optional<bool> dealloc) { 487 // Take a guard before anything else. 488 OpBuilder::InsertionGuard g(b); 489 490 // Compute allocation memref type. 491 assert(shapedValue.getType().isa<ShapedType>()); 492 SmallVector<Value> dynShape; 493 MemRefType allocMemRefType = 494 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 495 496 // Create the buffer allocation. 497 FailureOr<Value> buffer = 498 getOptions().createAlloc(b, loc, allocMemRefType, dynShape); 499 if (failed(buffer)) 500 return failure(); 501 502 // Should be the buffer be deallocated again or should we let it leak? 503 if (dealloc) { 504 if (!dealloc.getValue()) 505 return *buffer; 506 } else { 507 assert(shapedValue.getType().isa<TensorType>() && 508 "must specify `dealloc` if non-tensor value is passed"); 509 // Buffer should be not be deallocated if deallocs are generally deactivated 510 // or if the tensor is yielded from a block. 511 if (!getOptions().createDeallocs || 512 getAnalysisState().isTensorYielded(shapedValue)) 513 return *buffer; 514 } 515 516 // Create buffer deallocation. 517 b.setInsertionPoint(b.getInsertionBlock()->getTerminator()); 518 if (failed(getOptions().createDealloc(b, loc, *buffer))) 519 return failure(); 520 521 return *buffer; 522 } 523 524 /// Create a memory copy between two memref buffers. 525 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 526 Value from, Value to) const { 527 if (memCpyFn) 528 return (*memCpyFn)(b, loc, from, to); 529 530 b.create<memref::CopyOp>(loc, from, to); 531 return success(); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // Bufferization-specific BlockAndValueMapping support with debugging. 536 //===----------------------------------------------------------------------===// 537 538 bool bufferization::isFunctionArgument(Value value) { 539 auto bbArg = value.dyn_cast<BlockArgument>(); 540 if (!bbArg) 541 return false; 542 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 543 } 544 545 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 546 const BufferizationOptions &options, 547 MemRefLayoutAttrInterface layout, 548 Attribute memorySpace) { 549 // Case 1: Unranked memref type. 550 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 551 assert(!layout && "UnrankedTensorType cannot have a layout map"); 552 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 553 memorySpace); 554 } 555 556 // Case 2: Ranked memref type with specified layout. 557 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 558 if (layout) { 559 return MemRefType::get(rankedTensorType.getShape(), 560 rankedTensorType.getElementType(), layout, 561 memorySpace); 562 } 563 564 // Case 3: Configured with "fully dynamic layout maps". 565 if (options.unknownTypeConversion == 566 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap) 567 return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); 568 569 // Case 4: Configured with "static identity layout maps". 570 if (options.unknownTypeConversion == 571 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) 572 return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); 573 574 llvm_unreachable("InferLayoutMap is an invalid option"); 575 } 576 577 BaseMemRefType 578 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 579 Attribute memorySpace) { 580 // Case 1: Unranked memref type. 581 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 582 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 583 memorySpace); 584 } 585 586 // Case 2: Ranked memref type. 587 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 588 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 589 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 590 ShapedType::kDynamicStrideOrOffset); 591 AffineMap stridedLayout = makeStridedLinearLayoutMap( 592 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 593 return MemRefType::get(rankedTensorType.getShape(), 594 rankedTensorType.getElementType(), stridedLayout, 595 memorySpace); 596 } 597 598 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 599 /// the given tensor type is unranked, return an unranked MemRef type. 600 BaseMemRefType 601 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 602 Attribute memorySpace) { 603 // Case 1: Unranked memref type. 604 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 605 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 606 memorySpace); 607 } 608 609 // Case 2: Ranked memref type. 610 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 611 MemRefLayoutAttrInterface layout = {}; 612 return MemRefType::get(rankedTensorType.getShape(), 613 rankedTensorType.getElementType(), layout, 614 memorySpace); 615 } 616