1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Dominance.h" 16 #include "mlir/IR/Operation.h" 17 18 using namespace mlir; 19 using namespace linalg; 20 using namespace mlir::bufferization; 21 22 namespace { 23 24 // TODO: Ops in the linalg dialect can directly implement this interface. 25 26 /// Generic conversion for any LinalgOp on tensors. 27 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, 28 const BufferizationState &state) { 29 // Take a guard before anything else. 30 OpBuilder::InsertionGuard g(rewriter); 31 rewriter.setInsertionPoint(op); 32 33 // Nothing to do. This op is already bufferized. 34 if (op.hasBufferSemantics()) 35 return success(); 36 37 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need 38 // basis. 39 if (!op.hasTensorSemantics()) 40 return op->emitError() << "op does not have tensor semantics"; 41 42 // New input operands for the cloned op. 43 SmallVector<Value> newInputBuffers; 44 newInputBuffers.reserve(op.getNumInputs()); 45 for (OpOperand *opOperand : op.getInputOperands()) { 46 if (op.isScalar(opOperand)) { 47 newInputBuffers.push_back(opOperand->get()); 48 continue; 49 } 50 // Input operands are never written to. 51 newInputBuffers.push_back( 52 *state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true)); 53 } 54 55 // New output operands for the cloned op. 56 SmallVector<Value> newOutputBuffers; 57 for (OpResult opResult : op->getOpResults()) { 58 SmallVector<OpOperand *> aliasingOpOperands = 59 state.getAliasingOpOperand(opResult); 60 assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); 61 FailureOr<Value> resultBuffer = 62 state.getBuffer(rewriter, *aliasingOpOperands.front()); 63 if (failed(resultBuffer)) 64 return failure(); 65 newOutputBuffers.push_back(*resultBuffer); 66 } 67 68 // Merge input/output operands. 69 SmallVector<Value> newOperands = newInputBuffers; 70 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); 71 72 // Set insertion point now that potential alloc/dealloc are introduced. 73 rewriter.setInsertionPoint(op); 74 // Clone the op, but use the new operands. Move the existing block into the 75 // new op. Since the new op does not have any tensor results, it does not 76 // return anything. 77 assert(op->getNumRegions() == 1 && "expected that op has 1 region"); 78 auto newOp = cast<LinalgOp>(op.cloneWithoutRegions( 79 rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); 80 rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), 81 newOp->getRegion(0).begin()); 82 83 // Replace the results of the old op with the new output buffers. 84 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); 85 86 return success(); 87 } 88 89 /// Linalg OpResults usually bufferize inplace with their tied (output 90 /// OpOperands. However, if an output OpOperand is not used in the computation, 91 /// it is better to bufferize inplace with an actually used input OpOperand; 92 /// less memory will be touched that way. 93 /// 94 /// Example: 95 /// O(i, j) = A(i, j) + B(j) --> bufferizes inplace to: A(i, j) += B(j) 96 /// 97 /// O(i, j) = A(j, i) + B(j) --> cannot bufferize inplace with A because 98 /// indexing maps are not identical 99 /// 100 /// O(i, j) += A(i, j) + B(j) --> Output is used in computation. 101 /// This could bufferize inplace with A: 102 /// A(i, j) += O(i, j) + B(j) 103 /// However, we choose to bufferize inplace with O here, as there is no clear 104 /// benefit of choosing A. TODO: We may want to consider both options and make 105 /// an informed decision during analysis in the future. 106 static DenseMap<OpOperand *, OpResult> computeAliasingPairs(LinalgOp op) { 107 DenseMap<OpOperand *, OpResult> mapping; 108 for (OpResult opResult : op->getOpResults()) { 109 OpOperand *tiedOperand = 110 op.getOutputTensorOperands()[opResult.getResultNumber()]; 111 AffineMap outputIndexingMap = op.getTiedIndexingMap(tiedOperand); 112 bool onlyParallelIterators = op.getNumParallelLoops() == op.getNumLoops(); 113 bool tiedOperandUsed = op.payloadUsesValueFromOperand(tiedOperand); 114 115 // If the output arg is used in the computation or at least one iterator is 116 // not parallel, try to bufferize inplace with the corresponding output 117 // tensor. 118 if (tiedOperandUsed || !onlyParallelIterators) { 119 mapping[tiedOperand] = opResult; 120 continue; 121 } 122 123 // Otherwise, try to bufferize inplace with one of the inputs. 124 OpOperand *chosenOperand = nullptr; 125 for (OpOperand *opOperand : op.getInputTensorOperands()) { 126 if (opOperand->get().getType() != opResult.getType()) 127 continue; 128 if (!op.payloadUsesValueFromOperand(opOperand)) 129 continue; 130 if (op.getTiedIndexingMap(opOperand) != outputIndexingMap) 131 continue; 132 // No other OpResult bufferizes aliases with this OpOperand. 133 if (mapping.count(opOperand)) 134 continue; 135 assert(op.getTiedIndexingMap(opOperand).isProjectedPermutation() && 136 "expected projected permutation"); 137 chosenOperand = opOperand; 138 break; 139 } 140 141 // No suitable input tensor found. Use output tensor. 142 // TODO: This operand could bufferize inplace with OpOperands that have the 143 // correct type, even if they are not used inside the computation. 144 if (!chosenOperand) 145 chosenOperand = tiedOperand; 146 147 mapping[chosenOperand] = opResult; 148 } 149 return mapping; 150 } 151 152 /// Bufferization of linalg.generic. Replace with a new linalg.generic that 153 /// operates entirely on memrefs. 154 template <typename OpTy> 155 struct LinalgOpInterface 156 : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, 157 OpTy> { 158 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 159 const BufferizationState &state) const { 160 // Operand is read if it is used in the computation. 161 auto genericOp = cast<linalg::LinalgOp>(op); 162 return genericOp.payloadUsesValueFromOperand(&opOperand); 163 } 164 165 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 166 const BufferizationState &state) const { 167 // Operand is written to if it has an aliasing OpResult. 168 auto bufferizableOp = cast<BufferizableOpInterface>(op); 169 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 170 } 171 172 SmallVector<OpOperand *> 173 getAliasingOpOperand(Operation *op, OpResult opResult, 174 const BufferizationState &state) const { 175 auto genericOp = cast<linalg::LinalgOp>(op); 176 177 // By default, the i-th OpResult may alias with the i-th "out" tensor. 178 if (state.getOptions().alwaysAliasingWithDest) 179 return {genericOp.getOutputOperand(opResult.getResultNumber())}; 180 181 // We can try to be smart and alias in-place with an "in" tensor if the 182 // corresponding "out" tensor is not used in the computation. 183 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 184 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 185 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) 186 if (pairs[opOperand] == opResult) 187 return {opOperand}; 188 return {}; 189 } 190 191 SmallVector<OpResult> 192 getAliasingOpResult(Operation *op, OpOperand &opOperand, 193 const BufferizationState &state) const { 194 auto genericOp = cast<linalg::LinalgOp>(op); 195 196 // By default, the i-th "out" tensor may alias with the i-th OpResult. 197 if (state.getOptions().alwaysAliasingWithDest) { 198 if (genericOp.isOutputTensor(&opOperand)) 199 return {genericOp.getTiedOpResult(&opOperand)}; 200 return {}; 201 } 202 203 // We can try to be smart. See comment in `getAliasingOpOperand`. 204 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 205 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 206 if (!pairs.count(&opOperand)) 207 return {}; 208 return {pairs[&opOperand]}; 209 } 210 211 BufferRelation bufferRelation(Operation *op, OpResult opResult, 212 const BufferizationState &state) const { 213 return BufferRelation::Equivalent; 214 } 215 216 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 217 const BufferizationState &state) const { 218 return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state); 219 } 220 }; 221 222 struct InitTensorOpInterface 223 : public BufferizableOpInterface::ExternalModel<InitTensorOpInterface, 224 linalg::InitTensorOp> { 225 bool isMemoryWrite(Operation *op, OpResult opResult, 226 const BufferizationState &state) const { 227 // InitTensorOps allocate but do not write. 228 return false; 229 } 230 231 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 232 const BufferizationState &state) const { 233 auto initTensorOp = cast<linalg::InitTensorOp>(op); 234 235 // The InitTensorOp may have been eliminated. 236 if (initTensorOp->getUses().empty()) 237 return success(); 238 239 FailureOr<Value> alloc = 240 createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(), 241 state.getOptions().createDeallocs, state.getOptions()); 242 if (failed(alloc)) 243 return failure(); 244 replaceOpWithBufferizedValues(rewriter, op, *alloc); 245 return success(); 246 } 247 }; 248 249 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers 250 /// the `BufferizableOpInterface` with each of them. 251 template <typename... OpTys> 252 struct LinalgOpInterfaceHelper; 253 254 template <typename First, typename... Others> 255 struct LinalgOpInterfaceHelper<First, Others...> { 256 static void registerOpInterface(DialectRegistry ®istry) { 257 registry.addOpInterface<First, LinalgOpInterface<First>>(); 258 LinalgOpInterfaceHelper<Others...>::registerOpInterface(registry); 259 } 260 }; 261 262 template <> 263 struct LinalgOpInterfaceHelper<> { 264 static void registerOpInterface(DialectRegistry ®istry) {} 265 }; 266 267 } // namespace 268 269 /// Return true if all `neededValues` are in scope at the given 270 /// `insertionPoint`. 271 static bool 272 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, 273 Operation *insertionPoint, 274 const SmallVector<Value> &neededValues) { 275 for (Value val : neededValues) { 276 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 277 Block *owner = bbArg.getOwner(); 278 if (!owner->findAncestorOpInBlock(*insertionPoint)) 279 return false; 280 } else { 281 auto opResult = val.cast<OpResult>(); 282 if (!domInfo.dominates(opResult.getOwner(), insertionPoint)) 283 return false; 284 } 285 } 286 return true; 287 } 288 289 /// Return true if the given `insertionPoint` dominates all uses of 290 /// `initTensorOp`. 291 static bool insertionPointDominatesUses(const DominanceInfo &domInfo, 292 Operation *insertionPoint, 293 Operation *initTensorOp) { 294 for (Operation *user : initTensorOp->getUsers()) 295 if (!domInfo.dominates(insertionPoint, user)) 296 return false; 297 return true; 298 } 299 300 /// Find a valid insertion point for a replacement of `initTensorOp`, assuming 301 /// that the replacement may use any value from `neededValues`. 302 static Operation * 303 findValidInsertionPoint(Operation *initTensorOp, 304 const SmallVector<Value> &neededValues) { 305 DominanceInfo domInfo; 306 307 // Gather all possible insertion points: the location of `initTensorOp` and 308 // right after the definition of each value in `neededValues`. 309 SmallVector<Operation *> insertionPointCandidates; 310 insertionPointCandidates.push_back(initTensorOp); 311 for (Value val : neededValues) { 312 // Note: The anchor op is using all of `neededValues`, so: 313 // * in case of a block argument: There must be at least one op in the block 314 // (the anchor op or one of its parents). 315 // * in case of an OpResult: There must be at least one op right after the 316 // defining op (the anchor op or one of its 317 // parents). 318 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 319 insertionPointCandidates.push_back( 320 &bbArg.getOwner()->getOperations().front()); 321 } else { 322 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode()); 323 } 324 } 325 326 // Select first matching insertion point. 327 for (Operation *insertionPoint : insertionPointCandidates) { 328 // Check if all needed values are in scope. 329 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, 330 neededValues)) 331 continue; 332 // Check if the insertion point is before all uses. 333 if (!insertionPointDominatesUses(domInfo, insertionPoint, initTensorOp)) 334 continue; 335 return insertionPoint; 336 } 337 338 // No suitable insertion point was found. 339 return nullptr; 340 } 341 342 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced 343 /// with the the result of `rewriteFunc` if it is anchored on a matching 344 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def 345 /// chain, starting from the OpOperand and always following the aliasing 346 /// OpOperand, that eventually ends at a single InitTensorOp. 347 LogicalResult mlir::linalg::eliminateInitTensors( 348 Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, 349 AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, 350 SmallVector<Operation *> &newOps) { 351 OpBuilder b(op->getContext()); 352 353 WalkResult status = op->walk([&](Operation *op) { 354 for (OpOperand &operand : op->getOpOperands()) { 355 // Skip operands that do not bufferize inplace. 356 if (!aliasInfo.isInPlace(operand)) 357 continue; 358 // All values that are needed to create the replacement op. 359 SmallVector<Value> neededValues; 360 // Is this a matching OpOperand? 361 if (!anchorMatchFunc(operand, neededValues)) 362 continue; 363 SetVector<Value> maybeInitTensor = 364 state.findValueInReverseUseDefChain(operand.get(), [&](Value val) { 365 // Continue traversal until this function returns true. 366 OpResult opResult = val.dyn_cast<OpResult>(); 367 if (!opResult) 368 return true; 369 SmallVector<OpOperand *> opOperands = 370 state.getAliasingOpOperand(opResult); 371 if (!llvm::all_of(opOperands, [&](OpOperand *operand) { 372 return aliasInfo.isInPlace(*operand); 373 })) 374 return true; 375 // Only equivalent tensors are supported at the moment. 376 // TODO: Support cases such as extract_slice(init_tensor) 377 return !llvm::all_of(opOperands, [&](OpOperand *operand) { 378 return aliasInfo.areEquivalentBufferizedValues(operand->get(), 379 opResult); 380 }); 381 }); 382 383 // Replace only if the reverse use-def chain ends at exactly one 384 // InitTensorOp. 385 if (maybeInitTensor.size() != 1 || 386 !maybeInitTensor.front().getDefiningOp<InitTensorOp>()) 387 return WalkResult::skip(); 388 Value initTensor = maybeInitTensor.front(); 389 390 // Find a suitable insertion point. 391 Operation *insertionPoint = 392 findValidInsertionPoint(initTensor.getDefiningOp(), neededValues); 393 if (!insertionPoint) 394 continue; 395 396 // Create a replacement for the InitTensorOp. 397 b.setInsertionPoint(insertionPoint); 398 Value replacement = rewriteFunc(b, initTensor.getLoc(), operand); 399 if (!replacement) 400 continue; 401 402 // Uses of the InitTensorOp are replaced here, but the op is not deleted. 403 // InitTensorOps without uses are ignored by the bufferization. 404 initTensor.replaceAllUsesWith(replacement); 405 aliasInfo.createAliasInfoEntry(replacement); 406 aliasInfo.unionAliasSets(initTensor, replacement); 407 aliasInfo.unionEquivalenceClasses(initTensor, replacement); 408 409 // Register replacement ops. 410 if (Operation *newOp = replacement.getDefiningOp()) 411 newOps.push_back(newOp); 412 } 413 414 // Advance to the next operation. 415 return WalkResult::advance(); 416 }); 417 418 return failure(status.wasInterrupted()); 419 } 420 421 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp can be 422 /// eliminated if it is eventually inserted into another tensor (and some other 423 /// conditions are met). 424 /// 425 /// E.g.: 426 /// %0 = linalg.init_tensor 427 /// %1 = linalg.fill(%cst, %0) {inplace = [true]} 428 /// %2 = tensor.insert_slice %1 into %t[10][20][1] 429 /// 430 /// InitTensorOp elimination will try to fill %t inplace instead of filling a 431 /// new allocation %0 and inserting it into %t. This is done by replacing the 432 /// InitTensorOp with: 433 /// 434 /// %0 = tensor.extract_slice %t[10][20][1] 435 /// 436 /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets 437 /// those bufferize inplace in the absence of other conflicts. 438 /// 439 /// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert 440 /// source's reverse use-def chain is eliminated if: 441 /// * The InsertSliceOp was decided to bufferize inplace. 442 /// * On the reverse use-def chain path from the InsertSliceOp to the 443 /// InitTensorOp, all ops were decided to bufferize inplace and the buffer 444 /// relation is "equivalent" (TODO: can be relaxed if needed). 445 /// * The reverse use-def chain has exactly one end, which is the InitTensorOp. 446 /// 447 /// Note that the newly inserted ExtractSliceOp may have to bufferize 448 /// out-of-place due to RaW conflicts. 449 LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep( 450 Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, 451 SmallVector<Operation *> &newOps) { 452 return eliminateInitTensors( 453 op, state, aliasInfo, 454 /*anchorMatchFunc=*/ 455 [&](OpOperand &operand, SmallVector<Value> &neededValues) { 456 auto insertSliceOp = 457 dyn_cast<tensor::InsertSliceOp>(operand.getOwner()); 458 if (!insertSliceOp) 459 return false; 460 // Only inplace bufferized InsertSliceOps are eligible. 461 if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/)) 462 return false; 463 if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) 464 return false; 465 466 // Collect all values that are needed to construct the replacement op. 467 neededValues.append(insertSliceOp.offsets().begin(), 468 insertSliceOp.offsets().end()); 469 neededValues.append(insertSliceOp.sizes().begin(), 470 insertSliceOp.sizes().end()); 471 neededValues.append(insertSliceOp.strides().begin(), 472 insertSliceOp.strides().end()); 473 neededValues.push_back(insertSliceOp.dest()); 474 475 return true; 476 }, 477 /*rewriteFunc=*/ 478 [](OpBuilder &b, Location loc, OpOperand &operand) { 479 auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner()); 480 // Expand offsets, sizes and strides to the full rank to handle the 481 // rank-reducing case. 482 SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets(); 483 SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes(); 484 SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides(); 485 OffsetSizeAndStrideOpInterface::expandToRank( 486 insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides, 487 [&](Value target, int64_t dim) -> OpFoldResult { 488 auto shapedType = target.getType().cast<ShapedType>(); 489 if (shapedType.isDynamicDim(dim)) 490 return b.create<tensor::DimOp>(loc, target, dim).result(); 491 return b.getIndexAttr(shapedType.getDimSize(dim)); 492 }); 493 auto t = tensor::ExtractSliceOp::inferRankReducedResultType( 494 insertOp.getSourceType().getRank(), 495 insertOp.dest().getType().cast<RankedTensorType>(), mixedOffsets, 496 mixedSizes, mixedStrides); 497 auto extractOp = b.create<tensor::ExtractSliceOp>( 498 loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides); 499 return extractOp.result(); 500 }, 501 newOps); 502 } 503 504 void mlir::linalg::registerBufferizableOpInterfaceExternalModels( 505 DialectRegistry ®istry) { 506 registry.addOpInterface<linalg::InitTensorOp, InitTensorOpInterface>(); 507 508 // Register all Linalg structured ops. `LinalgOp` is an interface and it is 509 // not possible to attach an external interface to an existing interface. 510 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. 511 LinalgOpInterfaceHelper< 512 #define GET_OP_LIST 513 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 514 >::registerOpInterface(registry); 515 } 516