1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AsmState.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/IR/Value.h" 20 #include "llvm/Support/Debug.h" 21 22 //===----------------------------------------------------------------------===// 23 // BufferizableOpInterface 24 //===----------------------------------------------------------------------===// 25 26 namespace mlir { 27 namespace bufferization { 28 29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 30 31 } // namespace bufferization 32 } // namespace mlir 33 34 #define DEBUG_TYPE "bufferizable-op-interface" 35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 37 38 using namespace mlir; 39 using namespace bufferization; 40 41 /// Attribute name used to mark region arguments that can be bufferized 42 /// in-place during linalg comprehensive bufferization. 43 constexpr const ::llvm::StringLiteral 44 bufferization::BufferizableOpInterface::kInplaceableAttrName; 45 46 /// Return the owner of the given value. 47 static Operation *getOwnerOfValue(Value value) { 48 if (auto opResult = value.dyn_cast<OpResult>()) 49 return opResult.getDefiningOp(); 50 return value.cast<BlockArgument>().getOwner()->getParentOp(); 51 } 52 53 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the 54 /// shaped value is copied. Otherwise, a tensor with undefined contents is 55 /// allocated. 56 FailureOr<Value> bufferization::allocateTensorForShapedValue( 57 OpBuilder &b, Location loc, Value shapedValue, bool escape, 58 const BufferizationOptions &options, bool copy) { 59 Value tensor; 60 if (shapedValue.getType().isa<RankedTensorType>()) { 61 tensor = shapedValue; 62 } else if (shapedValue.getType().isa<MemRefType>()) { 63 tensor = b.create<ToTensorOp>(loc, shapedValue); 64 } else { 65 llvm_unreachable("expected RankedTensorType or MemRefType"); 66 } 67 RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>(); 68 SmallVector<Value> dynamicSizes; 69 if (!copy) { 70 // Compute the dynamic part of the shape. 71 // First try to query the shape via ReifyRankedShapedTypeOpInterface. 72 bool reifiedShapes = false; 73 if (shapedValue.getType().isa<RankedTensorType>() && 74 shapedValue.isa<OpResult>()) { 75 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 76 shapedValue.getDefiningOp())) { 77 ReifiedRankedShapedTypeDims resultDims; 78 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 79 reifiedShapes = true; 80 auto &shape = 81 resultDims[shapedValue.cast<OpResult>().getResultNumber()]; 82 for (const auto &dim : enumerate(tensorType.getShape())) 83 if (ShapedType::isDynamic(dim.value())) 84 dynamicSizes.push_back(shape[dim.index()]); 85 } 86 } 87 } 88 89 // If the shape could not be reified, create DimOps. 90 if (!reifiedShapes) 91 populateDynamicDimSizes(b, loc, tensor, dynamicSizes); 92 } 93 94 // Create AllocTensorOp. 95 auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes, 96 copy ? tensor : Value()); 97 allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName, 98 b.getBoolArrayAttr({escape})); 99 100 // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. 101 if (copy) 102 return allocTensorOp.getResult(); 103 FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options); 104 if (failed(copyBufferType)) 105 return failure(); 106 allocTensorOp.setMemorySpaceAttr( 107 b.getIntegerAttr(b.getIntegerType(64, /*isSigned=*/false), 108 copyBufferType->getMemorySpaceAsInt())); 109 return allocTensorOp.getResult(); 110 } 111 112 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 113 RewriterBase &rewriter, const AnalysisState &state) { 114 OpBuilder::InsertionGuard g(rewriter); 115 Operation *op = getOperation(); 116 SmallVector<OpOperand *> outOfPlaceOpOperands; 117 DenseSet<OpOperand *> copiedOpOperands; 118 DenseSet<OpOperand *> escapingOpOperandCopies; 119 SmallVector<OpResult> outOfPlaceOpResults; 120 DenseSet<OpResult> copiedOpResults; 121 DenseSet<OpResult> escapingOpResultCopies; 122 123 // Find all out-of-place OpOperands. 124 for (OpOperand &opOperand : op->getOpOperands()) { 125 Type operandType = opOperand.get().getType(); 126 if (!operandType.isa<TensorType>()) 127 continue; 128 if (state.isInPlace(opOperand)) 129 continue; 130 if (operandType.isa<UnrankedTensorType>()) 131 return op->emitError("copies of unranked tensors are not supported"); 132 133 SmallVector<OpResult> aliasingOpResults = 134 state.getAliasingOpResult(opOperand); 135 // Is the result yielded from a block? Or are deallocations turned off 136 // entirely? In either case, mark the allocation as "escaping", so that it 137 // will not be deallocated. 138 bool escape = !state.getOptions().createDeallocs || 139 llvm::any_of(aliasingOpResults, [&](Value v) { 140 return state.isTensorYielded(v); 141 }); 142 143 if (aliasingOpResults.size() == 1 && 144 !state.bufferizesToMemoryWrite(opOperand) && 145 state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { 146 // The op itself does not write but may create exactly one alias. Instead 147 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes 148 // be smaller than the OpOperand (e.g., in the case of an extract_slice, 149 // where the result is usually a smaller part of the source). 150 outOfPlaceOpResults.push_back(aliasingOpResults.front()); 151 if (!state.canOmitTensorCopy(opOperand)) 152 copiedOpResults.insert(aliasingOpResults.front()); 153 if (escape) 154 escapingOpResultCopies.insert(aliasingOpResults.front()); 155 } else { 156 // In all other cases, make a copy of the OpOperand. 157 outOfPlaceOpOperands.push_back(&opOperand); 158 if (!state.canOmitTensorCopy(opOperand)) 159 copiedOpOperands.insert(&opOperand); 160 if (escape) 161 escapingOpOperandCopies.insert(&opOperand); 162 } 163 } 164 165 // Insert copies of OpOperands. 166 rewriter.setInsertionPoint(op); 167 for (OpOperand *opOperand : outOfPlaceOpOperands) { 168 FailureOr<Value> copy = allocateTensorForShapedValue( 169 rewriter, op->getLoc(), opOperand->get(), 170 escapingOpOperandCopies.contains(opOperand), state.getOptions(), 171 copiedOpOperands.contains(opOperand)); 172 if (failed(copy)) 173 return failure(); 174 rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); }); 175 } 176 177 // Insert copies of OpResults. 178 rewriter.setInsertionPointAfter(op); 179 for (OpResult opResult : outOfPlaceOpResults) { 180 FailureOr<Value> copy = allocateTensorForShapedValue( 181 rewriter, op->getLoc(), opResult, 182 escapingOpResultCopies.contains(opResult), state.getOptions(), 183 copiedOpResults.count(opResult)); 184 if (failed(copy)) 185 return failure(); 186 SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range( 187 opResult.getUses(), [](OpOperand &use) { return &use; })); 188 for (OpOperand *use : uses) { 189 // Do not update the alloc_tensor op that we just created. 190 if (use->getOwner() != copy->getDefiningOp()) 191 rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); }); 192 } 193 } 194 195 return success(); 196 } 197 198 //===----------------------------------------------------------------------===// 199 // OpFilter 200 //===----------------------------------------------------------------------===// 201 202 bool OpFilter::isOpAllowed(Operation *op) const { 203 // All other ops: Allow/disallow according to filter. 204 bool isAllowed = !hasAllowRule(); 205 for (const Entry &entry : entries) { 206 bool filterResult = entry.fn(op); 207 switch (entry.type) { 208 case Entry::ALLOW: 209 isAllowed |= filterResult; 210 break; 211 case Entry::DENY: 212 if (filterResult) 213 // DENY filter matches. This op is no allowed. (Even if other ALLOW 214 // filters may match.) 215 return false; 216 }; 217 } 218 return isAllowed; 219 } 220 221 //===----------------------------------------------------------------------===// 222 // BufferizationOptions 223 //===----------------------------------------------------------------------===// 224 225 // Default constructor for BufferizationOptions. 226 BufferizationOptions::BufferizationOptions() = default; 227 228 bool BufferizationOptions::isOpAllowed(Operation *op) const { 229 // Special case: If function boundary bufferization is deactivated, do not 230 // allow ops that belong to the `func` dialect. 231 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 232 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 233 return false; 234 235 return opFilter.isOpAllowed(op); 236 } 237 238 BufferizableOpInterface 239 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 240 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 241 if (!bufferizableOp) 242 return nullptr; 243 if (!isOpAllowed(op)) 244 return nullptr; 245 return bufferizableOp; 246 } 247 248 BufferizableOpInterface 249 BufferizationOptions::dynCastBufferizableOp(Value value) const { 250 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 251 if (isOpAllowed(bufferizableOp.getOperation())) 252 return bufferizableOp; 253 return nullptr; 254 } 255 256 void BufferizationOptions::addDialectStateInitializer( 257 StringRef name, const DialectStateInitFn &fn) { 258 stateInitializers.push_back( 259 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // Helper functions for BufferizableOpInterface 264 //===----------------------------------------------------------------------===// 265 266 static void setInsertionPointAfter(OpBuilder &b, Value value) { 267 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 268 b.setInsertionPointToStart(bbArg.getOwner()); 269 } else { 270 b.setInsertionPointAfter(value.getDefiningOp()); 271 } 272 } 273 274 /// Determine which OpOperand* will alias with `result` if the op is bufferized 275 /// in place. Return an empty vector if the op is not bufferizable. 276 SmallVector<OpOperand *> 277 AnalysisState::getAliasingOpOperand(OpResult result) const { 278 if (Operation *op = result.getDefiningOp()) 279 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 280 return bufferizableOp.getAliasingOpOperand(result, *this); 281 return {}; 282 } 283 284 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 285 /// in place. Return an empty vector if the op is not bufferizable. 286 SmallVector<OpResult> 287 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 288 if (auto bufferizableOp = 289 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 290 return bufferizableOp.getAliasingOpResult(opOperand, *this); 291 return {}; 292 } 293 294 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 295 /// op is not bufferizable. 296 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 297 if (auto bufferizableOp = 298 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 299 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 300 301 // Unknown op that returns a tensor. The inplace analysis does not support it. 302 // Conservatively return true. 303 return true; 304 } 305 306 /// Return true if `opOperand` bufferizes to a memory write. Return 307 /// `true` if the op is not bufferizable. 308 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 309 if (auto bufferizableOp = 310 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 311 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 312 313 // Unknown op that returns a tensor. The inplace analysis does not support it. 314 // Conservatively return true. 315 return true; 316 } 317 318 /// Return true if `opOperand` does neither read nor write but bufferizes to an 319 /// alias. Return false if the op is not bufferizable. 320 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 321 if (auto bufferizableOp = 322 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 323 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 324 325 // Unknown op that returns a tensor. The inplace analysis does not support it. 326 // Conservatively return false. 327 return false; 328 } 329 330 /// Return true if the given value is read by an op that bufferizes to a memory 331 /// read. Also takes into account ops that create an alias but do not read by 332 /// themselves (e.g., ExtractSliceOp). 333 bool AnalysisState::isValueRead(Value value) const { 334 assert(value.getType().isa<TensorType>() && "expected TensorType"); 335 SmallVector<OpOperand *> workingSet; 336 for (OpOperand &use : value.getUses()) 337 workingSet.push_back(&use); 338 339 while (!workingSet.empty()) { 340 OpOperand *uMaybeReading = workingSet.pop_back_val(); 341 // Skip over all ops that neither read nor write (but create an alias). 342 if (bufferizesToAliasOnly(*uMaybeReading)) 343 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 344 for (OpOperand &use : opResult.getUses()) 345 workingSet.push_back(&use); 346 if (bufferizesToMemoryRead(*uMaybeReading)) 347 return true; 348 } 349 350 return false; 351 } 352 353 // Starting from `value`, follow the use-def chain in reverse, always selecting 354 // the aliasing OpOperands. Find and return Values for which `condition` 355 // evaluates to true. OpOperands of such matching Values are not traversed any 356 // further. 357 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 358 Value value, llvm::function_ref<bool(Value)> condition) const { 359 llvm::SetVector<Value> result, workingSet; 360 workingSet.insert(value); 361 362 while (!workingSet.empty()) { 363 Value value = workingSet.pop_back_val(); 364 if (condition(value) || value.isa<BlockArgument>()) { 365 result.insert(value); 366 continue; 367 } 368 369 OpResult opResult = value.cast<OpResult>(); 370 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 371 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 372 result.insert(value); 373 continue; 374 } 375 376 for (OpOperand *o : opOperands) 377 workingSet.insert(o->get()); 378 } 379 380 return result; 381 } 382 383 // Find the Values of the last preceding write of a given Value. 384 llvm::SetVector<Value> 385 AnalysisState::findLastPrecedingWrite(Value value) const { 386 return findValueInReverseUseDefChain(value, [&](Value value) { 387 Operation *op = value.getDefiningOp(); 388 if (!op) 389 return true; 390 auto bufferizableOp = options.dynCastBufferizableOp(op); 391 if (!bufferizableOp) 392 return true; 393 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 394 }); 395 } 396 397 AnalysisState::AnalysisState(const BufferizationOptions &options) 398 : options(options) { 399 for (const BufferizationOptions::AnalysisStateInitFn &fn : 400 options.stateInitializers) 401 fn(*this); 402 } 403 404 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const { 405 // Do not copy if the tensor has undefined contents. 406 if (hasUndefinedContents(&opOperand)) 407 return true; 408 409 // Do not copy if the buffer of the tensor is entirely overwritten (with 410 // values that do not depend on the old tensor). 411 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 412 return true; 413 414 // Do not copy if the tensor is never read. 415 SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand); 416 if (!bufferizesToMemoryRead(opOperand) && 417 llvm::none_of(aliasingOpResults, 418 [&](OpResult opResult) { return isValueRead(opResult); })) 419 return true; 420 421 // Default: Cannot omit the copy. 422 return false; 423 } 424 425 bool AnalysisState::isInPlace(OpOperand &opOperand) const { 426 // ToMemrefOps are always in-place. 427 if (isa<ToMemrefOp>(opOperand.getOwner())) 428 return true; 429 430 // In the absence of analysis information, OpOperands that bufferize to a 431 // memory write are out-of-place, i.e., an alloc and copy is inserted. 432 return !bufferizesToMemoryWrite(opOperand); 433 } 434 435 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const { 436 // In the absence of analysis information, we do not know if the values are 437 // equivalent. The conservative answer is "false". 438 return false; 439 } 440 441 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const { 442 // In the absence of analysis information, we do not know if the values may be 443 // aliasing. The conservative answer is "true". 444 return true; 445 } 446 447 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 448 // In the absence of analysis information, the conservative answer is "false". 449 return false; 450 } 451 452 bool AnalysisState::isTensorYielded(Value tensor) const { 453 // In the absence of analysis information, the conservative answer is "true". 454 if (!tensor.getDefiningOp<AllocTensorOp>()) 455 return true; 456 457 // For AllocTensorOp results, we can do better: They do not alias with any 458 // preceding value, so we can follow SSA use-def chains and do a simple 459 // analysis. 460 SmallVector<OpOperand *> worklist; 461 for (OpOperand &use : tensor.getUses()) 462 worklist.push_back(&use); 463 464 while (!worklist.empty()) { 465 OpOperand *operand = worklist.pop_back_val(); 466 Operation *op = operand->getOwner(); 467 468 // If the op is not bufferizable, we can safely assume that the value is not 469 // yielded. (When bufferizing that op, it must handle such cases.) 470 if (!options.dynCastBufferizableOp(op)) 471 continue; 472 473 // We cannot analyze through ToMemrefOps, so we have to conservatively 474 // assume that the value is yielded. 475 if (isa<ToMemrefOp>(op)) 476 return true; 477 478 // Check if the op is returning/yielding. 479 if (isRegionReturnLike(op)) 480 return true; 481 482 // Add all aliasing OpResults to the worklist. 483 // Note: In the absence of detailed analysis information (e.g., there may be 484 // no function call analysis information), this `getAliasingOpResult` is 485 // conservative and may report additional OpResults as potentially aliasing. 486 for (OpResult opResult : getAliasingOpResult(*operand)) 487 for (OpOperand &use : opResult.getUses()) 488 worklist.push_back(&use); 489 } 490 491 // No ReturnLike op found: The value is not yielded. 492 return false; 493 } 494 495 // bufferization.to_memref is not allowed to change the rank. 496 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 497 #ifndef NDEBUG 498 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 499 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 500 rankedTensorType.getRank()) && 501 "to_memref would be invalid: mismatching ranks"); 502 #endif 503 } 504 505 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, 506 const BufferizationOptions &options) { 507 #ifndef NDEBUG 508 auto tensorType = value.getType().dyn_cast<TensorType>(); 509 assert(tensorType && "unexpected non-tensor type"); 510 #endif // NDEBUG 511 512 // Replace "%t = to_tensor %m" with %m. 513 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 514 return toTensorOp.getMemref(); 515 516 // Insert to_memref op. 517 OpBuilder::InsertionGuard g(rewriter); 518 setInsertionPointAfter(rewriter, value); 519 FailureOr<BaseMemRefType> memrefType = getBufferType(value, options); 520 if (failed(memrefType)) 521 return failure(); 522 ensureToMemrefOpIsValid(value, *memrefType); 523 return rewriter 524 .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value) 525 .getResult(); 526 } 527 528 /// Return the buffer type for a given Value (tensor) after bufferization. 529 FailureOr<BaseMemRefType> 530 bufferization::getBufferType(Value value, const BufferizationOptions &options) { 531 auto tensorType = value.getType().dyn_cast<TensorType>(); 532 assert(tensorType && "unexpected non-tensor type"); 533 Operation *op = getOwnerOfValue(value); 534 535 // ToTensorOp: Take buffer type directly from the op. 536 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 537 return toTensorOp.getMemref().getType().cast<BaseMemRefType>(); 538 539 // If value is a bbArg of a bufferizable op: query op interface. 540 if (auto bbArg = value.dyn_cast<BlockArgument>()) 541 if (auto bufferizableOp = 542 options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) 543 return bufferizableOp.getBufferType(bbArg, options); 544 545 // Check value is a new buffer allocation with a memory space attribute. In 546 // that case we can at least infer the memory space. 547 Optional<unsigned> memorySpace = None; 548 if (auto opResult = value.dyn_cast<OpResult>()) { 549 if (auto bufferizableOp = 550 options.dynCastBufferizableOp(opResult.getDefiningOp())) { 551 if (bufferizableOp.bufferizesToAllocation(opResult)) { 552 FailureOr<unsigned> queriedMemorySpace = 553 bufferizableOp.getMemorySpace(opResult); 554 if (!failed(queriedMemorySpace)) 555 memorySpace = *queriedMemorySpace; 556 } 557 } 558 } 559 560 // If we still do not know the memory space, use the default memory space (if 561 // any). 562 if (!memorySpace.hasValue()) 563 memorySpace = options.defaultMemorySpace; 564 565 // If we still do not know the memory space, report a failure. 566 if (!memorySpace.hasValue()) 567 return op->emitError("could not infer memory space"); 568 569 return getMemRefType(tensorType, options, /*layout=*/{}, *memorySpace); 570 } 571 572 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 573 Operation *op, 574 ValueRange values) { 575 assert(values.size() == op->getNumResults() && 576 "expected one value per OpResult"); 577 OpBuilder::InsertionGuard g(rewriter); 578 579 // Replace all OpResults with the given values. 580 SmallVector<Value> replacements; 581 for (OpResult opResult : op->getOpResults()) { 582 Value replacement = values[opResult.getResultNumber()]; 583 if (opResult.getType().isa<TensorType>()) { 584 // The OpResult is a tensor. Such values are replaced with memrefs during 585 // bufferization. 586 assert((replacement.getType().isa<MemRefType>() || 587 replacement.getType().isa<UnrankedMemRefType>()) && 588 "tensor op result should be replaced with a memref value"); 589 // The existing uses of the OpResult still expect a tensor. Insert a 590 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 591 // loose all of its users and eventually DCE away. 592 rewriter.setInsertionPointAfter(op); 593 replacement = rewriter.create<bufferization::ToTensorOp>( 594 replacement.getLoc(), replacement); 595 } 596 replacements.push_back(replacement); 597 } 598 599 rewriter.replaceOp(op, replacements); 600 } 601 602 //===----------------------------------------------------------------------===// 603 // Bufferization-specific scoped alloc/dealloc insertion support. 604 //===----------------------------------------------------------------------===// 605 606 /// Create a memref allocation with the given type and dynamic extents. 607 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 608 MemRefType type, 609 ValueRange dynShape) const { 610 if (allocationFn) 611 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 612 613 // Default bufferallocation via AllocOp. 614 if (bufferAlignment != 0) 615 return b 616 .create<memref::AllocOp>(loc, type, dynShape, 617 b.getI64IntegerAttr(bufferAlignment)) 618 .getResult(); 619 return b.create<memref::AllocOp>(loc, type, dynShape).getResult(); 620 } 621 622 /// Creates a memref deallocation. The given memref buffer must have been 623 /// allocated using `createAlloc`. 624 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 625 Value allocatedBuffer) const { 626 if (deallocationFn) 627 return (*deallocationFn)(b, loc, allocatedBuffer); 628 629 // Default buffer deallocation via DeallocOp. 630 b.create<memref::DeallocOp>(loc, allocatedBuffer); 631 return success(); 632 } 633 634 /// Create a memory copy between two memref buffers. 635 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 636 Value from, Value to) const { 637 if (memCpyFn) 638 return (*memCpyFn)(b, loc, from, to); 639 640 b.create<memref::CopyOp>(loc, from, to); 641 return success(); 642 } 643 644 //===----------------------------------------------------------------------===// 645 // Bufferization-specific BlockAndValueMapping support with debugging. 646 //===----------------------------------------------------------------------===// 647 648 bool bufferization::isFunctionArgument(Value value) { 649 auto bbArg = value.dyn_cast<BlockArgument>(); 650 if (!bbArg) 651 return false; 652 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 653 } 654 655 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 656 const BufferizationOptions &options, 657 MemRefLayoutAttrInterface layout, 658 unsigned memorySpace) { 659 auto memorySpaceAttr = IntegerAttr::get( 660 IntegerType::get(tensorType.getContext(), 64), memorySpace); 661 662 // Case 1: Unranked memref type. 663 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 664 assert(!layout && "UnrankedTensorType cannot have a layout map"); 665 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 666 memorySpaceAttr); 667 } 668 669 // Case 2: Ranked memref type with specified layout. 670 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 671 if (layout) { 672 return MemRefType::get(rankedTensorType.getShape(), 673 rankedTensorType.getElementType(), layout, 674 memorySpaceAttr); 675 } 676 677 // Case 3: Configured with "fully dynamic layout maps". 678 if (options.unknownTypeConversion == 679 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap) 680 return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); 681 682 // Case 4: Configured with "static identity layout maps". 683 if (options.unknownTypeConversion == 684 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) 685 return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); 686 687 llvm_unreachable("InferLayoutMap is an invalid option"); 688 } 689 690 BaseMemRefType 691 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 692 unsigned memorySpace) { 693 // Case 1: Unranked memref type. 694 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 695 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 696 memorySpace); 697 } 698 699 // Case 2: Ranked memref type. 700 auto memorySpaceAttr = IntegerAttr::get( 701 IntegerType::get(tensorType.getContext(), 64), memorySpace); 702 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 703 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 704 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 705 ShapedType::kDynamicStrideOrOffset); 706 AffineMap stridedLayout = makeStridedLinearLayoutMap( 707 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 708 return MemRefType::get(rankedTensorType.getShape(), 709 rankedTensorType.getElementType(), stridedLayout, 710 memorySpaceAttr); 711 } 712 713 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 714 /// the given tensor type is unranked, return an unranked MemRef type. 715 BaseMemRefType 716 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 717 unsigned memorySpace) { 718 // Case 1: Unranked memref type. 719 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 720 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 721 memorySpace); 722 } 723 724 // Case 2: Ranked memref type. 725 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 726 auto memorySpaceAttr = IntegerAttr::get( 727 IntegerType::get(tensorType.getContext(), 64), memorySpace); 728 MemRefLayoutAttrInterface layout = {}; 729 return MemRefType::get(rankedTensorType.getShape(), 730 rankedTensorType.getElementType(), layout, 731 memorySpaceAttr); 732 } 733