1 //===- TestTransformStateExtension.h - Test Utility -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an TransformState extension for the purpose of testing the 10 // relevant APIs. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H 15 #define MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H 16 17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 18 19 using namespace mlir; 20 21 namespace mlir { 22 namespace test { 23 class TestTransformStateExtension 24 : public transform::TransformState::Extension { 25 public: TestTransformStateExtension(transform::TransformState & state,StringAttr message)26 TestTransformStateExtension(transform::TransformState &state, 27 StringAttr message) 28 : Extension(state), message(message) {} 29 getMessage()30 StringRef getMessage() const { return message.getValue(); } 31 updateMapping(Operation * previous,Operation * updated)32 LogicalResult updateMapping(Operation *previous, Operation *updated) { 33 return replacePayloadOp(previous, updated); 34 } 35 36 private: 37 StringAttr message; 38 }; 39 } // namespace test 40 } // namespace mlir 41 42 #endif // MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H 43