1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Dominance.h" 16 #include "mlir/IR/Operation.h" 17 18 using namespace mlir; 19 using namespace linalg; 20 using namespace mlir::bufferization; 21 22 namespace { 23 24 // TODO: Ops in the linalg dialect can directly implement this interface. 25 26 /// Generic conversion for any LinalgOp on tensors. 27 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, 28 BufferizationState &state) { 29 // Take a guard before anything else. 30 OpBuilder::InsertionGuard g(rewriter); 31 rewriter.setInsertionPoint(op); 32 33 // Nothing to do. This op is already bufferized. 34 if (op.hasBufferSemantics()) 35 return success(); 36 37 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need 38 // basis. 39 if (!op.hasTensorSemantics()) 40 return op->emitError() << "op does not have tensor semantics"; 41 42 // New input operands for the cloned op. 43 SmallVector<Value> newInputBuffers; 44 newInputBuffers.reserve(op.getNumInputs()); 45 for (OpOperand *opOperand : op.getInputOperands()) { 46 if (op.isScalar(opOperand)) { 47 newInputBuffers.push_back(opOperand->get()); 48 continue; 49 } 50 // Input operands are never written to. 51 newInputBuffers.push_back(*state.getBuffer( 52 rewriter, *opOperand, 53 BufferizationState::ForceInPlacability::FORCE_INPLACE)); 54 } 55 56 // New output operands for the cloned op. 57 SmallVector<Value> newOutputBuffers; 58 for (OpResult opResult : op->getOpResults()) { 59 SmallVector<OpOperand *> aliasingOpOperands = 60 state.getAnalysisState().getAliasingOpOperand(opResult); 61 assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); 62 FailureOr<Value> resultBuffer = 63 state.getBuffer(rewriter, *aliasingOpOperands.front()); 64 if (failed(resultBuffer)) 65 return failure(); 66 newOutputBuffers.push_back(*resultBuffer); 67 } 68 69 // Merge input/output operands. 70 SmallVector<Value> newOperands = newInputBuffers; 71 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); 72 73 // Set insertion point now that potential alloc/dealloc are introduced. 74 rewriter.setInsertionPoint(op); 75 // Clone the op, but use the new operands. Move the existing block into the 76 // new op. Since the new op does not have any tensor results, it does not 77 // return anything. 78 assert(op->getNumRegions() == 1 && "expected that op has 1 region"); 79 auto newOp = cast<LinalgOp>(op.cloneWithoutRegions( 80 rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); 81 rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), 82 newOp->getRegion(0).begin()); 83 84 // Replace the results of the old op with the new output buffers. 85 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); 86 87 return success(); 88 } 89 90 /// Linalg OpResults usually bufferize inplace with their tied (output 91 /// OpOperands. However, if an output OpOperand is not used in the computation, 92 /// it is better to bufferize inplace with an actually used input OpOperand; 93 /// less memory will be touched that way. 94 /// 95 /// Example: 96 /// O(i, j) = A(i, j) + B(j) --> bufferizes inplace to: A(i, j) += B(j) 97 /// 98 /// O(i, j) = A(j, i) + B(j) --> cannot bufferize inplace with A because 99 /// indexing maps are not identical 100 /// 101 /// O(i, j) += A(i, j) + B(j) --> Output is used in computation. 102 /// This could bufferize inplace with A: 103 /// A(i, j) += O(i, j) + B(j) 104 /// However, we choose to bufferize inplace with O here, as there is no clear 105 /// benefit of choosing A. TODO: We may want to consider both options and make 106 /// an informed decision during analysis in the future. 107 static DenseMap<OpOperand *, OpResult> computeAliasingPairs(LinalgOp op) { 108 DenseMap<OpOperand *, OpResult> mapping; 109 for (OpResult opResult : op->getOpResults()) { 110 OpOperand *tiedOperand = 111 op.getOutputTensorOperands()[opResult.getResultNumber()]; 112 AffineMap outputIndexingMap = op.getTiedIndexingMap(tiedOperand); 113 bool onlyParallelIterators = op.getNumParallelLoops() == op.getNumLoops(); 114 bool tiedOperandUsed = op.payloadUsesValueFromOperand(tiedOperand); 115 116 // If the output arg is used in the computation or at least one iterator is 117 // not parallel, try to bufferize inplace with the corresponding output 118 // tensor. 119 if (tiedOperandUsed || !onlyParallelIterators) { 120 mapping[tiedOperand] = opResult; 121 continue; 122 } 123 124 // Otherwise, try to bufferize inplace with one of the inputs. 125 OpOperand *chosenOperand = nullptr; 126 for (OpOperand *opOperand : op.getInputTensorOperands()) { 127 if (opOperand->get().getType() != opResult.getType()) 128 continue; 129 if (!op.payloadUsesValueFromOperand(opOperand)) 130 continue; 131 if (op.getTiedIndexingMap(opOperand) != outputIndexingMap) 132 continue; 133 // No other OpResult bufferizes aliases with this OpOperand. 134 if (mapping.count(opOperand)) 135 continue; 136 assert(op.getTiedIndexingMap(opOperand).isProjectedPermutation() && 137 "expected projected permutation"); 138 chosenOperand = opOperand; 139 break; 140 } 141 142 // No suitable input tensor found. Use output tensor. 143 // TODO: This operand could bufferize inplace with OpOperands that have the 144 // correct type, even if they are not used inside the computation. 145 if (!chosenOperand) 146 chosenOperand = tiedOperand; 147 148 mapping[chosenOperand] = opResult; 149 } 150 return mapping; 151 } 152 153 /// Bufferization of linalg.generic. Replace with a new linalg.generic that 154 /// operates entirely on memrefs. 155 template <typename OpTy> 156 struct LinalgOpInterface 157 : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, 158 OpTy> { 159 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 160 const AnalysisState &state) const { 161 // Operand is read if it is used in the computation. 162 auto genericOp = cast<linalg::LinalgOp>(op); 163 return genericOp.payloadUsesValueFromOperand(&opOperand); 164 } 165 166 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 167 const AnalysisState &state) const { 168 // Operand is written to if it has an aliasing OpResult. 169 auto bufferizableOp = cast<BufferizableOpInterface>(op); 170 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 171 } 172 173 SmallVector<OpOperand *> 174 getAliasingOpOperand(Operation *op, OpResult opResult, 175 const AnalysisState &state) const { 176 auto genericOp = cast<linalg::LinalgOp>(op); 177 178 // By default, the i-th OpResult may alias with the i-th "out" tensor. 179 if (state.getOptions().alwaysAliasingWithDest) 180 return {genericOp.getOutputOperand(opResult.getResultNumber())}; 181 182 // We can try to be smart and alias in-place with an "in" tensor if the 183 // corresponding "out" tensor is not used in the computation. 184 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 185 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 186 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) 187 if (pairs[opOperand] == opResult) 188 return {opOperand}; 189 return {}; 190 } 191 192 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 193 const AnalysisState &state) const { 194 auto genericOp = cast<linalg::LinalgOp>(op); 195 196 // By default, the i-th "out" tensor may alias with the i-th OpResult. 197 if (state.getOptions().alwaysAliasingWithDest) { 198 if (genericOp.isOutputTensor(&opOperand)) 199 return {genericOp.getTiedOpResult(&opOperand)}; 200 return {}; 201 } 202 203 // We can try to be smart. See comment in `getAliasingOpOperand`. 204 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 205 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 206 if (!pairs.count(&opOperand)) 207 return {}; 208 return {pairs[&opOperand]}; 209 } 210 211 BufferRelation bufferRelation(Operation *op, OpResult opResult, 212 const AnalysisState &state) const { 213 return BufferRelation::Equivalent; 214 } 215 216 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 217 BufferizationState &state) const { 218 return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state); 219 } 220 }; 221 222 struct InitTensorOpInterface 223 : public BufferizableOpInterface::ExternalModel<InitTensorOpInterface, 224 linalg::InitTensorOp> { 225 bool isMemoryWrite(Operation *op, OpResult opResult, 226 const AnalysisState &state) const { 227 // InitTensorOps allocate but do not write. 228 return false; 229 } 230 231 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 232 BufferizationState &state) const { 233 auto initTensorOp = cast<linalg::InitTensorOp>(op); 234 235 // The InitTensorOp may have been eliminated. 236 if (initTensorOp->getUses().empty()) 237 return success(); 238 239 FailureOr<Value> alloc = state.createAlloc(rewriter, initTensorOp->getLoc(), 240 initTensorOp.result()); 241 if (failed(alloc)) 242 return failure(); 243 replaceOpWithBufferizedValues(rewriter, op, *alloc); 244 return success(); 245 } 246 }; 247 248 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers 249 /// the `BufferizableOpInterface` with each of them. 250 template <typename... Ops> 251 struct LinalgOpInterfaceHelper { 252 static void registerOpInterface(MLIRContext *ctx) { 253 (void)std::initializer_list<int>{ 254 0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...}; 255 } 256 }; 257 } // namespace 258 259 /// Return true if all `neededValues` are in scope at the given 260 /// `insertionPoint`. 261 static bool 262 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, 263 Operation *insertionPoint, 264 const SmallVector<Value> &neededValues) { 265 for (Value val : neededValues) { 266 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 267 Block *owner = bbArg.getOwner(); 268 if (!owner->findAncestorOpInBlock(*insertionPoint)) 269 return false; 270 } else { 271 auto opResult = val.cast<OpResult>(); 272 if (!domInfo.dominates(opResult.getOwner(), insertionPoint)) 273 return false; 274 } 275 } 276 return true; 277 } 278 279 /// Return true if the given `insertionPoint` dominates all uses of 280 /// `initTensorOp`. 281 static bool insertionPointDominatesUses(const DominanceInfo &domInfo, 282 Operation *insertionPoint, 283 Operation *initTensorOp) { 284 for (Operation *user : initTensorOp->getUsers()) 285 if (!domInfo.dominates(insertionPoint, user)) 286 return false; 287 return true; 288 } 289 290 /// Find a valid insertion point for a replacement of `initTensorOp`, assuming 291 /// that the replacement may use any value from `neededValues`. 292 static Operation * 293 findValidInsertionPoint(Operation *initTensorOp, 294 const SmallVector<Value> &neededValues) { 295 DominanceInfo domInfo; 296 297 // Gather all possible insertion points: the location of `initTensorOp` and 298 // right after the definition of each value in `neededValues`. 299 SmallVector<Operation *> insertionPointCandidates; 300 insertionPointCandidates.push_back(initTensorOp); 301 for (Value val : neededValues) { 302 // Note: The anchor op is using all of `neededValues`, so: 303 // * in case of a block argument: There must be at least one op in the block 304 // (the anchor op or one of its parents). 305 // * in case of an OpResult: There must be at least one op right after the 306 // defining op (the anchor op or one of its 307 // parents). 308 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 309 insertionPointCandidates.push_back( 310 &bbArg.getOwner()->getOperations().front()); 311 } else { 312 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode()); 313 } 314 } 315 316 // Select first matching insertion point. 317 for (Operation *insertionPoint : insertionPointCandidates) { 318 // Check if all needed values are in scope. 319 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, 320 neededValues)) 321 continue; 322 // Check if the insertion point is before all uses. 323 if (!insertionPointDominatesUses(domInfo, insertionPoint, initTensorOp)) 324 continue; 325 return insertionPoint; 326 } 327 328 // No suitable insertion point was found. 329 return nullptr; 330 } 331 332 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced 333 /// with the the result of `rewriteFunc` if it is anchored on a matching 334 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def 335 /// chain, starting from the OpOperand and always following the aliasing 336 /// OpOperand, that eventually ends at a single InitTensorOp. 337 LogicalResult mlir::linalg::eliminateInitTensors( 338 Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, 339 AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, 340 SmallVector<Operation *> &newOps) { 341 OpBuilder b(op->getContext()); 342 343 WalkResult status = op->walk([&](Operation *op) { 344 for (OpOperand &operand : op->getOpOperands()) { 345 // Skip operands that do not bufferize inplace. 346 if (!aliasInfo.isInPlace(operand)) 347 continue; 348 // All values that are needed to create the replacement op. 349 SmallVector<Value> neededValues; 350 // Is this a matching OpOperand? 351 if (!anchorMatchFunc(operand, neededValues)) 352 continue; 353 SetVector<Value> maybeInitTensor = 354 state.findValueInReverseUseDefChain(operand.get(), [&](Value val) { 355 // Continue traversal until this function returns true. 356 OpResult opResult = val.dyn_cast<OpResult>(); 357 if (!opResult) 358 return true; 359 SmallVector<OpOperand *> opOperands = 360 state.getAliasingOpOperand(opResult); 361 if (!llvm::all_of(opOperands, [&](OpOperand *operand) { 362 return aliasInfo.isInPlace(*operand); 363 })) 364 return true; 365 // Only equivalent tensors are supported at the moment. 366 // TODO: Support cases such as extract_slice(init_tensor) 367 return !llvm::all_of(opOperands, [&](OpOperand *operand) { 368 return aliasInfo.areEquivalentBufferizedValues(operand->get(), 369 opResult); 370 }); 371 }); 372 373 // Replace only if the reverse use-def chain ends at exactly one 374 // InitTensorOp. 375 if (maybeInitTensor.size() != 1 || 376 !maybeInitTensor.front().getDefiningOp<InitTensorOp>()) 377 return WalkResult::skip(); 378 Value initTensor = maybeInitTensor.front(); 379 380 // Find a suitable insertion point. 381 Operation *insertionPoint = 382 findValidInsertionPoint(initTensor.getDefiningOp(), neededValues); 383 if (!insertionPoint) 384 continue; 385 386 // Create a replacement for the InitTensorOp. 387 b.setInsertionPoint(insertionPoint); 388 Value replacement = rewriteFunc(b, initTensor.getLoc(), operand); 389 if (!replacement) 390 continue; 391 392 // Uses of the InitTensorOp are replaced here, but the op is not deleted. 393 // InitTensorOps without uses are ignored by the bufferization. 394 initTensor.replaceAllUsesWith(replacement); 395 aliasInfo.createAliasInfoEntry(replacement); 396 aliasInfo.unionAliasSets(initTensor, replacement); 397 aliasInfo.unionEquivalenceClasses(initTensor, replacement); 398 399 // Register replacement ops. 400 if (Operation *newOp = replacement.getDefiningOp()) 401 newOps.push_back(newOp); 402 } 403 404 // Advance to the next operation. 405 return WalkResult::advance(); 406 }); 407 408 return failure(status.wasInterrupted()); 409 } 410 411 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp can be 412 /// eliminated if it is eventually inserted into another tensor (and some other 413 /// conditions are met). 414 /// 415 /// E.g.: 416 /// %0 = linalg.init_tensor 417 /// %1 = linalg.fill(%cst, %0) {inplace = [true]} 418 /// %2 = tensor.insert_slice %1 into %t[10][20][1] 419 /// 420 /// InitTensorOp elimination will try to fill %t inplace instead of filling a 421 /// new allocation %0 and inserting it into %t. This is done by replacing the 422 /// InitTensorOp with: 423 /// 424 /// %0 = tensor.extract_slice %t[10][20][1] 425 /// 426 /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets 427 /// those bufferize inplace in the absence of other conflicts. 428 /// 429 /// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert 430 /// source's reverse use-def chain is eliminated if: 431 /// * The InsertSliceOp was decided to bufferize inplace. 432 /// * On the reverse use-def chain path from the InsertSliceOp to the 433 /// InitTensorOp, all ops were decided to bufferize inplace and the buffer 434 /// relation is "equivalent" (TODO: can be relaxed if needed). 435 /// * The reverse use-def chain has exactly one end, which is the InitTensorOp. 436 /// 437 /// Note that the newly inserted ExtractSliceOp may have to bufferize 438 /// out-of-place due to RaW conflicts. 439 LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep( 440 Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, 441 SmallVector<Operation *> &newOps) { 442 return eliminateInitTensors( 443 op, state, aliasInfo, 444 /*anchorMatchFunc=*/ 445 [&](OpOperand &operand, SmallVector<Value> &neededValues) { 446 auto insertSliceOp = 447 dyn_cast<tensor::InsertSliceOp>(operand.getOwner()); 448 if (!insertSliceOp) 449 return false; 450 // Only inplace bufferized InsertSliceOps are eligible. 451 if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/)) 452 return false; 453 if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) 454 return false; 455 456 // Collect all values that are needed to construct the replacement op. 457 neededValues.append(insertSliceOp.offsets().begin(), 458 insertSliceOp.offsets().end()); 459 neededValues.append(insertSliceOp.sizes().begin(), 460 insertSliceOp.sizes().end()); 461 neededValues.append(insertSliceOp.strides().begin(), 462 insertSliceOp.strides().end()); 463 neededValues.push_back(insertSliceOp.dest()); 464 465 return true; 466 }, 467 /*rewriteFunc=*/ 468 [](OpBuilder &b, Location loc, OpOperand &operand) { 469 auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner()); 470 // Expand offsets, sizes and strides to the full rank to handle the 471 // rank-reducing case. 472 SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets(); 473 SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes(); 474 SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides(); 475 OffsetSizeAndStrideOpInterface::expandToRank( 476 insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides, 477 [&](Value target, int64_t dim) -> OpFoldResult { 478 auto shapedType = target.getType().cast<ShapedType>(); 479 if (shapedType.isDynamicDim(dim)) 480 return b.create<tensor::DimOp>(loc, target, dim).result(); 481 return b.getIndexAttr(shapedType.getDimSize(dim)); 482 }); 483 auto t = tensor::ExtractSliceOp::inferRankReducedResultType( 484 insertOp.getSourceType().getRank(), 485 insertOp.dest().getType().cast<RankedTensorType>(), mixedOffsets, 486 mixedSizes, mixedStrides); 487 auto extractOp = b.create<tensor::ExtractSliceOp>( 488 loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides); 489 return extractOp.result(); 490 }, 491 newOps); 492 } 493 494 void mlir::linalg::registerBufferizableOpInterfaceExternalModels( 495 DialectRegistry ®istry) { 496 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 497 linalg::InitTensorOp::attachInterface<InitTensorOpInterface>(*ctx); 498 499 // Register all Linalg structured ops. `LinalgOp` is an interface and it is 500 // not possible to attach an external interface to an existing interface. 501 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. 502 LinalgOpInterfaceHelper< 503 #define GET_OP_LIST 504 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 505 >::registerOpInterface(ctx); 506 }); 507 } 508