1 //===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains types defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "TestTypes.h"
15 #include "TestDialect.h"
16 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/ExtensibleDialect.h"
20 #include "mlir/IR/Types.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/TypeSwitch.h"
24
25 using namespace mlir;
26 using namespace test;
27
28 // Custom parser for SignednessSemantics.
29 static ParseResult
parseSignedness(AsmParser & parser,TestIntegerType::SignednessSemantics & result)30 parseSignedness(AsmParser &parser,
31 TestIntegerType::SignednessSemantics &result) {
32 StringRef signStr;
33 auto loc = parser.getCurrentLocation();
34 if (parser.parseKeyword(&signStr))
35 return failure();
36 if (signStr.equals_insensitive("u") || signStr.equals_insensitive("unsigned"))
37 result = TestIntegerType::SignednessSemantics::Unsigned;
38 else if (signStr.equals_insensitive("s") ||
39 signStr.equals_insensitive("signed"))
40 result = TestIntegerType::SignednessSemantics::Signed;
41 else if (signStr.equals_insensitive("n") ||
42 signStr.equals_insensitive("none"))
43 result = TestIntegerType::SignednessSemantics::Signless;
44 else
45 return parser.emitError(loc, "expected signed, unsigned, or none");
46 return success();
47 }
48
49 // Custom printer for SignednessSemantics.
printSignedness(AsmPrinter & printer,const TestIntegerType::SignednessSemantics & ss)50 static void printSignedness(AsmPrinter &printer,
51 const TestIntegerType::SignednessSemantics &ss) {
52 switch (ss) {
53 case TestIntegerType::SignednessSemantics::Unsigned:
54 printer << "unsigned";
55 break;
56 case TestIntegerType::SignednessSemantics::Signed:
57 printer << "signed";
58 break;
59 case TestIntegerType::SignednessSemantics::Signless:
60 printer << "none";
61 break;
62 }
63 }
64
65 // The functions don't need to be in the header file, but need to be in the mlir
66 // namespace. Declare them here, then define them immediately below. Separating
67 // the declaration and definition adheres to the LLVM coding standards.
68 namespace test {
69 // FieldInfo is used as part of a parameter, so equality comparison is
70 // compulsory.
71 static bool operator==(const FieldInfo &a, const FieldInfo &b);
72 // FieldInfo is used as part of a parameter, so a hash will be computed.
73 static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
74 } // namespace test
75
76 // FieldInfo is used as part of a parameter, so equality comparison is
77 // compulsory.
operator ==(const FieldInfo & a,const FieldInfo & b)78 static bool test::operator==(const FieldInfo &a, const FieldInfo &b) {
79 return a.name == b.name && a.type == b.type;
80 }
81
82 // FieldInfo is used as part of a parameter, so a hash will be computed.
hash_value(const FieldInfo & fi)83 static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
84 return llvm::hash_combine(fi.name, fi.type);
85 }
86
87 //===----------------------------------------------------------------------===//
88 // TestCustomType
89 //===----------------------------------------------------------------------===//
90
parseCustomTypeA(AsmParser & parser,FailureOr<int> & aResult)91 static LogicalResult parseCustomTypeA(AsmParser &parser,
92 FailureOr<int> &aResult) {
93 aResult.emplace();
94 return parser.parseInteger(*aResult);
95 }
96
printCustomTypeA(AsmPrinter & printer,int a)97 static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
98
parseCustomTypeB(AsmParser & parser,int a,FailureOr<Optional<int>> & bResult)99 static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
100 FailureOr<Optional<int>> &bResult) {
101 if (a < 0)
102 return success();
103 for (int i : llvm::seq(0, a))
104 if (failed(parser.parseInteger(i)))
105 return failure();
106 bResult.emplace(0);
107 return parser.parseInteger(**bResult);
108 }
109
printCustomTypeB(AsmPrinter & printer,int a,Optional<int> b)110 static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
111 if (a < 0)
112 return;
113 printer << ' ';
114 for (int i : llvm::seq(0, a))
115 printer << i << ' ';
116 printer << *b;
117 }
118
parseFooString(AsmParser & parser,FailureOr<std::string> & foo)119 static LogicalResult parseFooString(AsmParser &parser,
120 FailureOr<std::string> &foo) {
121 std::string result;
122 if (parser.parseString(&result))
123 return failure();
124 foo = std::move(result);
125 return success();
126 }
127
printFooString(AsmPrinter & printer,StringRef foo)128 static void printFooString(AsmPrinter &printer, StringRef foo) {
129 printer << '"' << foo << '"';
130 }
131
parseBarString(AsmParser & parser,StringRef foo)132 static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
133 return parser.parseKeyword(foo);
134 }
135
printBarString(AsmPrinter & printer,StringRef foo)136 static void printBarString(AsmPrinter &printer, StringRef foo) {
137 printer << ' ' << foo;
138 }
139 //===----------------------------------------------------------------------===//
140 // Tablegen Generated Definitions
141 //===----------------------------------------------------------------------===//
142
143 #define GET_TYPEDEF_CLASSES
144 #include "TestTypeDefs.cpp.inc"
145
146 //===----------------------------------------------------------------------===//
147 // CompoundAType
148 //===----------------------------------------------------------------------===//
149
parse(AsmParser & parser)150 Type CompoundAType::parse(AsmParser &parser) {
151 int widthOfSomething;
152 Type oneType;
153 SmallVector<int, 4> arrayOfInts;
154 if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
155 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
156 parser.parseLSquare())
157 return Type();
158
159 int i;
160 while (!*parser.parseOptionalInteger(i)) {
161 arrayOfInts.push_back(i);
162 if (parser.parseOptionalComma())
163 break;
164 }
165
166 if (parser.parseRSquare() || parser.parseGreater())
167 return Type();
168
169 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
170 }
print(AsmPrinter & printer) const171 void CompoundAType::print(AsmPrinter &printer) const {
172 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
173 auto intArray = getArrayOfInts();
174 llvm::interleaveComma(intArray, printer);
175 printer << "]>";
176 }
177
178 //===----------------------------------------------------------------------===//
179 // TestIntegerType
180 //===----------------------------------------------------------------------===//
181
182 // Example type validity checker.
183 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned width,TestIntegerType::SignednessSemantics ss)184 TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
185 unsigned width,
186 TestIntegerType::SignednessSemantics ss) {
187 if (width > 8)
188 return failure();
189 return success();
190 }
191
parse(AsmParser & parser)192 Type TestIntegerType::parse(AsmParser &parser) {
193 SignednessSemantics signedness;
194 int width;
195 if (parser.parseLess() || parseSignedness(parser, signedness) ||
196 parser.parseComma() || parser.parseInteger(width) ||
197 parser.parseGreater())
198 return Type();
199 Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
200 return getChecked(loc, loc.getContext(), width, signedness);
201 }
202
print(AsmPrinter & p) const203 void TestIntegerType::print(AsmPrinter &p) const {
204 p << "<";
205 printSignedness(p, getSignedness());
206 p << ", " << getWidth() << ">";
207 }
208
209 //===----------------------------------------------------------------------===//
210 // TestStructType
211 //===----------------------------------------------------------------------===//
212
parse(AsmParser & p)213 Type StructType::parse(AsmParser &p) {
214 SmallVector<FieldInfo, 4> parameters;
215 if (p.parseLess())
216 return Type();
217 while (succeeded(p.parseOptionalLBrace())) {
218 Type type;
219 StringRef name;
220 if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) ||
221 p.parseRBrace())
222 return Type();
223 parameters.push_back(FieldInfo{name, type});
224 if (p.parseOptionalComma())
225 break;
226 }
227 if (p.parseGreater())
228 return Type();
229 return get(p.getContext(), parameters);
230 }
231
print(AsmPrinter & p) const232 void StructType::print(AsmPrinter &p) const {
233 p << "<";
234 llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) {
235 p << "{" << field.name << "," << field.type << "}";
236 });
237 p << ">";
238 }
239
240 //===----------------------------------------------------------------------===//
241 // TestType
242 //===----------------------------------------------------------------------===//
243
printTypeC(Location loc) const244 void TestType::printTypeC(Location loc) const {
245 emitRemark(loc) << *this << " - TestC";
246 }
247
248 //===----------------------------------------------------------------------===//
249 // TestTypeWithLayout
250 //===----------------------------------------------------------------------===//
251
parse(AsmParser & parser)252 Type TestTypeWithLayoutType::parse(AsmParser &parser) {
253 unsigned val;
254 if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater())
255 return Type();
256 return TestTypeWithLayoutType::get(parser.getContext(), val);
257 }
258
print(AsmPrinter & printer) const259 void TestTypeWithLayoutType::print(AsmPrinter &printer) const {
260 printer << "<" << getKey() << ">";
261 }
262
263 unsigned
getTypeSizeInBits(const DataLayout & dataLayout,DataLayoutEntryListRef params) const264 TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout,
265 DataLayoutEntryListRef params) const {
266 return extractKind(params, "size");
267 }
268
269 unsigned
getABIAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const270 TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout,
271 DataLayoutEntryListRef params) const {
272 return extractKind(params, "alignment");
273 }
274
getPreferredAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const275 unsigned TestTypeWithLayoutType::getPreferredAlignment(
276 const DataLayout &dataLayout, DataLayoutEntryListRef params) const {
277 return extractKind(params, "preferred");
278 }
279
areCompatible(DataLayoutEntryListRef oldLayout,DataLayoutEntryListRef newLayout) const280 bool TestTypeWithLayoutType::areCompatible(
281 DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const {
282 unsigned old = extractKind(oldLayout, "alignment");
283 return old == 1 || extractKind(newLayout, "alignment") <= old;
284 }
285
286 LogicalResult
verifyEntries(DataLayoutEntryListRef params,Location loc) const287 TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
288 Location loc) const {
289 for (DataLayoutEntryInterface entry : params) {
290 // This is for testing purposes only, so assert well-formedness.
291 assert(entry.isTypeEntry() && "unexpected identifier entry");
292 assert(entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() &&
293 "wrong type passed in");
294 auto array = entry.getValue().dyn_cast<ArrayAttr>();
295 assert(array && array.getValue().size() == 2 &&
296 "expected array of two elements");
297 auto kind = array.getValue().front().dyn_cast<StringAttr>();
298 (void)kind;
299 assert(kind &&
300 (kind.getValue() == "size" || kind.getValue() == "alignment" ||
301 kind.getValue() == "preferred") &&
302 "unexpected kind");
303 assert(array.getValue().back().isa<IntegerAttr>());
304 }
305 return success();
306 }
307
extractKind(DataLayoutEntryListRef params,StringRef expectedKind) const308 unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
309 StringRef expectedKind) const {
310 for (DataLayoutEntryInterface entry : params) {
311 ArrayRef<Attribute> pair = entry.getValue().cast<ArrayAttr>().getValue();
312 StringRef kind = pair.front().cast<StringAttr>().getValue();
313 if (kind == expectedKind)
314 return pair.back().cast<IntegerAttr>().getValue().getZExtValue();
315 }
316 return 1;
317 }
318
319 //===----------------------------------------------------------------------===//
320 // Dynamic Types
321 //===----------------------------------------------------------------------===//
322
323 /// Define a singleton dynamic type.
324 static std::unique_ptr<DynamicTypeDefinition>
getSingletonDynamicType(TestDialect * testDialect)325 getSingletonDynamicType(TestDialect *testDialect) {
326 return DynamicTypeDefinition::get(
327 "dynamic_singleton", testDialect,
328 [](function_ref<InFlightDiagnostic()> emitError,
329 ArrayRef<Attribute> args) {
330 if (!args.empty()) {
331 emitError() << "expected 0 type arguments, but had " << args.size();
332 return failure();
333 }
334 return success();
335 });
336 }
337
338 /// Define a dynamic type representing a pair.
339 static std::unique_ptr<DynamicTypeDefinition>
getPairDynamicType(TestDialect * testDialect)340 getPairDynamicType(TestDialect *testDialect) {
341 return DynamicTypeDefinition::get(
342 "dynamic_pair", testDialect,
343 [](function_ref<InFlightDiagnostic()> emitError,
344 ArrayRef<Attribute> args) {
345 if (args.size() != 2) {
346 emitError() << "expected 2 type arguments, but had " << args.size();
347 return failure();
348 }
349 return success();
350 });
351 }
352
353 static std::unique_ptr<DynamicTypeDefinition>
getCustomAssemblyFormatDynamicType(TestDialect * testDialect)354 getCustomAssemblyFormatDynamicType(TestDialect *testDialect) {
355 auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
356 ArrayRef<Attribute> args) {
357 if (args.size() != 2) {
358 emitError() << "expected 2 type arguments, but had " << args.size();
359 return failure();
360 }
361 return success();
362 };
363
364 auto parser = [](AsmParser &parser,
365 llvm::SmallVectorImpl<Attribute> &parsedParams) {
366 Attribute leftAttr, rightAttr;
367 if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
368 parser.parseColon() || parser.parseAttribute(rightAttr) ||
369 parser.parseGreater())
370 return failure();
371 parsedParams.push_back(leftAttr);
372 parsedParams.push_back(rightAttr);
373 return success();
374 };
375
376 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
377 printer << "<" << params[0] << ":" << params[1] << ">";
378 };
379
380 return DynamicTypeDefinition::get("dynamic_custom_assembly_format",
381 testDialect, std::move(verifier),
382 std::move(parser), std::move(printer));
383 }
384
385 //===----------------------------------------------------------------------===//
386 // TestDialect
387 //===----------------------------------------------------------------------===//
388
389 namespace {
390
391 struct PtrElementModel
392 : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel,
393 SimpleAType> {};
394 } // namespace
395
registerTypes()396 void TestDialect::registerTypes() {
397 addTypes<TestRecursiveType,
398 #define GET_TYPEDEF_LIST
399 #include "TestTypeDefs.cpp.inc"
400 >();
401 SimpleAType::attachInterface<PtrElementModel>(*getContext());
402
403 registerDynamicType(getSingletonDynamicType(this));
404 registerDynamicType(getPairDynamicType(this));
405 registerDynamicType(getCustomAssemblyFormatDynamicType(this));
406 }
407
parseTestType(AsmParser & parser,SetVector<Type> & stack) const408 Type TestDialect::parseTestType(AsmParser &parser,
409 SetVector<Type> &stack) const {
410 StringRef typeTag;
411 {
412 Type genType;
413 auto parseResult = generatedTypeParser(parser, &typeTag, genType);
414 if (parseResult.hasValue())
415 return genType;
416 }
417
418 {
419 Type dynType;
420 auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
421 if (parseResult.hasValue()) {
422 if (succeeded(parseResult.getValue()))
423 return dynType;
424 return Type();
425 }
426 }
427
428 if (typeTag != "test_rec") {
429 parser.emitError(parser.getNameLoc()) << "unknown type!";
430 return Type();
431 }
432
433 StringRef name;
434 if (parser.parseLess() || parser.parseKeyword(&name))
435 return Type();
436 auto rec = TestRecursiveType::get(parser.getContext(), name);
437
438 // If this type already has been parsed above in the stack, expect just the
439 // name.
440 if (stack.contains(rec)) {
441 if (failed(parser.parseGreater()))
442 return Type();
443 return rec;
444 }
445
446 // Otherwise, parse the body and update the type.
447 if (failed(parser.parseComma()))
448 return Type();
449 stack.insert(rec);
450 Type subtype = parseTestType(parser, stack);
451 stack.pop_back();
452 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
453 return Type();
454
455 return rec;
456 }
457
parseType(DialectAsmParser & parser) const458 Type TestDialect::parseType(DialectAsmParser &parser) const {
459 SetVector<Type> stack;
460 return parseTestType(parser, stack);
461 }
462
printTestType(Type type,AsmPrinter & printer,SetVector<Type> & stack) const463 void TestDialect::printTestType(Type type, AsmPrinter &printer,
464 SetVector<Type> &stack) const {
465 if (succeeded(generatedTypePrinter(type, printer)))
466 return;
467
468 if (succeeded(printIfDynamicType(type, printer)))
469 return;
470
471 auto rec = type.cast<TestRecursiveType>();
472 printer << "test_rec<" << rec.getName();
473 if (!stack.contains(rec)) {
474 printer << ", ";
475 stack.insert(rec);
476 printTestType(rec.getBody(), printer, stack);
477 stack.pop_back();
478 }
479 printer << ">";
480 }
481
printType(Type type,DialectAsmPrinter & printer) const482 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
483 SetVector<Type> stack;
484 printTestType(type, printer, stack);
485 }
486