1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AsmState.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/IR/Value.h" 20 #include "llvm/Support/Debug.h" 21 22 //===----------------------------------------------------------------------===// 23 // BufferizableOpInterface 24 //===----------------------------------------------------------------------===// 25 26 namespace mlir { 27 namespace bufferization { 28 29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 30 31 } // namespace bufferization 32 } // namespace mlir 33 34 #define DEBUG_TYPE "bufferizable-op-interface" 35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 37 38 using namespace mlir; 39 using namespace bufferization; 40 41 /// Attribute name used to mark region arguments that can be bufferized 42 /// in-place during linalg comprehensive bufferization. 43 constexpr const ::llvm::StringLiteral 44 bufferization::BufferizableOpInterface::kInplaceableAttrName; 45 46 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the 47 /// shaped value is copied. Otherwise, a tensor with undefined contents is 48 /// allocated. 49 FailureOr<Value> bufferization::allocateTensorForShapedValue( 50 OpBuilder &b, Location loc, Value shapedValue, bool escape, 51 const BufferizationOptions &options, bool copy) { 52 Value tensor; 53 if (shapedValue.getType().isa<RankedTensorType>()) { 54 tensor = shapedValue; 55 } else if (shapedValue.getType().isa<MemRefType>()) { 56 tensor = b.create<ToTensorOp>(loc, shapedValue); 57 } else { 58 llvm_unreachable("expected RankedTensorType or MemRefType"); 59 } 60 RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>(); 61 SmallVector<Value> dynamicSizes; 62 if (!copy) { 63 // Compute the dynamic part of the shape. 64 // First try to query the shape via ReifyRankedShapedTypeOpInterface. 65 bool reifiedShapes = false; 66 if (shapedValue.getType().isa<RankedTensorType>() && 67 shapedValue.isa<OpResult>()) { 68 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 69 shapedValue.getDefiningOp())) { 70 ReifiedRankedShapedTypeDims resultDims; 71 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 72 reifiedShapes = true; 73 auto &shape = 74 resultDims[shapedValue.cast<OpResult>().getResultNumber()]; 75 for (const auto &dim : enumerate(tensorType.getShape())) 76 if (ShapedType::isDynamic(dim.value())) 77 dynamicSizes.push_back(shape[dim.index()]); 78 } 79 } 80 } 81 82 // If the shape could not be reified, create DimOps. 83 if (!reifiedShapes) 84 populateDynamicDimSizes(b, loc, tensor, dynamicSizes); 85 } 86 87 auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes, 88 copy ? tensor : Value()); 89 allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName, 90 b.getBoolArrayAttr({escape})); 91 return allocTensorOp.getResult(); 92 } 93 94 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 95 RewriterBase &rewriter, const AnalysisState &state) { 96 OpBuilder::InsertionGuard g(rewriter); 97 Operation *op = getOperation(); 98 SmallVector<OpOperand *> outOfPlaceOpOperands; 99 DenseSet<OpOperand *> copiedOpOperands; 100 DenseSet<OpOperand *> escapingOpOperandCopies; 101 SmallVector<OpResult> outOfPlaceOpResults; 102 DenseSet<OpResult> copiedOpResults; 103 DenseSet<OpResult> escapingOpResultCopies; 104 105 // Find all out-of-place OpOperands. 106 for (OpOperand &opOperand : op->getOpOperands()) { 107 Type operandType = opOperand.get().getType(); 108 if (!operandType.isa<TensorType>()) 109 continue; 110 if (state.isInPlace(opOperand)) 111 continue; 112 if (operandType.isa<UnrankedTensorType>()) 113 return op->emitError("copies of unranked tensors are not supported"); 114 115 SmallVector<OpResult> aliasingOpResults = 116 state.getAliasingOpResult(opOperand); 117 // Is the result yielded from a block? Or are deallocations turned off 118 // entirely? In either case, mark the allocation as "escaping", so that it 119 // will not be deallocated. 120 bool escape = !state.getOptions().createDeallocs || 121 llvm::any_of(aliasingOpResults, [&](Value v) { 122 return state.isTensorYielded(v); 123 }); 124 125 if (aliasingOpResults.size() == 1 && 126 !state.bufferizesToMemoryWrite(opOperand) && 127 state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { 128 // The op itself does not write but may create exactly one alias. Instead 129 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes 130 // be smaller than the OpOperand (e.g., in the case of an extract_slice, 131 // where the result is usually a smaller part of the source). 132 outOfPlaceOpResults.push_back(aliasingOpResults.front()); 133 if (!state.canOmitTensorCopy(opOperand)) 134 copiedOpResults.insert(aliasingOpResults.front()); 135 if (escape) 136 escapingOpResultCopies.insert(aliasingOpResults.front()); 137 } else { 138 // In all other cases, make a copy of the OpOperand. 139 outOfPlaceOpOperands.push_back(&opOperand); 140 if (!state.canOmitTensorCopy(opOperand)) 141 copiedOpOperands.insert(&opOperand); 142 if (escape) 143 escapingOpOperandCopies.insert(&opOperand); 144 } 145 } 146 147 // Insert copies of OpOperands. 148 rewriter.setInsertionPoint(op); 149 for (OpOperand *opOperand : outOfPlaceOpOperands) { 150 FailureOr<Value> copy = allocateTensorForShapedValue( 151 rewriter, op->getLoc(), opOperand->get(), 152 escapingOpOperandCopies.contains(opOperand), state.getOptions(), 153 copiedOpOperands.contains(opOperand)); 154 if (failed(copy)) 155 return failure(); 156 rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); }); 157 } 158 159 // Insert copies of OpResults. 160 rewriter.setInsertionPointAfter(op); 161 for (OpResult opResult : outOfPlaceOpResults) { 162 FailureOr<Value> copy = allocateTensorForShapedValue( 163 rewriter, op->getLoc(), opResult, 164 escapingOpResultCopies.contains(opResult), state.getOptions(), 165 copiedOpResults.count(opResult)); 166 if (failed(copy)) 167 return failure(); 168 SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range( 169 opResult.getUses(), [](OpOperand &use) { return &use; })); 170 for (OpOperand *use : uses) { 171 // Do not update the alloc_tensor op that we just created. 172 if (use->getOwner() != copy->getDefiningOp()) 173 rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); }); 174 } 175 } 176 177 return success(); 178 } 179 180 //===----------------------------------------------------------------------===// 181 // OpFilter 182 //===----------------------------------------------------------------------===// 183 184 bool OpFilter::isOpAllowed(Operation *op) const { 185 // All other ops: Allow/disallow according to filter. 186 bool isAllowed = !hasAllowRule(); 187 for (const Entry &entry : entries) { 188 bool filterResult = entry.fn(op); 189 switch (entry.type) { 190 case Entry::ALLOW: 191 isAllowed |= filterResult; 192 break; 193 case Entry::DENY: 194 if (filterResult) 195 // DENY filter matches. This op is no allowed. (Even if other ALLOW 196 // filters may match.) 197 return false; 198 }; 199 } 200 return isAllowed; 201 } 202 203 //===----------------------------------------------------------------------===// 204 // BufferizationOptions 205 //===----------------------------------------------------------------------===// 206 207 // Default constructor for BufferizationOptions. 208 BufferizationOptions::BufferizationOptions() = default; 209 210 bool BufferizationOptions::isOpAllowed(Operation *op) const { 211 // Special case: If function boundary bufferization is deactivated, do not 212 // allow ops that belong to the `func` dialect. 213 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 214 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 215 return false; 216 217 return opFilter.isOpAllowed(op); 218 } 219 220 BufferizableOpInterface 221 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 222 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 223 if (!bufferizableOp) 224 return nullptr; 225 if (!isOpAllowed(op)) 226 return nullptr; 227 return bufferizableOp; 228 } 229 230 BufferizableOpInterface 231 BufferizationOptions::dynCastBufferizableOp(Value value) const { 232 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 233 if (isOpAllowed(bufferizableOp.getOperation())) 234 return bufferizableOp; 235 return nullptr; 236 } 237 238 void BufferizationOptions::addDialectStateInitializer( 239 StringRef name, const DialectStateInitFn &fn) { 240 stateInitializers.push_back( 241 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // Helper functions for BufferizableOpInterface 246 //===----------------------------------------------------------------------===// 247 248 static void setInsertionPointAfter(OpBuilder &b, Value value) { 249 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 250 b.setInsertionPointToStart(bbArg.getOwner()); 251 } else { 252 b.setInsertionPointAfter(value.getDefiningOp()); 253 } 254 } 255 256 /// Determine which OpOperand* will alias with `result` if the op is bufferized 257 /// in place. Return an empty vector if the op is not bufferizable. 258 SmallVector<OpOperand *> 259 AnalysisState::getAliasingOpOperand(OpResult result) const { 260 if (Operation *op = result.getDefiningOp()) 261 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 262 return bufferizableOp.getAliasingOpOperand(result, *this); 263 return {}; 264 } 265 266 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 267 /// in place. Return an empty vector if the op is not bufferizable. 268 SmallVector<OpResult> 269 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 270 if (auto bufferizableOp = 271 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 272 return bufferizableOp.getAliasingOpResult(opOperand, *this); 273 return {}; 274 } 275 276 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 277 /// op is not bufferizable. 278 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 279 if (auto bufferizableOp = 280 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 281 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 282 283 // Unknown op that returns a tensor. The inplace analysis does not support it. 284 // Conservatively return true. 285 return true; 286 } 287 288 /// Return true if `opOperand` bufferizes to a memory write. Return 289 /// `true` if the op is not bufferizable. 290 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 291 if (auto bufferizableOp = 292 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 293 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 294 295 // Unknown op that returns a tensor. The inplace analysis does not support it. 296 // Conservatively return true. 297 return true; 298 } 299 300 /// Return true if `opOperand` does neither read nor write but bufferizes to an 301 /// alias. Return false if the op is not bufferizable. 302 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 303 if (auto bufferizableOp = 304 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 305 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 306 307 // Unknown op that returns a tensor. The inplace analysis does not support it. 308 // Conservatively return false. 309 return false; 310 } 311 312 /// Return true if the given value is read by an op that bufferizes to a memory 313 /// read. Also takes into account ops that create an alias but do not read by 314 /// themselves (e.g., ExtractSliceOp). 315 bool AnalysisState::isValueRead(Value value) const { 316 assert(value.getType().isa<TensorType>() && "expected TensorType"); 317 SmallVector<OpOperand *> workingSet; 318 for (OpOperand &use : value.getUses()) 319 workingSet.push_back(&use); 320 321 while (!workingSet.empty()) { 322 OpOperand *uMaybeReading = workingSet.pop_back_val(); 323 // Skip over all ops that neither read nor write (but create an alias). 324 if (bufferizesToAliasOnly(*uMaybeReading)) 325 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 326 for (OpOperand &use : opResult.getUses()) 327 workingSet.push_back(&use); 328 if (bufferizesToMemoryRead(*uMaybeReading)) 329 return true; 330 } 331 332 return false; 333 } 334 335 // Starting from `value`, follow the use-def chain in reverse, always selecting 336 // the aliasing OpOperands. Find and return Values for which `condition` 337 // evaluates to true. OpOperands of such matching Values are not traversed any 338 // further. 339 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 340 Value value, llvm::function_ref<bool(Value)> condition) const { 341 llvm::SetVector<Value> result, workingSet; 342 workingSet.insert(value); 343 344 while (!workingSet.empty()) { 345 Value value = workingSet.pop_back_val(); 346 if (condition(value) || value.isa<BlockArgument>()) { 347 result.insert(value); 348 continue; 349 } 350 351 OpResult opResult = value.cast<OpResult>(); 352 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 353 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 354 result.insert(value); 355 continue; 356 } 357 358 for (OpOperand *o : opOperands) 359 workingSet.insert(o->get()); 360 } 361 362 return result; 363 } 364 365 // Find the Values of the last preceding write of a given Value. 366 llvm::SetVector<Value> 367 AnalysisState::findLastPrecedingWrite(Value value) const { 368 return findValueInReverseUseDefChain(value, [&](Value value) { 369 Operation *op = value.getDefiningOp(); 370 if (!op) 371 return true; 372 auto bufferizableOp = options.dynCastBufferizableOp(op); 373 if (!bufferizableOp) 374 return true; 375 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 376 }); 377 } 378 379 AnalysisState::AnalysisState(const BufferizationOptions &options) 380 : options(options) { 381 for (const BufferizationOptions::AnalysisStateInitFn &fn : 382 options.stateInitializers) 383 fn(*this); 384 } 385 386 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const { 387 // Do not copy if the tensor has undefined contents. 388 if (hasUndefinedContents(&opOperand)) 389 return true; 390 391 // Do not copy if the buffer of the tensor is entirely overwritten (with 392 // values that do not depend on the old tensor). 393 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 394 return true; 395 396 // Do not copy if the tensor is never read. 397 SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand); 398 if (!bufferizesToMemoryRead(opOperand) && 399 llvm::none_of(aliasingOpResults, 400 [&](OpResult opResult) { return isValueRead(opResult); })) 401 return true; 402 403 // Default: Cannot omit the copy. 404 return false; 405 } 406 407 bool AnalysisState::isInPlace(OpOperand &opOperand) const { 408 // ToMemrefOps are always in-place. 409 if (isa<ToMemrefOp>(opOperand.getOwner())) 410 return true; 411 412 // In the absence of analysis information, OpOperands that bufferize to a 413 // memory write are out-of-place, i.e., an alloc and copy is inserted. 414 return !bufferizesToMemoryWrite(opOperand); 415 } 416 417 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const { 418 // In the absence of analysis information, we do not know if the values are 419 // equivalent. The conservative answer is "false". 420 return false; 421 } 422 423 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const { 424 // In the absence of analysis information, we do not know if the values may be 425 // aliasing. The conservative answer is "true". 426 return true; 427 } 428 429 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 430 // In the absence of analysis information, the conservative answer is "false". 431 return false; 432 } 433 434 bool AnalysisState::isTensorYielded(Value tensor) const { 435 // In the absence of analysis information, the conservative answer is "true". 436 if (!tensor.getDefiningOp<AllocTensorOp>()) 437 return true; 438 439 // For AllocTensorOp results, we can do better: They do not alias with any 440 // preceding value, so we can follow SSA use-def chains and do a simple 441 // analysis. 442 SmallVector<OpOperand *> worklist; 443 for (OpOperand &use : tensor.getUses()) 444 worklist.push_back(&use); 445 446 while (!worklist.empty()) { 447 OpOperand *operand = worklist.pop_back_val(); 448 Operation *op = operand->getOwner(); 449 450 // If the op is not bufferizable, we can safely assume that the value is not 451 // yielded. (When bufferizing that op, it must handle such cases.) 452 if (!options.dynCastBufferizableOp(op)) 453 continue; 454 455 // We cannot analyze through ToMemrefOps, so we have to conservatively 456 // assume that the value is yielded. 457 if (isa<ToMemrefOp>(op)) 458 return true; 459 460 // Check if the op is returning/yielding. 461 if (isRegionReturnLike(op)) 462 return true; 463 464 // Add all aliasing OpResults to the worklist. 465 // Note: In the absence of detailed analysis information (e.g., there may be 466 // no function call analysis information), this `getAliasingOpResult` is 467 // conservative and may report additional OpResults as potentially aliasing. 468 for (OpResult opResult : getAliasingOpResult(*operand)) 469 for (OpOperand &use : opResult.getUses()) 470 worklist.push_back(&use); 471 } 472 473 // No ReturnLike op found: The value is not yielded. 474 return false; 475 } 476 477 // bufferization.to_memref is not allowed to change the rank. 478 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 479 #ifndef NDEBUG 480 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 481 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 482 rankedTensorType.getRank()) && 483 "to_memref would be invalid: mismatching ranks"); 484 #endif 485 } 486 487 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, 488 const BufferizationOptions &options) { 489 #ifndef NDEBUG 490 auto tensorType = value.getType().dyn_cast<TensorType>(); 491 assert(tensorType && "unexpected non-tensor type"); 492 #endif // NDEBUG 493 494 // Replace "%t = to_tensor %m" with %m. 495 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 496 return toTensorOp.getMemref(); 497 498 // Insert to_memref op. 499 OpBuilder::InsertionGuard g(rewriter); 500 setInsertionPointAfter(rewriter, value); 501 FailureOr<BaseMemRefType> memrefType = getBufferType(value, options); 502 if (failed(memrefType)) 503 return failure(); 504 ensureToMemrefOpIsValid(value, *memrefType); 505 return rewriter 506 .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value) 507 .getResult(); 508 } 509 510 /// Return the buffer type for a given Value (tensor) after bufferization. 511 FailureOr<BaseMemRefType> 512 bufferization::getBufferType(Value value, const BufferizationOptions &options) { 513 auto tensorType = value.getType().dyn_cast<TensorType>(); 514 assert(tensorType && "unexpected non-tensor type"); 515 516 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 517 return toTensorOp.getMemref().getType().cast<BaseMemRefType>(); 518 519 if (auto bbArg = value.dyn_cast<BlockArgument>()) 520 if (auto bufferizableOp = 521 options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) 522 return bufferizableOp.getBufferType(bbArg, options); 523 524 return getMemRefType(tensorType, options); 525 } 526 527 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 528 Operation *op, 529 ValueRange values) { 530 assert(values.size() == op->getNumResults() && 531 "expected one value per OpResult"); 532 OpBuilder::InsertionGuard g(rewriter); 533 534 // Replace all OpResults with the given values. 535 SmallVector<Value> replacements; 536 for (OpResult opResult : op->getOpResults()) { 537 Value replacement = values[opResult.getResultNumber()]; 538 if (opResult.getType().isa<TensorType>()) { 539 // The OpResult is a tensor. Such values are replaced with memrefs during 540 // bufferization. 541 assert((replacement.getType().isa<MemRefType>() || 542 replacement.getType().isa<UnrankedMemRefType>()) && 543 "tensor op result should be replaced with a memref value"); 544 // The existing uses of the OpResult still expect a tensor. Insert a 545 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 546 // loose all of its users and eventually DCE away. 547 rewriter.setInsertionPointAfter(op); 548 replacement = rewriter.create<bufferization::ToTensorOp>( 549 replacement.getLoc(), replacement); 550 } 551 replacements.push_back(replacement); 552 } 553 554 rewriter.replaceOp(op, replacements); 555 } 556 557 //===----------------------------------------------------------------------===// 558 // Bufferization-specific scoped alloc/dealloc insertion support. 559 //===----------------------------------------------------------------------===// 560 561 /// Create a memref allocation with the given type and dynamic extents. 562 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 563 MemRefType type, 564 ValueRange dynShape) const { 565 if (allocationFn) 566 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 567 568 // Default bufferallocation via AllocOp. 569 if (bufferAlignment != 0) 570 return b 571 .create<memref::AllocOp>(loc, type, dynShape, 572 b.getI64IntegerAttr(bufferAlignment)) 573 .getResult(); 574 return b.create<memref::AllocOp>(loc, type, dynShape).getResult(); 575 } 576 577 /// Creates a memref deallocation. The given memref buffer must have been 578 /// allocated using `createAlloc`. 579 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 580 Value allocatedBuffer) const { 581 if (deallocationFn) 582 return (*deallocationFn)(b, loc, allocatedBuffer); 583 584 // Default buffer deallocation via DeallocOp. 585 b.create<memref::DeallocOp>(loc, allocatedBuffer); 586 return success(); 587 } 588 589 /// Create a memory copy between two memref buffers. 590 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 591 Value from, Value to) const { 592 if (memCpyFn) 593 return (*memCpyFn)(b, loc, from, to); 594 595 b.create<memref::CopyOp>(loc, from, to); 596 return success(); 597 } 598 599 //===----------------------------------------------------------------------===// 600 // Bufferization-specific BlockAndValueMapping support with debugging. 601 //===----------------------------------------------------------------------===// 602 603 bool bufferization::isFunctionArgument(Value value) { 604 auto bbArg = value.dyn_cast<BlockArgument>(); 605 if (!bbArg) 606 return false; 607 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 608 } 609 610 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 611 const BufferizationOptions &options, 612 MemRefLayoutAttrInterface layout, 613 unsigned memorySpace) { 614 auto memorySpaceAttr = IntegerAttr::get( 615 IntegerType::get(tensorType.getContext(), 64), memorySpace); 616 617 // Case 1: Unranked memref type. 618 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 619 assert(!layout && "UnrankedTensorType cannot have a layout map"); 620 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 621 memorySpaceAttr); 622 } 623 624 // Case 2: Ranked memref type with specified layout. 625 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 626 if (layout) { 627 return MemRefType::get(rankedTensorType.getShape(), 628 rankedTensorType.getElementType(), layout, 629 memorySpaceAttr); 630 } 631 632 // Case 3: Configured with "fully dynamic layout maps". 633 if (options.unknownTypeConversion == 634 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap) 635 return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); 636 637 // Case 4: Configured with "static identity layout maps". 638 if (options.unknownTypeConversion == 639 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) 640 return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); 641 642 llvm_unreachable("InferLayoutMap is an invalid option"); 643 } 644 645 BaseMemRefType 646 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 647 unsigned memorySpace) { 648 // Case 1: Unranked memref type. 649 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 650 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 651 memorySpace); 652 } 653 654 // Case 2: Ranked memref type. 655 auto memorySpaceAttr = IntegerAttr::get( 656 IntegerType::get(tensorType.getContext(), 64), memorySpace); 657 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 658 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 659 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 660 ShapedType::kDynamicStrideOrOffset); 661 AffineMap stridedLayout = makeStridedLinearLayoutMap( 662 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 663 return MemRefType::get(rankedTensorType.getShape(), 664 rankedTensorType.getElementType(), stridedLayout, 665 memorySpaceAttr); 666 } 667 668 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 669 /// the given tensor type is unranked, return an unranked MemRef type. 670 BaseMemRefType 671 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 672 unsigned memorySpace) { 673 // Case 1: Unranked memref type. 674 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 675 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 676 memorySpace); 677 } 678 679 // Case 2: Ranked memref type. 680 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 681 auto memorySpaceAttr = IntegerAttr::get( 682 IntegerType::get(tensorType.getContext(), 64), memorySpace); 683 MemRefLayoutAttrInterface layout = {}; 684 return MemRefType::get(rankedTensorType.getShape(), 685 rankedTensorType.getElementType(), layout, 686 memorySpaceAttr); 687 } 688