1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/IR/AsmState.h" 13 #include "mlir/IR/BlockAndValueMapping.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/IR/Value.h" 18 #include "llvm/Support/Debug.h" 19 20 namespace mlir { 21 namespace bufferization { 22 23 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 24 25 } // namespace bufferization 26 } // namespace mlir 27 28 #define DEBUG_TYPE "bufferizable-op-interface" 29 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 30 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 31 32 using namespace mlir; 33 using namespace bufferization; 34 35 //===----------------------------------------------------------------------===// 36 // BufferizationOptions 37 //===----------------------------------------------------------------------===// 38 39 // Default constructor for BufferizationOptions. 40 BufferizationOptions::BufferizationOptions() {} 41 42 BufferizableOpInterface 43 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 44 if (isOpAllowed(op)) 45 return dyn_cast<BufferizableOpInterface>(op); 46 return nullptr; 47 } 48 49 BufferizableOpInterface 50 BufferizationOptions::dynCastBufferizableOp(Value value) const { 51 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) 52 if (isOpAllowed(bufferizableOp.getOperation())) 53 return bufferizableOp; 54 return nullptr; 55 } 56 57 //===----------------------------------------------------------------------===// 58 // Helper functions for BufferizableOpInterface 59 //===----------------------------------------------------------------------===// 60 61 static void setInsertionPointAfter(OpBuilder &b, Value value) { 62 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 63 b.setInsertionPointToStart(bbArg.getOwner()); 64 } else { 65 b.setInsertionPointAfter(value.getDefiningOp()); 66 } 67 } 68 69 /// Determine which OpOperand* will alias with `result` if the op is bufferized 70 /// in place. Return an empty vector if the op is not bufferizable. 71 SmallVector<OpOperand *> 72 BufferizationState::getAliasingOpOperand(OpResult result) const { 73 if (Operation *op = result.getDefiningOp()) 74 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) 75 return bufferizableOp.getAliasingOpOperand(result, *this); 76 return {}; 77 } 78 79 /// Determine which OpResult will alias with `opOperand` if the op is bufferized 80 /// in place. Return an empty OpResult if the op is not bufferizable. 81 OpResult BufferizationState::getAliasingOpResult(OpOperand &opOperand) const { 82 if (auto bufferizableOp = 83 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 84 return bufferizableOp.getAliasingOpResult(opOperand, *this); 85 return OpResult(); 86 } 87 88 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 89 /// op is not bufferizable. 90 bool BufferizationState::bufferizesToMemoryRead(OpOperand &opOperand) const { 91 if (auto bufferizableOp = 92 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 93 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 94 95 // Unknown op that returns a tensor. The inplace analysis does not support it. 96 // Conservatively return true. 97 return true; 98 } 99 100 /// Return true if `opOperand` bufferizes to a memory write. Return 101 /// `true` if the op is not bufferizable. 102 bool BufferizationState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 103 if (auto bufferizableOp = 104 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 105 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 106 107 // Unknown op that returns a tensor. The inplace analysis does not support it. 108 // Conservatively return true. 109 return true; 110 } 111 112 /// Return true if `opOperand` does neither read nor write but bufferizes to an 113 /// alias. Return false if the op is not bufferizable. 114 bool BufferizationState::bufferizesToAliasOnly(OpOperand &opOperand) const { 115 if (auto bufferizableOp = 116 dyn_cast<BufferizableOpInterface>(opOperand.getOwner())) 117 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 118 119 // Unknown op that returns a tensor. The inplace analysis does not support it. 120 // Conservatively return false. 121 return false; 122 } 123 124 /// Return true if the given value is read by an op that bufferizes to a memory 125 /// read. Also takes into account ops that create an alias but do not read by 126 /// themselves (e.g., ExtractSliceOp). 127 bool BufferizationState::isValueRead(Value value) const { 128 assert(value.getType().isa<TensorType>() && "expected TensorType"); 129 SmallVector<OpOperand *> workingSet; 130 for (OpOperand &use : value.getUses()) 131 workingSet.push_back(&use); 132 133 while (!workingSet.empty()) { 134 OpOperand *uMaybeReading = workingSet.pop_back_val(); 135 // Skip over all ops that neither read nor write (but create an alias). 136 if (bufferizesToAliasOnly(*uMaybeReading)) 137 for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses()) 138 workingSet.push_back(&use); 139 if (bufferizesToMemoryRead(*uMaybeReading)) 140 return true; 141 } 142 143 return false; 144 } 145 146 // Starting from `value`, follow the use-def chain in reverse, always selecting 147 // the aliasing OpOperands. Find and return Values for which `condition` 148 // evaluates to true. OpOperands of such matching Values are not traversed any 149 // further. 150 llvm::SetVector<Value> BufferizationState::findValueInReverseUseDefChain( 151 Value value, llvm::function_ref<bool(Value)> condition) const { 152 llvm::SetVector<Value> result, workingSet; 153 workingSet.insert(value); 154 155 while (!workingSet.empty()) { 156 Value value = workingSet.pop_back_val(); 157 if (condition(value) || value.isa<BlockArgument>()) { 158 result.insert(value); 159 continue; 160 } 161 162 OpResult opResult = value.cast<OpResult>(); 163 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult); 164 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) { 165 result.insert(value); 166 continue; 167 } 168 169 for (OpOperand *o : opOperands) 170 workingSet.insert(o->get()); 171 } 172 173 return result; 174 } 175 176 // Find the Values of the last preceding write of a given Value. 177 llvm::SetVector<Value> 178 BufferizationState::findLastPrecedingWrite(Value value) const { 179 return findValueInReverseUseDefChain(value, [&](Value value) { 180 Operation *op = value.getDefiningOp(); 181 if (!op) 182 return true; 183 auto bufferizableOp = options.dynCastBufferizableOp(op); 184 if (!bufferizableOp) 185 return true; 186 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this); 187 }); 188 } 189 190 BufferizationState::BufferizationState(const BufferizationOptions &options) 191 : options(options) {} 192 193 // bufferization.to_memref is not allowed to change the rank. 194 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 195 #ifndef NDEBUG 196 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 197 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() == 198 rankedTensorType.getRank()) && 199 "to_memref would be invalid: mismatching ranks"); 200 #endif 201 } 202 203 static Value lookupBuffer(RewriterBase &rewriter, Value tensor) { 204 assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type"); 205 206 // Replace "%t = to_tensor %m" with %m. 207 if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>()) 208 return toTensorOp.memref(); 209 210 // Insert to_memref op. 211 OpBuilder::InsertionGuard g(rewriter); 212 setInsertionPointAfter(rewriter, tensor); 213 Type memrefType; 214 if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) { 215 memrefType = getDynamicMemRefType(rankedTensorType); 216 } else { 217 memrefType = getUnrankedMemRefType( 218 tensor.getType().cast<TensorType>().getElementType()); 219 } 220 ensureToMemrefOpIsValid(tensor, memrefType); 221 return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType, 222 tensor); 223 } 224 225 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate 226 /// a new buffer and copy over data from the existing buffer if out-of-place 227 /// bufferization is necessary. 228 FailureOr<Value> BufferizationState::getBuffer( 229 RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace, 230 Optional<Operation *> customCopyInsertionPoint) const { 231 OpBuilder::InsertionGuard guard(rewriter); 232 Operation *op = opOperand.getOwner(); 233 Location loc = op->getLoc(); 234 Value operand = opOperand.get(); 235 Value operandBuffer = lookupBuffer(rewriter, operand); 236 237 if (forceInPlace || isInPlace(opOperand)) 238 return operandBuffer; 239 240 // Bufferizing out-of-place: Allocate a new buffer. 241 // Move insertion point right after `operandBuffer`. That is where the 242 // allocation should be inserted (in the absence of allocation hoisting). 243 setInsertionPointAfter(rewriter, operandBuffer); 244 // Allocate the result buffer. 245 FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer, 246 options.createDeallocs, options); 247 if (failed(resultBuffer)) 248 return failure(); 249 // Do not copy if the last preceding writes of `operand` are ops that do 250 // not write (skipping ops that merely create aliases). E.g., InitTensorOp. 251 // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA 252 // use-def chain, it returns that value, regardless of whether it is a 253 // memory write or not. 254 SetVector<Value> lastWrites = findLastPrecedingWrite(operand); 255 if (llvm::none_of(lastWrites, [&](Value lastWrite) { 256 if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) 257 return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), 258 *this); 259 return true; 260 })) 261 return resultBuffer; 262 // Do not copy if the copied data is never read. 263 OpResult aliasingOpResult = getAliasingOpResult(opOperand); 264 if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) && 265 !isValueRead(aliasingOpResult)) 266 return resultBuffer; 267 // Do not copy if this op does not read the data, but writes it. 268 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 269 return resultBuffer; 270 271 if (customCopyInsertionPoint) { 272 rewriter.setInsertionPoint(*customCopyInsertionPoint); 273 } else { 274 // The copy happens right before the op that is bufferized. 275 rewriter.setInsertionPoint(op); 276 } 277 if (failed( 278 createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options))) 279 return failure(); 280 281 return resultBuffer; 282 } 283 284 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 285 Operation *op, 286 ValueRange values) { 287 OpBuilder::InsertionGuard g(rewriter); 288 289 // Replace all OpResults with the given values. 290 for (OpResult opResult : op->getOpResults()) { 291 // Skip OpResult if it has no uses. 292 if (opResult.getUses().empty()) 293 continue; 294 295 Value replacement = values[opResult.getResultNumber()]; 296 if (opResult.getType().isa<TensorType>()) { 297 // The OpResult is a tensor. Such values are replaced with memrefs during 298 // bufferization. 299 assert((replacement.getType().isa<MemRefType>() || 300 replacement.getType().isa<UnrankedMemRefType>()) && 301 "tensor op result should be replaced with a memref value"); 302 // The existing uses of the OpResult still expect a tensor. Insert a 303 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 304 // loose all of its users and eventually DCE away. 305 setInsertionPointAfter(rewriter, replacement); 306 replacement = rewriter.create<bufferization::ToTensorOp>( 307 replacement.getLoc(), replacement); 308 } 309 opResult.replaceAllUsesWith(replacement); 310 } 311 312 rewriter.eraseOp(op); 313 } 314 315 //===----------------------------------------------------------------------===// 316 // Bufferization-specific scoped alloc/dealloc insertion support. 317 //===----------------------------------------------------------------------===// 318 319 /// Move the insertion point of the given builder to the beginning of a 320 /// surrounding block as much as possible, while not crossing any allocation 321 /// hoisting barriers. 322 static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) { 323 Operation *op = b.getInsertionBlock()->getParentOp(); 324 while (op) { 325 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) 326 if (bufferizableOp.isAllocationHoistingBarrier()) 327 break; 328 op = op->getParentOp(); 329 } 330 331 if (!op) { 332 // No allocation hoisting barrier found. Hoist to FuncOp. 333 op = b.getInsertionBlock()->getParentOp(); 334 if (!isa<FuncOp>(op)) 335 op = op->getParentOfType<FuncOp>(); 336 assert(op && "could not find enclosing FuncOp"); 337 } 338 339 // TODO: Handle cases where allocation hoisting barrier has more than one 340 // region or block. 341 assert(op->getNumRegions() == 1 && 342 "allocation hoisting barriers with >1 regions not supported"); 343 assert(op->getRegion(0).getBlocks().size() == 1 && 344 "allocation hoisting barriers with >1 blocks not supported"); 345 b.setInsertionPointToStart(&(op->getRegion(0).front())); 346 } 347 348 /// Compute the type of the `memref` to use for allocating the buffer for 349 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the 350 /// dynamic dimensions in the returned `memref` type. The function may also set 351 /// the insertion point to an earlier location, where the allocation should 352 /// happen ("allocation hoisting"). 353 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, 354 Value shapedValue, 355 SmallVectorImpl<Value> &dynShape) { 356 MemRefType allocMemRefType = 357 getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); 358 359 // Compute the dynamic part of the shape. 360 bool reifiedShapes = false; 361 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>( 362 shapedValue.getDefiningOp())) { 363 ReifiedRankedShapedTypeDims resultDims; 364 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { 365 reifiedShapes = true; 366 OpResult resultValue = shapedValue.dyn_cast<OpResult>(); 367 auto &shape = resultDims[resultValue.getResultNumber()]; 368 for (const auto &dim : enumerate(allocMemRefType.getShape())) 369 if (ShapedType::isDynamic(dim.value())) 370 dynShape.push_back(shape[dim.index()]); 371 } 372 } 373 374 if (!reifiedShapes) { 375 for (const auto &dim : enumerate(allocMemRefType.getShape())) 376 if (ShapedType::isDynamic(dim.value())) { 377 assert((shapedValue.getType().isa<UnrankedMemRefType>() || 378 shapedValue.getType().isa<MemRefType>()) && 379 "expected MemRef type"); 380 dynShape.push_back( 381 b.create<memref::DimOp>(loc, shapedValue, dim.index())); 382 } 383 } 384 385 // If the buffer is statically shaped, try to hoist it to the first enclosing 386 // parallel region. 387 // TODO: also hoist in the dynamic case. For now this relies on subsequent 388 // calls to LICM and buffer hoisting which will most likely not succeed. 389 // TODO: when packing, allocate a static bounding box which will enable more 390 // hoisting. 391 if (dynShape.empty()) 392 moveInsertionPointToAllocationHoistingBarrier(b); 393 394 return allocMemRefType; 395 } 396 397 /// Create an AllocOp/DeallocOp pair, where the AllocOp is after 398 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a 399 /// bbArg) and the DeallocOp is at the end of the block. 400 FailureOr<Value> 401 bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue, 402 bool deallocMemref, 403 const BufferizationOptions &options) { 404 // Take a guard before anything else. 405 OpBuilder::InsertionGuard g(b); 406 407 // 1. Create memory allocation. 408 assert(shapedValue.getType().isa<ShapedType>()); 409 MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>(); 410 SmallVector<Value> dynShape; 411 // Note: getAllocationTypeAndShape also sets the insertion point. 412 MemRefType allocMemRefType = 413 getAllocationTypeAndShape(b, loc, shapedValue, dynShape); 414 FailureOr<Value> allocated = 415 createAlloc(b, loc, allocMemRefType, dynShape, options); 416 if (failed(allocated)) 417 return failure(); 418 Value casted = allocated.getValue(); 419 if (memRefType && memRefType != allocMemRefType) { 420 assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(), 421 memRefType) && 422 "createAlloc: cast incompatible"); 423 casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue()); 424 } 425 426 if (deallocMemref) { 427 // 2. Create memory deallocation. 428 b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); 429 if (failed(createDealloc(b, loc, allocated.getValue(), options))) 430 return failure(); 431 } 432 433 return casted; 434 } 435 436 /// Create a memref allocation. 437 FailureOr<Value> 438 bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, 439 ArrayRef<Value> dynShape, 440 const BufferizationOptions &options) { 441 if (options.allocationFn) 442 return (*options.allocationFn)(b, loc, type, dynShape); 443 444 // Default bufferallocation via AllocOp. 445 Value allocated = b.create<memref::AllocOp>( 446 loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); 447 return allocated; 448 } 449 450 /// Create a memref deallocation. 451 LogicalResult 452 bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, 453 const BufferizationOptions &options) { 454 if (options.deallocationFn) 455 return (*options.deallocationFn)(b, loc, allocatedBuffer); 456 457 // Default buffer deallocation via DeallocOp. 458 b.create<memref::DeallocOp>(loc, allocatedBuffer); 459 return success(); 460 } 461 462 /// Create a memory copy between two memref buffers. 463 LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc, 464 Value from, Value to, 465 const BufferizationOptions &options) { 466 if (options.memCpyFn) 467 return (*options.memCpyFn)(b, loc, from, to); 468 469 b.create<memref::CopyOp>(loc, from, to); 470 return success(); 471 } 472 473 //===----------------------------------------------------------------------===// 474 // Bufferization-specific BlockAndValueMapping support with debugging. 475 //===----------------------------------------------------------------------===// 476 477 bool bufferization::isFunctionArgument(Value value) { 478 auto bbArg = value.dyn_cast<BlockArgument>(); 479 if (!bbArg) 480 return false; 481 return isa<FuncOp>(bbArg.getOwner()->getParentOp()); 482 } 483 484 MemRefType 485 bufferization::getContiguousMemRefType(ShapedType shapedType, 486 MemRefLayoutAttrInterface layout, 487 Attribute memorySpace) { 488 return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), 489 layout, memorySpace); 490 } 491 492 UnrankedMemRefType bufferization::getUnrankedMemRefType(Type elementType, 493 Attribute memorySpace) { 494 return UnrankedMemRefType::get(elementType, memorySpace); 495 } 496 497 MemRefType bufferization::getDynamicMemRefType(RankedTensorType tensorType, 498 unsigned addressSpace) { 499 // TODO: address space decisions to connect with the actual alloc. 500 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 501 SmallVector<int64_t> dynamicStrides(tensorType.getRank(), 502 ShapedType::kDynamicStrideOrOffset); 503 AffineMap stridedLayout = makeStridedLinearLayoutMap( 504 dynamicStrides, dynamicOffset, tensorType.getContext()); 505 return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), 506 stridedLayout, addressSpace); 507 } 508