1 //===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements constant folding transformations on TOSA operations
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
14 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 
19 using namespace mlir;
20 using namespace mlir::tosa;
21 
22 namespace {
23 
24 template <typename... Args>
addOpsCanonicalizations(MLIRContext * ctx,RewritePatternSet & patterns)25 void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
26   (void)std::initializer_list<int>{
27       0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...};
28 }
29 
populateTosaOpsCanonicalizationPatterns(MLIRContext * ctx,RewritePatternSet & patterns)30 void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
31                                              RewritePatternSet &patterns) {
32   addOpsCanonicalizations<
33 #define GET_OP_LIST
34 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
35       >(ctx, patterns);
36 }
37 
38 struct TosaLayerwiseConstantFoldPass
39     : public TosaLayerwiseConstantFoldPassBase<TosaLayerwiseConstantFoldPass> {
runOnOperation__anon208c7db40111::TosaLayerwiseConstantFoldPass40   void runOnOperation() override {
41     auto *ctx = &getContext();
42     RewritePatternSet patterns(ctx);
43     auto func = getOperation();
44 
45     mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
46     populateTosaOpsCanonicalizationPatterns(ctx, patterns);
47 
48     if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
49       signalPassFailure();
50   }
51 };
52 
53 } // namespace
54 
createTosaLayerwiseConstantFoldPass()55 std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
56   return std::make_unique<TosaLayerwiseConstantFoldPass>();
57 }
58