1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace linalg; 19 using namespace mlir::bufferization; 20 21 namespace { 22 23 // TODO: Ops in the linalg dialect can directly implement this interface. 24 25 /// Generic conversion for any LinalgOp on tensors. 26 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, 27 BufferizationState &state) { 28 // Take a guard before anything else. 29 OpBuilder::InsertionGuard g(rewriter); 30 rewriter.setInsertionPoint(op); 31 32 // Nothing to do. This op is already bufferized. 33 if (op.hasBufferSemantics()) 34 return success(); 35 36 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need 37 // basis. 38 if (!op.hasTensorSemantics()) 39 return op->emitError() << "op does not have tensor semantics"; 40 41 // New input operands for the cloned op. 42 SmallVector<Value> newInputBuffers; 43 newInputBuffers.reserve(op.getNumInputs()); 44 for (OpOperand *opOperand : op.getInputOperands()) { 45 if (op.isScalar(opOperand)) { 46 newInputBuffers.push_back(opOperand->get()); 47 continue; 48 } 49 // Input operands are never written to. 50 newInputBuffers.push_back(*state.getBuffer( 51 rewriter, *opOperand, 52 BufferizationState::ForceInPlacability::FORCE_INPLACE)); 53 } 54 55 // New output operands for the cloned op. 56 SmallVector<Value> newOutputBuffers; 57 for (OpResult opResult : op->getOpResults()) { 58 SmallVector<OpOperand *> aliasingOpOperands = 59 state.getAnalysisState().getAliasingOpOperand(opResult); 60 assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); 61 FailureOr<Value> resultBuffer = 62 state.getBuffer(rewriter, *aliasingOpOperands.front()); 63 if (failed(resultBuffer)) 64 return failure(); 65 newOutputBuffers.push_back(*resultBuffer); 66 } 67 68 // Merge input/output operands. 69 SmallVector<Value> newOperands = newInputBuffers; 70 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); 71 72 // Set insertion point now that potential alloc/dealloc are introduced. 73 rewriter.setInsertionPoint(op); 74 // Clone the op, but use the new operands. Move the existing block into the 75 // new op. Since the new op does not have any tensor results, it does not 76 // return anything. 77 assert(op->getNumRegions() == 1 && "expected that op has 1 region"); 78 auto newOp = cast<LinalgOp>(op.cloneWithoutRegions( 79 rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); 80 rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), 81 newOp->getRegion(0).begin()); 82 83 // Replace the results of the old op with the new output buffers. 84 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); 85 86 return success(); 87 } 88 89 /// Bufferization of linalg.generic. Replace with a new linalg.generic that 90 /// operates entirely on memrefs. 91 template <typename OpTy> 92 struct LinalgOpInterface 93 : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, 94 OpTy> { 95 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 96 const AnalysisState &state) const { 97 // Operand is read if it is used in the computation. 98 auto genericOp = cast<linalg::LinalgOp>(op); 99 return genericOp.payloadUsesValueFromOperand(&opOperand); 100 } 101 102 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 103 const AnalysisState &state) const { 104 // Operand is written to if it has an aliasing OpResult. 105 auto bufferizableOp = cast<BufferizableOpInterface>(op); 106 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 107 } 108 109 SmallVector<OpOperand *> 110 getAliasingOpOperand(Operation *op, OpResult opResult, 111 const AnalysisState &state) const { 112 auto genericOp = cast<linalg::LinalgOp>(op); 113 114 // The i-th OpResult may alias with the i-th "out" tensor. 115 return {genericOp.getOutputOperand(opResult.getResultNumber())}; 116 } 117 118 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 119 const AnalysisState &state) const { 120 auto genericOp = cast<linalg::LinalgOp>(op); 121 122 // The i-th "out" tensor may alias with the i-th OpResult. 123 if (genericOp.isOutputTensor(&opOperand)) 124 return {genericOp.getTiedOpResult(&opOperand)}; 125 return {}; 126 } 127 128 BufferRelation bufferRelation(Operation *op, OpResult opResult, 129 const AnalysisState &state) const { 130 return BufferRelation::Equivalent; 131 } 132 133 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 134 BufferizationState &state) const { 135 return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state); 136 } 137 }; 138 139 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers 140 /// the `BufferizableOpInterface` with each of them. 141 template <typename... Ops> 142 struct LinalgOpInterfaceHelper { 143 static void registerOpInterface(MLIRContext *ctx) { 144 (void)std::initializer_list<int>{ 145 0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...}; 146 } 147 }; 148 } // namespace 149 150 void mlir::linalg::registerBufferizableOpInterfaceExternalModels( 151 DialectRegistry ®istry) { 152 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 153 // Register all Linalg structured ops. `LinalgOp` is an interface and it is 154 // not possible to attach an external interface to an existing interface. 155 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. 156 LinalgOpInterfaceHelper< 157 #define GET_OP_LIST 158 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 159 >::registerOpInterface(ctx); 160 }); 161 } 162