1 //===- Operator.cpp - Operator class --------------------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // Operator wrapper to simplify using TableGen Record defining a MLIR Op.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/Operator.h"
23 #include "mlir/TableGen/OpTrait.h"
24 #include "mlir/TableGen/Predicate.h"
25 #include "mlir/TableGen/Type.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/TableGen/Error.h"
28 #include "llvm/TableGen/Record.h"
29 
30 using namespace mlir;
31 
32 using llvm::DagInit;
33 using llvm::DefInit;
34 using llvm::Record;
35 
36 tblgen::Operator::Operator(const llvm::Record &def)
37     : dialect(def.getValueAsDef("opDialect")), def(def) {
38   // The first `_` in the op's TableGen def name is treated as separating the
39   // dialect prefix and the op class name. The dialect prefix will be ignored if
40   // not empty. Otherwise, if def name starts with a `_`, the `_` is considered
41   // as part of the class name.
42   StringRef prefix;
43   std::tie(prefix, cppClassName) = def.getName().split('_');
44   if (prefix.empty()) {
45     // Class name with a leading underscore and without dialect prefix
46     cppClassName = def.getName();
47   } else if (cppClassName.empty()) {
48     // Class name without dialect prefix
49     cppClassName = prefix;
50   }
51 
52   populateOpStructure();
53 }
54 
55 std::string tblgen::Operator::getOperationName() const {
56   auto prefix = dialect.getName();
57   auto opName = def.getValueAsString("opName");
58   if (prefix.empty())
59     return opName;
60   return llvm::formatv("{0}.{1}", prefix, opName);
61 }
62 
63 StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); }
64 
65 StringRef tblgen::Operator::getCppClassName() const { return cppClassName; }
66 
67 std::string tblgen::Operator::getQualCppClassName() const {
68   auto prefix = dialect.getCppNamespace();
69   if (prefix.empty())
70     return cppClassName;
71   return llvm::formatv("{0}::{1}", prefix, cppClassName);
72 }
73 
74 int tblgen::Operator::getNumResults() const {
75   DagInit *results = def.getValueAsDag("results");
76   return results->getNumArgs();
77 }
78 
79 StringRef tblgen::Operator::getExtraClassDeclaration() const {
80   constexpr auto attr = "extraClassDeclaration";
81   if (def.isValueUnset(attr))
82     return {};
83   return def.getValueAsString(attr);
84 }
85 
86 const llvm::Record &tblgen::Operator::getDef() const { return def; }
87 
88 bool tblgen::Operator::isVariadic() const {
89   return getNumVariadicOperands() != 0 || getNumVariadicResults() != 0;
90 }
91 
92 bool tblgen::Operator::skipDefaultBuilders() const {
93   return def.getValueAsBit("skipDefaultBuilders");
94 }
95 
96 auto tblgen::Operator::result_begin() -> value_iterator {
97   return results.begin();
98 }
99 
100 auto tblgen::Operator::result_end() -> value_iterator { return results.end(); }
101 
102 auto tblgen::Operator::getResults() -> value_range {
103   return {result_begin(), result_end()};
104 }
105 
106 tblgen::TypeConstraint
107 tblgen::Operator::getResultTypeConstraint(int index) const {
108   DagInit *results = def.getValueAsDag("results");
109   return TypeConstraint(cast<DefInit>(results->getArg(index)));
110 }
111 
112 StringRef tblgen::Operator::getResultName(int index) const {
113   DagInit *results = def.getValueAsDag("results");
114   return results->getArgNameStr(index);
115 }
116 
117 unsigned tblgen::Operator::getNumVariadicResults() const {
118   return std::count_if(
119       results.begin(), results.end(),
120       [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
121 }
122 
123 unsigned tblgen::Operator::getNumVariadicOperands() const {
124   return std::count_if(
125       operands.begin(), operands.end(),
126       [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
127 }
128 
129 StringRef tblgen::Operator::getArgName(int index) const {
130   DagInit *argumentValues = def.getValueAsDag("arguments");
131   return argumentValues->getArgName(index)->getValue();
132 }
133 
134 bool tblgen::Operator::hasTrait(StringRef trait) const {
135   for (auto t : getTraits()) {
136     if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
137       if (opTrait->getTrait() == trait)
138         return true;
139     } else if (auto opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) {
140       if (opTrait->getTrait() == trait)
141         return true;
142     }
143   }
144   return false;
145 }
146 
147 unsigned tblgen::Operator::getNumRegions() const { return regions.size(); }
148 
149 const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
150   return regions[index];
151 }
152 
153 auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
154   return traits.begin();
155 }
156 auto tblgen::Operator::trait_end() const -> const_trait_iterator {
157   return traits.end();
158 }
159 auto tblgen::Operator::getTraits() const
160     -> llvm::iterator_range<const_trait_iterator> {
161   return {trait_begin(), trait_end()};
162 }
163 
164 auto tblgen::Operator::attribute_begin() const -> attribute_iterator {
165   return attributes.begin();
166 }
167 auto tblgen::Operator::attribute_end() const -> attribute_iterator {
168   return attributes.end();
169 }
170 auto tblgen::Operator::getAttributes() const
171     -> llvm::iterator_range<attribute_iterator> {
172   return {attribute_begin(), attribute_end()};
173 }
174 
175 auto tblgen::Operator::operand_begin() -> value_iterator {
176   return operands.begin();
177 }
178 auto tblgen::Operator::operand_end() -> value_iterator {
179   return operands.end();
180 }
181 auto tblgen::Operator::getOperands() -> value_range {
182   return {operand_begin(), operand_end()};
183 }
184 
185 auto tblgen::Operator::getArg(int index) const -> Argument {
186   return arguments[index];
187 }
188 
189 void tblgen::Operator::populateOpStructure() {
190   auto &recordKeeper = def.getRecords();
191   auto typeConstraintClass = recordKeeper.getClass("TypeConstraint");
192   auto attrClass = recordKeeper.getClass("Attr");
193   auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
194   numNativeAttributes = 0;
195 
196   // The argument ordering is operands, native attributes, derived
197   // attributes.
198   DagInit *argumentValues = def.getValueAsDag("arguments");
199   unsigned i = 0;
200   // Handle operands and native attributes.
201   for (unsigned e = argumentValues->getNumArgs(); i != e; ++i) {
202     auto arg = argumentValues->getArg(i);
203     auto givenName = argumentValues->getArgNameStr(i);
204     auto argDefInit = dyn_cast<DefInit>(arg);
205     if (!argDefInit)
206       PrintFatalError(def.getLoc(),
207                       Twine("undefined type for argument #") + Twine(i));
208     Record *argDef = argDefInit->getDef();
209 
210     if (argDef->isSubClassOf(typeConstraintClass)) {
211       operands.push_back(
212           NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
213       arguments.emplace_back(&operands.back());
214     } else if (argDef->isSubClassOf(attrClass)) {
215       if (givenName.empty())
216         PrintFatalError(argDef->getLoc(), "attributes must be named");
217       if (argDef->isSubClassOf(derivedAttrClass))
218         PrintFatalError(argDef->getLoc(),
219                         "derived attributes not allowed in argument list");
220       attributes.push_back({givenName, Attribute(argDef)});
221       arguments.emplace_back(&attributes.back());
222       ++numNativeAttributes;
223     } else {
224       PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
225                                     "from TypeConstraint or Attr are allowed");
226     }
227   }
228 
229   // Handle derived attributes.
230   for (const auto &val : def.getValues()) {
231     if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
232       if (!record->isSubClassOf(attrClass))
233         continue;
234       if (!record->isSubClassOf(derivedAttrClass))
235         PrintFatalError(def.getLoc(),
236                         "unexpected Attr where only DerivedAttr is allowed");
237 
238       if (record->getClasses().size() != 1) {
239         PrintFatalError(
240             def.getLoc(),
241             "unsupported attribute modelling, only single class expected");
242       }
243       attributes.push_back(
244           {cast<llvm::StringInit>(val.getNameInit())->getValue(),
245            Attribute(cast<DefInit>(val.getValue()))});
246     }
247   }
248 
249   auto *resultsDag = def.getValueAsDag("results");
250   auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
251   if (!outsOp || outsOp->getDef()->getName() != "outs") {
252     PrintFatalError(def.getLoc(), "'results' must have 'outs' directive");
253   }
254 
255   // Handle results.
256   for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
257     auto name = resultsDag->getArgNameStr(i);
258     auto *resultDef = dyn_cast<DefInit>(resultsDag->getArg(i));
259     if (!resultDef) {
260       PrintFatalError(def.getLoc(),
261                       Twine("undefined type for result #") + Twine(i));
262     }
263     results.push_back({name, TypeConstraint(resultDef)});
264   }
265 
266   auto traitListInit = def.getValueAsListInit("traits");
267   if (!traitListInit)
268     return;
269   traits.reserve(traitListInit->size());
270   for (auto traitInit : *traitListInit)
271     traits.push_back(OpTrait::create(traitInit));
272 
273   // Handle regions
274   auto *regionsDag = def.getValueAsDag("regions");
275   auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
276   if (!regionsOp || regionsOp->getDef()->getName() != "region") {
277     PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
278   }
279 
280   for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
281     auto name = regionsDag->getArgNameStr(i);
282     auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
283     if (!regionInit) {
284       PrintFatalError(def.getLoc(),
285                       Twine("undefined kind for region #") + Twine(i));
286     }
287     regions.push_back({name, Region(regionInit->getDef())});
288   }
289 }
290 
291 ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
292 
293 bool tblgen::Operator::hasDescription() const {
294   return def.getValue("description") != nullptr;
295 }
296 
297 StringRef tblgen::Operator::getDescription() const {
298   return def.getValueAsString("description");
299 }
300 
301 bool tblgen::Operator::hasSummary() const {
302   return def.getValue("summary") != nullptr;
303 }
304 
305 StringRef tblgen::Operator::getSummary() const {
306   return def.getValueAsString("summary");
307 }
308