1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/IR/AsmState.h" 13 #include "mlir/IR/BlockAndValueMapping.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/IR/Value.h" 18 #include "llvm/Support/Debug.h" 19 20 namespace mlir { 21 namespace bufferization { 22 23 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 24 25 } // namespace bufferization 26 } // namespace mlir 27 28 #define DEBUG_TYPE "bufferizable-op-interface" 29 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 30 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 31 32 using namespace mlir; 33 using namespace bufferization; 34 35 /// Attribute name used to mark the bufferization layout for region 36 /// arguments during linalg comprehensive bufferization. 37 constexpr const ::llvm::StringLiteral 38 bufferization::BufferizableOpInterface::kBufferLayoutAttrName; 39 40 /// Attribute name used to mark region arguments that can be bufferized 41 /// in-place during linalg comprehensive bufferization. 42 constexpr const ::llvm::StringLiteral 43 bufferization::BufferizableOpInterface::kInplaceableAttrName; 44 45 //===----------------------------------------------------------------------===// 46 // BufferizationOptions 47 //===----------------------------------------------------------------------===// 48 49 // Default constructor for BufferizationOptions. 50 BufferizationOptions::BufferizationOptions() = default; 51 52 BufferizableOpInterface 53 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 54 if (isOpAllowed(op)) 55 return dyn_cast<BufferizableOpInterface>(op); 56 return nullptr; 57 } 58 59 BufferizableOpInterface 60 BufferizationOptions::dynCastBufferizableOp(Value value) const { 61 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 62 if (isOpAllowed(bufferizableOp.getOperation())) 63 return bufferizableOp; 64 return nullptr; 65 } 66 67 //===----------------------------------------------------------------------===// 68 // Helper functions for BufferizableOpInterface 69 //===----------------------------------------------------------------------===// 70 71 static void setInsertionPointAfter(OpBuilder &b, Value value) { 72 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 73 b.setInsertionPointToStart(bbArg.getOwner()); 74 } else { 75 b.setInsertionPointAfter(value.getDefiningOp()); 76 } 77 } 78 79 /// Determine which OpOperand* will alias with `result` if the op is bufferized 80 /// in place. Return an empty vector if the op is not bufferizable. 81 SmallVector<OpOperand *> 82 BufferizationState::getAliasingOpOperand(OpResult result) const { 83 if (Operation *op = result.getDefiningOp()) 84 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) 85 return bufferizableOp.getAliasingOpOperand(result, *this); 86 return {}; 87 } 88 89 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 90 /// in place. Return an empty OpResult if the op is not bufferizable. 91 OpResult BufferizationState::getAliasingOpResult(OpOperand &opOperand) const { 92 if (auto bufferizableOp = 93 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 94 return bufferizableOp.getAliasingOpResult(opOperand, *this); 95 return OpResult(); 96 } 97 98 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 99 /// op is not bufferizable. 100 bool BufferizationState::bufferizesToMemoryRead(OpOperand &opOperand) const { 101 if (auto bufferizableOp = 102 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 103 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 104 105 // Unknown op that returns a tensor. The inplace analysis does not support it. 106 // Conservatively return true. 107 return true; 108 } 109 110 /// Return true if `opOperand` bufferizes to a memory write. Return 111 /// `true` if the op is not bufferizable. 112 bool BufferizationState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 113 if (auto bufferizableOp = 114 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 115 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 116 117 // Unknown op that returns a tensor. The inplace analysis does not support it. 118 // Conservatively return true. 119 return true; 120 } 121 122 /// Return true if `opOperand` does neither read nor write but bufferizes to an 123 /// alias. Return false if the op is not bufferizable. 124 bool BufferizationState::bufferizesToAliasOnly(OpOperand &opOperand) const { 125 if (auto bufferizableOp = 126 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 127 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 128 129 // Unknown op that returns a tensor. The inplace analysis does not support it. 130 // Conservatively return false. 131 return false; 132 } 133 134 /// Return true if the given value is read by an op that bufferizes to a memory 135 /// read. Also takes into account ops that create an alias but do not read by 136 /// themselves (e.g., ExtractSliceOp). 137 bool BufferizationState::isValueRead(Value value) const { 138 assert(value.getType().isa<TensorType>() && "expected TensorType"); 139 SmallVector<OpOperand *> workingSet; 140 for (OpOperand &use : value.getUses()) 141 workingSet.push_back(&use); 142 143 while (!workingSet.empty()) { 144 OpOperand *uMaybeReading = workingSet.pop_back_val(); 145 // Skip over all ops that neither read nor write (but create an alias). 146 if (bufferizesToAliasOnly(*uMaybeReading)) 147 for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses()) 148 workingSet.push_back(&use); 149 if (bufferizesToMemoryRead(*uMaybeReading)) 150 return true; 151 } 152 153 return false; 154 } 155 156 // Starting from `value`, follow the use-def chain in reverse, always selecting 157 // the aliasing OpOperands. Find and return Values for which `condition` 158 // evaluates to true. OpOperands of such matching Values are not traversed any 159 // further. 160 llvm::SetVector<Value> BufferizationState::findValueInReverseUseDefChain( 161 Value value, llvm::function_ref<bool(Value)> condition) const { 162 llvm::SetVector<Value> result, workingSet; 163 workingSet.insert(value); 164 165 while (!workingSet.empty()) { 166 Value value = workingSet.pop_back_val(); 167 if (condition(value) || value.isa<BlockArgument>()) { 168 result.insert(value); 169 continue; 170 } 171 172 OpResult opResult = value.cast<OpResult>(); 173 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 174 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 175 result.insert(value); 176 continue; 177 } 178 179 for (OpOperand *o : opOperands) 180 workingSet.insert(o->get()); 181 } 182 183 return result; 184 } 185 186 // Find the Values of the last preceding write of a given Value. 187 llvm::SetVector<Value> 188 BufferizationState::findLastPrecedingWrite(Value value) const { 189 return findValueInReverseUseDefChain(value, [&](Value value) { 190 Operation *op = value.getDefiningOp(); 191 if (!op) 192 return true; 193 auto bufferizableOp = options.dynCastBufferizableOp(op); 194 if (!bufferizableOp) 195 return true; 196 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 197 }); 198 } 199 200 BufferizationState::BufferizationState(const BufferizationOptions &options) 201 : options(options) {} 202 203 // bufferization.to_memref is not allowed to change the rank. 204 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 205 #ifndef NDEBUG 206 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 207 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 208 rankedTensorType.getRank()) && 209 "to_memref would be invalid: mismatching ranks"); 210 #endif 211 } 212 213 static Value lookupBuffer(RewriterBase &rewriter, Value tensor, 214 const BufferizationOptions &options) { 215 auto tensorType = tensor.getType().dyn_cast<TensorType>(); 216 assert(tensorType && "unexpected non-tensor type"); 217 218 // Replace "%t = to_tensor %m" with %m. 219 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 220 return toTensorOp.memref(); 221 222 // Insert to_memref op. 223 OpBuilder::InsertionGuard g(rewriter); 224 setInsertionPointAfter(rewriter, tensor); 225 Type memrefType = getMemRefType(tensorType, options); 226 ensureToMemrefOpIsValid(tensor, memrefType); 227 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 228 tensor); 229 } 230 231 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate 232 /// a new buffer and copy over data from the existing buffer if out-of-place 233 /// bufferization is necessary. 234 FailureOr<Value> BufferizationState::getBuffer( 235 RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace, 236 Optional<Operation *> customCopyInsertionPoint) const { 237 OpBuilder::InsertionGuard guard(rewriter); 238 Operation *op = opOperand.getOwner(); 239 Location loc = op->getLoc(); 240 Value operand = opOperand.get(); 241 Value operandBuffer = lookupBuffer(rewriter, operand, options); 242 243 if (forceInPlace || isInPlace(opOperand)) 244 return operandBuffer; 245 246 // Bufferizing out-of-place: Allocate a new buffer. 247 // Move insertion point right after `operandBuffer`. That is where the 248 // allocation should be inserted (in the absence of allocation hoisting). 249 setInsertionPointAfter(rewriter, operandBuffer); 250 // Allocate the result buffer. 251 FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer, 252 options.createDeallocs, options); 253 if (failed(resultBuffer)) 254 return failure(); 255 // Do not copy if the last preceding writes of `operand` are ops that do 256 // not write (skipping ops that merely create aliases). E.g., InitTensorOp. 257 // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA 258 // use-def chain, it returns that value, regardless of whether it is a 259 // memory write or not. 260 SetVector<Value> lastWrites = findLastPrecedingWrite(operand); 261 if (llvm::none_of(lastWrites, [&](Value lastWrite) { 262 if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) 263 return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), 264 *this); 265 return true; 266 })) 267 return resultBuffer; 268 // Do not copy if the copied data is never read. 269 OpResult aliasingOpResult = getAliasingOpResult(opOperand); 270 if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) && 271 !isValueRead(aliasingOpResult)) 272 return resultBuffer; 273 // Do not copy if this op does not read the data, but writes it. 274 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 275 return resultBuffer; 276 277 if (customCopyInsertionPoint) { 278 rewriter.setInsertionPoint(*customCopyInsertionPoint); 279 } else { 280 // The copy happens right before the op that is bufferized. 281 rewriter.setInsertionPoint(op); 282 } 283 if (failed( 284 createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options))) 285 return failure(); 286 287 return resultBuffer; 288 } 289 290 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 291 Operation *op, 292 ValueRange values) { 293 OpBuilder::InsertionGuard g(rewriter); 294 295 // Replace all OpResults with the given values. 296 for (OpResult opResult : op->getOpResults()) { 297 // Skip OpResult if it has no uses. 298 if (opResult.getUses().empty()) 299 continue; 300 301 Value replacement = values[opResult.getResultNumber()]; 302 if (opResult.getType().isa<TensorType>()) { 303 // The OpResult is a tensor. Such values are replaced with memrefs during 304 // bufferization. 305 assert((replacement.getType().isa<MemRefType>() || 306 replacement.getType().isa<UnrankedMemRefType>()) && 307 "tensor op result should be replaced with a memref value"); 308 // The existing uses of the OpResult still expect a tensor. Insert a 309 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 310 // loose all of its users and eventually DCE away. 311 rewriter.setInsertionPointAfter(op); 312 replacement = rewriter.create<bufferization::ToTensorOp>( 313 replacement.getLoc(), replacement); 314 } 315 opResult.replaceAllUsesWith(replacement); 316 } 317 318 rewriter.eraseOp(op); 319 } 320 321 AlwaysCopyBufferizationState::AlwaysCopyBufferizationState( 322 const BufferizationOptions &options) 323 : BufferizationState(options) {} 324 325 /// Return `true` if the given OpResult has been decided to bufferize inplace. 326 bool AlwaysCopyBufferizationState::isInPlace(OpOperand &opOperand) const { 327 // OpOperands that bufferize to a memory write are out-of-place, i.e., an 328 // alloc and copy is inserted. 329 return !bufferizesToMemoryWrite(opOperand); 330 } 331 332 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 333 bool AlwaysCopyBufferizationState::areEquivalentBufferizedValues( 334 Value v1, Value v2) const { 335 // There is no analysis, so we do not know if the values are equivalent. The 336 // conservative answer is "false". 337 return false; 338 } 339 340 //===----------------------------------------------------------------------===// 341 // Bufferization-specific scoped alloc/dealloc insertion support. 342 //===----------------------------------------------------------------------===// 343 344 /// Move the insertion point of the given builder to the beginning of a 345 /// surrounding block as much as possible, while not crossing any allocation 346 /// hoisting barriers. 347 static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) { 348 Operation *op = b.getInsertionBlock()->getParentOp(); 349 while (op) { 350 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) 351 if (bufferizableOp.isAllocationHoistingBarrier()) 352 break; 353 op = op->getParentOp(); 354 } 355 356 if (!op) { 357 // No allocation hoisting barrier found. Hoist to FuncOp. 358 op = b.getInsertionBlock()->getParentOp(); 359 if (!isa<FuncOp>(op)) 360 op = op->getParentOfType<FuncOp>(); 361 assert(op && "could not find enclosing FuncOp"); 362 } 363 364 // TODO: Handle cases where allocation hoisting barrier has more than one 365 // region or block. 366 assert(op->getNumRegions() == 1 && 367 "allocation hoisting barriers with >1 regions not supported"); 368 assert(op->getRegion(0).getBlocks().size() == 1 && 369 "allocation hoisting barriers with >1 blocks not supported"); 370 b.setInsertionPointToStart(&(op->getRegion(0).front())); 371 } 372 373 /// Compute the type of the `memref` to use for allocating the buffer for 374 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 375 /// dynamic dimensions in the returned `memref` type. The function may also set 376 /// the insertion point to an earlier location, where the allocation should 377 /// happen ("allocation hoisting"). 378 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 379 Value shapedValue, 380 SmallVectorImpl<Value> &dynShape) { 381 MemRefType allocMemRefType = 382 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 383 384 // Compute the dynamic part of the shape. 385 bool reifiedShapes = false; 386 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 387 shapedValue.getDefiningOp())) { 388 ReifiedRankedShapedTypeDims resultDims; 389 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 390 reifiedShapes = true; 391 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 392 auto &shape = resultDims[resultValue.getResultNumber()]; 393 for (const auto &dim : enumerate(allocMemRefType.getShape())) 394 if (ShapedType::isDynamic(dim.value())) 395 dynShape.push_back(shape[dim.index()]); 396 } 397 } 398 399 if (!reifiedShapes) { 400 for (const auto &dim : enumerate(allocMemRefType.getShape())) 401 if (ShapedType::isDynamic(dim.value())) { 402 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 403 shapedValue.getType().isa<MemRefType>()) && 404 "expected MemRef type"); 405 dynShape.push_back( 406 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 407 } 408 } 409 410 // If the buffer is statically shaped, try to hoist it to the first enclosing 411 // parallel region. 412 // TODO: also hoist in the dynamic case. For now this relies on subsequent 413 // calls to LICM and buffer hoisting which will most likely not succeed. 414 // TODO: when packing, allocate a static bounding box which will enable more 415 // hoisting. 416 if (dynShape.empty()) 417 moveInsertionPointToAllocationHoistingBarrier(b); 418 419 return allocMemRefType; 420 } 421 422 /// Create an AllocOp/DeallocOp pair, where the AllocOp is after 423 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a 424 /// bbArg) and the DeallocOp is at the end of the block. 425 FailureOr<Value> 426 bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue, 427 bool deallocMemref, 428 const BufferizationOptions &options) { 429 // Take a guard before anything else. 430 OpBuilder::InsertionGuard g(b); 431 432 // 1. Create memory allocation. 433 assert(shapedValue.getType().isa<ShapedType>()); 434 MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>(); 435 SmallVector<Value> dynShape; 436 // Note: getAllocationTypeAndShape also sets the insertion point. 437 MemRefType allocMemRefType = 438 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 439 FailureOr<Value> allocated = 440 createAlloc(b, loc, allocMemRefType, dynShape, options); 441 if (failed(allocated)) 442 return failure(); 443 Value casted = allocated.getValue(); 444 if (memRefType && memRefType != allocMemRefType) { 445 assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(), 446 memRefType) && 447 "createAlloc: cast incompatible"); 448 casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue()); 449 } 450 451 if (deallocMemref) { 452 // 2. Create memory deallocation. 453 b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); 454 if (failed(createDealloc(b, loc, allocated.getValue(), options))) 455 return failure(); 456 } 457 458 return casted; 459 } 460 461 /// Create a memref allocation with the given type and dynamic extents. 462 FailureOr<Value> 463 bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, 464 ValueRange dynShape, 465 const BufferizationOptions &options) { 466 if (options.allocationFn) 467 return (*options.allocationFn)(b, loc, type, dynShape, 468 options.bufferAlignment); 469 470 // Default bufferallocation via AllocOp. 471 Value allocated = b.create<memref::AllocOp>( 472 loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment)); 473 return allocated; 474 } 475 476 /// Create a memref allocation with the given type and dynamic extents. May also 477 /// deallocate the memref again. 478 FailureOr<Value> 479 bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, 480 ValueRange dynShape, bool deallocMemref, 481 const BufferizationOptions &options) { 482 OpBuilder::InsertionGuard g(b); 483 484 FailureOr<Value> alloc = createAlloc(b, loc, type, dynShape, options); 485 if (failed(alloc)) 486 return failure(); 487 488 if (deallocMemref) { 489 // Dealloc at the end of the block. 490 b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator()); 491 if (failed(createDealloc(b, loc, *alloc, options))) 492 return failure(); 493 } 494 495 return alloc; 496 } 497 498 /// Create a memref deallocation. 499 LogicalResult 500 bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, 501 const BufferizationOptions &options) { 502 if (options.deallocationFn) 503 return (*options.deallocationFn)(b, loc, allocatedBuffer); 504 505 // Default buffer deallocation via DeallocOp. 506 b.create<memref::DeallocOp>(loc, allocatedBuffer); 507 return success(); 508 } 509 510 /// Create a memory copy between two memref buffers. 511 LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc, 512 Value from, Value to, 513 const BufferizationOptions &options) { 514 if (options.memCpyFn) 515 return (*options.memCpyFn)(b, loc, from, to); 516 517 b.create<memref::CopyOp>(loc, from, to); 518 return success(); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // Bufferization-specific BlockAndValueMapping support with debugging. 523 //===----------------------------------------------------------------------===// 524 525 bool bufferization::isFunctionArgument(Value value) { 526 auto bbArg = value.dyn_cast<BlockArgument>(); 527 if (!bbArg) 528 return false; 529 return isa<FuncOp>(bbArg.getOwner()->getParentOp()); 530 } 531 532 MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType, 533 Attribute memorySpace) { 534 MemRefLayoutAttrInterface layout = {}; 535 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 536 layout, memorySpace); 537 } 538 539 BaseMemRefType bufferization::getMemRefType(TensorType tensorType, 540 const BufferizationOptions &options, 541 MemRefLayoutAttrInterface layout, 542 Attribute memorySpace) { 543 // Case 1: Unranked memref type. 544 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) { 545 assert(!layout && "UnrankedTensorType cannot have a layout map"); 546 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 547 memorySpace); 548 } 549 550 // Case 2: Ranked memref type with specified layout. If fully dynamic layout 551 // maps are not requested, generate a type with `layout`, which is empty (no 552 // layout map) by default. 553 auto rankedTensorType = tensorType.cast<RankedTensorType>(); 554 if (layout || !options.fullyDynamicLayoutMaps) { 555 return MemRefType::get(rankedTensorType.getShape(), 556 rankedTensorType.getElementType(), layout, 557 memorySpace); 558 } 559 560 // Case 3: Ranked memref type with unspecified layout. Choose the most dynamic 561 // one. 562 // TODO: address space decisions to connect with the actual alloc. 563 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 564 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 565 ShapedType::kDynamicStrideOrOffset); 566 AffineMap stridedLayout = makeStridedLinearLayoutMap( 567 dynamicStrides, dynamicOffset, rankedTensorType.getContext()); 568 return MemRefType::get(rankedTensorType.getShape(), 569 rankedTensorType.getElementType(), stridedLayout, 570 memorySpace); 571 } 572