1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AsmState.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/IR/Value.h" 20 #include "llvm/Support/Debug.h" 21 22 //===----------------------------------------------------------------------===// 23 // BufferizableOpInterface 24 //===----------------------------------------------------------------------===// 25 26 namespace mlir { 27 namespace bufferization { 28 29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 30 31 } // namespace bufferization 32 } // namespace mlir 33 34 #define DEBUG_TYPE "bufferizable-op-interface" 35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 37 38 using namespace mlir; 39 using namespace bufferization; 40 41 /// Return the owner of the given value. 42 static Operation *getOwnerOfValue(Value value) { 43 if (auto opResult = value.dyn_cast<OpResult>()) 44 return opResult.getDefiningOp(); 45 return value.cast<BlockArgument>().getOwner()->getParentOp(); 46 } 47 48 bool bufferization::allocationDoesNotEscape(OpResult opResult) { 49 #ifndef NDEBUG 50 auto bufferizableOp = opResult.getDefiningOp<BufferizableOpInterface>(); 51 assert(bufferizableOp && bufferizableOp.bufferizesToAllocation(opResult) && 52 "expected op that bufferizes to an allocation"); 53 #endif // NDEBUG 54 55 Operation *op = opResult.getDefiningOp(); 56 // If there is no 'escape' attribute, we cannot say for sure. 57 if (!op->hasAttr(BufferizationDialect::kEscapeAttrName)) 58 return false; 59 auto attr = 60 op->getAttrOfType<ArrayAttr>(BufferizationDialect::kEscapeAttrName); 61 return !attr[opResult.getResultNumber()].cast<BoolAttr>().getValue(); 62 } 63 64 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the 65 /// shaped value is copied. Otherwise, a tensor with undefined contents is 66 /// allocated. 67 FailureOr<Value> bufferization::allocateTensorForShapedValue( 68 OpBuilder &b, Location loc, Value shapedValue, bool escape, 69 const BufferizationOptions &options, bool copy) { 70 Value tensor; 71 if (shapedValue.getType().isa<RankedTensorType>()) { 72 tensor = shapedValue; 73 } else if (shapedValue.getType().isa<MemRefType>()) { 74 tensor = b.create<ToTensorOp>(loc, shapedValue); 75 } else { 76 llvm_unreachable("expected RankedTensorType or MemRefType"); 77 } 78 RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>(); 79 SmallVector<Value> dynamicSizes; 80 if (!copy) { 81 // Compute the dynamic part of the shape. 82 // First try to query the shape via ReifyRankedShapedTypeOpInterface. 83 bool reifiedShapes = false; 84 if (shapedValue.getType().isa<RankedTensorType>() && 85 shapedValue.isa<OpResult>()) { 86 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 87 shapedValue.getDefiningOp())) { 88 ReifiedRankedShapedTypeDims resultDims; 89 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 90 reifiedShapes = true; 91 auto &shape = 92 resultDims[shapedValue.cast<OpResult>().getResultNumber()]; 93 for (const auto &dim : enumerate(tensorType.getShape())) 94 if (ShapedType::isDynamic(dim.value())) 95 dynamicSizes.push_back(shape[dim.index()]); 96 } 97 } 98 } 99 100 // If the shape could not be reified, create DimOps. 101 if (!reifiedShapes) 102 populateDynamicDimSizes(b, loc, tensor, dynamicSizes); 103 } 104 105 // Create AllocTensorOp. 106 auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes, 107 copy ? tensor : Value()); 108 allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName, 109 b.getBoolArrayAttr({escape})); 110 111 // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. 112 if (copy) 113 return allocTensorOp.getResult(); 114 FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options); 115 if (failed(copyBufferType)) 116 return failure(); 117 allocTensorOp.setMemorySpaceAttr( 118 b.getIntegerAttr(b.getIntegerType(64, /*isSigned=*/false), 119 copyBufferType->getMemorySpaceAsInt())); 120 return allocTensorOp.getResult(); 121 } 122 123 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 124 RewriterBase &rewriter, const AnalysisState &state) { 125 OpBuilder::InsertionGuard g(rewriter); 126 Operation *op = getOperation(); 127 SmallVector<OpOperand *> outOfPlaceOpOperands; 128 DenseSet<OpOperand *> copiedOpOperands; 129 DenseSet<OpOperand *> escapingOpOperandCopies; 130 SmallVector<OpResult> outOfPlaceOpResults; 131 DenseSet<OpResult> copiedOpResults; 132 DenseSet<OpResult> escapingOpResultCopies; 133 134 // Find all out-of-place OpOperands. 135 for (OpOperand &opOperand : op->getOpOperands()) { 136 Type operandType = opOperand.get().getType(); 137 if (!operandType.isa<TensorType>()) 138 continue; 139 if (state.isInPlace(opOperand)) 140 continue; 141 if (operandType.isa<UnrankedTensorType>()) 142 return op->emitError("copies of unranked tensors are not supported"); 143 144 SmallVector<OpResult> aliasingOpResults = 145 state.getAliasingOpResult(opOperand); 146 // Is the result yielded from a block? Or are deallocations turned off 147 // entirely? In either case, mark the allocation as "escaping", so that it 148 // will not be deallocated. 149 bool escape = !state.getOptions().createDeallocs || 150 llvm::any_of(aliasingOpResults, [&](Value v) { 151 return state.isTensorYielded(v); 152 }); 153 154 if (aliasingOpResults.size() == 1 && 155 !state.bufferizesToMemoryWrite(opOperand) && 156 state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { 157 // The op itself does not write but may create exactly one alias. Instead 158 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes 159 // be smaller than the OpOperand (e.g., in the case of an extract_slice, 160 // where the result is usually a smaller part of the source). 161 outOfPlaceOpResults.push_back(aliasingOpResults.front()); 162 if (!state.canOmitTensorCopy(opOperand)) 163 copiedOpResults.insert(aliasingOpResults.front()); 164 if (escape) 165 escapingOpResultCopies.insert(aliasingOpResults.front()); 166 } else { 167 // In all other cases, make a copy of the OpOperand. 168 outOfPlaceOpOperands.push_back(&opOperand); 169 if (!state.canOmitTensorCopy(opOperand)) 170 copiedOpOperands.insert(&opOperand); 171 if (escape) 172 escapingOpOperandCopies.insert(&opOperand); 173 } 174 } 175 176 // Insert copies of OpOperands. 177 rewriter.setInsertionPoint(op); 178 for (OpOperand *opOperand : outOfPlaceOpOperands) { 179 FailureOr<Value> copy = allocateTensorForShapedValue( 180 rewriter, op->getLoc(), opOperand->get(), 181 escapingOpOperandCopies.contains(opOperand), state.getOptions(), 182 copiedOpOperands.contains(opOperand)); 183 if (failed(copy)) 184 return failure(); 185 rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); }); 186 } 187 188 // Insert copies of OpResults. 189 rewriter.setInsertionPointAfter(op); 190 for (OpResult opResult : outOfPlaceOpResults) { 191 FailureOr<Value> copy = allocateTensorForShapedValue( 192 rewriter, op->getLoc(), opResult, 193 escapingOpResultCopies.contains(opResult), state.getOptions(), 194 copiedOpResults.count(opResult)); 195 if (failed(copy)) 196 return failure(); 197 SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range( 198 opResult.getUses(), [](OpOperand &use) { return &use; })); 199 for (OpOperand *use : uses) { 200 // Do not update the alloc_tensor op that we just created. 201 if (use->getOwner() != copy->getDefiningOp()) 202 rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); }); 203 } 204 } 205 206 return success(); 207 } 208 209 //===----------------------------------------------------------------------===// 210 // OpFilter 211 //===----------------------------------------------------------------------===// 212 213 bool OpFilter::isOpAllowed(Operation *op) const { 214 // All other ops: Allow/disallow according to filter. 215 bool isAllowed = !hasAllowRule(); 216 for (const Entry &entry : entries) { 217 bool filterResult = entry.fn(op); 218 switch (entry.type) { 219 case Entry::ALLOW: 220 isAllowed |= filterResult; 221 break; 222 case Entry::DENY: 223 if (filterResult) 224 // DENY filter matches. This op is no allowed. (Even if other ALLOW 225 // filters may match.) 226 return false; 227 }; 228 } 229 return isAllowed; 230 } 231 232 //===----------------------------------------------------------------------===// 233 // BufferizationOptions 234 //===----------------------------------------------------------------------===// 235 236 /// Default unknown type converter: Use a fully dynamic layout map. 237 static BaseMemRefType 238 defaultUnknownTypeConverter(Value value, unsigned memorySpace, 239 const BufferizationOptions &options) { 240 return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(), 241 memorySpace); 242 } 243 244 // Default constructor for BufferizationOptions. 245 BufferizationOptions::BufferizationOptions() 246 : unknownTypeConverterFn(defaultUnknownTypeConverter) {} 247 248 bool BufferizationOptions::isOpAllowed(Operation *op) const { 249 // Special case: If function boundary bufferization is deactivated, do not 250 // allow ops that belong to the `func` dialect. 251 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 252 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 253 return false; 254 255 return opFilter.isOpAllowed(op); 256 } 257 258 BufferizableOpInterface 259 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 260 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 261 if (!bufferizableOp) 262 return nullptr; 263 if (!isOpAllowed(op)) 264 return nullptr; 265 return bufferizableOp; 266 } 267 268 BufferizableOpInterface 269 BufferizationOptions::dynCastBufferizableOp(Value value) const { 270 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 271 if (isOpAllowed(bufferizableOp.getOperation())) 272 return bufferizableOp; 273 return nullptr; 274 } 275 276 void BufferizationOptions::addDialectStateInitializer( 277 StringRef name, const DialectStateInitFn &fn) { 278 stateInitializers.push_back( 279 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 280 } 281 282 //===----------------------------------------------------------------------===// 283 // Helper functions for BufferizableOpInterface 284 //===----------------------------------------------------------------------===// 285 286 static void setInsertionPointAfter(OpBuilder &b, Value value) { 287 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 288 b.setInsertionPointToStart(bbArg.getOwner()); 289 } else { 290 b.setInsertionPointAfter(value.getDefiningOp()); 291 } 292 } 293 294 /// Determine which OpOperand* will alias with `result` if the op is bufferized 295 /// in place. Return an empty vector if the op is not bufferizable. 296 SmallVector<OpOperand *> 297 AnalysisState::getAliasingOpOperand(OpResult result) const { 298 if (Operation *op = result.getDefiningOp()) 299 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 300 return bufferizableOp.getAliasingOpOperand(result, *this); 301 return {}; 302 } 303 304 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 305 /// in place. Return an empty vector if the op is not bufferizable. 306 SmallVector<OpResult> 307 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 308 if (auto bufferizableOp = 309 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 310 return bufferizableOp.getAliasingOpResult(opOperand, *this); 311 return {}; 312 } 313 314 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 315 /// op is not bufferizable. 316 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 317 if (auto bufferizableOp = 318 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 319 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 320 321 // Unknown op that returns a tensor. The inplace analysis does not support it. 322 // Conservatively return true. 323 return true; 324 } 325 326 /// Return true if `opOperand` bufferizes to a memory write. Return 327 /// `true` if the op is not bufferizable. 328 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 329 if (auto bufferizableOp = 330 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 331 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 332 333 // Unknown op that returns a tensor. The inplace analysis does not support it. 334 // Conservatively return true. 335 return true; 336 } 337 338 /// Return true if `opOperand` does neither read nor write but bufferizes to an 339 /// alias. Return false if the op is not bufferizable. 340 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 341 if (auto bufferizableOp = 342 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 343 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 344 345 // Unknown op that returns a tensor. The inplace analysis does not support it. 346 // Conservatively return false. 347 return false; 348 } 349 350 /// Return true if the given value is read by an op that bufferizes to a memory 351 /// read. Also takes into account ops that create an alias but do not read by 352 /// themselves (e.g., ExtractSliceOp). 353 bool AnalysisState::isValueRead(Value value) const { 354 assert(value.getType().isa<TensorType>() && "expected TensorType"); 355 SmallVector<OpOperand *> workingSet; 356 for (OpOperand &use : value.getUses()) 357 workingSet.push_back(&use); 358 359 while (!workingSet.empty()) { 360 OpOperand *uMaybeReading = workingSet.pop_back_val(); 361 // Skip over all ops that neither read nor write (but create an alias). 362 if (bufferizesToAliasOnly(*uMaybeReading)) 363 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 364 for (OpOperand &use : opResult.getUses()) 365 workingSet.push_back(&use); 366 if (bufferizesToMemoryRead(*uMaybeReading)) 367 return true; 368 } 369 370 return false; 371 } 372 373 // Starting from `value`, follow the use-def chain in reverse, always selecting 374 // the aliasing OpOperands. Find and return Values for which `condition` 375 // evaluates to true. OpOperands of such matching Values are not traversed any 376 // further. 377 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 378 Value value, llvm::function_ref<bool(Value)> condition) const { 379 llvm::SetVector<Value> result, workingSet; 380 workingSet.insert(value); 381 382 while (!workingSet.empty()) { 383 Value value = workingSet.pop_back_val(); 384 if (condition(value) || value.isa<BlockArgument>()) { 385 result.insert(value); 386 continue; 387 } 388 389 OpResult opResult = value.cast<OpResult>(); 390 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 391 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 392 result.insert(value); 393 continue; 394 } 395 396 for (OpOperand *o : opOperands) 397 workingSet.insert(o->get()); 398 } 399 400 return result; 401 } 402 403 // Find the Values of the last preceding write of a given Value. 404 llvm::SetVector<Value> 405 AnalysisState::findLastPrecedingWrite(Value value) const { 406 return findValueInReverseUseDefChain(value, [&](Value value) { 407 Operation *op = value.getDefiningOp(); 408 if (!op) 409 return true; 410 auto bufferizableOp = options.dynCastBufferizableOp(op); 411 if (!bufferizableOp) 412 return true; 413 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 414 }); 415 } 416 417 AnalysisState::AnalysisState(const BufferizationOptions &options) 418 : options(options) { 419 for (const BufferizationOptions::AnalysisStateInitFn &fn : 420 options.stateInitializers) 421 fn(*this); 422 } 423 424 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const { 425 // Do not copy if the tensor has undefined contents. 426 if (hasUndefinedContents(&opOperand)) 427 return true; 428 429 // Do not copy if the buffer of the tensor is entirely overwritten (with 430 // values that do not depend on the old tensor). 431 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 432 return true; 433 434 // Do not copy if the tensor is never read. 435 SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand); 436 if (!bufferizesToMemoryRead(opOperand) && 437 llvm::none_of(aliasingOpResults, 438 [&](OpResult opResult) { return isValueRead(opResult); })) 439 return true; 440 441 // Default: Cannot omit the copy. 442 return false; 443 } 444 445 bool AnalysisState::isInPlace(OpOperand &opOperand) const { 446 // ToMemrefOps are always in-place. 447 if (isa<ToMemrefOp>(opOperand.getOwner())) 448 return true; 449 450 // In the absence of analysis information, OpOperands that bufferize to a 451 // memory write are out-of-place, i.e., an alloc and copy is inserted. 452 return !bufferizesToMemoryWrite(opOperand); 453 } 454 455 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const { 456 // In the absence of analysis information, we do not know if the values are 457 // equivalent. The conservative answer is "false". 458 return false; 459 } 460 461 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const { 462 // In the absence of analysis information, we do not know if the values may be 463 // aliasing. The conservative answer is "true". 464 return true; 465 } 466 467 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 468 // In the absence of analysis information, the conservative answer is "false". 469 return false; 470 } 471 472 bool AnalysisState::isTensorYielded(Value tensor) const { 473 // In the absence of analysis information, the conservative answer is "true". 474 if (!tensor.getDefiningOp<AllocTensorOp>()) 475 return true; 476 477 // For AllocTensorOp results, we can do better: They do not alias with any 478 // preceding value, so we can follow SSA use-def chains and do a simple 479 // analysis. 480 SmallVector<OpOperand *> worklist; 481 for (OpOperand &use : tensor.getUses()) 482 worklist.push_back(&use); 483 484 while (!worklist.empty()) { 485 OpOperand *operand = worklist.pop_back_val(); 486 Operation *op = operand->getOwner(); 487 488 // If the op is not bufferizable, we can safely assume that the value is not 489 // yielded. (When bufferizing that op, it must handle such cases.) 490 if (!options.dynCastBufferizableOp(op)) 491 continue; 492 493 // We cannot analyze through ToMemrefOps, so we have to conservatively 494 // assume that the value is yielded. 495 if (isa<ToMemrefOp>(op)) 496 return true; 497 498 // Check if the op is returning/yielding. 499 if (isRegionReturnLike(op)) 500 return true; 501 502 // Add all aliasing OpResults to the worklist. 503 // Note: In the absence of detailed analysis information (e.g., there may be 504 // no function call analysis information), this `getAliasingOpResult` is 505 // conservative and may report additional OpResults as potentially aliasing. 506 for (OpResult opResult : getAliasingOpResult(*operand)) 507 for (OpOperand &use : opResult.getUses()) 508 worklist.push_back(&use); 509 } 510 511 // No ReturnLike op found: The value is not yielded. 512 return false; 513 } 514 515 // bufferization.to_memref is not allowed to change the rank. 516 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 517 #ifndef NDEBUG 518 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 519 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 520 rankedTensorType.getRank()) && 521 "to_memref would be invalid: mismatching ranks"); 522 #endif 523 } 524 525 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, 526 const BufferizationOptions &options) { 527 #ifndef NDEBUG 528 auto tensorType = value.getType().dyn_cast<TensorType>(); 529 assert(tensorType && "unexpected non-tensor type"); 530 #endif // NDEBUG 531 532 // Replace "%t = to_tensor %m" with %m. 533 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 534 return toTensorOp.getMemref(); 535 536 // Insert to_memref op. 537 OpBuilder::InsertionGuard g(rewriter); 538 setInsertionPointAfter(rewriter, value); 539 FailureOr<BaseMemRefType> memrefType = getBufferType(value, options); 540 if (failed(memrefType)) 541 return failure(); 542 ensureToMemrefOpIsValid(value, *memrefType); 543 return rewriter 544 .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value) 545 .getResult(); 546 } 547 548 /// Return the buffer type for a given Value (tensor) after bufferization. 549 FailureOr<BaseMemRefType> 550 bufferization::getBufferType(Value value, const BufferizationOptions &options) { 551 assert(value.getType().isa<TensorType>() && "unexpected non-tensor type"); 552 Operation *op = getOwnerOfValue(value); 553 554 // ToTensorOp: Take buffer type directly from the op. 555 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 556 return toTensorOp.getMemref().getType().cast<BaseMemRefType>(); 557 558 // If value is a bbArg of a bufferizable op: query op interface. 559 if (auto bbArg = value.dyn_cast<BlockArgument>()) 560 if (auto bufferizableOp = 561 options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) 562 return bufferizableOp.getBufferType(bbArg, options); 563 564 // Check value is a new buffer allocation with a memory space attribute. In 565 // that case we can at least infer the memory space. 566 Optional<unsigned> memorySpace = None; 567 if (auto opResult = value.dyn_cast<OpResult>()) { 568 if (auto bufferizableOp = 569 options.dynCastBufferizableOp(opResult.getDefiningOp())) { 570 if (bufferizableOp.bufferizesToAllocation(opResult)) { 571 FailureOr<unsigned> queriedMemorySpace = 572 bufferizableOp.getMemorySpace(opResult); 573 if (!failed(queriedMemorySpace)) 574 memorySpace = *queriedMemorySpace; 575 } 576 } 577 } 578 579 // If we still do not know the memory space, use the default memory space (if 580 // any). 581 if (!memorySpace.has_value()) 582 memorySpace = options.defaultMemorySpace; 583 584 // If we still do not know the memory space, report a failure. 585 if (!memorySpace.has_value()) 586 return op->emitError("could not infer memory space"); 587 588 return getMemRefType(value, options, /*layout=*/{}, *memorySpace); 589 } 590 591 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 592 Operation *op, 593 ValueRange values) { 594 assert(values.size() == op->getNumResults() && 595 "expected one value per OpResult"); 596 OpBuilder::InsertionGuard g(rewriter); 597 598 // Replace all OpResults with the given values. 599 SmallVector<Value> replacements; 600 for (OpResult opResult : op->getOpResults()) { 601 Value replacement = values[opResult.getResultNumber()]; 602 if (opResult.getType().isa<TensorType>()) { 603 // The OpResult is a tensor. Such values are replaced with memrefs during 604 // bufferization. 605 assert((replacement.getType().isa<MemRefType>() || 606 replacement.getType().isa<UnrankedMemRefType>()) && 607 "tensor op result should be replaced with a memref value"); 608 // The existing uses of the OpResult still expect a tensor. Insert a 609 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 610 // loose all of its users and eventually DCE away. 611 rewriter.setInsertionPointAfter(op); 612 replacement = rewriter.create<bufferization::ToTensorOp>( 613 replacement.getLoc(), replacement); 614 } 615 replacements.push_back(replacement); 616 } 617 618 rewriter.replaceOp(op, replacements); 619 } 620 621 //===----------------------------------------------------------------------===// 622 // Bufferization-specific scoped alloc/dealloc insertion support. 623 //===----------------------------------------------------------------------===// 624 625 /// Create a memref allocation with the given type and dynamic extents. 626 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 627 MemRefType type, 628 ValueRange dynShape) const { 629 if (allocationFn) 630 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 631 632 // Default bufferallocation via AllocOp. 633 if (bufferAlignment != 0) 634 return b 635 .create<memref::AllocOp>(loc, type, dynShape, 636 b.getI64IntegerAttr(bufferAlignment)) 637 .getResult(); 638 return b.create<memref::AllocOp>(loc, type, dynShape).getResult(); 639 } 640 641 /// Creates a memref deallocation. The given memref buffer must have been 642 /// allocated using `createAlloc`. 643 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 644 Value allocatedBuffer) const { 645 if (deallocationFn) 646 return (*deallocationFn)(b, loc, allocatedBuffer); 647 648 // Default buffer deallocation via DeallocOp. 649 b.create<memref::DeallocOp>(loc, allocatedBuffer); 650 return success(); 651 } 652 653 /// Create a memory copy between two memref buffers. 654 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 655 Value from, Value to) const { 656 if (memCpyFn) 657 return (*memCpyFn)(b, loc, from, to); 658 659 b.create<memref::CopyOp>(loc, from, to); 660 return success(); 661 } 662 663 //===----------------------------------------------------------------------===// 664 // Bufferization-specific BlockAndValueMapping support with debugging. 665 //===----------------------------------------------------------------------===// 666 667 bool bufferization::isFunctionArgument(Value value) { 668 auto bbArg = value.dyn_cast<BlockArgument>(); 669 if (!bbArg) 670 return false; 671 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 672 } 673 674 BaseMemRefType bufferization::getMemRefType(Value value, 675 const BufferizationOptions &options, 676 MemRefLayoutAttrInterface layout, 677 unsigned memorySpace) { 678 auto tensorType = value.getType().cast<TensorType>(); 679 auto memorySpaceAttr = IntegerAttr::get( 680 IntegerType::get(tensorType.getContext(), 64), memorySpace); 681 682 // Case 1: Unranked memref type. 683 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 684 assert(!layout && "UnrankedTensorType cannot have a layout map"); 685 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 686 memorySpaceAttr); 687 } 688 689 // Case 2: Ranked memref type with specified layout. 690 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 691 if (layout) { 692 return MemRefType::get(rankedTensorType.getShape(), 693 rankedTensorType.getElementType(), layout, 694 memorySpaceAttr); 695 } 696 697 return options.unknownTypeConverterFn(value, memorySpace, options); 698 } 699 700 BaseMemRefType 701 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 702 unsigned memorySpace) { 703 // Case 1: Unranked memref type. 704 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 705 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 706 memorySpace); 707 } 708 709 // Case 2: Ranked memref type. 710 auto memorySpaceAttr = IntegerAttr::get( 711 IntegerType::get(tensorType.getContext(), 64), memorySpace); 712 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 713 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 714 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 715 ShapedType::kDynamicStrideOrOffset); 716 AffineMap stridedLayout = makeStridedLinearLayoutMap( 717 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 718 return MemRefType::get(rankedTensorType.getShape(), 719 rankedTensorType.getElementType(), stridedLayout, 720 memorySpaceAttr); 721 } 722 723 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 724 /// the given tensor type is unranked, return an unranked MemRef type. 725 BaseMemRefType 726 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 727 unsigned memorySpace) { 728 // Case 1: Unranked memref type. 729 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 730 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 731 memorySpace); 732 } 733 734 // Case 2: Ranked memref type. 735 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 736 auto memorySpaceAttr = IntegerAttr::get( 737 IntegerType::get(tensorType.getContext(), 64), memorySpace); 738 MemRefLayoutAttrInterface layout = {}; 739 return MemRefType::get(rankedTensorType.getShape(), 740 rankedTensorType.getElementType(), layout, 741 memorySpaceAttr); 742 } 743