1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "gtest/gtest.h"
11 
12 using namespace mlir;
13 
14 static Operation *createOp(MLIRContext *context) {
15   context->allowUnregisteredDialects();
16   return Operation::create(UnknownLoc::get(context),
17                            OperationName("foo.bar", context), llvm::None,
18                            llvm::None, llvm::None, llvm::None, 0);
19 }
20 
21 namespace {
22 struct DummyOp {
23   static StringRef getOperationName() { return "foo.bar"; }
24 };
25 
26 TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
27   MLIRContext context;
28   ConversionTarget target(context);
29 
30   int index = 0;
31   int callbackCalled1 = 0;
32   target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
33     callbackCalled1 = ++index;
34     return true;
35   });
36 
37   int callbackCalled2 = 0;
38   target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
39     callbackCalled2 = ++index;
40     return llvm::None;
41   });
42 
43   auto *op = createOp(&context);
44   EXPECT_TRUE(target.isLegal(op));
45   EXPECT_EQ(2, callbackCalled1);
46   EXPECT_EQ(1, callbackCalled2);
47   EXPECT_FALSE(target.isIllegal(op));
48   EXPECT_EQ(4, callbackCalled1);
49   EXPECT_EQ(3, callbackCalled2);
50   op->destroy();
51 }
52 
53 TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
54   MLIRContext context;
55   ConversionTarget target(context);
56 
57   int index = 0;
58   int callbackCalled = 0;
59   target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
60     callbackCalled = ++index;
61     return llvm::None;
62   });
63 
64   auto *op = createOp(&context);
65   EXPECT_FALSE(target.isLegal(op));
66   EXPECT_EQ(1, callbackCalled);
67   EXPECT_FALSE(target.isIllegal(op));
68   EXPECT_EQ(2, callbackCalled);
69   op->destroy();
70 }
71 
72 TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
73   MLIRContext context;
74   ConversionTarget target(context);
75 
76   int index = 0;
77   int callbackCalled1 = 0;
78   target.markUnknownOpDynamicallyLegal([&](Operation *) {
79     callbackCalled1 = ++index;
80     return true;
81   });
82 
83   int callbackCalled2 = 0;
84   target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> {
85     callbackCalled2 = ++index;
86     return llvm::None;
87   });
88 
89   auto *op = createOp(&context);
90   EXPECT_TRUE(target.isLegal(op));
91   EXPECT_EQ(2, callbackCalled1);
92   EXPECT_EQ(1, callbackCalled2);
93   EXPECT_FALSE(target.isIllegal(op));
94   EXPECT_EQ(4, callbackCalled1);
95   EXPECT_EQ(3, callbackCalled2);
96   op->destroy();
97 }
98 
99 TEST(DialectConversionTest, DynamicallyLegalReturnNone) {
100   MLIRContext context;
101   ConversionTarget target(context);
102 
103   target.addDynamicallyLegalOp<DummyOp>(
104       [&](Operation *) -> Optional<bool> { return llvm::None; });
105 
106   auto *op = createOp(&context);
107   EXPECT_FALSE(target.isLegal(op));
108   EXPECT_FALSE(target.isIllegal(op));
109 
110   EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
111   EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
112 
113   op->destroy();
114 }
115 
116 TEST(DialectConversionTest, DynamicallyLegalUnknownReturnNone) {
117   MLIRContext context;
118   ConversionTarget target(context);
119 
120   target.markUnknownOpDynamicallyLegal(
121       [&](Operation *) -> Optional<bool> { return llvm::None; });
122 
123   auto *op = createOp(&context);
124   EXPECT_FALSE(target.isLegal(op));
125   EXPECT_FALSE(target.isIllegal(op));
126 
127   EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
128   EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
129 
130   op->destroy();
131 }
132 } // namespace
133