1 //===- TestLinalgDecomposeOps.cpp - Test Linalg decomposition  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing decomposition of Linalg ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 struct TestLinalgDecomposeOps
22     : public PassWrapper<TestLinalgDecomposeOps, OperationPass<>> {
23   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDecomposeOps)
24 
25   TestLinalgDecomposeOps() = default;
TestLinalgDecomposeOps__anonf044407b0111::TestLinalgDecomposeOps26   TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass)
27       : PassWrapper(pass) {}
getDependentDialects__anonf044407b0111::TestLinalgDecomposeOps28   void getDependentDialects(DialectRegistry &registry) const override {
29     registry.insert<AffineDialect, linalg::LinalgDialect>();
30   }
getArgument__anonf044407b0111::TestLinalgDecomposeOps31   StringRef getArgument() const final { return "test-linalg-decompose-ops"; }
getDescription__anonf044407b0111::TestLinalgDecomposeOps32   StringRef getDescription() const final {
33     return "Test Linalg decomposition patterns";
34   }
35 
runOnOperation__anonf044407b0111::TestLinalgDecomposeOps36   void runOnOperation() override {
37     MLIRContext *context = &this->getContext();
38     RewritePatternSet decompositionPatterns(context);
39     linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns);
40     if (failed(applyPatternsAndFoldGreedily(
41             getOperation(), std::move(decompositionPatterns)))) {
42       return signalPassFailure();
43     }
44   }
45 };
46 } // namespace
47 
48 namespace mlir {
49 namespace test {
registerTestLinalgDecomposeOps()50 void registerTestLinalgDecomposeOps() {
51   PassRegistration<TestLinalgDecomposeOps>();
52 }
53 } // namespace test
54 } // namespace mlir
55