1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Dominance.h" 16 #include "mlir/IR/Operation.h" 17 18 using namespace mlir; 19 using namespace linalg; 20 using namespace mlir::bufferization; 21 22 namespace { 23 24 // TODO: Ops in the linalg dialect can directly implement this interface. 25 26 /// Generic conversion for any LinalgOp on tensors. 27 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, 28 BufferizationState &state) { 29 // Take a guard before anything else. 30 OpBuilder::InsertionGuard g(rewriter); 31 rewriter.setInsertionPoint(op); 32 33 // Nothing to do. This op is already bufferized. 34 if (op.hasBufferSemantics()) 35 return success(); 36 37 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need 38 // basis. 39 if (!op.hasTensorSemantics()) 40 return op->emitError() << "op does not have tensor semantics"; 41 42 // New input operands for the cloned op. 43 SmallVector<Value> newInputBuffers; 44 newInputBuffers.reserve(op.getNumInputs()); 45 for (OpOperand *opOperand : op.getInputOperands()) { 46 if (op.isScalar(opOperand)) { 47 newInputBuffers.push_back(opOperand->get()); 48 continue; 49 } 50 // Input operands are never written to. 51 newInputBuffers.push_back( 52 *state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true)); 53 } 54 55 // New output operands for the cloned op. 56 SmallVector<Value> newOutputBuffers; 57 for (OpResult opResult : op->getOpResults()) { 58 SmallVector<OpOperand *> aliasingOpOperands = 59 state.getAnalysisState().getAliasingOpOperand(opResult); 60 assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); 61 FailureOr<Value> resultBuffer = 62 state.getBuffer(rewriter, *aliasingOpOperands.front()); 63 if (failed(resultBuffer)) 64 return failure(); 65 newOutputBuffers.push_back(*resultBuffer); 66 } 67 68 // Merge input/output operands. 69 SmallVector<Value> newOperands = newInputBuffers; 70 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); 71 72 // Set insertion point now that potential alloc/dealloc are introduced. 73 rewriter.setInsertionPoint(op); 74 // Clone the op, but use the new operands. Move the existing block into the 75 // new op. Since the new op does not have any tensor results, it does not 76 // return anything. 77 assert(op->getNumRegions() == 1 && "expected that op has 1 region"); 78 auto newOp = cast<LinalgOp>(op.cloneWithoutRegions( 79 rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); 80 rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), 81 newOp->getRegion(0).begin()); 82 83 // Replace the results of the old op with the new output buffers. 84 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); 85 86 return success(); 87 } 88 89 /// Linalg OpResults usually bufferize inplace with their tied (output 90 /// OpOperands. However, if an output OpOperand is not used in the computation, 91 /// it is better to bufferize inplace with an actually used input OpOperand; 92 /// less memory will be touched that way. 93 /// 94 /// Example: 95 /// O(i, j) = A(i, j) + B(j) --> bufferizes inplace to: A(i, j) += B(j) 96 /// 97 /// O(i, j) = A(j, i) + B(j) --> cannot bufferize inplace with A because 98 /// indexing maps are not identical 99 /// 100 /// O(i, j) += A(i, j) + B(j) --> Output is used in computation. 101 /// This could bufferize inplace with A: 102 /// A(i, j) += O(i, j) + B(j) 103 /// However, we choose to bufferize inplace with O here, as there is no clear 104 /// benefit of choosing A. TODO: We may want to consider both options and make 105 /// an informed decision during analysis in the future. 106 static DenseMap<OpOperand *, OpResult> computeAliasingPairs(LinalgOp op) { 107 DenseMap<OpOperand *, OpResult> mapping; 108 for (OpResult opResult : op->getOpResults()) { 109 OpOperand *tiedOperand = 110 op.getOutputTensorOperands()[opResult.getResultNumber()]; 111 AffineMap outputIndexingMap = op.getTiedIndexingMap(tiedOperand); 112 bool onlyParallelIterators = op.getNumParallelLoops() == op.getNumLoops(); 113 bool tiedOperandUsed = op.payloadUsesValueFromOperand(tiedOperand); 114 115 // If the output arg is used in the computation or at least one iterator is 116 // not parallel, try to bufferize inplace with the corresponding output 117 // tensor. 118 if (tiedOperandUsed || !onlyParallelIterators) { 119 mapping[tiedOperand] = opResult; 120 continue; 121 } 122 123 // Otherwise, try to bufferize inplace with one of the inputs. 124 OpOperand *chosenOperand = nullptr; 125 for (OpOperand *opOperand : op.getInputTensorOperands()) { 126 if (opOperand->get().getType() != opResult.getType()) 127 continue; 128 if (!op.payloadUsesValueFromOperand(opOperand)) 129 continue; 130 if (op.getTiedIndexingMap(opOperand) != outputIndexingMap) 131 continue; 132 // No other OpResult bufferizes aliases with this OpOperand. 133 if (mapping.count(opOperand)) 134 continue; 135 assert(op.getTiedIndexingMap(opOperand).isProjectedPermutation() && 136 "expected projected permutation"); 137 chosenOperand = opOperand; 138 break; 139 } 140 141 // No suitable input tensor found. Use output tensor. 142 // TODO: This operand could bufferize inplace with OpOperands that have the 143 // correct type, even if they are not used inside the computation. 144 if (!chosenOperand) 145 chosenOperand = tiedOperand; 146 147 mapping[chosenOperand] = opResult; 148 } 149 return mapping; 150 } 151 152 /// Bufferization of linalg.generic. Replace with a new linalg.generic that 153 /// operates entirely on memrefs. 154 template <typename OpTy> 155 struct LinalgOpInterface 156 : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, 157 OpTy> { 158 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 159 const AnalysisState &state) const { 160 // Operand is read if it is used in the computation. 161 auto genericOp = cast<linalg::LinalgOp>(op); 162 return genericOp.payloadUsesValueFromOperand(&opOperand); 163 } 164 165 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 166 const AnalysisState &state) const { 167 // Operand is written to if it has an aliasing OpResult. 168 auto bufferizableOp = cast<BufferizableOpInterface>(op); 169 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 170 } 171 172 SmallVector<OpOperand *> 173 getAliasingOpOperand(Operation *op, OpResult opResult, 174 const AnalysisState &state) const { 175 auto genericOp = cast<linalg::LinalgOp>(op); 176 177 // By default, the i-th OpResult may alias with the i-th "out" tensor. 178 if (state.getOptions().alwaysAliasingWithDest) 179 return {genericOp.getOutputOperand(opResult.getResultNumber())}; 180 181 // We can try to be smart and alias in-place with an "in" tensor if the 182 // corresponding "out" tensor is not used in the computation. 183 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 184 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 185 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) 186 if (pairs[opOperand] == opResult) 187 return {opOperand}; 188 return {}; 189 } 190 191 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 192 const AnalysisState &state) const { 193 auto genericOp = cast<linalg::LinalgOp>(op); 194 195 // By default, the i-th "out" tensor may alias with the i-th OpResult. 196 if (state.getOptions().alwaysAliasingWithDest) { 197 if (genericOp.isOutputTensor(&opOperand)) 198 return {genericOp.getTiedOpResult(&opOperand)}; 199 return {}; 200 } 201 202 // We can try to be smart. See comment in `getAliasingOpOperand`. 203 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 204 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 205 if (!pairs.count(&opOperand)) 206 return {}; 207 return {pairs[&opOperand]}; 208 } 209 210 BufferRelation bufferRelation(Operation *op, OpResult opResult, 211 const AnalysisState &state) const { 212 return BufferRelation::Equivalent; 213 } 214 215 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 216 BufferizationState &state) const { 217 return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state); 218 } 219 }; 220 221 struct InitTensorOpInterface 222 : public BufferizableOpInterface::ExternalModel<InitTensorOpInterface, 223 linalg::InitTensorOp> { 224 bool isMemoryWrite(Operation *op, OpResult opResult, 225 const AnalysisState &state) const { 226 // InitTensorOps allocate but do not write. 227 return false; 228 } 229 230 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 231 BufferizationState &state) const { 232 auto initTensorOp = cast<linalg::InitTensorOp>(op); 233 234 // The InitTensorOp may have been eliminated. 235 if (initTensorOp->getUses().empty()) 236 return success(); 237 238 FailureOr<Value> alloc = state.createAlloc(rewriter, initTensorOp->getLoc(), 239 initTensorOp.result()); 240 if (failed(alloc)) 241 return failure(); 242 replaceOpWithBufferizedValues(rewriter, op, *alloc); 243 return success(); 244 } 245 }; 246 247 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers 248 /// the `BufferizableOpInterface` with each of them. 249 template <typename... Ops> 250 struct LinalgOpInterfaceHelper { 251 static void registerOpInterface(MLIRContext *ctx) { 252 (void)std::initializer_list<int>{ 253 0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...}; 254 } 255 }; 256 } // namespace 257 258 /// Return true if all `neededValues` are in scope at the given 259 /// `insertionPoint`. 260 static bool 261 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, 262 Operation *insertionPoint, 263 const SmallVector<Value> &neededValues) { 264 for (Value val : neededValues) { 265 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 266 Block *owner = bbArg.getOwner(); 267 if (!owner->findAncestorOpInBlock(*insertionPoint)) 268 return false; 269 } else { 270 auto opResult = val.cast<OpResult>(); 271 if (!domInfo.dominates(opResult.getOwner(), insertionPoint)) 272 return false; 273 } 274 } 275 return true; 276 } 277 278 /// Return true if the given `insertionPoint` dominates all uses of 279 /// `initTensorOp`. 280 static bool insertionPointDominatesUses(const DominanceInfo &domInfo, 281 Operation *insertionPoint, 282 Operation *initTensorOp) { 283 for (Operation *user : initTensorOp->getUsers()) 284 if (!domInfo.dominates(insertionPoint, user)) 285 return false; 286 return true; 287 } 288 289 /// Find a valid insertion point for a replacement of `initTensorOp`, assuming 290 /// that the replacement may use any value from `neededValues`. 291 static Operation * 292 findValidInsertionPoint(Operation *initTensorOp, 293 const SmallVector<Value> &neededValues) { 294 DominanceInfo domInfo; 295 296 // Gather all possible insertion points: the location of `initTensorOp` and 297 // right after the definition of each value in `neededValues`. 298 SmallVector<Operation *> insertionPointCandidates; 299 insertionPointCandidates.push_back(initTensorOp); 300 for (Value val : neededValues) { 301 // Note: The anchor op is using all of `neededValues`, so: 302 // * in case of a block argument: There must be at least one op in the block 303 // (the anchor op or one of its parents). 304 // * in case of an OpResult: There must be at least one op right after the 305 // defining op (the anchor op or one of its 306 // parents). 307 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 308 insertionPointCandidates.push_back( 309 &bbArg.getOwner()->getOperations().front()); 310 } else { 311 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode()); 312 } 313 } 314 315 // Select first matching insertion point. 316 for (Operation *insertionPoint : insertionPointCandidates) { 317 // Check if all needed values are in scope. 318 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, 319 neededValues)) 320 continue; 321 // Check if the insertion point is before all uses. 322 if (!insertionPointDominatesUses(domInfo, insertionPoint, initTensorOp)) 323 continue; 324 return insertionPoint; 325 } 326 327 // No suitable insertion point was found. 328 return nullptr; 329 } 330 331 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced 332 /// with the the result of `rewriteFunc` if it is anchored on a matching 333 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def 334 /// chain, starting from the OpOperand and always following the aliasing 335 /// OpOperand, that eventually ends at a single InitTensorOp. 336 LogicalResult mlir::linalg::eliminateInitTensors( 337 Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, 338 AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, 339 SmallVector<Operation *> &newOps) { 340 OpBuilder b(op->getContext()); 341 342 WalkResult status = op->walk([&](Operation *op) { 343 for (OpOperand &operand : op->getOpOperands()) { 344 // Skip operands that do not bufferize inplace. 345 if (!aliasInfo.isInPlace(operand)) 346 continue; 347 // All values that are needed to create the replacement op. 348 SmallVector<Value> neededValues; 349 // Is this a matching OpOperand? 350 if (!anchorMatchFunc(operand, neededValues)) 351 continue; 352 SetVector<Value> maybeInitTensor = 353 state.findValueInReverseUseDefChain(operand.get(), [&](Value val) { 354 // Continue traversal until this function returns true. 355 OpResult opResult = val.dyn_cast<OpResult>(); 356 if (!opResult) 357 return true; 358 SmallVector<OpOperand *> opOperands = 359 state.getAliasingOpOperand(opResult); 360 if (!llvm::all_of(opOperands, [&](OpOperand *operand) { 361 return aliasInfo.isInPlace(*operand); 362 })) 363 return true; 364 // Only equivalent tensors are supported at the moment. 365 // TODO: Support cases such as extract_slice(init_tensor) 366 return !llvm::all_of(opOperands, [&](OpOperand *operand) { 367 return aliasInfo.areEquivalentBufferizedValues(operand->get(), 368 opResult); 369 }); 370 }); 371 372 // Replace only if the reverse use-def chain ends at exactly one 373 // InitTensorOp. 374 if (maybeInitTensor.size() != 1 || 375 !maybeInitTensor.front().getDefiningOp<InitTensorOp>()) 376 return WalkResult::skip(); 377 Value initTensor = maybeInitTensor.front(); 378 379 // Find a suitable insertion point. 380 Operation *insertionPoint = 381 findValidInsertionPoint(initTensor.getDefiningOp(), neededValues); 382 if (!insertionPoint) 383 continue; 384 385 // Create a replacement for the InitTensorOp. 386 b.setInsertionPoint(insertionPoint); 387 Value replacement = rewriteFunc(b, initTensor.getLoc(), operand); 388 if (!replacement) 389 continue; 390 391 // Uses of the InitTensorOp are replaced here, but the op is not deleted. 392 // InitTensorOps without uses are ignored by the bufferization. 393 initTensor.replaceAllUsesWith(replacement); 394 aliasInfo.createAliasInfoEntry(replacement); 395 aliasInfo.unionAliasSets(initTensor, replacement); 396 aliasInfo.unionEquivalenceClasses(initTensor, replacement); 397 398 // Register replacement ops. 399 if (Operation *newOp = replacement.getDefiningOp()) 400 newOps.push_back(newOp); 401 } 402 403 // Advance to the next operation. 404 return WalkResult::advance(); 405 }); 406 407 return failure(status.wasInterrupted()); 408 } 409 410 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp can be 411 /// eliminated if it is eventually inserted into another tensor (and some other 412 /// conditions are met). 413 /// 414 /// E.g.: 415 /// %0 = linalg.init_tensor 416 /// %1 = linalg.fill(%cst, %0) {inplace = [true]} 417 /// %2 = tensor.insert_slice %1 into %t[10][20][1] 418 /// 419 /// InitTensorOp elimination will try to fill %t inplace instead of filling a 420 /// new allocation %0 and inserting it into %t. This is done by replacing the 421 /// InitTensorOp with: 422 /// 423 /// %0 = tensor.extract_slice %t[10][20][1] 424 /// 425 /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets 426 /// those bufferize inplace in the absence of other conflicts. 427 /// 428 /// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert 429 /// source's reverse use-def chain is eliminated if: 430 /// * The InsertSliceOp was decided to bufferize inplace. 431 /// * On the reverse use-def chain path from the InsertSliceOp to the 432 /// InitTensorOp, all ops were decided to bufferize inplace and the buffer 433 /// relation is "equivalent" (TODO: can be relaxed if needed). 434 /// * The reverse use-def chain has exactly one end, which is the InitTensorOp. 435 /// 436 /// Note that the newly inserted ExtractSliceOp may have to bufferize 437 /// out-of-place due to RaW conflicts. 438 LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep( 439 Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, 440 SmallVector<Operation *> &newOps) { 441 return eliminateInitTensors( 442 op, state, aliasInfo, 443 /*anchorMatchFunc=*/ 444 [&](OpOperand &operand, SmallVector<Value> &neededValues) { 445 auto insertSliceOp = 446 dyn_cast<tensor::InsertSliceOp>(operand.getOwner()); 447 if (!insertSliceOp) 448 return false; 449 // Only inplace bufferized InsertSliceOps are eligible. 450 if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/)) 451 return false; 452 if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) 453 return false; 454 455 // Collect all values that are needed to construct the replacement op. 456 neededValues.append(insertSliceOp.offsets().begin(), 457 insertSliceOp.offsets().end()); 458 neededValues.append(insertSliceOp.sizes().begin(), 459 insertSliceOp.sizes().end()); 460 neededValues.append(insertSliceOp.strides().begin(), 461 insertSliceOp.strides().end()); 462 neededValues.push_back(insertSliceOp.dest()); 463 464 return true; 465 }, 466 /*rewriteFunc=*/ 467 [](OpBuilder &b, Location loc, OpOperand &operand) { 468 auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner()); 469 // Expand offsets, sizes and strides to the full rank to handle the 470 // rank-reducing case. 471 SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets(); 472 SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes(); 473 SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides(); 474 OffsetSizeAndStrideOpInterface::expandToRank( 475 insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides, 476 [&](Value target, int64_t dim) -> OpFoldResult { 477 auto shapedType = target.getType().cast<ShapedType>(); 478 if (shapedType.isDynamicDim(dim)) 479 return b.create<tensor::DimOp>(loc, target, dim).result(); 480 return b.getIndexAttr(shapedType.getDimSize(dim)); 481 }); 482 auto t = tensor::ExtractSliceOp::inferRankReducedResultType( 483 insertOp.getSourceType().getRank(), 484 insertOp.dest().getType().cast<RankedTensorType>(), mixedOffsets, 485 mixedSizes, mixedStrides); 486 auto extractOp = b.create<tensor::ExtractSliceOp>( 487 loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides); 488 return extractOp.result(); 489 }, 490 newOps); 491 } 492 493 void mlir::linalg::registerBufferizableOpInterfaceExternalModels( 494 DialectRegistry ®istry) { 495 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 496 linalg::InitTensorOp::attachInterface<InitTensorOpInterface>(*ctx); 497 498 // Register all Linalg structured ops. `LinalgOp` is an interface and it is 499 // not possible to attach an external interface to an existing interface. 500 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. 501 LinalgOpInterfaceHelper< 502 #define GET_OP_LIST 503 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 504 >::registerOpInterface(ctx); 505 }); 506 } 507