1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "gtest/gtest.h"
11 
12 using namespace mlir;
13 
createOp(MLIRContext * context)14 static Operation *createOp(MLIRContext *context) {
15   context->allowUnregisteredDialects();
16   return Operation::create(UnknownLoc::get(context),
17                            OperationName("foo.bar", context), llvm::None,
18                            llvm::None, llvm::None, llvm::None, 0);
19 }
20 
21 namespace {
22 struct DummyOp {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon19825acf0111::DummyOp23   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyOp)
24 
25   static StringRef getOperationName() { return "foo.bar"; }
26 };
27 
TEST(DialectConversionTest,DynamicallyLegalOpCallbackOrder)28 TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
29   MLIRContext context;
30   ConversionTarget target(context);
31 
32   int index = 0;
33   int callbackCalled1 = 0;
34   target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
35     callbackCalled1 = ++index;
36     return true;
37   });
38 
39   int callbackCalled2 = 0;
40   target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
41     callbackCalled2 = ++index;
42     return llvm::None;
43   });
44 
45   auto *op = createOp(&context);
46   EXPECT_TRUE(target.isLegal(op));
47   EXPECT_EQ(2, callbackCalled1);
48   EXPECT_EQ(1, callbackCalled2);
49   EXPECT_FALSE(target.isIllegal(op));
50   EXPECT_EQ(4, callbackCalled1);
51   EXPECT_EQ(3, callbackCalled2);
52   op->destroy();
53 }
54 
TEST(DialectConversionTest,DynamicallyLegalOpCallbackSkip)55 TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
56   MLIRContext context;
57   ConversionTarget target(context);
58 
59   int index = 0;
60   int callbackCalled = 0;
61   target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
62     callbackCalled = ++index;
63     return llvm::None;
64   });
65 
66   auto *op = createOp(&context);
67   EXPECT_FALSE(target.isLegal(op));
68   EXPECT_EQ(1, callbackCalled);
69   EXPECT_FALSE(target.isIllegal(op));
70   EXPECT_EQ(2, callbackCalled);
71   op->destroy();
72 }
73 
TEST(DialectConversionTest,DynamicallyLegalUnknownOpCallbackOrder)74 TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
75   MLIRContext context;
76   ConversionTarget target(context);
77 
78   int index = 0;
79   int callbackCalled1 = 0;
80   target.markUnknownOpDynamicallyLegal([&](Operation *) {
81     callbackCalled1 = ++index;
82     return true;
83   });
84 
85   int callbackCalled2 = 0;
86   target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> {
87     callbackCalled2 = ++index;
88     return llvm::None;
89   });
90 
91   auto *op = createOp(&context);
92   EXPECT_TRUE(target.isLegal(op));
93   EXPECT_EQ(2, callbackCalled1);
94   EXPECT_EQ(1, callbackCalled2);
95   EXPECT_FALSE(target.isIllegal(op));
96   EXPECT_EQ(4, callbackCalled1);
97   EXPECT_EQ(3, callbackCalled2);
98   op->destroy();
99 }
100 
TEST(DialectConversionTest,DynamicallyLegalReturnNone)101 TEST(DialectConversionTest, DynamicallyLegalReturnNone) {
102   MLIRContext context;
103   ConversionTarget target(context);
104 
105   target.addDynamicallyLegalOp<DummyOp>(
106       [&](Operation *) -> Optional<bool> { return llvm::None; });
107 
108   auto *op = createOp(&context);
109   EXPECT_FALSE(target.isLegal(op));
110   EXPECT_FALSE(target.isIllegal(op));
111 
112   EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
113   EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
114 
115   op->destroy();
116 }
117 
TEST(DialectConversionTest,DynamicallyLegalUnknownReturnNone)118 TEST(DialectConversionTest, DynamicallyLegalUnknownReturnNone) {
119   MLIRContext context;
120   ConversionTarget target(context);
121 
122   target.markUnknownOpDynamicallyLegal(
123       [&](Operation *) -> Optional<bool> { return llvm::None; });
124 
125   auto *op = createOp(&context);
126   EXPECT_FALSE(target.isLegal(op));
127   EXPECT_FALSE(target.isIllegal(op));
128 
129   EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
130   EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
131 
132   op->destroy();
133 }
134 } // namespace
135