1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace linalg; 19 using namespace mlir::bufferization; 20 21 namespace { 22 23 /// Generic conversion for any LinalgOp on tensors. 24 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, 25 const BufferizationOptions &options) { 26 // Take a guard before anything else. 27 OpBuilder::InsertionGuard g(rewriter); 28 rewriter.setInsertionPoint(op); 29 30 // Nothing to do. This op is already bufferized. 31 if (op.hasBufferSemantics()) 32 return success(); 33 34 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need 35 // basis. 36 if (!op.hasTensorSemantics()) 37 return op->emitError() << "op does not have tensor semantics"; 38 39 // New input operands for the cloned op. 40 SmallVector<Value> newInputBuffers; 41 newInputBuffers.reserve(op.getNumInputs()); 42 for (OpOperand *opOperand : op.getInputOperands()) { 43 if (op.isScalar(opOperand)) { 44 newInputBuffers.push_back(opOperand->get()); 45 continue; 46 } 47 newInputBuffers.push_back(getBuffer(rewriter, opOperand->get(), options)); 48 } 49 50 // New output operands for the cloned op. 51 SmallVector<Value> newOutputBuffers; 52 for (OpResult opResult : op->getOpResults()) { 53 OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber()); 54 Value resultBuffer = getBuffer(rewriter, opOperand->get(), options); 55 newOutputBuffers.push_back(resultBuffer); 56 } 57 58 // Merge input/output operands. 59 SmallVector<Value> newOperands = newInputBuffers; 60 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); 61 62 // Set insertion point now that potential alloc/dealloc are introduced. 63 rewriter.setInsertionPoint(op); 64 // Clone the op, but use the new operands. Move the existing block into the 65 // new op. Since the new op does not have any tensor results, it does not 66 // return anything. 67 assert(op->getNumRegions() == 1 && "expected that op has 1 region"); 68 auto newOp = cast<LinalgOp>(op.cloneWithoutRegions( 69 rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); 70 rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), 71 newOp->getRegion(0).begin()); 72 73 // Replace the results of the old op with the new output buffers. 74 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); 75 76 return success(); 77 } 78 79 /// Bufferization of linalg.generic. Replace with a new linalg.generic that 80 /// operates entirely on memrefs. 81 template <typename OpTy> 82 struct LinalgOpInterface 83 : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, 84 OpTy> { 85 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 86 const AnalysisState &state) const { 87 // Operand is read if it is used in the computation. 88 auto genericOp = cast<linalg::LinalgOp>(op); 89 return genericOp.payloadUsesValueFromOperand(&opOperand); 90 } 91 92 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 93 const AnalysisState &state) const { 94 // Operand is written to if it has an aliasing OpResult. 95 auto bufferizableOp = cast<BufferizableOpInterface>(op); 96 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 97 } 98 99 SmallVector<OpOperand *> 100 getAliasingOpOperand(Operation *op, OpResult opResult, 101 const AnalysisState &state) const { 102 auto genericOp = cast<linalg::LinalgOp>(op); 103 104 // The i-th OpResult may alias with the i-th "out" tensor. 105 return {genericOp.getOutputOperand(opResult.getResultNumber())}; 106 } 107 108 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 109 const AnalysisState &state) const { 110 auto genericOp = cast<linalg::LinalgOp>(op); 111 112 // The i-th "out" tensor may alias with the i-th OpResult. 113 if (genericOp.isOutputTensor(&opOperand)) 114 return {genericOp.getTiedOpResult(&opOperand)}; 115 return {}; 116 } 117 118 BufferRelation bufferRelation(Operation *op, OpResult opResult, 119 const AnalysisState &state) const { 120 return BufferRelation::Equivalent; 121 } 122 123 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 124 const BufferizationOptions &options) const { 125 return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), options); 126 } 127 }; 128 129 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers 130 /// the `BufferizableOpInterface` with each of them. 131 template <typename... Ops> 132 struct LinalgOpInterfaceHelper { 133 static void registerOpInterface(MLIRContext *ctx) { 134 (void)std::initializer_list<int>{ 135 0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...}; 136 } 137 }; 138 } // namespace 139 140 void mlir::linalg::registerBufferizableOpInterfaceExternalModels( 141 DialectRegistry ®istry) { 142 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 143 // Register all Linalg structured ops. `LinalgOp` is an interface and it is 144 // not possible to attach an external interface to an existing interface. 145 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. 146 LinalgOpInterfaceHelper< 147 #define GET_OP_LIST 148 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 149 >::registerOpInterface(ctx); 150 }); 151 } 152