1 //===- TestTransformDialectInterpreter.cpp --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a test pass that interprets Transform dialect operations in 10 // the module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 20 namespace { 21 /// Simple pass that applies transform dialect ops directly contained in a 22 /// module. 23 class TestTransformDialectInterpreterPass 24 : public PassWrapper<TestTransformDialectInterpreterPass, 25 OperationPass<ModuleOp>> { 26 public: 27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 28 TestTransformDialectInterpreterPass) 29 30 StringRef getArgument() const override { 31 return "test-transform-dialect-interpreter"; 32 } 33 34 StringRef getDescription() const override { 35 return "apply transform dialect operations one by one"; 36 } 37 38 void runOnOperation() override { 39 ModuleOp module = getOperation(); 40 transform::TransformState state(module.getBodyRegion(), module); 41 for (auto op : 42 module.getBody()->getOps<transform::TransformOpInterface>()) { 43 if (failed(state.applyTransform(op))) 44 return signalPassFailure(); 45 } 46 } 47 }; 48 } // namespace 49 50 namespace mlir { 51 namespace test { 52 /// Registers the test pass for applying transform dialect ops. 53 void registerTestTransformDialectInterpreterPass() { 54 PassRegistration<TestTransformDialectInterpreterPass> reg; 55 } 56 } // namespace test 57 } // namespace mlir 58