1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "gtest/gtest.h"
11 
12 using namespace mlir;
13 
14 static Operation *createOp(MLIRContext *context) {
15   context->allowUnregisteredDialects();
16   return Operation::create(UnknownLoc::get(context),
17                            OperationName("foo.bar", context), llvm::None,
18                            llvm::None, llvm::None, llvm::None, 0);
19 }
20 
21 namespace {
22 struct DummyOp {
23   static StringRef getOperationName() { return "foo.bar"; }
24 };
25 
26 TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
27   MLIRContext context;
28   ConversionTarget target(context);
29 
30   int index = 0;
31   int callbackCalled1 = 0;
32   target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
33     callbackCalled1 = ++index;
34     return true;
35   });
36 
37   int callbackCalled2 = 0;
38   target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
39     callbackCalled2 = ++index;
40     return llvm::None;
41   });
42 
43   auto *op = createOp(&context);
44   EXPECT_TRUE(target.isLegal(op));
45   EXPECT_EQ(2, callbackCalled1);
46   EXPECT_EQ(1, callbackCalled2);
47   op->destroy();
48 }
49 
50 TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
51   MLIRContext context;
52   ConversionTarget target(context);
53 
54   int index = 0;
55   int callbackCalled = 0;
56   target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
57     callbackCalled = ++index;
58     return llvm::None;
59   });
60 
61   auto *op = createOp(&context);
62   EXPECT_FALSE(target.isLegal(op));
63   EXPECT_EQ(1, callbackCalled);
64   op->destroy();
65 }
66 
67 TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
68   MLIRContext context;
69   ConversionTarget target(context);
70 
71   int index = 0;
72   int callbackCalled1 = 0;
73   target.markUnknownOpDynamicallyLegal([&](Operation *) {
74     callbackCalled1 = ++index;
75     return true;
76   });
77 
78   int callbackCalled2 = 0;
79   target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> {
80     callbackCalled2 = ++index;
81     return llvm::None;
82   });
83 
84   auto *op = createOp(&context);
85   EXPECT_TRUE(target.isLegal(op));
86   EXPECT_EQ(2, callbackCalled1);
87   EXPECT_EQ(1, callbackCalled2);
88   op->destroy();
89 }
90 } // namespace
91