1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains attributes defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/ExtensibleDialect.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/Support/LogicalResult.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/ADT/bit.h"
25 #include "llvm/Support/ErrorHandling.h"
26 
27 using namespace mlir;
28 using namespace test;
29 
30 //===----------------------------------------------------------------------===//
31 // AttrWithTypeBuilderAttr
32 //===----------------------------------------------------------------------===//
33 
34 Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) {
35   IntegerAttr element;
36   if (parser.parseAttribute(element))
37     return Attribute();
38   return get(parser.getContext(), element);
39 }
40 
41 void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const {
42   printer << " " << getAttr();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // CompoundAAttr
47 //===----------------------------------------------------------------------===//
48 
49 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) {
50   int widthOfSomething;
51   Type oneType;
52   SmallVector<int, 4> arrayOfInts;
53   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
54       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
55       parser.parseLSquare())
56     return Attribute();
57 
58   int intVal;
59   while (!*parser.parseOptionalInteger(intVal)) {
60     arrayOfInts.push_back(intVal);
61     if (parser.parseOptionalComma())
62       break;
63   }
64 
65   if (parser.parseRSquare() || parser.parseGreater())
66     return Attribute();
67   return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
68 }
69 
70 void CompoundAAttr::print(AsmPrinter &printer) const {
71   printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
72   llvm::interleaveComma(getArrayOfInts(), printer);
73   printer << "]>";
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // CompoundAAttr
78 //===----------------------------------------------------------------------===//
79 
80 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
81   SmallVector<uint64_t> elements;
82   if (parser.parseLess() || parser.parseLSquare())
83     return Attribute();
84   uint64_t intVal;
85   while (succeeded(*parser.parseOptionalInteger(intVal))) {
86     elements.push_back(intVal);
87     if (parser.parseOptionalComma())
88       break;
89   }
90 
91   if (parser.parseRSquare() || parser.parseGreater())
92     return Attribute();
93   return parser.getChecked<TestI64ElementsAttr>(
94       parser.getContext(), type.cast<ShapedType>(), elements);
95 }
96 
97 void TestI64ElementsAttr::print(AsmPrinter &printer) const {
98   printer << "<[";
99   llvm::interleaveComma(getElements(), printer);
100   printer << "] : " << getType() << ">";
101 }
102 
103 LogicalResult
104 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
105                             ShapedType type, ArrayRef<uint64_t> elements) {
106   if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
107     return emitError()
108            << "number of elements does not match the provided shape type, got: "
109            << elements.size() << ", but expected: " << type.getNumElements();
110   }
111   if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
112     return emitError() << "expected single rank 64-bit shape type, but got: "
113                        << type;
114   return success();
115 }
116 
117 LogicalResult TestAttrWithFormatAttr::verify(
118     function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two,
119     IntegerAttr three, ArrayRef<int> four,
120     ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrWithTypeBuilderAttr) {
121   if (four.size() != static_cast<unsigned>(one))
122     return emitError() << "expected 'one' to equal 'four.size()'";
123   return success();
124 }
125 
126 //===----------------------------------------------------------------------===//
127 // Utility Functions for Generated Attributes
128 //===----------------------------------------------------------------------===//
129 
130 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) {
131   SmallVector<int> ints;
132   if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
133         ints.push_back(0);
134         return parser.parseInteger(ints.back());
135       }) ||
136       parser.parseRSquare())
137     return failure();
138   return ints;
139 }
140 
141 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) {
142   printer << '[';
143   llvm::interleaveComma(ints, printer);
144   printer << ']';
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // TestSubElementsAccessAttr
149 //===----------------------------------------------------------------------===//
150 
151 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser,
152                                            ::mlir::Type type) {
153   Attribute first, second, third;
154   if (parser.parseLess() || parser.parseAttribute(first) ||
155       parser.parseComma() || parser.parseAttribute(second) ||
156       parser.parseComma() || parser.parseAttribute(third) ||
157       parser.parseGreater()) {
158     return {};
159   }
160   return get(parser.getContext(), first, second, third);
161 }
162 
163 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const {
164   printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird()
165           << ">";
166 }
167 
168 void TestSubElementsAccessAttr::walkImmediateSubElements(
169     llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
170     llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
171   walkAttrsFn(getFirst());
172   walkAttrsFn(getSecond());
173   walkAttrsFn(getThird());
174 }
175 
176 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
177     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
178   Attribute first = getFirst();
179   Attribute second = getSecond();
180   Attribute third = getThird();
181   for (auto &it : replacements) {
182     switch (it.first) {
183     case 0:
184       first = it.second;
185       break;
186     case 1:
187       second = it.second;
188       break;
189     case 2:
190       third = it.second;
191       break;
192     }
193   }
194   return get(getContext(), first, second, third);
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // TestExtern1DI64ElementsAttr
199 //===----------------------------------------------------------------------===//
200 
201 ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
202   return getHandle().getData()->getData();
203 }
204 
205 //===----------------------------------------------------------------------===//
206 // Tablegen Generated Definitions
207 //===----------------------------------------------------------------------===//
208 
209 #include "TestAttrInterfaces.cpp.inc"
210 
211 #define GET_ATTRDEF_CLASSES
212 #include "TestAttrDefs.cpp.inc"
213 
214 //===----------------------------------------------------------------------===//
215 // Dynamic Attributes
216 //===----------------------------------------------------------------------===//
217 
218 /// Define a singleton dynamic attribute.
219 static std::unique_ptr<DynamicAttrDefinition>
220 getDynamicSingletonAttr(TestDialect *testDialect) {
221   return DynamicAttrDefinition::get(
222       "dynamic_singleton", testDialect,
223       [](function_ref<InFlightDiagnostic()> emitError,
224          ArrayRef<Attribute> args) {
225         if (!args.empty()) {
226           emitError() << "expected 0 attribute arguments, but had "
227                       << args.size();
228           return failure();
229         }
230         return success();
231       });
232 }
233 
234 /// Define a dynamic attribute representing a pair or attributes.
235 static std::unique_ptr<DynamicAttrDefinition>
236 getDynamicPairAttr(TestDialect *testDialect) {
237   return DynamicAttrDefinition::get(
238       "dynamic_pair", testDialect,
239       [](function_ref<InFlightDiagnostic()> emitError,
240          ArrayRef<Attribute> args) {
241         if (args.size() != 2) {
242           emitError() << "expected 2 attribute arguments, but had "
243                       << args.size();
244           return failure();
245         }
246         return success();
247       });
248 }
249 
250 static std::unique_ptr<DynamicAttrDefinition>
251 getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) {
252   auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
253                      ArrayRef<Attribute> args) {
254     if (args.size() != 2) {
255       emitError() << "expected 2 attribute arguments, but had " << args.size();
256       return failure();
257     }
258     return success();
259   };
260 
261   auto parser = [](AsmParser &parser,
262                    llvm::SmallVectorImpl<Attribute> &parsedParams) {
263     Attribute leftAttr, rightAttr;
264     if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
265         parser.parseColon() || parser.parseAttribute(rightAttr) ||
266         parser.parseGreater())
267       return failure();
268     parsedParams.push_back(leftAttr);
269     parsedParams.push_back(rightAttr);
270     return success();
271   };
272 
273   auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
274     printer << "<" << params[0] << ":" << params[1] << ">";
275   };
276 
277   return DynamicAttrDefinition::get("dynamic_custom_assembly_format",
278                                     testDialect, std::move(verifier),
279                                     std::move(parser), std::move(printer));
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // TestDialect
284 //===----------------------------------------------------------------------===//
285 
286 void TestDialect::registerAttributes() {
287   addAttributes<
288 #define GET_ATTRDEF_LIST
289 #include "TestAttrDefs.cpp.inc"
290       >();
291   registerDynamicAttr(getDynamicSingletonAttr(this));
292   registerDynamicAttr(getDynamicPairAttr(this));
293   registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this));
294 }
295