1 //===- PassGenTest.cpp - TableGen PassGen Tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/Pass.h"
10 
11 #include "gmock/gmock.h"
12 
13 std::unique_ptr<mlir::Pass> createTestPass(int v = 0);
14 
15 #define GEN_PASS_REGISTRATION
16 #include "PassGenTest.h.inc"
17 
18 #define GEN_PASS_CLASSES
19 #include "PassGenTest.h.inc"
20 
21 struct TestPass : public TestPassBase<TestPass> {
TestPassTestPass22   explicit TestPass(int v) : extraVal(v) {}
23 
runOnOperationTestPass24   void runOnOperation() override {}
25 
cloneTestPass26   std::unique_ptr<mlir::Pass> clone() const {
27     return TestPassBase<TestPass>::clone();
28   }
29 
30   int extraVal;
31 };
32 
createTestPass(int v)33 std::unique_ptr<mlir::Pass> createTestPass(int v) {
34   return std::make_unique<TestPass>(v);
35 }
36 
TEST(PassGenTest,PassClone)37 TEST(PassGenTest, PassClone) {
38   mlir::MLIRContext context;
39 
40   const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
41     return static_cast<const TestPass *>(pass.get());
42   };
43 
44   const auto origPass = createTestPass(10);
45   const auto clonePass = unwrap(origPass)->clone();
46 
47   EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal);
48 }
49