1 //===- TestTransformDialectInterpreter.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a test pass that interprets Transform dialect operations in
10 // the module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 /// Simple pass that applies transform dialect ops directly contained in a
22 /// module.
23 class TestTransformDialectInterpreterPass
24     : public PassWrapper<TestTransformDialectInterpreterPass,
25                          OperationPass<ModuleOp>> {
26 public:
27   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
28       TestTransformDialectInterpreterPass)
29 
30   StringRef getArgument() const override {
31     return "test-transform-dialect-interpreter";
32   }
33 
34   StringRef getDescription() const override {
35     return "apply transform dialect operations one by one";
36   }
37 
38   void runOnOperation() override {
39     ModuleOp module = getOperation();
40     transform::TransformState state(module.getBodyRegion(), module);
41     for (auto op :
42          module.getBody()->getOps<transform::TransformOpInterface>()) {
43       if (failed(state.applyTransform(op)))
44         return signalPassFailure();
45     }
46   }
47 };
48 } // namespace
49 
50 namespace mlir {
51 namespace test {
52 /// Registers the test pass for applying transform dialect ops.
53 void registerTestTransformDialectInterpreterPass() {
54   PassRegistration<TestTransformDialectInterpreterPass> reg;
55 }
56 } // namespace test
57 } // namespace mlir
58