1 //===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AttrOrTypeFormatGen.h"
10 #include "FormatGen.h"
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/Support/LogicalResult.h"
13 #include "mlir/TableGen/AttrOrTypeDef.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/StringSwitch.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/MemoryBuffer.h"
21 #include "llvm/Support/SaveAndRestore.h"
22 #include "llvm/Support/SourceMgr.h"
23 #include "llvm/TableGen/Error.h"
24 #include "llvm/TableGen/TableGenBackend.h"
25 
26 using namespace mlir;
27 using namespace mlir::tblgen;
28 
29 using llvm::formatv;
30 
31 //===----------------------------------------------------------------------===//
32 // Element
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 /// This class represents an instance of a variable element. A variable refers
37 /// to an attribute or type parameter.
38 class ParameterElement
39     : public VariableElementBase<VariableElement::Parameter> {
40 public:
ParameterElement(AttrOrTypeParameter param)41   ParameterElement(AttrOrTypeParameter param) : param(param) {}
42 
43   /// Get the parameter in the element.
getParam() const44   const AttrOrTypeParameter &getParam() const { return param; }
45 
46   /// Indicate if this variable is printed "qualified" (that is it is
47   /// prefixed with the `#dialect.mnemonic`).
shouldBeQualified()48   bool shouldBeQualified() { return shouldBeQualifiedFlag; }
setShouldBeQualified(bool qualified=true)49   void setShouldBeQualified(bool qualified = true) {
50     shouldBeQualifiedFlag = qualified;
51   }
52 
53   /// Returns true if the element contains an optional parameter.
isOptional() const54   bool isOptional() const { return param.isOptional(); }
55 
56   /// Returns the name of the parameter.
getName() const57   StringRef getName() const { return param.getName(); }
58 
59   /// Generate the code to check whether the parameter should be printed.
genPrintGuard(FmtContext & ctx,MethodBody & os) const60   MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
61     std::string self = param.getAccessorName() + "()";
62     ctx.withSelf(self);
63     os << tgfmt("($_self", &ctx);
64     if (llvm::Optional<StringRef> defaultValue = getParam().getDefaultValue()) {
65       // Use the `comparator` field if it exists, else the equality operator.
66       std::string valueStr = tgfmt(*defaultValue, &ctx).str();
67       ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
68       os << " && !(" << tgfmt(getParam().getComparator(), &ctx) << ")";
69     }
70     return os << ")";
71   }
72 
73 private:
74   bool shouldBeQualifiedFlag = false;
75   AttrOrTypeParameter param;
76 };
77 
78 /// Shorthand functions that can be used with ranged-based conditions.
paramIsOptional(ParameterElement * el)79 static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
paramNotOptional(ParameterElement * el)80 static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
81 
82 /// Base class for a directive that contains references to multiple variables.
83 template <DirectiveElement::Kind DirectiveKind>
84 class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
85 public:
86   using Base = ParamsDirectiveBase<DirectiveKind>;
87 
ParamsDirectiveBase(std::vector<ParameterElement * > && params)88   ParamsDirectiveBase(std::vector<ParameterElement *> &&params)
89       : params(std::move(params)) {}
90 
91   /// Get the parameters contained in this directive.
getParams() const92   ArrayRef<ParameterElement *> getParams() const { return params; }
93 
94   /// Get the number of parameters.
getNumParams() const95   unsigned getNumParams() const { return params.size(); }
96 
97   /// Take all of the parameters from this directive.
takeParams()98   std::vector<ParameterElement *> takeParams() { return std::move(params); }
99 
100   /// Returns true if there are optional parameters present.
hasOptionalParams() const101   bool hasOptionalParams() const {
102     return llvm::any_of(getParams(), paramIsOptional);
103   }
104 
105 private:
106   /// The parameters captured by this directive.
107   std::vector<ParameterElement *> params;
108 };
109 
110 /// This class represents a `params` directive that refers to all parameters
111 /// of an attribute or type. When used as a top-level directive, it generates
112 /// a format of the form:
113 ///
114 ///   (param-value (`,` param-value)*)?
115 ///
116 /// When used as an argument to another directive that accepts variables,
117 /// `params` can be used in place of manually listing all parameters of an
118 /// attribute or type.
119 class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
120 public:
121   using Base::Base;
122 };
123 
124 /// This class represents a `struct` directive that generates a struct format
125 /// of the form:
126 ///
127 ///   `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
128 ///
129 class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
130 public:
131   using Base::Base;
132 };
133 
134 } // namespace
135 
136 //===----------------------------------------------------------------------===//
137 // Format Strings
138 //===----------------------------------------------------------------------===//
139 
140 /// Default parser for attribute or type parameters.
141 static const char *const defaultParameterParser =
142     "::mlir::FieldParser<$0>::parse($_parser)";
143 
144 /// Default printer for attribute or type parameters.
145 static const char *const defaultParameterPrinter =
146     "$_printer.printStrippedAttrOrType($_self)";
147 
148 /// Qualified printer for attribute or type parameters: it does not elide
149 /// dialect and mnemonic.
150 static const char *const qualifiedParameterPrinter = "$_printer << $_self";
151 
152 /// Print an error when failing to parse an element.
153 ///
154 /// $0: The parameter C++ class name.
155 static const char *const parserErrorStr =
156     "$_parser.emitError($_parser.getCurrentLocation(), ";
157 
158 /// Code format to parse a variable. Separate by lines because variable parsers
159 /// may be generated inside other directives, which requires indentation.
160 ///
161 /// {0}: The parameter name.
162 /// {1}: The parse code for the parameter.
163 /// {2}: Code template for printing an error.
164 /// {3}: Name of the attribute or type.
165 /// {4}: C++ class of the parameter.
166 static const char *const variableParser = R"(
167 // Parse variable '{0}'
168 _result_{0} = {1};
169 if (::mlir::failed(_result_{0})) {{
170   {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
171   return {{};
172 }
173 )";
174 
175 //===----------------------------------------------------------------------===//
176 // DefFormat
177 //===----------------------------------------------------------------------===//
178 
179 namespace {
180 class DefFormat {
181 public:
DefFormat(const AttrOrTypeDef & def,std::vector<FormatElement * > && elements)182   DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements)
183       : def(def), elements(std::move(elements)) {}
184 
185   /// Generate the attribute or type parser.
186   void genParser(MethodBody &os);
187   /// Generate the attribute or type printer.
188   void genPrinter(MethodBody &os);
189 
190 private:
191   /// Generate the parser code for a specific format element.
192   void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
193   /// Generate the parser code for a literal.
194   void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os,
195                         bool isOptional = false);
196   /// Generate the parser code for a variable.
197   void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
198   /// Generate the parser code for a `params` directive.
199   void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
200   /// Generate the parser code for a `struct` directive.
201   void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
202   /// Generate the parser code for a `custom` directive.
203   void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os);
204   /// Generate the parser code for an optional group.
205   void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
206                               MethodBody &os);
207 
208   /// Generate the printer code for a specific format element.
209   void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
210   /// Generate the printer code for a literal.
211   void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
212   /// Generate the printer code for a variable.
213   void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
214                           bool skipGuard = false);
215   /// Generate a printer for comma-separated parameters.
216   void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
217                                 FmtContext &ctx, MethodBody &os,
218                                 function_ref<void(ParameterElement *)> extra);
219   /// Generate the printer code for a `params` directive.
220   void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
221   /// Generate the printer code for a `struct` directive.
222   void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
223   /// Generate the printer code for a `custom` directive.
224   void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
225   /// Generate the printer code for an optional group.
226   void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
227                                MethodBody &os);
228   /// Generate a printer (or space eraser) for a whitespace element.
229   void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
230                             MethodBody &os);
231 
232   /// The ODS definition of the attribute or type whose format is being used to
233   /// generate a parser and printer.
234   const AttrOrTypeDef &def;
235   /// The list of top-level format elements returned by the assembly format
236   /// parser.
237   std::vector<FormatElement *> elements;
238 
239   /// Flags for printing spaces.
240   bool shouldEmitSpace = false;
241   bool lastWasPunctuation = false;
242 };
243 } // namespace
244 
245 //===----------------------------------------------------------------------===//
246 // ParserGen
247 //===----------------------------------------------------------------------===//
248 
genParser(MethodBody & os)249 void DefFormat::genParser(MethodBody &os) {
250   FmtContext ctx;
251   ctx.addSubst("_parser", "odsParser");
252   ctx.addSubst("_ctxt", "odsParser.getContext()");
253   ctx.withBuilder("odsBuilder");
254   if (isa<AttrDef>(def))
255     ctx.addSubst("_type", "odsType");
256   os.indent();
257   os << "::mlir::Builder odsBuilder(odsParser.getContext());\n";
258 
259   // Declare variables to store all of the parameters. Allocated parameters
260   // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
261   // FailureOr<T> to defer type construction for parameters that are parsed in
262   // a loop (parsers return FailureOr anyways).
263   ArrayRef<AttrOrTypeParameter> params = def.getParameters();
264   for (const AttrOrTypeParameter &param : params) {
265     if (isa<AttributeSelfTypeParameter>(param))
266       continue;
267     os << formatv("::mlir::FailureOr<{0}> _result_{1};\n",
268                   param.getCppStorageType(), param.getName());
269   }
270 
271   // Store the initial location of the parser.
272   ctx.addSubst("_loc", "odsLoc");
273   os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
274               "(void) $_loc;\n",
275               &ctx);
276 
277   // Generate call to each parameter parser.
278   for (FormatElement *el : elements)
279     genElementParser(el, ctx, os);
280 
281   // Emit an assert for each mandatory parameter. Triggering an assert means
282   // the generated parser is incorrect (i.e. there is a bug in this code).
283   for (const AttrOrTypeParameter &param : params) {
284     if (param.isOptional() || isa<AttributeSelfTypeParameter>(param))
285       continue;
286     os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName());
287   }
288 
289   // Generate call to the attribute or type builder. Use the checked getter
290   // if one was generated.
291   if (def.genVerifyDecl()) {
292     os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
293                 &ctx, def.getCppClassName());
294   } else {
295     os << tgfmt("return $0::get($_parser.getContext()", &ctx,
296                 def.getCppClassName());
297   }
298   for (const AttrOrTypeParameter &param : params) {
299     os << ",\n    ";
300     std::string paramSelfStr;
301     llvm::raw_string_ostream selfOs(paramSelfStr);
302     if (param.isOptional()) {
303       selfOs << formatv("(_result_{0}.value_or(", param.getName());
304       if (Optional<StringRef> defaultValue = param.getDefaultValue())
305         selfOs << tgfmt(*defaultValue, &ctx);
306       else
307         selfOs << param.getCppStorageType() << "()";
308       selfOs << "))";
309     } else if (isa<AttributeSelfTypeParameter>(param)) {
310       selfOs << tgfmt("$_type", &ctx);
311     } else {
312       selfOs << formatv("(*_result_{0})", param.getName());
313     }
314     os << param.getCppType() << "("
315        << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
316        << ")";
317   }
318   os << ");";
319 }
320 
genElementParser(FormatElement * el,FmtContext & ctx,MethodBody & os)321 void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
322                                  MethodBody &os) {
323   if (auto *literal = dyn_cast<LiteralElement>(el))
324     return genLiteralParser(literal->getSpelling(), ctx, os);
325   if (auto *var = dyn_cast<ParameterElement>(el))
326     return genVariableParser(var, ctx, os);
327   if (auto *params = dyn_cast<ParamsDirective>(el))
328     return genParamsParser(params, ctx, os);
329   if (auto *strct = dyn_cast<StructDirective>(el))
330     return genStructParser(strct, ctx, os);
331   if (auto *custom = dyn_cast<CustomDirective>(el))
332     return genCustomParser(custom, ctx, os);
333   if (auto *optional = dyn_cast<OptionalElement>(el))
334     return genOptionalGroupParser(optional, ctx, os);
335   if (isa<WhitespaceElement>(el))
336     return;
337 
338   llvm_unreachable("unknown format element");
339 }
340 
genLiteralParser(StringRef value,FmtContext & ctx,MethodBody & os,bool isOptional)341 void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
342                                  MethodBody &os, bool isOptional) {
343   os << "// Parse literal '" << value << "'\n";
344   os << tgfmt("if ($_parser.parse", &ctx);
345   if (isOptional)
346     os << "Optional";
347   if (value.front() == '_' || isalpha(value.front())) {
348     os << "Keyword(\"" << value << "\")";
349   } else {
350     os << StringSwitch<StringRef>(value)
351               .Case("->", "Arrow")
352               .Case(":", "Colon")
353               .Case(",", "Comma")
354               .Case("=", "Equal")
355               .Case("<", "Less")
356               .Case(">", "Greater")
357               .Case("{", "LBrace")
358               .Case("}", "RBrace")
359               .Case("(", "LParen")
360               .Case(")", "RParen")
361               .Case("[", "LSquare")
362               .Case("]", "RSquare")
363               .Case("?", "Question")
364               .Case("+", "Plus")
365               .Case("*", "Star")
366        << "()";
367   }
368   if (isOptional) {
369     // Leave the `if` unclosed to guard optional groups.
370     return;
371   }
372   // Parser will emit an error
373   os << ") return {};\n";
374 }
375 
genVariableParser(ParameterElement * el,FmtContext & ctx,MethodBody & os)376 void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
377                                   MethodBody &os) {
378   // Check for a custom parser. Use the default attribute parser otherwise.
379   const AttrOrTypeParameter &param = el->getParam();
380   auto customParser = param.getParser();
381   auto parser =
382       customParser ? *customParser : StringRef(defaultParameterParser);
383   os << formatv(variableParser, param.getName(),
384                 tgfmt(parser, &ctx, param.getCppStorageType()),
385                 tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType());
386 }
387 
genParamsParser(ParamsDirective * el,FmtContext & ctx,MethodBody & os)388 void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
389                                 MethodBody &os) {
390   os << "// Parse parameter list\n";
391 
392   // If there are optional parameters, we need to switch to `parseOptionalComma`
393   // if there are no more required parameters after a certain point.
394   bool hasOptional = el->hasOptionalParams();
395   if (hasOptional) {
396     // Wrap everything in a do-while so that we can `break`.
397     os << "do {\n";
398     os.indent();
399   }
400 
401   ArrayRef<ParameterElement *> params = el->getParams();
402   using IteratorT = ParameterElement *const *;
403   IteratorT it = params.begin();
404 
405   // Find the last required parameter. Commas become optional aftewards.
406   // Note: IteratorT's copy assignment is deleted.
407   ParameterElement *lastReq = nullptr;
408   for (ParameterElement *param : params)
409     if (!param->isOptional())
410       lastReq = param;
411   IteratorT lastReqIt = lastReq ? llvm::find(params, lastReq) : params.begin();
412 
413   auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); };
414   auto betweenFn = [&](IteratorT it) {
415     ParameterElement *el = *std::prev(it);
416     // Parse a comma if the last optional parameter had a value.
417     if (el->isOptional()) {
418       os << formatv("if (::mlir::succeeded(_result_{0}) && *_result_{0}) {{\n",
419                     el->getName());
420       os.indent();
421     }
422     if (it <= lastReqIt) {
423       genLiteralParser(",", ctx, os);
424     } else {
425       genLiteralParser(",", ctx, os, /*isOptional=*/true);
426       os << ") break;\n";
427     }
428     if (el->isOptional())
429       os.unindent() << "}\n";
430   };
431 
432   // llvm::interleave
433   if (it != params.end()) {
434     eachFn(*it++);
435     for (IteratorT e = params.end(); it != e; ++it) {
436       betweenFn(it);
437       eachFn(*it);
438     }
439   }
440 
441   if (hasOptional)
442     os.unindent() << "} while(false);\n";
443 }
444 
genStructParser(StructDirective * el,FmtContext & ctx,MethodBody & os)445 void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
446                                 MethodBody &os) {
447   // Loop declaration for struct parser with only required parameters.
448   //
449   // $0: Number of expected parameters.
450   const char *const loopHeader = R"(
451   for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
452 )";
453 
454   // Loop body start for struct parser.
455   const char *const loopStart = R"(
456     ::llvm::StringRef _paramKey;
457     if ($_parser.parseKeyword(&_paramKey)) {
458       $_parser.emitError($_parser.getCurrentLocation(),
459                          "expected a parameter name in struct");
460       return {};
461     }
462     if (!_loop_body(_paramKey)) return {};
463 )";
464 
465   // Struct parser loop end. Check for duplicate or unknown struct parameters.
466   //
467   // {0}: Code template for printing an error.
468   const char *const loopEnd = R"({{
469   {0}"duplicate or unknown struct parameter name: ") << _paramKey;
470   return {{};
471 }
472 )";
473 
474   // Struct parser loop terminator. Parse a comma except on the last element.
475   //
476   // {0}: Number of elements in the struct.
477   const char *const loopTerminator = R"(
478   if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
479     return {{};
480 }
481 )";
482 
483   // Check that a mandatory parameter was parse.
484   //
485   // {0}: Name of the parameter.
486   const char *const checkParam = R"(
487     if (!_seen_{0}) {
488       {1}"struct is missing required parameter: ") << "{0}";
489       return {{};
490     }
491 )";
492 
493   // Optional parameters in a struct must be parsed successfully if the
494   // keyword is present.
495   //
496   // {0}: Name of the parameter.
497   // {1}: Emit error string
498   const char *const checkOptionalParam = R"(
499     if (::mlir::succeeded(_result_{0}) && !*_result_{0}) {{
500       {1}"expected a value for parameter '{0}'");
501       return {{};
502     }
503 )";
504 
505   // First iteration of the loop parsing an optional struct.
506   const char *const optionalStructFirst = R"(
507   ::llvm::StringRef _paramKey;
508   if (!$_parser.parseOptionalKeyword(&_paramKey)) {
509     if (!_loop_body(_paramKey)) return {};
510     while (!$_parser.parseOptionalComma()) {
511 )";
512 
513   os << "// Parse parameter struct\n";
514 
515   // Declare a "seen" variable for each key.
516   for (ParameterElement *param : el->getParams())
517     os << formatv("bool _seen_{0} = false;\n", param->getName());
518 
519   // Generate the body of the parsing loop inside a lambda.
520   os << "{\n";
521   os.indent()
522       << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
523   genLiteralParser("=", ctx, os.indent());
524   for (ParameterElement *param : el->getParams()) {
525     os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
526                   "  _seen_{0} = true;\n",
527                   param->getName());
528     genVariableParser(param, ctx, os.indent());
529     if (param->isOptional()) {
530       os.getStream().printReindented(strfmt(checkOptionalParam,
531                                             param->getName(),
532                                             tgfmt(parserErrorStr, &ctx).str()));
533     }
534     os.unindent() << "} else ";
535     // Print the check for duplicate or unknown parameter.
536   }
537   os.getStream().printReindented(strfmt(loopEnd, tgfmt(parserErrorStr, &ctx)));
538   os << "return true;\n";
539   os.unindent() << "};\n";
540 
541   // Generate the parsing loop. If optional parameters are present, then the
542   // parse loop is guarded by commas.
543   unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
544   if (numOptional) {
545     // If the struct itself is optional, pull out the first iteration.
546     if (numOptional == el->getNumParams()) {
547       os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
548       os.indent();
549     } else {
550       os << "do {\n";
551     }
552   } else {
553     os.getStream().printReindented(
554         tgfmt(loopHeader, &ctx, el->getNumParams()).str());
555   }
556   os.indent();
557   os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
558   os.unindent();
559 
560   // Print the loop terminator. For optional parameters, we have to check that
561   // all mandatory parameters have been parsed.
562   // The whole struct is optional if all its parameters are optional.
563   if (numOptional) {
564     if (numOptional == el->getNumParams()) {
565       os << "}\n";
566       os.unindent() << "}\n";
567     } else {
568       os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
569       for (ParameterElement *param : el->getParams()) {
570         if (param->isOptional())
571           continue;
572         os.getStream().printReindented(
573             strfmt(checkParam, param->getName(), tgfmt(parserErrorStr, &ctx)));
574       }
575     }
576   } else {
577     // Because the loop loops N times and each non-failing iteration sets 1 of
578     // N flags, successfully exiting the loop means that all parameters have
579     // been seen. `parseOptionalComma` would cause issues with any formats that
580     // use "struct(...) `,`" beacuse structs aren't sounded by braces.
581     os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
582   }
583   os.unindent() << "}\n";
584 }
585 
genCustomParser(CustomDirective * el,FmtContext & ctx,MethodBody & os)586 void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
587                                 MethodBody &os) {
588   os << "{\n";
589   os.indent();
590 
591   // Bound variables are passed directly to the parser as `FailureOr<T> &`.
592   // Referenced variables are passed as `T`. The custom parser fails if it
593   // returns failure or if any of the required parameters failed.
594   os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx);
595   os << "(void)odsCustomLoc;\n";
596   os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
597   os.indent();
598   for (FormatElement *arg : el->getArguments()) {
599     os << ",\n";
600     FormatElement *param;
601     if (auto *ref = dyn_cast<RefDirective>(arg)) {
602       os << "*";
603       param = ref->getArg();
604     } else {
605       param = arg;
606     }
607     os << "_result_" << cast<ParameterElement>(param)->getName();
608   }
609   os.unindent() << ");\n";
610   os << "if (::mlir::failed(odsCustomResult)) return {};\n";
611   for (FormatElement *arg : el->getArguments()) {
612     if (auto *param = dyn_cast<ParameterElement>(arg)) {
613       if (param->isOptional())
614         continue;
615       os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName());
616       os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx)
617                   << "\"custom parser failed to parse parameter '"
618                   << param->getName() << "'\");\n";
619       os << "return {};\n";
620       os.unindent() << "}\n";
621     }
622   }
623 
624   os.unindent() << "}\n";
625 }
626 
genOptionalGroupParser(OptionalElement * el,FmtContext & ctx,MethodBody & os)627 void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
628                                        MethodBody &os) {
629   ArrayRef<FormatElement *> elements =
630       el->getThenElements().drop_front(el->getParseStart());
631 
632   FormatElement *first = elements.front();
633   const auto guardOn = [&](auto params) {
634     os << "if (!(";
635     llvm::interleave(
636         params, os,
637         [&](ParameterElement *el) {
638           os << formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})",
639                         el->getName());
640         },
641         " || ");
642     os << ")) {\n";
643   };
644   if (auto *literal = dyn_cast<LiteralElement>(first)) {
645     genLiteralParser(literal->getSpelling(), ctx, os, /*isOptional=*/true);
646     os << ") {\n";
647   } else if (auto *param = dyn_cast<ParameterElement>(first)) {
648     genVariableParser(param, ctx, os);
649     guardOn(llvm::makeArrayRef(param));
650   } else if (auto *params = dyn_cast<ParamsDirective>(first)) {
651     genParamsParser(params, ctx, os);
652     guardOn(params->getParams());
653   } else {
654     auto *strct = cast<StructDirective>(first);
655     genStructParser(strct, ctx, os);
656     guardOn(params->getParams());
657   }
658   os.indent();
659 
660   // Generate the parsers for the rest of the elements.
661   for (FormatElement *element : el->getElseElements())
662     genElementParser(element, ctx, os);
663   os.unindent() << "} else {\n";
664   os.indent();
665   for (FormatElement *element : elements.drop_front())
666     genElementParser(element, ctx, os);
667   os.unindent() << "}\n";
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // PrinterGen
672 //===----------------------------------------------------------------------===//
673 
genPrinter(MethodBody & os)674 void DefFormat::genPrinter(MethodBody &os) {
675   FmtContext ctx;
676   ctx.addSubst("_printer", "odsPrinter");
677   ctx.addSubst("_ctxt", "getContext()");
678   ctx.withBuilder("odsBuilder");
679   os.indent();
680   os << "::mlir::Builder odsBuilder(getContext());\n";
681 
682   // Generate printers.
683   shouldEmitSpace = true;
684   lastWasPunctuation = false;
685   for (FormatElement *el : elements)
686     genElementPrinter(el, ctx, os);
687 }
688 
genElementPrinter(FormatElement * el,FmtContext & ctx,MethodBody & os)689 void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
690                                   MethodBody &os) {
691   if (auto *literal = dyn_cast<LiteralElement>(el))
692     return genLiteralPrinter(literal->getSpelling(), ctx, os);
693   if (auto *params = dyn_cast<ParamsDirective>(el))
694     return genParamsPrinter(params, ctx, os);
695   if (auto *strct = dyn_cast<StructDirective>(el))
696     return genStructPrinter(strct, ctx, os);
697   if (auto *custom = dyn_cast<CustomDirective>(el))
698     return genCustomPrinter(custom, ctx, os);
699   if (auto *var = dyn_cast<ParameterElement>(el))
700     return genVariablePrinter(var, ctx, os);
701   if (auto *optional = dyn_cast<OptionalElement>(el))
702     return genOptionalGroupPrinter(optional, ctx, os);
703   if (auto *whitespace = dyn_cast<WhitespaceElement>(el))
704     return genWhitespacePrinter(whitespace, ctx, os);
705 
706   llvm::PrintFatalError("unsupported format element");
707 }
708 
genLiteralPrinter(StringRef value,FmtContext & ctx,MethodBody & os)709 void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
710                                   MethodBody &os) {
711   // Don't insert a space before certain punctuation.
712   bool needSpace =
713       shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
714   os << tgfmt("$_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
715               value);
716 
717   // Update the flags.
718   shouldEmitSpace =
719       value.size() != 1 || !StringRef("<({[").contains(value.front());
720   lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
721 }
722 
genVariablePrinter(ParameterElement * el,FmtContext & ctx,MethodBody & os,bool skipGuard)723 void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
724                                    MethodBody &os, bool skipGuard) {
725   const AttrOrTypeParameter &param = el->getParam();
726   ctx.withSelf(param.getAccessorName() + "()");
727 
728   // Guard the printer on the presence of optional parameters and that they
729   // aren't equal to their default values (if they have one).
730   if (el->isOptional() && !skipGuard) {
731     el->genPrintGuard(ctx, os << "if (") << ") {\n";
732     os.indent();
733   }
734 
735   // Insert a space before the next parameter, if necessary.
736   if (shouldEmitSpace || !lastWasPunctuation)
737     os << tgfmt("$_printer << ' ';\n", &ctx);
738   shouldEmitSpace = true;
739   lastWasPunctuation = false;
740 
741   if (el->shouldBeQualified())
742     os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n";
743   else if (auto printer = param.getPrinter())
744     os << tgfmt(*printer, &ctx) << ";\n";
745   else
746     os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
747 
748   if (el->isOptional() && !skipGuard)
749     os.unindent() << "}\n";
750 }
751 
752 /// Generate code to guard printing on the presence of any optional parameters.
753 template <typename ParameterRange>
guardOnAny(FmtContext & ctx,MethodBody & os,ParameterRange && params)754 static void guardOnAny(FmtContext &ctx, MethodBody &os,
755                        ParameterRange &&params) {
756   os << "if (";
757   llvm::interleave(
758       params, os,
759       [&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || ");
760   os << ") {\n";
761   os.indent();
762 }
763 
genCommaSeparatedPrinter(ArrayRef<ParameterElement * > params,FmtContext & ctx,MethodBody & os,function_ref<void (ParameterElement *)> extra)764 void DefFormat::genCommaSeparatedPrinter(
765     ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
766     function_ref<void(ParameterElement *)> extra) {
767   // Emit a space if necessary, but only if the struct is present.
768   if (shouldEmitSpace || !lastWasPunctuation) {
769     bool allOptional = llvm::all_of(params, paramIsOptional);
770     if (allOptional)
771       guardOnAny(ctx, os, params);
772     os << tgfmt("$_printer << ' ';\n", &ctx);
773     if (allOptional)
774       os.unindent() << "}\n";
775   }
776 
777   // The first printed element does not need to emit a comma.
778   os << "{\n";
779   os.indent() << "bool _firstPrinted = true;\n";
780   for (ParameterElement *param : params) {
781     if (param->isOptional()) {
782       param->genPrintGuard(ctx, os << "if (") << ") {\n";
783       os.indent();
784     }
785     os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
786     os << "_firstPrinted = false;\n";
787     extra(param);
788     shouldEmitSpace = false;
789     lastWasPunctuation = true;
790     genVariablePrinter(param, ctx, os);
791     if (param->isOptional())
792       os.unindent() << "}\n";
793   }
794   os.unindent() << "}\n";
795 }
796 
genParamsPrinter(ParamsDirective * el,FmtContext & ctx,MethodBody & os)797 void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
798                                  MethodBody &os) {
799   genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
800                            [&](ParameterElement *param) {});
801 }
802 
genStructPrinter(StructDirective * el,FmtContext & ctx,MethodBody & os)803 void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
804                                  MethodBody &os) {
805   genCommaSeparatedPrinter(
806       llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
807         os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
808       });
809 }
810 
genCustomPrinter(CustomDirective * el,FmtContext & ctx,MethodBody & os)811 void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
812                                  MethodBody &os) {
813   os << tgfmt("print$0($_printer", &ctx, el->getName());
814   os.indent();
815   for (FormatElement *arg : el->getArguments()) {
816     FormatElement *param = arg;
817     if (auto *ref = dyn_cast<RefDirective>(arg))
818       param = ref->getArg();
819     os << ",\n"
820        << cast<ParameterElement>(param)->getParam().getAccessorName() << "()";
821   }
822   os.unindent() << ");\n";
823 }
824 
genOptionalGroupPrinter(OptionalElement * el,FmtContext & ctx,MethodBody & os)825 void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
826                                         MethodBody &os) {
827   FormatElement *anchor = el->getAnchor();
828   if (auto *param = dyn_cast<ParameterElement>(anchor)) {
829     guardOnAny(ctx, os, llvm::makeArrayRef(param));
830   } else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
831     guardOnAny(ctx, os, params->getParams());
832   } else {
833     auto *strct = cast<StructDirective>(anchor);
834     guardOnAny(ctx, os, strct->getParams());
835   }
836   // Generate the printer for the contained elements.
837   {
838     llvm::SaveAndRestore<bool> shouldEmitSpaceFlag(shouldEmitSpace);
839     llvm::SaveAndRestore<bool> lastWasPunctuationFlag(lastWasPunctuation);
840     for (FormatElement *element : el->getThenElements())
841       genElementPrinter(element, ctx, os);
842   }
843   os.unindent() << "} else {\n";
844   os.indent();
845   for (FormatElement *element : el->getElseElements())
846     genElementPrinter(element, ctx, os);
847   os.unindent() << "}\n";
848 }
849 
genWhitespacePrinter(WhitespaceElement * el,FmtContext & ctx,MethodBody & os)850 void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
851                                      MethodBody &os) {
852   if (el->getValue() == "\\n") {
853     // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by
854     // the printer.
855     os << tgfmt("$_printer << '\\n';\n", &ctx);
856   } else if (!el->getValue().empty()) {
857     os << tgfmt("$_printer << \"$0\";\n", &ctx, el->getValue());
858   } else {
859     lastWasPunctuation = true;
860   }
861   shouldEmitSpace = false;
862 }
863 
864 //===----------------------------------------------------------------------===//
865 // DefFormatParser
866 //===----------------------------------------------------------------------===//
867 
868 namespace {
869 class DefFormatParser : public FormatParser {
870 public:
DefFormatParser(llvm::SourceMgr & mgr,const AttrOrTypeDef & def)871   DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
872       : FormatParser(mgr, def.getLoc()[0]), def(def),
873         seenParams(def.getNumParameters()) {}
874 
875   /// Parse the attribute or type format and create the format elements.
876   FailureOr<DefFormat> parse();
877 
878 protected:
879   /// Verify the parsed elements.
880   LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
881   /// Verify the elements of a custom directive.
882   LogicalResult
883   verifyCustomDirectiveArguments(SMLoc loc,
884                                  ArrayRef<FormatElement *> arguments) override;
885   /// Verify the elements of an optional group.
886   LogicalResult
887   verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
888                               Optional<unsigned> anchorIndex) override;
889 
890   /// Parse an attribute or type variable.
891   FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
892                                                Context ctx) override;
893   /// Parse an attribute or type format directive.
894   FailureOr<FormatElement *>
895   parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
896 
897 private:
898   /// Parse a `params` directive.
899   FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
900   /// Parse a `qualified` directive.
901   FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
902   /// Parse a `struct` directive.
903   FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
904   /// Parse a `ref` directive.
905   FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
906 
907   /// Attribute or type tablegen def.
908   const AttrOrTypeDef &def;
909 
910   /// Seen attribute or type parameters.
911   BitVector seenParams;
912 };
913 } // namespace
914 
verify(SMLoc loc,ArrayRef<FormatElement * > elements)915 LogicalResult DefFormatParser::verify(SMLoc loc,
916                                       ArrayRef<FormatElement *> elements) {
917   // Check that all parameters are referenced in the format.
918   for (auto &it : llvm::enumerate(def.getParameters())) {
919     if (it.value().isOptional())
920       continue;
921     if (!seenParams.test(it.index())) {
922       if (isa<AttributeSelfTypeParameter>(it.value()))
923         continue;
924       return emitError(loc, "format is missing reference to parameter: " +
925                                 it.value().getName());
926     }
927     if (isa<AttributeSelfTypeParameter>(it.value())) {
928       return emitError(loc,
929                        "unexpected self type parameter in assembly format");
930     }
931   }
932   if (elements.empty())
933     return success();
934   // A `struct` directive that contains optional parameters cannot be followed
935   // by a comma literal, which is ambiguous.
936   for (auto it : llvm::zip(elements.drop_back(), elements.drop_front())) {
937     auto *structEl = dyn_cast<StructDirective>(std::get<0>(it));
938     auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
939     if (!structEl || !literalEl)
940       continue;
941     if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
942       return emitError(loc, "`struct` directive with optional parameters "
943                             "cannot be followed by a comma literal");
944     }
945   }
946   return success();
947 }
948 
verifyCustomDirectiveArguments(SMLoc loc,ArrayRef<FormatElement * > arguments)949 LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
950     SMLoc loc, ArrayRef<FormatElement *> arguments) {
951   // Arguments are fully verified by the parser context.
952   return success();
953 }
954 
955 LogicalResult
verifyOptionalGroupElements(llvm::SMLoc loc,ArrayRef<FormatElement * > elements,Optional<unsigned> anchorIndex)956 DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
957                                              ArrayRef<FormatElement *> elements,
958                                              Optional<unsigned> anchorIndex) {
959   // `params` and `struct` directives are allowed only if all the contained
960   // parameters are optional.
961   for (FormatElement *el : elements) {
962     if (auto *param = dyn_cast<ParameterElement>(el)) {
963       if (!param->isOptional()) {
964         return emitError(loc,
965                          "parameters in an optional group must be optional");
966       }
967     } else if (auto *params = dyn_cast<ParamsDirective>(el)) {
968       if (llvm::any_of(params->getParams(), paramNotOptional)) {
969         return emitError(loc, "`params` directive allowed in optional group "
970                               "only if all parameters are optional");
971       }
972     } else if (auto *strct = dyn_cast<StructDirective>(el)) {
973       if (llvm::any_of(strct->getParams(), paramNotOptional)) {
974         return emitError(loc, "`struct` is only allowed in an optional group "
975                               "if all captured parameters are optional");
976       }
977     }
978   }
979   // The anchor must be a parameter or one of the aforementioned directives.
980   if (anchorIndex && !isa<ParameterElement, ParamsDirective, StructDirective>(
981                          elements[*anchorIndex])) {
982     return emitError(loc,
983                      "optional group anchor must be a parameter or directive");
984   }
985   return success();
986 }
987 
parse()988 FailureOr<DefFormat> DefFormatParser::parse() {
989   FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
990   if (failed(elements))
991     return failure();
992   return DefFormat(def, std::move(*elements));
993 }
994 
995 FailureOr<FormatElement *>
parseVariableImpl(SMLoc loc,StringRef name,Context ctx)996 DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
997   // Lookup the parameter.
998   ArrayRef<AttrOrTypeParameter> params = def.getParameters();
999   auto *it = llvm::find_if(
1000       params, [&](auto &param) { return param.getName() == name; });
1001 
1002   // Check that the parameter reference is valid.
1003   if (it == params.end()) {
1004     return emitError(loc,
1005                      def.getName() + " has no parameter named '" + name + "'");
1006   }
1007   auto idx = std::distance(params.begin(), it);
1008 
1009   if (ctx != RefDirectiveContext) {
1010     // Check that the variable has not already been bound.
1011     if (seenParams.test(idx))
1012       return emitError(loc, "duplicate parameter '" + name + "'");
1013     seenParams.set(idx);
1014 
1015     // Otherwise, to be referenced, a variable must have been bound.
1016   } else if (!seenParams.test(idx)) {
1017     return emitError(loc, "parameter '" + name +
1018                               "' must be bound before it is referenced");
1019   }
1020 
1021   return create<ParameterElement>(*it);
1022 }
1023 
1024 FailureOr<FormatElement *>
parseDirectiveImpl(SMLoc loc,FormatToken::Kind kind,Context ctx)1025 DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
1026                                     Context ctx) {
1027 
1028   switch (kind) {
1029   case FormatToken::kw_qualified:
1030     return parseQualifiedDirective(loc, ctx);
1031   case FormatToken::kw_params:
1032     return parseParamsDirective(loc, ctx);
1033   case FormatToken::kw_struct:
1034     return parseStructDirective(loc, ctx);
1035   case FormatToken::kw_ref:
1036     return parseRefDirective(loc, ctx);
1037   case FormatToken::kw_custom:
1038     return parseCustomDirective(loc, ctx);
1039 
1040   default:
1041     return emitError(loc, "unsupported directive kind");
1042   }
1043 }
1044 
1045 FailureOr<FormatElement *>
parseQualifiedDirective(SMLoc loc,Context ctx)1046 DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
1047   if (failed(parseToken(FormatToken::l_paren,
1048                         "expected '(' before argument list")))
1049     return failure();
1050   FailureOr<FormatElement *> var = parseElement(ctx);
1051   if (failed(var))
1052     return var;
1053   if (!isa<ParameterElement>(*var))
1054     return emitError(loc, "`qualified` argument list expected a variable");
1055   cast<ParameterElement>(*var)->setShouldBeQualified();
1056   if (failed(
1057           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
1058     return failure();
1059   return var;
1060 }
1061 
parseParamsDirective(SMLoc loc,Context ctx)1062 FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
1063                                                                  Context ctx) {
1064   // It doesn't make sense to allow references to all parameters in a custom
1065   // directive because parameters are the only things that can be bound.
1066   if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
1067     return emitError(loc, "`params` can only be used at the top-level context "
1068                           "or within a `struct` directive");
1069   }
1070 
1071   // Collect all of the attribute's or type's parameters and ensure that none of
1072   // the parameters have already been captured.
1073   std::vector<ParameterElement *> vars;
1074   for (const auto &it : llvm::enumerate(def.getParameters())) {
1075     if (seenParams.test(it.index())) {
1076       return emitError(loc, "`params` captures duplicate parameter: " +
1077                                 it.value().getName());
1078     }
1079     seenParams.set(it.index());
1080     vars.push_back(create<ParameterElement>(it.value()));
1081   }
1082   return create<ParamsDirective>(std::move(vars));
1083 }
1084 
parseStructDirective(SMLoc loc,Context ctx)1085 FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
1086                                                                  Context ctx) {
1087   if (ctx != TopLevelContext)
1088     return emitError(loc, "`struct` can only be used at the top-level context");
1089 
1090   if (failed(parseToken(FormatToken::l_paren,
1091                         "expected '(' before `struct` argument list")))
1092     return failure();
1093 
1094   // Parse variables captured by `struct`.
1095   std::vector<ParameterElement *> vars;
1096 
1097   // Parse first captured parameter or a `params` directive.
1098   FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
1099   if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
1100     return emitError(loc,
1101                      "`struct` argument list expected a variable or directive");
1102   }
1103   if (isa<VariableElement>(*var)) {
1104     // Parse any other parameters.
1105     vars.push_back(cast<ParameterElement>(*var));
1106     while (peekToken().is(FormatToken::comma)) {
1107       consumeToken();
1108       var = parseElement(StructDirectiveContext);
1109       if (failed(var) || !isa<VariableElement>(*var))
1110         return emitError(loc, "expected a variable in `struct` argument list");
1111       vars.push_back(cast<ParameterElement>(*var));
1112     }
1113   } else {
1114     // `struct(params)` captures all parameters in the attribute or type.
1115     vars = cast<ParamsDirective>(*var)->takeParams();
1116   }
1117 
1118   if (failed(parseToken(FormatToken::r_paren,
1119                         "expected ')' at the end of an argument list")))
1120     return failure();
1121 
1122   return create<StructDirective>(std::move(vars));
1123 }
1124 
parseRefDirective(SMLoc loc,Context ctx)1125 FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
1126                                                               Context ctx) {
1127   if (ctx != CustomDirectiveContext)
1128     return emitError(loc, "`ref` is only allowed inside custom directives");
1129 
1130   // Parse the child parameter element.
1131   FailureOr<FormatElement *> child;
1132   if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
1133       failed(child = parseElement(RefDirectiveContext)) ||
1134       failed(parseToken(FormatToken::r_paren, "expeced ')'")))
1135     return failure();
1136 
1137   // Only parameter elements are allowed to be parsed under a `ref` directive.
1138   return create<RefDirective>(*child);
1139 }
1140 
1141 //===----------------------------------------------------------------------===//
1142 // Interface
1143 //===----------------------------------------------------------------------===//
1144 
generateAttrOrTypeFormat(const AttrOrTypeDef & def,MethodBody & parser,MethodBody & printer)1145 void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
1146                                             MethodBody &parser,
1147                                             MethodBody &printer) {
1148   llvm::SourceMgr mgr;
1149   mgr.AddNewSourceBuffer(
1150       llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc());
1151 
1152   // Parse the custom assembly format>
1153   DefFormatParser fmtParser(mgr, def);
1154   FailureOr<DefFormat> format = fmtParser.parse();
1155   if (failed(format)) {
1156     if (formatErrorIsFatal)
1157       PrintFatalError(def.getLoc(), "failed to parse assembly format");
1158     return;
1159   }
1160 
1161   // Generate the parser and printer.
1162   format->genParser(parser);
1163   format->genPrinter(printer);
1164 }
1165