1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // Attribute wrapper to simplify using TableGen Record defining a MLIR
19 // Attribute.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "mlir/TableGen/Format.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "llvm/TableGen/Record.h"
26 
27 using namespace mlir;
28 
29 using llvm::CodeInit;
30 using llvm::DefInit;
31 using llvm::Init;
32 using llvm::Record;
33 using llvm::StringInit;
34 
35 // Returns the initializer's value as string if the given TableGen initializer
36 // is a code or string initializer. Returns the empty StringRef otherwise.
37 static StringRef getValueAsString(const Init *init) {
38   if (const auto *code = dyn_cast<CodeInit>(init))
39     return code->getValue().trim();
40   else if (const auto *str = dyn_cast<StringInit>(init))
41     return str->getValue().trim();
42   return {};
43 }
44 
45 tblgen::AttrConstraint::AttrConstraint(const Record *record)
46     : Constraint(Constraint::CK_Attr, record) {
47   assert(def->isSubClassOf("AttrConstraint") &&
48          "must be subclass of TableGen 'AttrConstraint' class");
49 }
50 
51 tblgen::Attribute::Attribute(const Record *record) : AttrConstraint(record) {
52   assert(record->isSubClassOf("Attr") &&
53          "must be subclass of TableGen 'Attr' class");
54 }
55 
56 tblgen::Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
57 
58 bool tblgen::Attribute::isDerivedAttr() const {
59   return def->isSubClassOf("DerivedAttr");
60 }
61 
62 bool tblgen::Attribute::isTypeAttr() const {
63   return def->isSubClassOf("TypeAttrBase");
64 }
65 
66 bool tblgen::Attribute::isEnumAttr() const {
67   return def->isSubClassOf("EnumAttrInfo");
68 }
69 
70 StringRef tblgen::Attribute::getStorageType() const {
71   const auto *init = def->getValueInit("storageType");
72   auto type = getValueAsString(init);
73   if (type.empty())
74     return "Attribute";
75   return type;
76 }
77 
78 StringRef tblgen::Attribute::getReturnType() const {
79   const auto *init = def->getValueInit("returnType");
80   return getValueAsString(init);
81 }
82 
83 StringRef tblgen::Attribute::getConvertFromStorageCall() const {
84   const auto *init = def->getValueInit("convertFromStorage");
85   return getValueAsString(init);
86 }
87 
88 bool tblgen::Attribute::isConstBuildable() const {
89   const auto *init = def->getValueInit("constBuilderCall");
90   return !getValueAsString(init).empty();
91 }
92 
93 StringRef tblgen::Attribute::getConstBuilderTemplate() const {
94   const auto *init = def->getValueInit("constBuilderCall");
95   return getValueAsString(init);
96 }
97 
98 tblgen::Attribute tblgen::Attribute::getBaseAttr() const {
99   if (const auto *defInit =
100           llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
101     return Attribute(defInit).getBaseAttr();
102   }
103   return *this;
104 }
105 
106 bool tblgen::Attribute::hasDefaultValueInitializer() const {
107   const auto *init = def->getValueInit("defaultValue");
108   return !getValueAsString(init).empty();
109 }
110 
111 StringRef tblgen::Attribute::getDefaultValueInitializer() const {
112   const auto *init = def->getValueInit("defaultValue");
113   return getValueAsString(init);
114 }
115 
116 bool tblgen::Attribute::isOptional() const {
117   return def->getValueAsBit("isOptional");
118 }
119 
120 StringRef tblgen::Attribute::getAttrDefName() const {
121   if (def->isAnonymous()) {
122     return getBaseAttr().def->getName();
123   }
124   return def->getName();
125 }
126 
127 StringRef tblgen::Attribute::getDerivedCodeBody() const {
128   assert(isDerivedAttr() && "only derived attribute has 'body' field");
129   return def->getValueAsString("body");
130 }
131 
132 tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
133   assert(def->isSubClassOf("ConstantAttr") &&
134          "must be subclass of TableGen 'ConstantAttr' class");
135 }
136 
137 tblgen::Attribute tblgen::ConstantAttr::getAttribute() const {
138   return Attribute(def->getValueAsDef("attr"));
139 }
140 
141 StringRef tblgen::ConstantAttr::getConstantValue() const {
142   return def->getValueAsString("value");
143 }
144 
145 tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
146     : Attribute(init) {
147   assert(def->isSubClassOf("EnumAttrCaseInfo") &&
148          "must be subclass of TableGen 'EnumAttrInfo' class");
149 }
150 
151 bool tblgen::EnumAttrCase::isStrCase() const {
152   return def->isSubClassOf("StrEnumAttrCase");
153 }
154 
155 StringRef tblgen::EnumAttrCase::getSymbol() const {
156   return def->getValueAsString("symbol");
157 }
158 
159 int64_t tblgen::EnumAttrCase::getValue() const {
160   return def->getValueAsInt("value");
161 }
162 
163 tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
164   assert(def->isSubClassOf("EnumAttrInfo") &&
165          "must be subclass of TableGen 'EnumAttr' class");
166 }
167 
168 tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
169 
170 tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
171     : EnumAttr(init->getDef()) {}
172 
173 bool tblgen::EnumAttr::skipAutoGen() const {
174   return def->getValueAsBit("skipAutoGen");
175 }
176 
177 StringRef tblgen::EnumAttr::getEnumClassName() const {
178   return def->getValueAsString("className");
179 }
180 
181 StringRef tblgen::EnumAttr::getCppNamespace() const {
182   return def->getValueAsString("cppNamespace");
183 }
184 
185 StringRef tblgen::EnumAttr::getUnderlyingType() const {
186   return def->getValueAsString("underlyingType");
187 }
188 
189 StringRef tblgen::EnumAttr::getUnderlyingToSymbolFnName() const {
190   return def->getValueAsString("underlyingToSymbolFnName");
191 }
192 
193 StringRef tblgen::EnumAttr::getStringToSymbolFnName() const {
194   return def->getValueAsString("stringToSymbolFnName");
195 }
196 
197 StringRef tblgen::EnumAttr::getSymbolToStringFnName() const {
198   return def->getValueAsString("symbolToStringFnName");
199 }
200 
201 StringRef tblgen::EnumAttr::getSymbolToStringFnRetType() const {
202   return def->getValueAsString("symbolToStringFnRetType");
203 }
204 
205 StringRef tblgen::EnumAttr::getMaxEnumValFnName() const {
206   return def->getValueAsString("maxEnumValFnName");
207 }
208 
209 std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const {
210   const auto *inits = def->getValueAsListInit("enumerants");
211 
212   std::vector<tblgen::EnumAttrCase> cases;
213   cases.reserve(inits->size());
214 
215   for (const llvm::Init *init : *inits) {
216     cases.push_back(tblgen::EnumAttrCase(cast<llvm::DefInit>(init)));
217   }
218 
219   return cases;
220 }
221 
222 tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record *record)
223     : def(record) {
224   assert(def->isSubClassOf("StructFieldAttr") &&
225          "must be subclass of TableGen 'StructFieldAttr' class");
226 }
227 
228 tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record &record)
229     : StructFieldAttr(&record) {}
230 
231 tblgen::StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
232     : StructFieldAttr(init->getDef()) {}
233 
234 StringRef tblgen::StructFieldAttr::getName() const {
235   return def->getValueAsString("name");
236 }
237 
238 tblgen::Attribute tblgen::StructFieldAttr::getType() const {
239   auto init = def->getValueInit("type");
240   return tblgen::Attribute(cast<llvm::DefInit>(init));
241 }
242 
243 tblgen::StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
244   assert(def->isSubClassOf("StructAttr") &&
245          "must be subclass of TableGen 'StructAttr' class");
246 }
247 
248 tblgen::StructAttr::StructAttr(const llvm::DefInit *init)
249     : StructAttr(init->getDef()) {}
250 
251 StringRef tblgen::StructAttr::getStructClassName() const {
252   return def->getValueAsString("className");
253 }
254 
255 StringRef tblgen::StructAttr::getCppNamespace() const {
256   Dialect dialect(def->getValueAsDef("structDialect"));
257   return dialect.getCppNamespace();
258 }
259 
260 std::vector<mlir::tblgen::StructFieldAttr>
261 tblgen::StructAttr::getAllFields() const {
262   std::vector<mlir::tblgen::StructFieldAttr> attributes;
263 
264   const auto *inits = def->getValueAsListInit("fields");
265   attributes.reserve(inits->size());
266 
267   for (const llvm::Init *init : *inits) {
268     attributes.emplace_back(cast<llvm::DefInit>(init));
269   }
270 
271   return attributes;
272 }
273