1 //===- Class.cpp - Helper classes for Op C++ code emission --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/TableGen/Class.h" 10 11 #include "mlir/TableGen/Format.h" 12 #include "llvm/ADT/Sequence.h" 13 #include "llvm/ADT/Twine.h" 14 #include "llvm/Support/Debug.h" 15 #include "llvm/Support/raw_ostream.h" 16 #include <unordered_set> 17 18 #define DEBUG_TYPE "mlir-tblgen-opclass" 19 20 using namespace mlir; 21 using namespace mlir::tblgen; 22 23 // Returns space to be emitted after the given C++ `type`. return "" if the 24 // ends with '&' or '*', or is empty, else returns " ". 25 static StringRef getSpaceAfterType(StringRef type) { 26 return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " "; 27 } 28 29 //===----------------------------------------------------------------------===// 30 // MethodParameter definitions 31 //===----------------------------------------------------------------------===// 32 33 void MethodParameter::writeTo(raw_ostream &os, bool emitDefault) const { 34 if (optional) 35 os << "/*optional*/"; 36 os << type << getSpaceAfterType(type) << name; 37 if (emitDefault && hasDefaultValue()) 38 os << " = " << defaultValue; 39 } 40 41 //===----------------------------------------------------------------------===// 42 // MethodParameters definitions 43 //===----------------------------------------------------------------------===// 44 45 void MethodParameters::writeDeclTo(raw_ostream &os) const { 46 llvm::interleaveComma(parameters, os, 47 [&os](auto ¶m) { param.writeDeclTo(os); }); 48 } 49 void MethodParameters::writeDefTo(raw_ostream &os) const { 50 llvm::interleaveComma(parameters, os, 51 [&os](auto ¶m) { param.writeDefTo(os); }); 52 } 53 54 bool MethodParameters::subsumes(const MethodParameters &other) const { 55 // These parameters do not subsume the others if there are fewer parameters 56 // or their types do not match. 57 if (parameters.size() < other.parameters.size()) 58 return false; 59 if (!std::equal( 60 other.parameters.begin(), other.parameters.end(), parameters.begin(), 61 [](auto &lhs, auto &rhs) { return lhs.getType() == rhs.getType(); })) 62 return false; 63 64 // If all the common parameters have the same type, we can elide the other 65 // method if this method has the same number of parameters as other or if the 66 // first paramater after the common parameters has a default value (and, as 67 // required by C++, subsequent parameters will have default values too). 68 return parameters.size() == other.parameters.size() || 69 parameters[other.parameters.size()].hasDefaultValue(); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // MethodSignature definitions 74 //===----------------------------------------------------------------------===// 75 76 bool MethodSignature::makesRedundant(const MethodSignature &other) const { 77 return methodName == other.methodName && 78 parameters.subsumes(other.parameters); 79 } 80 81 void MethodSignature::writeDeclTo(raw_ostream &os) const { 82 os << returnType << getSpaceAfterType(returnType) << methodName << "("; 83 parameters.writeDeclTo(os); 84 os << ")"; 85 } 86 87 void MethodSignature::writeDefTo(raw_ostream &os, StringRef namePrefix) const { 88 os << returnType << getSpaceAfterType(returnType) << namePrefix 89 << (namePrefix.empty() ? "" : "::") << methodName << "("; 90 parameters.writeDefTo(os); 91 os << ")"; 92 } 93 94 //===----------------------------------------------------------------------===// 95 // MethodBody definitions 96 //===----------------------------------------------------------------------===// 97 98 MethodBody::MethodBody(bool declOnly) : isEffective(!declOnly) {} 99 100 MethodBody &MethodBody::operator<<(Twine content) { 101 if (isEffective) 102 body.append(content.str()); 103 return *this; 104 } 105 106 MethodBody &MethodBody::operator<<(int content) { 107 if (isEffective) 108 body.append(std::to_string(content)); 109 return *this; 110 } 111 112 MethodBody &MethodBody::operator<<(const FmtObjectBase &content) { 113 if (isEffective) 114 body.append(content.str()); 115 return *this; 116 } 117 118 void MethodBody::writeTo(raw_ostream &os) const { 119 auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); 120 os << bodyRef; 121 if (bodyRef.empty() || bodyRef.back() != '\n') 122 os << "\n"; 123 } 124 125 //===----------------------------------------------------------------------===// 126 // Method definitions 127 //===----------------------------------------------------------------------===// 128 129 void Method::writeDeclTo(raw_ostream &os) const { 130 os.indent(2); 131 if (isStatic()) 132 os << "static "; 133 if ((properties & MP_Constexpr) == MP_Constexpr) 134 os << "constexpr "; 135 methodSignature.writeDeclTo(os); 136 if (!isInline()) { 137 os << ";"; 138 } else { 139 os << " {\n"; 140 methodBody.writeTo(os.indent(2)); 141 os.indent(2) << "}"; 142 } 143 } 144 145 void Method::writeDefTo(raw_ostream &os, StringRef namePrefix) const { 146 // Do not write definition if the method is decl only. 147 if (properties & MP_Declaration) 148 return; 149 // Do not generate separate definition for inline method 150 if (isInline()) 151 return; 152 methodSignature.writeDefTo(os, namePrefix); 153 os << " {\n"; 154 methodBody.writeTo(os); 155 os << "}"; 156 } 157 158 //===----------------------------------------------------------------------===// 159 // Constructor definitions 160 //===----------------------------------------------------------------------===// 161 162 void Constructor::addMemberInitializer(StringRef name, StringRef value) { 163 memberInitializers.append(std::string(llvm::formatv( 164 "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value))); 165 } 166 167 void Constructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const { 168 // Do not write definition if the method is decl only. 169 if (properties & MP_Declaration) 170 return; 171 172 methodSignature.writeDefTo(os, namePrefix); 173 os << " " << memberInitializers << " {\n"; 174 methodBody.writeTo(os); 175 os << "}\n"; 176 } 177 178 //===----------------------------------------------------------------------===// 179 // Class definitions 180 //===----------------------------------------------------------------------===// 181 182 Class::Class(StringRef name) : className(name) {} 183 184 void Class::newField(StringRef type, StringRef name, StringRef defaultValue) { 185 std::string varName = formatv("{0} {1}", type, name).str(); 186 std::string field = defaultValue.empty() 187 ? varName 188 : formatv("{0} = {1}", varName, defaultValue).str(); 189 fields.push_back(std::move(field)); 190 } 191 192 void Class::writeDeclTo(raw_ostream &os) const { 193 bool hasPrivateMethod = false; 194 os << "class " << className << " {\n"; 195 os << "public:\n"; 196 197 forAllMethods([&](const Method &method) { 198 if (!method.isPrivate()) { 199 method.writeDeclTo(os); 200 os << '\n'; 201 } else { 202 hasPrivateMethod = true; 203 } 204 }); 205 206 os << '\n'; 207 os << "private:\n"; 208 if (hasPrivateMethod) { 209 forAllMethods([&](const Method &method) { 210 if (method.isPrivate()) { 211 method.writeDeclTo(os); 212 os << '\n'; 213 } 214 }); 215 os << '\n'; 216 } 217 218 for (const auto &field : fields) 219 os.indent(2) << field << ";\n"; 220 os << "};\n"; 221 } 222 223 void Class::writeDefTo(raw_ostream &os) const { 224 forAllMethods([&](const Method &method) { 225 method.writeDefTo(os, className); 226 os << "\n"; 227 }); 228 } 229 230 // Insert a new method into a list of methods, if it would not be pruned, and 231 // prune and existing methods. 232 template <typename ContainerT, typename MethodT> 233 MethodT *insertAndPrune(ContainerT &methods, MethodT newMethod) { 234 if (llvm::any_of(methods, [&](auto &method) { 235 return method.makesRedundant(newMethod); 236 })) 237 return nullptr; 238 239 llvm::erase_if( 240 methods, [&](auto &method) { return newMethod.makesRedundant(method); }); 241 methods.push_back(std::move(newMethod)); 242 return &methods.back(); 243 } 244 245 Method *Class::addMethodAndPrune(Method &&newMethod) { 246 return insertAndPrune(methods, std::move(newMethod)); 247 } 248 249 Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) { 250 return insertAndPrune(constructors, std::move(newCtor)); 251 } 252 253 //===----------------------------------------------------------------------===// 254 // OpClass definitions 255 //===----------------------------------------------------------------------===// 256 257 OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) 258 : Class(name), extraClassDeclaration(extraClassDeclaration) {} 259 260 void OpClass::addTrait(Twine trait) { traits.insert(trait.str()); } 261 262 void OpClass::writeDeclTo(raw_ostream &os) const { 263 os << "class " << className << " : public ::mlir::Op<" << className; 264 for (const auto &trait : traits) 265 os << ", " << trait; 266 os << "> {\npublic:\n" 267 << " using Op::Op;\n" 268 << " using Op::print;\n" 269 << " using Adaptor = " << className << "Adaptor;\n"; 270 271 bool hasPrivateMethod = false; 272 forAllMethods([&](const Method &method) { 273 if (!method.isPrivate()) { 274 method.writeDeclTo(os); 275 os << "\n"; 276 } else { 277 hasPrivateMethod = true; 278 } 279 }); 280 281 // TODO: Add line control markers to make errors easier to debug. 282 if (!extraClassDeclaration.empty()) 283 os << extraClassDeclaration << "\n"; 284 285 if (hasPrivateMethod) { 286 os << "\nprivate:\n"; 287 forAllMethods([&](const Method &method) { 288 if (method.isPrivate()) { 289 method.writeDeclTo(os); 290 os << "\n"; 291 } 292 }); 293 } 294 295 os << "};\n"; 296 } 297