1 //===- DataLayoutInterfacesTest.cpp - Unit Tests for Data Layouts ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/DataLayoutInterfaces.h"
10 #include "mlir/Dialect/DLTI/DLTI.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/Parser.h"
18 
19 #include <gtest/gtest.h>
20 
21 using namespace mlir;
22 
23 namespace {
24 constexpr static llvm::StringLiteral kAttrName = "dltest.layout";
25 
26 /// Trivial array storage for the custom data layout spec attribute, just a list
27 /// of entries.
28 class DataLayoutSpecStorage : public AttributeStorage {
29 public:
30   using KeyTy = ArrayRef<DataLayoutEntryInterface>;
31 
32   DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
33       : entries(entries) {}
34 
35   bool operator==(const KeyTy &key) const { return key == entries; }
36 
37   static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator,
38                                           const KeyTy &key) {
39     return new (allocator.allocate<DataLayoutSpecStorage>())
40         DataLayoutSpecStorage(allocator.copyInto(key));
41   }
42 
43   ArrayRef<DataLayoutEntryInterface> entries;
44 };
45 
46 /// Simple data layout spec containing a list of entries that always verifies
47 /// as valid.
48 struct CustomDataLayoutSpec
49     : public Attribute::AttrBase<CustomDataLayoutSpec, Attribute,
50                                  DataLayoutSpecStorage,
51                                  DataLayoutSpecInterface::Trait> {
52   using Base::Base;
53   static CustomDataLayoutSpec get(MLIRContext *ctx,
54                                   ArrayRef<DataLayoutEntryInterface> entries) {
55     return Base::get(ctx, entries);
56   }
57   CustomDataLayoutSpec
58   combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
59     return *this;
60   }
61   DataLayoutEntryListRef getEntries() const { return getImpl()->entries; }
62   LogicalResult verifySpec(Location loc) { return success(); }
63 };
64 
65 /// A type subject to data layout that exits the program if it is queried more
66 /// than once. Handy to check if the cache works.
67 struct SingleQueryType
68     : public Type::TypeBase<SingleQueryType, Type, TypeStorage,
69                             DataLayoutTypeInterface::Trait> {
70   using Base::Base;
71 
72   static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); }
73 
74   unsigned getTypeSizeInBits(const DataLayout &layout,
75                              DataLayoutEntryListRef params) const {
76     static bool executed = false;
77     if (executed)
78       llvm::report_fatal_error("repeated call");
79 
80     executed = true;
81     return 1;
82   }
83 
84   unsigned getABIAlignment(const DataLayout &layout,
85                            DataLayoutEntryListRef params) {
86     static bool executed = false;
87     if (executed)
88       llvm::report_fatal_error("repeated call");
89 
90     executed = true;
91     return 2;
92   }
93 
94   unsigned getPreferredAlignment(const DataLayout &layout,
95                                  DataLayoutEntryListRef params) {
96     static bool executed = false;
97     if (executed)
98       llvm::report_fatal_error("repeated call");
99 
100     executed = true;
101     return 4;
102   }
103 };
104 
105 /// A types that is not subject to data layout.
106 struct TypeNoLayout : public Type::TypeBase<TypeNoLayout, Type, TypeStorage> {
107   using Base::Base;
108 
109   static TypeNoLayout get(MLIRContext *ctx) { return Base::get(ctx); }
110 };
111 
112 /// An op that serves as scope for data layout queries with the relevant
113 /// attribute attached. This can handle data layout requests for the built-in
114 /// types itself.
115 struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
116   using Op::Op;
117 
118   static StringRef getOperationName() { return "dltest.op_with_layout"; }
119 
120   DataLayoutSpecInterface getDataLayoutSpec() {
121     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
122   }
123 
124   static unsigned getTypeSizeInBits(Type type, const DataLayout &dataLayout,
125                                     DataLayoutEntryListRef params) {
126     // Make a recursive query.
127     if (type.isa<FloatType>())
128       return dataLayout.getTypeSizeInBits(
129           IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth()));
130 
131     // Handle built-in types that are not handled by the default process.
132     if (auto iType = type.dyn_cast<IntegerType>()) {
133       for (DataLayoutEntryInterface entry : params)
134         if (entry.getKey().dyn_cast<Type>() == type)
135           return 8 *
136                  entry.getValue().cast<IntegerAttr>().getValue().getZExtValue();
137       return 8 * iType.getIntOrFloatBitWidth();
138     }
139 
140     // Use the default process for everything else.
141     return detail::getDefaultTypeSize(type, dataLayout, params);
142   }
143 
144   static unsigned getTypeABIAlignment(Type type, const DataLayout &dataLayout,
145                                       DataLayoutEntryListRef params) {
146     return llvm::PowerOf2Ceil(getTypeSize(type, dataLayout, params));
147   }
148 
149   static unsigned getTypePreferredAlignment(Type type,
150                                             const DataLayout &dataLayout,
151                                             DataLayoutEntryListRef params) {
152     return 2 * getTypeABIAlignment(type, dataLayout, params);
153   }
154 };
155 
156 struct OpWith7BitByte
157     : public Op<OpWith7BitByte, DataLayoutOpInterface::Trait> {
158   using Op::Op;
159 
160   static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; }
161 
162   DataLayoutSpecInterface getDataLayoutSpec() {
163     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
164   }
165 
166   // Bytes are assumed to be 7-bit here.
167   static unsigned getTypeSize(Type type, const DataLayout &dataLayout,
168                               DataLayoutEntryListRef params) {
169     return llvm::divideCeil(dataLayout.getTypeSizeInBits(type), 7);
170   }
171 };
172 
173 /// A dialect putting all the above together.
174 struct DLTestDialect : Dialect {
175   explicit DLTestDialect(MLIRContext *ctx)
176       : Dialect(getDialectNamespace(), ctx, TypeID::get<DLTestDialect>()) {
177     ctx->getOrLoadDialect<DLTIDialect>();
178     addAttributes<CustomDataLayoutSpec>();
179     addOperations<OpWithLayout, OpWith7BitByte>();
180     addTypes<SingleQueryType, TypeNoLayout>();
181   }
182   static StringRef getDialectNamespace() { return "dltest"; }
183 
184   void printAttribute(Attribute attr,
185                       DialectAsmPrinter &printer) const override {
186     printer << "spec<";
187     llvm::interleaveComma(attr.cast<CustomDataLayoutSpec>().getEntries(),
188                           printer);
189     printer << ">";
190   }
191 
192   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
193     bool ok =
194         succeeded(parser.parseKeyword("spec")) && succeeded(parser.parseLess());
195     (void)ok;
196     assert(ok);
197     if (succeeded(parser.parseOptionalGreater()))
198       return CustomDataLayoutSpec::get(parser.getBuilder().getContext(), {});
199 
200     SmallVector<DataLayoutEntryInterface> entries;
201     do {
202       entries.emplace_back();
203       ok = succeeded(parser.parseAttribute(entries.back()));
204       assert(ok);
205     } while (succeeded(parser.parseOptionalComma()));
206     ok = succeeded(parser.parseGreater());
207     assert(ok);
208     return CustomDataLayoutSpec::get(parser.getBuilder().getContext(), entries);
209   }
210 
211   void printType(Type type, DialectAsmPrinter &printer) const override {
212     if (type.isa<SingleQueryType>())
213       printer << "single_query";
214     else
215       printer << "no_layout";
216   }
217 
218   Type parseType(DialectAsmParser &parser) const override {
219     bool ok = succeeded(parser.parseKeyword("single_query"));
220     (void)ok;
221     assert(ok);
222     return SingleQueryType::get(parser.getBuilder().getContext());
223   }
224 };
225 
226 } // end namespace
227 
228 TEST(DataLayout, FallbackDefault) {
229   const char *ir = R"MLIR(
230 module {}
231   )MLIR";
232 
233   DialectRegistry registry;
234   registry.insert<DLTIDialect, DLTestDialect>();
235   MLIRContext ctx(registry);
236 
237   OwningModuleRef module = parseSourceString(ir, &ctx);
238   DataLayout layout(module.get());
239   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
240   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u);
241   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
242   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u);
243   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
244   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u);
245   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u);
246   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u);
247 }
248 
249 TEST(DataLayout, NullSpec) {
250   const char *ir = R"MLIR(
251 "dltest.op_with_layout"() : () -> ()
252   )MLIR";
253 
254   DialectRegistry registry;
255   registry.insert<DLTIDialect, DLTestDialect>();
256   MLIRContext ctx(registry);
257 
258   OwningModuleRef module = parseSourceString(ir, &ctx);
259   auto op =
260       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
261   DataLayout layout(op);
262   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
263   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
264   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
265   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
266   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
267   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
268   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
269   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
270 }
271 
272 TEST(DataLayout, EmptySpec) {
273   const char *ir = R"MLIR(
274 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> ()
275   )MLIR";
276 
277   DialectRegistry registry;
278   registry.insert<DLTIDialect, DLTestDialect>();
279   MLIRContext ctx(registry);
280 
281   OwningModuleRef module = parseSourceString(ir, &ctx);
282   auto op =
283       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
284   DataLayout layout(op);
285   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
286   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
287   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
288   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
289   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
290   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
291   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
292   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
293 }
294 
295 TEST(DataLayout, SpecWithEntries) {
296   const char *ir = R"MLIR(
297 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
298   #dlti.dl_entry<i42, 5>,
299   #dlti.dl_entry<i16, 6>
300 > } : () -> ()
301   )MLIR";
302 
303   DialectRegistry registry;
304   registry.insert<DLTIDialect, DLTestDialect>();
305   MLIRContext ctx(registry);
306 
307   OwningModuleRef module = parseSourceString(ir, &ctx);
308   auto op =
309       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
310   DataLayout layout(op);
311   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 5u);
312   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
313   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 40u);
314   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 48u);
315   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
316   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u);
317   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u);
318   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 16u);
319 
320   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u);
321   EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u);
322   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 256u);
323   EXPECT_EQ(layout.getTypeSizeInBits(Float32Type::get(&ctx)), 256u);
324   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 32)), 32u);
325   EXPECT_EQ(layout.getTypeABIAlignment(Float32Type::get(&ctx)), 32u);
326   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 32)), 64u);
327   EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u);
328 }
329 
330 TEST(DataLayout, Caching) {
331   const char *ir = R"MLIR(
332 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
333   )MLIR";
334 
335   DialectRegistry registry;
336   registry.insert<DLTIDialect, DLTestDialect>();
337   MLIRContext ctx(registry);
338 
339   OwningModuleRef module = parseSourceString(ir, &ctx);
340   auto op =
341       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
342   DataLayout layout(op);
343 
344   unsigned sum = 0;
345   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
346   // The second call should hit the cache. If it does not, the function in
347   // SingleQueryType will be called and will abort the process.
348   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
349   // Make sure the complier doesn't optimize away the query code.
350   EXPECT_EQ(sum, 2u);
351 
352   // A fresh data layout has a new cache, so the call to it should be dispatched
353   // down to the type and abort the proces.
354   DataLayout second(op);
355   ASSERT_DEATH(second.getTypeSize(SingleQueryType::get(&ctx)), "repeated call");
356 }
357 
358 TEST(DataLayout, CacheInvalidation) {
359   const char *ir = R"MLIR(
360 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
361   #dlti.dl_entry<i42, 5>,
362   #dlti.dl_entry<i16, 6>
363 > } : () -> ()
364   )MLIR";
365 
366   DialectRegistry registry;
367   registry.insert<DLTIDialect, DLTestDialect>();
368   MLIRContext ctx(registry);
369 
370   OwningModuleRef module = parseSourceString(ir, &ctx);
371   auto op =
372       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
373   DataLayout layout(op);
374 
375   // Normal query is fine.
376   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
377 
378   // Replace the data layout spec with a new, empty spec.
379   op->setAttr(kAttrName, CustomDataLayoutSpec::get(&ctx, {}));
380 
381   // Data layout is no longer valid and should trigger assertion when queried.
382 #ifndef NDEBUG
383   ASSERT_DEATH(layout.getTypeSize(Float16Type::get(&ctx)), "no longer valid");
384 #endif
385 }
386 
387 TEST(DataLayout, UnimplementedTypeInterface) {
388   const char *ir = R"MLIR(
389 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
390   )MLIR";
391 
392   DialectRegistry registry;
393   registry.insert<DLTIDialect, DLTestDialect>();
394   MLIRContext ctx(registry);
395 
396   OwningModuleRef module = parseSourceString(ir, &ctx);
397   auto op =
398       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
399   DataLayout layout(op);
400 
401   ASSERT_DEATH(layout.getTypeSize(TypeNoLayout::get(&ctx)),
402                "neither the scoping op nor the type class provide data layout "
403                "information");
404 }
405 
406 TEST(DataLayout, SevenBitByte) {
407   const char *ir = R"MLIR(
408 "dltest.op_with_7bit_byte"() { dltest.layout = #dltest.spec<> } : () -> ()
409   )MLIR";
410 
411   DialectRegistry registry;
412   registry.insert<DLTIDialect, DLTestDialect>();
413   MLIRContext ctx(registry);
414 
415   OwningModuleRef module = parseSourceString(ir, &ctx);
416   auto op =
417       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
418   DataLayout layout(op);
419 
420   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
421   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 32u);
422   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
423   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 5u);
424 }
425