1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Dominance.h" 16 #include "mlir/IR/Operation.h" 17 18 using namespace mlir; 19 using namespace linalg; 20 using namespace mlir::bufferization; 21 22 namespace { 23 24 // TODO: Ops in the linalg dialect can directly implement this interface. 25 26 /// Generic conversion for any LinalgOp on tensors. 27 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, 28 const BufferizationState &state) { 29 // Take a guard before anything else. 30 OpBuilder::InsertionGuard g(rewriter); 31 rewriter.setInsertionPoint(op); 32 33 // Nothing to do. This op is already bufferized. 34 if (op.hasBufferSemantics()) 35 return success(); 36 37 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need 38 // basis. 39 if (!op.hasTensorSemantics()) 40 return op->emitError() << "op does not have tensor semantics"; 41 42 // New input operands for the cloned op. 43 SmallVector<Value> newInputBuffers; 44 newInputBuffers.reserve(op.getNumInputs()); 45 for (OpOperand *opOperand : op.getInputOperands()) { 46 if (op.isScalar(opOperand)) { 47 newInputBuffers.push_back(opOperand->get()); 48 continue; 49 } 50 // Input operands are never written to. 51 newInputBuffers.push_back( 52 *state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true)); 53 } 54 55 // New output operands for the cloned op. 56 SmallVector<Value> newOutputBuffers; 57 for (OpResult opResult : op->getOpResults()) { 58 SmallVector<OpOperand *> aliasingOpOperands = 59 state.getAliasingOpOperand(opResult); 60 assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); 61 FailureOr<Value> resultBuffer = 62 state.getBuffer(rewriter, *aliasingOpOperands.front()); 63 if (failed(resultBuffer)) 64 return failure(); 65 newOutputBuffers.push_back(*resultBuffer); 66 } 67 68 // Merge input/output operands. 69 SmallVector<Value> newOperands = newInputBuffers; 70 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); 71 72 // Set insertion point now that potential alloc/dealloc are introduced. 73 rewriter.setInsertionPoint(op); 74 // Clone the op, but use the new operands. Move the existing block into the 75 // new op. Since the new op does not have any tensor results, it does not 76 // return anything. 77 assert(op->getNumRegions() == 1 && "expected that op has 1 region"); 78 auto newOp = cast<LinalgOp>(op.cloneWithoutRegions( 79 rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); 80 rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), 81 newOp->getRegion(0).begin()); 82 83 // Replace the results of the old op with the new output buffers. 84 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); 85 86 return success(); 87 } 88 89 /// Linalg OpResults usually bufferize inplace with their tied (output 90 /// OpOperands. However, if an output OpOperand is not used in the computation, 91 /// it is better to bufferize inplace with an actually used input OpOperand; 92 /// less memory will be touched that way. 93 /// 94 /// Example: 95 /// O(i, j) = A(i, j) + B(j) --> bufferizes inplace to: A(i, j) += B(j) 96 /// 97 /// O(i, j) = A(j, i) + B(j) --> cannot bufferize inplace with A because 98 /// indexing maps are not identical 99 /// 100 /// O(i, j) += A(i, j) + B(j) --> Output is used in computation. 101 /// This could bufferize inplace with A: 102 /// A(i, j) += O(i, j) + B(j) 103 /// However, we choose to bufferize inplace with O here, as there is no clear 104 /// benefit of choosing A. TODO: We may want to consider both options and make 105 /// an informed decision during analysis in the future. 106 static DenseMap<OpOperand *, OpResult> computeAliasingPairs(LinalgOp op) { 107 DenseMap<OpOperand *, OpResult> mapping; 108 for (OpResult opResult : op->getOpResults()) { 109 OpOperand *tiedOperand = 110 op.getOutputTensorOperands()[opResult.getResultNumber()]; 111 AffineMap outputIndexingMap = op.getTiedIndexingMap(tiedOperand); 112 bool onlyParallelIterators = op.getNumParallelLoops() == op.getNumLoops(); 113 bool tiedOperandUsed = op.payloadUsesValueFromOperand(tiedOperand); 114 115 // If the output arg is used in the computation or at least one iterator is 116 // not parallel, try to bufferize inplace with the corresponding output 117 // tensor. 118 if (tiedOperandUsed || !onlyParallelIterators) { 119 mapping[tiedOperand] = opResult; 120 continue; 121 } 122 123 // Otherwise, try to bufferize inplace with one of the inputs. 124 OpOperand *chosenOperand = nullptr; 125 for (OpOperand *opOperand : op.getInputTensorOperands()) { 126 if (opOperand->get().getType() != opResult.getType()) 127 continue; 128 if (!op.payloadUsesValueFromOperand(opOperand)) 129 continue; 130 if (op.getTiedIndexingMap(opOperand) != outputIndexingMap) 131 continue; 132 // No other OpResult bufferizes aliases with this OpOperand. 133 if (mapping.count(opOperand)) 134 continue; 135 assert(op.getTiedIndexingMap(opOperand).isProjectedPermutation() && 136 "expected projected permutation"); 137 chosenOperand = opOperand; 138 break; 139 } 140 141 // No suitable input tensor found. Use output tensor. 142 // TODO: This operand could bufferize inplace with OpOperands that have the 143 // correct type, even if they are not used inside the computation. 144 if (!chosenOperand) 145 chosenOperand = tiedOperand; 146 147 mapping[chosenOperand] = opResult; 148 } 149 return mapping; 150 } 151 152 /// Bufferization of linalg.generic. Replace with a new linalg.generic that 153 /// operates entirely on memrefs. 154 template <typename OpTy> 155 struct LinalgOpInterface 156 : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, 157 OpTy> { 158 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 159 const BufferizationState &state) const { 160 // Operand is read if it is used in the computation. 161 auto genericOp = cast<linalg::LinalgOp>(op); 162 return genericOp.payloadUsesValueFromOperand(&opOperand); 163 } 164 165 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 166 const BufferizationState &state) const { 167 // Operand is written to if it has an aliasing OpResult. 168 auto bufferizableOp = cast<BufferizableOpInterface>(op); 169 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 170 } 171 172 SmallVector<OpOperand *> 173 getAliasingOpOperand(Operation *op, OpResult opResult, 174 const BufferizationState &state) const { 175 auto genericOp = cast<linalg::LinalgOp>(op); 176 177 // By default, the i-th OpResult may alias with the i-th "out" tensor. 178 if (state.getOptions().alwaysAliasingWithDest) 179 return {genericOp.getOutputOperand(opResult.getResultNumber())}; 180 181 // We can try to be smart and alias in-place with an "in" tensor if the 182 // corresponding "out" tensor is not used in the computation. 183 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 184 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 185 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) 186 if (pairs[opOperand] == opResult) 187 return {opOperand}; 188 return {}; 189 } 190 191 SmallVector<OpResult> 192 getAliasingOpResult(Operation *op, OpOperand &opOperand, 193 const BufferizationState &state) const { 194 auto genericOp = cast<linalg::LinalgOp>(op); 195 196 // By default, the i-th "out" tensor may alias with the i-th OpResult. 197 if (state.getOptions().alwaysAliasingWithDest) { 198 if (genericOp.isOutputTensor(&opOperand)) 199 return {genericOp.getTiedOpResult(&opOperand)}; 200 return {}; 201 } 202 203 // We can try to be smart. See comment in `getAliasingOpOperand`. 204 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 205 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 206 if (!pairs.count(&opOperand)) 207 return {}; 208 return {pairs[&opOperand]}; 209 } 210 211 BufferRelation bufferRelation(Operation *op, OpResult opResult, 212 const BufferizationState &state) const { 213 return BufferRelation::Equivalent; 214 } 215 216 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 217 const BufferizationState &state) const { 218 return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state); 219 } 220 }; 221 222 struct InitTensorOpInterface 223 : public BufferizableOpInterface::ExternalModel<InitTensorOpInterface, 224 linalg::InitTensorOp> { 225 bool isMemoryWrite(Operation *op, OpResult opResult, 226 const BufferizationState &state) const { 227 // InitTensorOps allocate but do not write. 228 return false; 229 } 230 231 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 232 const BufferizationState &state) const { 233 auto initTensorOp = cast<linalg::InitTensorOp>(op); 234 235 // The InitTensorOp may have been eliminated. 236 if (initTensorOp->getUses().empty()) 237 return success(); 238 239 FailureOr<Value> alloc = 240 createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(), 241 state.getOptions().createDeallocs, state.getOptions()); 242 if (failed(alloc)) 243 return failure(); 244 replaceOpWithBufferizedValues(rewriter, op, *alloc); 245 return success(); 246 } 247 }; 248 249 /// Bufferization of linalg.tiled_loop. Replace with a new linalg.tiled_loop 250 /// that operates entirely on memrefs. 251 struct TiledLoopOpInterface 252 : public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface, 253 linalg::TiledLoopOp> { 254 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 255 const BufferizationState &state) const { 256 auto tiledLoopOp = cast<linalg::TiledLoopOp>(op); 257 258 // linalg.tiled_loop operands alone do not bufferize to a memory read, but 259 // one of the uses of their matching bbArgs may. 260 return state.isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand)); 261 } 262 263 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 264 const BufferizationState &state) const { 265 auto bufferizableOp = cast<BufferizableOpInterface>(op); 266 267 // Only operands with an aliasing OpResult (i.e., output operands) bufferize 268 // to a memory write. 269 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 270 } 271 272 SmallVector<OpResult> 273 getAliasingOpResult(Operation *op, OpOperand &opOperand, 274 const BufferizationState &state) const { 275 auto tiledLoopOp = cast<linalg::TiledLoopOp>(op); 276 277 // Output operands are tied to their corresponding OpResults. 278 OpResult opResult = tiledLoopOp.getTiedOpResult(opOperand); 279 if (!opResult) 280 return {}; 281 return {opResult}; 282 } 283 284 BufferRelation bufferRelation(Operation *op, OpResult opResult, 285 const BufferizationState &state) const { 286 return BufferRelation::Equivalent; 287 } 288 289 bool isWritable(Operation *op, Value value, 290 const BufferizationState &state) const { 291 // Interestingly, linalg::TiledLoopOp's bbArgs can **always** be viewed 292 // inplace from the perspective of nested ops: 293 // 1. Either the matching iter operand is not bufferized inplace and an 294 // alloc + optional copy makes the bbArg itself inplaceable. 295 // 2. Or the matching iter operand is bufferized inplace and bbArg just 296 // bufferizes to that too. 297 return true; 298 } 299 300 bool isAllocationHoistingBarrier(Operation *op) const { return true; } 301 302 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 303 const BufferizationState &state) const { 304 auto tiledLoopOp = cast<linalg::TiledLoopOp>(op); 305 306 // Compute new inputs, outputs and results. 307 SmallVector<Value> newInputs, newOutputs, newResults; 308 for (unsigned i = tiledLoopOp.getNumControlOperands(); 309 i < tiledLoopOp->getNumOperands(); ++i) { 310 OpOperand &operand = tiledLoopOp->getOpOperand(i); 311 Value rewrittenValue = operand.get(); 312 if (rewrittenValue.getType().isa<TensorType>()) { 313 FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, operand); 314 if (failed(bufferOrFailure)) 315 return failure(); 316 rewrittenValue = *bufferOrFailure; 317 } 318 if (i < 319 tiledLoopOp.getNumControlOperands() + tiledLoopOp.getNumInputs()) { 320 newInputs.push_back(rewrittenValue); 321 } else { 322 newOutputs.push_back(rewrittenValue); 323 if (operand.get().getType().isa<TensorType>()) 324 newResults.push_back(rewrittenValue); 325 } 326 } 327 328 // Create new TiledLoopOp. 329 auto newTiledLoopOp = rewriter.create<TiledLoopOp>( 330 tiledLoopOp.getLoc(), tiledLoopOp.lowerBound(), 331 tiledLoopOp.upperBound(), tiledLoopOp.step(), newInputs, newOutputs, 332 tiledLoopOp.iterator_types(), tiledLoopOp.distribution_types()); 333 334 // Remove terminator. 335 if (!newTiledLoopOp.getBody()->empty()) 336 rewriter.eraseOp(tiledLoopOp.getBody()->getTerminator()); 337 338 // Compute new loop body arguments. 339 SmallVector<Value> newBlockArgs, newRegionInOutArgs, oldRegionInOutArgs; 340 ValueRange newInductionVars = newTiledLoopOp.getInductionVars(); 341 newBlockArgs.append(newInductionVars.begin(), newInductionVars.end()); 342 343 ValueRange newRegionInArgs = newTiledLoopOp.getRegionInputArgs(); 344 ValueRange newRegionOutArgs = newTiledLoopOp.getRegionOutputArgs(); 345 newRegionInOutArgs.append(newRegionInArgs.begin(), newRegionInArgs.end()); 346 newRegionInOutArgs.append(newRegionOutArgs.begin(), newRegionOutArgs.end()); 347 348 ValueRange oldRegionInArgs = tiledLoopOp.getRegionInputArgs(); 349 ValueRange oldRegionOutArgs = tiledLoopOp.getRegionOutputArgs(); 350 oldRegionInOutArgs.append(oldRegionInArgs.begin(), oldRegionInArgs.end()); 351 oldRegionInOutArgs.append(oldRegionOutArgs.begin(), oldRegionOutArgs.end()); 352 assert(newRegionInArgs.size() == oldRegionInArgs.size() && 353 "expected same number of input args"); 354 assert(newRegionOutArgs.size() == oldRegionOutArgs.size() && 355 "expected same number of output args"); 356 357 for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) { 358 Value oldArg = std::get<0>(it); 359 Value newArg = std::get<1>(it); 360 rewriter.setInsertionPointToStart(newTiledLoopOp.getBody()); 361 if (oldArg.getType().isa<TensorType>()) { 362 newBlockArgs.push_back(rewriter.create<bufferization::ToTensorOp>( 363 oldArg.getLoc(), newArg)); 364 } else { 365 newBlockArgs.push_back(newArg); 366 } 367 } 368 369 // Move old body into new loop. 370 rewriter.mergeBlocks(tiledLoopOp.getBody(), newTiledLoopOp.getBody(), 371 newBlockArgs); 372 373 // Replace previous terminator with a new one that does not yield anything. 374 auto oldTerminator = 375 cast<linalg::YieldOp>(newTiledLoopOp.getBody()->getTerminator()); 376 rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody()); 377 auto newTerminator = 378 rewriter.create<linalg::YieldOp>(oldTerminator->getLoc()); 379 380 // Copy buffer of yielded tensor to output buffer. If everything bufferized 381 // inplace, this copy will fold away. 382 rewriter.setInsertionPoint(newTerminator); 383 for (auto it : llvm::zip(oldTerminator.values(), newOutputs)) { 384 Value output = std::get<1>(it); 385 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 386 newTerminator.getLoc(), output.getType(), std::get<0>(it)); 387 if (failed(createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp, 388 output, state.getOptions()))) 389 return failure(); 390 } 391 392 // Erase old terminator. 393 rewriter.eraseOp(oldTerminator); 394 395 // Replace results and delete old op. 396 replaceOpWithBufferizedValues(rewriter, op, newResults); 397 398 return success(); 399 } 400 }; 401 402 /// Bufferization of linalg.yield. Bufferized as part of linalg.tiled_loop's 403 /// bufferization. 404 struct YieldOpInterface 405 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 406 linalg::YieldOp> { 407 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 408 const BufferizationState &state) const { 409 return true; 410 } 411 412 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 413 const BufferizationState &state) const { 414 return false; 415 } 416 417 SmallVector<OpResult> 418 getAliasingOpResult(Operation *op, OpOperand &opOperand, 419 const BufferizationState &state) const { 420 return {}; 421 } 422 423 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 424 const BufferizationState &state) const { 425 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 426 // may be generated inside the block. We should not return/yield allocations 427 // when possible. 428 return true; 429 } 430 431 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 432 const BufferizationState &state) const { 433 auto yieldOp = cast<linalg::YieldOp>(op); 434 435 if (!yieldOp->getParentOfType<TiledLoopOp>()) 436 return yieldOp->emitError( 437 "expected that linalg.yield terminates a tiled_loop"); 438 439 assert(yieldOp->getOpOperands().empty() && 440 "expected that linalg.yield was bufferized together with" 441 " tiled_loop"); 442 return success(); 443 } 444 }; 445 446 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers 447 /// the `BufferizableOpInterface` with each of them. 448 template <typename... OpTys> 449 struct LinalgOpInterfaceHelper; 450 451 template <typename First, typename... Others> 452 struct LinalgOpInterfaceHelper<First, Others...> { 453 static void registerOpInterface(DialectRegistry ®istry) { 454 registry.addOpInterface<First, LinalgOpInterface<First>>(); 455 LinalgOpInterfaceHelper<Others...>::registerOpInterface(registry); 456 } 457 }; 458 459 template <> 460 struct LinalgOpInterfaceHelper<> { 461 static void registerOpInterface(DialectRegistry ®istry) {} 462 }; 463 464 } // namespace 465 466 /// Return true if all `neededValues` are in scope at the given 467 /// `insertionPoint`. 468 static bool 469 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, 470 Operation *insertionPoint, 471 const SmallVector<Value> &neededValues) { 472 for (Value val : neededValues) { 473 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 474 Block *owner = bbArg.getOwner(); 475 if (!owner->findAncestorOpInBlock(*insertionPoint)) 476 return false; 477 } else { 478 auto opResult = val.cast<OpResult>(); 479 if (!domInfo.dominates(opResult.getOwner(), insertionPoint)) 480 return false; 481 } 482 } 483 return true; 484 } 485 486 /// Return true if the given `insertionPoint` dominates all uses of 487 /// `initTensorOp`. 488 static bool insertionPointDominatesUses(const DominanceInfo &domInfo, 489 Operation *insertionPoint, 490 Operation *initTensorOp) { 491 for (Operation *user : initTensorOp->getUsers()) 492 if (!domInfo.dominates(insertionPoint, user)) 493 return false; 494 return true; 495 } 496 497 /// Find a valid insertion point for a replacement of `initTensorOp`, assuming 498 /// that the replacement may use any value from `neededValues`. 499 static Operation * 500 findValidInsertionPoint(Operation *initTensorOp, 501 const SmallVector<Value> &neededValues) { 502 DominanceInfo domInfo; 503 504 // Gather all possible insertion points: the location of `initTensorOp` and 505 // right after the definition of each value in `neededValues`. 506 SmallVector<Operation *> insertionPointCandidates; 507 insertionPointCandidates.push_back(initTensorOp); 508 for (Value val : neededValues) { 509 // Note: The anchor op is using all of `neededValues`, so: 510 // * in case of a block argument: There must be at least one op in the block 511 // (the anchor op or one of its parents). 512 // * in case of an OpResult: There must be at least one op right after the 513 // defining op (the anchor op or one of its 514 // parents). 515 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 516 insertionPointCandidates.push_back( 517 &bbArg.getOwner()->getOperations().front()); 518 } else { 519 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode()); 520 } 521 } 522 523 // Select first matching insertion point. 524 for (Operation *insertionPoint : insertionPointCandidates) { 525 // Check if all needed values are in scope. 526 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, 527 neededValues)) 528 continue; 529 // Check if the insertion point is before all uses. 530 if (!insertionPointDominatesUses(domInfo, insertionPoint, initTensorOp)) 531 continue; 532 return insertionPoint; 533 } 534 535 // No suitable insertion point was found. 536 return nullptr; 537 } 538 539 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced 540 /// with the the result of `rewriteFunc` if it is anchored on a matching 541 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def 542 /// chain, starting from the OpOperand and always following the aliasing 543 /// OpOperand, that eventually ends at a single InitTensorOp. 544 LogicalResult mlir::linalg::eliminateInitTensors( 545 Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, 546 AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, 547 SmallVector<Operation *> &newOps) { 548 OpBuilder b(op->getContext()); 549 550 WalkResult status = op->walk([&](Operation *op) { 551 for (OpOperand &operand : op->getOpOperands()) { 552 // Skip operands that do not bufferize inplace. 553 if (!aliasInfo.isInPlace(operand)) 554 continue; 555 // All values that are needed to create the replacement op. 556 SmallVector<Value> neededValues; 557 // Is this a matching OpOperand? 558 if (!anchorMatchFunc(operand, neededValues)) 559 continue; 560 SetVector<Value> maybeInitTensor = 561 state.findValueInReverseUseDefChain(operand.get(), [&](Value val) { 562 // Continue traversal until this function returns true. 563 OpResult opResult = val.dyn_cast<OpResult>(); 564 if (!opResult) 565 return true; 566 SmallVector<OpOperand *> opOperands = 567 state.getAliasingOpOperand(opResult); 568 if (!llvm::all_of(opOperands, [&](OpOperand *operand) { 569 return aliasInfo.isInPlace(*operand); 570 })) 571 return true; 572 // Only equivalent tensors are supported at the moment. 573 // TODO: Support cases such as extract_slice(init_tensor) 574 return !llvm::all_of(opOperands, [&](OpOperand *operand) { 575 return aliasInfo.areEquivalentBufferizedValues(operand->get(), 576 opResult); 577 }); 578 }); 579 580 // Replace only if the reverse use-def chain ends at exactly one 581 // InitTensorOp. 582 if (maybeInitTensor.size() != 1 || 583 !maybeInitTensor.front().getDefiningOp<InitTensorOp>()) 584 return WalkResult::skip(); 585 Value initTensor = maybeInitTensor.front(); 586 587 // Find a suitable insertion point. 588 Operation *insertionPoint = 589 findValidInsertionPoint(initTensor.getDefiningOp(), neededValues); 590 if (!insertionPoint) 591 continue; 592 593 // Create a replacement for the InitTensorOp. 594 b.setInsertionPoint(insertionPoint); 595 Value replacement = rewriteFunc(b, initTensor.getLoc(), operand); 596 if (!replacement) 597 continue; 598 599 // Uses of the InitTensorOp are replaced here, but the op is not deleted. 600 // InitTensorOps without uses are ignored by the bufferization. 601 initTensor.replaceAllUsesWith(replacement); 602 aliasInfo.createAliasInfoEntry(replacement); 603 aliasInfo.unionAliasSets(initTensor, replacement); 604 aliasInfo.unionEquivalenceClasses(initTensor, replacement); 605 606 // Register replacement ops. 607 if (Operation *newOp = replacement.getDefiningOp()) 608 newOps.push_back(newOp); 609 } 610 611 // Advance to the next operation. 612 return WalkResult::advance(); 613 }); 614 615 return failure(status.wasInterrupted()); 616 } 617 618 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp can be 619 /// eliminated if it is eventually inserted into another tensor (and some other 620 /// conditions are met). 621 /// 622 /// E.g.: 623 /// %0 = linalg.init_tensor 624 /// %1 = linalg.fill(%cst, %0) {inplace = [true]} 625 /// %2 = tensor.insert_slice %1 into %t[10][20][1] 626 /// 627 /// InitTensorOp elimination will try to fill %t inplace instead of filling a 628 /// new allocation %0 and inserting it into %t. This is done by replacing the 629 /// InitTensorOp with: 630 /// 631 /// %0 = tensor.extract_slice %t[10][20][1] 632 /// 633 /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets 634 /// those bufferize inplace in the absence of other conflicts. 635 /// 636 /// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert 637 /// source's reverse use-def chain is eliminated if: 638 /// * The InsertSliceOp was decided to bufferize inplace. 639 /// * On the reverse use-def chain path from the InsertSliceOp to the 640 /// InitTensorOp, all ops were decided to bufferize inplace and the buffer 641 /// relation is "equivalent" (TODO: can be relaxed if needed). 642 /// * The reverse use-def chain has exactly one end, which is the InitTensorOp. 643 /// 644 /// Note that the newly inserted ExtractSliceOp may have to bufferize 645 /// out-of-place due to RaW conflicts. 646 LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep( 647 Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, 648 SmallVector<Operation *> &newOps) { 649 return eliminateInitTensors( 650 op, state, aliasInfo, 651 /*anchorMatchFunc=*/ 652 [&](OpOperand &operand, SmallVector<Value> &neededValues) { 653 auto insertSliceOp = 654 dyn_cast<tensor::InsertSliceOp>(operand.getOwner()); 655 if (!insertSliceOp) 656 return false; 657 // Only inplace bufferized InsertSliceOps are eligible. 658 if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/)) 659 return false; 660 if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) 661 return false; 662 663 // Collect all values that are needed to construct the replacement op. 664 neededValues.append(insertSliceOp.offsets().begin(), 665 insertSliceOp.offsets().end()); 666 neededValues.append(insertSliceOp.sizes().begin(), 667 insertSliceOp.sizes().end()); 668 neededValues.append(insertSliceOp.strides().begin(), 669 insertSliceOp.strides().end()); 670 neededValues.push_back(insertSliceOp.dest()); 671 672 return true; 673 }, 674 /*rewriteFunc=*/ 675 [](OpBuilder &b, Location loc, OpOperand &operand) { 676 auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner()); 677 // Expand offsets, sizes and strides to the full rank to handle the 678 // rank-reducing case. 679 SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets(); 680 SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes(); 681 SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides(); 682 OffsetSizeAndStrideOpInterface::expandToRank( 683 insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides, 684 [&](Value target, int64_t dim) -> OpFoldResult { 685 auto shapedType = target.getType().cast<ShapedType>(); 686 if (shapedType.isDynamicDim(dim)) 687 return b.create<tensor::DimOp>(loc, target, dim).result(); 688 return b.getIndexAttr(shapedType.getDimSize(dim)); 689 }); 690 auto t = tensor::ExtractSliceOp::inferRankReducedResultType( 691 insertOp.getSourceType().getRank(), 692 insertOp.dest().getType().cast<RankedTensorType>(), mixedOffsets, 693 mixedSizes, mixedStrides); 694 auto extractOp = b.create<tensor::ExtractSliceOp>( 695 loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides); 696 return extractOp.result(); 697 }, 698 newOps); 699 } 700 701 void mlir::linalg::registerBufferizableOpInterfaceExternalModels( 702 DialectRegistry ®istry) { 703 registry.addOpInterface<linalg::InitTensorOp, InitTensorOpInterface>(); 704 registry.addOpInterface<linalg::TiledLoopOp, TiledLoopOpInterface>(); 705 registry.addOpInterface<linalg::YieldOp, YieldOpInterface>(); 706 707 // Register all Linalg structured ops. `LinalgOp` is an interface and it is 708 // not possible to attach an external interface to an existing interface. 709 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. 710 LinalgOpInterfaceHelper< 711 #define GET_OP_LIST 712 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 713 >::registerOpInterface(registry); 714 } 715