1 //===- DataLayoutInterfacesTest.cpp - Unit Tests for Data Layouts ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/DataLayoutInterfaces.h" 10 #include "mlir/Dialect/DLTI/DLTI.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/DialectImplementation.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/Parser.h" 18 19 #include <gtest/gtest.h> 20 21 using namespace mlir; 22 23 namespace { 24 constexpr static llvm::StringLiteral kAttrName = "dltest.layout"; 25 26 /// Trivial array storage for the custom data layout spec attribute, just a list 27 /// of entries. 28 class DataLayoutSpecStorage : public AttributeStorage { 29 public: 30 using KeyTy = ArrayRef<DataLayoutEntryInterface>; 31 32 DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries) 33 : entries(entries) {} 34 35 bool operator==(const KeyTy &key) const { return key == entries; } 36 37 static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator, 38 const KeyTy &key) { 39 return new (allocator.allocate<DataLayoutSpecStorage>()) 40 DataLayoutSpecStorage(allocator.copyInto(key)); 41 } 42 43 ArrayRef<DataLayoutEntryInterface> entries; 44 }; 45 46 /// Simple data layout spec containing a list of entries that always verifies 47 /// as valid. 48 struct CustomDataLayoutSpec 49 : public Attribute::AttrBase<CustomDataLayoutSpec, Attribute, 50 DataLayoutSpecStorage, 51 DataLayoutSpecInterface::Trait> { 52 using Base::Base; 53 static CustomDataLayoutSpec get(MLIRContext *ctx, 54 ArrayRef<DataLayoutEntryInterface> entries) { 55 return Base::get(ctx, entries); 56 } 57 CustomDataLayoutSpec 58 combineWith(ArrayRef<DataLayoutSpecInterface> specs) const { 59 return *this; 60 } 61 DataLayoutEntryListRef getEntries() const { return getImpl()->entries; } 62 LogicalResult verifySpec(Location loc) { return success(); } 63 }; 64 65 /// A type subject to data layout that exits the program if it is queried more 66 /// than once. Handy to check if the cache works. 67 struct SingleQueryType 68 : public Type::TypeBase<SingleQueryType, Type, TypeStorage, 69 DataLayoutTypeInterface::Trait> { 70 using Base::Base; 71 72 static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); } 73 74 unsigned getTypeSizeInBits(const DataLayout &layout, 75 DataLayoutEntryListRef params) const { 76 static bool executed = false; 77 if (executed) 78 llvm::report_fatal_error("repeated call"); 79 80 executed = true; 81 return 1; 82 } 83 84 unsigned getABIAlignment(const DataLayout &layout, 85 DataLayoutEntryListRef params) { 86 static bool executed = false; 87 if (executed) 88 llvm::report_fatal_error("repeated call"); 89 90 executed = true; 91 return 2; 92 } 93 94 unsigned getPreferredAlignment(const DataLayout &layout, 95 DataLayoutEntryListRef params) { 96 static bool executed = false; 97 if (executed) 98 llvm::report_fatal_error("repeated call"); 99 100 executed = true; 101 return 4; 102 } 103 }; 104 105 /// A types that is not subject to data layout. 106 struct TypeNoLayout : public Type::TypeBase<TypeNoLayout, Type, TypeStorage> { 107 using Base::Base; 108 109 static TypeNoLayout get(MLIRContext *ctx) { return Base::get(ctx); } 110 }; 111 112 /// An op that serves as scope for data layout queries with the relevant 113 /// attribute attached. This can handle data layout requests for the built-in 114 /// types itself. 115 struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> { 116 using Op::Op; 117 118 static StringRef getOperationName() { return "dltest.op_with_layout"; } 119 120 DataLayoutSpecInterface getDataLayoutSpec() { 121 return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName); 122 } 123 124 static unsigned getTypeSizeInBits(Type type, const DataLayout &dataLayout, 125 DataLayoutEntryListRef params) { 126 // Make a recursive query. 127 if (type.isa<FloatType>()) 128 return dataLayout.getTypeSizeInBits( 129 IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth())); 130 131 // Handle built-in types that are not handled by the default process. 132 if (auto iType = type.dyn_cast<IntegerType>()) { 133 for (DataLayoutEntryInterface entry : params) 134 if (entry.getKey().dyn_cast<Type>() == type) 135 return 8 * 136 entry.getValue().cast<IntegerAttr>().getValue().getZExtValue(); 137 return 8 * iType.getIntOrFloatBitWidth(); 138 } 139 140 // Use the default process for everything else. 141 return detail::getDefaultTypeSize(type, dataLayout, params); 142 } 143 144 static unsigned getTypeABIAlignment(Type type, const DataLayout &dataLayout, 145 DataLayoutEntryListRef params) { 146 return llvm::PowerOf2Ceil(getTypeSize(type, dataLayout, params)); 147 } 148 149 static unsigned getTypePreferredAlignment(Type type, 150 const DataLayout &dataLayout, 151 DataLayoutEntryListRef params) { 152 return 2 * getTypeABIAlignment(type, dataLayout, params); 153 } 154 }; 155 156 struct OpWith7BitByte 157 : public Op<OpWith7BitByte, DataLayoutOpInterface::Trait> { 158 using Op::Op; 159 160 static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; } 161 162 DataLayoutSpecInterface getDataLayoutSpec() { 163 return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName); 164 } 165 166 // Bytes are assumed to be 7-bit here. 167 static unsigned getTypeSize(Type type, const DataLayout &dataLayout, 168 DataLayoutEntryListRef params) { 169 return llvm::divideCeil(dataLayout.getTypeSizeInBits(type), 7); 170 } 171 }; 172 173 /// A dialect putting all the above together. 174 struct DLTestDialect : Dialect { 175 explicit DLTestDialect(MLIRContext *ctx) 176 : Dialect(getDialectNamespace(), ctx, TypeID::get<DLTestDialect>()) { 177 ctx->getOrLoadDialect<DLTIDialect>(); 178 addAttributes<CustomDataLayoutSpec>(); 179 addOperations<OpWithLayout, OpWith7BitByte>(); 180 addTypes<SingleQueryType, TypeNoLayout>(); 181 } 182 static StringRef getDialectNamespace() { return "dltest"; } 183 184 void printAttribute(Attribute attr, 185 DialectAsmPrinter &printer) const override { 186 printer << "spec<"; 187 llvm::interleaveComma(attr.cast<CustomDataLayoutSpec>().getEntries(), 188 printer); 189 printer << ">"; 190 } 191 192 Attribute parseAttribute(DialectAsmParser &parser, Type type) const override { 193 bool ok = 194 succeeded(parser.parseKeyword("spec")) && succeeded(parser.parseLess()); 195 (void)ok; 196 assert(ok); 197 if (succeeded(parser.parseOptionalGreater())) 198 return CustomDataLayoutSpec::get(parser.getBuilder().getContext(), {}); 199 200 SmallVector<DataLayoutEntryInterface> entries; 201 do { 202 entries.emplace_back(); 203 ok = succeeded(parser.parseAttribute(entries.back())); 204 assert(ok); 205 } while (succeeded(parser.parseOptionalComma())); 206 ok = succeeded(parser.parseGreater()); 207 assert(ok); 208 return CustomDataLayoutSpec::get(parser.getBuilder().getContext(), entries); 209 } 210 211 void printType(Type type, DialectAsmPrinter &printer) const override { 212 if (type.isa<SingleQueryType>()) 213 printer << "single_query"; 214 else 215 printer << "no_layout"; 216 } 217 218 Type parseType(DialectAsmParser &parser) const override { 219 bool ok = succeeded(parser.parseKeyword("single_query")); 220 (void)ok; 221 assert(ok); 222 return SingleQueryType::get(parser.getBuilder().getContext()); 223 } 224 }; 225 226 } // end namespace 227 228 TEST(DataLayout, FallbackDefault) { 229 const char *ir = R"MLIR( 230 "dltest.op_with_layout"() : () -> () 231 )MLIR"; 232 233 DialectRegistry registry; 234 registry.insert<DLTIDialect, DLTestDialect>(); 235 MLIRContext ctx(registry); 236 237 OwningModuleRef module = parseSourceString(ir, &ctx); 238 auto op = 239 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 240 DataLayout layout(op); 241 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u); 242 EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u); 243 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u); 244 EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u); 245 EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u); 246 EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u); 247 EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u); 248 EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u); 249 } 250 251 TEST(DataLayout, EmptySpec) { 252 const char *ir = R"MLIR( 253 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> () 254 )MLIR"; 255 256 DialectRegistry registry; 257 registry.insert<DLTIDialect, DLTestDialect>(); 258 MLIRContext ctx(registry); 259 260 OwningModuleRef module = parseSourceString(ir, &ctx); 261 auto op = 262 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 263 DataLayout layout(op); 264 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u); 265 EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u); 266 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u); 267 EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u); 268 EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u); 269 EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u); 270 EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u); 271 EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u); 272 } 273 274 TEST(DataLayout, SpecWithEntries) { 275 const char *ir = R"MLIR( 276 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< 277 #dlti.dl_entry<i42, 5>, 278 #dlti.dl_entry<i16, 6> 279 > } : () -> () 280 )MLIR"; 281 282 DialectRegistry registry; 283 registry.insert<DLTIDialect, DLTestDialect>(); 284 MLIRContext ctx(registry); 285 286 OwningModuleRef module = parseSourceString(ir, &ctx); 287 auto op = 288 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 289 DataLayout layout(op); 290 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 5u); 291 EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u); 292 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 40u); 293 EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 48u); 294 EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u); 295 EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u); 296 EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u); 297 EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 16u); 298 299 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u); 300 EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u); 301 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 256u); 302 EXPECT_EQ(layout.getTypeSizeInBits(Float32Type::get(&ctx)), 256u); 303 EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 32)), 32u); 304 EXPECT_EQ(layout.getTypeABIAlignment(Float32Type::get(&ctx)), 32u); 305 EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 32)), 64u); 306 EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u); 307 } 308 309 TEST(DataLayout, Caching) { 310 const char *ir = R"MLIR( 311 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> () 312 )MLIR"; 313 314 DialectRegistry registry; 315 registry.insert<DLTIDialect, DLTestDialect>(); 316 MLIRContext ctx(registry); 317 318 OwningModuleRef module = parseSourceString(ir, &ctx); 319 auto op = 320 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 321 DataLayout layout(op); 322 323 unsigned sum = 0; 324 sum += layout.getTypeSize(SingleQueryType::get(&ctx)); 325 // The second call should hit the cache. If it does not, the function in 326 // SingleQueryType will be called and will abort the process. 327 sum += layout.getTypeSize(SingleQueryType::get(&ctx)); 328 // Make sure the complier doesn't optimize away the query code. 329 EXPECT_EQ(sum, 2u); 330 331 // A fresh data layout has a new cache, so the call to it should be dispatched 332 // down to the type and abort the proces. 333 DataLayout second(op); 334 ASSERT_DEATH(second.getTypeSize(SingleQueryType::get(&ctx)), "repeated call"); 335 } 336 337 TEST(DataLayout, CacheInvalidation) { 338 const char *ir = R"MLIR( 339 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< 340 #dlti.dl_entry<i42, 5>, 341 #dlti.dl_entry<i16, 6> 342 > } : () -> () 343 )MLIR"; 344 345 DialectRegistry registry; 346 registry.insert<DLTIDialect, DLTestDialect>(); 347 MLIRContext ctx(registry); 348 349 OwningModuleRef module = parseSourceString(ir, &ctx); 350 auto op = 351 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 352 DataLayout layout(op); 353 354 // Normal query is fine. 355 EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u); 356 357 // Replace the data layout spec with a new, empty spec. 358 op->setAttr(kAttrName, CustomDataLayoutSpec::get(&ctx, {})); 359 360 // Data layout is no longer valid and should trigger assertion when queried. 361 #ifndef NDEBUG 362 ASSERT_DEATH(layout.getTypeSize(Float16Type::get(&ctx)), "no longer valid"); 363 #endif 364 } 365 366 TEST(DataLayout, UnimplementedTypeInterface) { 367 const char *ir = R"MLIR( 368 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> () 369 )MLIR"; 370 371 DialectRegistry registry; 372 registry.insert<DLTIDialect, DLTestDialect>(); 373 MLIRContext ctx(registry); 374 375 OwningModuleRef module = parseSourceString(ir, &ctx); 376 auto op = 377 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 378 DataLayout layout(op); 379 380 ASSERT_DEATH(layout.getTypeSize(TypeNoLayout::get(&ctx)), 381 "neither the scoping op nor the type class provide data layout " 382 "information"); 383 } 384 385 TEST(DataLayout, SevenBitByte) { 386 const char *ir = R"MLIR( 387 "dltest.op_with_7bit_byte"() { dltest.layout = #dltest.spec<> } : () -> () 388 )MLIR"; 389 390 DialectRegistry registry; 391 registry.insert<DLTIDialect, DLTestDialect>(); 392 MLIRContext ctx(registry); 393 394 OwningModuleRef module = parseSourceString(ir, &ctx); 395 auto op = 396 cast<DataLayoutOpInterface>(module->getBody()->getOperations().front()); 397 DataLayout layout(op); 398 399 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u); 400 EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 32u); 401 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u); 402 EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 5u); 403 } 404