1 //====----- Bufferize.cpp - Bufferization of shape ops ---------*- C++-*--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/Shape/IR/Shape.h" 15 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" 16 #include "mlir/Dialect/Shape/Transforms/Passes.h" 17 #include "mlir/Pass/Pass.h" 18 19 using namespace mlir; 20 using namespace bufferization; 21 22 namespace { 23 struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> { runOnOperation__anon10b1a01e0111::ShapeBufferizePass24 void runOnOperation() override { 25 BufferizationOptions options = getPartialBufferizationOptions(); 26 options.opFilter.allowDialect<shape::ShapeDialect>(); 27 28 if (failed(bufferizeOp(getOperation(), options))) 29 signalPassFailure(); 30 } 31 getDependentDialects__anon10b1a01e0111::ShapeBufferizePass32 void getDependentDialects(DialectRegistry ®istry) const override { 33 registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect, 34 shape::ShapeDialect>(); 35 shape::registerBufferizableOpInterfaceExternalModels(registry); 36 } 37 }; 38 } // namespace 39 createShapeBufferizePass()40std::unique_ptr<OperationPass<func::FuncOp>> mlir::createShapeBufferizePass() { 41 return std::make_unique<ShapeBufferizePass>(); 42 } 43