1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/AsmState.h" 14 #include "mlir/IR/BlockAndValueMapping.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/IR/Value.h" 19 #include "llvm/Support/Debug.h" 20 21 //===----------------------------------------------------------------------===// 22 // BufferizableOpInterface 23 //===----------------------------------------------------------------------===// 24 25 namespace mlir { 26 namespace bufferization { 27 28 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 29 30 } // namespace bufferization 31 } // namespace mlir 32 33 #define DEBUG_TYPE "bufferizable-op-interface" 34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 35 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 36 37 using namespace mlir; 38 using namespace bufferization; 39 40 /// Attribute name used to mark region arguments that can be bufferized 41 /// in-place during linalg comprehensive bufferization. 42 constexpr const ::llvm::StringLiteral 43 bufferization::BufferizableOpInterface::kInplaceableAttrName; 44 45 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 46 RewriterBase &rewriter, const AnalysisState &state) { 47 OpBuilder::InsertionGuard g(rewriter); 48 Operation *op = getOperation(); 49 SmallVector<OpOperand *> outOfPlaceOpOperands; 50 SmallVector<OpResult> outOfPlaceOpResults; 51 52 // Find all out-of-place OpOperands. 53 for (OpOperand &opOperand : op->getOpOperands()) { 54 Type operandType = opOperand.get().getType(); 55 if (!operandType.isa<TensorType>()) 56 continue; 57 if (state.isInPlace(opOperand)) 58 continue; 59 if (operandType.isa<UnrankedTensorType>()) 60 return op->emitError("copies of unranked tensors are not supported"); 61 62 SmallVector<OpResult> aliasingOpResults = 63 state.getAliasingOpResult(opOperand); 64 if (aliasingOpResults.size() == 1 && 65 !state.bufferizesToMemoryWrite(opOperand) && 66 state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { 67 // The op itself does not write but may create exactly one alias. Instead 68 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes 69 // be smaller than the OpOperand (e.g., in the case of an extract_slice, 70 // where the result is usually a smaller part of the source). 71 outOfPlaceOpResults.push_back(aliasingOpResults.front()); 72 } else { 73 // In all other cases, make a copy of the OpOperand. 74 outOfPlaceOpOperands.push_back(&opOperand); 75 } 76 } 77 78 // Insert copies of OpOperands. 79 rewriter.setInsertionPoint(op); 80 for (OpOperand *opOperand : outOfPlaceOpOperands) { 81 auto tensorType = opOperand->get().getType().cast<RankedTensorType>(); 82 SmallVector<OpResult> aliasingOpResults = 83 state.getAliasingOpResult(*opOperand); 84 bool escape = llvm::any_of( 85 aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); 86 Value copy = rewriter.create<AllocTensorOp>( 87 op->getLoc(), tensorType, ValueRange(), opOperand->get(), escape); 88 rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); }); 89 } 90 91 // Insert copies of OpResults. 92 rewriter.setInsertionPointAfter(op); 93 for (OpResult opResult : outOfPlaceOpResults) { 94 auto tensorType = opResult.getType().cast<RankedTensorType>(); 95 bool escape = state.isTensorYielded(opResult); 96 Value copy = rewriter.create<AllocTensorOp>(op->getLoc(), tensorType, 97 ValueRange(), opResult, escape); 98 SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range( 99 opResult.getUses(), [](OpOperand &use) { return &use; })); 100 for (OpOperand *use : uses) { 101 // Do not update the alloc_tensor op that we just created. 102 if (use->getOwner() != copy.getDefiningOp()) 103 rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); }); 104 } 105 } 106 107 return success(); 108 } 109 110 //===----------------------------------------------------------------------===// 111 // OpFilter 112 //===----------------------------------------------------------------------===// 113 114 bool OpFilter::isOpAllowed(Operation *op) const { 115 // All other ops: Allow/disallow according to filter. 116 bool isAllowed = !hasAllowRule(); 117 for (const Entry &entry : entries) { 118 bool filterResult = entry.fn(op); 119 switch (entry.type) { 120 case Entry::ALLOW: 121 isAllowed |= filterResult; 122 break; 123 case Entry::DENY: 124 if (filterResult) 125 // DENY filter matches. This op is no allowed. (Even if other ALLOW 126 // filters may match.) 127 return false; 128 }; 129 } 130 return isAllowed; 131 } 132 133 //===----------------------------------------------------------------------===// 134 // BufferizationOptions 135 //===----------------------------------------------------------------------===// 136 137 // Default constructor for BufferizationOptions. 138 BufferizationOptions::BufferizationOptions() = default; 139 140 bool BufferizationOptions::isOpAllowed(Operation *op) const { 141 // Special case: If function boundary bufferization is deactivated, do not 142 // allow ops that belong to the `func` dialect. 143 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 144 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 145 return false; 146 147 return opFilter.isOpAllowed(op); 148 } 149 150 BufferizableOpInterface 151 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 152 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 153 if (!bufferizableOp) 154 return nullptr; 155 if (!isOpAllowed(op)) 156 return nullptr; 157 return bufferizableOp; 158 } 159 160 BufferizableOpInterface 161 BufferizationOptions::dynCastBufferizableOp(Value value) const { 162 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 163 if (isOpAllowed(bufferizableOp.getOperation())) 164 return bufferizableOp; 165 return nullptr; 166 } 167 168 void BufferizationOptions::addDialectStateInitializer( 169 StringRef name, const DialectStateInitFn &fn) { 170 stateInitializers.push_back( 171 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 172 } 173 174 //===----------------------------------------------------------------------===// 175 // Helper functions for BufferizableOpInterface 176 //===----------------------------------------------------------------------===// 177 178 static void setInsertionPointAfter(OpBuilder &b, Value value) { 179 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 180 b.setInsertionPointToStart(bbArg.getOwner()); 181 } else { 182 b.setInsertionPointAfter(value.getDefiningOp()); 183 } 184 } 185 186 /// Determine which OpOperand* will alias with `result` if the op is bufferized 187 /// in place. Return an empty vector if the op is not bufferizable. 188 SmallVector<OpOperand *> 189 AnalysisState::getAliasingOpOperand(OpResult result) const { 190 if (Operation *op = result.getDefiningOp()) 191 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 192 return bufferizableOp.getAliasingOpOperand(result, *this); 193 return {}; 194 } 195 196 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 197 /// in place. Return an empty vector if the op is not bufferizable. 198 SmallVector<OpResult> 199 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 200 if (auto bufferizableOp = 201 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 202 return bufferizableOp.getAliasingOpResult(opOperand, *this); 203 return {}; 204 } 205 206 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 207 /// op is not bufferizable. 208 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 209 if (auto bufferizableOp = 210 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 211 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 212 213 // Unknown op that returns a tensor. The inplace analysis does not support it. 214 // Conservatively return true. 215 return true; 216 } 217 218 /// Return true if `opOperand` bufferizes to a memory write. Return 219 /// `true` if the op is not bufferizable. 220 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 221 if (auto bufferizableOp = 222 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 223 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 224 225 // Unknown op that returns a tensor. The inplace analysis does not support it. 226 // Conservatively return true. 227 return true; 228 } 229 230 /// Return true if `opOperand` does neither read nor write but bufferizes to an 231 /// alias. Return false if the op is not bufferizable. 232 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 233 if (auto bufferizableOp = 234 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 235 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 236 237 // Unknown op that returns a tensor. The inplace analysis does not support it. 238 // Conservatively return false. 239 return false; 240 } 241 242 /// Return true if the given value is read by an op that bufferizes to a memory 243 /// read. Also takes into account ops that create an alias but do not read by 244 /// themselves (e.g., ExtractSliceOp). 245 bool AnalysisState::isValueRead(Value value) const { 246 assert(value.getType().isa<TensorType>() && "expected TensorType"); 247 SmallVector<OpOperand *> workingSet; 248 for (OpOperand &use : value.getUses()) 249 workingSet.push_back(&use); 250 251 while (!workingSet.empty()) { 252 OpOperand *uMaybeReading = workingSet.pop_back_val(); 253 // Skip over all ops that neither read nor write (but create an alias). 254 if (bufferizesToAliasOnly(*uMaybeReading)) 255 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 256 for (OpOperand &use : opResult.getUses()) 257 workingSet.push_back(&use); 258 if (bufferizesToMemoryRead(*uMaybeReading)) 259 return true; 260 } 261 262 return false; 263 } 264 265 // Starting from `value`, follow the use-def chain in reverse, always selecting 266 // the aliasing OpOperands. Find and return Values for which `condition` 267 // evaluates to true. OpOperands of such matching Values are not traversed any 268 // further. 269 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 270 Value value, llvm::function_ref<bool(Value)> condition) const { 271 llvm::SetVector<Value> result, workingSet; 272 workingSet.insert(value); 273 274 while (!workingSet.empty()) { 275 Value value = workingSet.pop_back_val(); 276 if (condition(value) || value.isa<BlockArgument>()) { 277 result.insert(value); 278 continue; 279 } 280 281 OpResult opResult = value.cast<OpResult>(); 282 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 283 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 284 result.insert(value); 285 continue; 286 } 287 288 for (OpOperand *o : opOperands) 289 workingSet.insert(o->get()); 290 } 291 292 return result; 293 } 294 295 // Find the Values of the last preceding write of a given Value. 296 llvm::SetVector<Value> 297 AnalysisState::findLastPrecedingWrite(Value value) const { 298 return findValueInReverseUseDefChain(value, [&](Value value) { 299 Operation *op = value.getDefiningOp(); 300 if (!op) 301 return true; 302 auto bufferizableOp = options.dynCastBufferizableOp(op); 303 if (!bufferizableOp) 304 return true; 305 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 306 }); 307 } 308 309 AnalysisState::AnalysisState(const BufferizationOptions &options) 310 : options(options) { 311 for (const BufferizationOptions::AnalysisStateInitFn &fn : 312 options.stateInitializers) 313 fn(*this); 314 } 315 316 // bufferization.to_memref is not allowed to change the rank. 317 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 318 #ifndef NDEBUG 319 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 320 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 321 rankedTensorType.getRank()) && 322 "to_memref would be invalid: mismatching ranks"); 323 #endif 324 } 325 326 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, 327 const BufferizationOptions &options) { 328 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 329 assert(tensorType && "unexpected non-tensor type"); 330 331 // Replace "%t = to_tensor %m" with %m. 332 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 333 return toTensorOp.memref(); 334 335 // Insert to_memref op. 336 OpBuilder::InsertionGuard g(rewriter); 337 setInsertionPointAfter(rewriter, tensor); 338 Type memrefType = getMemRefType(tensorType, options); 339 ensureToMemrefOpIsValid(tensor, memrefType); 340 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 341 tensor); 342 } 343 344 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate 345 /// a new buffer and copy over data from the existing buffer if out-of-place 346 /// bufferization was decided. 347 FailureOr<Value> 348 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, 349 Optional<ForceInPlacability> overrideInPlace, 350 Optional<Operation *> customCopyInsertionPoint) { 351 const BufferizationOptions &options = analysisState.getOptions(); 352 OpBuilder::InsertionGuard guard(rewriter); 353 Operation *op = opOperand.getOwner(); 354 Location loc = op->getLoc(); 355 SmallVector<OpResult> aliasingOpResults = 356 analysisState.getAliasingOpResult(opOperand); 357 Value operand = opOperand.get(); 358 Value operandBuffer = lookupBuffer(rewriter, operand, options); 359 360 // Can `operandBuffer` be used directly or do we need a copy? 361 bool inplace = 362 overrideInPlace != FORCE_OUT_OF_PLACE && 363 (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand)); 364 if (inplace) 365 return operandBuffer; 366 367 // Bufferizing out-of-place: Allocate a new buffer. 368 // Move insertion point right after `operandBuffer`. That is where the 369 // allocation should be inserted (in the absence of allocation hoisting). 370 setInsertionPointAfter(rewriter, operandBuffer); 371 // Allocate the result buffer. The buffer should be deallocated if the tensor 372 // is not yielded and deallocs are enabled in general. 373 bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { 374 return getAnalysisState().isTensorYielded(v); 375 }); 376 FailureOr<Value> resultBuffer = createAlloc( 377 rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); 378 if (failed(resultBuffer)) 379 return failure(); 380 // Do not copy the buffer if its contents are undefined. 381 if (analysisState.hasUndefinedContents(&opOperand)) 382 return resultBuffer; 383 // Do not copy if the copied data is never read. 384 if (!aliasingOpResults.empty() && 385 !analysisState.bufferizesToMemoryRead(opOperand) && 386 llvm::none_of(aliasingOpResults, [&](OpResult opResult) { 387 return analysisState.isValueRead(opResult); 388 })) 389 return resultBuffer; 390 // Do not copy if this op does not read the data, but writes it. 391 if (analysisState.bufferizesToMemoryWrite(opOperand) && 392 !analysisState.bufferizesToMemoryRead(opOperand)) 393 return resultBuffer; 394 395 if (customCopyInsertionPoint) { 396 rewriter.setInsertionPoint(*customCopyInsertionPoint); 397 } else { 398 // The copy happens right before the op that is bufferized. 399 rewriter.setInsertionPoint(op); 400 } 401 if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer))) 402 return failure(); 403 404 return resultBuffer; 405 } 406 407 /// Return the buffer type for a given Value (tensor) after bufferization. 408 BaseMemRefType BufferizationState::getBufferType(Value value) const { 409 auto tensorType = value.getType().dyn_cast<TensorType>(); 410 assert(tensorType && "unexpected non-tensor type"); 411 412 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 413 return toTensorOp.memref().getType().cast<BaseMemRefType>(); 414 415 return getMemRefType(tensorType, getOptions()); 416 } 417 418 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 419 Operation *op, 420 ValueRange values) { 421 assert(values.size() == op->getNumResults() && 422 "expected one value per OpResult"); 423 OpBuilder::InsertionGuard g(rewriter); 424 425 // Replace all OpResults with the given values. 426 SmallVector<Value> replacements; 427 for (OpResult opResult : op->getOpResults()) { 428 Value replacement = values[opResult.getResultNumber()]; 429 if (opResult.getType().isa<TensorType>()) { 430 // The OpResult is a tensor. Such values are replaced with memrefs during 431 // bufferization. 432 assert((replacement.getType().isa<MemRefType>() || 433 replacement.getType().isa<UnrankedMemRefType>()) && 434 "tensor op result should be replaced with a memref value"); 435 // The existing uses of the OpResult still expect a tensor. Insert a 436 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 437 // loose all of its users and eventually DCE away. 438 rewriter.setInsertionPointAfter(op); 439 replacement = rewriter.create<bufferization::ToTensorOp>( 440 replacement.getLoc(), replacement); 441 } 442 replacements.push_back(replacement); 443 } 444 445 rewriter.replaceOp(op, replacements); 446 } 447 448 //===----------------------------------------------------------------------===// 449 // Bufferization-specific scoped alloc/dealloc insertion support. 450 //===----------------------------------------------------------------------===// 451 452 /// Create a memref allocation with the given type and dynamic extents. 453 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 454 MemRefType type, 455 ValueRange dynShape) const { 456 if (allocationFn) 457 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 458 459 // Default bufferallocation via AllocOp. 460 Value allocated = b.create<memref::AllocOp>( 461 loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment)); 462 return allocated; 463 } 464 465 /// Creates a memref deallocation. The given memref buffer must have been 466 /// allocated using `createAlloc`. 467 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 468 Value allocatedBuffer) const { 469 if (deallocationFn) 470 return (*deallocationFn)(b, loc, allocatedBuffer); 471 472 // Default buffer deallocation via DeallocOp. 473 b.create<memref::DeallocOp>(loc, allocatedBuffer); 474 return success(); 475 } 476 477 static MemRefType getContiguousMemRefType(ShapedType shapedType, 478 Attribute memorySpace = {}) { 479 MemRefLayoutAttrInterface layout = {}; 480 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 481 layout, memorySpace); 482 } 483 484 /// Compute the type of the `memref` to use for allocating the buffer for 485 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 486 /// dynamic dimensions in the returned `memref` type. 487 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 488 Value shapedValue, 489 SmallVectorImpl<Value> &dynShape) { 490 MemRefType allocMemRefType = 491 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 492 493 // Compute the dynamic part of the shape. 494 bool reifiedShapes = false; 495 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 496 shapedValue.getDefiningOp())) { 497 ReifiedRankedShapedTypeDims resultDims; 498 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 499 reifiedShapes = true; 500 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 501 auto &shape = resultDims[resultValue.getResultNumber()]; 502 for (const auto &dim : enumerate(allocMemRefType.getShape())) 503 if (ShapedType::isDynamic(dim.value())) 504 dynShape.push_back(shape[dim.index()]); 505 } 506 } 507 508 if (!reifiedShapes) { 509 for (const auto &dim : enumerate(allocMemRefType.getShape())) 510 if (ShapedType::isDynamic(dim.value())) { 511 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 512 shapedValue.getType().isa<MemRefType>()) && 513 "expected MemRef type"); 514 dynShape.push_back( 515 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 516 } 517 } 518 519 return allocMemRefType; 520 } 521 522 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the 523 /// block in case of a bbArg). 524 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 525 Value shapedValue, 526 Optional<bool> dealloc) { 527 // Take a guard before anything else. 528 OpBuilder::InsertionGuard g(b); 529 530 // Compute allocation memref type. 531 assert(shapedValue.getType().isa<ShapedType>()); 532 SmallVector<Value> dynShape; 533 MemRefType allocMemRefType = 534 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 535 536 // Create the buffer allocation. 537 FailureOr<Value> buffer = 538 getOptions().createAlloc(b, loc, allocMemRefType, dynShape); 539 if (failed(buffer)) 540 return failure(); 541 542 // Should be the buffer be deallocated again or should we let it leak? 543 if (dealloc) { 544 if (!dealloc.getValue()) 545 return *buffer; 546 } else { 547 assert(shapedValue.getType().isa<TensorType>() && 548 "must specify `dealloc` if non-tensor value is passed"); 549 // Buffer should be not be deallocated if deallocs are generally deactivated 550 // or if the tensor is yielded from a block. 551 if (!getOptions().createDeallocs || 552 getAnalysisState().isTensorYielded(shapedValue)) 553 return *buffer; 554 } 555 556 // Create buffer deallocation. 557 b.setInsertionPoint(b.getInsertionBlock()->getTerminator()); 558 if (failed(getOptions().createDealloc(b, loc, *buffer))) 559 return failure(); 560 561 return *buffer; 562 } 563 564 /// Create a memory copy between two memref buffers. 565 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 566 Value from, Value to) const { 567 if (memCpyFn) 568 return (*memCpyFn)(b, loc, from, to); 569 570 b.create<memref::CopyOp>(loc, from, to); 571 return success(); 572 } 573 574 //===----------------------------------------------------------------------===// 575 // Bufferization-specific BlockAndValueMapping support with debugging. 576 //===----------------------------------------------------------------------===// 577 578 bool bufferization::isFunctionArgument(Value value) { 579 auto bbArg = value.dyn_cast<BlockArgument>(); 580 if (!bbArg) 581 return false; 582 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 583 } 584 585 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 586 const BufferizationOptions &options, 587 MemRefLayoutAttrInterface layout, 588 Attribute memorySpace) { 589 // Case 1: Unranked memref type. 590 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 591 assert(!layout && "UnrankedTensorType cannot have a layout map"); 592 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 593 memorySpace); 594 } 595 596 // Case 2: Ranked memref type with specified layout. 597 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 598 if (layout) { 599 return MemRefType::get(rankedTensorType.getShape(), 600 rankedTensorType.getElementType(), layout, 601 memorySpace); 602 } 603 604 // Case 3: Configured with "fully dynamic layout maps". 605 if (options.unknownTypeConversion == 606 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap) 607 return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); 608 609 // Case 4: Configured with "static identity layout maps". 610 if (options.unknownTypeConversion == 611 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) 612 return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); 613 614 llvm_unreachable("InferLayoutMap is an invalid option"); 615 } 616 617 BaseMemRefType 618 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 619 Attribute memorySpace) { 620 // Case 1: Unranked memref type. 621 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 622 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 623 memorySpace); 624 } 625 626 // Case 2: Ranked memref type. 627 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 628 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 629 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 630 ShapedType::kDynamicStrideOrOffset); 631 AffineMap stridedLayout = makeStridedLinearLayoutMap( 632 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 633 return MemRefType::get(rankedTensorType.getShape(), 634 rankedTensorType.getElementType(), stridedLayout, 635 memorySpace); 636 } 637 638 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 639 /// the given tensor type is unranked, return an unranked MemRef type. 640 BaseMemRefType 641 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 642 Attribute memorySpace) { 643 // Case 1: Unranked memref type. 644 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 645 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 646 memorySpace); 647 } 648 649 // Case 2: Ranked memref type. 650 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 651 MemRefLayoutAttrInterface layout = {}; 652 return MemRefType::get(rankedTensorType.getShape(), 653 rankedTensorType.getElementType(), layout, 654 memorySpace); 655 } 656