1d064c480SAlex Zinenko //===- TestTransformDialectInterpreter.cpp --------------------------------===// 2d064c480SAlex Zinenko // 3d064c480SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d064c480SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5d064c480SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d064c480SAlex Zinenko // 7d064c480SAlex Zinenko //===----------------------------------------------------------------------===// 8d064c480SAlex Zinenko // 9d064c480SAlex Zinenko // This file defines a test pass that interprets Transform dialect operations in 10d064c480SAlex Zinenko // the module. 11d064c480SAlex Zinenko // 12d064c480SAlex Zinenko //===----------------------------------------------------------------------===// 13d064c480SAlex Zinenko 14d064c480SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 15d064c480SAlex Zinenko #include "mlir/IR/BuiltinOps.h" 16d064c480SAlex Zinenko #include "mlir/Pass/Pass.h" 17d064c480SAlex Zinenko 18d064c480SAlex Zinenko using namespace mlir; 19d064c480SAlex Zinenko 20d064c480SAlex Zinenko namespace { 21d064c480SAlex Zinenko /// Simple pass that applies transform dialect ops directly contained in a 22d064c480SAlex Zinenko /// module. 23d064c480SAlex Zinenko class TestTransformDialectInterpreterPass 24d064c480SAlex Zinenko : public PassWrapper<TestTransformDialectInterpreterPass, 25d064c480SAlex Zinenko OperationPass<ModuleOp>> { 26d064c480SAlex Zinenko public: 27d064c480SAlex Zinenko MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 28d064c480SAlex Zinenko TestTransformDialectInterpreterPass) 29d064c480SAlex Zinenko 306403e1b1SAlex Zinenko TestTransformDialectInterpreterPass() = default; TestTransformDialectInterpreterPass(const TestTransformDialectInterpreterPass &)316403e1b1SAlex Zinenko TestTransformDialectInterpreterPass( 326403e1b1SAlex Zinenko const TestTransformDialectInterpreterPass &) {} 336403e1b1SAlex Zinenko getArgument() const34d064c480SAlex Zinenko StringRef getArgument() const override { 35d064c480SAlex Zinenko return "test-transform-dialect-interpreter"; 36d064c480SAlex Zinenko } 37d064c480SAlex Zinenko getDescription() const38d064c480SAlex Zinenko StringRef getDescription() const override { 39d064c480SAlex Zinenko return "apply transform dialect operations one by one"; 40d064c480SAlex Zinenko } 41d064c480SAlex Zinenko runOnOperation()42d064c480SAlex Zinenko void runOnOperation() override { 43d064c480SAlex Zinenko ModuleOp module = getOperation(); 446403e1b1SAlex Zinenko transform::TransformState state( 456403e1b1SAlex Zinenko module.getBodyRegion(), module, 466403e1b1SAlex Zinenko transform::TransformOptions().enableExpensiveChecks( 476403e1b1SAlex Zinenko enableExpensiveChecks)); 48d064c480SAlex Zinenko for (auto op : 49d064c480SAlex Zinenko module.getBody()->getOps<transform::TransformOpInterface>()) { 50*e3890b7fSAlex Zinenko if (failed(state.applyTransform(op).checkAndReport())) 51d064c480SAlex Zinenko return signalPassFailure(); 52d064c480SAlex Zinenko } 53d064c480SAlex Zinenko } 546403e1b1SAlex Zinenko 556403e1b1SAlex Zinenko Option<bool> enableExpensiveChecks{ 566403e1b1SAlex Zinenko *this, "enable-expensive-checks", llvm::cl::init(false), 576403e1b1SAlex Zinenko llvm::cl::desc("perform expensive checks to better report errors in the " 586403e1b1SAlex Zinenko "transform IR")}; 59d064c480SAlex Zinenko }; 60d064c480SAlex Zinenko } // namespace 61d064c480SAlex Zinenko 62d064c480SAlex Zinenko namespace mlir { 63d064c480SAlex Zinenko namespace test { 64d064c480SAlex Zinenko /// Registers the test pass for applying transform dialect ops. registerTestTransformDialectInterpreterPass()65d064c480SAlex Zinenkovoid registerTestTransformDialectInterpreterPass() { 66d064c480SAlex Zinenko PassRegistration<TestTransformDialectInterpreterPass> reg; 67d064c480SAlex Zinenko } 68d064c480SAlex Zinenko } // namespace test 69d064c480SAlex Zinenko } // namespace mlir 70