1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/DialectConversion.h" 10 #include "gtest/gtest.h" 11 12 using namespace mlir; 13 14 static Operation *createOp(MLIRContext *context) { 15 context->allowUnregisteredDialects(); 16 return Operation::create(UnknownLoc::get(context), 17 OperationName("foo.bar", context), llvm::None, 18 llvm::None, llvm::None, llvm::None, 0); 19 } 20 21 namespace { 22 struct DummyOp { 23 static StringRef getOperationName() { return "foo.bar"; } 24 }; 25 26 TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) { 27 MLIRContext context; 28 ConversionTarget target(context); 29 30 int index = 0; 31 int callbackCalled1 = 0; 32 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) { 33 callbackCalled1 = ++index; 34 return true; 35 }); 36 37 int callbackCalled2 = 0; 38 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> { 39 callbackCalled2 = ++index; 40 return llvm::None; 41 }); 42 43 auto *op = createOp(&context); 44 EXPECT_TRUE(target.isLegal(op)); 45 EXPECT_EQ(2, callbackCalled1); 46 EXPECT_EQ(1, callbackCalled2); 47 op->destroy(); 48 } 49 50 TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) { 51 MLIRContext context; 52 ConversionTarget target(context); 53 54 int index = 0; 55 int callbackCalled = 0; 56 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> { 57 callbackCalled = ++index; 58 return llvm::None; 59 }); 60 61 auto *op = createOp(&context); 62 EXPECT_FALSE(target.isLegal(op)); 63 EXPECT_EQ(1, callbackCalled); 64 op->destroy(); 65 } 66 67 TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) { 68 MLIRContext context; 69 ConversionTarget target(context); 70 71 int index = 0; 72 int callbackCalled1 = 0; 73 target.markUnknownOpDynamicallyLegal([&](Operation *) { 74 callbackCalled1 = ++index; 75 return true; 76 }); 77 78 int callbackCalled2 = 0; 79 target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> { 80 callbackCalled2 = ++index; 81 return llvm::None; 82 }); 83 84 auto *op = createOp(&context); 85 EXPECT_TRUE(target.isLegal(op)); 86 EXPECT_EQ(2, callbackCalled1); 87 EXPECT_EQ(1, callbackCalled2); 88 op->destroy(); 89 } 90 } // namespace 91