1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AsmState.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/IR/Value.h" 20 #include "llvm/Support/Debug.h" 21 22 //===----------------------------------------------------------------------===// 23 // BufferizableOpInterface 24 //===----------------------------------------------------------------------===// 25 26 namespace mlir { 27 namespace bufferization { 28 29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 30 31 } // namespace bufferization 32 } // namespace mlir 33 34 #define DEBUG_TYPE "bufferizable-op-interface" 35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 37 38 using namespace mlir; 39 using namespace bufferization; 40 41 /// Attribute name used to mark region arguments that can be bufferized 42 /// in-place during linalg comprehensive bufferization. 43 constexpr const ::llvm::StringLiteral 44 bufferization::BufferizableOpInterface::kInplaceableAttrName; 45 46 /// Create an AllocTensorOp for the given shaped value. Only ranked tensors are 47 /// supported at the moment. If `copy` is set, the shaped value is copied. 48 /// Otherwise, a tensor with undefined contents is allocated. 49 static Value allocateTensorForShapedValue(OpBuilder &b, Location loc, 50 Value shapedValue, bool escape, 51 bool copy = true) { 52 auto tensorType = shapedValue.getType().dyn_cast<RankedTensorType>(); 53 assert(tensorType && "only RankedTensorType supported at the moment"); 54 Value alloc; 55 if (!copy) { 56 // No copy needed: Just allocate. 57 SmallVector<Value> dynamicSizes; 58 for (int64_t i = 0; i < tensorType.getRank(); ++i) 59 if (tensorType.isDynamicDim(i)) 60 dynamicSizes.push_back(b.create<tensor::DimOp>(loc, shapedValue, i)); 61 alloc = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes, 62 /*copy=*/Value(), escape); 63 } else { 64 // Allocate and copy. 65 alloc = b.create<AllocTensorOp>(loc, tensorType, 66 /*dynamicSizes=*/ValueRange(), shapedValue, 67 escape); 68 } 69 return alloc; 70 } 71 72 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 73 RewriterBase &rewriter, const AnalysisState &state) { 74 OpBuilder::InsertionGuard g(rewriter); 75 Operation *op = getOperation(); 76 SmallVector<OpOperand *> outOfPlaceOpOperands; 77 DenseSet<OpOperand *> copiedOpOperands; 78 SmallVector<OpResult> outOfPlaceOpResults; 79 DenseSet<OpResult> copiedOpResults; 80 81 // Find all out-of-place OpOperands. 82 for (OpOperand &opOperand : op->getOpOperands()) { 83 Type operandType = opOperand.get().getType(); 84 if (!operandType.isa<TensorType>()) 85 continue; 86 if (state.isInPlace(opOperand)) 87 continue; 88 if (operandType.isa<UnrankedTensorType>()) 89 return op->emitError("copies of unranked tensors are not supported"); 90 91 SmallVector<OpResult> aliasingOpResults = 92 state.getAliasingOpResult(opOperand); 93 if (aliasingOpResults.size() == 1 && 94 !state.bufferizesToMemoryWrite(opOperand) && 95 state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { 96 // The op itself does not write but may create exactly one alias. Instead 97 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes 98 // be smaller than the OpOperand (e.g., in the case of an extract_slice, 99 // where the result is usually a smaller part of the source). 100 outOfPlaceOpResults.push_back(aliasingOpResults.front()); 101 if (!state.canOmitTensorCopy(opOperand)) 102 copiedOpResults.insert(aliasingOpResults.front()); 103 } else { 104 // In all other cases, make a copy of the OpOperand. 105 outOfPlaceOpOperands.push_back(&opOperand); 106 if (!state.canOmitTensorCopy(opOperand)) 107 copiedOpOperands.insert(&opOperand); 108 } 109 } 110 111 // Insert copies of OpOperands. 112 rewriter.setInsertionPoint(op); 113 for (OpOperand *opOperand : outOfPlaceOpOperands) { 114 SmallVector<OpResult> aliasingOpResults = 115 state.getAliasingOpResult(*opOperand); 116 bool escape = llvm::any_of( 117 aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); 118 Value copy = allocateTensorForShapedValue( 119 rewriter, op->getLoc(), opOperand->get(), escape, 120 copiedOpOperands.contains(opOperand)); 121 rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); }); 122 } 123 124 // Insert copies of OpResults. 125 rewriter.setInsertionPointAfter(op); 126 for (OpResult opResult : outOfPlaceOpResults) { 127 bool escape = state.isTensorYielded(opResult); 128 Value copy = 129 allocateTensorForShapedValue(rewriter, op->getLoc(), opResult, escape, 130 copiedOpResults.count(opResult)); 131 SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range( 132 opResult.getUses(), [](OpOperand &use) { return &use; })); 133 for (OpOperand *use : uses) { 134 // Do not update the alloc_tensor op that we just created. 135 if (use->getOwner() != copy.getDefiningOp()) 136 rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); }); 137 } 138 } 139 140 return success(); 141 } 142 143 //===----------------------------------------------------------------------===// 144 // OpFilter 145 //===----------------------------------------------------------------------===// 146 147 bool OpFilter::isOpAllowed(Operation *op) const { 148 // All other ops: Allow/disallow according to filter. 149 bool isAllowed = !hasAllowRule(); 150 for (const Entry &entry : entries) { 151 bool filterResult = entry.fn(op); 152 switch (entry.type) { 153 case Entry::ALLOW: 154 isAllowed |= filterResult; 155 break; 156 case Entry::DENY: 157 if (filterResult) 158 // DENY filter matches. This op is no allowed. (Even if other ALLOW 159 // filters may match.) 160 return false; 161 }; 162 } 163 return isAllowed; 164 } 165 166 //===----------------------------------------------------------------------===// 167 // BufferizationOptions 168 //===----------------------------------------------------------------------===// 169 170 // Default constructor for BufferizationOptions. 171 BufferizationOptions::BufferizationOptions() = default; 172 173 bool BufferizationOptions::isOpAllowed(Operation *op) const { 174 // Special case: If function boundary bufferization is deactivated, do not 175 // allow ops that belong to the `func` dialect. 176 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 177 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 178 return false; 179 180 return opFilter.isOpAllowed(op); 181 } 182 183 BufferizableOpInterface 184 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 185 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 186 if (!bufferizableOp) 187 return nullptr; 188 if (!isOpAllowed(op)) 189 return nullptr; 190 return bufferizableOp; 191 } 192 193 BufferizableOpInterface 194 BufferizationOptions::dynCastBufferizableOp(Value value) const { 195 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 196 if (isOpAllowed(bufferizableOp.getOperation())) 197 return bufferizableOp; 198 return nullptr; 199 } 200 201 void BufferizationOptions::addDialectStateInitializer( 202 StringRef name, const DialectStateInitFn &fn) { 203 stateInitializers.push_back( 204 [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // Helper functions for BufferizableOpInterface 209 //===----------------------------------------------------------------------===// 210 211 static void setInsertionPointAfter(OpBuilder &b, Value value) { 212 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 213 b.setInsertionPointToStart(bbArg.getOwner()); 214 } else { 215 b.setInsertionPointAfter(value.getDefiningOp()); 216 } 217 } 218 219 /// Determine which OpOperand* will alias with `result` if the op is bufferized 220 /// in place. Return an empty vector if the op is not bufferizable. 221 SmallVector<OpOperand *> 222 AnalysisState::getAliasingOpOperand(OpResult result) const { 223 if (Operation *op = result.getDefiningOp()) 224 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 225 return bufferizableOp.getAliasingOpOperand(result, *this); 226 return {}; 227 } 228 229 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 230 /// in place. Return an empty vector if the op is not bufferizable. 231 SmallVector<OpResult> 232 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { 233 if (auto bufferizableOp = 234 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 235 return bufferizableOp.getAliasingOpResult(opOperand, *this); 236 return {}; 237 } 238 239 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 240 /// op is not bufferizable. 241 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 242 if (auto bufferizableOp = 243 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 244 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 245 246 // Unknown op that returns a tensor. The inplace analysis does not support it. 247 // Conservatively return true. 248 return true; 249 } 250 251 /// Return true if `opOperand` bufferizes to a memory write. Return 252 /// `true` if the op is not bufferizable. 253 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 254 if (auto bufferizableOp = 255 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 256 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 257 258 // Unknown op that returns a tensor. The inplace analysis does not support it. 259 // Conservatively return true. 260 return true; 261 } 262 263 /// Return true if `opOperand` does neither read nor write but bufferizes to an 264 /// alias. Return false if the op is not bufferizable. 265 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 266 if (auto bufferizableOp = 267 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 268 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 269 270 // Unknown op that returns a tensor. The inplace analysis does not support it. 271 // Conservatively return false. 272 return false; 273 } 274 275 /// Return true if the given value is read by an op that bufferizes to a memory 276 /// read. Also takes into account ops that create an alias but do not read by 277 /// themselves (e.g., ExtractSliceOp). 278 bool AnalysisState::isValueRead(Value value) const { 279 assert(value.getType().isa<TensorType>() && "expected TensorType"); 280 SmallVector<OpOperand *> workingSet; 281 for (OpOperand &use : value.getUses()) 282 workingSet.push_back(&use); 283 284 while (!workingSet.empty()) { 285 OpOperand *uMaybeReading = workingSet.pop_back_val(); 286 // Skip over all ops that neither read nor write (but create an alias). 287 if (bufferizesToAliasOnly(*uMaybeReading)) 288 for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) 289 for (OpOperand &use : opResult.getUses()) 290 workingSet.push_back(&use); 291 if (bufferizesToMemoryRead(*uMaybeReading)) 292 return true; 293 } 294 295 return false; 296 } 297 298 // Starting from `value`, follow the use-def chain in reverse, always selecting 299 // the aliasing OpOperands. Find and return Values for which `condition` 300 // evaluates to true. OpOperands of such matching Values are not traversed any 301 // further. 302 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 303 Value value, llvm::function_ref<bool(Value)> condition) const { 304 llvm::SetVector<Value> result, workingSet; 305 workingSet.insert(value); 306 307 while (!workingSet.empty()) { 308 Value value = workingSet.pop_back_val(); 309 if (condition(value) || value.isa<BlockArgument>()) { 310 result.insert(value); 311 continue; 312 } 313 314 OpResult opResult = value.cast<OpResult>(); 315 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 316 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 317 result.insert(value); 318 continue; 319 } 320 321 for (OpOperand *o : opOperands) 322 workingSet.insert(o->get()); 323 } 324 325 return result; 326 } 327 328 // Find the Values of the last preceding write of a given Value. 329 llvm::SetVector<Value> 330 AnalysisState::findLastPrecedingWrite(Value value) const { 331 return findValueInReverseUseDefChain(value, [&](Value value) { 332 Operation *op = value.getDefiningOp(); 333 if (!op) 334 return true; 335 auto bufferizableOp = options.dynCastBufferizableOp(op); 336 if (!bufferizableOp) 337 return true; 338 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 339 }); 340 } 341 342 AnalysisState::AnalysisState(const BufferizationOptions &options) 343 : options(options) { 344 for (const BufferizationOptions::AnalysisStateInitFn &fn : 345 options.stateInitializers) 346 fn(*this); 347 } 348 349 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const { 350 // Do not copy if the tensor has undefined contents. 351 if (hasUndefinedContents(&opOperand)) 352 return true; 353 354 // Do not copy if the buffer of the tensor is entirely overwritten (with 355 // values that do not depend on the old tensor). 356 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 357 return true; 358 359 // Do not copy if the tensor is never read. 360 SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand); 361 if (!bufferizesToMemoryRead(opOperand) && 362 llvm::none_of(aliasingOpResults, 363 [&](OpResult opResult) { return isValueRead(opResult); })) 364 return true; 365 366 // Default: Cannot omit the copy. 367 return false; 368 } 369 370 // bufferization.to_memref is not allowed to change the rank. 371 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 372 #ifndef NDEBUG 373 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 374 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 375 rankedTensorType.getRank()) && 376 "to_memref would be invalid: mismatching ranks"); 377 #endif 378 } 379 380 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, 381 const BufferizationOptions &options) { 382 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 383 assert(tensorType && "unexpected non-tensor type"); 384 385 // Replace "%t = to_tensor %m" with %m. 386 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 387 return toTensorOp.memref(); 388 389 // Insert to_memref op. 390 OpBuilder::InsertionGuard g(rewriter); 391 setInsertionPointAfter(rewriter, tensor); 392 Type memrefType = getMemRefType(tensorType, options); 393 ensureToMemrefOpIsValid(tensor, memrefType); 394 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 395 tensor); 396 } 397 398 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate 399 /// a new buffer and copy over data from the existing buffer if out-of-place 400 /// bufferization was decided. 401 FailureOr<Value> 402 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, 403 Optional<ForceInPlacability> overrideInPlace, 404 Optional<Operation *> customCopyInsertionPoint) { 405 const BufferizationOptions &options = analysisState.getOptions(); 406 OpBuilder::InsertionGuard guard(rewriter); 407 Operation *op = opOperand.getOwner(); 408 Location loc = op->getLoc(); 409 SmallVector<OpResult> aliasingOpResults = 410 analysisState.getAliasingOpResult(opOperand); 411 Value operand = opOperand.get(); 412 Value operandBuffer = lookupBuffer(rewriter, operand, options); 413 414 // Can `operandBuffer` be used directly or do we need a copy? 415 bool inplace = 416 overrideInPlace != FORCE_OUT_OF_PLACE && 417 (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand)); 418 if (inplace) 419 return operandBuffer; 420 421 // Bufferizing out-of-place: Allocate a new buffer. 422 // Move insertion point right after `operandBuffer`. That is where the 423 // allocation should be inserted (in the absence of allocation hoisting). 424 setInsertionPointAfter(rewriter, operandBuffer); 425 // Allocate the result buffer. The buffer should be deallocated if the tensor 426 // is not yielded and deallocs are enabled in general. 427 bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { 428 return getAnalysisState().isTensorYielded(v); 429 }); 430 FailureOr<Value> resultBuffer = createAlloc( 431 rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); 432 if (failed(resultBuffer)) 433 return failure(); 434 // Do not copy the buffer if its contents are undefined. 435 if (analysisState.hasUndefinedContents(&opOperand)) 436 return resultBuffer; 437 // Do not copy if the copied data is never read. 438 if (!aliasingOpResults.empty() && 439 !analysisState.bufferizesToMemoryRead(opOperand) && 440 llvm::none_of(aliasingOpResults, [&](OpResult opResult) { 441 return analysisState.isValueRead(opResult); 442 })) 443 return resultBuffer; 444 // Do not copy if this op does not read the data, but writes it. 445 if (analysisState.bufferizesToMemoryWrite(opOperand) && 446 !analysisState.bufferizesToMemoryRead(opOperand)) 447 return resultBuffer; 448 449 if (customCopyInsertionPoint) { 450 rewriter.setInsertionPoint(*customCopyInsertionPoint); 451 } else { 452 // The copy happens right before the op that is bufferized. 453 rewriter.setInsertionPoint(op); 454 } 455 if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer))) 456 return failure(); 457 458 return resultBuffer; 459 } 460 461 /// Return the buffer type for a given Value (tensor) after bufferization. 462 BaseMemRefType BufferizationState::getBufferType(Value value) const { 463 auto tensorType = value.getType().dyn_cast<TensorType>(); 464 assert(tensorType && "unexpected non-tensor type"); 465 466 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 467 return toTensorOp.memref().getType().cast<BaseMemRefType>(); 468 469 return getMemRefType(tensorType, getOptions()); 470 } 471 472 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 473 Operation *op, 474 ValueRange values) { 475 assert(values.size() == op->getNumResults() && 476 "expected one value per OpResult"); 477 OpBuilder::InsertionGuard g(rewriter); 478 479 // Replace all OpResults with the given values. 480 SmallVector<Value> replacements; 481 for (OpResult opResult : op->getOpResults()) { 482 Value replacement = values[opResult.getResultNumber()]; 483 if (opResult.getType().isa<TensorType>()) { 484 // The OpResult is a tensor. Such values are replaced with memrefs during 485 // bufferization. 486 assert((replacement.getType().isa<MemRefType>() || 487 replacement.getType().isa<UnrankedMemRefType>()) && 488 "tensor op result should be replaced with a memref value"); 489 // The existing uses of the OpResult still expect a tensor. Insert a 490 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 491 // loose all of its users and eventually DCE away. 492 rewriter.setInsertionPointAfter(op); 493 replacement = rewriter.create<bufferization::ToTensorOp>( 494 replacement.getLoc(), replacement); 495 } 496 replacements.push_back(replacement); 497 } 498 499 rewriter.replaceOp(op, replacements); 500 } 501 502 //===----------------------------------------------------------------------===// 503 // Bufferization-specific scoped alloc/dealloc insertion support. 504 //===----------------------------------------------------------------------===// 505 506 /// Create a memref allocation with the given type and dynamic extents. 507 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 508 MemRefType type, 509 ValueRange dynShape) const { 510 if (allocationFn) 511 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 512 513 // Default bufferallocation via AllocOp. 514 Value allocated = b.create<memref::AllocOp>( 515 loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment)); 516 return allocated; 517 } 518 519 /// Creates a memref deallocation. The given memref buffer must have been 520 /// allocated using `createAlloc`. 521 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, 522 Value allocatedBuffer) const { 523 if (deallocationFn) 524 return (*deallocationFn)(b, loc, allocatedBuffer); 525 526 // Default buffer deallocation via DeallocOp. 527 b.create<memref::DeallocOp>(loc, allocatedBuffer); 528 return success(); 529 } 530 531 static MemRefType getContiguousMemRefType(ShapedType shapedType, 532 Attribute memorySpace = {}) { 533 MemRefLayoutAttrInterface layout = {}; 534 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 535 layout, memorySpace); 536 } 537 538 /// Compute the type of the `memref` to use for allocating the buffer for 539 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 540 /// dynamic dimensions in the returned `memref` type. 541 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 542 Value shapedValue, 543 SmallVectorImpl<Value> &dynShape) { 544 MemRefType allocMemRefType = 545 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 546 547 // Compute the dynamic part of the shape. 548 bool reifiedShapes = false; 549 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 550 shapedValue.getDefiningOp())) { 551 ReifiedRankedShapedTypeDims resultDims; 552 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 553 reifiedShapes = true; 554 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 555 auto &shape = resultDims[resultValue.getResultNumber()]; 556 for (const auto &dim : enumerate(allocMemRefType.getShape())) 557 if (ShapedType::isDynamic(dim.value())) 558 dynShape.push_back(shape[dim.index()]); 559 } 560 } 561 562 if (!reifiedShapes) { 563 for (const auto &dim : enumerate(allocMemRefType.getShape())) 564 if (ShapedType::isDynamic(dim.value())) { 565 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 566 shapedValue.getType().isa<MemRefType>()) && 567 "expected MemRef type"); 568 dynShape.push_back( 569 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 570 } 571 } 572 573 return allocMemRefType; 574 } 575 576 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the 577 /// block in case of a bbArg). 578 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc, 579 Value shapedValue, 580 Optional<bool> dealloc) { 581 // Take a guard before anything else. 582 OpBuilder::InsertionGuard g(b); 583 584 // Compute allocation memref type. 585 assert(shapedValue.getType().isa<ShapedType>()); 586 SmallVector<Value> dynShape; 587 MemRefType allocMemRefType = 588 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 589 590 // Create the buffer allocation. 591 FailureOr<Value> buffer = 592 getOptions().createAlloc(b, loc, allocMemRefType, dynShape); 593 if (failed(buffer)) 594 return failure(); 595 596 // Should be the buffer be deallocated again or should we let it leak? 597 if (dealloc) { 598 if (!dealloc.getValue()) 599 return *buffer; 600 } else { 601 assert(shapedValue.getType().isa<TensorType>() && 602 "must specify `dealloc` if non-tensor value is passed"); 603 // Buffer should be not be deallocated if deallocs are generally deactivated 604 // or if the tensor is yielded from a block. 605 if (!getOptions().createDeallocs || 606 getAnalysisState().isTensorYielded(shapedValue)) 607 return *buffer; 608 } 609 610 // Create buffer deallocation. 611 b.setInsertionPoint(b.getInsertionBlock()->getTerminator()); 612 if (failed(getOptions().createDealloc(b, loc, *buffer))) 613 return failure(); 614 615 return *buffer; 616 } 617 618 /// Create a memory copy between two memref buffers. 619 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 620 Value from, Value to) const { 621 if (memCpyFn) 622 return (*memCpyFn)(b, loc, from, to); 623 624 b.create<memref::CopyOp>(loc, from, to); 625 return success(); 626 } 627 628 //===----------------------------------------------------------------------===// 629 // Bufferization-specific BlockAndValueMapping support with debugging. 630 //===----------------------------------------------------------------------===// 631 632 bool bufferization::isFunctionArgument(Value value) { 633 auto bbArg = value.dyn_cast<BlockArgument>(); 634 if (!bbArg) 635 return false; 636 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); 637 } 638 639 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 640 const BufferizationOptions &options, 641 MemRefLayoutAttrInterface layout, 642 Attribute memorySpace) { 643 // Case 1: Unranked memref type. 644 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 645 assert(!layout && "UnrankedTensorType cannot have a layout map"); 646 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 647 memorySpace); 648 } 649 650 // Case 2: Ranked memref type with specified layout. 651 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 652 if (layout) { 653 return MemRefType::get(rankedTensorType.getShape(), 654 rankedTensorType.getElementType(), layout, 655 memorySpace); 656 } 657 658 // Case 3: Configured with "fully dynamic layout maps". 659 if (options.unknownTypeConversion == 660 BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap) 661 return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); 662 663 // Case 4: Configured with "static identity layout maps". 664 if (options.unknownTypeConversion == 665 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) 666 return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); 667 668 llvm_unreachable("InferLayoutMap is an invalid option"); 669 } 670 671 BaseMemRefType 672 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 673 Attribute memorySpace) { 674 // Case 1: Unranked memref type. 675 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 676 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 677 memorySpace); 678 } 679 680 // Case 2: Ranked memref type. 681 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 682 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 683 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 684 ShapedType::kDynamicStrideOrOffset); 685 AffineMap stridedLayout = makeStridedLinearLayoutMap( 686 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 687 return MemRefType::get(rankedTensorType.getShape(), 688 rankedTensorType.getElementType(), stridedLayout, 689 memorySpace); 690 } 691 692 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 693 /// the given tensor type is unranked, return an unranked MemRef type. 694 BaseMemRefType 695 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 696 Attribute memorySpace) { 697 // Case 1: Unranked memref type. 698 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 699 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 700 memorySpace); 701 } 702 703 // Case 2: Ranked memref type. 704 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 705 MemRefLayoutAttrInterface layout = {}; 706 return MemRefType::get(rankedTensorType.getShape(), 707 rankedTensorType.getElementType(), layout, 708 memorySpace); 709 } 710