1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AsmState.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/IR/Value.h" 20 #include "llvm/Support/Debug.h" 21 22 //===----------------------------------------------------------------------===// 23 // BufferizableOpInterface 24 //===----------------------------------------------------------------------===// 25 26 namespace mlir { 27 namespace bufferization { 28 29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 30 31 } // namespace bufferization 32 } // namespace mlir 33 34 #define DEBUG_TYPE "bufferizable-op-interface" 35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 37 38 using namespace mlir; 39 using namespace bufferization; 40 41 /// Attribute name used to mark region arguments that can be bufferized 42 /// in-place during linalg comprehensive bufferization. 43 constexpr const ::llvm::StringLiteral 44 bufferization::BufferizableOpInterface::kInplaceableAttrName; 45 46 /// Create an AllocTensorOp for the given shaped value. Only ranked tensors are 47 /// supported at the moment. If `copy` is set, the shaped value is copied. 48 /// Otherwise, a tensor with undefined contents is allocated. 49 static Value allocateTensorForShapedValue(OpBuilder &b, Location loc, 50 Value shapedValue, bool escape, 51 bool copy = true) { 52 auto tensorType = shapedValue.getType().dyn_cast<RankedTensorType>(); 53 assert(tensorType && "only RankedTensorType supported at the moment"); 54 Value alloc; 55 if (!copy) { 56 // No copy needed: Just allocate. 57 SmallVector<Value> dynamicSizes; 58 for (int64_t i = 0; i < tensorType.getRank(); ++i) 59 if (tensorType.isDynamicDim(i)) 60 dynamicSizes.push_back(b.create<tensor::DimOp>(loc, shapedValue, i)); 61 alloc = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes, 62 /*copy=*/Value(), escape); 63 } else { 64 // Allocate and copy. 65 alloc = b.create<AllocTensorOp>(loc, tensorType, 66 /*dynamicSizes=*/ValueRange(), shapedValue, 67 escape); 68 } 69 return alloc; 70 } 71 72 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 73 RewriterBase &rewriter, const AnalysisState &state) { 74 OpBuilder::InsertionGuard g(rewriter); 75 Operation *op = getOperation(); 76 SmallVector<OpOperand *> outOfPlaceOpOperands; 77 DenseSet<OpOperand *> copiedOpOperands; 78 DenseSet<OpOperand *> escapingOpOperandCopies; 79 SmallVector<OpResult> outOfPlaceOpResults; 80 DenseSet<OpResult> copiedOpResults; 81 DenseSet<OpResult> escapingOpResultCopies; 82 83 // Find all out-of-place OpOperands. 84 for (OpOperand &opOperand : op->getOpOperands()) { 85 Type operandType = opOperand.get().getType(); 86 if (!operandType.isa<TensorType>()) 87 continue; 88 if (state.isInPlace(opOperand)) 89 continue; 90 if (operandType.isa<UnrankedTensorType>()) 91 return op->emitError("copies of unranked tensors are not supported"); 92 93 SmallVector<OpResult> aliasingOpResults = 94 state.getAliasingOpResult(opOperand); 95 // Is the result yielded from a block? Or are deallocations turned off 96 // entirely? In either case, mark the allocation as "escaping", so that it 97 // will not be deallocated. 98 bool escape = !state.getOptions().createDeallocs || 99 llvm::any_of(aliasingOpResults, [&](Value v) { 100 return state.isTensorYielded(v); 101 }); 102 103 if (aliasingOpResults.size() == 1 && 104 !state.bufferizesToMemoryWrite(opOperand) && 105 state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { 106 // The op itself does not write but may create exactly one alias. Instead 107 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes 108 // be smaller than the OpOperand (e.g., in the case of an extract_slice, 109 // where the result is usually a smaller part of the source). 110 outOfPlaceOpResults.push_back(aliasingOpResults.front()); 111 if (!state.canOmitTensorCopy(opOperand)) 112 copiedOpResults.insert(aliasingOpResults.front()); 113 if (escape) 114 escapingOpResultCopies.insert(aliasingOpResults.front()); 115 } else { 116 // In all other cases, make a copy of the OpOperand. 117 outOfPlaceOpOperands.push_back(&opOperand); 118 if (!state.canOmitTensorCopy(opOperand)) 119 copiedOpOperands.insert(&opOperand); 120 if (escape) 121 escapingOpOperandCopies.insert(&opOperand); 122 } 123 } 124 125 // Insert copies of OpOperands. 126 rewriter.setInsertionPoint(op); 127 for (OpOperand *opOperand : outOfPlaceOpOperands) { 128 Value copy = allocateTensorForShapedValue( 129 rewriter, op->getLoc(), opOperand->get(), 130 escapingOpOperandCopies.contains(opOperand), 131 copiedOpOperands.contains(opOperand)); 132 rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); }); 133 } 134 135 // Insert copies of OpResults. 136 rewriter.setInsertionPointAfter(op); 137 for (OpResult opResult : outOfPlaceOpResults) { 138 Value copy = 139 allocateTensorForShapedValue(rewriter, op->getLoc(), opResult, 140 escapingOpResultCopies.contains(opResult), 141 copiedOpResults.count(opResult)); 142 SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range( 143 opResult.getUses(), [](OpOperand &use) { return &use; })); 144 for (OpOperand *use : uses) { 145 // Do not update the alloc_tensor op that we just created. 146 if (use->getOwner() != copy.getDefiningOp()) 147 rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); }); 148 } 149 } 150 151 return success(); 152 } 153 154 //===----------------------------------------------------------------------===// 155 // OpFilter 156 //===----------------------------------------------------------------------===// 157 158 bool OpFilter::isOpAllowed(Operation *op) const { 159 // All other ops: Allow/disallow according to filter. 160 bool isAllowed = !hasAllowRule(); 161 for (const Entry &entry : entries) { 162 bool filterResult = entry.fn(op); 163 switch (entry.type) { 164 case Entry::ALLOW: 165 isAllowed |= filterResult; 166 break; 167 case Entry::DENY: 168 if (filterResult) 169 // DENY filter matches. This op is no allowed. (Even if other ALLOW 170 // filters may match.) 171 return false; 172 }; 173 } 174 return isAllowed; 175 } 176 177 //===----------------------------------------------------------------------===// 178 // BufferizationOptions 179 //===----------------------------------------------------------------------===// 180 181 // Default constructor for BufferizationOptions. 182 BufferizationOptions::BufferizationOptions() = default; 183 184 bool BufferizationOptions::isOpAllowed(Operation *op) const { 185 // Special case: If function boundary bufferization is deactivated, do not 186 // allow ops that belong to the `func` dialect. 187 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 188 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 189 return false; 190 191 return opFilter.isOpAllowed(op); 192 } 193 194 BufferizableOpInterface 195 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 196 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 197 if (!bufferizableOp) 198 return nullptr; 199 if (!isOpAllowed(op)) 200 return nullptr; 201 return bufferizableOp; 202 } 203 204 BufferizableOpInterface 205 BufferizationOptions::dynCastBufferizableOp(Value value) const { 206 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 207 if (isOpAllowed(bufferizableOp.getOperation())) 208 return bufferizableOp; 209 return nullptr; 210 } 211 212 void BufferizationOptions::addDialectStateInitializer( 213 StringRef name, const DialectStateInitFn &fn) { 214 stateInitializers.push_back( 215 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // Helper functions for BufferizableOpInterface 220 //===----------------------------------------------------------------------===// 221 222 static void setInsertionPointAfter(OpBuilder &b, Value value) { 223 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 224 b.setInsertionPointToStart(bbArg.getOwner()); 225 } else { 226 b.setInsertionPointAfter(value.getDefiningOp()); 227 } 228 } 229 230 /// Determine which OpOperand* will alias with `result` if the op is bufferized 231 /// in place. Return an empty vector if the op is not bufferizable. 232 SmallVector<OpOperand *> 233 AnalysisState::getAliasingOpOperand(OpResult result) const { 234 if (Operation *op = result.getDefiningOp()) 235 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 236 return bufferizableOp.getAliasingOpOperand(result, *this); 237 return {}; 238 } 239 240 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 241 /// in place. Return an empty vector if the op is not bufferizable. 242 SmallVector<OpResult> 243 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 244 if (auto bufferizableOp = 245 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 246 return bufferizableOp.getAliasingOpResult(opOperand, *this); 247 return {}; 248 } 249 250 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 251 /// op is not bufferizable. 252 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 253 if (auto bufferizableOp = 254 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 255 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 256 257 // Unknown op that returns a tensor. The inplace analysis does not support it. 258 // Conservatively return true. 259 return true; 260 } 261 262 /// Return true if `opOperand` bufferizes to a memory write. Return 263 /// `true` if the op is not bufferizable. 264 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 265 if (auto bufferizableOp = 266 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 267 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 268 269 // Unknown op that returns a tensor. The inplace analysis does not support it. 270 // Conservatively return true. 271 return true; 272 } 273 274 /// Return true if `opOperand` does neither read nor write but bufferizes to an 275 /// alias. Return false if the op is not bufferizable. 276 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 277 if (auto bufferizableOp = 278 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 279 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 280 281 // Unknown op that returns a tensor. The inplace analysis does not support it. 282 // Conservatively return false. 283 return false; 284 } 285 286 /// Return true if the given value is read by an op that bufferizes to a memory 287 /// read. Also takes into account ops that create an alias but do not read by 288 /// themselves (e.g., ExtractSliceOp). 289 bool AnalysisState::isValueRead(Value value) const { 290 assert(value.getType().isa<TensorType>() && "expected TensorType"); 291 SmallVector<OpOperand *> workingSet; 292 for (OpOperand &use : value.getUses()) 293 workingSet.push_back(&use); 294 295 while (!workingSet.empty()) { 296 OpOperand *uMaybeReading = workingSet.pop_back_val(); 297 // Skip over all ops that neither read nor write (but create an alias). 298 if (bufferizesToAliasOnly(*uMaybeReading)) 299 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 300 for (OpOperand &use : opResult.getUses()) 301 workingSet.push_back(&use); 302 if (bufferizesToMemoryRead(*uMaybeReading)) 303 return true; 304 } 305 306 return false; 307 } 308 309 // Starting from `value`, follow the use-def chain in reverse, always selecting 310 // the aliasing OpOperands. Find and return Values for which `condition` 311 // evaluates to true. OpOperands of such matching Values are not traversed any 312 // further. 313 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 314 Value value, llvm::function_ref<bool(Value)> condition) const { 315 llvm::SetVector<Value> result, workingSet; 316 workingSet.insert(value); 317 318 while (!workingSet.empty()) { 319 Value value = workingSet.pop_back_val(); 320 if (condition(value) || value.isa<BlockArgument>()) { 321 result.insert(value); 322 continue; 323 } 324 325 OpResult opResult = value.cast<OpResult>(); 326 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 327 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 328 result.insert(value); 329 continue; 330 } 331 332 for (OpOperand *o : opOperands) 333 workingSet.insert(o->get()); 334 } 335 336 return result; 337 } 338 339 // Find the Values of the last preceding write of a given Value. 340 llvm::SetVector<Value> 341 AnalysisState::findLastPrecedingWrite(Value value) const { 342 return findValueInReverseUseDefChain(value, [&](Value value) { 343 Operation *op = value.getDefiningOp(); 344 if (!op) 345 return true; 346 auto bufferizableOp = options.dynCastBufferizableOp(op); 347 if (!bufferizableOp) 348 return true; 349 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 350 }); 351 } 352 353 AnalysisState::AnalysisState(const BufferizationOptions &options) 354 : options(options) { 355 for (const BufferizationOptions::AnalysisStateInitFn &fn : 356 options.stateInitializers) 357 fn(*this); 358 } 359 360 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const { 361 // Do not copy if the tensor has undefined contents. 362 if (hasUndefinedContents(&opOperand)) 363 return true; 364 365 // Do not copy if the buffer of the tensor is entirely overwritten (with 366 // values that do not depend on the old tensor). 367 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 368 return true; 369 370 // Do not copy if the tensor is never read. 371 SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand); 372 if (!bufferizesToMemoryRead(opOperand) && 373 llvm::none_of(aliasingOpResults, 374 [&](OpResult opResult) { return isValueRead(opResult); })) 375 return true; 376 377 // Default: Cannot omit the copy. 378 return false; 379 } 380 381 bool AnalysisState::isInPlace(OpOperand &opOperand) const { 382 // In the absence of analysis information, OpOperands that bufferize to a 383 // memory write are out-of-place, i.e., an alloc and copy is inserted. 384 return !bufferizesToMemoryWrite(opOperand); 385 } 386 387 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const { 388 // In the absence of analysis information, we do not know if the values are 389 // equivalent. The conservative answer is "false". 390 return false; 391 } 392 393 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const { 394 // In the absence of analysis information, we do not know if the values may be 395 // aliasing. The conservative answer is "true". 396 return true; 397 } 398 399 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 400 // In the absence of analysis information, the conservative answer is "false". 401 return false; 402 } 403 404 bool AnalysisState::isTensorYielded(Value tensor) const { 405 // In the absence of analysis information, the conservative answer is "true". 406 if (!tensor.getDefiningOp<AllocTensorOp>()) 407 return true; 408 409 // For AllocTensorOp results, we can do better: They do not alias with any 410 // preceding value, so we can follow SSA use-def chains and do a simple 411 // analysis. 412 SmallVector<OpOperand *> worklist; 413 for (OpOperand &use : tensor.getUses()) 414 worklist.push_back(&use); 415 416 while (!worklist.empty()) { 417 OpOperand *operand = worklist.pop_back_val(); 418 Operation *op = operand->getOwner(); 419 420 // If the op is not bufferizable, we can safely assume that the value is not 421 // yielded. (When bufferizing that op, it must handle such cases.) 422 if (!options.dynCastBufferizableOp(op)) 423 continue; 424 425 // We cannot analyze through ToMemrefOps, so we have to conservatively 426 // assume that the value is yielded. 427 if (isa<ToMemrefOp>(op)) 428 return true; 429 430 // Check if the op is returning/yielding. 431 if (isRegionReturnLike(op)) 432 return true; 433 434 // Add all aliasing OpResults to the worklist. 435 // Note: In the absence of detailed analysis information (e.g., there may be 436 // no function call analysis information), this `getAliasingOpResult` is 437 // conservative and may report additional OpResults as potentially aliasing. 438 for (OpResult opResult : getAliasingOpResult(*operand)) 439 for (OpOperand &use : opResult.getUses()) 440 worklist.push_back(&use); 441 } 442 443 // No ReturnLike op found: The value is not yielded. 444 return false; 445 } 446 447 // bufferization.to_memref is not allowed to change the rank. 448 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 449 #ifndef NDEBUG 450 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 451 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 452 rankedTensorType.getRank()) && 453 "to_memref would be invalid: mismatching ranks"); 454 #endif 455 } 456 457 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, 458 const BufferizationOptions &options) { 459 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 460 assert(tensorType && "unexpected non-tensor type"); 461 462 // Replace "%t = to_tensor %m" with %m. 463 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 464 return toTensorOp.memref(); 465 466 // Insert to_memref op. 467 OpBuilder::InsertionGuard g(rewriter); 468 setInsertionPointAfter(rewriter, tensor); 469 Type memrefType = getMemRefType(tensorType, options); 470 ensureToMemrefOpIsValid(tensor, memrefType); 471 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 472 tensor); 473 } 474 475 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate 476 /// a new buffer and copy over data from the existing buffer if out-of-place 477 /// bufferization was decided. 478 FailureOr<Value> 479 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, 480 Optional<ForceInPlacability> overrideInPlace, 481 Optional<Operation *> customCopyInsertionPoint) { 482 const BufferizationOptions &options = analysisState.getOptions(); 483 OpBuilder::InsertionGuard guard(rewriter); 484 Operation *op = opOperand.getOwner(); 485 Location loc = op->getLoc(); 486 SmallVector<OpResult> aliasingOpResults = 487 analysisState.getAliasingOpResult(opOperand); 488 Value operand = opOperand.get(); 489 Value operandBuffer = lookupBuffer(rewriter, operand, options); 490 491 // Can `operandBuffer` be used directly or do we need a copy? 492 bool inplace = 493 overrideInPlace != FORCE_OUT_OF_PLACE && 494 (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand)); 495 if (inplace) 496 return operandBuffer; 497 498 // Bufferizing out-of-place: Allocate a new buffer. 499 // Move insertion point right after `operandBuffer`. That is where the 500 // allocation should be inserted (in the absence of allocation hoisting). 501 setInsertionPointAfter(rewriter, operandBuffer); 502 // Allocate the result buffer. The buffer should be deallocated if the tensor 503 // is not yielded and deallocs are enabled in general. 504 bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { 505 return getAnalysisState().isTensorYielded(v); 506 }); 507 FailureOr<Value> resultBuffer = createAlloc( 508 rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); 509 if (failed(resultBuffer)) 510 return failure(); 511 // Do not copy the buffer if its contents are undefined. 512 if (analysisState.hasUndefinedContents(&opOperand)) 513 return resultBuffer; 514 // Do not copy if the copied data is never read. 515 if (!aliasingOpResults.empty() && 516 !analysisState.bufferizesToMemoryRead(opOperand) && 517 llvm::none_of(aliasingOpResults, [&](OpResult opResult) { 518 return analysisState.isValueRead(opResult); 519 })) 520 return resultBuffer; 521 // Do not copy if this op does not read the data, but writes it. 522 if (analysisState.bufferizesToMemoryWrite(opOperand) && 523 !analysisState.bufferizesToMemoryRead(opOperand)) 524 return resultBuffer; 525 526 if (customCopyInsertionPoint) { 527 rewriter.setInsertionPoint(*customCopyInsertionPoint); 528 } else { 529 // The copy happens right before the op that is bufferized. 530 rewriter.setInsertionPoint(op); 531 } 532 if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer))) 533 return failure(); 534 535 return resultBuffer; 536 } 537 538 /// Return the buffer type for a given Value (tensor) after bufferization. 539 BaseMemRefType BufferizationState::getBufferType(Value value) const { 540 auto tensorType = value.getType().dyn_cast<TensorType>(); 541 assert(tensorType && "unexpected non-tensor type"); 542 543 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 544 return toTensorOp.memref().getType().cast<BaseMemRefType>(); 545 546 return getMemRefType(tensorType, getOptions()); 547 } 548 549 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 550 Operation *op, 551 ValueRange values) { 552 assert(values.size() == op->getNumResults() && 553 "expected one value per OpResult"); 554 OpBuilder::InsertionGuard g(rewriter); 555 556 // Replace all OpResults with the given values. 557 SmallVector<Value> replacements; 558 for (OpResult opResult : op->getOpResults()) { 559 Value replacement = values[opResult.getResultNumber()]; 560 if (opResult.getType().isa<TensorType>()) { 561 // The OpResult is a tensor. Such values are replaced with memrefs during 562 // bufferization. 563 assert((replacement.getType().isa<MemRefType>() || 564 replacement.getType().isa<UnrankedMemRefType>()) && 565 "tensor op result should be replaced with a memref value"); 566 // The existing uses of the OpResult still expect a tensor. Insert a 567 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 568 // loose all of its users and eventually DCE away. 569 rewriter.setInsertionPointAfter(op); 570 replacement = rewriter.create<bufferization::ToTensorOp>( 571 replacement.getLoc(), replacement); 572 } 573 replacements.push_back(replacement); 574 } 575 576 rewriter.replaceOp(op, replacements); 577 } 578 579 //===----------------------------------------------------------------------===// 580 // Bufferization-specific scoped alloc/dealloc insertion support. 581 //===----------------------------------------------------------------------===// 582 583 /// Create a memref allocation with the given type and dynamic extents. 584 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 585 MemRefType type, 586 ValueRange dynShape) const { 587 if (allocationFn) 588 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 589 590 // Default bufferallocation via AllocOp. 591 Value allocated = b.create<memref::AllocOp>( 592 loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment)); 593 return allocated; 594 } 595 596 /// Creates a memref deallocation. The given memref buffer must have been 597 /// allocated using `createAlloc`. 598 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 599 Value allocatedBuffer) const { 600 if (deallocationFn) 601 return (*deallocationFn)(b, loc, allocatedBuffer); 602 603 // Default buffer deallocation via DeallocOp. 604 b.create<memref::DeallocOp>(loc, allocatedBuffer); 605 return success(); 606 } 607 608 static MemRefType getContiguousMemRefType(ShapedType shapedType, 609 Attribute memorySpace = {}) { 610 MemRefLayoutAttrInterface layout = {}; 611 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 612 layout, memorySpace); 613 } 614 615 /// Compute the type of the `memref` to use for allocating the buffer for 616 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 617 /// dynamic dimensions in the returned `memref` type. 618 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 619 Value shapedValue, 620 SmallVectorImpl<Value> &dynShape) { 621 MemRefType allocMemRefType = 622 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 623 624 // Compute the dynamic part of the shape. 625 bool reifiedShapes = false; 626 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 627 shapedValue.getDefiningOp())) { 628 ReifiedRankedShapedTypeDims resultDims; 629 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 630 reifiedShapes = true; 631 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 632 auto &shape = resultDims[resultValue.getResultNumber()]; 633 for (const auto &dim : enumerate(allocMemRefType.getShape())) 634 if (ShapedType::isDynamic(dim.value())) 635 dynShape.push_back(shape[dim.index()]); 636 } 637 } 638 639 if (!reifiedShapes) { 640 for (const auto &dim : enumerate(allocMemRefType.getShape())) 641 if (ShapedType::isDynamic(dim.value())) { 642 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 643 shapedValue.getType().isa<MemRefType>()) && 644 "expected MemRef type"); 645 dynShape.push_back( 646 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 647 } 648 } 649 650 return allocMemRefType; 651 } 652 653 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the 654 /// block in case of a bbArg). 655 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 656 Value shapedValue, 657 Optional<bool> dealloc) { 658 // Take a guard before anything else. 659 OpBuilder::InsertionGuard g(b); 660 661 // Compute allocation memref type. 662 assert(shapedValue.getType().isa<ShapedType>()); 663 SmallVector<Value> dynShape; 664 MemRefType allocMemRefType = 665 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 666 667 // Create the buffer allocation. 668 FailureOr<Value> buffer = 669 getOptions().createAlloc(b, loc, allocMemRefType, dynShape); 670 if (failed(buffer)) 671 return failure(); 672 673 // Should be the buffer be deallocated again or should we let it leak? 674 if (dealloc) { 675 if (!dealloc.getValue()) 676 return *buffer; 677 } else { 678 assert(shapedValue.getType().isa<TensorType>() && 679 "must specify `dealloc` if non-tensor value is passed"); 680 // Buffer should be not be deallocated if deallocs are generally deactivated 681 // or if the tensor is yielded from a block. 682 if (!getOptions().createDeallocs || 683 getAnalysisState().isTensorYielded(shapedValue)) 684 return *buffer; 685 } 686 687 // Create buffer deallocation. 688 b.setInsertionPoint(b.getInsertionBlock()->getTerminator()); 689 if (failed(getOptions().createDealloc(b, loc, *buffer))) 690 return failure(); 691 692 return *buffer; 693 } 694 695 /// Create a memory copy between two memref buffers. 696 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 697 Value from, Value to) const { 698 if (memCpyFn) 699 return (*memCpyFn)(b, loc, from, to); 700 701 b.create<memref::CopyOp>(loc, from, to); 702 return success(); 703 } 704 705 //===----------------------------------------------------------------------===// 706 // Bufferization-specific BlockAndValueMapping support with debugging. 707 //===----------------------------------------------------------------------===// 708 709 bool bufferization::isFunctionArgument(Value value) { 710 auto bbArg = value.dyn_cast<BlockArgument>(); 711 if (!bbArg) 712 return false; 713 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 714 } 715 716 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 717 const BufferizationOptions &options, 718 MemRefLayoutAttrInterface layout, 719 Attribute memorySpace) { 720 // Case 1: Unranked memref type. 721 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 722 assert(!layout && "UnrankedTensorType cannot have a layout map"); 723 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 724 memorySpace); 725 } 726 727 // Case 2: Ranked memref type with specified layout. 728 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 729 if (layout) { 730 return MemRefType::get(rankedTensorType.getShape(), 731 rankedTensorType.getElementType(), layout, 732 memorySpace); 733 } 734 735 // Case 3: Configured with "fully dynamic layout maps". 736 if (options.unknownTypeConversion == 737 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap) 738 return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); 739 740 // Case 4: Configured with "static identity layout maps". 741 if (options.unknownTypeConversion == 742 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) 743 return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); 744 745 llvm_unreachable("InferLayoutMap is an invalid option"); 746 } 747 748 BaseMemRefType 749 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 750 Attribute memorySpace) { 751 // Case 1: Unranked memref type. 752 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 753 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 754 memorySpace); 755 } 756 757 // Case 2: Ranked memref type. 758 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 759 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 760 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 761 ShapedType::kDynamicStrideOrOffset); 762 AffineMap stridedLayout = makeStridedLinearLayoutMap( 763 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 764 return MemRefType::get(rankedTensorType.getShape(), 765 rankedTensorType.getElementType(), stridedLayout, 766 memorySpace); 767 } 768 769 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 770 /// the given tensor type is unranked, return an unranked MemRef type. 771 BaseMemRefType 772 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 773 Attribute memorySpace) { 774 // Case 1: Unranked memref type. 775 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 776 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 777 memorySpace); 778 } 779 780 // Case 2: Ranked memref type. 781 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 782 MemRefLayoutAttrInterface layout = {}; 783 return MemRefType::get(rankedTensorType.getShape(), 784 rankedTensorType.getElementType(), layout, 785 memorySpace); 786 } 787