1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Attribute wrapper to simplify using TableGen Record defining a MLIR
10 // Attribute.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/TableGen/Record.h"
17 
18 using namespace mlir;
19 using namespace mlir::tblgen;
20 
21 using llvm::DefInit;
22 using llvm::Init;
23 using llvm::Record;
24 using llvm::StringInit;
25 
26 // Returns the initializer's value as string if the given TableGen initializer
27 // is a code or string initializer. Returns the empty StringRef otherwise.
28 static StringRef getValueAsString(const Init *init) {
29   if (const auto *str = dyn_cast<StringInit>(init))
30     return str->getValue().trim();
31   return {};
32 }
33 
34 bool AttrConstraint::isSubClassOf(StringRef className) const {
35   return def->isSubClassOf(className);
36 }
37 
38 Attribute::Attribute(const Record *record) : AttrConstraint(record) {
39   assert(record->isSubClassOf("Attr") &&
40          "must be subclass of TableGen 'Attr' class");
41 }
42 
43 Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
44 
45 bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
46 
47 bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
48 
49 bool Attribute::isSymbolRefAttr() const {
50   StringRef defName = def->getName();
51   if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr")
52     return true;
53   return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
54 }
55 
56 bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
57 
58 StringRef Attribute::getStorageType() const {
59   const auto *init = def->getValueInit("storageType");
60   auto type = getValueAsString(init);
61   if (type.empty())
62     return "::mlir::Attribute";
63   return type;
64 }
65 
66 StringRef Attribute::getReturnType() const {
67   const auto *init = def->getValueInit("returnType");
68   return getValueAsString(init);
69 }
70 
71 // Return the type constraint corresponding to the type of this attribute, or
72 // None if this is not a TypedAttr.
73 llvm::Optional<Type> Attribute::getValueType() const {
74   if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
75     return Type(defInit->getDef());
76   return llvm::None;
77 }
78 
79 StringRef Attribute::getConvertFromStorageCall() const {
80   const auto *init = def->getValueInit("convertFromStorage");
81   return getValueAsString(init);
82 }
83 
84 bool Attribute::isConstBuildable() const {
85   const auto *init = def->getValueInit("constBuilderCall");
86   return !getValueAsString(init).empty();
87 }
88 
89 StringRef Attribute::getConstBuilderTemplate() const {
90   const auto *init = def->getValueInit("constBuilderCall");
91   return getValueAsString(init);
92 }
93 
94 Attribute Attribute::getBaseAttr() const {
95   if (const auto *defInit =
96           llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
97     return Attribute(defInit).getBaseAttr();
98   }
99   return *this;
100 }
101 
102 bool Attribute::hasDefaultValue() const {
103   const auto *init = def->getValueInit("defaultValue");
104   return !getValueAsString(init).empty();
105 }
106 
107 StringRef Attribute::getDefaultValue() const {
108   const auto *init = def->getValueInit("defaultValue");
109   return getValueAsString(init);
110 }
111 
112 bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); }
113 
114 StringRef Attribute::getAttrDefName() const {
115   if (def->isAnonymous()) {
116     return getBaseAttr().def->getName();
117   }
118   return def->getName();
119 }
120 
121 StringRef Attribute::getDerivedCodeBody() const {
122   assert(isDerivedAttr() && "only derived attribute has 'body' field");
123   return def->getValueAsString("body");
124 }
125 
126 Dialect Attribute::getDialect() const {
127   const llvm::RecordVal *record = def->getValue("dialect");
128   if (record && record->getValue()) {
129     if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
130       return Dialect(init->getDef());
131   }
132   return Dialect(nullptr);
133 }
134 
135 StringRef Attribute::getDescription() const {
136   return def->getValueAsString("description");
137 }
138 
139 ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
140   assert(def->isSubClassOf("ConstantAttr") &&
141          "must be subclass of TableGen 'ConstantAttr' class");
142 }
143 
144 Attribute ConstantAttr::getAttribute() const {
145   return Attribute(def->getValueAsDef("attr"));
146 }
147 
148 StringRef ConstantAttr::getConstantValue() const {
149   return def->getValueAsString("value");
150 }
151 
152 EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
153   assert(isSubClassOf("EnumAttrCaseInfo") &&
154          "must be subclass of TableGen 'EnumAttrInfo' class");
155 }
156 
157 EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
158     : EnumAttrCase(init->getDef()) {}
159 
160 bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
161 
162 StringRef EnumAttrCase::getSymbol() const {
163   return def->getValueAsString("symbol");
164 }
165 
166 StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
167 
168 int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
169 
170 const llvm::Record &EnumAttrCase::getDef() const { return *def; }
171 
172 EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
173   assert(isSubClassOf("EnumAttrInfo") &&
174          "must be subclass of TableGen 'EnumAttr' class");
175 }
176 
177 EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
178 
179 EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
180 
181 bool EnumAttr::classof(const Attribute *attr) {
182   return attr->isSubClassOf("EnumAttrInfo");
183 }
184 
185 bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
186 
187 StringRef EnumAttr::getEnumClassName() const {
188   return def->getValueAsString("className");
189 }
190 
191 StringRef EnumAttr::getCppNamespace() const {
192   return def->getValueAsString("cppNamespace");
193 }
194 
195 StringRef EnumAttr::getUnderlyingType() const {
196   return def->getValueAsString("underlyingType");
197 }
198 
199 StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
200   return def->getValueAsString("underlyingToSymbolFnName");
201 }
202 
203 StringRef EnumAttr::getStringToSymbolFnName() const {
204   return def->getValueAsString("stringToSymbolFnName");
205 }
206 
207 StringRef EnumAttr::getSymbolToStringFnName() const {
208   return def->getValueAsString("symbolToStringFnName");
209 }
210 
211 StringRef EnumAttr::getSymbolToStringFnRetType() const {
212   return def->getValueAsString("symbolToStringFnRetType");
213 }
214 
215 StringRef EnumAttr::getMaxEnumValFnName() const {
216   return def->getValueAsString("maxEnumValFnName");
217 }
218 
219 std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
220   const auto *inits = def->getValueAsListInit("enumerants");
221 
222   std::vector<EnumAttrCase> cases;
223   cases.reserve(inits->size());
224 
225   for (const llvm::Init *init : *inits) {
226     cases.emplace_back(cast<llvm::DefInit>(init));
227   }
228 
229   return cases;
230 }
231 
232 bool EnumAttr::genSpecializedAttr() const {
233   return def->getValueAsBit("genSpecializedAttr");
234 }
235 
236 llvm::Record *EnumAttr::getBaseAttrClass() const {
237   return def->getValueAsDef("baseAttrClass");
238 }
239 
240 StringRef EnumAttr::getSpecializedAttrClassName() const {
241   return def->getValueAsString("specializedAttrClassName");
242 }
243 
244 StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
245   assert(def->isSubClassOf("StructFieldAttr") &&
246          "must be subclass of TableGen 'StructFieldAttr' class");
247 }
248 
249 StructFieldAttr::StructFieldAttr(const llvm::Record &record)
250     : StructFieldAttr(&record) {}
251 
252 StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
253     : StructFieldAttr(init->getDef()) {}
254 
255 StringRef StructFieldAttr::getName() const {
256   return def->getValueAsString("name");
257 }
258 
259 Attribute StructFieldAttr::getType() const {
260   auto *init = def->getValueInit("type");
261   return Attribute(cast<llvm::DefInit>(init));
262 }
263 
264 StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
265   assert(isSubClassOf("StructAttr") &&
266          "must be subclass of TableGen 'StructAttr' class");
267 }
268 
269 StructAttr::StructAttr(const llvm::DefInit *init)
270     : StructAttr(init->getDef()) {}
271 
272 StringRef StructAttr::getStructClassName() const {
273   return def->getValueAsString("className");
274 }
275 
276 StringRef StructAttr::getCppNamespace() const {
277   Dialect dialect(def->getValueAsDef("dialect"));
278   return dialect.getCppNamespace();
279 }
280 
281 std::vector<StructFieldAttr> StructAttr::getAllFields() const {
282   std::vector<StructFieldAttr> attributes;
283 
284   const auto *inits = def->getValueAsListInit("fields");
285   attributes.reserve(inits->size());
286 
287   for (const llvm::Init *init : *inits) {
288     attributes.emplace_back(cast<llvm::DefInit>(init));
289   }
290 
291   return attributes;
292 }
293 
294 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
295