1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace linalg; 19 using namespace mlir::bufferization; 20 21 namespace { 22 23 // TODO: Ops in the linalg dialect can directly implement this interface. 24 25 /// Generic conversion for any LinalgOp on tensors. 26 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, 27 BufferizationState &state) { 28 // Take a guard before anything else. 29 OpBuilder::InsertionGuard g(rewriter); 30 rewriter.setInsertionPoint(op); 31 32 // Nothing to do. This op is already bufferized. 33 if (op.hasBufferSemantics()) 34 return success(); 35 36 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need 37 // basis. 38 if (!op.hasTensorSemantics()) 39 return op->emitError() << "op does not have tensor semantics"; 40 41 // New input operands for the cloned op. 42 SmallVector<Value> newInputBuffers; 43 newInputBuffers.reserve(op.getNumInputs()); 44 for (OpOperand *opOperand : op.getInputOperands()) { 45 if (op.isScalar(opOperand)) { 46 newInputBuffers.push_back(opOperand->get()); 47 continue; 48 } 49 // Input operands are never written to. 50 newInputBuffers.push_back(*state.getBuffer( 51 rewriter, *opOperand, 52 BufferizationState::ForceInPlacability::FORCE_INPLACE)); 53 } 54 55 // New output operands for the cloned op. 56 SmallVector<Value> newOutputBuffers; 57 for (OpResult opResult : op->getOpResults()) { 58 SmallVector<OpOperand *> aliasingOpOperands = 59 state.getAnalysisState().getAliasingOpOperand(opResult); 60 assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); 61 FailureOr<Value> resultBuffer = 62 state.getBuffer(rewriter, *aliasingOpOperands.front()); 63 if (failed(resultBuffer)) 64 return failure(); 65 newOutputBuffers.push_back(*resultBuffer); 66 } 67 68 // Merge input/output operands. 69 SmallVector<Value> newOperands = newInputBuffers; 70 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); 71 72 // Set insertion point now that potential alloc/dealloc are introduced. 73 rewriter.setInsertionPoint(op); 74 // Clone the op, but use the new operands. Move the existing block into the 75 // new op. Since the new op does not have any tensor results, it does not 76 // return anything. 77 assert(op->getNumRegions() == 1 && "expected that op has 1 region"); 78 auto newOp = cast<LinalgOp>(op.cloneWithoutRegions( 79 rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); 80 rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), 81 newOp->getRegion(0).begin()); 82 83 // Replace the results of the old op with the new output buffers. 84 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); 85 86 return success(); 87 } 88 89 /// Linalg OpResults usually bufferize inplace with their tied (output 90 /// OpOperands. However, if an output OpOperand is not used in the computation, 91 /// it is better to bufferize inplace with an actually used input OpOperand; 92 /// less memory will be touched that way. 93 /// 94 /// Example: 95 /// O(i, j) = A(i, j) + B(j) --> bufferizes inplace to: A(i, j) += B(j) 96 /// 97 /// O(i, j) = A(j, i) + B(j) --> cannot bufferize inplace with A because 98 /// indexing maps are not identical 99 /// 100 /// O(i, j) += A(i, j) + B(j) --> Output is used in computation. 101 /// This could bufferize inplace with A: 102 /// A(i, j) += O(i, j) + B(j) 103 /// However, we choose to bufferize inplace with O here, as there is no clear 104 /// benefit of choosing A. TODO: We may want to consider both options and make 105 /// an informed decision during analysis in the future. 106 static DenseMap<OpOperand *, OpResult> computeAliasingPairs(LinalgOp op) { 107 DenseMap<OpOperand *, OpResult> mapping; 108 for (OpResult opResult : op->getOpResults()) { 109 OpOperand *tiedOperand = 110 op.getOutputTensorOperands()[opResult.getResultNumber()]; 111 AffineMap outputIndexingMap = op.getTiedIndexingMap(tiedOperand); 112 bool onlyParallelIterators = op.getNumParallelLoops() == op.getNumLoops(); 113 bool tiedOperandUsed = op.payloadUsesValueFromOperand(tiedOperand); 114 115 // If the output arg is used in the computation or at least one iterator is 116 // not parallel, try to bufferize inplace with the corresponding output 117 // tensor. 118 if (tiedOperandUsed || !onlyParallelIterators) { 119 mapping[tiedOperand] = opResult; 120 continue; 121 } 122 123 // Otherwise, try to bufferize inplace with one of the inputs. 124 OpOperand *chosenOperand = nullptr; 125 for (OpOperand *opOperand : op.getInputTensorOperands()) { 126 if (opOperand->get().getType() != opResult.getType()) 127 continue; 128 if (!op.payloadUsesValueFromOperand(opOperand)) 129 continue; 130 if (op.getTiedIndexingMap(opOperand) != outputIndexingMap) 131 continue; 132 // No other OpResult bufferizes aliases with this OpOperand. 133 if (mapping.count(opOperand)) 134 continue; 135 assert(op.getTiedIndexingMap(opOperand).isProjectedPermutation() && 136 "expected projected permutation"); 137 chosenOperand = opOperand; 138 break; 139 } 140 141 // No suitable input tensor found. Use output tensor. 142 // TODO: This operand could bufferize inplace with OpOperands that have the 143 // correct type, even if they are not used inside the computation. 144 if (!chosenOperand) 145 chosenOperand = tiedOperand; 146 147 mapping[chosenOperand] = opResult; 148 } 149 return mapping; 150 } 151 152 /// Bufferization of linalg.generic. Replace with a new linalg.generic that 153 /// operates entirely on memrefs. 154 template <typename OpTy> 155 struct LinalgOpInterface 156 : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, 157 OpTy> { 158 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 159 const AnalysisState &state) const { 160 // Operand is read if it is used in the computation. 161 auto genericOp = cast<linalg::LinalgOp>(op); 162 return genericOp.payloadUsesValueFromOperand(&opOperand); 163 } 164 165 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 166 const AnalysisState &state) const { 167 // Operand is written to if it has an aliasing OpResult. 168 auto bufferizableOp = cast<BufferizableOpInterface>(op); 169 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 170 } 171 172 SmallVector<OpOperand *> 173 getAliasingOpOperand(Operation *op, OpResult opResult, 174 const AnalysisState &state) const { 175 auto genericOp = cast<linalg::LinalgOp>(op); 176 177 // By default, the i-th OpResult may alias with the i-th "out" tensor. 178 if (state.getOptions().alwaysAliasingWithDest) 179 return {genericOp.getOutputOperand(opResult.getResultNumber())}; 180 181 // We can try to be smart and alias in-place with an "in" tensor if the 182 // corresponding "out" tensor is not used in the computation. 183 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 184 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 185 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) 186 if (pairs[opOperand] == opResult) 187 return {opOperand}; 188 return {}; 189 } 190 191 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 192 const AnalysisState &state) const { 193 auto genericOp = cast<linalg::LinalgOp>(op); 194 195 // By default, the i-th "out" tensor may alias with the i-th OpResult. 196 if (state.getOptions().alwaysAliasingWithDest) { 197 if (genericOp.isOutputTensor(&opOperand)) 198 return {genericOp.getTiedOpResult(&opOperand)}; 199 return {}; 200 } 201 202 // We can try to be smart. See comment in `getAliasingOpOperand`. 203 // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. 204 DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp); 205 if (!pairs.count(&opOperand)) 206 return {}; 207 return {pairs[&opOperand]}; 208 } 209 210 BufferRelation bufferRelation(Operation *op, OpResult opResult, 211 const AnalysisState &state) const { 212 return BufferRelation::Equivalent; 213 } 214 215 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 216 BufferizationState &state) const { 217 return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state); 218 } 219 }; 220 221 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers 222 /// the `BufferizableOpInterface` with each of them. 223 template <typename... Ops> 224 struct LinalgOpInterfaceHelper { 225 static void registerOpInterface(MLIRContext *ctx) { 226 (void)std::initializer_list<int>{ 227 0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...}; 228 } 229 }; 230 } // namespace 231 232 void mlir::linalg::registerBufferizableOpInterfaceExternalModels( 233 DialectRegistry ®istry) { 234 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 235 // Register all Linalg structured ops. `LinalgOp` is an interface and it is 236 // not possible to attach an external interface to an existing interface. 237 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. 238 LinalgOpInterfaceHelper< 239 #define GET_OP_LIST 240 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 241 >::registerOpInterface(ctx); 242 }); 243 } 244