1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // Attribute wrapper to simplify using TableGen Record defining a MLIR
19 // Attribute.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "mlir/TableGen/Format.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "llvm/TableGen/Record.h"
26 
27 using namespace mlir;
28 
29 using llvm::CodeInit;
30 using llvm::DefInit;
31 using llvm::Init;
32 using llvm::Record;
33 using llvm::StringInit;
34 
35 // Returns the initializer's value as string if the given TableGen initializer
36 // is a code or string initializer. Returns the empty StringRef otherwise.
37 static StringRef getValueAsString(const Init *init) {
38   if (const auto *code = dyn_cast<CodeInit>(init))
39     return code->getValue().trim();
40   else if (const auto *str = dyn_cast<StringInit>(init))
41     return str->getValue().trim();
42   return {};
43 }
44 
45 tblgen::AttrConstraint::AttrConstraint(const Record *record)
46     : Constraint(Constraint::CK_Attr, record) {
47   assert(isSubClassOf("AttrConstraint") &&
48          "must be subclass of TableGen 'AttrConstraint' class");
49 }
50 
51 bool tblgen::AttrConstraint::isSubClassOf(StringRef className) const {
52   return def->isSubClassOf(className);
53 }
54 
55 tblgen::Attribute::Attribute(const Record *record) : AttrConstraint(record) {
56   assert(record->isSubClassOf("Attr") &&
57          "must be subclass of TableGen 'Attr' class");
58 }
59 
60 tblgen::Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
61 
62 bool tblgen::Attribute::isDerivedAttr() const {
63   return isSubClassOf("DerivedAttr");
64 }
65 
66 bool tblgen::Attribute::isTypeAttr() const {
67   return isSubClassOf("TypeAttrBase");
68 }
69 
70 bool tblgen::Attribute::isEnumAttr() const {
71   return isSubClassOf("EnumAttrInfo");
72 }
73 
74 StringRef tblgen::Attribute::getStorageType() const {
75   const auto *init = def->getValueInit("storageType");
76   auto type = getValueAsString(init);
77   if (type.empty())
78     return "Attribute";
79   return type;
80 }
81 
82 StringRef tblgen::Attribute::getReturnType() const {
83   const auto *init = def->getValueInit("returnType");
84   return getValueAsString(init);
85 }
86 
87 StringRef tblgen::Attribute::getConvertFromStorageCall() const {
88   const auto *init = def->getValueInit("convertFromStorage");
89   return getValueAsString(init);
90 }
91 
92 bool tblgen::Attribute::isConstBuildable() const {
93   const auto *init = def->getValueInit("constBuilderCall");
94   return !getValueAsString(init).empty();
95 }
96 
97 StringRef tblgen::Attribute::getConstBuilderTemplate() const {
98   const auto *init = def->getValueInit("constBuilderCall");
99   return getValueAsString(init);
100 }
101 
102 tblgen::Attribute tblgen::Attribute::getBaseAttr() const {
103   if (const auto *defInit =
104           llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
105     return Attribute(defInit).getBaseAttr();
106   }
107   return *this;
108 }
109 
110 bool tblgen::Attribute::hasDefaultValue() const {
111   const auto *init = def->getValueInit("defaultValue");
112   return !getValueAsString(init).empty();
113 }
114 
115 StringRef tblgen::Attribute::getDefaultValue() const {
116   const auto *init = def->getValueInit("defaultValue");
117   return getValueAsString(init);
118 }
119 
120 bool tblgen::Attribute::isOptional() const {
121   return def->getValueAsBit("isOptional");
122 }
123 
124 StringRef tblgen::Attribute::getAttrDefName() const {
125   if (def->isAnonymous()) {
126     return getBaseAttr().def->getName();
127   }
128   return def->getName();
129 }
130 
131 StringRef tblgen::Attribute::getDerivedCodeBody() const {
132   assert(isDerivedAttr() && "only derived attribute has 'body' field");
133   return def->getValueAsString("body");
134 }
135 
136 tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
137   assert(def->isSubClassOf("ConstantAttr") &&
138          "must be subclass of TableGen 'ConstantAttr' class");
139 }
140 
141 tblgen::Attribute tblgen::ConstantAttr::getAttribute() const {
142   return Attribute(def->getValueAsDef("attr"));
143 }
144 
145 StringRef tblgen::ConstantAttr::getConstantValue() const {
146   return def->getValueAsString("value");
147 }
148 
149 tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
150     : Attribute(init) {
151   assert(isSubClassOf("EnumAttrCaseInfo") &&
152          "must be subclass of TableGen 'EnumAttrInfo' class");
153 }
154 
155 bool tblgen::EnumAttrCase::isStrCase() const {
156   return isSubClassOf("StrEnumAttrCase");
157 }
158 
159 StringRef tblgen::EnumAttrCase::getSymbol() const {
160   return def->getValueAsString("symbol");
161 }
162 
163 int64_t tblgen::EnumAttrCase::getValue() const {
164   return def->getValueAsInt("value");
165 }
166 
167 tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
168   assert(isSubClassOf("EnumAttrInfo") &&
169          "must be subclass of TableGen 'EnumAttr' class");
170 }
171 
172 tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
173 
174 tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
175     : EnumAttr(init->getDef()) {}
176 
177 bool tblgen::EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
178 
179 StringRef tblgen::EnumAttr::getEnumClassName() const {
180   return def->getValueAsString("className");
181 }
182 
183 StringRef tblgen::EnumAttr::getCppNamespace() const {
184   return def->getValueAsString("cppNamespace");
185 }
186 
187 StringRef tblgen::EnumAttr::getUnderlyingType() const {
188   return def->getValueAsString("underlyingType");
189 }
190 
191 StringRef tblgen::EnumAttr::getUnderlyingToSymbolFnName() const {
192   return def->getValueAsString("underlyingToSymbolFnName");
193 }
194 
195 StringRef tblgen::EnumAttr::getStringToSymbolFnName() const {
196   return def->getValueAsString("stringToSymbolFnName");
197 }
198 
199 StringRef tblgen::EnumAttr::getSymbolToStringFnName() const {
200   return def->getValueAsString("symbolToStringFnName");
201 }
202 
203 StringRef tblgen::EnumAttr::getSymbolToStringFnRetType() const {
204   return def->getValueAsString("symbolToStringFnRetType");
205 }
206 
207 StringRef tblgen::EnumAttr::getMaxEnumValFnName() const {
208   return def->getValueAsString("maxEnumValFnName");
209 }
210 
211 std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const {
212   const auto *inits = def->getValueAsListInit("enumerants");
213 
214   std::vector<tblgen::EnumAttrCase> cases;
215   cases.reserve(inits->size());
216 
217   for (const llvm::Init *init : *inits) {
218     cases.push_back(tblgen::EnumAttrCase(cast<llvm::DefInit>(init)));
219   }
220 
221   return cases;
222 }
223 
224 tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record *record)
225     : def(record) {
226   assert(def->isSubClassOf("StructFieldAttr") &&
227          "must be subclass of TableGen 'StructFieldAttr' class");
228 }
229 
230 tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record &record)
231     : StructFieldAttr(&record) {}
232 
233 tblgen::StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
234     : StructFieldAttr(init->getDef()) {}
235 
236 StringRef tblgen::StructFieldAttr::getName() const {
237   return def->getValueAsString("name");
238 }
239 
240 tblgen::Attribute tblgen::StructFieldAttr::getType() const {
241   auto init = def->getValueInit("type");
242   return tblgen::Attribute(cast<llvm::DefInit>(init));
243 }
244 
245 tblgen::StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
246   assert(isSubClassOf("StructAttr") &&
247          "must be subclass of TableGen 'StructAttr' class");
248 }
249 
250 tblgen::StructAttr::StructAttr(const llvm::DefInit *init)
251     : StructAttr(init->getDef()) {}
252 
253 StringRef tblgen::StructAttr::getStructClassName() const {
254   return def->getValueAsString("className");
255 }
256 
257 StringRef tblgen::StructAttr::getCppNamespace() const {
258   Dialect dialect(def->getValueAsDef("structDialect"));
259   return dialect.getCppNamespace();
260 }
261 
262 std::vector<mlir::tblgen::StructFieldAttr>
263 tblgen::StructAttr::getAllFields() const {
264   std::vector<mlir::tblgen::StructFieldAttr> attributes;
265 
266   const auto *inits = def->getValueAsListInit("fields");
267   attributes.reserve(inits->size());
268 
269   for (const llvm::Init *init : *inits) {
270     attributes.emplace_back(cast<llvm::DefInit>(init));
271   }
272 
273   return attributes;
274 }
275