1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains attributes defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/ExtensibleDialect.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/Support/LogicalResult.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/ADT/bit.h"
25 #include "llvm/Support/ErrorHandling.h"
26
27 using namespace mlir;
28 using namespace test;
29
30 //===----------------------------------------------------------------------===//
31 // AttrWithTypeBuilderAttr
32 //===----------------------------------------------------------------------===//
33
parse(AsmParser & parser,Type type)34 Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) {
35 IntegerAttr element;
36 if (parser.parseAttribute(element))
37 return Attribute();
38 return get(parser.getContext(), element);
39 }
40
print(AsmPrinter & printer) const41 void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const {
42 printer << " " << getAttr();
43 }
44
45 //===----------------------------------------------------------------------===//
46 // CompoundAAttr
47 //===----------------------------------------------------------------------===//
48
parse(AsmParser & parser,Type type)49 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) {
50 int widthOfSomething;
51 Type oneType;
52 SmallVector<int, 4> arrayOfInts;
53 if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
54 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
55 parser.parseLSquare())
56 return Attribute();
57
58 int intVal;
59 while (!*parser.parseOptionalInteger(intVal)) {
60 arrayOfInts.push_back(intVal);
61 if (parser.parseOptionalComma())
62 break;
63 }
64
65 if (parser.parseRSquare() || parser.parseGreater())
66 return Attribute();
67 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
68 }
69
print(AsmPrinter & printer) const70 void CompoundAAttr::print(AsmPrinter &printer) const {
71 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
72 llvm::interleaveComma(getArrayOfInts(), printer);
73 printer << "]>";
74 }
75
76 //===----------------------------------------------------------------------===//
77 // CompoundAAttr
78 //===----------------------------------------------------------------------===//
79
parse(AsmParser & parser,Type type)80 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
81 SmallVector<uint64_t> elements;
82 if (parser.parseLess() || parser.parseLSquare())
83 return Attribute();
84 uint64_t intVal;
85 while (succeeded(*parser.parseOptionalInteger(intVal))) {
86 elements.push_back(intVal);
87 if (parser.parseOptionalComma())
88 break;
89 }
90
91 if (parser.parseRSquare() || parser.parseGreater())
92 return Attribute();
93 return parser.getChecked<TestI64ElementsAttr>(
94 parser.getContext(), type.cast<ShapedType>(), elements);
95 }
96
print(AsmPrinter & printer) const97 void TestI64ElementsAttr::print(AsmPrinter &printer) const {
98 printer << "<[";
99 llvm::interleaveComma(getElements(), printer);
100 printer << "] : " << getType() << ">";
101 }
102
103 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ShapedType type,ArrayRef<uint64_t> elements)104 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
105 ShapedType type, ArrayRef<uint64_t> elements) {
106 if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
107 return emitError()
108 << "number of elements does not match the provided shape type, got: "
109 << elements.size() << ", but expected: " << type.getNumElements();
110 }
111 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
112 return emitError() << "expected single rank 64-bit shape type, but got: "
113 << type;
114 return success();
115 }
116
verify(function_ref<InFlightDiagnostic ()> emitError,int64_t one,std::string two,IntegerAttr three,ArrayRef<int> four,ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrWithTypeBuilderAttr)117 LogicalResult TestAttrWithFormatAttr::verify(
118 function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two,
119 IntegerAttr three, ArrayRef<int> four,
120 ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrWithTypeBuilderAttr) {
121 if (four.size() != static_cast<unsigned>(one))
122 return emitError() << "expected 'one' to equal 'four.size()'";
123 return success();
124 }
125
126 //===----------------------------------------------------------------------===//
127 // Utility Functions for Generated Attributes
128 //===----------------------------------------------------------------------===//
129
parseIntArray(AsmParser & parser)130 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) {
131 SmallVector<int> ints;
132 if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
133 ints.push_back(0);
134 return parser.parseInteger(ints.back());
135 }) ||
136 parser.parseRSquare())
137 return failure();
138 return ints;
139 }
140
printIntArray(AsmPrinter & printer,ArrayRef<int> ints)141 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) {
142 printer << '[';
143 llvm::interleaveComma(ints, printer);
144 printer << ']';
145 }
146
147 //===----------------------------------------------------------------------===//
148 // TestSubElementsAccessAttr
149 //===----------------------------------------------------------------------===//
150
parse(::mlir::AsmParser & parser,::mlir::Type type)151 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser,
152 ::mlir::Type type) {
153 Attribute first, second, third;
154 if (parser.parseLess() || parser.parseAttribute(first) ||
155 parser.parseComma() || parser.parseAttribute(second) ||
156 parser.parseComma() || parser.parseAttribute(third) ||
157 parser.parseGreater()) {
158 return {};
159 }
160 return get(parser.getContext(), first, second, third);
161 }
162
print(::mlir::AsmPrinter & printer) const163 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const {
164 printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird()
165 << ">";
166 }
167
walkImmediateSubElements(llvm::function_ref<void (mlir::Attribute)> walkAttrsFn,llvm::function_ref<void (mlir::Type)> walkTypesFn) const168 void TestSubElementsAccessAttr::walkImmediateSubElements(
169 llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
170 llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
171 walkAttrsFn(getFirst());
172 walkAttrsFn(getSecond());
173 walkAttrsFn(getThird());
174 }
175
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const176 Attribute TestSubElementsAccessAttr::replaceImmediateSubElements(
177 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
178 assert(replAttrs.size() == 3 && "invalid number of replacement attributes");
179 return get(getContext(), replAttrs[0], replAttrs[1], replAttrs[2]);
180 }
181
182 //===----------------------------------------------------------------------===//
183 // TestExtern1DI64ElementsAttr
184 //===----------------------------------------------------------------------===//
185
getElements() const186 ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
187 return getHandle().getData()->getData();
188 }
189
190 //===----------------------------------------------------------------------===//
191 // Tablegen Generated Definitions
192 //===----------------------------------------------------------------------===//
193
194 #include "TestAttrInterfaces.cpp.inc"
195
196 #define GET_ATTRDEF_CLASSES
197 #include "TestAttrDefs.cpp.inc"
198
199 //===----------------------------------------------------------------------===//
200 // Dynamic Attributes
201 //===----------------------------------------------------------------------===//
202
203 /// Define a singleton dynamic attribute.
204 static std::unique_ptr<DynamicAttrDefinition>
getDynamicSingletonAttr(TestDialect * testDialect)205 getDynamicSingletonAttr(TestDialect *testDialect) {
206 return DynamicAttrDefinition::get(
207 "dynamic_singleton", testDialect,
208 [](function_ref<InFlightDiagnostic()> emitError,
209 ArrayRef<Attribute> args) {
210 if (!args.empty()) {
211 emitError() << "expected 0 attribute arguments, but had "
212 << args.size();
213 return failure();
214 }
215 return success();
216 });
217 }
218
219 /// Define a dynamic attribute representing a pair or attributes.
220 static std::unique_ptr<DynamicAttrDefinition>
getDynamicPairAttr(TestDialect * testDialect)221 getDynamicPairAttr(TestDialect *testDialect) {
222 return DynamicAttrDefinition::get(
223 "dynamic_pair", testDialect,
224 [](function_ref<InFlightDiagnostic()> emitError,
225 ArrayRef<Attribute> args) {
226 if (args.size() != 2) {
227 emitError() << "expected 2 attribute arguments, but had "
228 << args.size();
229 return failure();
230 }
231 return success();
232 });
233 }
234
235 static std::unique_ptr<DynamicAttrDefinition>
getDynamicCustomAssemblyFormatAttr(TestDialect * testDialect)236 getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) {
237 auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
238 ArrayRef<Attribute> args) {
239 if (args.size() != 2) {
240 emitError() << "expected 2 attribute arguments, but had " << args.size();
241 return failure();
242 }
243 return success();
244 };
245
246 auto parser = [](AsmParser &parser,
247 llvm::SmallVectorImpl<Attribute> &parsedParams) {
248 Attribute leftAttr, rightAttr;
249 if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
250 parser.parseColon() || parser.parseAttribute(rightAttr) ||
251 parser.parseGreater())
252 return failure();
253 parsedParams.push_back(leftAttr);
254 parsedParams.push_back(rightAttr);
255 return success();
256 };
257
258 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
259 printer << "<" << params[0] << ":" << params[1] << ">";
260 };
261
262 return DynamicAttrDefinition::get("dynamic_custom_assembly_format",
263 testDialect, std::move(verifier),
264 std::move(parser), std::move(printer));
265 }
266
267 //===----------------------------------------------------------------------===//
268 // TestDialect
269 //===----------------------------------------------------------------------===//
270
registerAttributes()271 void TestDialect::registerAttributes() {
272 addAttributes<
273 #define GET_ATTRDEF_LIST
274 #include "TestAttrDefs.cpp.inc"
275 >();
276 registerDynamicAttr(getDynamicSingletonAttr(this));
277 registerDynamicAttr(getDynamicPairAttr(this));
278 registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this));
279 }
280