1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AsmState.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/IR/Value.h" 20 #include "llvm/Support/Debug.h" 21 22 //===----------------------------------------------------------------------===// 23 // BufferizableOpInterface 24 //===----------------------------------------------------------------------===// 25 26 namespace mlir { 27 namespace bufferization { 28 29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 30 31 } // namespace bufferization 32 } // namespace mlir 33 34 #define DEBUG_TYPE "bufferizable-op-interface" 35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 37 38 using namespace mlir; 39 using namespace bufferization; 40 41 /// Attribute name used to mark region arguments that can be bufferized 42 /// in-place during linalg comprehensive bufferization. 43 constexpr const ::llvm::StringLiteral 44 bufferization::BufferizableOpInterface::kInplaceableAttrName; 45 46 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the 47 /// shaped value is copied. Otherwise, a tensor with undefined contents is 48 /// allocated. 49 Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc, 50 Value shapedValue, 51 bool escape, bool copy) { 52 Value tensor; 53 if (shapedValue.getType().isa<RankedTensorType>()) { 54 tensor = shapedValue; 55 } else if (shapedValue.getType().isa<MemRefType>()) { 56 tensor = b.create<ToTensorOp>(loc, shapedValue); 57 } else { 58 llvm_unreachable("expected RankedTensorType or MemRefType"); 59 } 60 RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>(); 61 SmallVector<Value> dynamicSizes; 62 if (!copy) { 63 // Compute the dynamic part of the shape. 64 // First try to query the shape via ReifyRankedShapedTypeOpInterface. 65 bool reifiedShapes = false; 66 if (shapedValue.getType().isa<RankedTensorType>() && 67 shapedValue.isa<OpResult>()) { 68 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 69 shapedValue.getDefiningOp())) { 70 ReifiedRankedShapedTypeDims resultDims; 71 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 72 reifiedShapes = true; 73 auto &shape = 74 resultDims[shapedValue.cast<OpResult>().getResultNumber()]; 75 for (const auto &dim : enumerate(tensorType.getShape())) 76 if (ShapedType::isDynamic(dim.value())) 77 dynamicSizes.push_back(shape[dim.index()]); 78 } 79 } 80 } 81 82 // If the shape could not be reified, create DimOps. 83 if (!reifiedShapes) 84 populateDynamicDimSizes(b, loc, tensor, dynamicSizes); 85 } 86 87 return b.create<AllocTensorOp>(loc, tensorType, dynamicSizes, 88 copy ? tensor : Value(), escape); 89 } 90 91 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 92 RewriterBase &rewriter, const AnalysisState &state) { 93 OpBuilder::InsertionGuard g(rewriter); 94 Operation *op = getOperation(); 95 SmallVector<OpOperand *> outOfPlaceOpOperands; 96 DenseSet<OpOperand *> copiedOpOperands; 97 DenseSet<OpOperand *> escapingOpOperandCopies; 98 SmallVector<OpResult> outOfPlaceOpResults; 99 DenseSet<OpResult> copiedOpResults; 100 DenseSet<OpResult> escapingOpResultCopies; 101 102 // Find all out-of-place OpOperands. 103 for (OpOperand &opOperand : op->getOpOperands()) { 104 Type operandType = opOperand.get().getType(); 105 if (!operandType.isa<TensorType>()) 106 continue; 107 if (state.isInPlace(opOperand)) 108 continue; 109 if (operandType.isa<UnrankedTensorType>()) 110 return op->emitError("copies of unranked tensors are not supported"); 111 112 SmallVector<OpResult> aliasingOpResults = 113 state.getAliasingOpResult(opOperand); 114 // Is the result yielded from a block? Or are deallocations turned off 115 // entirely? In either case, mark the allocation as "escaping", so that it 116 // will not be deallocated. 117 bool escape = !state.getOptions().createDeallocs || 118 llvm::any_of(aliasingOpResults, [&](Value v) { 119 return state.isTensorYielded(v); 120 }); 121 122 if (aliasingOpResults.size() == 1 && 123 !state.bufferizesToMemoryWrite(opOperand) && 124 state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { 125 // The op itself does not write but may create exactly one alias. Instead 126 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes 127 // be smaller than the OpOperand (e.g., in the case of an extract_slice, 128 // where the result is usually a smaller part of the source). 129 outOfPlaceOpResults.push_back(aliasingOpResults.front()); 130 if (!state.canOmitTensorCopy(opOperand)) 131 copiedOpResults.insert(aliasingOpResults.front()); 132 if (escape) 133 escapingOpResultCopies.insert(aliasingOpResults.front()); 134 } else { 135 // In all other cases, make a copy of the OpOperand. 136 outOfPlaceOpOperands.push_back(&opOperand); 137 if (!state.canOmitTensorCopy(opOperand)) 138 copiedOpOperands.insert(&opOperand); 139 if (escape) 140 escapingOpOperandCopies.insert(&opOperand); 141 } 142 } 143 144 // Insert copies of OpOperands. 145 rewriter.setInsertionPoint(op); 146 for (OpOperand *opOperand : outOfPlaceOpOperands) { 147 Value copy = allocateTensorForShapedValue( 148 rewriter, op->getLoc(), opOperand->get(), 149 escapingOpOperandCopies.contains(opOperand), 150 copiedOpOperands.contains(opOperand)); 151 rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); }); 152 } 153 154 // Insert copies of OpResults. 155 rewriter.setInsertionPointAfter(op); 156 for (OpResult opResult : outOfPlaceOpResults) { 157 Value copy = 158 allocateTensorForShapedValue(rewriter, op->getLoc(), opResult, 159 escapingOpResultCopies.contains(opResult), 160 copiedOpResults.count(opResult)); 161 SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range( 162 opResult.getUses(), [](OpOperand &use) { return &use; })); 163 for (OpOperand *use : uses) { 164 // Do not update the alloc_tensor op that we just created. 165 if (use->getOwner() != copy.getDefiningOp()) 166 rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); }); 167 } 168 } 169 170 return success(); 171 } 172 173 //===----------------------------------------------------------------------===// 174 // OpFilter 175 //===----------------------------------------------------------------------===// 176 177 bool OpFilter::isOpAllowed(Operation *op) const { 178 // All other ops: Allow/disallow according to filter. 179 bool isAllowed = !hasAllowRule(); 180 for (const Entry &entry : entries) { 181 bool filterResult = entry.fn(op); 182 switch (entry.type) { 183 case Entry::ALLOW: 184 isAllowed |= filterResult; 185 break; 186 case Entry::DENY: 187 if (filterResult) 188 // DENY filter matches. This op is no allowed. (Even if other ALLOW 189 // filters may match.) 190 return false; 191 }; 192 } 193 return isAllowed; 194 } 195 196 //===----------------------------------------------------------------------===// 197 // BufferizationOptions 198 //===----------------------------------------------------------------------===// 199 200 // Default constructor for BufferizationOptions. 201 BufferizationOptions::BufferizationOptions() = default; 202 203 bool BufferizationOptions::isOpAllowed(Operation *op) const { 204 // Special case: If function boundary bufferization is deactivated, do not 205 // allow ops that belong to the `func` dialect. 206 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 207 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 208 return false; 209 210 return opFilter.isOpAllowed(op); 211 } 212 213 BufferizableOpInterface 214 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 215 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 216 if (!bufferizableOp) 217 return nullptr; 218 if (!isOpAllowed(op)) 219 return nullptr; 220 return bufferizableOp; 221 } 222 223 BufferizableOpInterface 224 BufferizationOptions::dynCastBufferizableOp(Value value) const { 225 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 226 if (isOpAllowed(bufferizableOp.getOperation())) 227 return bufferizableOp; 228 return nullptr; 229 } 230 231 void BufferizationOptions::addDialectStateInitializer( 232 StringRef name, const DialectStateInitFn &fn) { 233 stateInitializers.push_back( 234 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // Helper functions for BufferizableOpInterface 239 //===----------------------------------------------------------------------===// 240 241 static void setInsertionPointAfter(OpBuilder &b, Value value) { 242 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 243 b.setInsertionPointToStart(bbArg.getOwner()); 244 } else { 245 b.setInsertionPointAfter(value.getDefiningOp()); 246 } 247 } 248 249 /// Determine which OpOperand* will alias with `result` if the op is bufferized 250 /// in place. Return an empty vector if the op is not bufferizable. 251 SmallVector<OpOperand *> 252 AnalysisState::getAliasingOpOperand(OpResult result) const { 253 if (Operation *op = result.getDefiningOp()) 254 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 255 return bufferizableOp.getAliasingOpOperand(result, *this); 256 return {}; 257 } 258 259 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 260 /// in place. Return an empty vector if the op is not bufferizable. 261 SmallVector<OpResult> 262 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 263 if (auto bufferizableOp = 264 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 265 return bufferizableOp.getAliasingOpResult(opOperand, *this); 266 return {}; 267 } 268 269 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 270 /// op is not bufferizable. 271 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 272 if (auto bufferizableOp = 273 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 274 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 275 276 // Unknown op that returns a tensor. The inplace analysis does not support it. 277 // Conservatively return true. 278 return true; 279 } 280 281 /// Return true if `opOperand` bufferizes to a memory write. Return 282 /// `true` if the op is not bufferizable. 283 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 284 if (auto bufferizableOp = 285 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 286 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 287 288 // Unknown op that returns a tensor. The inplace analysis does not support it. 289 // Conservatively return true. 290 return true; 291 } 292 293 /// Return true if `opOperand` does neither read nor write but bufferizes to an 294 /// alias. Return false if the op is not bufferizable. 295 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 296 if (auto bufferizableOp = 297 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 298 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 299 300 // Unknown op that returns a tensor. The inplace analysis does not support it. 301 // Conservatively return false. 302 return false; 303 } 304 305 /// Return true if the given value is read by an op that bufferizes to a memory 306 /// read. Also takes into account ops that create an alias but do not read by 307 /// themselves (e.g., ExtractSliceOp). 308 bool AnalysisState::isValueRead(Value value) const { 309 assert(value.getType().isa<TensorType>() && "expected TensorType"); 310 SmallVector<OpOperand *> workingSet; 311 for (OpOperand &use : value.getUses()) 312 workingSet.push_back(&use); 313 314 while (!workingSet.empty()) { 315 OpOperand *uMaybeReading = workingSet.pop_back_val(); 316 // Skip over all ops that neither read nor write (but create an alias). 317 if (bufferizesToAliasOnly(*uMaybeReading)) 318 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 319 for (OpOperand &use : opResult.getUses()) 320 workingSet.push_back(&use); 321 if (bufferizesToMemoryRead(*uMaybeReading)) 322 return true; 323 } 324 325 return false; 326 } 327 328 // Starting from `value`, follow the use-def chain in reverse, always selecting 329 // the aliasing OpOperands. Find and return Values for which `condition` 330 // evaluates to true. OpOperands of such matching Values are not traversed any 331 // further. 332 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 333 Value value, llvm::function_ref<bool(Value)> condition) const { 334 llvm::SetVector<Value> result, workingSet; 335 workingSet.insert(value); 336 337 while (!workingSet.empty()) { 338 Value value = workingSet.pop_back_val(); 339 if (condition(value) || value.isa<BlockArgument>()) { 340 result.insert(value); 341 continue; 342 } 343 344 OpResult opResult = value.cast<OpResult>(); 345 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 346 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 347 result.insert(value); 348 continue; 349 } 350 351 for (OpOperand *o : opOperands) 352 workingSet.insert(o->get()); 353 } 354 355 return result; 356 } 357 358 // Find the Values of the last preceding write of a given Value. 359 llvm::SetVector<Value> 360 AnalysisState::findLastPrecedingWrite(Value value) const { 361 return findValueInReverseUseDefChain(value, [&](Value value) { 362 Operation *op = value.getDefiningOp(); 363 if (!op) 364 return true; 365 auto bufferizableOp = options.dynCastBufferizableOp(op); 366 if (!bufferizableOp) 367 return true; 368 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 369 }); 370 } 371 372 AnalysisState::AnalysisState(const BufferizationOptions &options) 373 : options(options) { 374 for (const BufferizationOptions::AnalysisStateInitFn &fn : 375 options.stateInitializers) 376 fn(*this); 377 } 378 379 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const { 380 // Do not copy if the tensor has undefined contents. 381 if (hasUndefinedContents(&opOperand)) 382 return true; 383 384 // Do not copy if the buffer of the tensor is entirely overwritten (with 385 // values that do not depend on the old tensor). 386 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 387 return true; 388 389 // Do not copy if the tensor is never read. 390 SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand); 391 if (!bufferizesToMemoryRead(opOperand) && 392 llvm::none_of(aliasingOpResults, 393 [&](OpResult opResult) { return isValueRead(opResult); })) 394 return true; 395 396 // Default: Cannot omit the copy. 397 return false; 398 } 399 400 bool AnalysisState::isInPlace(OpOperand &opOperand) const { 401 // ToMemrefOps are always in-place. 402 if (isa<ToMemrefOp>(opOperand.getOwner())) 403 return true; 404 405 // In the absence of analysis information, OpOperands that bufferize to a 406 // memory write are out-of-place, i.e., an alloc and copy is inserted. 407 return !bufferizesToMemoryWrite(opOperand); 408 } 409 410 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const { 411 // In the absence of analysis information, we do not know if the values are 412 // equivalent. The conservative answer is "false". 413 return false; 414 } 415 416 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const { 417 // In the absence of analysis information, we do not know if the values may be 418 // aliasing. The conservative answer is "true". 419 return true; 420 } 421 422 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 423 // In the absence of analysis information, the conservative answer is "false". 424 return false; 425 } 426 427 bool AnalysisState::isTensorYielded(Value tensor) const { 428 // In the absence of analysis information, the conservative answer is "true". 429 if (!tensor.getDefiningOp<AllocTensorOp>()) 430 return true; 431 432 // For AllocTensorOp results, we can do better: They do not alias with any 433 // preceding value, so we can follow SSA use-def chains and do a simple 434 // analysis. 435 SmallVector<OpOperand *> worklist; 436 for (OpOperand &use : tensor.getUses()) 437 worklist.push_back(&use); 438 439 while (!worklist.empty()) { 440 OpOperand *operand = worklist.pop_back_val(); 441 Operation *op = operand->getOwner(); 442 443 // If the op is not bufferizable, we can safely assume that the value is not 444 // yielded. (When bufferizing that op, it must handle such cases.) 445 if (!options.dynCastBufferizableOp(op)) 446 continue; 447 448 // We cannot analyze through ToMemrefOps, so we have to conservatively 449 // assume that the value is yielded. 450 if (isa<ToMemrefOp>(op)) 451 return true; 452 453 // Check if the op is returning/yielding. 454 if (isRegionReturnLike(op)) 455 return true; 456 457 // Add all aliasing OpResults to the worklist. 458 // Note: In the absence of detailed analysis information (e.g., there may be 459 // no function call analysis information), this `getAliasingOpResult` is 460 // conservative and may report additional OpResults as potentially aliasing. 461 for (OpResult opResult : getAliasingOpResult(*operand)) 462 for (OpOperand &use : opResult.getUses()) 463 worklist.push_back(&use); 464 } 465 466 // No ReturnLike op found: The value is not yielded. 467 return false; 468 } 469 470 // bufferization.to_memref is not allowed to change the rank. 471 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 472 #ifndef NDEBUG 473 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 474 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 475 rankedTensorType.getRank()) && 476 "to_memref would be invalid: mismatching ranks"); 477 #endif 478 } 479 480 Value bufferization::getBuffer(RewriterBase &rewriter, Value value, 481 const BufferizationOptions &options) { 482 auto tensorType = value.getType().dyn_cast<TensorType>(); 483 assert(tensorType && "unexpected non-tensor type"); 484 485 // Replace "%t = to_tensor %m" with %m. 486 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 487 return toTensorOp.getMemref(); 488 489 // Insert to_memref op. 490 OpBuilder::InsertionGuard g(rewriter); 491 setInsertionPointAfter(rewriter, value); 492 Type memrefType = getMemRefType(tensorType, options); 493 ensureToMemrefOpIsValid(value, memrefType); 494 return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType, 495 value); 496 } 497 498 /// Return the buffer type for a given Value (tensor) after bufferization. 499 BaseMemRefType 500 bufferization::getBufferType(Value value, const BufferizationOptions &options) { 501 auto tensorType = value.getType().dyn_cast<TensorType>(); 502 assert(tensorType && "unexpected non-tensor type"); 503 504 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 505 return toTensorOp.getMemref().getType().cast<BaseMemRefType>(); 506 507 return getMemRefType(tensorType, options); 508 } 509 510 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 511 Operation *op, 512 ValueRange values) { 513 assert(values.size() == op->getNumResults() && 514 "expected one value per OpResult"); 515 OpBuilder::InsertionGuard g(rewriter); 516 517 // Replace all OpResults with the given values. 518 SmallVector<Value> replacements; 519 for (OpResult opResult : op->getOpResults()) { 520 Value replacement = values[opResult.getResultNumber()]; 521 if (opResult.getType().isa<TensorType>()) { 522 // The OpResult is a tensor. Such values are replaced with memrefs during 523 // bufferization. 524 assert((replacement.getType().isa<MemRefType>() || 525 replacement.getType().isa<UnrankedMemRefType>()) && 526 "tensor op result should be replaced with a memref value"); 527 // The existing uses of the OpResult still expect a tensor. Insert a 528 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 529 // loose all of its users and eventually DCE away. 530 rewriter.setInsertionPointAfter(op); 531 replacement = rewriter.create<bufferization::ToTensorOp>( 532 replacement.getLoc(), replacement); 533 } 534 replacements.push_back(replacement); 535 } 536 537 rewriter.replaceOp(op, replacements); 538 } 539 540 //===----------------------------------------------------------------------===// 541 // Bufferization-specific scoped alloc/dealloc insertion support. 542 //===----------------------------------------------------------------------===// 543 544 /// Create a memref allocation with the given type and dynamic extents. 545 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 546 MemRefType type, 547 ValueRange dynShape) const { 548 if (allocationFn) 549 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 550 551 // Default bufferallocation via AllocOp. 552 if (bufferAlignment != 0) 553 return b 554 .create<memref::AllocOp>(loc, type, dynShape, 555 b.getI64IntegerAttr(bufferAlignment)) 556 .getResult(); 557 return b.create<memref::AllocOp>(loc, type, dynShape).getResult(); 558 } 559 560 /// Creates a memref deallocation. The given memref buffer must have been 561 /// allocated using `createAlloc`. 562 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 563 Value allocatedBuffer) const { 564 if (deallocationFn) 565 return (*deallocationFn)(b, loc, allocatedBuffer); 566 567 // Default buffer deallocation via DeallocOp. 568 b.create<memref::DeallocOp>(loc, allocatedBuffer); 569 return success(); 570 } 571 572 /// Create a memory copy between two memref buffers. 573 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 574 Value from, Value to) const { 575 if (memCpyFn) 576 return (*memCpyFn)(b, loc, from, to); 577 578 b.create<memref::CopyOp>(loc, from, to); 579 return success(); 580 } 581 582 //===----------------------------------------------------------------------===// 583 // Bufferization-specific BlockAndValueMapping support with debugging. 584 //===----------------------------------------------------------------------===// 585 586 bool bufferization::isFunctionArgument(Value value) { 587 auto bbArg = value.dyn_cast<BlockArgument>(); 588 if (!bbArg) 589 return false; 590 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 591 } 592 593 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 594 const BufferizationOptions &options, 595 MemRefLayoutAttrInterface layout, 596 Attribute memorySpace) { 597 // Case 1: Unranked memref type. 598 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 599 assert(!layout && "UnrankedTensorType cannot have a layout map"); 600 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 601 memorySpace); 602 } 603 604 // Case 2: Ranked memref type with specified layout. 605 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 606 if (layout) { 607 return MemRefType::get(rankedTensorType.getShape(), 608 rankedTensorType.getElementType(), layout, 609 memorySpace); 610 } 611 612 // Case 3: Configured with "fully dynamic layout maps". 613 if (options.unknownTypeConversion == 614 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap) 615 return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); 616 617 // Case 4: Configured with "static identity layout maps". 618 if (options.unknownTypeConversion == 619 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) 620 return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); 621 622 llvm_unreachable("InferLayoutMap is an invalid option"); 623 } 624 625 BaseMemRefType 626 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 627 Attribute memorySpace) { 628 // Case 1: Unranked memref type. 629 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 630 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 631 memorySpace); 632 } 633 634 // Case 2: Ranked memref type. 635 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 636 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 637 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 638 ShapedType::kDynamicStrideOrOffset); 639 AffineMap stridedLayout = makeStridedLinearLayoutMap( 640 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 641 return MemRefType::get(rankedTensorType.getShape(), 642 rankedTensorType.getElementType(), stridedLayout, 643 memorySpace); 644 } 645 646 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 647 /// the given tensor type is unranked, return an unranked MemRef type. 648 BaseMemRefType 649 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 650 Attribute memorySpace) { 651 // Case 1: Unranked memref type. 652 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 653 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 654 memorySpace); 655 } 656 657 // Case 2: Ranked memref type. 658 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 659 MemRefLayoutAttrInterface layout = {}; 660 return MemRefType::get(rankedTensorType.getShape(), 661 rankedTensorType.getElementType(), layout, 662 memorySpace); 663 } 664