1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains attributes defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/Types.h"
19 #include "mlir/Support/LogicalResult.h"
20 #include "llvm/ADT/Hashing.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/ADT/bit.h"
24 
25 using namespace mlir;
26 using namespace test;
27 
28 //===----------------------------------------------------------------------===//
29 // AttrWithSelfTypeParamAttr
30 //===----------------------------------------------------------------------===//
31 
32 Attribute AttrWithSelfTypeParamAttr::parse(AsmParser &parser, Type type) {
33   Type selfType;
34   if (parser.parseType(selfType))
35     return Attribute();
36   return get(parser.getContext(), selfType);
37 }
38 
39 void AttrWithSelfTypeParamAttr::print(AsmPrinter &printer) const {
40   printer << " " << getType();
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // AttrWithTypeBuilderAttr
45 //===----------------------------------------------------------------------===//
46 
47 Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) {
48   IntegerAttr element;
49   if (parser.parseAttribute(element))
50     return Attribute();
51   return get(parser.getContext(), element);
52 }
53 
54 void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const {
55   printer << " " << getAttr();
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // CompoundAAttr
60 //===----------------------------------------------------------------------===//
61 
62 Attribute CompoundAAttr::parse(AsmParser &parser, Type type) {
63   int widthOfSomething;
64   Type oneType;
65   SmallVector<int, 4> arrayOfInts;
66   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
67       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
68       parser.parseLSquare())
69     return Attribute();
70 
71   int intVal;
72   while (!*parser.parseOptionalInteger(intVal)) {
73     arrayOfInts.push_back(intVal);
74     if (parser.parseOptionalComma())
75       break;
76   }
77 
78   if (parser.parseRSquare() || parser.parseGreater())
79     return Attribute();
80   return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
81 }
82 
83 void CompoundAAttr::print(AsmPrinter &printer) const {
84   printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
85   llvm::interleaveComma(getArrayOfInts(), printer);
86   printer << "]>";
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // CompoundAAttr
91 //===----------------------------------------------------------------------===//
92 
93 Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
94   SmallVector<uint64_t> elements;
95   if (parser.parseLess() || parser.parseLSquare())
96     return Attribute();
97   uint64_t intVal;
98   while (succeeded(*parser.parseOptionalInteger(intVal))) {
99     elements.push_back(intVal);
100     if (parser.parseOptionalComma())
101       break;
102   }
103 
104   if (parser.parseRSquare() || parser.parseGreater())
105     return Attribute();
106   return parser.getChecked<TestI64ElementsAttr>(
107       parser.getContext(), type.cast<ShapedType>(), elements);
108 }
109 
110 void TestI64ElementsAttr::print(AsmPrinter &printer) const {
111   printer << "<[";
112   llvm::interleaveComma(getElements(), printer);
113   printer << "] : " << getType() << ">";
114 }
115 
116 LogicalResult
117 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
118                             ShapedType type, ArrayRef<uint64_t> elements) {
119   if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
120     return emitError()
121            << "number of elements does not match the provided shape type, got: "
122            << elements.size() << ", but expected: " << type.getNumElements();
123   }
124   if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
125     return emitError() << "expected single rank 64-bit shape type, but got: "
126                        << type;
127   return success();
128 }
129 
130 LogicalResult
131 TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
132                                int64_t one, std::string two, IntegerAttr three,
133                                ArrayRef<int> four) {
134   if (four.size() != static_cast<unsigned>(one))
135     return emitError() << "expected 'one' to equal 'four.size()'";
136   return success();
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Utility Functions for Generated Attributes
141 //===----------------------------------------------------------------------===//
142 
143 static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) {
144   SmallVector<int> ints;
145   if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
146         ints.push_back(0);
147         return parser.parseInteger(ints.back());
148       }) ||
149       parser.parseRSquare())
150     return failure();
151   return ints;
152 }
153 
154 static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) {
155   printer << '[';
156   llvm::interleaveComma(ints, printer);
157   printer << ']';
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // TestSubElementsAccessAttr
162 //===----------------------------------------------------------------------===//
163 
164 Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser,
165                                            ::mlir::Type type) {
166   Attribute first, second, third;
167   if (parser.parseLess() || parser.parseAttribute(first) ||
168       parser.parseComma() || parser.parseAttribute(second) ||
169       parser.parseComma() || parser.parseAttribute(third) ||
170       parser.parseGreater()) {
171     return {};
172   }
173   return get(parser.getContext(), first, second, third);
174 }
175 
176 void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const {
177   printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird()
178           << ">";
179 }
180 
181 void TestSubElementsAccessAttr::walkImmediateSubElements(
182     llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
183     llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
184   walkAttrsFn(getFirst());
185   walkAttrsFn(getSecond());
186   walkAttrsFn(getThird());
187 }
188 
189 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
190     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
191   Attribute first = getFirst();
192   Attribute second = getSecond();
193   Attribute third = getThird();
194   for (auto &it : replacements) {
195     switch (it.first) {
196     case 0:
197       first = it.second;
198       break;
199     case 1:
200       second = it.second;
201       break;
202     case 2:
203       third = it.second;
204       break;
205     }
206   }
207   return get(getContext(), first, second, third);
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // Tablegen Generated Definitions
212 //===----------------------------------------------------------------------===//
213 
214 #include "TestAttrInterfaces.cpp.inc"
215 
216 #define GET_ATTRDEF_CLASSES
217 #include "TestAttrDefs.cpp.inc"
218 
219 //===----------------------------------------------------------------------===//
220 // TestDialect
221 //===----------------------------------------------------------------------===//
222 
223 void TestDialect::registerAttributes() {
224   addAttributes<
225 #define GET_ATTRDEF_LIST
226 #include "TestAttrDefs.cpp.inc"
227       >();
228 }
229