1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace linalg; 19 using namespace mlir::bufferization; 20 21 namespace { 22 23 /// Generic conversion for any LinalgOp on tensors. 24 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, 25 const BufferizationOptions &options) { 26 // Take a guard before anything else. 27 OpBuilder::InsertionGuard g(rewriter); 28 rewriter.setInsertionPoint(op); 29 30 // Nothing to do. This op is already bufferized. 31 if (op.hasBufferSemantics()) 32 return success(); 33 34 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need 35 // basis. 36 if (!op.hasTensorSemantics()) 37 return op->emitError() << "op does not have tensor semantics"; 38 39 // New input operands for the cloned op. 40 SmallVector<Value> newInputBuffers; 41 newInputBuffers.reserve(op.getNumInputs()); 42 for (OpOperand *opOperand : op.getInputOperands()) { 43 if (op.isScalar(opOperand)) { 44 newInputBuffers.push_back(opOperand->get()); 45 continue; 46 } 47 FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options); 48 if (failed(buffer)) 49 return failure(); 50 newInputBuffers.push_back(*buffer); 51 } 52 53 // New output operands for the cloned op. 54 SmallVector<Value> newOutputBuffers; 55 for (OpResult opResult : op->getOpResults()) { 56 OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber()); 57 FailureOr<Value> resultBuffer = 58 getBuffer(rewriter, opOperand->get(), options); 59 if (failed(resultBuffer)) 60 return failure(); 61 newOutputBuffers.push_back(*resultBuffer); 62 } 63 64 // Merge input/output operands. 65 SmallVector<Value> newOperands = newInputBuffers; 66 newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); 67 68 // Set insertion point now that potential alloc/dealloc are introduced. 69 rewriter.setInsertionPoint(op); 70 // Clone the op, but use the new operands. Move the existing block into the 71 // new op. Since the new op does not have any tensor results, it does not 72 // return anything. 73 assert(op->getNumRegions() == 1 && "expected that op has 1 region"); 74 auto newOp = cast<LinalgOp>(op.cloneWithoutRegions( 75 rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); 76 rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), 77 newOp->getRegion(0).begin()); 78 79 // Replace the results of the old op with the new output buffers. 80 replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); 81 82 return success(); 83 } 84 85 /// Bufferization of linalg.generic. Replace with a new linalg.generic that 86 /// operates entirely on memrefs. 87 template <typename OpTy> 88 struct LinalgOpInterface 89 : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, 90 OpTy> { 91 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 92 const AnalysisState &state) const { 93 // Operand is read if it is used in the computation. 94 auto genericOp = cast<linalg::LinalgOp>(op); 95 return genericOp.payloadUsesValueFromOperand(&opOperand); 96 } 97 98 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 99 const AnalysisState &state) const { 100 // Operand is written to if it has an aliasing OpResult. 101 auto bufferizableOp = cast<BufferizableOpInterface>(op); 102 return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); 103 } 104 105 SmallVector<OpOperand *> 106 getAliasingOpOperand(Operation *op, OpResult opResult, 107 const AnalysisState &state) const { 108 auto genericOp = cast<linalg::LinalgOp>(op); 109 110 // The i-th OpResult may alias with the i-th "out" tensor. 111 return {genericOp.getOutputOperand(opResult.getResultNumber())}; 112 } 113 114 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 115 const AnalysisState &state) const { 116 auto genericOp = cast<linalg::LinalgOp>(op); 117 118 // The i-th "out" tensor may alias with the i-th OpResult. 119 if (genericOp.isOutputTensor(&opOperand)) 120 return {genericOp.getTiedOpResult(&opOperand)}; 121 return {}; 122 } 123 124 BufferRelation bufferRelation(Operation *op, OpResult opResult, 125 const AnalysisState &state) const { 126 return BufferRelation::Equivalent; 127 } 128 129 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 130 const BufferizationOptions &options) const { 131 return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), options); 132 } 133 }; 134 135 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers 136 /// the `BufferizableOpInterface` with each of them. 137 template <typename... Ops> 138 struct LinalgOpInterfaceHelper { 139 static void registerOpInterface(MLIRContext *ctx) { 140 (void)std::initializer_list<int>{ 141 0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...}; 142 } 143 }; 144 } // namespace 145 146 void mlir::linalg::registerBufferizableOpInterfaceExternalModels( 147 DialectRegistry ®istry) { 148 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 149 // Register all Linalg structured ops. `LinalgOp` is an interface and it is 150 // not possible to attach an external interface to an existing interface. 151 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. 152 LinalgOpInterfaceHelper< 153 #define GET_OP_LIST 154 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 155 >::registerOpInterface(ctx); 156 }); 157 } 158