//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" using namespace mlir; using namespace mlir::bufferization; using namespace mlir::vector; namespace mlir { namespace vector { namespace { /// Bufferization of vector.transfer_read. Replaced with a new /// vector.transfer_read that operates on a memref. struct TransferReadOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto readOp = cast(op); assert(readOp.getShapedType().isa() && "only tensor types expected"); FailureOr buffer = getBuffer(rewriter, readOp.getSource(), options); if (failed(buffer)) return failure(); replaceOpWithNewBufferizedOp( rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr()); return success(); } }; /// Bufferization of vector.transfer_write. Replace with a new /// vector.transfer_write that operates on a memref. struct TransferWriteOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return {op->getOpResult(0)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto writeOp = cast(op); assert(writeOp.getShapedType().isa() && "only tensor types expected"); // Create a new transfer_write on buffer that doesn't have a return value. FailureOr resultBuffer = getBuffer(rewriter, writeOp.getSource(), options); if (failed(resultBuffer)) return failure(); rewriter.create( writeOp.getLoc(), writeOp.getVector(), *resultBuffer, writeOp.getIndices(), writeOp.getPermutationMapAttr(), writeOp.getMask(), writeOp.getInBoundsAttr()); replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); return success(); } }; } // namespace } // namespace vector } // namespace mlir void mlir::vector::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) { TransferReadOp::attachInterface(*ctx); TransferWriteOp::attachInterface(*ctx); }); }