1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace linalg; 19 using namespace mlir::bufferization; 20 21 namespace { 22 23 // TODO: Ops in the linalg dialect can directly implement this interface. 24 25 /// Generic conversion for any LinalgOp on tensors. 26 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, 27 BufferizationState &state) { 28 // Take a guard before anything else. 29 OpBuilder::InsertionGuard g(rewriter); 30 rewriter.setInsertionPoint(op); 31 32 // Nothing to do. This op is already bufferized. 33 if (op.hasBufferSemantics()) 34 return success(); 35 36 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need 37 // basis. 38 if (!op.hasTensorSemantics()) 39 return op->emitError() << "op does not have tensor semantics"; 40 41 // New input operands for the cloned op. 42 SmallVector<Value> newInputBuffers; 43 newInputBuffers.reserve(op.getNumInputs()); 44 for (OpOperand *opOperand : op.getInputOperands()) { 45 if (op.isScalar(opOperand)) { 46 newInputBuffers.push_back(opOperand->get()); 47 continue; 48 } 49 newInputBuffers.push_back(state.getBuffer(rewriter, opOperand->get())); 50 } 51 52 // New output operands for the cloned op. 53 SmallVector<Value> newOutputBuffers; 54 for (OpResult opResult : op->getOpResults()) { 55 OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber()); 56 Value resultBuffer = state.getBuffer(rewriter, opOperand->get()); 57 newOutputBuffers.push_back(resultBuffer); 58 } 59 60 // Merge input/output operands. 61 SmallVector<Value> newOperands = newInputBuffers; 62 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); 63 64 // Set insertion point now that potential alloc/dealloc are introduced. 65 rewriter.setInsertionPoint(op); 66 // Clone the op, but use the new operands. Move the existing block into the 67 // new op. Since the new op does not have any tensor results, it does not 68 // return anything. 69 assert(op->getNumRegions() == 1 && "expected that op has 1 region"); 70 auto newOp = cast<LinalgOp>(op.cloneWithoutRegions( 71 rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); 72 rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), 73 newOp->getRegion(0).begin()); 74 75 // Replace the results of the old op with the new output buffers. 76 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); 77 78 return success(); 79 } 80 81 /// Bufferization of linalg.generic. Replace with a new linalg.generic that 82 /// operates entirely on memrefs. 83 template <typename OpTy> 84 struct LinalgOpInterface 85 : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, 86 OpTy> { 87 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 88 const AnalysisState &state) const { 89 // Operand is read if it is used in the computation. 90 auto genericOp = cast<linalg::LinalgOp>(op); 91 return genericOp.payloadUsesValueFromOperand(&opOperand); 92 } 93 94 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 95 const AnalysisState &state) const { 96 // Operand is written to if it has an aliasing OpResult. 97 auto bufferizableOp = cast<BufferizableOpInterface>(op); 98 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 99 } 100 101 SmallVector<OpOperand *> 102 getAliasingOpOperand(Operation *op, OpResult opResult, 103 const AnalysisState &state) const { 104 auto genericOp = cast<linalg::LinalgOp>(op); 105 106 // The i-th OpResult may alias with the i-th "out" tensor. 107 return {genericOp.getOutputOperand(opResult.getResultNumber())}; 108 } 109 110 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 111 const AnalysisState &state) const { 112 auto genericOp = cast<linalg::LinalgOp>(op); 113 114 // The i-th "out" tensor may alias with the i-th OpResult. 115 if (genericOp.isOutputTensor(&opOperand)) 116 return {genericOp.getTiedOpResult(&opOperand)}; 117 return {}; 118 } 119 120 BufferRelation bufferRelation(Operation *op, OpResult opResult, 121 const AnalysisState &state) const { 122 return BufferRelation::Equivalent; 123 } 124 125 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 126 BufferizationState &state) const { 127 return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state); 128 } 129 }; 130 131 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers 132 /// the `BufferizableOpInterface` with each of them. 133 template <typename... Ops> 134 struct LinalgOpInterfaceHelper { 135 static void registerOpInterface(MLIRContext *ctx) { 136 (void)std::initializer_list<int>{ 137 0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...}; 138 } 139 }; 140 } // namespace 141 142 void mlir::linalg::registerBufferizableOpInterfaceExternalModels( 143 DialectRegistry ®istry) { 144 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 145 // Register all Linalg structured ops. `LinalgOp` is an interface and it is 146 // not possible to attach an external interface to an existing interface. 147 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. 148 LinalgOpInterfaceHelper< 149 #define GET_OP_LIST 150 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 151 >::registerOpInterface(ctx); 152 }); 153 } 154