1 //===- TestTransformDialectInterpreter.cpp --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a test pass that interprets Transform dialect operations in 10 // the module. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 20 namespace { 21 /// Simple pass that applies transform dialect ops directly contained in a 22 /// module. 23 class TestTransformDialectInterpreterPass 24 : public PassWrapper<TestTransformDialectInterpreterPass, 25 OperationPass<ModuleOp>> { 26 public: 27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 28 TestTransformDialectInterpreterPass) 29 30 TestTransformDialectInterpreterPass() = default; TestTransformDialectInterpreterPass(const TestTransformDialectInterpreterPass &)31 TestTransformDialectInterpreterPass( 32 const TestTransformDialectInterpreterPass &) {} 33 getArgument() const34 StringRef getArgument() const override { 35 return "test-transform-dialect-interpreter"; 36 } 37 getDescription() const38 StringRef getDescription() const override { 39 return "apply transform dialect operations one by one"; 40 } 41 runOnOperation()42 void runOnOperation() override { 43 ModuleOp module = getOperation(); 44 transform::TransformState state( 45 module.getBodyRegion(), module, 46 transform::TransformOptions().enableExpensiveChecks( 47 enableExpensiveChecks)); 48 for (auto op : 49 module.getBody()->getOps<transform::TransformOpInterface>()) { 50 if (failed(state.applyTransform(op).checkAndReport())) 51 return signalPassFailure(); 52 } 53 } 54 55 Option<bool> enableExpensiveChecks{ 56 *this, "enable-expensive-checks", llvm::cl::init(false), 57 llvm::cl::desc("perform expensive checks to better report errors in the " 58 "transform IR")}; 59 }; 60 } // namespace 61 62 namespace mlir { 63 namespace test { 64 /// Registers the test pass for applying transform dialect ops. registerTestTransformDialectInterpreterPass()65void registerTestTransformDialectInterpreterPass() { 66 PassRegistration<TestTransformDialectInterpreterPass> reg; 67 } 68 } // namespace test 69 } // namespace mlir 70