1 //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // DialectGen uses the description of dialects to generate C++ definitions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Class.h" 14 #include "mlir/TableGen/CodeGenHelpers.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/GenInfo.h" 17 #include "mlir/TableGen/Interfaces.h" 18 #include "mlir/TableGen/Operator.h" 19 #include "mlir/TableGen/Trait.h" 20 #include "llvm/ADT/Optional.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/Support/CommandLine.h" 24 #include "llvm/Support/Signals.h" 25 #include "llvm/TableGen/Error.h" 26 #include "llvm/TableGen/Record.h" 27 #include "llvm/TableGen/TableGenBackend.h" 28 29 #define DEBUG_TYPE "mlir-tblgen-opdefgen" 30 31 using namespace mlir; 32 using namespace mlir::tblgen; 33 34 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*"); 35 llvm::cl::opt<std::string> 36 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"), 37 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); 38 39 /// Utility iterator used for filtering records for a specific dialect. 40 namespace { 41 using DialectFilterIterator = 42 llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator, 43 std::function<bool(const llvm::Record *)>>; 44 } // namespace 45 46 /// Given a set of records for a T, filter the ones that correspond to 47 /// the given dialect. 48 template <typename T> 49 static iterator_range<DialectFilterIterator> 50 filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) { 51 auto filterFn = [&](const llvm::Record *record) { 52 return T(record).getDialect() == dialect; 53 }; 54 return {DialectFilterIterator(records.begin(), records.end(), filterFn), 55 DialectFilterIterator(records.end(), records.end(), filterFn)}; 56 } 57 58 static Optional<Dialect> 59 findSelectedDialect(ArrayRef<const llvm::Record *> dialectDefs) { 60 // Select the dialect to gen for. 61 if (dialectDefs.size() == 1 && selectedDialect.getNumOccurrences() == 0) { 62 return Dialect(dialectDefs.front()); 63 } 64 65 if (selectedDialect.getNumOccurrences() == 0) { 66 llvm::errs() << "when more than 1 dialect is present, one must be selected " 67 "via '-dialect'\n"; 68 return llvm::None; 69 } 70 71 const auto *dialectIt = 72 llvm::find_if(dialectDefs, [](const llvm::Record *def) { 73 return Dialect(def).getName() == selectedDialect; 74 }); 75 if (dialectIt == dialectDefs.end()) { 76 llvm::errs() << "selected dialect with '-dialect' does not exist\n"; 77 return llvm::None; 78 } 79 return Dialect(*dialectIt); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // GEN: Dialect declarations 84 //===----------------------------------------------------------------------===// 85 86 /// The code block for the start of a dialect class declaration. 87 /// 88 /// {0}: The name of the dialect class. 89 /// {1}: The dialect namespace. 90 /// {2}: The dialect parent class. 91 static const char *const dialectDeclBeginStr = R"( 92 class {0} : public ::mlir::{2} { 93 explicit {0}(::mlir::MLIRContext *context); 94 95 void initialize(); 96 friend class ::mlir::MLIRContext; 97 public: 98 ~{0}() override; 99 static constexpr ::llvm::StringLiteral getDialectNamespace() { 100 return ::llvm::StringLiteral("{1}"); 101 } 102 )"; 103 104 /// Registration for a single dependent dialect: to be inserted in the ctor 105 /// above for each dependent dialect. 106 const char *const dialectRegistrationTemplate = R"( 107 getContext()->getOrLoadDialect<{0}>(); 108 )"; 109 110 /// The code block for the attribute parser/printer hooks. 111 static const char *const attrParserDecl = R"( 112 /// Parse an attribute registered to this dialect. 113 ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser, 114 ::mlir::Type type) const override; 115 116 /// Print an attribute registered to this dialect. 117 void printAttribute(::mlir::Attribute attr, 118 ::mlir::DialectAsmPrinter &os) const override; 119 )"; 120 121 /// The code block for the type parser/printer hooks. 122 static const char *const typeParserDecl = R"( 123 /// Parse a type registered to this dialect. 124 ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; 125 126 /// Print a type registered to this dialect. 127 void printType(::mlir::Type type, 128 ::mlir::DialectAsmPrinter &os) const override; 129 )"; 130 131 /// The code block for the canonicalization pattern registration hook. 132 static const char *const canonicalizerDecl = R"( 133 /// Register canonicalization patterns. 134 void getCanonicalizationPatterns( 135 ::mlir::RewritePatternSet &results) const override; 136 )"; 137 138 /// The code block for the constant materializer hook. 139 static const char *const constantMaterializerDecl = R"( 140 /// Materialize a single constant operation from a given attribute value with 141 /// the desired resultant type. 142 ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder, 143 ::mlir::Attribute value, 144 ::mlir::Type type, 145 ::mlir::Location loc) override; 146 )"; 147 148 /// The code block for the operation attribute verifier hook. 149 static const char *const opAttrVerifierDecl = R"( 150 /// Provides a hook for verifying dialect attributes attached to the given 151 /// op. 152 ::mlir::LogicalResult verifyOperationAttribute( 153 ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override; 154 )"; 155 156 /// The code block for the region argument attribute verifier hook. 157 static const char *const regionArgAttrVerifierDecl = R"( 158 /// Provides a hook for verifying dialect attributes attached to the given 159 /// op's region argument. 160 ::mlir::LogicalResult verifyRegionArgAttribute( 161 ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex, 162 ::mlir::NamedAttribute attribute) override; 163 )"; 164 165 /// The code block for the region result attribute verifier hook. 166 static const char *const regionResultAttrVerifierDecl = R"( 167 /// Provides a hook for verifying dialect attributes attached to the given 168 /// op's region result. 169 ::mlir::LogicalResult verifyRegionResultAttribute( 170 ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex, 171 ::mlir::NamedAttribute attribute) override; 172 )"; 173 174 /// The code block for the op interface fallback hook. 175 static const char *const operationInterfaceFallbackDecl = R"( 176 /// Provides a hook for op interface. 177 void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID, 178 mlir::OperationName opName) override; 179 )"; 180 181 /// Generate the declaration for the given dialect class. 182 static void 183 emitDialectDecl(Dialect &dialect, 184 const iterator_range<DialectFilterIterator> &dialectAttrs, 185 const iterator_range<DialectFilterIterator> &dialectTypes, 186 raw_ostream &os) { 187 // Emit all nested namespaces. 188 { 189 NamespaceEmitter nsEmitter(os, dialect); 190 191 // Emit the start of the decl. 192 std::string cppName = dialect.getCppClassName(); 193 StringRef superClassName = 194 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; 195 os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), 196 superClassName); 197 198 // Check for any attributes/types registered to this dialect. If there are, 199 // add the hooks for parsing/printing. 200 if (!dialectAttrs.empty() && dialect.useDefaultAttributePrinterParser()) 201 os << attrParserDecl; 202 if (!dialectTypes.empty() && dialect.useDefaultTypePrinterParser()) 203 os << typeParserDecl; 204 205 // Add the decls for the various features of the dialect. 206 if (dialect.hasCanonicalizer()) 207 os << canonicalizerDecl; 208 if (dialect.hasConstantMaterializer()) 209 os << constantMaterializerDecl; 210 if (dialect.hasOperationAttrVerify()) 211 os << opAttrVerifierDecl; 212 if (dialect.hasRegionArgAttrVerify()) 213 os << regionArgAttrVerifierDecl; 214 if (dialect.hasRegionResultAttrVerify()) 215 os << regionResultAttrVerifierDecl; 216 if (dialect.hasOperationInterfaceFallback()) 217 os << operationInterfaceFallbackDecl; 218 if (llvm::Optional<StringRef> extraDecl = 219 dialect.getExtraClassDeclaration()) 220 os << *extraDecl; 221 222 // End the dialect decl. 223 os << "};\n"; 224 } 225 if (!dialect.getCppNamespace().empty()) 226 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() 227 << "::" << dialect.getCppClassName() << ")\n"; 228 } 229 230 static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper, 231 raw_ostream &os) { 232 emitSourceFileHeader("Dialect Declarations", os); 233 234 auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect"); 235 if (dialectDefs.empty()) 236 return false; 237 238 Optional<Dialect> dialect = findSelectedDialect(dialectDefs); 239 if (!dialect) 240 return true; 241 auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr"); 242 auto typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType"); 243 emitDialectDecl(*dialect, filterForDialect<Attribute>(attrDefs, *dialect), 244 filterForDialect<Type>(typeDefs, *dialect), os); 245 return false; 246 } 247 248 //===----------------------------------------------------------------------===// 249 // GEN: Dialect definitions 250 //===----------------------------------------------------------------------===// 251 252 /// The code block to generate a dialect constructor definition. 253 /// 254 /// {0}: The name of the dialect class. 255 /// {1}: initialization code that is emitted in the ctor body before calling 256 /// initialize(). 257 /// {2}: The dialect parent class. 258 static const char *const dialectConstructorStr = R"( 259 {0}::{0}(::mlir::MLIRContext *context) 260 : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{ 261 {1} 262 initialize(); 263 } 264 )"; 265 266 /// The code block to generate a default desturctor definition. 267 /// 268 /// {0}: The name of the dialect class. 269 static const char *const dialectDestructorStr = R"( 270 {0}::~{0}() = default; 271 272 )"; 273 274 static void emitDialectDef(Dialect &dialect, raw_ostream &os) { 275 std::string cppClassName = dialect.getCppClassName(); 276 277 // Emit the TypeID explicit specializations to have a single symbol def. 278 if (!dialect.getCppNamespace().empty()) 279 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() 280 << "::" << cppClassName << ")\n"; 281 282 // Emit all nested namespaces. 283 NamespaceEmitter nsEmitter(os, dialect); 284 285 /// Build the list of dependent dialects. 286 std::string dependentDialectRegistrations; 287 { 288 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); 289 for (StringRef dependentDialect : dialect.getDependentDialects()) 290 dialectsOs << llvm::formatv(dialectRegistrationTemplate, 291 dependentDialect); 292 } 293 294 // Emit the constructor and destructor. 295 StringRef superClassName = 296 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; 297 os << llvm::formatv(dialectConstructorStr, cppClassName, 298 dependentDialectRegistrations, superClassName); 299 if (!dialect.hasNonDefaultDestructor()) 300 os << llvm::formatv(dialectDestructorStr, cppClassName); 301 } 302 303 static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper, 304 raw_ostream &os) { 305 emitSourceFileHeader("Dialect Definitions", os); 306 307 auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect"); 308 if (dialectDefs.empty()) 309 return false; 310 311 Optional<Dialect> dialect = findSelectedDialect(dialectDefs); 312 if (!dialect) 313 return true; 314 emitDialectDef(*dialect, os); 315 return false; 316 } 317 318 //===----------------------------------------------------------------------===// 319 // GEN: Dialect registration hooks 320 //===----------------------------------------------------------------------===// 321 322 static mlir::GenRegistration 323 genDialectDecls("gen-dialect-decls", "Generate dialect declarations", 324 [](const llvm::RecordKeeper &records, raw_ostream &os) { 325 return emitDialectDecls(records, os); 326 }); 327 328 static mlir::GenRegistration 329 genDialectDefs("gen-dialect-defs", "Generate dialect definitions", 330 [](const llvm::RecordKeeper &records, raw_ostream &os) { 331 return emitDialectDefs(records, os); 332 }); 333