1 //===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains types defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestTypes.h"
15 #include "TestDialect.h"
16 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/ExtensibleDialect.h"
20 #include "mlir/IR/Types.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace test;
27 
28 // Custom parser for SignednessSemantics.
29 static ParseResult
parseSignedness(AsmParser & parser,TestIntegerType::SignednessSemantics & result)30 parseSignedness(AsmParser &parser,
31                 TestIntegerType::SignednessSemantics &result) {
32   StringRef signStr;
33   auto loc = parser.getCurrentLocation();
34   if (parser.parseKeyword(&signStr))
35     return failure();
36   if (signStr.equals_insensitive("u") || signStr.equals_insensitive("unsigned"))
37     result = TestIntegerType::SignednessSemantics::Unsigned;
38   else if (signStr.equals_insensitive("s") ||
39            signStr.equals_insensitive("signed"))
40     result = TestIntegerType::SignednessSemantics::Signed;
41   else if (signStr.equals_insensitive("n") ||
42            signStr.equals_insensitive("none"))
43     result = TestIntegerType::SignednessSemantics::Signless;
44   else
45     return parser.emitError(loc, "expected signed, unsigned, or none");
46   return success();
47 }
48 
49 // Custom printer for SignednessSemantics.
printSignedness(AsmPrinter & printer,const TestIntegerType::SignednessSemantics & ss)50 static void printSignedness(AsmPrinter &printer,
51                             const TestIntegerType::SignednessSemantics &ss) {
52   switch (ss) {
53   case TestIntegerType::SignednessSemantics::Unsigned:
54     printer << "unsigned";
55     break;
56   case TestIntegerType::SignednessSemantics::Signed:
57     printer << "signed";
58     break;
59   case TestIntegerType::SignednessSemantics::Signless:
60     printer << "none";
61     break;
62   }
63 }
64 
65 // The functions don't need to be in the header file, but need to be in the mlir
66 // namespace. Declare them here, then define them immediately below. Separating
67 // the declaration and definition adheres to the LLVM coding standards.
68 namespace test {
69 // FieldInfo is used as part of a parameter, so equality comparison is
70 // compulsory.
71 static bool operator==(const FieldInfo &a, const FieldInfo &b);
72 // FieldInfo is used as part of a parameter, so a hash will be computed.
73 static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
74 } // namespace test
75 
76 // FieldInfo is used as part of a parameter, so equality comparison is
77 // compulsory.
operator ==(const FieldInfo & a,const FieldInfo & b)78 static bool test::operator==(const FieldInfo &a, const FieldInfo &b) {
79   return a.name == b.name && a.type == b.type;
80 }
81 
82 // FieldInfo is used as part of a parameter, so a hash will be computed.
hash_value(const FieldInfo & fi)83 static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
84   return llvm::hash_combine(fi.name, fi.type);
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // TestCustomType
89 //===----------------------------------------------------------------------===//
90 
parseCustomTypeA(AsmParser & parser,FailureOr<int> & aResult)91 static LogicalResult parseCustomTypeA(AsmParser &parser,
92                                       FailureOr<int> &aResult) {
93   aResult.emplace();
94   return parser.parseInteger(*aResult);
95 }
96 
printCustomTypeA(AsmPrinter & printer,int a)97 static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
98 
parseCustomTypeB(AsmParser & parser,int a,FailureOr<Optional<int>> & bResult)99 static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
100                                       FailureOr<Optional<int>> &bResult) {
101   if (a < 0)
102     return success();
103   for (int i : llvm::seq(0, a))
104     if (failed(parser.parseInteger(i)))
105       return failure();
106   bResult.emplace(0);
107   return parser.parseInteger(**bResult);
108 }
109 
printCustomTypeB(AsmPrinter & printer,int a,Optional<int> b)110 static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
111   if (a < 0)
112     return;
113   printer << ' ';
114   for (int i : llvm::seq(0, a))
115     printer << i << ' ';
116   printer << *b;
117 }
118 
parseFooString(AsmParser & parser,FailureOr<std::string> & foo)119 static LogicalResult parseFooString(AsmParser &parser,
120                                     FailureOr<std::string> &foo) {
121   std::string result;
122   if (parser.parseString(&result))
123     return failure();
124   foo = std::move(result);
125   return success();
126 }
127 
printFooString(AsmPrinter & printer,StringRef foo)128 static void printFooString(AsmPrinter &printer, StringRef foo) {
129   printer << '"' << foo << '"';
130 }
131 
parseBarString(AsmParser & parser,StringRef foo)132 static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
133   return parser.parseKeyword(foo);
134 }
135 
printBarString(AsmPrinter & printer,StringRef foo)136 static void printBarString(AsmPrinter &printer, StringRef foo) {
137   printer << ' ' << foo;
138 }
139 //===----------------------------------------------------------------------===//
140 // Tablegen Generated Definitions
141 //===----------------------------------------------------------------------===//
142 
143 #define GET_TYPEDEF_CLASSES
144 #include "TestTypeDefs.cpp.inc"
145 
146 //===----------------------------------------------------------------------===//
147 // CompoundAType
148 //===----------------------------------------------------------------------===//
149 
parse(AsmParser & parser)150 Type CompoundAType::parse(AsmParser &parser) {
151   int widthOfSomething;
152   Type oneType;
153   SmallVector<int, 4> arrayOfInts;
154   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
155       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
156       parser.parseLSquare())
157     return Type();
158 
159   int i;
160   while (!*parser.parseOptionalInteger(i)) {
161     arrayOfInts.push_back(i);
162     if (parser.parseOptionalComma())
163       break;
164   }
165 
166   if (parser.parseRSquare() || parser.parseGreater())
167     return Type();
168 
169   return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
170 }
print(AsmPrinter & printer) const171 void CompoundAType::print(AsmPrinter &printer) const {
172   printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
173   auto intArray = getArrayOfInts();
174   llvm::interleaveComma(intArray, printer);
175   printer << "]>";
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // TestIntegerType
180 //===----------------------------------------------------------------------===//
181 
182 // Example type validity checker.
183 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned width,TestIntegerType::SignednessSemantics ss)184 TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
185                         unsigned width,
186                         TestIntegerType::SignednessSemantics ss) {
187   if (width > 8)
188     return failure();
189   return success();
190 }
191 
parse(AsmParser & parser)192 Type TestIntegerType::parse(AsmParser &parser) {
193   SignednessSemantics signedness;
194   int width;
195   if (parser.parseLess() || parseSignedness(parser, signedness) ||
196       parser.parseComma() || parser.parseInteger(width) ||
197       parser.parseGreater())
198     return Type();
199   Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
200   return getChecked(loc, loc.getContext(), width, signedness);
201 }
202 
print(AsmPrinter & p) const203 void TestIntegerType::print(AsmPrinter &p) const {
204   p << "<";
205   printSignedness(p, getSignedness());
206   p << ", " << getWidth() << ">";
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // TestStructType
211 //===----------------------------------------------------------------------===//
212 
parse(AsmParser & p)213 Type StructType::parse(AsmParser &p) {
214   SmallVector<FieldInfo, 4> parameters;
215   if (p.parseLess())
216     return Type();
217   while (succeeded(p.parseOptionalLBrace())) {
218     Type type;
219     StringRef name;
220     if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) ||
221         p.parseRBrace())
222       return Type();
223     parameters.push_back(FieldInfo{name, type});
224     if (p.parseOptionalComma())
225       break;
226   }
227   if (p.parseGreater())
228     return Type();
229   return get(p.getContext(), parameters);
230 }
231 
print(AsmPrinter & p) const232 void StructType::print(AsmPrinter &p) const {
233   p << "<";
234   llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) {
235     p << "{" << field.name << "," << field.type << "}";
236   });
237   p << ">";
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // TestType
242 //===----------------------------------------------------------------------===//
243 
printTypeC(Location loc) const244 void TestType::printTypeC(Location loc) const {
245   emitRemark(loc) << *this << " - TestC";
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // TestTypeWithLayout
250 //===----------------------------------------------------------------------===//
251 
parse(AsmParser & parser)252 Type TestTypeWithLayoutType::parse(AsmParser &parser) {
253   unsigned val;
254   if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater())
255     return Type();
256   return TestTypeWithLayoutType::get(parser.getContext(), val);
257 }
258 
print(AsmPrinter & printer) const259 void TestTypeWithLayoutType::print(AsmPrinter &printer) const {
260   printer << "<" << getKey() << ">";
261 }
262 
263 unsigned
getTypeSizeInBits(const DataLayout & dataLayout,DataLayoutEntryListRef params) const264 TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout,
265                                           DataLayoutEntryListRef params) const {
266   return extractKind(params, "size");
267 }
268 
269 unsigned
getABIAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const270 TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout,
271                                         DataLayoutEntryListRef params) const {
272   return extractKind(params, "alignment");
273 }
274 
getPreferredAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const275 unsigned TestTypeWithLayoutType::getPreferredAlignment(
276     const DataLayout &dataLayout, DataLayoutEntryListRef params) const {
277   return extractKind(params, "preferred");
278 }
279 
areCompatible(DataLayoutEntryListRef oldLayout,DataLayoutEntryListRef newLayout) const280 bool TestTypeWithLayoutType::areCompatible(
281     DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const {
282   unsigned old = extractKind(oldLayout, "alignment");
283   return old == 1 || extractKind(newLayout, "alignment") <= old;
284 }
285 
286 LogicalResult
verifyEntries(DataLayoutEntryListRef params,Location loc) const287 TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
288                                       Location loc) const {
289   for (DataLayoutEntryInterface entry : params) {
290     // This is for testing purposes only, so assert well-formedness.
291     assert(entry.isTypeEntry() && "unexpected identifier entry");
292     assert(entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() &&
293            "wrong type passed in");
294     auto array = entry.getValue().dyn_cast<ArrayAttr>();
295     assert(array && array.getValue().size() == 2 &&
296            "expected array of two elements");
297     auto kind = array.getValue().front().dyn_cast<StringAttr>();
298     (void)kind;
299     assert(kind &&
300            (kind.getValue() == "size" || kind.getValue() == "alignment" ||
301             kind.getValue() == "preferred") &&
302            "unexpected kind");
303     assert(array.getValue().back().isa<IntegerAttr>());
304   }
305   return success();
306 }
307 
extractKind(DataLayoutEntryListRef params,StringRef expectedKind) const308 unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
309                                              StringRef expectedKind) const {
310   for (DataLayoutEntryInterface entry : params) {
311     ArrayRef<Attribute> pair = entry.getValue().cast<ArrayAttr>().getValue();
312     StringRef kind = pair.front().cast<StringAttr>().getValue();
313     if (kind == expectedKind)
314       return pair.back().cast<IntegerAttr>().getValue().getZExtValue();
315   }
316   return 1;
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // Dynamic Types
321 //===----------------------------------------------------------------------===//
322 
323 /// Define a singleton dynamic type.
324 static std::unique_ptr<DynamicTypeDefinition>
getSingletonDynamicType(TestDialect * testDialect)325 getSingletonDynamicType(TestDialect *testDialect) {
326   return DynamicTypeDefinition::get(
327       "dynamic_singleton", testDialect,
328       [](function_ref<InFlightDiagnostic()> emitError,
329          ArrayRef<Attribute> args) {
330         if (!args.empty()) {
331           emitError() << "expected 0 type arguments, but had " << args.size();
332           return failure();
333         }
334         return success();
335       });
336 }
337 
338 /// Define a dynamic type representing a pair.
339 static std::unique_ptr<DynamicTypeDefinition>
getPairDynamicType(TestDialect * testDialect)340 getPairDynamicType(TestDialect *testDialect) {
341   return DynamicTypeDefinition::get(
342       "dynamic_pair", testDialect,
343       [](function_ref<InFlightDiagnostic()> emitError,
344          ArrayRef<Attribute> args) {
345         if (args.size() != 2) {
346           emitError() << "expected 2 type arguments, but had " << args.size();
347           return failure();
348         }
349         return success();
350       });
351 }
352 
353 static std::unique_ptr<DynamicTypeDefinition>
getCustomAssemblyFormatDynamicType(TestDialect * testDialect)354 getCustomAssemblyFormatDynamicType(TestDialect *testDialect) {
355   auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
356                      ArrayRef<Attribute> args) {
357     if (args.size() != 2) {
358       emitError() << "expected 2 type arguments, but had " << args.size();
359       return failure();
360     }
361     return success();
362   };
363 
364   auto parser = [](AsmParser &parser,
365                    llvm::SmallVectorImpl<Attribute> &parsedParams) {
366     Attribute leftAttr, rightAttr;
367     if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
368         parser.parseColon() || parser.parseAttribute(rightAttr) ||
369         parser.parseGreater())
370       return failure();
371     parsedParams.push_back(leftAttr);
372     parsedParams.push_back(rightAttr);
373     return success();
374   };
375 
376   auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
377     printer << "<" << params[0] << ":" << params[1] << ">";
378   };
379 
380   return DynamicTypeDefinition::get("dynamic_custom_assembly_format",
381                                     testDialect, std::move(verifier),
382                                     std::move(parser), std::move(printer));
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // TestDialect
387 //===----------------------------------------------------------------------===//
388 
389 namespace {
390 
391 struct PtrElementModel
392     : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel,
393                                                               SimpleAType> {};
394 } // namespace
395 
registerTypes()396 void TestDialect::registerTypes() {
397   addTypes<TestRecursiveType,
398 #define GET_TYPEDEF_LIST
399 #include "TestTypeDefs.cpp.inc"
400            >();
401   SimpleAType::attachInterface<PtrElementModel>(*getContext());
402 
403   registerDynamicType(getSingletonDynamicType(this));
404   registerDynamicType(getPairDynamicType(this));
405   registerDynamicType(getCustomAssemblyFormatDynamicType(this));
406 }
407 
parseTestType(AsmParser & parser,SetVector<Type> & stack) const408 Type TestDialect::parseTestType(AsmParser &parser,
409                                 SetVector<Type> &stack) const {
410   StringRef typeTag;
411   {
412     Type genType;
413     auto parseResult = generatedTypeParser(parser, &typeTag, genType);
414     if (parseResult.hasValue())
415       return genType;
416   }
417 
418   {
419     Type dynType;
420     auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
421     if (parseResult.hasValue()) {
422       if (succeeded(parseResult.getValue()))
423         return dynType;
424       return Type();
425     }
426   }
427 
428   if (typeTag != "test_rec") {
429     parser.emitError(parser.getNameLoc()) << "unknown type!";
430     return Type();
431   }
432 
433   StringRef name;
434   if (parser.parseLess() || parser.parseKeyword(&name))
435     return Type();
436   auto rec = TestRecursiveType::get(parser.getContext(), name);
437 
438   // If this type already has been parsed above in the stack, expect just the
439   // name.
440   if (stack.contains(rec)) {
441     if (failed(parser.parseGreater()))
442       return Type();
443     return rec;
444   }
445 
446   // Otherwise, parse the body and update the type.
447   if (failed(parser.parseComma()))
448     return Type();
449   stack.insert(rec);
450   Type subtype = parseTestType(parser, stack);
451   stack.pop_back();
452   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
453     return Type();
454 
455   return rec;
456 }
457 
parseType(DialectAsmParser & parser) const458 Type TestDialect::parseType(DialectAsmParser &parser) const {
459   SetVector<Type> stack;
460   return parseTestType(parser, stack);
461 }
462 
printTestType(Type type,AsmPrinter & printer,SetVector<Type> & stack) const463 void TestDialect::printTestType(Type type, AsmPrinter &printer,
464                                 SetVector<Type> &stack) const {
465   if (succeeded(generatedTypePrinter(type, printer)))
466     return;
467 
468   if (succeeded(printIfDynamicType(type, printer)))
469     return;
470 
471   auto rec = type.cast<TestRecursiveType>();
472   printer << "test_rec<" << rec.getName();
473   if (!stack.contains(rec)) {
474     printer << ", ";
475     stack.insert(rec);
476     printTestType(rec.getBody(), printer, stack);
477     stack.pop_back();
478   }
479   printer << ">";
480 }
481 
printType(Type type,DialectAsmPrinter & printer) const482 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
483   SetVector<Type> stack;
484   printTestType(type, printer, stack);
485 }
486