1 //===- TosaOptionalDecompositions.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pass to apply the Tosa operations decompositions
10 // exposed as populate functions in
11 // include/mlir/Dialect/Tosa/Transforms/Passes.h
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
17 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 
25 struct TosaOptionalDecompositions
26     : public TosaOptionalDecompositionsBase<TosaOptionalDecompositions> {
runOnOperation__anon1c4fdb2f0111::TosaOptionalDecompositions27   void runOnOperation() override {
28     auto *ctx = &getContext();
29     RewritePatternSet patterns(ctx);
30     auto func = getOperation();
31 
32     mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns);
33     mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns);
34     mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns);
35 
36     if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
37       signalPassFailure();
38   }
39 };
40 
41 } // namespace
42 
createTosaOptionalDecompositions()43 std::unique_ptr<Pass> mlir::tosa::createTosaOptionalDecompositions() {
44   return std::make_unique<TosaOptionalDecompositions>();
45 }
46