1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/DialectConversion.h" 10 #include "gtest/gtest.h" 11 12 using namespace mlir; 13 14 static Operation *createOp(MLIRContext *context) { 15 context->allowUnregisteredDialects(); 16 return Operation::create(UnknownLoc::get(context), 17 OperationName("foo.bar", context), llvm::None, 18 llvm::None, llvm::None, llvm::None, 0); 19 } 20 21 namespace { 22 struct DummyOp { 23 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyOp) 24 25 static StringRef getOperationName() { return "foo.bar"; } 26 }; 27 28 TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) { 29 MLIRContext context; 30 ConversionTarget target(context); 31 32 int index = 0; 33 int callbackCalled1 = 0; 34 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) { 35 callbackCalled1 = ++index; 36 return true; 37 }); 38 39 int callbackCalled2 = 0; 40 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> { 41 callbackCalled2 = ++index; 42 return llvm::None; 43 }); 44 45 auto *op = createOp(&context); 46 EXPECT_TRUE(target.isLegal(op)); 47 EXPECT_EQ(2, callbackCalled1); 48 EXPECT_EQ(1, callbackCalled2); 49 EXPECT_FALSE(target.isIllegal(op)); 50 EXPECT_EQ(4, callbackCalled1); 51 EXPECT_EQ(3, callbackCalled2); 52 op->destroy(); 53 } 54 55 TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) { 56 MLIRContext context; 57 ConversionTarget target(context); 58 59 int index = 0; 60 int callbackCalled = 0; 61 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> { 62 callbackCalled = ++index; 63 return llvm::None; 64 }); 65 66 auto *op = createOp(&context); 67 EXPECT_FALSE(target.isLegal(op)); 68 EXPECT_EQ(1, callbackCalled); 69 EXPECT_FALSE(target.isIllegal(op)); 70 EXPECT_EQ(2, callbackCalled); 71 op->destroy(); 72 } 73 74 TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) { 75 MLIRContext context; 76 ConversionTarget target(context); 77 78 int index = 0; 79 int callbackCalled1 = 0; 80 target.markUnknownOpDynamicallyLegal([&](Operation *) { 81 callbackCalled1 = ++index; 82 return true; 83 }); 84 85 int callbackCalled2 = 0; 86 target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> { 87 callbackCalled2 = ++index; 88 return llvm::None; 89 }); 90 91 auto *op = createOp(&context); 92 EXPECT_TRUE(target.isLegal(op)); 93 EXPECT_EQ(2, callbackCalled1); 94 EXPECT_EQ(1, callbackCalled2); 95 EXPECT_FALSE(target.isIllegal(op)); 96 EXPECT_EQ(4, callbackCalled1); 97 EXPECT_EQ(3, callbackCalled2); 98 op->destroy(); 99 } 100 101 TEST(DialectConversionTest, DynamicallyLegalReturnNone) { 102 MLIRContext context; 103 ConversionTarget target(context); 104 105 target.addDynamicallyLegalOp<DummyOp>( 106 [&](Operation *) -> Optional<bool> { return llvm::None; }); 107 108 auto *op = createOp(&context); 109 EXPECT_FALSE(target.isLegal(op)); 110 EXPECT_FALSE(target.isIllegal(op)); 111 112 EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {}))); 113 EXPECT_TRUE(failed(applyFullConversion(op, target, {}))); 114 115 op->destroy(); 116 } 117 118 TEST(DialectConversionTest, DynamicallyLegalUnknownReturnNone) { 119 MLIRContext context; 120 ConversionTarget target(context); 121 122 target.markUnknownOpDynamicallyLegal( 123 [&](Operation *) -> Optional<bool> { return llvm::None; }); 124 125 auto *op = createOp(&context); 126 EXPECT_FALSE(target.isLegal(op)); 127 EXPECT_FALSE(target.isIllegal(op)); 128 129 EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {}))); 130 EXPECT_TRUE(failed(applyFullConversion(op, target, {}))); 131 132 op->destroy(); 133 } 134 } // namespace 135