1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains attributes defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/ExtensibleDialect.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/Support/LogicalResult.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/ADT/bit.h"
25 #include "llvm/Support/ErrorHandling.h"
26 
27 using namespace mlir;
28 using namespace test;
29 
30 //===----------------------------------------------------------------------===//
31 // AttrWithSelfTypeParamAttr
32 //===----------------------------------------------------------------------===//
33 
34 Attribute AttrWithSelfTypeParamAttr::parse(AsmParser &parser, Type type) {
35   Type selfType;
36   if (parser.parseType(selfType))
37     return Attribute();
38   return get(parser.getContext(), selfType);
39 }
40 
41 void AttrWithSelfTypeParamAttr::print(AsmPrinter &printer) const {
42   printer << " " << getType();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // AttrWithTypeBuilderAttr
47 //===----------------------------------------------------------------------===//
48 
49 Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) {
50   IntegerAttr element;
51   if (parser.parseAttribute(element))
52     return Attribute();
53   return get(parser.getContext(), element);
54 }
55 
56 void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const {
57   printer << " " << getAttr();
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // CompoundAAttr
62 //===----------------------------------------------------------------------===//
63 
64 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) {
65   int widthOfSomething;
66   Type oneType;
67   SmallVector<int, 4> arrayOfInts;
68   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
69       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
70       parser.parseLSquare())
71     return Attribute();
72 
73   int intVal;
74   while (!*parser.parseOptionalInteger(intVal)) {
75     arrayOfInts.push_back(intVal);
76     if (parser.parseOptionalComma())
77       break;
78   }
79 
80   if (parser.parseRSquare() || parser.parseGreater())
81     return Attribute();
82   return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
83 }
84 
85 void CompoundAAttr::print(AsmPrinter &printer) const {
86   printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
87   llvm::interleaveComma(getArrayOfInts(), printer);
88   printer << "]>";
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // CompoundAAttr
93 //===----------------------------------------------------------------------===//
94 
95 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
96   SmallVector<uint64_t> elements;
97   if (parser.parseLess() || parser.parseLSquare())
98     return Attribute();
99   uint64_t intVal;
100   while (succeeded(*parser.parseOptionalInteger(intVal))) {
101     elements.push_back(intVal);
102     if (parser.parseOptionalComma())
103       break;
104   }
105 
106   if (parser.parseRSquare() || parser.parseGreater())
107     return Attribute();
108   return parser.getChecked<TestI64ElementsAttr>(
109       parser.getContext(), type.cast<ShapedType>(), elements);
110 }
111 
112 void TestI64ElementsAttr::print(AsmPrinter &printer) const {
113   printer << "<[";
114   llvm::interleaveComma(getElements(), printer);
115   printer << "] : " << getType() << ">";
116 }
117 
118 LogicalResult
119 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
120                             ShapedType type, ArrayRef<uint64_t> elements) {
121   if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
122     return emitError()
123            << "number of elements does not match the provided shape type, got: "
124            << elements.size() << ", but expected: " << type.getNumElements();
125   }
126   if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
127     return emitError() << "expected single rank 64-bit shape type, but got: "
128                        << type;
129   return success();
130 }
131 
132 LogicalResult
133 TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
134                                int64_t one, std::string two, IntegerAttr three,
135                                ArrayRef<int> four) {
136   if (four.size() != static_cast<unsigned>(one))
137     return emitError() << "expected 'one' to equal 'four.size()'";
138   return success();
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // Utility Functions for Generated Attributes
143 //===----------------------------------------------------------------------===//
144 
145 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) {
146   SmallVector<int> ints;
147   if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
148         ints.push_back(0);
149         return parser.parseInteger(ints.back());
150       }) ||
151       parser.parseRSquare())
152     return failure();
153   return ints;
154 }
155 
156 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) {
157   printer << '[';
158   llvm::interleaveComma(ints, printer);
159   printer << ']';
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // TestSubElementsAccessAttr
164 //===----------------------------------------------------------------------===//
165 
166 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser,
167                                            ::mlir::Type type) {
168   Attribute first, second, third;
169   if (parser.parseLess() || parser.parseAttribute(first) ||
170       parser.parseComma() || parser.parseAttribute(second) ||
171       parser.parseComma() || parser.parseAttribute(third) ||
172       parser.parseGreater()) {
173     return {};
174   }
175   return get(parser.getContext(), first, second, third);
176 }
177 
178 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const {
179   printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird()
180           << ">";
181 }
182 
183 void TestSubElementsAccessAttr::walkImmediateSubElements(
184     llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
185     llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
186   walkAttrsFn(getFirst());
187   walkAttrsFn(getSecond());
188   walkAttrsFn(getThird());
189 }
190 
191 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
192     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
193   Attribute first = getFirst();
194   Attribute second = getSecond();
195   Attribute third = getThird();
196   for (auto &it : replacements) {
197     switch (it.first) {
198     case 0:
199       first = it.second;
200       break;
201     case 1:
202       second = it.second;
203       break;
204     case 2:
205       third = it.second;
206       break;
207     }
208   }
209   return get(getContext(), first, second, third);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // Tablegen Generated Definitions
214 //===----------------------------------------------------------------------===//
215 
216 #include "TestAttrInterfaces.cpp.inc"
217 
218 #define GET_ATTRDEF_CLASSES
219 #include "TestAttrDefs.cpp.inc"
220 
221 //===----------------------------------------------------------------------===//
222 // Dynamic Attributes
223 //===----------------------------------------------------------------------===//
224 
225 /// Define a singleton dynamic attribute.
226 static std::unique_ptr<DynamicAttrDefinition>
227 getDynamicSingletonAttr(TestDialect *testDialect) {
228   return DynamicAttrDefinition::get(
229       "dynamic_singleton", testDialect,
230       [](function_ref<InFlightDiagnostic()> emitError,
231          ArrayRef<Attribute> args) {
232         if (!args.empty()) {
233           emitError() << "expected 0 attribute arguments, but had "
234                       << args.size();
235           return failure();
236         }
237         return success();
238       });
239 }
240 
241 /// Define a dynamic attribute representing a pair or attributes.
242 static std::unique_ptr<DynamicAttrDefinition>
243 getDynamicPairAttr(TestDialect *testDialect) {
244   return DynamicAttrDefinition::get(
245       "dynamic_pair", testDialect,
246       [](function_ref<InFlightDiagnostic()> emitError,
247          ArrayRef<Attribute> args) {
248         if (args.size() != 2) {
249           emitError() << "expected 2 attribute arguments, but had "
250                       << args.size();
251           return failure();
252         }
253         return success();
254       });
255 }
256 
257 static std::unique_ptr<DynamicAttrDefinition>
258 getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) {
259   auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
260                      ArrayRef<Attribute> args) {
261     if (args.size() != 2) {
262       emitError() << "expected 2 attribute arguments, but had " << args.size();
263       return failure();
264     }
265     return success();
266   };
267 
268   auto parser = [](AsmParser &parser,
269                    llvm::SmallVectorImpl<Attribute> &parsedParams) {
270     Attribute leftAttr, rightAttr;
271     if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
272         parser.parseColon() || parser.parseAttribute(rightAttr) ||
273         parser.parseGreater())
274       return failure();
275     parsedParams.push_back(leftAttr);
276     parsedParams.push_back(rightAttr);
277     return success();
278   };
279 
280   auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
281     printer << "<" << params[0] << ":" << params[1] << ">";
282   };
283 
284   return DynamicAttrDefinition::get("dynamic_custom_assembly_format",
285                                     testDialect, std::move(verifier),
286                                     std::move(parser), std::move(printer));
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // TestDialect
291 //===----------------------------------------------------------------------===//
292 
293 void TestDialect::registerAttributes() {
294   addAttributes<
295 #define GET_ATTRDEF_LIST
296 #include "TestAttrDefs.cpp.inc"
297       >();
298   registerDynamicAttr(getDynamicSingletonAttr(this));
299   registerDynamicAttr(getDynamicPairAttr(this));
300   registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this));
301 }
302