1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AsmState.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/IR/Value.h" 20 #include "llvm/Support/Debug.h" 21 22 //===----------------------------------------------------------------------===// 23 // BufferizableOpInterface 24 //===----------------------------------------------------------------------===// 25 26 namespace mlir { 27 namespace bufferization { 28 29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 30 31 } // namespace bufferization 32 } // namespace mlir 33 34 #define DEBUG_TYPE "bufferizable-op-interface" 35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 37 38 using namespace mlir; 39 using namespace bufferization; 40 41 /// Return the owner of the given value. 42 static Operation *getOwnerOfValue(Value value) { 43 if (auto opResult = value.dyn_cast<OpResult>()) 44 return opResult.getDefiningOp(); 45 return value.cast<BlockArgument>().getOwner()->getParentOp(); 46 } 47 48 bool bufferization::allocationDoesNotEscape(OpResult opResult) { 49 #ifndef NDEBUG 50 auto bufferizableOp = opResult.getDefiningOp<BufferizableOpInterface>(); 51 assert(bufferizableOp && bufferizableOp.bufferizesToAllocation(opResult) && 52 "expected op that bufferizes to an allocation"); 53 #endif // NDEBUG 54 55 Operation *op = opResult.getDefiningOp(); 56 // If there is no 'escape' attribute, we cannot say for sure. 57 if (!op->hasAttr(BufferizationDialect::kEscapeAttrName)) 58 return false; 59 auto attr = 60 op->getAttrOfType<ArrayAttr>(BufferizationDialect::kEscapeAttrName); 61 return !attr[opResult.getResultNumber()].cast<BoolAttr>().getValue(); 62 } 63 64 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the 65 /// shaped value is copied. Otherwise, a tensor with undefined contents is 66 /// allocated. 67 FailureOr<Value> bufferization::allocateTensorForShapedValue( 68 OpBuilder &b, Location loc, Value shapedValue, bool escape, 69 const BufferizationOptions &options, bool copy) { 70 Value tensor; 71 if (shapedValue.getType().isa<RankedTensorType>()) { 72 tensor = shapedValue; 73 } else if (shapedValue.getType().isa<MemRefType>()) { 74 tensor = b.create<ToTensorOp>(loc, shapedValue); 75 } else { 76 llvm_unreachable("expected RankedTensorType or MemRefType"); 77 } 78 RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>(); 79 SmallVector<Value> dynamicSizes; 80 if (!copy) { 81 // Compute the dynamic part of the shape. 82 // First try to query the shape via ReifyRankedShapedTypeOpInterface. 83 bool reifiedShapes = false; 84 if (shapedValue.getType().isa<RankedTensorType>() && 85 shapedValue.isa<OpResult>()) { 86 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 87 shapedValue.getDefiningOp())) { 88 ReifiedRankedShapedTypeDims resultDims; 89 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 90 reifiedShapes = true; 91 auto &shape = 92 resultDims[shapedValue.cast<OpResult>().getResultNumber()]; 93 for (const auto &dim : enumerate(tensorType.getShape())) 94 if (ShapedType::isDynamic(dim.value())) 95 dynamicSizes.push_back(shape[dim.index()]); 96 } 97 } 98 } 99 100 // If the shape could not be reified, create DimOps. 101 if (!reifiedShapes) 102 populateDynamicDimSizes(b, loc, tensor, dynamicSizes); 103 } 104 105 // Create AllocTensorOp. 106 auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes, 107 copy ? tensor : Value()); 108 allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName, 109 b.getBoolArrayAttr({escape})); 110 111 // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. 112 if (copy) 113 return allocTensorOp.getResult(); 114 FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options); 115 if (failed(copyBufferType)) 116 return failure(); 117 allocTensorOp.setMemorySpaceAttr( 118 b.getIntegerAttr(b.getIntegerType(64, /*isSigned=*/false), 119 copyBufferType->getMemorySpaceAsInt())); 120 return allocTensorOp.getResult(); 121 } 122 123 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 124 RewriterBase &rewriter, const AnalysisState &state) { 125 OpBuilder::InsertionGuard g(rewriter); 126 Operation *op = getOperation(); 127 SmallVector<OpOperand *> outOfPlaceOpOperands; 128 DenseSet<OpOperand *> copiedOpOperands; 129 DenseSet<OpOperand *> escapingOpOperandCopies; 130 SmallVector<OpResult> outOfPlaceOpResults; 131 DenseSet<OpResult> copiedOpResults; 132 DenseSet<OpResult> escapingOpResultCopies; 133 134 // Find all out-of-place OpOperands. 135 for (OpOperand &opOperand : op->getOpOperands()) { 136 Type operandType = opOperand.get().getType(); 137 if (!operandType.isa<TensorType>()) 138 continue; 139 if (state.isInPlace(opOperand)) 140 continue; 141 if (operandType.isa<UnrankedTensorType>()) 142 return op->emitError("copies of unranked tensors are not supported"); 143 144 SmallVector<OpResult> aliasingOpResults = 145 state.getAliasingOpResult(opOperand); 146 // Is the result yielded from a block? Or are deallocations turned off 147 // entirely? In either case, mark the allocation as "escaping", so that it 148 // will not be deallocated. 149 bool escape = !state.getOptions().createDeallocs || 150 llvm::any_of(aliasingOpResults, [&](Value v) { 151 return state.isTensorYielded(v); 152 }); 153 154 if (aliasingOpResults.size() == 1 && 155 !state.bufferizesToMemoryWrite(opOperand) && 156 state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { 157 // The op itself does not write but may create exactly one alias. Instead 158 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes 159 // be smaller than the OpOperand (e.g., in the case of an extract_slice, 160 // where the result is usually a smaller part of the source). 161 outOfPlaceOpResults.push_back(aliasingOpResults.front()); 162 if (!state.canOmitTensorCopy(opOperand)) 163 copiedOpResults.insert(aliasingOpResults.front()); 164 if (escape) 165 escapingOpResultCopies.insert(aliasingOpResults.front()); 166 } else { 167 // In all other cases, make a copy of the OpOperand. 168 outOfPlaceOpOperands.push_back(&opOperand); 169 if (!state.canOmitTensorCopy(opOperand)) 170 copiedOpOperands.insert(&opOperand); 171 if (escape) 172 escapingOpOperandCopies.insert(&opOperand); 173 } 174 } 175 176 // Insert copies of OpOperands. 177 rewriter.setInsertionPoint(op); 178 for (OpOperand *opOperand : outOfPlaceOpOperands) { 179 FailureOr<Value> copy = allocateTensorForShapedValue( 180 rewriter, op->getLoc(), opOperand->get(), 181 escapingOpOperandCopies.contains(opOperand), state.getOptions(), 182 copiedOpOperands.contains(opOperand)); 183 if (failed(copy)) 184 return failure(); 185 rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); }); 186 } 187 188 // Insert copies of OpResults. 189 rewriter.setInsertionPointAfter(op); 190 for (OpResult opResult : outOfPlaceOpResults) { 191 FailureOr<Value> copy = allocateTensorForShapedValue( 192 rewriter, op->getLoc(), opResult, 193 escapingOpResultCopies.contains(opResult), state.getOptions(), 194 copiedOpResults.count(opResult)); 195 if (failed(copy)) 196 return failure(); 197 SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range( 198 opResult.getUses(), [](OpOperand &use) { return &use; })); 199 for (OpOperand *use : uses) { 200 // Do not update the alloc_tensor op that we just created. 201 if (use->getOwner() != copy->getDefiningOp()) 202 rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); }); 203 } 204 } 205 206 return success(); 207 } 208 209 bool bufferization::shouldDeallocateOpResult( 210 OpResult opResult, const BufferizationOptions &options) { 211 Operation *op = opResult.getOwner(); 212 assert(options.dynCastBufferizableOp(op).bufferizesToAllocation(opResult) && 213 "expected that op allocates"); 214 215 AnalysisState analysisState(options); 216 if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) { 217 // AllocTensorOp has one result. 218 ArrayAttr escapeAttr = 219 op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>(); 220 return !escapeAttr[0].cast<BoolAttr>().getValue(); 221 } 222 223 // No "escape" annotation found. 224 if (options.createDeallocs) { 225 // Perform an ad-hoc analysis. 226 return !analysisState.isTensorYielded(opResult); 227 } 228 229 return false; 230 } 231 232 //===----------------------------------------------------------------------===// 233 // OpFilter 234 //===----------------------------------------------------------------------===// 235 236 bool OpFilter::isOpAllowed(Operation *op) const { 237 // All other ops: Allow/disallow according to filter. 238 bool isAllowed = !hasAllowRule(); 239 for (const Entry &entry : entries) { 240 bool filterResult = entry.fn(op); 241 switch (entry.type) { 242 case Entry::ALLOW: 243 isAllowed |= filterResult; 244 break; 245 case Entry::DENY: 246 if (filterResult) 247 // DENY filter matches. This op is no allowed. (Even if other ALLOW 248 // filters may match.) 249 return false; 250 }; 251 } 252 return isAllowed; 253 } 254 255 //===----------------------------------------------------------------------===// 256 // BufferizationOptions 257 //===----------------------------------------------------------------------===// 258 259 /// Default unknown type converter: Use a fully dynamic layout map. 260 static BaseMemRefType 261 defaultUnknownTypeConverter(Value value, unsigned memorySpace, 262 const BufferizationOptions &options) { 263 return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(), 264 memorySpace); 265 } 266 267 // Default constructor for BufferizationOptions. 268 BufferizationOptions::BufferizationOptions() 269 : unknownTypeConverterFn(defaultUnknownTypeConverter) {} 270 271 bool BufferizationOptions::isOpAllowed(Operation *op) const { 272 // Special case: If function boundary bufferization is deactivated, do not 273 // allow ops that belong to the `func` dialect. 274 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 275 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 276 return false; 277 278 return opFilter.isOpAllowed(op); 279 } 280 281 BufferizableOpInterface 282 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 283 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 284 if (!bufferizableOp) 285 return nullptr; 286 if (!isOpAllowed(op)) 287 return nullptr; 288 return bufferizableOp; 289 } 290 291 BufferizableOpInterface 292 BufferizationOptions::dynCastBufferizableOp(Value value) const { 293 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 294 if (isOpAllowed(bufferizableOp.getOperation())) 295 return bufferizableOp; 296 return nullptr; 297 } 298 299 void BufferizationOptions::addDialectStateInitializer( 300 StringRef name, const DialectStateInitFn &fn) { 301 stateInitializers.push_back( 302 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 303 } 304 305 //===----------------------------------------------------------------------===// 306 // Helper functions for BufferizableOpInterface 307 //===----------------------------------------------------------------------===// 308 309 static void setInsertionPointAfter(OpBuilder &b, Value value) { 310 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 311 b.setInsertionPointToStart(bbArg.getOwner()); 312 } else { 313 b.setInsertionPointAfter(value.getDefiningOp()); 314 } 315 } 316 317 /// Determine which OpOperand* will alias with `result` if the op is bufferized 318 /// in place. Return an empty vector if the op is not bufferizable. 319 SmallVector<OpOperand *> 320 AnalysisState::getAliasingOpOperand(OpResult result) const { 321 if (Operation *op = result.getDefiningOp()) 322 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 323 return bufferizableOp.getAliasingOpOperand(result, *this); 324 return {}; 325 } 326 327 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 328 /// in place. Return an empty vector if the op is not bufferizable. 329 SmallVector<OpResult> 330 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 331 if (auto bufferizableOp = 332 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 333 return bufferizableOp.getAliasingOpResult(opOperand, *this); 334 return {}; 335 } 336 337 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 338 /// op is not bufferizable. 339 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 340 if (auto bufferizableOp = 341 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 342 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 343 344 // Unknown op that returns a tensor. The inplace analysis does not support it. 345 // Conservatively return true. 346 return true; 347 } 348 349 /// Return true if `opOperand` bufferizes to a memory write. Return 350 /// `true` if the op is not bufferizable. 351 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 352 if (auto bufferizableOp = 353 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 354 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 355 356 // Unknown op that returns a tensor. The inplace analysis does not support it. 357 // Conservatively return true. 358 return true; 359 } 360 361 /// Return true if `opOperand` does neither read nor write but bufferizes to an 362 /// alias. Return false if the op is not bufferizable. 363 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 364 if (auto bufferizableOp = 365 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 366 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 367 368 // Unknown op that returns a tensor. The inplace analysis does not support it. 369 // Conservatively return false. 370 return false; 371 } 372 373 /// Return true if the given value is read by an op that bufferizes to a memory 374 /// read. Also takes into account ops that create an alias but do not read by 375 /// themselves (e.g., ExtractSliceOp). 376 bool AnalysisState::isValueRead(Value value) const { 377 assert(value.getType().isa<TensorType>() && "expected TensorType"); 378 SmallVector<OpOperand *> workingSet; 379 for (OpOperand &use : value.getUses()) 380 workingSet.push_back(&use); 381 382 while (!workingSet.empty()) { 383 OpOperand *uMaybeReading = workingSet.pop_back_val(); 384 // Skip over all ops that neither read nor write (but create an alias). 385 if (bufferizesToAliasOnly(*uMaybeReading)) 386 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 387 for (OpOperand &use : opResult.getUses()) 388 workingSet.push_back(&use); 389 if (bufferizesToMemoryRead(*uMaybeReading)) 390 return true; 391 } 392 393 return false; 394 } 395 396 // Starting from `value`, follow the use-def chain in reverse, always selecting 397 // the aliasing OpOperands. Find and return Values for which `condition` 398 // evaluates to true. OpOperands of such matching Values are not traversed any 399 // further. 400 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 401 Value value, llvm::function_ref<bool(Value)> condition) const { 402 llvm::SetVector<Value> result, workingSet; 403 workingSet.insert(value); 404 405 while (!workingSet.empty()) { 406 Value value = workingSet.pop_back_val(); 407 if (condition(value) || value.isa<BlockArgument>()) { 408 result.insert(value); 409 continue; 410 } 411 412 OpResult opResult = value.cast<OpResult>(); 413 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 414 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 415 result.insert(value); 416 continue; 417 } 418 419 for (OpOperand *o : opOperands) 420 workingSet.insert(o->get()); 421 } 422 423 return result; 424 } 425 426 // Find the Values of the last preceding write of a given Value. 427 llvm::SetVector<Value> 428 AnalysisState::findLastPrecedingWrite(Value value) const { 429 return findValueInReverseUseDefChain(value, [&](Value value) { 430 Operation *op = value.getDefiningOp(); 431 if (!op) 432 return true; 433 auto bufferizableOp = options.dynCastBufferizableOp(op); 434 if (!bufferizableOp) 435 return true; 436 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 437 }); 438 } 439 440 AnalysisState::AnalysisState(const BufferizationOptions &options) 441 : options(options) { 442 for (const BufferizationOptions::AnalysisStateInitFn &fn : 443 options.stateInitializers) 444 fn(*this); 445 } 446 447 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const { 448 // Do not copy if the tensor has undefined contents. 449 if (hasUndefinedContents(&opOperand)) 450 return true; 451 452 // Do not copy if the buffer of the tensor is entirely overwritten (with 453 // values that do not depend on the old tensor). 454 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 455 return true; 456 457 // Do not copy if the tensor is never read. 458 SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand); 459 if (!bufferizesToMemoryRead(opOperand) && 460 llvm::none_of(aliasingOpResults, 461 [&](OpResult opResult) { return isValueRead(opResult); })) 462 return true; 463 464 // Default: Cannot omit the copy. 465 return false; 466 } 467 468 bool AnalysisState::isInPlace(OpOperand &opOperand) const { 469 // ToMemrefOps are always in-place. 470 if (isa<ToMemrefOp>(opOperand.getOwner())) 471 return true; 472 473 // In the absence of analysis information, OpOperands that bufferize to a 474 // memory write are out-of-place, i.e., an alloc and copy is inserted. 475 return !bufferizesToMemoryWrite(opOperand); 476 } 477 478 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const { 479 // In the absence of analysis information, we do not know if the values are 480 // equivalent. The conservative answer is "false". 481 return false; 482 } 483 484 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const { 485 // In the absence of analysis information, we do not know if the values may be 486 // aliasing. The conservative answer is "true". 487 return true; 488 } 489 490 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 491 // In the absence of analysis information, the conservative answer is "false". 492 return false; 493 } 494 495 bool AnalysisState::isTensorYielded(Value tensor) const { 496 // In the absence of analysis information, the conservative answer is "true". 497 if (!tensor.getDefiningOp<AllocTensorOp>()) 498 return true; 499 500 // For AllocTensorOp results, we can do better: They do not alias with any 501 // preceding value, so we can follow SSA use-def chains and do a simple 502 // analysis. 503 SmallVector<OpOperand *> worklist; 504 for (OpOperand &use : tensor.getUses()) 505 worklist.push_back(&use); 506 507 while (!worklist.empty()) { 508 OpOperand *operand = worklist.pop_back_val(); 509 Operation *op = operand->getOwner(); 510 511 // If the op is not bufferizable, we can safely assume that the value is not 512 // yielded. (When bufferizing that op, it must handle such cases.) 513 if (!options.dynCastBufferizableOp(op)) 514 continue; 515 516 // We cannot analyze through ToMemrefOps, so we have to conservatively 517 // assume that the value is yielded. 518 if (isa<ToMemrefOp>(op)) 519 return true; 520 521 // Check if the op is returning/yielding. 522 if (isRegionReturnLike(op)) 523 return true; 524 525 // Add all aliasing OpResults to the worklist. 526 // Note: In the absence of detailed analysis information (e.g., there may be 527 // no function call analysis information), this `getAliasingOpResult` is 528 // conservative and may report additional OpResults as potentially aliasing. 529 for (OpResult opResult : getAliasingOpResult(*operand)) 530 for (OpOperand &use : opResult.getUses()) 531 worklist.push_back(&use); 532 } 533 534 // No ReturnLike op found: The value is not yielded. 535 return false; 536 } 537 538 // bufferization.to_memref is not allowed to change the rank. 539 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 540 #ifndef NDEBUG 541 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 542 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 543 rankedTensorType.getRank()) && 544 "to_memref would be invalid: mismatching ranks"); 545 #endif 546 } 547 548 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, 549 const BufferizationOptions &options) { 550 #ifndef NDEBUG 551 auto tensorType = value.getType().dyn_cast<TensorType>(); 552 assert(tensorType && "unexpected non-tensor type"); 553 #endif // NDEBUG 554 555 // Replace "%t = to_tensor %m" with %m. 556 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 557 return toTensorOp.getMemref(); 558 559 // Insert to_memref op. 560 OpBuilder::InsertionGuard g(rewriter); 561 setInsertionPointAfter(rewriter, value); 562 FailureOr<BaseMemRefType> memrefType = getBufferType(value, options); 563 if (failed(memrefType)) 564 return failure(); 565 ensureToMemrefOpIsValid(value, *memrefType); 566 return rewriter 567 .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value) 568 .getResult(); 569 } 570 571 /// Return the buffer type for a given Value (tensor) after bufferization. 572 FailureOr<BaseMemRefType> 573 bufferization::getBufferType(Value value, const BufferizationOptions &options) { 574 assert(value.getType().isa<TensorType>() && "unexpected non-tensor type"); 575 Operation *op = getOwnerOfValue(value); 576 577 // ToTensorOp: Take buffer type directly from the op. 578 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 579 return toTensorOp.getMemref().getType().cast<BaseMemRefType>(); 580 581 // If value is a bbArg of a bufferizable op: query op interface. 582 if (auto bbArg = value.dyn_cast<BlockArgument>()) 583 if (auto bufferizableOp = 584 options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) 585 return bufferizableOp.getBufferType(bbArg, options); 586 587 // Check value is a new buffer allocation with a memory space attribute. In 588 // that case we can at least infer the memory space. 589 Optional<unsigned> memorySpace = None; 590 if (auto opResult = value.dyn_cast<OpResult>()) { 591 if (auto bufferizableOp = 592 options.dynCastBufferizableOp(opResult.getDefiningOp())) { 593 if (bufferizableOp.bufferizesToAllocation(opResult)) { 594 FailureOr<unsigned> queriedMemorySpace = 595 bufferizableOp.getMemorySpace(opResult); 596 if (!failed(queriedMemorySpace)) 597 memorySpace = *queriedMemorySpace; 598 } 599 } 600 } 601 602 // If we still do not know the memory space, use the default memory space (if 603 // any). 604 if (!memorySpace.has_value()) 605 memorySpace = options.defaultMemorySpace; 606 607 // If we still do not know the memory space, report a failure. 608 if (!memorySpace.has_value()) 609 return op->emitError("could not infer memory space"); 610 611 return getMemRefType(value, options, /*layout=*/{}, *memorySpace); 612 } 613 614 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 615 Operation *op, 616 ValueRange values) { 617 assert(values.size() == op->getNumResults() && 618 "expected one value per OpResult"); 619 OpBuilder::InsertionGuard g(rewriter); 620 621 // Replace all OpResults with the given values. 622 SmallVector<Value> replacements; 623 for (OpResult opResult : op->getOpResults()) { 624 Value replacement = values[opResult.getResultNumber()]; 625 if (opResult.getType().isa<TensorType>()) { 626 // The OpResult is a tensor. Such values are replaced with memrefs during 627 // bufferization. 628 assert((replacement.getType().isa<MemRefType>() || 629 replacement.getType().isa<UnrankedMemRefType>()) && 630 "tensor op result should be replaced with a memref value"); 631 // The existing uses of the OpResult still expect a tensor. Insert a 632 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 633 // loose all of its users and eventually DCE away. 634 rewriter.setInsertionPointAfter(op); 635 replacement = rewriter.create<bufferization::ToTensorOp>( 636 replacement.getLoc(), replacement); 637 } 638 replacements.push_back(replacement); 639 } 640 641 rewriter.replaceOp(op, replacements); 642 } 643 644 //===----------------------------------------------------------------------===// 645 // Bufferization-specific scoped alloc/dealloc insertion support. 646 //===----------------------------------------------------------------------===// 647 648 /// Create a memref allocation with the given type and dynamic extents. 649 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 650 MemRefType type, 651 ValueRange dynShape) const { 652 if (allocationFn) 653 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 654 655 // Default bufferallocation via AllocOp. 656 if (bufferAlignment != 0) 657 return b 658 .create<memref::AllocOp>(loc, type, dynShape, 659 b.getI64IntegerAttr(bufferAlignment)) 660 .getResult(); 661 return b.create<memref::AllocOp>(loc, type, dynShape).getResult(); 662 } 663 664 /// Creates a memref deallocation. The given memref buffer must have been 665 /// allocated using `createAlloc`. 666 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 667 Value allocatedBuffer) const { 668 if (deallocationFn) 669 return (*deallocationFn)(b, loc, allocatedBuffer); 670 671 // Default buffer deallocation via DeallocOp. 672 b.create<memref::DeallocOp>(loc, allocatedBuffer); 673 return success(); 674 } 675 676 /// Create a memory copy between two memref buffers. 677 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 678 Value from, Value to) const { 679 if (memCpyFn) 680 return (*memCpyFn)(b, loc, from, to); 681 682 b.create<memref::CopyOp>(loc, from, to); 683 return success(); 684 } 685 686 //===----------------------------------------------------------------------===// 687 // Bufferization-specific BlockAndValueMapping support with debugging. 688 //===----------------------------------------------------------------------===// 689 690 bool bufferization::isFunctionArgument(Value value) { 691 auto bbArg = value.dyn_cast<BlockArgument>(); 692 if (!bbArg) 693 return false; 694 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 695 } 696 697 BaseMemRefType bufferization::getMemRefType(Value value, 698 const BufferizationOptions &options, 699 MemRefLayoutAttrInterface layout, 700 unsigned memorySpace) { 701 auto tensorType = value.getType().cast<TensorType>(); 702 auto memorySpaceAttr = IntegerAttr::get( 703 IntegerType::get(tensorType.getContext(), 64), memorySpace); 704 705 // Case 1: Unranked memref type. 706 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 707 assert(!layout && "UnrankedTensorType cannot have a layout map"); 708 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 709 memorySpaceAttr); 710 } 711 712 // Case 2: Ranked memref type with specified layout. 713 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 714 if (layout) { 715 return MemRefType::get(rankedTensorType.getShape(), 716 rankedTensorType.getElementType(), layout, 717 memorySpaceAttr); 718 } 719 720 return options.unknownTypeConverterFn(value, memorySpace, options); 721 } 722 723 BaseMemRefType 724 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 725 unsigned memorySpace) { 726 // Case 1: Unranked memref type. 727 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 728 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 729 memorySpace); 730 } 731 732 // Case 2: Ranked memref type. 733 auto memorySpaceAttr = IntegerAttr::get( 734 IntegerType::get(tensorType.getContext(), 64), memorySpace); 735 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 736 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 737 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 738 ShapedType::kDynamicStrideOrOffset); 739 AffineMap stridedLayout = makeStridedLinearLayoutMap( 740 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 741 return MemRefType::get(rankedTensorType.getShape(), 742 rankedTensorType.getElementType(), stridedLayout, 743 memorySpaceAttr); 744 } 745 746 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 747 /// the given tensor type is unranked, return an unranked MemRef type. 748 BaseMemRefType 749 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 750 unsigned memorySpace) { 751 // Case 1: Unranked memref type. 752 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 753 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 754 memorySpace); 755 } 756 757 // Case 2: Ranked memref type. 758 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 759 auto memorySpaceAttr = IntegerAttr::get( 760 IntegerType::get(tensorType.getContext(), 64), memorySpace); 761 MemRefLayoutAttrInterface layout = {}; 762 return MemRefType::get(rankedTensorType.getShape(), 763 rankedTensorType.getElementType(), layout, 764 memorySpaceAttr); 765 } 766