1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains attributes defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/Types.h"
19 #include "mlir/Support/LogicalResult.h"
20 #include "llvm/ADT/Hashing.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/ADT/bit.h"
24 
25 using namespace mlir;
26 using namespace test;
27 
28 //===----------------------------------------------------------------------===//
29 // AttrWithSelfTypeParamAttr
30 //===----------------------------------------------------------------------===//
31 
32 Attribute AttrWithSelfTypeParamAttr::parse(DialectAsmParser &parser,
33                                            Type type) {
34   Type selfType;
35   if (parser.parseType(selfType))
36     return Attribute();
37   return get(parser.getContext(), selfType);
38 }
39 
40 void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const {
41   printer << "attr_with_self_type_param " << getType();
42 }
43 
44 //===----------------------------------------------------------------------===//
45 // AttrWithTypeBuilderAttr
46 //===----------------------------------------------------------------------===//
47 
48 Attribute AttrWithTypeBuilderAttr::parse(DialectAsmParser &parser, Type type) {
49   IntegerAttr element;
50   if (parser.parseAttribute(element))
51     return Attribute();
52   return get(parser.getContext(), element);
53 }
54 
55 void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const {
56   printer << "attr_with_type_builder " << getAttr();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // CompoundAAttr
61 //===----------------------------------------------------------------------===//
62 
63 Attribute CompoundAAttr::parse(DialectAsmParser &parser, Type type) {
64   int widthOfSomething;
65   Type oneType;
66   SmallVector<int, 4> arrayOfInts;
67   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
68       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
69       parser.parseLSquare())
70     return Attribute();
71 
72   int intVal;
73   while (!*parser.parseOptionalInteger(intVal)) {
74     arrayOfInts.push_back(intVal);
75     if (parser.parseOptionalComma())
76       break;
77   }
78 
79   if (parser.parseRSquare() || parser.parseGreater())
80     return Attribute();
81   return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
82 }
83 
84 void CompoundAAttr::print(DialectAsmPrinter &printer) const {
85   printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
86           << ", [";
87   llvm::interleaveComma(getArrayOfInts(), printer);
88   printer << "]>";
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // CompoundAAttr
93 //===----------------------------------------------------------------------===//
94 
95 Attribute TestI64ElementsAttr::parse(DialectAsmParser &parser, Type type) {
96   SmallVector<uint64_t> elements;
97   if (parser.parseLess() || parser.parseLSquare())
98     return Attribute();
99   uint64_t intVal;
100   while (succeeded(*parser.parseOptionalInteger(intVal))) {
101     elements.push_back(intVal);
102     if (parser.parseOptionalComma())
103       break;
104   }
105 
106   if (parser.parseRSquare() || parser.parseGreater())
107     return Attribute();
108   return parser.getChecked<TestI64ElementsAttr>(
109       parser.getContext(), type.cast<ShapedType>(), elements);
110 }
111 
112 void TestI64ElementsAttr::print(DialectAsmPrinter &printer) const {
113   printer << "i64_elements<[";
114   llvm::interleaveComma(getElements(), printer);
115   printer << "] : " << getType() << ">";
116 }
117 
118 LogicalResult
119 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
120                             ShapedType type, ArrayRef<uint64_t> elements) {
121   if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
122     return emitError()
123            << "number of elements does not match the provided shape type, got: "
124            << elements.size() << ", but expected: " << type.getNumElements();
125   }
126   if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
127     return emitError() << "expected single rank 64-bit shape type, but got: "
128                        << type;
129   return success();
130 }
131 
132 LogicalResult
133 TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
134                                int64_t one, std::string two, IntegerAttr three,
135                                ArrayRef<int> four) {
136   if (four.size() != static_cast<unsigned>(one))
137     return emitError() << "expected 'one' to equal 'four.size()'";
138   return success();
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // Utility Functions for Generated Attributes
143 //===----------------------------------------------------------------------===//
144 
145 static FailureOr<SmallVector<int>> parseIntArray(DialectAsmParser &parser) {
146   SmallVector<int> ints;
147   if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
148         ints.push_back(0);
149         return parser.parseInteger(ints.back());
150       }) ||
151       parser.parseRSquare())
152     return failure();
153   return ints;
154 }
155 
156 static void printIntArray(DialectAsmPrinter &printer, ArrayRef<int> ints) {
157   printer << '[';
158   llvm::interleaveComma(ints, printer);
159   printer << ']';
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // TestSubElementsAccessAttr
164 //===----------------------------------------------------------------------===//
165 
166 Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser,
167                                            ::mlir::Type type) {
168   Attribute first, second, third;
169   if (parser.parseLess() || parser.parseAttribute(first) ||
170       parser.parseComma() || parser.parseAttribute(second) ||
171       parser.parseComma() || parser.parseAttribute(third) ||
172       parser.parseGreater()) {
173     return {};
174   }
175   return get(parser.getContext(), first, second, third);
176 }
177 
178 void TestSubElementsAccessAttr::print(
179     ::mlir::DialectAsmPrinter &printer) const {
180   printer << getMnemonic() << "<" << getFirst() << ", " << getSecond() << ", "
181           << getThird() << ">";
182 }
183 
184 void TestSubElementsAccessAttr::walkImmediateSubElements(
185     llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
186     llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
187   walkAttrsFn(getFirst());
188   walkAttrsFn(getSecond());
189   walkAttrsFn(getThird());
190 }
191 
192 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
193     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
194   Attribute first = getFirst();
195   Attribute second = getSecond();
196   Attribute third = getThird();
197   for (auto &it : replacements) {
198     switch (it.first) {
199     case 0:
200       first = it.second;
201       break;
202     case 1:
203       second = it.second;
204       break;
205     case 2:
206       third = it.second;
207       break;
208     }
209   }
210   return get(getContext(), first, second, third);
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // Tablegen Generated Definitions
215 //===----------------------------------------------------------------------===//
216 
217 #include "TestAttrInterfaces.cpp.inc"
218 
219 #define GET_ATTRDEF_CLASSES
220 #include "TestAttrDefs.cpp.inc"
221 
222 //===----------------------------------------------------------------------===//
223 // TestDialect
224 //===----------------------------------------------------------------------===//
225 
226 void TestDialect::registerAttributes() {
227   addAttributes<
228 #define GET_ATTRDEF_LIST
229 #include "TestAttrDefs.cpp.inc"
230       >();
231 }
232 
233 Attribute TestDialect::parseAttribute(DialectAsmParser &parser,
234                                       Type type) const {
235   StringRef attrTag;
236   if (failed(parser.parseKeyword(&attrTag)))
237     return Attribute();
238   {
239     Attribute attr;
240     auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
241     if (parseResult.hasValue())
242       return attr;
243   }
244   parser.emitError(parser.getNameLoc(), "unknown test attribute");
245   return Attribute();
246 }
247 
248 void TestDialect::printAttribute(Attribute attr,
249                                  DialectAsmPrinter &printer) const {
250   if (succeeded(generatedAttributePrinter(attr, printer)))
251     return;
252 }
253