1 //====----- Bufferize.cpp - Bufferization of shape ops  ---------*- C++-*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/Shape/IR/Shape.h"
15 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
16 #include "mlir/Dialect/Shape/Transforms/Passes.h"
17 #include "mlir/Pass/Pass.h"
18 
19 using namespace mlir;
20 using namespace bufferization;
21 
22 namespace {
23 struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
runOnOperation__anon10b1a01e0111::ShapeBufferizePass24   void runOnOperation() override {
25     BufferizationOptions options = getPartialBufferizationOptions();
26     options.opFilter.allowDialect<shape::ShapeDialect>();
27 
28     if (failed(bufferizeOp(getOperation(), options)))
29       signalPassFailure();
30   }
31 
getDependentDialects__anon10b1a01e0111::ShapeBufferizePass32   void getDependentDialects(DialectRegistry &registry) const override {
33     registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
34                     shape::ShapeDialect>();
35     shape::registerBufferizableOpInterfaceExternalModels(registry);
36   }
37 };
38 } // namespace
39 
createShapeBufferizePass()40 std::unique_ptr<OperationPass<func::FuncOp>> mlir::createShapeBufferizePass() {
41   return std::make_unique<ShapeBufferizePass>();
42 }
43