1 //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains attributes defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestAttributes.h"
15 #include "TestDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/Types.h"
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace test;
25 
26 //===----------------------------------------------------------------------===//
27 // AttrWithSelfTypeParamAttr
28 //===----------------------------------------------------------------------===//
29 
30 Attribute AttrWithSelfTypeParamAttr::parse(DialectAsmParser &parser,
31                                            Type type) {
32   Type selfType;
33   if (parser.parseType(selfType))
34     return Attribute();
35   return get(parser.getContext(), selfType);
36 }
37 
38 void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const {
39   printer << "attr_with_self_type_param " << getType();
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // AttrWithTypeBuilderAttr
44 //===----------------------------------------------------------------------===//
45 
46 Attribute AttrWithTypeBuilderAttr::parse(DialectAsmParser &parser, Type type) {
47   IntegerAttr element;
48   if (parser.parseAttribute(element))
49     return Attribute();
50   return get(parser.getContext(), element);
51 }
52 
53 void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const {
54   printer << "attr_with_type_builder " << getAttr();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // CompoundAAttr
59 //===----------------------------------------------------------------------===//
60 
61 Attribute CompoundAAttr::parse(DialectAsmParser &parser, Type type) {
62   int widthOfSomething;
63   Type oneType;
64   SmallVector<int, 4> arrayOfInts;
65   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
66       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
67       parser.parseLSquare())
68     return Attribute();
69 
70   int intVal;
71   while (!*parser.parseOptionalInteger(intVal)) {
72     arrayOfInts.push_back(intVal);
73     if (parser.parseOptionalComma())
74       break;
75   }
76 
77   if (parser.parseRSquare() || parser.parseGreater())
78     return Attribute();
79   return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
80 }
81 
82 void CompoundAAttr::print(DialectAsmPrinter &printer) const {
83   printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
84           << ", [";
85   llvm::interleaveComma(getArrayOfInts(), printer);
86   printer << "]>";
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // CompoundAAttr
91 //===----------------------------------------------------------------------===//
92 
93 Attribute TestI64ElementsAttr::parse(DialectAsmParser &parser, Type type) {
94   SmallVector<uint64_t> elements;
95   if (parser.parseLess() || parser.parseLSquare())
96     return Attribute();
97   uint64_t intVal;
98   while (succeeded(*parser.parseOptionalInteger(intVal))) {
99     elements.push_back(intVal);
100     if (parser.parseOptionalComma())
101       break;
102   }
103 
104   if (parser.parseRSquare() || parser.parseGreater())
105     return Attribute();
106   return parser.getChecked<TestI64ElementsAttr>(
107       parser.getContext(), type.cast<ShapedType>(), elements);
108 }
109 
110 void TestI64ElementsAttr::print(DialectAsmPrinter &printer) const {
111   printer << "i64_elements<[";
112   llvm::interleaveComma(getElements(), printer);
113   printer << "] : " << getType() << ">";
114 }
115 
116 LogicalResult
117 TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
118                             ShapedType type, ArrayRef<uint64_t> elements) {
119   if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
120     return emitError()
121            << "number of elements does not match the provided shape type, got: "
122            << elements.size() << ", but expected: " << type.getNumElements();
123   }
124   if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
125     return emitError() << "expected single rank 64-bit shape type, but got: "
126                        << type;
127   return success();
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // TestSubElementsAccessAttr
132 //===----------------------------------------------------------------------===//
133 
134 Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser,
135                                            ::mlir::Type type) {
136   Attribute first, second, third;
137   if (parser.parseLess() || parser.parseAttribute(first) ||
138       parser.parseComma() || parser.parseAttribute(second) ||
139       parser.parseComma() || parser.parseAttribute(third) ||
140       parser.parseGreater()) {
141     return {};
142   }
143   return get(parser.getContext(), first, second, third);
144 }
145 
146 void TestSubElementsAccessAttr::print(
147     ::mlir::DialectAsmPrinter &printer) const {
148   printer << getMnemonic() << "<" << getFirst() << ", " << getSecond() << ", "
149           << getThird() << ">";
150 }
151 
152 void TestSubElementsAccessAttr::walkImmediateSubElements(
153     llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
154     llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
155   walkAttrsFn(getFirst());
156   walkAttrsFn(getSecond());
157   walkAttrsFn(getThird());
158 }
159 
160 SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
161     ArrayRef<std::pair<size_t, Attribute>> replacements) const {
162   Attribute first = getFirst();
163   Attribute second = getSecond();
164   Attribute third = getThird();
165   for (auto &it : replacements) {
166     switch (it.first) {
167     case 0:
168       first = it.second;
169       break;
170     case 1:
171       second = it.second;
172       break;
173     case 2:
174       third = it.second;
175       break;
176     }
177   }
178   return get(getContext(), first, second, third);
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // Tablegen Generated Definitions
183 //===----------------------------------------------------------------------===//
184 
185 #include "TestAttrInterfaces.cpp.inc"
186 
187 #define GET_ATTRDEF_CLASSES
188 #include "TestAttrDefs.cpp.inc"
189 
190 //===----------------------------------------------------------------------===//
191 // TestDialect
192 //===----------------------------------------------------------------------===//
193 
194 void TestDialect::registerAttributes() {
195   addAttributes<
196 #define GET_ATTRDEF_LIST
197 #include "TestAttrDefs.cpp.inc"
198       >();
199 }
200 
201 Attribute TestDialect::parseAttribute(DialectAsmParser &parser,
202                                       Type type) const {
203   StringRef attrTag;
204   if (failed(parser.parseKeyword(&attrTag)))
205     return Attribute();
206   {
207     Attribute attr;
208     auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
209     if (parseResult.hasValue())
210       return attr;
211   }
212   parser.emitError(parser.getNameLoc(), "unknown test attribute");
213   return Attribute();
214 }
215 
216 void TestDialect::printAttribute(Attribute attr,
217                                  DialectAsmPrinter &printer) const {
218   if (succeeded(generatedAttributePrinter(attr, printer)))
219     return;
220 }
221