1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/DialectConversion.h" 10 #include "gtest/gtest.h" 11 12 using namespace mlir; 13 14 static Operation *createOp(MLIRContext *context) { 15 context->allowUnregisteredDialects(); 16 return Operation::create(UnknownLoc::get(context), 17 OperationName("foo.bar", context), llvm::None, 18 llvm::None, llvm::None, llvm::None, 0); 19 } 20 21 namespace { 22 struct DummyOp { 23 static StringRef getOperationName() { return "foo.bar"; } 24 }; 25 26 TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) { 27 MLIRContext context; 28 ConversionTarget target(context); 29 30 int index = 0; 31 int callbackCalled1 = 0; 32 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) { 33 callbackCalled1 = ++index; 34 return true; 35 }); 36 37 int callbackCalled2 = 0; 38 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> { 39 callbackCalled2 = ++index; 40 return llvm::None; 41 }); 42 43 auto *op = createOp(&context); 44 EXPECT_TRUE(target.isLegal(op)); 45 EXPECT_EQ(2, callbackCalled1); 46 EXPECT_EQ(1, callbackCalled2); 47 EXPECT_FALSE(target.isIllegal(op)); 48 EXPECT_EQ(4, callbackCalled1); 49 EXPECT_EQ(3, callbackCalled2); 50 op->destroy(); 51 } 52 53 TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) { 54 MLIRContext context; 55 ConversionTarget target(context); 56 57 int index = 0; 58 int callbackCalled = 0; 59 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> { 60 callbackCalled = ++index; 61 return llvm::None; 62 }); 63 64 auto *op = createOp(&context); 65 EXPECT_FALSE(target.isLegal(op)); 66 EXPECT_EQ(1, callbackCalled); 67 EXPECT_FALSE(target.isIllegal(op)); 68 EXPECT_EQ(2, callbackCalled); 69 op->destroy(); 70 } 71 72 TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) { 73 MLIRContext context; 74 ConversionTarget target(context); 75 76 int index = 0; 77 int callbackCalled1 = 0; 78 target.markUnknownOpDynamicallyLegal([&](Operation *) { 79 callbackCalled1 = ++index; 80 return true; 81 }); 82 83 int callbackCalled2 = 0; 84 target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> { 85 callbackCalled2 = ++index; 86 return llvm::None; 87 }); 88 89 auto *op = createOp(&context); 90 EXPECT_TRUE(target.isLegal(op)); 91 EXPECT_EQ(2, callbackCalled1); 92 EXPECT_EQ(1, callbackCalled2); 93 EXPECT_FALSE(target.isIllegal(op)); 94 EXPECT_EQ(4, callbackCalled1); 95 EXPECT_EQ(3, callbackCalled2); 96 op->destroy(); 97 } 98 99 TEST(DialectConversionTest, DynamicallyLegalReturnNone) { 100 MLIRContext context; 101 ConversionTarget target(context); 102 103 target.addDynamicallyLegalOp<DummyOp>( 104 [&](Operation *) -> Optional<bool> { return llvm::None; }); 105 106 auto *op = createOp(&context); 107 EXPECT_FALSE(target.isLegal(op)); 108 EXPECT_FALSE(target.isIllegal(op)); 109 110 EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {}))); 111 EXPECT_TRUE(failed(applyFullConversion(op, target, {}))); 112 113 op->destroy(); 114 } 115 116 TEST(DialectConversionTest, DynamicallyLegalUnknownReturnNone) { 117 MLIRContext context; 118 ConversionTarget target(context); 119 120 target.markUnknownOpDynamicallyLegal( 121 [&](Operation *) -> Optional<bool> { return llvm::None; }); 122 123 auto *op = createOp(&context); 124 EXPECT_FALSE(target.isLegal(op)); 125 EXPECT_FALSE(target.isIllegal(op)); 126 127 EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {}))); 128 EXPECT_TRUE(failed(applyFullConversion(op, target, {}))); 129 130 op->destroy(); 131 } 132 } // namespace 133