1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "gtest/gtest.h"
11
12 using namespace mlir;
13
createOp(MLIRContext * context)14 static Operation *createOp(MLIRContext *context) {
15 context->allowUnregisteredDialects();
16 return Operation::create(UnknownLoc::get(context),
17 OperationName("foo.bar", context), llvm::None,
18 llvm::None, llvm::None, llvm::None, 0);
19 }
20
21 namespace {
22 struct DummyOp {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon19825acf0111::DummyOp23 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyOp)
24
25 static StringRef getOperationName() { return "foo.bar"; }
26 };
27
TEST(DialectConversionTest,DynamicallyLegalOpCallbackOrder)28 TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
29 MLIRContext context;
30 ConversionTarget target(context);
31
32 int index = 0;
33 int callbackCalled1 = 0;
34 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
35 callbackCalled1 = ++index;
36 return true;
37 });
38
39 int callbackCalled2 = 0;
40 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
41 callbackCalled2 = ++index;
42 return llvm::None;
43 });
44
45 auto *op = createOp(&context);
46 EXPECT_TRUE(target.isLegal(op));
47 EXPECT_EQ(2, callbackCalled1);
48 EXPECT_EQ(1, callbackCalled2);
49 EXPECT_FALSE(target.isIllegal(op));
50 EXPECT_EQ(4, callbackCalled1);
51 EXPECT_EQ(3, callbackCalled2);
52 op->destroy();
53 }
54
TEST(DialectConversionTest,DynamicallyLegalOpCallbackSkip)55 TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
56 MLIRContext context;
57 ConversionTarget target(context);
58
59 int index = 0;
60 int callbackCalled = 0;
61 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
62 callbackCalled = ++index;
63 return llvm::None;
64 });
65
66 auto *op = createOp(&context);
67 EXPECT_FALSE(target.isLegal(op));
68 EXPECT_EQ(1, callbackCalled);
69 EXPECT_FALSE(target.isIllegal(op));
70 EXPECT_EQ(2, callbackCalled);
71 op->destroy();
72 }
73
TEST(DialectConversionTest,DynamicallyLegalUnknownOpCallbackOrder)74 TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
75 MLIRContext context;
76 ConversionTarget target(context);
77
78 int index = 0;
79 int callbackCalled1 = 0;
80 target.markUnknownOpDynamicallyLegal([&](Operation *) {
81 callbackCalled1 = ++index;
82 return true;
83 });
84
85 int callbackCalled2 = 0;
86 target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> {
87 callbackCalled2 = ++index;
88 return llvm::None;
89 });
90
91 auto *op = createOp(&context);
92 EXPECT_TRUE(target.isLegal(op));
93 EXPECT_EQ(2, callbackCalled1);
94 EXPECT_EQ(1, callbackCalled2);
95 EXPECT_FALSE(target.isIllegal(op));
96 EXPECT_EQ(4, callbackCalled1);
97 EXPECT_EQ(3, callbackCalled2);
98 op->destroy();
99 }
100
TEST(DialectConversionTest,DynamicallyLegalReturnNone)101 TEST(DialectConversionTest, DynamicallyLegalReturnNone) {
102 MLIRContext context;
103 ConversionTarget target(context);
104
105 target.addDynamicallyLegalOp<DummyOp>(
106 [&](Operation *) -> Optional<bool> { return llvm::None; });
107
108 auto *op = createOp(&context);
109 EXPECT_FALSE(target.isLegal(op));
110 EXPECT_FALSE(target.isIllegal(op));
111
112 EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
113 EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
114
115 op->destroy();
116 }
117
TEST(DialectConversionTest,DynamicallyLegalUnknownReturnNone)118 TEST(DialectConversionTest, DynamicallyLegalUnknownReturnNone) {
119 MLIRContext context;
120 ConversionTarget target(context);
121
122 target.markUnknownOpDynamicallyLegal(
123 [&](Operation *) -> Optional<bool> { return llvm::None; });
124
125 auto *op = createOp(&context);
126 EXPECT_FALSE(target.isLegal(op));
127 EXPECT_FALSE(target.isIllegal(op));
128
129 EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
130 EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
131
132 op->destroy();
133 }
134 } // namespace
135