1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains attributes defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/Types.h"
19 #include "mlir/Support/LogicalResult.h"
20 #include "llvm/ADT/Hashing.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/ADT/bit.h"
24 
25 using namespace mlir;
26 using namespace test;
27 
28 //===----------------------------------------------------------------------===//
29 // AttrWithSelfTypeParamAttr
30 //===----------------------------------------------------------------------===//
31 
32 Attribute AttrWithSelfTypeParamAttr::parse(DialectAsmParser &parser,
33                                            Type type) {
34   Type selfType;
35   if (parser.parseType(selfType))
36     return Attribute();
37   return get(parser.getContext(), selfType);
38 }
39 
40 void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const {
41   printer << " " << getType();
42 }
43 
44 //===----------------------------------------------------------------------===//
45 // AttrWithTypeBuilderAttr
46 //===----------------------------------------------------------------------===//
47 
48 Attribute AttrWithTypeBuilderAttr::parse(DialectAsmParser &parser, Type type) {
49   IntegerAttr element;
50   if (parser.parseAttribute(element))
51     return Attribute();
52   return get(parser.getContext(), element);
53 }
54 
55 void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const {
56   printer << " " << getAttr();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // CompoundAAttr
61 //===----------------------------------------------------------------------===//
62 
63 Attribute CompoundAAttr::parse(DialectAsmParser &parser, Type type) {
64   int widthOfSomething;
65   Type oneType;
66   SmallVector<int, 4> arrayOfInts;
67   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
68       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
69       parser.parseLSquare())
70     return Attribute();
71 
72   int intVal;
73   while (!*parser.parseOptionalInteger(intVal)) {
74     arrayOfInts.push_back(intVal);
75     if (parser.parseOptionalComma())
76       break;
77   }
78 
79   if (parser.parseRSquare() || parser.parseGreater())
80     return Attribute();
81   return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
82 }
83 
84 void CompoundAAttr::print(DialectAsmPrinter &printer) const {
85   printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
86   llvm::interleaveComma(getArrayOfInts(), printer);
87   printer << "]>";
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // CompoundAAttr
92 //===----------------------------------------------------------------------===//
93 
94 Attribute TestI64ElementsAttr::parse(DialectAsmParser &parser, Type type) {
95   SmallVector<uint64_t> elements;
96   if (parser.parseLess() || parser.parseLSquare())
97     return Attribute();
98   uint64_t intVal;
99   while (succeeded(*parser.parseOptionalInteger(intVal))) {
100     elements.push_back(intVal);
101     if (parser.parseOptionalComma())
102       break;
103   }
104 
105   if (parser.parseRSquare() || parser.parseGreater())
106     return Attribute();
107   return parser.getChecked<TestI64ElementsAttr>(
108       parser.getContext(), type.cast<ShapedType>(), elements);
109 }
110 
111 void TestI64ElementsAttr::print(DialectAsmPrinter &printer) const {
112   printer << "<[";
113   llvm::interleaveComma(getElements(), printer);
114   printer << "] : " << getType() << ">";
115 }
116 
117 LogicalResult
118 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
119                             ShapedType type, ArrayRef<uint64_t> elements) {
120   if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
121     return emitError()
122            << "number of elements does not match the provided shape type, got: "
123            << elements.size() << ", but expected: " << type.getNumElements();
124   }
125   if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
126     return emitError() << "expected single rank 64-bit shape type, but got: "
127                        << type;
128   return success();
129 }
130 
131 LogicalResult
132 TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
133                                int64_t one, std::string two, IntegerAttr three,
134                                ArrayRef<int> four) {
135   if (four.size() != static_cast<unsigned>(one))
136     return emitError() << "expected 'one' to equal 'four.size()'";
137   return success();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // Utility Functions for Generated Attributes
142 //===----------------------------------------------------------------------===//
143 
144 static FailureOr<SmallVector<int>> parseIntArray(DialectAsmParser &parser) {
145   SmallVector<int> ints;
146   if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
147         ints.push_back(0);
148         return parser.parseInteger(ints.back());
149       }) ||
150       parser.parseRSquare())
151     return failure();
152   return ints;
153 }
154 
155 static void printIntArray(DialectAsmPrinter &printer, ArrayRef<int> ints) {
156   printer << '[';
157   llvm::interleaveComma(ints, printer);
158   printer << ']';
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // TestSubElementsAccessAttr
163 //===----------------------------------------------------------------------===//
164 
165 Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser,
166                                            ::mlir::Type type) {
167   Attribute first, second, third;
168   if (parser.parseLess() || parser.parseAttribute(first) ||
169       parser.parseComma() || parser.parseAttribute(second) ||
170       parser.parseComma() || parser.parseAttribute(third) ||
171       parser.parseGreater()) {
172     return {};
173   }
174   return get(parser.getContext(), first, second, third);
175 }
176 
177 void TestSubElementsAccessAttr::print(
178     ::mlir::DialectAsmPrinter &printer) const {
179   printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird()
180           << ">";
181 }
182 
183 void TestSubElementsAccessAttr::walkImmediateSubElements(
184     llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
185     llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
186   walkAttrsFn(getFirst());
187   walkAttrsFn(getSecond());
188   walkAttrsFn(getThird());
189 }
190 
191 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
192     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
193   Attribute first = getFirst();
194   Attribute second = getSecond();
195   Attribute third = getThird();
196   for (auto &it : replacements) {
197     switch (it.first) {
198     case 0:
199       first = it.second;
200       break;
201     case 1:
202       second = it.second;
203       break;
204     case 2:
205       third = it.second;
206       break;
207     }
208   }
209   return get(getContext(), first, second, third);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // Tablegen Generated Definitions
214 //===----------------------------------------------------------------------===//
215 
216 #include "TestAttrInterfaces.cpp.inc"
217 
218 #define GET_ATTRDEF_CLASSES
219 #include "TestAttrDefs.cpp.inc"
220 
221 //===----------------------------------------------------------------------===//
222 // TestDialect
223 //===----------------------------------------------------------------------===//
224 
225 void TestDialect::registerAttributes() {
226   addAttributes<
227 #define GET_ATTRDEF_LIST
228 #include "TestAttrDefs.cpp.inc"
229       >();
230 }
231