1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements the tests for attaching interfaces to attributes and types
10 // without having to specify them on the attribute or type class directly.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/BuiltinDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "gtest/gtest.h"
19
20 #include "../../test/lib/Dialect/Test/TestAttributes.h"
21 #include "../../test/lib/Dialect/Test/TestDialect.h"
22 #include "../../test/lib/Dialect/Test/TestTypes.h"
23 #include "mlir/IR/OwningOpRef.h"
24
25 using namespace mlir;
26 using namespace test;
27
28 namespace {
29
30 /// External interface model for the integer type. Only provides non-default
31 /// methods.
32 struct Model
33 : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
getBitwidthPlusArg__anon49fb1b390111::Model34 unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
35 return type.getIntOrFloatBitWidth() + arg;
36 }
37
staticGetSomeValuePlusArg__anon49fb1b390111::Model38 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
39 };
40
41 /// External interface model for the float type. Provides non-deafult and
42 /// overrides default methods.
43 struct OverridingModel
44 : public TestExternalTypeInterface::ExternalModel<OverridingModel,
45 FloatType> {
getBitwidthPlusArg__anon49fb1b390111::OverridingModel46 unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
47 return type.getIntOrFloatBitWidth() + arg;
48 }
49
staticGetSomeValuePlusArg__anon49fb1b390111::OverridingModel50 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
51
getBitwidthPlusDoubleArgument__anon49fb1b390111::OverridingModel52 unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
53 return 128;
54 }
55
staticGetArgument__anon49fb1b390111::OverridingModel56 static unsigned staticGetArgument(unsigned arg) { return 420; }
57 };
58
TEST(InterfaceAttachment,Type)59 TEST(InterfaceAttachment, Type) {
60 MLIRContext context;
61
62 // Check that the type has no interface.
63 IntegerType i8 = IntegerType::get(&context, 8);
64 ASSERT_FALSE(i8.isa<TestExternalTypeInterface>());
65
66 // Attach an interface and check that the type now has the interface.
67 IntegerType::attachInterface<Model>(context);
68 TestExternalTypeInterface iface = i8.dyn_cast<TestExternalTypeInterface>();
69 ASSERT_TRUE(iface != nullptr);
70 EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
71 EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
72 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
73 EXPECT_EQ(iface.staticGetArgument(17), 17u);
74
75 // Same, but with the default implementation overridden.
76 FloatType flt = Float32Type::get(&context);
77 ASSERT_FALSE(flt.isa<TestExternalTypeInterface>());
78 Float32Type::attachInterface<OverridingModel>(context);
79 iface = flt.dyn_cast<TestExternalTypeInterface>();
80 ASSERT_TRUE(iface != nullptr);
81 EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
82 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
83 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
84 EXPECT_EQ(iface.staticGetArgument(17), 420u);
85
86 // Other contexts shouldn't have the attribute attached.
87 MLIRContext other;
88 IntegerType i8other = IntegerType::get(&other, 8);
89 EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>());
90 }
91
92 /// External interface model for the test type from the test dialect.
93 struct TestTypeModel
94 : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
95 test::TestType> {
getBitwidthPlusArg__anon49fb1b390111::TestTypeModel96 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
97
staticGetSomeValuePlusArg__anon49fb1b390111::TestTypeModel98 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
99 };
100
TEST(InterfaceAttachment,TypeDelayedContextConstruct)101 TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
102 // Put the interface in the registry.
103 DialectRegistry registry;
104 registry.insert<test::TestDialect>();
105 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
106 test::TestType::attachInterface<TestTypeModel>(*ctx);
107 });
108
109 // Check that when a context is constructed with the given registry, the type
110 // interface gets registered.
111 MLIRContext context(registry);
112 context.loadDialect<test::TestDialect>();
113 test::TestType testType = test::TestType::get(&context);
114 auto iface = testType.dyn_cast<TestExternalTypeInterface>();
115 ASSERT_TRUE(iface != nullptr);
116 EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
117 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
118 }
119
TEST(InterfaceAttachment,TypeDelayedContextAppend)120 TEST(InterfaceAttachment, TypeDelayedContextAppend) {
121 // Put the interface in the registry.
122 DialectRegistry registry;
123 registry.insert<test::TestDialect>();
124 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
125 test::TestType::attachInterface<TestTypeModel>(*ctx);
126 });
127
128 // Check that when the registry gets appended to the context, the interface
129 // becomes available for objects in loaded dialects.
130 MLIRContext context;
131 context.loadDialect<test::TestDialect>();
132 test::TestType testType = test::TestType::get(&context);
133 EXPECT_FALSE(testType.isa<TestExternalTypeInterface>());
134 context.appendDialectRegistry(registry);
135 EXPECT_TRUE(testType.isa<TestExternalTypeInterface>());
136 }
137
TEST(InterfaceAttachment,RepeatedRegistration)138 TEST(InterfaceAttachment, RepeatedRegistration) {
139 DialectRegistry registry;
140 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
141 IntegerType::attachInterface<Model>(*ctx);
142 });
143 MLIRContext context(registry);
144
145 // Should't fail on repeated registration through the dialect registry.
146 context.appendDialectRegistry(registry);
147 }
148
TEST(InterfaceAttachment,TypeBuiltinDelayed)149 TEST(InterfaceAttachment, TypeBuiltinDelayed) {
150 // Builtin dialect needs to registration or loading, but delayed interface
151 // registration must still work.
152 DialectRegistry registry;
153 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
154 IntegerType::attachInterface<Model>(*ctx);
155 });
156
157 MLIRContext context(registry);
158 IntegerType i16 = IntegerType::get(&context, 16);
159 EXPECT_TRUE(i16.isa<TestExternalTypeInterface>());
160
161 MLIRContext initiallyEmpty;
162 IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
163 EXPECT_FALSE(i32.isa<TestExternalTypeInterface>());
164 initiallyEmpty.appendDialectRegistry(registry);
165 EXPECT_TRUE(i32.isa<TestExternalTypeInterface>());
166 }
167
168 /// The interface provides a default implementation that expects
169 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
170 /// just derives from the ExternalModel.
171 struct TestExternalFallbackTypeIntegerModel
172 : public TestExternalFallbackTypeInterface::ExternalModel<
173 TestExternalFallbackTypeIntegerModel, IntegerType> {};
174
175 /// The interface provides a default implementation that expects
176 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
177 /// FallbackModel instead to override this and make sure the code still compiles
178 /// because we never instantiate the ExternalModel class template with a
179 /// template argument that would have led to compilation failures.
180 struct TestExternalFallbackTypeVectorModel
181 : public TestExternalFallbackTypeInterface::FallbackModel<
182 TestExternalFallbackTypeVectorModel> {
getBitwidth__anon49fb1b390111::TestExternalFallbackTypeVectorModel183 unsigned getBitwidth(Type type) const {
184 IntegerType elementType = type.cast<VectorType>()
185 .getElementType()
186 .dyn_cast_or_null<IntegerType>();
187 return elementType ? elementType.getWidth() : 0;
188 }
189 };
190
TEST(InterfaceAttachment,Fallback)191 TEST(InterfaceAttachment, Fallback) {
192 MLIRContext context;
193
194 // Just check that we can attach the interface.
195 IntegerType i8 = IntegerType::get(&context, 8);
196 ASSERT_FALSE(i8.isa<TestExternalFallbackTypeInterface>());
197 IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
198 ASSERT_TRUE(i8.isa<TestExternalFallbackTypeInterface>());
199
200 // Call the method so it is guaranteed not to be instantiated.
201 VectorType vec = VectorType::get({42}, i8);
202 ASSERT_FALSE(vec.isa<TestExternalFallbackTypeInterface>());
203 VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
204 ASSERT_TRUE(vec.isa<TestExternalFallbackTypeInterface>());
205 EXPECT_EQ(vec.cast<TestExternalFallbackTypeInterface>().getBitwidth(), 8u);
206 }
207
208 /// External model for attribute interfaces.
209 struct TestExternalIntegerAttrModel
210 : public TestExternalAttrInterface::ExternalModel<
211 TestExternalIntegerAttrModel, IntegerAttr> {
getDialectPtr__anon49fb1b390111::TestExternalIntegerAttrModel212 const Dialect *getDialectPtr(Attribute attr) const {
213 return &attr.cast<IntegerAttr>().getDialect();
214 }
215
getSomeNumber__anon49fb1b390111::TestExternalIntegerAttrModel216 static int getSomeNumber() { return 42; }
217 };
218
TEST(InterfaceAttachment,Attribute)219 TEST(InterfaceAttachment, Attribute) {
220 MLIRContext context;
221
222 // Attribute interfaces use the exact same mechanism as types, so just check
223 // that the basics work for attributes.
224 IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
225 ASSERT_FALSE(attr.isa<TestExternalAttrInterface>());
226 IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
227 auto iface = attr.dyn_cast<TestExternalAttrInterface>();
228 ASSERT_TRUE(iface != nullptr);
229 EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
230 EXPECT_EQ(iface.getSomeNumber(), 42);
231 }
232
233 /// External model for an interface attachable to a non-builtin attribute.
234 struct TestExternalSimpleAAttrModel
235 : public TestExternalAttrInterface::ExternalModel<
236 TestExternalSimpleAAttrModel, test::SimpleAAttr> {
getDialectPtr__anon49fb1b390111::TestExternalSimpleAAttrModel237 const Dialect *getDialectPtr(Attribute attr) const {
238 return &attr.getDialect();
239 }
240
getSomeNumber__anon49fb1b390111::TestExternalSimpleAAttrModel241 static int getSomeNumber() { return 21; }
242 };
243
TEST(InterfaceAttachmentTest,AttributeDelayed)244 TEST(InterfaceAttachmentTest, AttributeDelayed) {
245 // Attribute interfaces use the exact same mechanism as types, so just check
246 // that the delayed registration work for attributes.
247 DialectRegistry registry;
248 registry.insert<test::TestDialect>();
249 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
250 test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx);
251 });
252
253 MLIRContext context(registry);
254 context.loadDialect<test::TestDialect>();
255 auto attr = test::SimpleAAttr::get(&context);
256 EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
257
258 MLIRContext initiallyEmpty;
259 initiallyEmpty.loadDialect<test::TestDialect>();
260 attr = test::SimpleAAttr::get(&initiallyEmpty);
261 EXPECT_FALSE(attr.isa<TestExternalAttrInterface>());
262 initiallyEmpty.appendDialectRegistry(registry);
263 EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
264 }
265
266 /// External interface model for the module operation. Only provides non-default
267 /// methods.
268 struct TestExternalOpModel
269 : public TestExternalOpInterface::ExternalModel<TestExternalOpModel,
270 ModuleOp> {
getNameLengthPlusArg__anon49fb1b390111::TestExternalOpModel271 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
272 return op->getName().getStringRef().size() + arg;
273 }
274
getNameLengthPlusArgTwice__anon49fb1b390111::TestExternalOpModel275 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
276 return ModuleOp::getOperationName().size() + 2 * arg;
277 }
278 };
279
280 /// External interface model for the func operation. Provides non-deafult and
281 /// overrides default methods.
282 struct TestExternalOpOverridingModel
283 : public TestExternalOpInterface::FallbackModel<
284 TestExternalOpOverridingModel> {
getNameLengthPlusArg__anon49fb1b390111::TestExternalOpOverridingModel285 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
286 return op->getName().getStringRef().size() + arg;
287 }
288
getNameLengthPlusArgTwice__anon49fb1b390111::TestExternalOpOverridingModel289 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
290 return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg;
291 }
292
getNameLengthTimesArg__anon49fb1b390111::TestExternalOpOverridingModel293 unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const {
294 return 42;
295 }
296
getNameLengthMinusArg__anon49fb1b390111::TestExternalOpOverridingModel297 static unsigned getNameLengthMinusArg(unsigned arg) { return 21; }
298 };
299
TEST(InterfaceAttachment,Operation)300 TEST(InterfaceAttachment, Operation) {
301 MLIRContext context;
302 OpBuilder builder(&context);
303
304 // Initially, the operation doesn't have the interface.
305 OwningOpRef<ModuleOp> moduleOp =
306 builder.create<ModuleOp>(UnknownLoc::get(&context));
307 ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation()));
308
309 // We can attach an external interface and now the operaiton has it.
310 ModuleOp::attachInterface<TestExternalOpModel>(context);
311 auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation());
312 ASSERT_TRUE(iface != nullptr);
313 EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u);
314 EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u);
315 EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u);
316 EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u);
317
318 // Default implementation can be overridden.
319 OwningOpRef<UnrealizedConversionCastOp> castOp =
320 builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context),
321 TypeRange(), ValueRange());
322 ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation()));
323 UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>(
324 context);
325 iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation());
326 ASSERT_TRUE(iface != nullptr);
327 EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u);
328 EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u);
329 EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u);
330 EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u);
331
332 // Another context doesn't have the interfaces registered.
333 MLIRContext other;
334 OwningOpRef<ModuleOp> otherModuleOp =
335 ModuleOp::create(UnknownLoc::get(&other));
336 ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation()));
337 }
338
339 template <class ConcreteOp>
340 struct TestExternalTestOpModel
341 : public TestExternalOpInterface::ExternalModel<
342 TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
getNameLengthPlusArg__anon49fb1b390111::TestExternalTestOpModel343 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
344 return op->getName().getStringRef().size() + arg;
345 }
346
getNameLengthPlusArgTwice__anon49fb1b390111::TestExternalTestOpModel347 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
348 return ConcreteOp::getOperationName().size() + 2 * arg;
349 }
350 };
351
TEST(InterfaceAttachment,OperationDelayedContextConstruct)352 TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
353 DialectRegistry registry;
354 registry.insert<test::TestDialect>();
355 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
356 ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
357 });
358 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
359 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
360 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
361 });
362
363 // Construct the context directly from a registry. The interfaces are
364 // expected to be readily available on operations.
365 MLIRContext context(registry);
366 context.loadDialect<test::TestDialect>();
367
368 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
369 OpBuilder builder(module->getBody(), module->getBody()->begin());
370 auto opJ =
371 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
372 auto opH =
373 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
374 auto opI =
375 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
376
377 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
378 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
379 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
380 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
381 }
382
TEST(InterfaceAttachment,OperationDelayedContextAppend)383 TEST(InterfaceAttachment, OperationDelayedContextAppend) {
384 DialectRegistry registry;
385 registry.insert<test::TestDialect>();
386 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
387 ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
388 });
389 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
390 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
391 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
392 });
393
394 // Construct the context, create ops, and only then append the registry. The
395 // interfaces are expected to be available after appending the registry.
396 MLIRContext context;
397 context.loadDialect<test::TestDialect>();
398
399 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
400 OpBuilder builder(module->getBody(), module->getBody()->begin());
401 auto opJ =
402 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
403 auto opH =
404 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
405 auto opI =
406 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
407
408 EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation()));
409 EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
410 EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
411 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
412
413 context.appendDialectRegistry(registry);
414
415 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
416 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
417 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
418 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
419 }
420
421 } // namespace
422