1 //===- Class.cpp - Helper classes for Op C++ code emission --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/TableGen/Class.h"
10 
11 #include "mlir/TableGen/Format.h"
12 #include "llvm/ADT/Sequence.h"
13 #include "llvm/ADT/Twine.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/raw_ostream.h"
16 #include <unordered_set>
17 
18 #define DEBUG_TYPE "mlir-tblgen-opclass"
19 
20 using namespace mlir;
21 using namespace mlir::tblgen;
22 
23 // Returns space to be emitted after the given C++ `type`. return "" if the
24 // ends with '&' or '*', or is empty, else returns " ".
25 static StringRef getSpaceAfterType(StringRef type) {
26   return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
27 }
28 
29 //===----------------------------------------------------------------------===//
30 // MethodParameter definitions
31 //===----------------------------------------------------------------------===//
32 
33 void MethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
34   if (optional)
35     os << "/*optional*/";
36   os << type << getSpaceAfterType(type) << name;
37   if (emitDefault && hasDefaultValue())
38     os << " = " << defaultValue;
39 }
40 
41 //===----------------------------------------------------------------------===//
42 // MethodParameters definitions
43 //===----------------------------------------------------------------------===//
44 
45 void MethodParameters::writeDeclTo(raw_ostream &os) const {
46   llvm::interleaveComma(parameters, os,
47                         [&os](auto &param) { param.writeDeclTo(os); });
48 }
49 void MethodParameters::writeDefTo(raw_ostream &os) const {
50   llvm::interleaveComma(parameters, os,
51                         [&os](auto &param) { param.writeDefTo(os); });
52 }
53 
54 bool MethodParameters::subsumes(const MethodParameters &other) const {
55   // These parameters do not subsume the others if there are fewer parameters
56   // or their types do not match.
57   if (parameters.size() < other.parameters.size())
58     return false;
59   if (!std::equal(
60           other.parameters.begin(), other.parameters.end(), parameters.begin(),
61           [](auto &lhs, auto &rhs) { return lhs.getType() == rhs.getType(); }))
62     return false;
63 
64   // If all the common parameters have the same type, we can elide the other
65   // method if this method has the same number of parameters as other or if the
66   // first paramater after the common parameters has a default value (and, as
67   // required by C++, subsequent parameters will have default values too).
68   return parameters.size() == other.parameters.size() ||
69          parameters[other.parameters.size()].hasDefaultValue();
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // MethodSignature definitions
74 //===----------------------------------------------------------------------===//
75 
76 bool MethodSignature::makesRedundant(const MethodSignature &other) const {
77   return methodName == other.methodName &&
78          parameters.subsumes(other.parameters);
79 }
80 
81 void MethodSignature::writeDeclTo(raw_ostream &os) const {
82   os << returnType << getSpaceAfterType(returnType) << methodName << "(";
83   parameters.writeDeclTo(os);
84   os << ")";
85 }
86 
87 void MethodSignature::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
88   os << returnType << getSpaceAfterType(returnType) << namePrefix
89      << (namePrefix.empty() ? "" : "::") << methodName << "(";
90   parameters.writeDefTo(os);
91   os << ")";
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // MethodBody definitions
96 //===----------------------------------------------------------------------===//
97 
98 MethodBody::MethodBody(bool declOnly) : isEffective(!declOnly) {}
99 
100 MethodBody &MethodBody::operator<<(Twine content) {
101   if (isEffective)
102     body.append(content.str());
103   return *this;
104 }
105 
106 MethodBody &MethodBody::operator<<(int content) {
107   if (isEffective)
108     body.append(std::to_string(content));
109   return *this;
110 }
111 
112 MethodBody &MethodBody::operator<<(const FmtObjectBase &content) {
113   if (isEffective)
114     body.append(content.str());
115   return *this;
116 }
117 
118 void MethodBody::writeTo(raw_ostream &os) const {
119   auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
120   os << bodyRef;
121   if (bodyRef.empty() || bodyRef.back() != '\n')
122     os << "\n";
123 }
124 
125 //===----------------------------------------------------------------------===//
126 // Method definitions
127 //===----------------------------------------------------------------------===//
128 
129 void Method::writeDeclTo(raw_ostream &os) const {
130   os.indent(2);
131   if (isStatic())
132     os << "static ";
133   if ((properties & MP_Constexpr) == MP_Constexpr)
134     os << "constexpr ";
135   methodSignature.writeDeclTo(os);
136   if (!isInline()) {
137     os << ";";
138   } else {
139     os << " {\n";
140     methodBody.writeTo(os.indent(2));
141     os.indent(2) << "}";
142   }
143 }
144 
145 void Method::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
146   // Do not write definition if the method is decl only.
147   if (properties & MP_Declaration)
148     return;
149   // Do not generate separate definition for inline method
150   if (isInline())
151     return;
152   methodSignature.writeDefTo(os, namePrefix);
153   os << " {\n";
154   methodBody.writeTo(os);
155   os << "}";
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // Constructor definitions
160 //===----------------------------------------------------------------------===//
161 
162 void Constructor::addMemberInitializer(StringRef name, StringRef value) {
163   memberInitializers.append(std::string(llvm::formatv(
164       "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
165 }
166 
167 void Constructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
168   // Do not write definition if the method is decl only.
169   if (properties & MP_Declaration)
170     return;
171 
172   methodSignature.writeDefTo(os, namePrefix);
173   os << " " << memberInitializers << " {\n";
174   methodBody.writeTo(os);
175   os << "}\n";
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // Class definitions
180 //===----------------------------------------------------------------------===//
181 
182 Class::Class(StringRef name) : className(name) {}
183 
184 void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
185   std::string varName = formatv("{0} {1}", type, name).str();
186   std::string field = defaultValue.empty()
187                           ? varName
188                           : formatv("{0} = {1}", varName, defaultValue).str();
189   fields.push_back(std::move(field));
190 }
191 
192 void Class::writeDeclTo(raw_ostream &os) const {
193   bool hasPrivateMethod = false;
194   os << "class " << className << " {\n";
195   os << "public:\n";
196 
197   forAllMethods([&](const Method &method) {
198     if (!method.isPrivate()) {
199       method.writeDeclTo(os);
200       os << '\n';
201     } else {
202       hasPrivateMethod = true;
203     }
204   });
205 
206   os << '\n';
207   os << "private:\n";
208   if (hasPrivateMethod) {
209     forAllMethods([&](const Method &method) {
210       if (method.isPrivate()) {
211         method.writeDeclTo(os);
212         os << '\n';
213       }
214     });
215     os << '\n';
216   }
217 
218   for (const auto &field : fields)
219     os.indent(2) << field << ";\n";
220   os << "};\n";
221 }
222 
223 void Class::writeDefTo(raw_ostream &os) const {
224   forAllMethods([&](const Method &method) {
225     method.writeDefTo(os, className);
226     os << "\n";
227   });
228 }
229 
230 // Insert a new method into a list of methods, if it would not be pruned, and
231 // prune and existing methods.
232 template <typename ContainerT, typename MethodT>
233 MethodT *insertAndPrune(ContainerT &methods, MethodT newMethod) {
234   if (llvm::any_of(methods, [&](auto &method) {
235         return method.makesRedundant(newMethod);
236       }))
237     return nullptr;
238 
239   llvm::erase_if(
240       methods, [&](auto &method) { return newMethod.makesRedundant(method); });
241   methods.push_back(std::move(newMethod));
242   return &methods.back();
243 }
244 
245 Method *Class::addMethodAndPrune(Method &&newMethod) {
246   return insertAndPrune(methods, std::move(newMethod));
247 }
248 
249 Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) {
250   return insertAndPrune(constructors, std::move(newCtor));
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // OpClass definitions
255 //===----------------------------------------------------------------------===//
256 
257 OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
258     : Class(name), extraClassDeclaration(extraClassDeclaration) {}
259 
260 void OpClass::addTrait(Twine trait) { traits.insert(trait.str()); }
261 
262 void OpClass::writeDeclTo(raw_ostream &os) const {
263   os << "class " << className << " : public ::mlir::Op<" << className;
264   for (const auto &trait : traits)
265     os << ", " << trait;
266   os << "> {\npublic:\n"
267      << "  using Op::Op;\n"
268      << "  using Op::print;\n"
269      << "  using Adaptor = " << className << "Adaptor;\n";
270 
271   bool hasPrivateMethod = false;
272   forAllMethods([&](const Method &method) {
273     if (!method.isPrivate()) {
274       method.writeDeclTo(os);
275       os << "\n";
276     } else {
277       hasPrivateMethod = true;
278     }
279   });
280 
281   // TODO: Add line control markers to make errors easier to debug.
282   if (!extraClassDeclaration.empty())
283     os << extraClassDeclaration << "\n";
284 
285   if (hasPrivateMethod) {
286     os << "\nprivate:\n";
287     forAllMethods([&](const Method &method) {
288       if (method.isPrivate()) {
289         method.writeDeclTo(os);
290         os << "\n";
291       }
292     });
293   }
294 
295   os << "};\n";
296 }
297