1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Attribute wrapper to simplify using TableGen Record defining a MLIR
10 // Attribute.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/TableGen/Record.h"
17 
18 using namespace mlir;
19 using namespace mlir::tblgen;
20 
21 using llvm::CodeInit;
22 using llvm::DefInit;
23 using llvm::Init;
24 using llvm::Record;
25 using llvm::StringInit;
26 
27 // Returns the initializer's value as string if the given TableGen initializer
28 // is a code or string initializer. Returns the empty StringRef otherwise.
29 static StringRef getValueAsString(const Init *init) {
30   if (const auto *code = dyn_cast<CodeInit>(init))
31     return code->getValue().trim();
32   if (const auto *str = dyn_cast<StringInit>(init))
33     return str->getValue().trim();
34   return {};
35 }
36 
37 AttrConstraint::AttrConstraint(const Record *record)
38     : Constraint(Constraint::CK_Attr, record) {
39   assert(isSubClassOf("AttrConstraint") &&
40          "must be subclass of TableGen 'AttrConstraint' class");
41 }
42 
43 bool AttrConstraint::isSubClassOf(StringRef className) const {
44   return def->isSubClassOf(className);
45 }
46 
47 Attribute::Attribute(const Record *record) : AttrConstraint(record) {
48   assert(record->isSubClassOf("Attr") &&
49          "must be subclass of TableGen 'Attr' class");
50 }
51 
52 Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
53 
54 bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
55 
56 bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
57 
58 bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
59 
60 StringRef Attribute::getStorageType() const {
61   const auto *init = def->getValueInit("storageType");
62   auto type = getValueAsString(init);
63   if (type.empty())
64     return "Attribute";
65   return type;
66 }
67 
68 StringRef Attribute::getReturnType() const {
69   const auto *init = def->getValueInit("returnType");
70   return getValueAsString(init);
71 }
72 
73 // Return the type constraint corresponding to the type of this attribute, or
74 // None if this is not a TypedAttr.
75 llvm::Optional<Type> Attribute::getValueType() const {
76   if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
77     return Type(defInit->getDef());
78   return llvm::None;
79 }
80 
81 StringRef Attribute::getConvertFromStorageCall() const {
82   const auto *init = def->getValueInit("convertFromStorage");
83   return getValueAsString(init);
84 }
85 
86 bool Attribute::isConstBuildable() const {
87   const auto *init = def->getValueInit("constBuilderCall");
88   return !getValueAsString(init).empty();
89 }
90 
91 StringRef Attribute::getConstBuilderTemplate() const {
92   const auto *init = def->getValueInit("constBuilderCall");
93   return getValueAsString(init);
94 }
95 
96 Attribute Attribute::getBaseAttr() const {
97   if (const auto *defInit =
98           llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
99     return Attribute(defInit).getBaseAttr();
100   }
101   return *this;
102 }
103 
104 bool Attribute::hasDefaultValue() const {
105   const auto *init = def->getValueInit("defaultValue");
106   return !getValueAsString(init).empty();
107 }
108 
109 StringRef Attribute::getDefaultValue() const {
110   const auto *init = def->getValueInit("defaultValue");
111   return getValueAsString(init);
112 }
113 
114 bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); }
115 
116 StringRef Attribute::getAttrDefName() const {
117   if (def->isAnonymous()) {
118     return getBaseAttr().def->getName();
119   }
120   return def->getName();
121 }
122 
123 StringRef Attribute::getDerivedCodeBody() const {
124   assert(isDerivedAttr() && "only derived attribute has 'body' field");
125   return def->getValueAsString("body");
126 }
127 
128 Dialect Attribute::getDialect() const {
129   const llvm::RecordVal *record = def->getValue("dialect");
130   if (record && record->getValue()) {
131     if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
132       return Dialect(init->getDef());
133   }
134   return Dialect(nullptr);
135 }
136 
137 ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
138   assert(def->isSubClassOf("ConstantAttr") &&
139          "must be subclass of TableGen 'ConstantAttr' class");
140 }
141 
142 Attribute ConstantAttr::getAttribute() const {
143   return Attribute(def->getValueAsDef("attr"));
144 }
145 
146 StringRef ConstantAttr::getConstantValue() const {
147   return def->getValueAsString("value");
148 }
149 
150 EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
151   assert(isSubClassOf("EnumAttrCaseInfo") &&
152          "must be subclass of TableGen 'EnumAttrInfo' class");
153 }
154 
155 EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
156     : EnumAttrCase(init->getDef()) {}
157 
158 bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
159 
160 StringRef EnumAttrCase::getSymbol() const {
161   return def->getValueAsString("symbol");
162 }
163 
164 StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
165 
166 int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
167 
168 const llvm::Record &EnumAttrCase::getDef() const { return *def; }
169 
170 EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
171   assert(isSubClassOf("EnumAttrInfo") &&
172          "must be subclass of TableGen 'EnumAttr' class");
173 }
174 
175 EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
176 
177 EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
178 
179 bool EnumAttr::classof(const Attribute *attr) {
180   return attr->isSubClassOf("EnumAttrInfo");
181 }
182 
183 bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
184 
185 StringRef EnumAttr::getEnumClassName() const {
186   return def->getValueAsString("className");
187 }
188 
189 StringRef EnumAttr::getCppNamespace() const {
190   return def->getValueAsString("cppNamespace");
191 }
192 
193 StringRef EnumAttr::getUnderlyingType() const {
194   return def->getValueAsString("underlyingType");
195 }
196 
197 StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
198   return def->getValueAsString("underlyingToSymbolFnName");
199 }
200 
201 StringRef EnumAttr::getStringToSymbolFnName() const {
202   return def->getValueAsString("stringToSymbolFnName");
203 }
204 
205 StringRef EnumAttr::getSymbolToStringFnName() const {
206   return def->getValueAsString("symbolToStringFnName");
207 }
208 
209 StringRef EnumAttr::getSymbolToStringFnRetType() const {
210   return def->getValueAsString("symbolToStringFnRetType");
211 }
212 
213 StringRef EnumAttr::getMaxEnumValFnName() const {
214   return def->getValueAsString("maxEnumValFnName");
215 }
216 
217 std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
218   const auto *inits = def->getValueAsListInit("enumerants");
219 
220   std::vector<EnumAttrCase> cases;
221   cases.reserve(inits->size());
222 
223   for (const llvm::Init *init : *inits) {
224     cases.push_back(EnumAttrCase(cast<llvm::DefInit>(init)));
225   }
226 
227   return cases;
228 }
229 
230 StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
231   assert(def->isSubClassOf("StructFieldAttr") &&
232          "must be subclass of TableGen 'StructFieldAttr' class");
233 }
234 
235 StructFieldAttr::StructFieldAttr(const llvm::Record &record)
236     : StructFieldAttr(&record) {}
237 
238 StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
239     : StructFieldAttr(init->getDef()) {}
240 
241 StringRef StructFieldAttr::getName() const {
242   return def->getValueAsString("name");
243 }
244 
245 Attribute StructFieldAttr::getType() const {
246   auto init = def->getValueInit("type");
247   return Attribute(cast<llvm::DefInit>(init));
248 }
249 
250 StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
251   assert(isSubClassOf("StructAttr") &&
252          "must be subclass of TableGen 'StructAttr' class");
253 }
254 
255 StructAttr::StructAttr(const llvm::DefInit *init)
256     : StructAttr(init->getDef()) {}
257 
258 StringRef StructAttr::getStructClassName() const {
259   return def->getValueAsString("className");
260 }
261 
262 StringRef StructAttr::getCppNamespace() const {
263   Dialect dialect(def->getValueAsDef("dialect"));
264   return dialect.getCppNamespace();
265 }
266 
267 std::vector<StructFieldAttr> StructAttr::getAllFields() const {
268   std::vector<StructFieldAttr> attributes;
269 
270   const auto *inits = def->getValueAsListInit("fields");
271   attributes.reserve(inits->size());
272 
273   for (const llvm::Init *init : *inits) {
274     attributes.emplace_back(cast<llvm::DefInit>(init));
275   }
276 
277   return attributes;
278 }
279 
280 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
281