1444822d7SSean Silva //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===// 2444822d7SSean Silva // 3444822d7SSean Silva // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4444822d7SSean Silva // See https://llvm.org/LICENSE.txt for license information. 5444822d7SSean Silva // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6444822d7SSean Silva // 7444822d7SSean Silva //===----------------------------------------------------------------------===// 8444822d7SSean Silva // 9444822d7SSean Silva // This file implements bufferization of `tensor` dialect ops 10444822d7SSean Silva // 11444822d7SSean Silva //===----------------------------------------------------------------------===// 12444822d7SSean Silva 13f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14444822d7SSean Silva #include "PassDetail.h" 15a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16daf18108SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1757470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 18e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h" 19*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 20444822d7SSean Silva #include "mlir/Dialect/Tensor/IR/Tensor.h" 21daf18108SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 22444822d7SSean Silva #include "mlir/Dialect/Tensor/Transforms/Passes.h" 23f77e9f87SAlexander Belyaev #include "mlir/IR/ImplicitLocOpBuilder.h" 24444822d7SSean Silva #include "mlir/Transforms/DialectConversion.h" 25444822d7SSean Silva 26444822d7SSean Silva using namespace mlir; 27daf18108SMatthias Springer using namespace bufferization; 28444822d7SSean Silva 29444822d7SSean Silva namespace { 30a82a19c1SAlexander Belyaev struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { runOnOperation__anone461fda80111::TensorBufferizePass3141574554SRiver Riddle void runOnOperation() override { 32cdb7675cSMatthias Springer BufferizationOptions options = getPartialBufferizationOptions(); 331534177fSMatthias Springer options.opFilter.allowDialect<tensor::TensorDialect>(); 34a82a19c1SAlexander Belyaev 35cdb7675cSMatthias Springer if (failed(bufferizeOp(getOperation(), options))) 36a82a19c1SAlexander Belyaev signalPassFailure(); 37a82a19c1SAlexander Belyaev } 38a82a19c1SAlexander Belyaev getDependentDialects__anone461fda80111::TensorBufferizePass39daf18108SMatthias Springer void getDependentDialects(DialectRegistry ®istry) const override { 40daf18108SMatthias Springer registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect, 41daf18108SMatthias Springer tensor::TensorDialect, scf::SCFDialect, 42daf18108SMatthias Springer arith::ArithmeticDialect>(); 43daf18108SMatthias Springer tensor::registerBufferizableOpInterfaceExternalModels(registry); 44444822d7SSean Silva } 45daf18108SMatthias Springer }; 46daf18108SMatthias Springer } // namespace 47444822d7SSean Silva createTensorBufferizePass()48444822d7SSean Silvastd::unique_ptr<Pass> mlir::createTensorBufferizePass() { 49444822d7SSean Silva return std::make_unique<TensorBufferizePass>(); 50444822d7SSean Silva } 51