1 //===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements constant folding transformations on TOSA operations
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
14 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18
19 using namespace mlir;
20 using namespace mlir::tosa;
21
22 namespace {
23
24 template <typename... Args>
addOpsCanonicalizations(MLIRContext * ctx,RewritePatternSet & patterns)25 void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
26 (void)std::initializer_list<int>{
27 0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...};
28 }
29
populateTosaOpsCanonicalizationPatterns(MLIRContext * ctx,RewritePatternSet & patterns)30 void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
31 RewritePatternSet &patterns) {
32 addOpsCanonicalizations<
33 #define GET_OP_LIST
34 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
35 >(ctx, patterns);
36 }
37
38 struct TosaLayerwiseConstantFoldPass
39 : public TosaLayerwiseConstantFoldPassBase<TosaLayerwiseConstantFoldPass> {
runOnOperation__anon208c7db40111::TosaLayerwiseConstantFoldPass40 void runOnOperation() override {
41 auto *ctx = &getContext();
42 RewritePatternSet patterns(ctx);
43 auto func = getOperation();
44
45 mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
46 populateTosaOpsCanonicalizationPatterns(ctx, patterns);
47
48 if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
49 signalPassFailure();
50 }
51 };
52
53 } // namespace
54
createTosaLayerwiseConstantFoldPass()55 std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
56 return std::make_unique<TosaLayerwiseConstantFoldPass>();
57 }
58