1 //===- TestTransformDialectInterpreter.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a test pass that interprets Transform dialect operations in
10 // the module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 /// Simple pass that applies transform dialect ops directly contained in a
22 /// module.
23 class TestTransformDialectInterpreterPass
24     : public PassWrapper<TestTransformDialectInterpreterPass,
25                          OperationPass<ModuleOp>> {
26 public:
27   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
28       TestTransformDialectInterpreterPass)
29 
30   TestTransformDialectInterpreterPass() = default;
TestTransformDialectInterpreterPass(const TestTransformDialectInterpreterPass &)31   TestTransformDialectInterpreterPass(
32       const TestTransformDialectInterpreterPass &) {}
33 
getArgument() const34   StringRef getArgument() const override {
35     return "test-transform-dialect-interpreter";
36   }
37 
getDescription() const38   StringRef getDescription() const override {
39     return "apply transform dialect operations one by one";
40   }
41 
runOnOperation()42   void runOnOperation() override {
43     ModuleOp module = getOperation();
44     transform::TransformState state(
45         module.getBodyRegion(), module,
46         transform::TransformOptions().enableExpensiveChecks(
47             enableExpensiveChecks));
48     for (auto op :
49          module.getBody()->getOps<transform::TransformOpInterface>()) {
50       if (failed(state.applyTransform(op).checkAndReport()))
51         return signalPassFailure();
52     }
53   }
54 
55   Option<bool> enableExpensiveChecks{
56       *this, "enable-expensive-checks", llvm::cl::init(false),
57       llvm::cl::desc("perform expensive checks to better report errors in the "
58                      "transform IR")};
59 };
60 } // namespace
61 
62 namespace mlir {
63 namespace test {
64 /// Registers the test pass for applying transform dialect ops.
registerTestTransformDialectInterpreterPass()65 void registerTestTransformDialectInterpreterPass() {
66   PassRegistration<TestTransformDialectInterpreterPass> reg;
67 }
68 } // namespace test
69 } // namespace mlir
70