1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Dominance.h" 16 #include "mlir/IR/Operation.h" 17 18 using namespace mlir; 19 using namespace linalg; 20 using namespace mlir::bufferization; 21 22 namespace { 23 24 // TODO: Ops in the linalg dialect can directly implement this interface. 25 26 /// Generic conversion for any LinalgOp on tensors. 27 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, 28 const BufferizationState &state) { 29 // Take a guard before anything else. 30 OpBuilder::InsertionGuard g(rewriter); 31 rewriter.setInsertionPoint(op); 32 33 // Nothing to do. This op is already bufferized. 34 if (op.hasBufferSemantics()) 35 return success(); 36 37 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need 38 // basis. 39 if (!op.hasTensorSemantics()) 40 return op->emitError() << "op does not have tensor semantics"; 41 42 // New input operands for the cloned op. 43 SmallVector<Value> newInputBuffers; 44 newInputBuffers.reserve(op.getNumInputs()); 45 for (OpOperand *opOperand : op.getInputOperands()) { 46 if (op.isScalar(opOperand)) { 47 newInputBuffers.push_back(opOperand->get()); 48 continue; 49 } 50 // Input operands are never written to. 51 newInputBuffers.push_back( 52 *state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true)); 53 } 54 55 // New output operands for the cloned op. 56 SmallVector<Value> newOutputBuffers; 57 for (OpResult opResult : op->getOpResults()) { 58 SmallVector<OpOperand *> aliasingOpOperands = 59 state.getAliasingOpOperand(opResult); 60 assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); 61 FailureOr<Value> resultBuffer = 62 state.getBuffer(rewriter, *aliasingOpOperands.front()); 63 if (failed(resultBuffer)) 64 return failure(); 65 newOutputBuffers.push_back(*resultBuffer); 66 } 67 68 // Merge input/output operands. 69 SmallVector<Value> newOperands = newInputBuffers; 70 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); 71 72 // Set insertion point now that potential alloc/dealloc are introduced. 73 rewriter.setInsertionPoint(op); 74 // Clone the op, but use the new operands. Move the existing block into the 75 // new op. Since the new op does not have any tensor results, it does not 76 // return anything. 77 assert(op->getNumRegions() == 1 && "expected that op has 1 region"); 78 auto newOp = cast<LinalgOp>(op.cloneWithoutRegions( 79 rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); 80 rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), 81 newOp->getRegion(0).begin()); 82 83 // Replace the results of the old op with the new output buffers. 84 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); 85 86 return success(); 87 } 88 89 /// Linalg OpResults usually bufferize inplace with their tied (output 90 /// OpOperands. However, if an output OpOperand is not used in the computation, 91 /// it is better to bufferize inplace with an actually used input OpOperand; 92 /// less memory will be touched that way. 93 /// 94 /// Example: 95 /// O(i, j) = A(i, j) + B(j) --> bufferizes inplace to: A(i, j) += B(j) 96 /// 97 /// O(i, j) = A(j, i) + B(j) --> cannot bufferize inplace with A because 98 /// indexing maps are not identical 99 /// 100 /// O(i, j) += A(i, j) + B(j) --> Output is used in computation. 101 /// This could bufferize inplace with A: 102 /// A(i, j) += O(i, j) + B(j) 103 /// However, we choose to bufferize inplace with O here, as there is no clear 104 /// benefit of choosing A. TODO: We may want to consider both options and make 105 /// an informed decision during analysis in the future. 106 static DenseMap<OpOperand *, OpResult> computeAliasingPairs(LinalgOp op) { 107 DenseMap<OpOperand *, OpResult> mapping; 108 for (OpResult opResult : op->getOpResults()) { 109 OpOperand *tiedOperand = 110 op.getOutputTensorOperands()[opResult.getResultNumber()]; 111 AffineMap outputIndexingMap = op.getTiedIndexingMap(tiedOperand); 112 bool onlyParallelIterators = op.getNumParallelLoops() == op.getNumLoops(); 113 bool tiedOperandUsed = op.payloadUsesValueFromOperand(tiedOperand); 114 115 // If the output arg is used in the computation or at least one iterator is 116 // not parallel, try to bufferize inplace with the corresponding output 117 // tensor. 118 if (tiedOperandUsed || !onlyParallelIterators) { 119 mapping[tiedOperand] = opResult; 120 continue; 121 } 122 123 // Otherwise, try to bufferize inplace with one of the inputs. 124 OpOperand *chosenOperand = nullptr; 125 for (OpOperand *opOperand : op.getInputTensorOperands()) { 126 if (opOperand->get().getType() != opResult.getType()) 127 continue; 128 if (!op.payloadUsesValueFromOperand(opOperand)) 129 continue; 130 if (op.getTiedIndexingMap(opOperand) != outputIndexingMap) 131 continue; 132 // No other OpResult bufferizes aliases with this OpOperand. 133 if (mapping.count(opOperand)) 134 continue; 135 assert(op.getTiedIndexingMap(opOperand).isProjectedPermutation() && 136 "expected projected permutation"); 137 chosenOperand = opOperand; 138 break; 139 } 140 141 // No suitable input tensor found. Use output tensor. 142 // TODO: This operand could bufferize inplace with OpOperands that have the 143 // correct type, even if they are not used inside the computation. 144 if (!chosenOperand) 145 chosenOperand = tiedOperand; 146 147 mapping[chosenOperand] = opResult; 148 } 149 return mapping; 150 } 151 152 /// Bufferization of linalg.generic. Replace with a new linalg.generic that 153 /// operates entirely on memrefs. 154 template <typename OpTy> 155 struct LinalgOpInterface 156 : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, 157 OpTy> { 158 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 159 const BufferizationState &state) const { 160 // Operand is read if it is used in the computation. 161 auto genericOp = cast<linalg::LinalgOp>(op); 162 return genericOp.payloadUsesValueFromOperand(&opOperand); 163 } 164 165 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 166 const BufferizationState &state) const { 167 // Operand is written to if it has an aliasing OpResult. For more details, 168 // see `computeAliasingPairs`. 169 auto bufferizableOp = cast<BufferizableOpInterface>(op); 170 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 171 } 172 173 SmallVector<OpOperand *> 174 getAliasingOpOperand(Operation *op, OpResult opResult, 175 const BufferizationState &state) const { 176 auto genericOp = cast<linalg::LinalgOp>(op); 177 178 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 179 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 180 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) 181 if (pairs[opOperand] == opResult) 182 return {opOperand}; 183 return {}; 184 } 185 186 SmallVector<OpResult> 187 getAliasingOpResult(Operation *op, OpOperand &opOperand, 188 const BufferizationState &state) const { 189 auto genericOp = cast<linalg::LinalgOp>(op); 190 191 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 192 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 193 if (!pairs.count(&opOperand)) 194 return {}; 195 return {pairs[&opOperand]}; 196 } 197 198 BufferRelation bufferRelation(Operation *op, OpResult opResult, 199 const BufferizationState &state) const { 200 return BufferRelation::Equivalent; 201 } 202 203 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 204 const BufferizationState &state) const { 205 return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state); 206 } 207 }; 208 209 struct InitTensorOpInterface 210 : public BufferizableOpInterface::ExternalModel<InitTensorOpInterface, 211 linalg::InitTensorOp> { 212 bool isMemoryWrite(Operation *op, OpResult opResult, 213 const BufferizationState &state) const { 214 // InitTensorOps allocate but do not write. 215 return false; 216 } 217 218 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 219 const BufferizationState &state) const { 220 auto initTensorOp = cast<linalg::InitTensorOp>(op); 221 222 // The InitTensorOp may have been eliminated. 223 if (initTensorOp->getUses().empty()) 224 return success(); 225 226 FailureOr<Value> alloc = 227 createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(), 228 state.getOptions().createDeallocs, state.getOptions()); 229 if (failed(alloc)) 230 return failure(); 231 replaceOpWithBufferizedValues(rewriter, op, *alloc); 232 return success(); 233 } 234 }; 235 236 /// Bufferization of linalg.tiled_loop. Replace with a new linalg.tiled_loop 237 /// that operates entirely on memrefs. 238 struct TiledLoopOpInterface 239 : public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface, 240 linalg::TiledLoopOp> { 241 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 242 const BufferizationState &state) const { 243 auto tiledLoopOp = cast<linalg::TiledLoopOp>(op); 244 245 // linalg.tiled_loop operands alone do not bufferize to a memory read, but 246 // one of the uses of their matching bbArgs may. 247 return state.isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand)); 248 } 249 250 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 251 const BufferizationState &state) const { 252 auto bufferizableOp = cast<BufferizableOpInterface>(op); 253 254 // Only operands with an aliasing OpResult (i.e., output operands) bufferize 255 // to a memory write. 256 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 257 } 258 259 SmallVector<OpResult> 260 getAliasingOpResult(Operation *op, OpOperand &opOperand, 261 const BufferizationState &state) const { 262 auto tiledLoopOp = cast<linalg::TiledLoopOp>(op); 263 264 // Output operands are tied to their corresponding OpResults. 265 OpResult opResult = tiledLoopOp.getTiedOpResult(opOperand); 266 if (!opResult) 267 return {}; 268 return {opResult}; 269 } 270 271 BufferRelation bufferRelation(Operation *op, OpResult opResult, 272 const BufferizationState &state) const { 273 return BufferRelation::Equivalent; 274 } 275 276 bool isWritable(Operation *op, Value value, 277 const BufferizationState &state) const { 278 // Interestingly, linalg::TiledLoopOp's bbArgs can **always** be viewed 279 // inplace from the perspective of nested ops: 280 // 1. Either the matching iter operand is not bufferized inplace and an 281 // alloc + optional copy makes the bbArg itself inplaceable. 282 // 2. Or the matching iter operand is bufferized inplace and bbArg just 283 // bufferizes to that too. 284 return true; 285 } 286 287 bool isAllocationHoistingBarrier(Operation *op) const { return true; } 288 289 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 290 const BufferizationState &state) const { 291 auto tiledLoopOp = cast<linalg::TiledLoopOp>(op); 292 293 // Compute new inputs, outputs and results. 294 SmallVector<Value> newInputs, newOutputs, newResults; 295 for (unsigned i = tiledLoopOp.getNumControlOperands(); 296 i < tiledLoopOp->getNumOperands(); ++i) { 297 OpOperand &operand = tiledLoopOp->getOpOperand(i); 298 Value rewrittenValue = operand.get(); 299 if (rewrittenValue.getType().isa<TensorType>()) { 300 FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, operand); 301 if (failed(bufferOrFailure)) 302 return failure(); 303 rewrittenValue = *bufferOrFailure; 304 } 305 if (i < 306 tiledLoopOp.getNumControlOperands() + tiledLoopOp.getNumInputs()) { 307 newInputs.push_back(rewrittenValue); 308 } else { 309 newOutputs.push_back(rewrittenValue); 310 if (operand.get().getType().isa<TensorType>()) 311 newResults.push_back(rewrittenValue); 312 } 313 } 314 315 // Create new TiledLoopOp. 316 auto newTiledLoopOp = rewriter.create<TiledLoopOp>( 317 tiledLoopOp.getLoc(), tiledLoopOp.lowerBound(), 318 tiledLoopOp.upperBound(), tiledLoopOp.step(), newInputs, newOutputs, 319 tiledLoopOp.iterator_types(), tiledLoopOp.distribution_types()); 320 321 // Remove terminator. 322 if (!newTiledLoopOp.getBody()->empty()) 323 rewriter.eraseOp(tiledLoopOp.getBody()->getTerminator()); 324 325 // Compute new loop body arguments. 326 SmallVector<Value> newBlockArgs, newRegionInOutArgs, oldRegionInOutArgs; 327 ValueRange newInductionVars = newTiledLoopOp.getInductionVars(); 328 newBlockArgs.append(newInductionVars.begin(), newInductionVars.end()); 329 330 ValueRange newRegionInArgs = newTiledLoopOp.getRegionInputArgs(); 331 ValueRange newRegionOutArgs = newTiledLoopOp.getRegionOutputArgs(); 332 newRegionInOutArgs.append(newRegionInArgs.begin(), newRegionInArgs.end()); 333 newRegionInOutArgs.append(newRegionOutArgs.begin(), newRegionOutArgs.end()); 334 335 ValueRange oldRegionInArgs = tiledLoopOp.getRegionInputArgs(); 336 ValueRange oldRegionOutArgs = tiledLoopOp.getRegionOutputArgs(); 337 oldRegionInOutArgs.append(oldRegionInArgs.begin(), oldRegionInArgs.end()); 338 oldRegionInOutArgs.append(oldRegionOutArgs.begin(), oldRegionOutArgs.end()); 339 assert(newRegionInArgs.size() == oldRegionInArgs.size() && 340 "expected same number of input args"); 341 assert(newRegionOutArgs.size() == oldRegionOutArgs.size() && 342 "expected same number of output args"); 343 344 for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) { 345 Value oldArg = std::get<0>(it); 346 Value newArg = std::get<1>(it); 347 rewriter.setInsertionPointToStart(newTiledLoopOp.getBody()); 348 if (oldArg.getType().isa<TensorType>()) { 349 newBlockArgs.push_back(rewriter.create<bufferization::ToTensorOp>( 350 oldArg.getLoc(), newArg)); 351 } else { 352 newBlockArgs.push_back(newArg); 353 } 354 } 355 356 // Move old body into new loop. 357 rewriter.mergeBlocks(tiledLoopOp.getBody(), newTiledLoopOp.getBody(), 358 newBlockArgs); 359 360 // Replace previous terminator with a new one that does not yield anything. 361 auto oldTerminator = 362 cast<linalg::YieldOp>(newTiledLoopOp.getBody()->getTerminator()); 363 rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody()); 364 auto newTerminator = 365 rewriter.create<linalg::YieldOp>(oldTerminator->getLoc()); 366 367 // Copy buffer of yielded tensor to output buffer. If everything bufferized 368 // inplace, this copy will fold away. 369 rewriter.setInsertionPoint(newTerminator); 370 for (auto it : llvm::zip(oldTerminator.values(), newOutputs)) { 371 Value output = std::get<1>(it); 372 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 373 newTerminator.getLoc(), output.getType(), std::get<0>(it)); 374 if (failed(createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp, 375 output, state.getOptions()))) 376 return failure(); 377 } 378 379 // Erase old terminator. 380 rewriter.eraseOp(oldTerminator); 381 382 // Replace results and delete old op. 383 replaceOpWithBufferizedValues(rewriter, op, newResults); 384 385 return success(); 386 } 387 }; 388 389 /// Bufferization of linalg.yield. Bufferized as part of linalg.tiled_loop's 390 /// bufferization. 391 struct YieldOpInterface 392 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 393 linalg::YieldOp> { 394 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 395 const BufferizationState &state) const { 396 return true; 397 } 398 399 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 400 const BufferizationState &state) const { 401 return false; 402 } 403 404 SmallVector<OpResult> 405 getAliasingOpResult(Operation *op, OpOperand &opOperand, 406 const BufferizationState &state) const { 407 return {}; 408 } 409 410 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 411 const BufferizationState &state) const { 412 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 413 // may be generated inside the block. We should not return/yield allocations 414 // when possible. 415 return true; 416 } 417 418 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 419 const BufferizationState &state) const { 420 auto yieldOp = cast<linalg::YieldOp>(op); 421 422 if (!yieldOp->getParentOfType<TiledLoopOp>()) 423 return yieldOp->emitError( 424 "expected that linalg.yield terminates a tiled_loop"); 425 426 assert(yieldOp->getOpOperands().empty() && 427 "expected that linalg.yield was bufferized together with" 428 " tiled_loop"); 429 return success(); 430 } 431 }; 432 433 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers 434 /// the `BufferizableOpInterface` with each of them. 435 template <typename... OpTys> 436 struct LinalgOpInterfaceHelper; 437 438 template <typename First, typename... Others> 439 struct LinalgOpInterfaceHelper<First, Others...> { 440 static void registerOpInterface(DialectRegistry ®istry) { 441 registry.addOpInterface<First, LinalgOpInterface<First>>(); 442 LinalgOpInterfaceHelper<Others...>::registerOpInterface(registry); 443 } 444 }; 445 446 template <> 447 struct LinalgOpInterfaceHelper<> { 448 static void registerOpInterface(DialectRegistry ®istry) {} 449 }; 450 451 } // namespace 452 453 /// Return true if all `neededValues` are in scope at the given 454 /// `insertionPoint`. 455 static bool 456 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, 457 Operation *insertionPoint, 458 const SmallVector<Value> &neededValues) { 459 for (Value val : neededValues) { 460 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 461 Block *owner = bbArg.getOwner(); 462 if (!owner->findAncestorOpInBlock(*insertionPoint)) 463 return false; 464 } else { 465 auto opResult = val.cast<OpResult>(); 466 if (!domInfo.dominates(opResult.getOwner(), insertionPoint)) 467 return false; 468 } 469 } 470 return true; 471 } 472 473 /// Return true if the given `insertionPoint` dominates all uses of 474 /// `initTensorOp`. 475 static bool insertionPointDominatesUses(const DominanceInfo &domInfo, 476 Operation *insertionPoint, 477 Operation *initTensorOp) { 478 for (Operation *user : initTensorOp->getUsers()) 479 if (!domInfo.dominates(insertionPoint, user)) 480 return false; 481 return true; 482 } 483 484 /// Find a valid insertion point for a replacement of `initTensorOp`, assuming 485 /// that the replacement may use any value from `neededValues`. 486 static Operation * 487 findValidInsertionPoint(Operation *initTensorOp, 488 const SmallVector<Value> &neededValues) { 489 DominanceInfo domInfo; 490 491 // Gather all possible insertion points: the location of `initTensorOp` and 492 // right after the definition of each value in `neededValues`. 493 SmallVector<Operation *> insertionPointCandidates; 494 insertionPointCandidates.push_back(initTensorOp); 495 for (Value val : neededValues) { 496 // Note: The anchor op is using all of `neededValues`, so: 497 // * in case of a block argument: There must be at least one op in the block 498 // (the anchor op or one of its parents). 499 // * in case of an OpResult: There must be at least one op right after the 500 // defining op (the anchor op or one of its 501 // parents). 502 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 503 insertionPointCandidates.push_back( 504 &bbArg.getOwner()->getOperations().front()); 505 } else { 506 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode()); 507 } 508 } 509 510 // Select first matching insertion point. 511 for (Operation *insertionPoint : insertionPointCandidates) { 512 // Check if all needed values are in scope. 513 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, 514 neededValues)) 515 continue; 516 // Check if the insertion point is before all uses. 517 if (!insertionPointDominatesUses(domInfo, insertionPoint, initTensorOp)) 518 continue; 519 return insertionPoint; 520 } 521 522 // No suitable insertion point was found. 523 return nullptr; 524 } 525 526 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced 527 /// with the the result of `rewriteFunc` if it is anchored on a matching 528 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def 529 /// chain, starting from the OpOperand and always following the aliasing 530 /// OpOperand, that eventually ends at a single InitTensorOp. 531 LogicalResult mlir::linalg::eliminateInitTensors( 532 Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, 533 AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, 534 SmallVector<Operation *> &newOps) { 535 OpBuilder b(op->getContext()); 536 537 WalkResult status = op->walk([&](Operation *op) { 538 for (OpOperand &operand : op->getOpOperands()) { 539 // Skip operands that do not bufferize inplace. 540 if (!aliasInfo.isInPlace(operand)) 541 continue; 542 // All values that are needed to create the replacement op. 543 SmallVector<Value> neededValues; 544 // Is this a matching OpOperand? 545 if (!anchorMatchFunc(operand, neededValues)) 546 continue; 547 SetVector<Value> maybeInitTensor = 548 state.findValueInReverseUseDefChain(operand.get(), [&](Value val) { 549 // Continue traversal until this function returns true. 550 OpResult opResult = val.dyn_cast<OpResult>(); 551 if (!opResult) 552 return true; 553 SmallVector<OpOperand *> opOperands = 554 state.getAliasingOpOperand(opResult); 555 if (!llvm::all_of(opOperands, [&](OpOperand *operand) { 556 return aliasInfo.isInPlace(*operand); 557 })) 558 return true; 559 // Only equivalent tensors are supported at the moment. 560 // TODO: Support cases such as extract_slice(init_tensor) 561 return !llvm::all_of(opOperands, [&](OpOperand *operand) { 562 return aliasInfo.areEquivalentBufferizedValues(operand->get(), 563 opResult); 564 }); 565 }); 566 567 // Replace only if the reverse use-def chain ends at exactly one 568 // InitTensorOp. 569 if (maybeInitTensor.size() != 1 || 570 !maybeInitTensor.front().getDefiningOp<InitTensorOp>()) 571 return WalkResult::skip(); 572 Value initTensor = maybeInitTensor.front(); 573 574 // Find a suitable insertion point. 575 Operation *insertionPoint = 576 findValidInsertionPoint(initTensor.getDefiningOp(), neededValues); 577 if (!insertionPoint) 578 continue; 579 580 // Create a replacement for the InitTensorOp. 581 b.setInsertionPoint(insertionPoint); 582 Value replacement = rewriteFunc(b, initTensor.getLoc(), operand); 583 if (!replacement) 584 continue; 585 586 // Uses of the InitTensorOp are replaced here, but the op is not deleted. 587 // InitTensorOps without uses are ignored by the bufferization. 588 initTensor.replaceAllUsesWith(replacement); 589 aliasInfo.createAliasInfoEntry(replacement); 590 aliasInfo.unionAliasSets(initTensor, replacement); 591 aliasInfo.unionEquivalenceClasses(initTensor, replacement); 592 593 // Register replacement ops. 594 if (Operation *newOp = replacement.getDefiningOp()) 595 newOps.push_back(newOp); 596 } 597 598 // Advance to the next operation. 599 return WalkResult::advance(); 600 }); 601 602 return failure(status.wasInterrupted()); 603 } 604 605 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp can be 606 /// eliminated if it is eventually inserted into another tensor (and some other 607 /// conditions are met). 608 /// 609 /// E.g.: 610 /// %0 = linalg.init_tensor 611 /// %1 = linalg.fill(%cst, %0) {inplace = [true]} 612 /// %2 = tensor.insert_slice %1 into %t[10][20][1] 613 /// 614 /// InitTensorOp elimination will try to fill %t inplace instead of filling a 615 /// new allocation %0 and inserting it into %t. This is done by replacing the 616 /// InitTensorOp with: 617 /// 618 /// %0 = tensor.extract_slice %t[10][20][1] 619 /// 620 /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets 621 /// those bufferize inplace in the absence of other conflicts. 622 /// 623 /// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert 624 /// source's reverse use-def chain is eliminated if: 625 /// * The InsertSliceOp was decided to bufferize inplace. 626 /// * On the reverse use-def chain path from the InsertSliceOp to the 627 /// InitTensorOp, all ops were decided to bufferize inplace and the buffer 628 /// relation is "equivalent" (TODO: can be relaxed if needed). 629 /// * The reverse use-def chain has exactly one end, which is the InitTensorOp. 630 /// 631 /// Note that the newly inserted ExtractSliceOp may have to bufferize 632 /// out-of-place due to RaW conflicts. 633 LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep( 634 Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, 635 SmallVector<Operation *> &newOps) { 636 return eliminateInitTensors( 637 op, state, aliasInfo, 638 /*anchorMatchFunc=*/ 639 [&](OpOperand &operand, SmallVector<Value> &neededValues) { 640 auto insertSliceOp = 641 dyn_cast<tensor::InsertSliceOp>(operand.getOwner()); 642 if (!insertSliceOp) 643 return false; 644 // Only inplace bufferized InsertSliceOps are eligible. 645 if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/)) 646 return false; 647 if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) 648 return false; 649 650 // Collect all values that are needed to construct the replacement op. 651 neededValues.append(insertSliceOp.offsets().begin(), 652 insertSliceOp.offsets().end()); 653 neededValues.append(insertSliceOp.sizes().begin(), 654 insertSliceOp.sizes().end()); 655 neededValues.append(insertSliceOp.strides().begin(), 656 insertSliceOp.strides().end()); 657 neededValues.push_back(insertSliceOp.dest()); 658 659 return true; 660 }, 661 /*rewriteFunc=*/ 662 [](OpBuilder &b, Location loc, OpOperand &operand) { 663 auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner()); 664 // Expand offsets, sizes and strides to the full rank to handle the 665 // rank-reducing case. 666 SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets(); 667 SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes(); 668 SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides(); 669 OffsetSizeAndStrideOpInterface::expandToRank( 670 insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides, 671 [&](Value target, int64_t dim) -> OpFoldResult { 672 auto shapedType = target.getType().cast<ShapedType>(); 673 if (shapedType.isDynamicDim(dim)) 674 return b.create<tensor::DimOp>(loc, target, dim).result(); 675 return b.getIndexAttr(shapedType.getDimSize(dim)); 676 }); 677 auto t = tensor::ExtractSliceOp::inferRankReducedResultType( 678 insertOp.getSourceType().getRank(), 679 insertOp.dest().getType().cast<RankedTensorType>(), mixedOffsets, 680 mixedSizes, mixedStrides); 681 auto extractOp = b.create<tensor::ExtractSliceOp>( 682 loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides); 683 return extractOp.result(); 684 }, 685 newOps); 686 } 687 688 void mlir::linalg::registerBufferizableOpInterfaceExternalModels( 689 DialectRegistry ®istry) { 690 registry.addOpInterface<linalg::InitTensorOp, InitTensorOpInterface>(); 691 registry.addOpInterface<linalg::TiledLoopOp, TiledLoopOpInterface>(); 692 registry.addOpInterface<linalg::YieldOp, YieldOpInterface>(); 693 694 // Register all Linalg structured ops. `LinalgOp` is an interface and it is 695 // not possible to attach an external interface to an existing interface. 696 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. 697 LinalgOpInterfaceHelper< 698 #define GET_OP_LIST 699 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 700 >::registerOpInterface(registry); 701 } 702