1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
11 #include "OpClass.h"
12 #include "mlir/Support/LogicalResult.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "mlir/TableGen/Interfaces.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Trait.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Signals.h"
26 #include "llvm/TableGen/Error.h"
27 #include "llvm/TableGen/Record.h"
28 
29 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
30 
31 using namespace mlir;
32 using namespace mlir::tblgen;
33 
34 //===----------------------------------------------------------------------===//
35 // Element
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 /// This class represents a single format element.
40 class Element {
41 public:
42   enum class Kind {
43     /// This element is a directive.
44     AttrDictDirective,
45     CustomDirective,
46     FunctionalTypeDirective,
47     OperandsDirective,
48     RefDirective,
49     RegionsDirective,
50     ResultsDirective,
51     SuccessorsDirective,
52     TypeDirective,
53 
54     /// This element is a literal.
55     Literal,
56 
57     /// This element is a whitespace.
58     Newline,
59     Space,
60 
61     /// This element is an variable value.
62     AttributeVariable,
63     OperandVariable,
64     RegionVariable,
65     ResultVariable,
66     SuccessorVariable,
67 
68     /// This element is an optional element.
69     Optional,
70   };
71   Element(Kind kind) : kind(kind) {}
72   virtual ~Element() = default;
73 
74   /// Return the kind of this element.
75   Kind getKind() const { return kind; }
76 
77 private:
78   /// The kind of this element.
79   Kind kind;
80 };
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // VariableElement
85 
86 namespace {
87 /// This class represents an instance of an variable element. A variable refers
88 /// to something registered on the operation itself, e.g. an argument, result,
89 /// etc.
90 template <typename VarT, Element::Kind kindVal>
91 class VariableElement : public Element {
92 public:
93   VariableElement(const VarT *var) : Element(kindVal), var(var) {}
94   static bool classof(const Element *element) {
95     return element->getKind() == kindVal;
96   }
97   const VarT *getVar() { return var; }
98 
99 protected:
100   const VarT *var;
101 };
102 
103 /// This class represents a variable that refers to an attribute argument.
104 struct AttributeVariable
105     : public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> {
106   using VariableElement<NamedAttribute,
107                         Element::Kind::AttributeVariable>::VariableElement;
108 
109   /// Return the constant builder call for the type of this attribute, or None
110   /// if it doesn't have one.
111   Optional<StringRef> getTypeBuilder() const {
112     Optional<Type> attrType = var->attr.getValueType();
113     return attrType ? attrType->getBuilderCall() : llvm::None;
114   }
115 
116   /// Return if this attribute refers to a UnitAttr.
117   bool isUnitAttr() const {
118     return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
119   }
120 
121   /// Indicate if this attribute is printed "qualified" (that is it is
122   /// prefixed with the `#dialect.mnemonic`).
123   bool shouldBeQualified() { return shouldBeQualifiedFlag; }
124   void setShouldBeQualified(bool qualified = true) {
125     shouldBeQualifiedFlag = qualified;
126   }
127 
128 private:
129   bool shouldBeQualifiedFlag = false;
130 };
131 
132 /// This class represents a variable that refers to an operand argument.
133 using OperandVariable =
134     VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>;
135 
136 /// This class represents a variable that refers to a region.
137 using RegionVariable =
138     VariableElement<NamedRegion, Element::Kind::RegionVariable>;
139 
140 /// This class represents a variable that refers to a result.
141 using ResultVariable =
142     VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
143 
144 /// This class represents a variable that refers to a successor.
145 using SuccessorVariable =
146     VariableElement<NamedSuccessor, Element::Kind::SuccessorVariable>;
147 } // namespace
148 
149 //===----------------------------------------------------------------------===//
150 // DirectiveElement
151 
152 namespace {
153 /// This class implements single kind directives.
154 template <Element::Kind type> class DirectiveElement : public Element {
155 public:
156   DirectiveElement() : Element(type){};
157   static bool classof(const Element *ele) { return ele->getKind() == type; }
158 };
159 /// This class represents the `operands` directive. This directive represents
160 /// all of the operands of an operation.
161 using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
162 
163 /// This class represents the `regions` directive. This directive represents
164 /// all of the regions of an operation.
165 using RegionsDirective = DirectiveElement<Element::Kind::RegionsDirective>;
166 
167 /// This class represents the `results` directive. This directive represents
168 /// all of the results of an operation.
169 using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
170 
171 /// This class represents the `successors` directive. This directive represents
172 /// all of the successors of an operation.
173 using SuccessorsDirective =
174     DirectiveElement<Element::Kind::SuccessorsDirective>;
175 
176 /// This class represents the `attr-dict` directive. This directive represents
177 /// the attribute dictionary of the operation.
178 class AttrDictDirective
179     : public DirectiveElement<Element::Kind::AttrDictDirective> {
180 public:
181   explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
182   bool isWithKeyword() const { return withKeyword; }
183 
184 private:
185   /// If the dictionary should be printed with the 'attributes' keyword.
186   bool withKeyword;
187 };
188 
189 /// This class represents a custom format directive that is implemented by the
190 /// user in C++.
191 class CustomDirective : public Element {
192 public:
193   CustomDirective(StringRef name,
194                   std::vector<std::unique_ptr<Element>> &&arguments)
195       : Element{Kind::CustomDirective}, name(name),
196         arguments(std::move(arguments)) {}
197 
198   static bool classof(const Element *element) {
199     return element->getKind() == Kind::CustomDirective;
200   }
201 
202   /// Return the name of the custom directive.
203   StringRef getName() const { return name; }
204 
205   /// Return the arguments to the custom directive.
206   auto getArguments() const { return llvm::make_pointee_range(arguments); }
207 
208 private:
209   /// The user provided name of the directive.
210   StringRef name;
211 
212   /// The arguments to the custom directive.
213   std::vector<std::unique_ptr<Element>> arguments;
214 };
215 
216 /// This class represents the `functional-type` directive. This directive takes
217 /// two arguments and formats them, respectively, as the inputs and results of a
218 /// FunctionType.
219 class FunctionalTypeDirective
220     : public DirectiveElement<Element::Kind::FunctionalTypeDirective> {
221 public:
222   FunctionalTypeDirective(std::unique_ptr<Element> inputs,
223                           std::unique_ptr<Element> results)
224       : inputs(std::move(inputs)), results(std::move(results)) {}
225   Element *getInputs() const { return inputs.get(); }
226   Element *getResults() const { return results.get(); }
227 
228 private:
229   /// The input and result arguments.
230   std::unique_ptr<Element> inputs, results;
231 };
232 
233 /// This class represents the `ref` directive.
234 class RefDirective : public DirectiveElement<Element::Kind::RefDirective> {
235 public:
236   RefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
237   Element *getOperand() const { return operand.get(); }
238 
239 private:
240   /// The operand that is used to format the directive.
241   std::unique_ptr<Element> operand;
242 };
243 
244 /// This class represents the `type` directive.
245 class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
246 public:
247   TypeDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
248   Element *getOperand() const { return operand.get(); }
249 
250   /// Indicate if this type is printed "qualified" (that is it is
251   /// prefixed with the `!dialect.mnemonic`).
252   bool shouldBeQualified() { return shouldBeQualifiedFlag; }
253   void setShouldBeQualified(bool qualified = true) {
254     shouldBeQualifiedFlag = qualified;
255   }
256 
257 private:
258   /// The operand that is used to format the directive.
259   std::unique_ptr<Element> operand;
260 
261   bool shouldBeQualifiedFlag = false;
262 };
263 } // namespace
264 
265 //===----------------------------------------------------------------------===//
266 // LiteralElement
267 
268 namespace {
269 /// This class represents an instance of a literal element.
270 class LiteralElement : public Element {
271 public:
272   LiteralElement(StringRef literal)
273       : Element{Kind::Literal}, literal(literal) {}
274   static bool classof(const Element *element) {
275     return element->getKind() == Kind::Literal;
276   }
277 
278   /// Return the literal for this element.
279   StringRef getLiteral() const { return literal; }
280 
281 private:
282   /// The spelling of the literal for this element.
283   StringRef literal;
284 };
285 } // namespace
286 
287 //===----------------------------------------------------------------------===//
288 // WhitespaceElement
289 
290 namespace {
291 /// This class represents a whitespace element, e.g. newline or space. It's a
292 /// literal that is printed but never parsed.
293 class WhitespaceElement : public Element {
294 public:
295   WhitespaceElement(Kind kind) : Element{kind} {}
296   static bool classof(const Element *element) {
297     Kind kind = element->getKind();
298     return kind == Kind::Newline || kind == Kind::Space;
299   }
300 };
301 
302 /// This class represents an instance of a newline element. It's a literal that
303 /// prints a newline. It is ignored by the parser.
304 class NewlineElement : public WhitespaceElement {
305 public:
306   NewlineElement() : WhitespaceElement(Kind::Newline) {}
307   static bool classof(const Element *element) {
308     return element->getKind() == Kind::Newline;
309   }
310 };
311 
312 /// This class represents an instance of a space element. It's a literal that
313 /// prints or omits printing a space. It is ignored by the parser.
314 class SpaceElement : public WhitespaceElement {
315 public:
316   SpaceElement(bool value) : WhitespaceElement(Kind::Space), value(value) {}
317   static bool classof(const Element *element) {
318     return element->getKind() == Kind::Space;
319   }
320 
321   /// Returns true if this element should print as a space. Otherwise, the
322   /// element should omit printing a space between the surrounding elements.
323   bool getValue() const { return value; }
324 
325 private:
326   bool value;
327 };
328 } // namespace
329 
330 //===----------------------------------------------------------------------===//
331 // OptionalElement
332 
333 namespace {
334 /// This class represents a group of elements that are optionally emitted based
335 /// upon an optional variable of the operation, and a group of elements that are
336 /// emotted when the anchor element is not present.
337 class OptionalElement : public Element {
338 public:
339   OptionalElement(std::vector<std::unique_ptr<Element>> &&thenElements,
340                   std::vector<std::unique_ptr<Element>> &&elseElements,
341                   unsigned anchor, unsigned parseStart)
342       : Element{Kind::Optional}, thenElements(std::move(thenElements)),
343         elseElements(std::move(elseElements)), anchor(anchor),
344         parseStart(parseStart) {}
345   static bool classof(const Element *element) {
346     return element->getKind() == Kind::Optional;
347   }
348 
349   /// Return the `then` elements of this grouping.
350   auto getThenElements() const {
351     return llvm::make_pointee_range(thenElements);
352   }
353 
354   /// Return the `else` elements of this grouping.
355   auto getElseElements() const {
356     return llvm::make_pointee_range(elseElements);
357   }
358 
359   /// Return the anchor of this optional group.
360   Element *getAnchor() const { return thenElements[anchor].get(); }
361 
362   /// Return the index of the first element that needs to be parsed.
363   unsigned getParseStart() const { return parseStart; }
364 
365 private:
366   /// The child elements of `then` branch of this optional.
367   std::vector<std::unique_ptr<Element>> thenElements;
368   /// The child elements of `else` branch of this optional.
369   std::vector<std::unique_ptr<Element>> elseElements;
370   /// The index of the element that acts as the anchor for the optional group.
371   unsigned anchor;
372   /// The index of the first element that is parsed (is not a
373   /// WhitespaceElement).
374   unsigned parseStart;
375 };
376 } // namespace
377 
378 //===----------------------------------------------------------------------===//
379 // OperationFormat
380 //===----------------------------------------------------------------------===//
381 
382 namespace {
383 
384 using ConstArgument =
385     llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
386 
387 struct OperationFormat {
388   /// This class represents a specific resolver for an operand or result type.
389   class TypeResolution {
390   public:
391     TypeResolution() = default;
392 
393     /// Get the index into the buildable types for this type, or None.
394     Optional<int> getBuilderIdx() const { return builderIdx; }
395     void setBuilderIdx(int idx) { builderIdx = idx; }
396 
397     /// Get the variable this type is resolved to, or nullptr.
398     const NamedTypeConstraint *getVariable() const {
399       return resolver.dyn_cast<const NamedTypeConstraint *>();
400     }
401     /// Get the attribute this type is resolved to, or nullptr.
402     const NamedAttribute *getAttribute() const {
403       return resolver.dyn_cast<const NamedAttribute *>();
404     }
405     /// Get the transformer for the type of the variable, or None.
406     Optional<StringRef> getVarTransformer() const {
407       return variableTransformer;
408     }
409     void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
410       resolver = arg;
411       variableTransformer = transformer;
412       assert(getVariable() || getAttribute());
413     }
414 
415   private:
416     /// If the type is resolved with a buildable type, this is the index into
417     /// 'buildableTypes' in the parent format.
418     Optional<int> builderIdx;
419     /// If the type is resolved based upon another operand or result, this is
420     /// the variable or the attribute that this type is resolved to.
421     ConstArgument resolver;
422     /// If the type is resolved based upon another operand or result, this is
423     /// a transformer to apply to the variable when resolving.
424     Optional<StringRef> variableTransformer;
425   };
426 
427   /// The context in which an element is generated.
428   enum class GenContext {
429     /// The element is generated at the top-level or with the same behaviour.
430     Normal,
431     /// The element is generated inside an optional group.
432     Optional
433   };
434 
435   OperationFormat(const Operator &op)
436       : allOperands(false), allOperandTypes(false), allResultTypes(false),
437         infersResultTypes(false) {
438     operandTypes.resize(op.getNumOperands(), TypeResolution());
439     resultTypes.resize(op.getNumResults(), TypeResolution());
440 
441     hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
442       return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
443     });
444 
445     hasSingleBlockTrait =
446         hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock");
447   }
448 
449   /// Generate the operation parser from this format.
450   void genParser(Operator &op, OpClass &opClass);
451   /// Generate the parser code for a specific format element.
452   void genElementParser(Element *element, MethodBody &body,
453                         FmtContext &attrTypeCtx,
454                         GenContext genCtx = GenContext::Normal);
455   /// Generate the C++ to resolve the types of operands and results during
456   /// parsing.
457   void genParserTypeResolution(Operator &op, MethodBody &body);
458   /// Generate the C++ to resolve the types of the operands during parsing.
459   void genParserOperandTypeResolution(
460       Operator &op, MethodBody &body,
461       function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
462   /// Generate the C++ to resolve regions during parsing.
463   void genParserRegionResolution(Operator &op, MethodBody &body);
464   /// Generate the C++ to resolve successors during parsing.
465   void genParserSuccessorResolution(Operator &op, MethodBody &body);
466   /// Generate the C++ to handling variadic segment size traits.
467   void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
468 
469   /// Generate the operation printer from this format.
470   void genPrinter(Operator &op, OpClass &opClass);
471 
472   /// Generate the printer code for a specific format element.
473   void genElementPrinter(Element *element, MethodBody &body, Operator &op,
474                          bool &shouldEmitSpace, bool &lastWasPunctuation);
475 
476   /// The various elements in this format.
477   std::vector<std::unique_ptr<Element>> elements;
478 
479   /// A flag indicating if all operand/result types were seen. If the format
480   /// contains these, it can not contain individual type resolvers.
481   bool allOperands, allOperandTypes, allResultTypes;
482 
483   /// A flag indicating if this operation infers its result types
484   bool infersResultTypes;
485 
486   /// A flag indicating if this operation has the SingleBlockImplicitTerminator
487   /// trait.
488   bool hasImplicitTermTrait;
489 
490   /// A flag indicating if this operation has the SingleBlock trait.
491   bool hasSingleBlockTrait;
492 
493   /// A map of buildable types to indices.
494   llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
495 
496   /// The index of the buildable type, if valid, for every operand and result.
497   std::vector<TypeResolution> operandTypes, resultTypes;
498 
499   /// The set of attributes explicitly used within the format.
500   SmallVector<const NamedAttribute *, 8> usedAttributes;
501   llvm::StringSet<> inferredAttributes;
502 };
503 } // namespace
504 
505 //===----------------------------------------------------------------------===//
506 // Parser Gen
507 
508 /// Returns true if we can format the given attribute as an EnumAttr in the
509 /// parser format.
510 static bool canFormatEnumAttr(const NamedAttribute *attr) {
511   Attribute baseAttr = attr->attr.getBaseAttr();
512   const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
513   if (!enumAttr)
514     return false;
515 
516   // The attribute must have a valid underlying type and a constant builder.
517   return !enumAttr->getUnderlyingType().empty() &&
518          !enumAttr->getConstBuilderTemplate().empty();
519 }
520 
521 /// Returns if we should format the given attribute as an SymbolNameAttr.
522 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
523   return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
524 }
525 
526 /// The code snippet used to generate a parser call for an attribute.
527 ///
528 /// {0}: The name of the attribute.
529 /// {1}: The type for the attribute.
530 const char *const attrParserCode = R"(
531   if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}",
532           result.attributes)) {{
533     return ::mlir::failure();
534   }
535 )";
536 
537 /// The code snippet used to generate a parser call for an attribute.
538 ///
539 /// {0}: The name of the attribute.
540 /// {1}: The type for the attribute.
541 const char *const genericAttrParserCode = R"(
542   if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes))
543     return ::mlir::failure();
544 )";
545 
546 const char *const optionalAttrParserCode = R"(
547   {
548     ::mlir::OptionalParseResult parseResult =
549       parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes);
550     if (parseResult.hasValue() && failed(*parseResult))
551       return ::mlir::failure();
552   }
553 )";
554 
555 /// The code snippet used to generate a parser call for a symbol name attribute.
556 ///
557 /// {0}: The name of the attribute.
558 const char *const symbolNameAttrParserCode = R"(
559   if (parser.parseSymbolName({0}Attr, "{0}", result.attributes))
560     return ::mlir::failure();
561 )";
562 const char *const optionalSymbolNameAttrParserCode = R"(
563   // Parsing an optional symbol name doesn't fail, so no need to check the
564   // result.
565   (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes);
566 )";
567 
568 /// The code snippet used to generate a parser call for an enum attribute.
569 ///
570 /// {0}: The name of the attribute.
571 /// {1}: The c++ namespace for the enum symbolize functions.
572 /// {2}: The function to symbolize a string of the enum.
573 /// {3}: The constant builder call to create an attribute of the enum type.
574 /// {4}: The set of allowed enum keywords.
575 /// {5}: The error message on failure when the enum isn't present.
576 const char *const enumAttrParserCode = R"(
577   {
578     ::llvm::StringRef attrStr;
579     ::mlir::NamedAttrList attrStorage;
580     auto loc = parser.getCurrentLocation();
581     if (parser.parseOptionalKeyword(&attrStr, {4})) {
582       ::mlir::StringAttr attrVal;
583       ::mlir::OptionalParseResult parseResult =
584         parser.parseOptionalAttribute(attrVal,
585                                       parser.getBuilder().getNoneType(),
586                                       "{0}", attrStorage);
587       if (parseResult.hasValue()) {{
588         if (failed(*parseResult))
589           return ::mlir::failure();
590         attrStr = attrVal.getValue();
591       } else {
592         {5}
593       }
594     }
595     if (!attrStr.empty()) {
596       auto attrOptional = {1}::{2}(attrStr);
597       if (!attrOptional)
598         return parser.emitError(loc, "invalid ")
599                << "{0} attribute specification: \"" << attrStr << '"';;
600 
601       {0}Attr = {3};
602       result.addAttribute("{0}", {0}Attr);
603     }
604   }
605 )";
606 
607 /// The code snippet used to generate a parser call for an operand.
608 ///
609 /// {0}: The name of the operand.
610 const char *const variadicOperandParserCode = R"(
611   {0}OperandsLoc = parser.getCurrentLocation();
612   if (parser.parseOperandList({0}Operands))
613     return ::mlir::failure();
614 )";
615 const char *const optionalOperandParserCode = R"(
616   {
617     {0}OperandsLoc = parser.getCurrentLocation();
618     ::mlir::OpAsmParser::OperandType operand;
619     ::mlir::OptionalParseResult parseResult =
620                                     parser.parseOptionalOperand(operand);
621     if (parseResult.hasValue()) {
622       if (failed(*parseResult))
623         return ::mlir::failure();
624       {0}Operands.push_back(operand);
625     }
626   }
627 )";
628 const char *const operandParserCode = R"(
629   {0}OperandsLoc = parser.getCurrentLocation();
630   if (parser.parseOperand({0}RawOperands[0]))
631     return ::mlir::failure();
632 )";
633 /// The code snippet used to generate a parser call for a VariadicOfVariadic
634 /// operand.
635 ///
636 /// {0}: The name of the operand.
637 /// {1}: The name of segment size attribute.
638 const char *const variadicOfVariadicOperandParserCode = R"(
639   {
640     {0}OperandsLoc = parser.getCurrentLocation();
641     int32_t curSize = 0;
642     do {
643       if (parser.parseOptionalLParen())
644         break;
645       if (parser.parseOperandList({0}Operands) || parser.parseRParen())
646         return ::mlir::failure();
647       {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
648       curSize = {0}Operands.size();
649     } while (succeeded(parser.parseOptionalComma()));
650   }
651 )";
652 
653 /// The code snippet used to generate a parser call for a type list.
654 ///
655 /// {0}: The name for the type list.
656 const char *const variadicOfVariadicTypeParserCode = R"(
657   do {
658     if (parser.parseOptionalLParen())
659       break;
660     if (parser.parseOptionalRParen() &&
661         (parser.parseTypeList({0}Types) || parser.parseRParen()))
662       return ::mlir::failure();
663   } while (succeeded(parser.parseOptionalComma()));
664 )";
665 const char *const variadicTypeParserCode = R"(
666   if (parser.parseTypeList({0}Types))
667     return ::mlir::failure();
668 )";
669 const char *const optionalTypeParserCode = R"(
670   {
671     ::mlir::Type optionalType;
672     ::mlir::OptionalParseResult parseResult =
673                                     parser.parseOptionalType(optionalType);
674     if (parseResult.hasValue()) {
675       if (failed(*parseResult))
676         return ::mlir::failure();
677       {0}Types.push_back(optionalType);
678     }
679   }
680 )";
681 const char *const typeParserCode = R"(
682   {
683     {0} type;
684     if (parser.parseCustomTypeWithFallback(type))
685       return ::mlir::failure();
686     {1}RawTypes[0] = type;
687   }
688 )";
689 const char *const qualifiedTypeParserCode = R"(
690   if (parser.parseType({1}RawTypes[0]))
691     return ::mlir::failure();
692 )";
693 
694 /// The code snippet used to generate a parser call for a functional type.
695 ///
696 /// {0}: The name for the input type list.
697 /// {1}: The name for the result type list.
698 const char *const functionalTypeParserCode = R"(
699   ::mlir::FunctionType {0}__{1}_functionType;
700   if (parser.parseType({0}__{1}_functionType))
701     return ::mlir::failure();
702   {0}Types = {0}__{1}_functionType.getInputs();
703   {1}Types = {0}__{1}_functionType.getResults();
704 )";
705 
706 /// The code snippet used to generate a parser call to infer return types.
707 ///
708 /// {0}: The operation class name
709 const char *const inferReturnTypesParserCode = R"(
710   ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
711   if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
712       result.location, result.operands,
713       result.attributes.getDictionary(parser.getContext()),
714       result.regions, inferredReturnTypes)))
715     return ::mlir::failure();
716   result.addTypes(inferredReturnTypes);
717 )";
718 
719 /// The code snippet used to generate a parser call for a region list.
720 ///
721 /// {0}: The name for the region list.
722 const char *regionListParserCode = R"(
723   {
724     std::unique_ptr<::mlir::Region> region;
725     auto firstRegionResult = parser.parseOptionalRegion(region);
726     if (firstRegionResult.hasValue()) {
727       if (failed(*firstRegionResult))
728         return ::mlir::failure();
729       {0}Regions.emplace_back(std::move(region));
730 
731       // Parse any trailing regions.
732       while (succeeded(parser.parseOptionalComma())) {
733         region = std::make_unique<::mlir::Region>();
734         if (parser.parseRegion(*region))
735           return ::mlir::failure();
736         {0}Regions.emplace_back(std::move(region));
737       }
738     }
739   }
740 )";
741 
742 /// The code snippet used to ensure a list of regions have terminators.
743 ///
744 /// {0}: The name of the region list.
745 const char *regionListEnsureTerminatorParserCode = R"(
746   for (auto &region : {0}Regions)
747     ensureTerminator(*region, parser.getBuilder(), result.location);
748 )";
749 
750 /// The code snippet used to ensure a list of regions have a block.
751 ///
752 /// {0}: The name of the region list.
753 const char *regionListEnsureSingleBlockParserCode = R"(
754   for (auto &region : {0}Regions)
755     if (region->empty()) region->emplaceBlock();
756 )";
757 
758 /// The code snippet used to generate a parser call for an optional region.
759 ///
760 /// {0}: The name of the region.
761 const char *optionalRegionParserCode = R"(
762   {
763      auto parseResult = parser.parseOptionalRegion(*{0}Region);
764      if (parseResult.hasValue() && failed(*parseResult))
765        return ::mlir::failure();
766   }
767 )";
768 
769 /// The code snippet used to generate a parser call for a region.
770 ///
771 /// {0}: The name of the region.
772 const char *regionParserCode = R"(
773   if (parser.parseRegion(*{0}Region))
774     return ::mlir::failure();
775 )";
776 
777 /// The code snippet used to ensure a region has a terminator.
778 ///
779 /// {0}: The name of the region.
780 const char *regionEnsureTerminatorParserCode = R"(
781   ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
782 )";
783 
784 /// The code snippet used to ensure a region has a block.
785 ///
786 /// {0}: The name of the region.
787 const char *regionEnsureSingleBlockParserCode = R"(
788   if ({0}Region->empty()) {0}Region->emplaceBlock();
789 )";
790 
791 /// The code snippet used to generate a parser call for a successor list.
792 ///
793 /// {0}: The name for the successor list.
794 const char *successorListParserCode = R"(
795   {
796     ::mlir::Block *succ;
797     auto firstSucc = parser.parseOptionalSuccessor(succ);
798     if (firstSucc.hasValue()) {
799       if (failed(*firstSucc))
800         return ::mlir::failure();
801       {0}Successors.emplace_back(succ);
802 
803       // Parse any trailing successors.
804       while (succeeded(parser.parseOptionalComma())) {
805         if (parser.parseSuccessor(succ))
806           return ::mlir::failure();
807         {0}Successors.emplace_back(succ);
808       }
809     }
810   }
811 )";
812 
813 /// The code snippet used to generate a parser call for a successor.
814 ///
815 /// {0}: The name of the successor.
816 const char *successorParserCode = R"(
817   if (parser.parseSuccessor({0}Successor))
818     return ::mlir::failure();
819 )";
820 
821 namespace {
822 /// The type of length for a given parse argument.
823 enum class ArgumentLengthKind {
824   /// The argument is a variadic of a variadic, and may contain 0->N range
825   /// elements.
826   VariadicOfVariadic,
827   /// The argument is variadic, and may contain 0->N elements.
828   Variadic,
829   /// The argument is optional, and may contain 0 or 1 elements.
830   Optional,
831   /// The argument is a single element, i.e. always represents 1 element.
832   Single
833 };
834 } // namespace
835 
836 /// Get the length kind for the given constraint.
837 static ArgumentLengthKind
838 getArgumentLengthKind(const NamedTypeConstraint *var) {
839   if (var->isOptional())
840     return ArgumentLengthKind::Optional;
841   if (var->isVariadicOfVariadic())
842     return ArgumentLengthKind::VariadicOfVariadic;
843   if (var->isVariadic())
844     return ArgumentLengthKind::Variadic;
845   return ArgumentLengthKind::Single;
846 }
847 
848 /// Get the name used for the type list for the given type directive operand.
849 /// 'lengthKind' to the corresponding kind for the given argument.
850 static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
851   if (auto *operand = dyn_cast<OperandVariable>(arg)) {
852     lengthKind = getArgumentLengthKind(operand->getVar());
853     return operand->getVar()->name;
854   }
855   if (auto *result = dyn_cast<ResultVariable>(arg)) {
856     lengthKind = getArgumentLengthKind(result->getVar());
857     return result->getVar()->name;
858   }
859   lengthKind = ArgumentLengthKind::Variadic;
860   if (isa<OperandsDirective>(arg))
861     return "allOperand";
862   if (isa<ResultsDirective>(arg))
863     return "allResult";
864   llvm_unreachable("unknown 'type' directive argument");
865 }
866 
867 /// Generate the parser for a literal value.
868 static void genLiteralParser(StringRef value, MethodBody &body) {
869   // Handle the case of a keyword/identifier.
870   if (value.front() == '_' || isalpha(value.front())) {
871     body << "Keyword(\"" << value << "\")";
872     return;
873   }
874   body << (StringRef)StringSwitch<StringRef>(value)
875               .Case("->", "Arrow()")
876               .Case(":", "Colon()")
877               .Case(",", "Comma()")
878               .Case("=", "Equal()")
879               .Case("<", "Less()")
880               .Case(">", "Greater()")
881               .Case("{", "LBrace()")
882               .Case("}", "RBrace()")
883               .Case("(", "LParen()")
884               .Case(")", "RParen()")
885               .Case("[", "LSquare()")
886               .Case("]", "RSquare()")
887               .Case("?", "Question()")
888               .Case("+", "Plus()")
889               .Case("*", "Star()");
890 }
891 
892 /// Generate the storage code required for parsing the given element.
893 static void genElementParserStorage(Element *element, const Operator &op,
894                                     MethodBody &body) {
895   if (auto *optional = dyn_cast<OptionalElement>(element)) {
896     auto elements = optional->getThenElements();
897 
898     // If the anchor is a unit attribute, it won't be parsed directly so elide
899     // it.
900     auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
901     Element *elidedAnchorElement = nullptr;
902     if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
903       elidedAnchorElement = anchor;
904     for (auto &childElement : elements)
905       if (&childElement != elidedAnchorElement)
906         genElementParserStorage(&childElement, op, body);
907     for (auto &childElement : optional->getElseElements())
908       genElementParserStorage(&childElement, op, body);
909 
910   } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
911     for (auto &paramElement : custom->getArguments())
912       genElementParserStorage(&paramElement, op, body);
913 
914   } else if (isa<OperandsDirective>(element)) {
915     body << "  ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
916             "allOperands;\n";
917 
918   } else if (isa<RegionsDirective>(element)) {
919     body << "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
920             "fullRegions;\n";
921 
922   } else if (isa<SuccessorsDirective>(element)) {
923     body << "  ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
924 
925   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
926     const NamedAttribute *var = attr->getVar();
927     body << llvm::formatv("  {0} {1}Attr;\n", var->attr.getStorageType(),
928                           var->name);
929 
930   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
931     StringRef name = operand->getVar()->name;
932     if (operand->getVar()->isVariableLength()) {
933       body << "  ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
934            << name << "Operands;\n";
935       if (operand->getVar()->isVariadicOfVariadic()) {
936         body << "    llvm::SmallVector<int32_t> " << name
937              << "OperandGroupSizes;\n";
938       }
939     } else {
940       body << "  ::mlir::OpAsmParser::OperandType " << name
941            << "RawOperands[1];\n"
942            << "  ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
943            << "Operands(" << name << "RawOperands);";
944     }
945     body << llvm::formatv("  ::llvm::SMLoc {0}OperandsLoc;\n"
946                           "  (void){0}OperandsLoc;\n",
947                           name);
948 
949   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
950     StringRef name = region->getVar()->name;
951     if (region->getVar()->isVariadic()) {
952       body << llvm::formatv(
953           "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
954           "{0}Regions;\n",
955           name);
956     } else {
957       body << llvm::formatv("  std::unique_ptr<::mlir::Region> {0}Region = "
958                             "std::make_unique<::mlir::Region>();\n",
959                             name);
960     }
961 
962   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
963     StringRef name = successor->getVar()->name;
964     if (successor->getVar()->isVariadic()) {
965       body << llvm::formatv("  ::llvm::SmallVector<::mlir::Block *, 2> "
966                             "{0}Successors;\n",
967                             name);
968     } else {
969       body << llvm::formatv("  ::mlir::Block *{0}Successor = nullptr;\n", name);
970     }
971 
972   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
973     ArgumentLengthKind lengthKind;
974     StringRef name = getTypeListName(dir->getOperand(), lengthKind);
975     if (lengthKind != ArgumentLengthKind::Single)
976       body << "  ::mlir::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
977     else
978       body << llvm::formatv("  ::mlir::Type {0}RawTypes[1];\n", name)
979            << llvm::formatv(
980                   "  ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
981                   name);
982   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
983     ArgumentLengthKind ignored;
984     body << "  ::llvm::ArrayRef<::mlir::Type> "
985          << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
986     body << "  ::llvm::ArrayRef<::mlir::Type> "
987          << getTypeListName(dir->getResults(), ignored) << "Types;\n";
988   }
989 }
990 
991 /// Generate the parser for a parameter to a custom directive.
992 static void genCustomParameterParser(Element &param, MethodBody &body) {
993   if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
994     body << attr->getVar()->name << "Attr";
995   } else if (isa<AttrDictDirective>(&param)) {
996     body << "result.attributes";
997   } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
998     StringRef name = operand->getVar()->name;
999     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1000     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1001       body << llvm::formatv("{0}OperandGroups", name);
1002     else if (lengthKind == ArgumentLengthKind::Variadic)
1003       body << llvm::formatv("{0}Operands", name);
1004     else if (lengthKind == ArgumentLengthKind::Optional)
1005       body << llvm::formatv("{0}Operand", name);
1006     else
1007       body << formatv("{0}RawOperands[0]", name);
1008 
1009   } else if (auto *region = dyn_cast<RegionVariable>(&param)) {
1010     StringRef name = region->getVar()->name;
1011     if (region->getVar()->isVariadic())
1012       body << llvm::formatv("{0}Regions", name);
1013     else
1014       body << llvm::formatv("*{0}Region", name);
1015 
1016   } else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
1017     StringRef name = successor->getVar()->name;
1018     if (successor->getVar()->isVariadic())
1019       body << llvm::formatv("{0}Successors", name);
1020     else
1021       body << llvm::formatv("{0}Successor", name);
1022 
1023   } else if (auto *dir = dyn_cast<RefDirective>(&param)) {
1024     genCustomParameterParser(*dir->getOperand(), body);
1025 
1026   } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
1027     ArgumentLengthKind lengthKind;
1028     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1029     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1030       body << llvm::formatv("{0}TypeGroups", listName);
1031     else if (lengthKind == ArgumentLengthKind::Variadic)
1032       body << llvm::formatv("{0}Types", listName);
1033     else if (lengthKind == ArgumentLengthKind::Optional)
1034       body << llvm::formatv("{0}Type", listName);
1035     else
1036       body << formatv("{0}RawTypes[0]", listName);
1037   } else {
1038     llvm_unreachable("unknown custom directive parameter");
1039   }
1040 }
1041 
1042 /// Generate the parser for a custom directive.
1043 static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
1044   body << "  {\n";
1045 
1046   // Preprocess the directive variables.
1047   // * Add a local variable for optional operands and types. This provides a
1048   //   better API to the user defined parser methods.
1049   // * Set the location of operand variables.
1050   for (Element &param : dir->getArguments()) {
1051     if (auto *operand = dyn_cast<OperandVariable>(&param)) {
1052       auto *var = operand->getVar();
1053       body << "    " << var->name
1054            << "OperandsLoc = parser.getCurrentLocation();\n";
1055       if (var->isOptional()) {
1056         body << llvm::formatv(
1057             "    llvm::Optional<::mlir::OpAsmParser::OperandType> "
1058             "{0}Operand;\n",
1059             var->name);
1060       } else if (var->isVariadicOfVariadic()) {
1061         body << llvm::formatv("    "
1062                               "llvm::SmallVector<llvm::SmallVector<::mlir::"
1063                               "OpAsmParser::OperandType>> "
1064                               "{0}OperandGroups;\n",
1065                               var->name);
1066       }
1067     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
1068       ArgumentLengthKind lengthKind;
1069       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1070       if (lengthKind == ArgumentLengthKind::Optional) {
1071         body << llvm::formatv("    ::mlir::Type {0}Type;\n", listName);
1072       } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1073         body << llvm::formatv(
1074             "    llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
1075             "{0}TypeGroups;\n",
1076             listName);
1077       }
1078     } else if (auto *dir = dyn_cast<RefDirective>(&param)) {
1079       Element *input = dir->getOperand();
1080       if (auto *operand = dyn_cast<OperandVariable>(input)) {
1081         if (!operand->getVar()->isOptional())
1082           continue;
1083         body << llvm::formatv(
1084             "    {0} {1}Operand = {1}Operands.empty() ? {0}() : "
1085             "{1}Operands[0];\n",
1086             "llvm::Optional<::mlir::OpAsmParser::OperandType>",
1087             operand->getVar()->name);
1088 
1089       } else if (auto *type = dyn_cast<TypeDirective>(input)) {
1090         ArgumentLengthKind lengthKind;
1091         StringRef listName = getTypeListName(type->getOperand(), lengthKind);
1092         if (lengthKind == ArgumentLengthKind::Optional) {
1093           body << llvm::formatv("    ::mlir::Type {0}Type = {0}Types.empty() ? "
1094                                 "::mlir::Type() : {0}Types[0];\n",
1095                                 listName);
1096         }
1097       }
1098     }
1099   }
1100 
1101   body << "    if (parse" << dir->getName() << "(parser";
1102   for (Element &param : dir->getArguments()) {
1103     body << ", ";
1104     genCustomParameterParser(param, body);
1105   }
1106 
1107   body << "))\n"
1108        << "      return ::mlir::failure();\n";
1109 
1110   // After parsing, add handling for any of the optional constructs.
1111   for (Element &param : dir->getArguments()) {
1112     if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
1113       const NamedAttribute *var = attr->getVar();
1114       if (var->attr.isOptional())
1115         body << llvm::formatv("    if ({0}Attr)\n  ", var->name);
1116 
1117       body << llvm::formatv("    result.addAttribute(\"{0}\", {0}Attr);\n",
1118                             var->name);
1119     } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
1120       const NamedTypeConstraint *var = operand->getVar();
1121       if (var->isOptional()) {
1122         body << llvm::formatv("    if ({0}Operand.hasValue())\n"
1123                               "      {0}Operands.push_back(*{0}Operand);\n",
1124                               var->name);
1125       } else if (var->isVariadicOfVariadic()) {
1126         body << llvm::formatv(
1127             "    for (const auto &subRange : {0}OperandGroups) {{\n"
1128             "      {0}Operands.append(subRange.begin(), subRange.end());\n"
1129             "      {0}OperandGroupSizes.push_back(subRange.size());\n"
1130             "    }\n",
1131             var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr());
1132       }
1133     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
1134       ArgumentLengthKind lengthKind;
1135       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1136       if (lengthKind == ArgumentLengthKind::Optional) {
1137         body << llvm::formatv("    if ({0}Type)\n"
1138                               "      {0}Types.push_back({0}Type);\n",
1139                               listName);
1140       } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1141         body << llvm::formatv(
1142             "    for (const auto &subRange : {0}TypeGroups)\n"
1143             "      {0}Types.append(subRange.begin(), subRange.end());\n",
1144             listName);
1145       }
1146     }
1147   }
1148 
1149   body << "  }\n";
1150 }
1151 
1152 /// Generate the parser for a enum attribute.
1153 static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
1154                               FmtContext &attrTypeCtx) {
1155   Attribute baseAttr = var->attr.getBaseAttr();
1156   const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1157   std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1158 
1159   // Generate the code for building an attribute for this enum.
1160   std::string attrBuilderStr;
1161   {
1162     llvm::raw_string_ostream os(attrBuilderStr);
1163     os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1164                 "attrOptional.getValue()");
1165   }
1166 
1167   // Build a string containing the cases that can be formatted as a keyword.
1168   std::string validCaseKeywordsStr = "{";
1169   llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
1170   for (const EnumAttrCase &attrCase : cases)
1171     if (canFormatStringAsKeyword(attrCase.getStr()))
1172       validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
1173   validCaseKeywordsOS.str().back() = '}';
1174 
1175   // If the attribute is not optional, build an error message for the missing
1176   // attribute.
1177   std::string errorMessage;
1178   if (!var->attr.isOptional()) {
1179     llvm::raw_string_ostream errorMessageOS(errorMessage);
1180     errorMessageOS
1181         << "return parser.emitError(loc, \"expected string or "
1182            "keyword containing one of the following enum values for attribute '"
1183         << var->name << "' [";
1184     llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) {
1185       errorMessageOS << attrCase.getStr();
1186     });
1187     errorMessageOS << "]\");";
1188   }
1189 
1190   body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
1191                   enumAttr.getStringToSymbolFnName(), attrBuilderStr,
1192                   validCaseKeywordsStr, errorMessage);
1193 }
1194 
1195 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1196   SmallVector<MethodParameter> paramList;
1197   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1198   paramList.emplace_back("::mlir::OperationState &", "result");
1199 
1200   auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
1201                                          std::move(paramList));
1202   auto &body = method->body();
1203 
1204   // Generate variables to store the operands and type within the format. This
1205   // allows for referencing these variables in the presence of optional
1206   // groupings.
1207   for (auto &element : elements)
1208     genElementParserStorage(&*element, op, body);
1209 
1210   // A format context used when parsing attributes with buildable types.
1211   FmtContext attrTypeCtx;
1212   attrTypeCtx.withBuilder("parser.getBuilder()");
1213 
1214   // Generate parsers for each of the elements.
1215   for (auto &element : elements)
1216     genElementParser(element.get(), body, attrTypeCtx);
1217 
1218   // Generate the code to resolve the operand/result types and successors now
1219   // that they have been parsed.
1220   genParserRegionResolution(op, body);
1221   genParserSuccessorResolution(op, body);
1222   genParserVariadicSegmentResolution(op, body);
1223   genParserTypeResolution(op, body);
1224 
1225   body << "  return ::mlir::success();\n";
1226 }
1227 
1228 void OperationFormat::genElementParser(Element *element, MethodBody &body,
1229                                        FmtContext &attrTypeCtx,
1230                                        GenContext genCtx) {
1231   /// Optional Group.
1232   if (auto *optional = dyn_cast<OptionalElement>(element)) {
1233     auto elements = llvm::drop_begin(optional->getThenElements(),
1234                                      optional->getParseStart());
1235 
1236     // Generate a special optional parser for the first element to gate the
1237     // parsing of the rest of the elements.
1238     Element *firstElement = &*elements.begin();
1239     if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1240       genElementParser(attrVar, body, attrTypeCtx);
1241       body << "  if (" << attrVar->getVar()->name << "Attr) {\n";
1242     } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1243       body << "  if (succeeded(parser.parseOptional";
1244       genLiteralParser(literal->getLiteral(), body);
1245       body << ")) {\n";
1246     } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1247       genElementParser(opVar, body, attrTypeCtx);
1248       body << "  if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1249     } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1250       const NamedRegion *region = regionVar->getVar();
1251       if (region->isVariadic()) {
1252         genElementParser(regionVar, body, attrTypeCtx);
1253         body << "  if (!" << region->name << "Regions.empty()) {\n";
1254       } else {
1255         body << llvm::formatv(optionalRegionParserCode, region->name);
1256         body << "  if (!" << region->name << "Region->empty()) {\n  ";
1257         if (hasImplicitTermTrait)
1258           body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
1259         else if (hasSingleBlockTrait)
1260           body << llvm::formatv(regionEnsureSingleBlockParserCode,
1261                                 region->name);
1262       }
1263     }
1264 
1265     // If the anchor is a unit attribute, we don't need to print it. When
1266     // parsing, we will add this attribute if this group is present.
1267     Element *elidedAnchorElement = nullptr;
1268     auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
1269     if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
1270       elidedAnchorElement = anchorAttr;
1271 
1272       // Add the anchor unit attribute to the operation state.
1273       body << "    result.addAttribute(\"" << anchorAttr->getVar()->name
1274            << "\", parser.getBuilder().getUnitAttr());\n";
1275     }
1276 
1277     // Generate the rest of the elements inside an optional group. Elements in
1278     // an optional group after the guard are parsed as required.
1279     for (Element &childElement : llvm::drop_begin(elements, 1)) {
1280       if (&childElement != elidedAnchorElement) {
1281         genElementParser(&childElement, body, attrTypeCtx,
1282                          GenContext::Optional);
1283       }
1284     }
1285     body << "  }";
1286 
1287     // Generate the else elements.
1288     auto elseElements = optional->getElseElements();
1289     if (!elseElements.empty()) {
1290       body << " else {\n";
1291       for (Element &childElement : elseElements)
1292         genElementParser(&childElement, body, attrTypeCtx);
1293       body << "  }";
1294     }
1295     body << "\n";
1296 
1297     /// Literals.
1298   } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1299     body << "  if (parser.parse";
1300     genLiteralParser(literal->getLiteral(), body);
1301     body << ")\n    return ::mlir::failure();\n";
1302 
1303     /// Whitespaces.
1304   } else if (isa<WhitespaceElement>(element)) {
1305     // Nothing to parse.
1306 
1307     /// Arguments.
1308   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1309     const NamedAttribute *var = attr->getVar();
1310 
1311     // Check to see if we can parse this as an enum attribute.
1312     if (canFormatEnumAttr(var))
1313       return genEnumAttrParser(var, body, attrTypeCtx);
1314 
1315     // Check to see if we should parse this as a symbol name attribute.
1316     if (shouldFormatSymbolNameAttr(var)) {
1317       body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode
1318                                              : symbolNameAttrParserCode,
1319                       var->name);
1320       return;
1321     }
1322 
1323     // If this attribute has a buildable type, use that when parsing the
1324     // attribute.
1325     std::string attrTypeStr;
1326     if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1327       llvm::raw_string_ostream os(attrTypeStr);
1328       os << tgfmt(*typeBuilder, &attrTypeCtx);
1329     } else {
1330       attrTypeStr = "::mlir::Type{}";
1331     }
1332     if (genCtx == GenContext::Normal && var->attr.isOptional()) {
1333       body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
1334     } else {
1335       if (attr->shouldBeQualified() ||
1336           var->attr.getStorageType() == "::mlir::Attribute")
1337         body << formatv(genericAttrParserCode, var->name, attrTypeStr);
1338       else
1339         body << formatv(attrParserCode, var->name, attrTypeStr);
1340     }
1341 
1342   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1343     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1344     StringRef name = operand->getVar()->name;
1345     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1346       body << llvm::formatv(
1347           variadicOfVariadicOperandParserCode, name,
1348           operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr());
1349     else if (lengthKind == ArgumentLengthKind::Variadic)
1350       body << llvm::formatv(variadicOperandParserCode, name);
1351     else if (lengthKind == ArgumentLengthKind::Optional)
1352       body << llvm::formatv(optionalOperandParserCode, name);
1353     else
1354       body << formatv(operandParserCode, name);
1355 
1356   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1357     bool isVariadic = region->getVar()->isVariadic();
1358     body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
1359                           region->getVar()->name);
1360     if (hasImplicitTermTrait)
1361       body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
1362                                        : regionEnsureTerminatorParserCode,
1363                             region->getVar()->name);
1364     else if (hasSingleBlockTrait)
1365       body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
1366                                        : regionEnsureSingleBlockParserCode,
1367                             region->getVar()->name);
1368 
1369   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1370     bool isVariadic = successor->getVar()->isVariadic();
1371     body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1372                     successor->getVar()->name);
1373 
1374     /// Directives.
1375   } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1376     body << "  if (parser.parseOptionalAttrDict"
1377          << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1378          << "(result.attributes))\n"
1379          << "    return ::mlir::failure();\n";
1380   } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1381     genCustomDirectiveParser(customDir, body);
1382 
1383   } else if (isa<OperandsDirective>(element)) {
1384     body << "  ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1385          << "  if (parser.parseOperandList(allOperands))\n"
1386          << "    return ::mlir::failure();\n";
1387 
1388   } else if (isa<RegionsDirective>(element)) {
1389     body << llvm::formatv(regionListParserCode, "full");
1390     if (hasImplicitTermTrait)
1391       body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
1392     else if (hasSingleBlockTrait)
1393       body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
1394 
1395   } else if (isa<SuccessorsDirective>(element)) {
1396     body << llvm::formatv(successorListParserCode, "full");
1397 
1398   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1399     ArgumentLengthKind lengthKind;
1400     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1401     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1402       body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
1403     } else if (lengthKind == ArgumentLengthKind::Variadic) {
1404       body << llvm::formatv(variadicTypeParserCode, listName);
1405     } else if (lengthKind == ArgumentLengthKind::Optional) {
1406       body << llvm::formatv(optionalTypeParserCode, listName);
1407     } else {
1408       const char *parserCode =
1409           dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
1410       TypeSwitch<Element *>(dir->getOperand())
1411           .Case<OperandVariable, ResultVariable>([&](auto operand) {
1412             body << formatv(parserCode,
1413                             operand->getVar()->constraint.getCPPClassName(),
1414                             listName);
1415           })
1416           .Default([&](auto operand) {
1417             body << formatv(parserCode, "::mlir::Type", listName);
1418           });
1419     }
1420   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1421     ArgumentLengthKind ignored;
1422     body << formatv(functionalTypeParserCode,
1423                     getTypeListName(dir->getInputs(), ignored),
1424                     getTypeListName(dir->getResults(), ignored));
1425   } else {
1426     llvm_unreachable("unknown format element");
1427   }
1428 }
1429 
1430 void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
1431   // If any of type resolutions use transformed variables, make sure that the
1432   // types of those variables are resolved.
1433   SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1434   FmtContext verifierFCtx;
1435   for (TypeResolution &resolver :
1436        llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1437     Optional<StringRef> transformer = resolver.getVarTransformer();
1438     if (!transformer)
1439       continue;
1440     // Ensure that we don't verify the same variables twice.
1441     const NamedTypeConstraint *variable = resolver.getVariable();
1442     if (!variable || !verifiedVariables.insert(variable).second)
1443       continue;
1444 
1445     auto constraint = variable->constraint;
1446     body << "  for (::mlir::Type type : " << variable->name << "Types) {\n"
1447          << "    (void)type;\n"
1448          << "    if (!("
1449          << tgfmt(constraint.getConditionTemplate(),
1450                   &verifierFCtx.withSelf("type"))
1451          << ")) {\n"
1452          << formatv("      return parser.emitError(parser.getNameLoc()) << "
1453                     "\"'{0}' must be {1}, but got \" << type;\n",
1454                     variable->name, constraint.getSummary())
1455          << "    }\n"
1456          << "  }\n";
1457   }
1458 
1459   // Initialize the set of buildable types.
1460   if (!buildableTypes.empty()) {
1461     FmtContext typeBuilderCtx;
1462     typeBuilderCtx.withBuilder("parser.getBuilder()");
1463     for (auto &it : buildableTypes)
1464       body << "  ::mlir::Type odsBuildableType" << it.second << " = "
1465            << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1466   }
1467 
1468   // Emit the code necessary for a type resolver.
1469   auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1470     if (Optional<int> val = resolver.getBuilderIdx()) {
1471       body << "odsBuildableType" << *val;
1472     } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1473       if (Optional<StringRef> tform = resolver.getVarTransformer()) {
1474         FmtContext fmtContext;
1475         fmtContext.addSubst("_ctxt", "parser.getContext()");
1476         if (var->isVariadic())
1477           fmtContext.withSelf(var->name + "Types");
1478         else
1479           fmtContext.withSelf(var->name + "Types[0]");
1480         body << tgfmt(*tform, &fmtContext);
1481       } else {
1482         body << var->name << "Types";
1483       }
1484     } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1485       if (Optional<StringRef> tform = resolver.getVarTransformer())
1486         body << tgfmt(*tform,
1487                       &FmtContext().withSelf(attr->name + "Attr.getType()"));
1488       else
1489         body << attr->name << "Attr.getType()";
1490     } else {
1491       body << curVar << "Types";
1492     }
1493   };
1494 
1495   // Resolve each of the result types.
1496   if (!infersResultTypes) {
1497     if (allResultTypes) {
1498       body << "  result.addTypes(allResultTypes);\n";
1499     } else {
1500       for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1501         body << "  result.addTypes(";
1502         emitTypeResolver(resultTypes[i], op.getResultName(i));
1503         body << ");\n";
1504       }
1505     }
1506   }
1507 
1508   // Emit the operand type resolutions.
1509   genParserOperandTypeResolution(op, body, emitTypeResolver);
1510 
1511   // Handle return type inference once all operands have been resolved
1512   if (infersResultTypes)
1513     body << formatv(inferReturnTypesParserCode, op.getCppClassName());
1514 }
1515 
1516 void OperationFormat::genParserOperandTypeResolution(
1517     Operator &op, MethodBody &body,
1518     function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
1519   // Early exit if there are no operands.
1520   if (op.getNumOperands() == 0)
1521     return;
1522 
1523   // Handle the case where all operand types are grouped together with
1524   // "types(operands)".
1525   if (allOperandTypes) {
1526     // If `operands` was specified, use the full operand list directly.
1527     if (allOperands) {
1528       body << "  if (parser.resolveOperands(allOperands, allOperandTypes, "
1529               "allOperandLoc, result.operands))\n"
1530               "    return ::mlir::failure();\n";
1531       return;
1532     }
1533 
1534     // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1535     // llvm::concat does not allow the case of a single range, so guard it here.
1536     body << "  if (parser.resolveOperands(";
1537     if (op.getNumOperands() > 1) {
1538       body << "::llvm::concat<const ::mlir::OpAsmParser::OperandType>(";
1539       llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1540         body << operand.name << "Operands";
1541       });
1542       body << ")";
1543     } else {
1544       body << op.operand_begin()->name << "Operands";
1545     }
1546     body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1547          << "    return ::mlir::failure();\n";
1548     return;
1549   }
1550 
1551   // Handle the case where all operands are grouped together with "operands".
1552   if (allOperands) {
1553     body << "  if (parser.resolveOperands(allOperands, ";
1554 
1555     // Group all of the operand types together to perform the resolution all at
1556     // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1557     // the case of a single range, so guard it here.
1558     if (op.getNumOperands() > 1) {
1559       body << "::llvm::concat<const ::mlir::Type>(";
1560       llvm::interleaveComma(
1561           llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1562             body << "::llvm::ArrayRef<::mlir::Type>(";
1563             emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1564             body << ")";
1565           });
1566       body << ")";
1567     } else {
1568       emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1569     }
1570 
1571     body << ", allOperandLoc, result.operands))\n"
1572          << "    return ::mlir::failure();\n";
1573     return;
1574   }
1575 
1576   // The final case is the one where each of the operands types are resolved
1577   // separately.
1578   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1579     NamedTypeConstraint &operand = op.getOperand(i);
1580     body << "  if (parser.resolveOperands(" << operand.name << "Operands, ";
1581 
1582     // Resolve the type of this operand.
1583     TypeResolution &operandType = operandTypes[i];
1584     emitTypeResolver(operandType, operand.name);
1585 
1586     // If the type is resolved by a non-variadic variable, index into the
1587     // resolved type list. This allows for resolving the types of a variadic
1588     // operand list from a non-variadic variable.
1589     bool verifyOperandAndTypeSize = true;
1590     if (auto *resolverVar = operandType.getVariable()) {
1591       if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) {
1592         body << "[0]";
1593         verifyOperandAndTypeSize = false;
1594       }
1595     } else {
1596       verifyOperandAndTypeSize = !operandType.getBuilderIdx();
1597     }
1598 
1599     // Check to see if the sizes between the types and operands must match. If
1600     // they do, provide the operand location to select the proper resolution
1601     // overload.
1602     if (verifyOperandAndTypeSize)
1603       body << ", " << operand.name << "OperandsLoc";
1604     body << ", result.operands))\n    return ::mlir::failure();\n";
1605   }
1606 }
1607 
1608 void OperationFormat::genParserRegionResolution(Operator &op,
1609                                                 MethodBody &body) {
1610   // Check for the case where all regions were parsed.
1611   bool hasAllRegions = llvm::any_of(
1612       elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
1613   if (hasAllRegions) {
1614     body << "  result.addRegions(fullRegions);\n";
1615     return;
1616   }
1617 
1618   // Otherwise, handle each region individually.
1619   for (const NamedRegion &region : op.getRegions()) {
1620     if (region.isVariadic())
1621       body << "  result.addRegions(" << region.name << "Regions);\n";
1622     else
1623       body << "  result.addRegion(std::move(" << region.name << "Region));\n";
1624   }
1625 }
1626 
1627 void OperationFormat::genParserSuccessorResolution(Operator &op,
1628                                                    MethodBody &body) {
1629   // Check for the case where all successors were parsed.
1630   bool hasAllSuccessors = llvm::any_of(
1631       elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
1632   if (hasAllSuccessors) {
1633     body << "  result.addSuccessors(fullSuccessors);\n";
1634     return;
1635   }
1636 
1637   // Otherwise, handle each successor individually.
1638   for (const NamedSuccessor &successor : op.getSuccessors()) {
1639     if (successor.isVariadic())
1640       body << "  result.addSuccessors(" << successor.name << "Successors);\n";
1641     else
1642       body << "  result.addSuccessors(" << successor.name << "Successor);\n";
1643   }
1644 }
1645 
1646 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1647                                                          MethodBody &body) {
1648   if (!allOperands) {
1649     if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1650       body << "  result.addAttribute(\"operand_segment_sizes\", "
1651            << "parser.getBuilder().getI32VectorAttr({";
1652       auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1653         // If the operand is variadic emit the parsed size.
1654         if (operand.isVariableLength())
1655           body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1656         else
1657           body << "1";
1658       };
1659       llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1660       body << "}));\n";
1661     }
1662     for (const NamedTypeConstraint &operand : op.getOperands()) {
1663       if (!operand.isVariadicOfVariadic())
1664         continue;
1665       body << llvm::formatv(
1666           "  result.addAttribute(\"{0}\", "
1667           "parser.getBuilder().getI32TensorAttr({1}OperandGroupSizes));\n",
1668           operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1669           operand.name);
1670     }
1671   }
1672 
1673   if (!allResultTypes &&
1674       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1675     body << "  result.addAttribute(\"result_segment_sizes\", "
1676          << "parser.getBuilder().getI32VectorAttr({";
1677     auto interleaveFn = [&](const NamedTypeConstraint &result) {
1678       // If the result is variadic emit the parsed size.
1679       if (result.isVariableLength())
1680         body << "static_cast<int32_t>(" << result.name << "Types.size())";
1681       else
1682         body << "1";
1683     };
1684     llvm::interleaveComma(op.getResults(), body, interleaveFn);
1685     body << "}));\n";
1686   }
1687 }
1688 
1689 //===----------------------------------------------------------------------===//
1690 // PrinterGen
1691 
1692 /// The code snippet used to generate a printer call for a region of an
1693 // operation that has the SingleBlockImplicitTerminator trait.
1694 ///
1695 /// {0}: The name of the region.
1696 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1697   {
1698     bool printTerminator = true;
1699     if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1700       printTerminator = !term->getAttrDictionary().empty() ||
1701                         term->getNumOperands() != 0 ||
1702                         term->getNumResults() != 0;
1703     }
1704     _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1705       /*printBlockTerminators=*/printTerminator);
1706   }
1707 )";
1708 
1709 /// The code snippet used to generate a printer call for an enum that has cases
1710 /// that can't be represented with a keyword.
1711 ///
1712 /// {0}: The name of the enum attribute.
1713 /// {1}: The name of the enum attributes symbolToString function.
1714 const char *enumAttrBeginPrinterCode = R"(
1715   {
1716     auto caseValue = {0}();
1717     auto caseValueStr = {1}(caseValue);
1718 )";
1719 
1720 /// Generate the printer for the 'attr-dict' directive.
1721 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1722                                MethodBody &body, bool withKeyword) {
1723   body << "  _odsPrinter.printOptionalAttrDict"
1724        << (withKeyword ? "WithKeyword" : "")
1725        << "((*this)->getAttrs(), /*elidedAttrs=*/{";
1726   // Elide the variadic segment size attributes if necessary.
1727   if (!fmt.allOperands &&
1728       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1729     body << "\"operand_segment_sizes\", ";
1730   if (!fmt.allResultTypes &&
1731       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1732     body << "\"result_segment_sizes\", ";
1733   if (!fmt.inferredAttributes.empty()) {
1734     for (const auto &attr : fmt.inferredAttributes)
1735       body << "\"" << attr.getKey() << "\", ";
1736   }
1737   llvm::interleaveComma(
1738       fmt.usedAttributes, body,
1739       [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
1740   body << "});\n";
1741 }
1742 
1743 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1744 /// space should be emitted before this element. `lastWasPunctuation` is true if
1745 /// the previous element was a punctuation literal.
1746 static void genLiteralPrinter(StringRef value, MethodBody &body,
1747                               bool &shouldEmitSpace, bool &lastWasPunctuation) {
1748   body << "  _odsPrinter";
1749 
1750   // Don't insert a space for certain punctuation.
1751   if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
1752     body << " << ' '";
1753   body << " << \"" << value << "\";\n";
1754 
1755   // Insert a space after certain literals.
1756   shouldEmitSpace =
1757       value.size() != 1 || !StringRef("<({[").contains(value.front());
1758   lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
1759 }
1760 
1761 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1762 /// are set to false.
1763 static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
1764                             bool &lastWasPunctuation) {
1765   if (value) {
1766     body << "  _odsPrinter << ' ';\n";
1767     lastWasPunctuation = false;
1768   } else {
1769     lastWasPunctuation = true;
1770   }
1771   shouldEmitSpace = false;
1772 }
1773 
1774 /// Generate the printer for a custom directive parameter.
1775 static void genCustomDirectiveParameterPrinter(Element *element,
1776                                                const Operator &op,
1777                                                MethodBody &body) {
1778   if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1779     body << op.getGetterName(attr->getVar()->name) << "Attr()";
1780 
1781   } else if (isa<AttrDictDirective>(element)) {
1782     body << "getOperation()->getAttrDictionary()";
1783 
1784   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1785     body << op.getGetterName(operand->getVar()->name) << "()";
1786 
1787   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1788     body << op.getGetterName(region->getVar()->name) << "()";
1789 
1790   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1791     body << op.getGetterName(successor->getVar()->name) << "()";
1792 
1793   } else if (auto *dir = dyn_cast<RefDirective>(element)) {
1794     genCustomDirectiveParameterPrinter(dir->getOperand(), op, body);
1795 
1796   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1797     auto *typeOperand = dir->getOperand();
1798     auto *operand = dyn_cast<OperandVariable>(typeOperand);
1799     auto *var = operand ? operand->getVar()
1800                         : cast<ResultVariable>(typeOperand)->getVar();
1801     std::string name = op.getGetterName(var->name);
1802     if (var->isVariadic())
1803       body << name << "().getTypes()";
1804     else if (var->isOptional())
1805       body << llvm::formatv("({0}() ? {0}().getType() : Type())", name);
1806     else
1807       body << name << "().getType()";
1808   } else {
1809     llvm_unreachable("unknown custom directive parameter");
1810   }
1811 }
1812 
1813 /// Generate the printer for a custom directive.
1814 static void genCustomDirectivePrinter(CustomDirective *customDir,
1815                                       const Operator &op, MethodBody &body) {
1816   body << "  print" << customDir->getName() << "(_odsPrinter, *this";
1817   for (Element &param : customDir->getArguments()) {
1818     body << ", ";
1819     genCustomDirectiveParameterPrinter(&param, op, body);
1820   }
1821   body << ");\n";
1822 }
1823 
1824 /// Generate the printer for a region with the given variable name.
1825 static void genRegionPrinter(const Twine &regionName, MethodBody &body,
1826                              bool hasImplicitTermTrait) {
1827   if (hasImplicitTermTrait)
1828     body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
1829                           regionName);
1830   else
1831     body << "  _odsPrinter.printRegion(" << regionName << ");\n";
1832 }
1833 static void genVariadicRegionPrinter(const Twine &regionListName,
1834                                      MethodBody &body,
1835                                      bool hasImplicitTermTrait) {
1836   body << "    llvm::interleaveComma(" << regionListName
1837        << ", _odsPrinter, [&](::mlir::Region &region) {\n      ";
1838   genRegionPrinter("region", body, hasImplicitTermTrait);
1839   body << "    });\n";
1840 }
1841 
1842 /// Generate the C++ for an operand to a (*-)type directive.
1843 static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
1844                                          MethodBody &body,
1845                                          bool useArrayRef = true) {
1846   if (isa<OperandsDirective>(arg))
1847     return body << "getOperation()->getOperandTypes()";
1848   if (isa<ResultsDirective>(arg))
1849     return body << "getOperation()->getResultTypes()";
1850   auto *operand = dyn_cast<OperandVariable>(arg);
1851   auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
1852   if (var->isVariadicOfVariadic())
1853     return body << llvm::formatv("{0}().join().getTypes()",
1854                                  op.getGetterName(var->name));
1855   if (var->isVariadic())
1856     return body << op.getGetterName(var->name) << "().getTypes()";
1857   if (var->isOptional())
1858     return body << llvm::formatv(
1859                "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1860                "::llvm::ArrayRef<::mlir::Type>())",
1861                op.getGetterName(var->name));
1862   if (useArrayRef)
1863     return body << "::llvm::ArrayRef<::mlir::Type>("
1864                 << op.getGetterName(var->name) << "().getType())";
1865   return body << op.getGetterName(var->name) << "().getType()";
1866 }
1867 
1868 /// Generate the printer for an enum attribute.
1869 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
1870                                MethodBody &body) {
1871   Attribute baseAttr = var->attr.getBaseAttr();
1872   const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1873   std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1874 
1875   body << llvm::formatv(enumAttrBeginPrinterCode,
1876                         (var->attr.isOptional() ? "*" : "") +
1877                             op.getGetterName(var->name),
1878                         enumAttr.getSymbolToStringFnName());
1879 
1880   // Get a string containing all of the cases that can't be represented with a
1881   // keyword.
1882   llvm::BitVector nonKeywordCases(cases.size());
1883   bool hasStrCase = false;
1884   for (auto &it : llvm::enumerate(cases)) {
1885     hasStrCase = it.value().isStrCase();
1886     if (!canFormatStringAsKeyword(it.value().getStr()))
1887       nonKeywordCases.set(it.index());
1888   }
1889 
1890   // If this is a string enum, use the case string to determine which cases
1891   // need to use the string form.
1892   if (hasStrCase) {
1893     if (nonKeywordCases.any()) {
1894       body << "    if (llvm::is_contained(llvm::ArrayRef<llvm::StringRef>(";
1895       llvm::interleaveComma(nonKeywordCases.set_bits(), body, [&](unsigned it) {
1896         body << '"' << cases[it].getStr() << '"';
1897       });
1898       body << ")))\n"
1899               "      _odsPrinter << '\"' << caseValueStr << '\"';\n"
1900               "    else\n  ";
1901     }
1902     body << "    _odsPrinter << caseValueStr;\n"
1903             "  }\n";
1904     return;
1905   }
1906 
1907   // Otherwise if this is a bit enum attribute, don't allow cases that may
1908   // overlap with other cases. For simplicity sake, only allow cases with a
1909   // single bit value.
1910   if (enumAttr.isBitEnum()) {
1911     for (auto &it : llvm::enumerate(cases)) {
1912       int64_t value = it.value().getValue();
1913       if (value < 0 || !llvm::isPowerOf2_64(value))
1914         nonKeywordCases.set(it.index());
1915     }
1916   }
1917 
1918   // If there are any cases that can't be used with a keyword, switch on the
1919   // case value to determine when to print in the string form.
1920   if (nonKeywordCases.any()) {
1921     body << "    switch (caseValue) {\n";
1922     StringRef cppNamespace = enumAttr.getCppNamespace();
1923     StringRef enumName = enumAttr.getEnumClassName();
1924     for (auto &it : llvm::enumerate(cases)) {
1925       if (nonKeywordCases.test(it.index()))
1926         continue;
1927       StringRef symbol = it.value().getSymbol();
1928       body << llvm::formatv("    case {0}::{1}::{2}:\n", cppNamespace, enumName,
1929                             llvm::isDigit(symbol.front()) ? ("_" + symbol)
1930                                                           : symbol);
1931     }
1932     body << "      _odsPrinter << caseValueStr;\n"
1933             "      break;\n"
1934             "    default:\n"
1935             "      _odsPrinter << '\"' << caseValueStr << '\"';\n"
1936             "      break;\n"
1937             "    }\n"
1938             "  }\n";
1939     return;
1940   }
1941 
1942   body << "    _odsPrinter << caseValueStr;\n"
1943           "  }\n";
1944 }
1945 
1946 /// Generate the check for the anchor of an optional group.
1947 static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
1948                                           MethodBody &body) {
1949   TypeSwitch<Element *>(anchor)
1950       .Case<OperandVariable, ResultVariable>([&](auto *element) {
1951         const NamedTypeConstraint *var = element->getVar();
1952         std::string name = op.getGetterName(var->name);
1953         if (var->isOptional())
1954           body << "  if (" << name << "()) {\n";
1955         else if (var->isVariadic())
1956           body << "  if (!" << name << "().empty()) {\n";
1957       })
1958       .Case<RegionVariable>([&](RegionVariable *element) {
1959         const NamedRegion *var = element->getVar();
1960         std::string name = op.getGetterName(var->name);
1961         // TODO: Add a check for optional regions here when ODS supports it.
1962         body << "  if (!" << name << "().empty()) {\n";
1963       })
1964       .Case<TypeDirective>([&](TypeDirective *element) {
1965         genOptionalGroupPrinterAnchor(element->getOperand(), op, body);
1966       })
1967       .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
1968         genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
1969       })
1970       .Case<AttributeVariable>([&](AttributeVariable *attr) {
1971         body << "  if ((*this)->getAttr(\"" << attr->getVar()->name
1972              << "\")) {\n";
1973       });
1974 }
1975 
1976 void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
1977                                         Operator &op, bool &shouldEmitSpace,
1978                                         bool &lastWasPunctuation) {
1979   if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
1980     return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
1981                              lastWasPunctuation);
1982 
1983   // Emit a whitespace element.
1984   if (isa<NewlineElement>(element)) {
1985     body << "  _odsPrinter.printNewline();\n";
1986     return;
1987   }
1988   if (SpaceElement *space = dyn_cast<SpaceElement>(element))
1989     return genSpacePrinter(space->getValue(), body, shouldEmitSpace,
1990                            lastWasPunctuation);
1991 
1992   // Emit an optional group.
1993   if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
1994     // Emit the check for the presence of the anchor element.
1995     Element *anchor = optional->getAnchor();
1996     genOptionalGroupPrinterAnchor(anchor, op, body);
1997 
1998     // If the anchor is a unit attribute, we don't need to print it. When
1999     // parsing, we will add this attribute if this group is present.
2000     auto elements = optional->getThenElements();
2001     Element *elidedAnchorElement = nullptr;
2002     auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
2003     if (anchorAttr && anchorAttr != &*elements.begin() &&
2004         anchorAttr->isUnitAttr()) {
2005       elidedAnchorElement = anchorAttr;
2006     }
2007 
2008     // Emit each of the elements.
2009     for (Element &childElement : elements) {
2010       if (&childElement != elidedAnchorElement) {
2011         genElementPrinter(&childElement, body, op, shouldEmitSpace,
2012                           lastWasPunctuation);
2013       }
2014     }
2015     body << "  }";
2016 
2017     // Emit each of the else elements.
2018     auto elseElements = optional->getElseElements();
2019     if (!elseElements.empty()) {
2020       body << " else {\n";
2021       for (Element &childElement : elseElements) {
2022         genElementPrinter(&childElement, body, op, shouldEmitSpace,
2023                           lastWasPunctuation);
2024       }
2025       body << "  }";
2026     }
2027 
2028     body << "\n";
2029     return;
2030   }
2031 
2032   // Emit the attribute dictionary.
2033   if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
2034     genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
2035     lastWasPunctuation = false;
2036     return;
2037   }
2038 
2039   // Optionally insert a space before the next element. The AttrDict printer
2040   // already adds a space as necessary.
2041   if (shouldEmitSpace || !lastWasPunctuation)
2042     body << "  _odsPrinter << ' ';\n";
2043   lastWasPunctuation = false;
2044   shouldEmitSpace = true;
2045 
2046   if (auto *attr = dyn_cast<AttributeVariable>(element)) {
2047     const NamedAttribute *var = attr->getVar();
2048 
2049     // If we are formatting as an enum, symbolize the attribute as a string.
2050     if (canFormatEnumAttr(var))
2051       return genEnumAttrPrinter(var, op, body);
2052 
2053     // If we are formatting as a symbol name, handle it as a symbol name.
2054     if (shouldFormatSymbolNameAttr(var)) {
2055       body << "  _odsPrinter.printSymbolName(" << op.getGetterName(var->name)
2056            << "Attr().getValue());\n";
2057       return;
2058     }
2059 
2060     // Elide the attribute type if it is buildable.
2061     if (attr->getTypeBuilder())
2062       body << "  _odsPrinter.printAttributeWithoutType("
2063            << op.getGetterName(var->name) << "Attr());\n";
2064     else if (attr->shouldBeQualified() ||
2065              var->attr.getStorageType() == "::mlir::Attribute")
2066       body << "  _odsPrinter.printAttribute(" << op.getGetterName(var->name)
2067            << "Attr());\n";
2068     else
2069       body << "_odsPrinter.printStrippedAttrOrType("
2070            << op.getGetterName(var->name) << "Attr());\n";
2071   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
2072     if (operand->getVar()->isVariadicOfVariadic()) {
2073       body << "  ::llvm::interleaveComma("
2074            << op.getGetterName(operand->getVar()->name)
2075            << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2076               "\"(\" << operands << "
2077               "\")\"; });\n";
2078 
2079     } else if (operand->getVar()->isOptional()) {
2080       body << "  if (::mlir::Value value = "
2081            << op.getGetterName(operand->getVar()->name) << "())\n"
2082            << "    _odsPrinter << value;\n";
2083     } else {
2084       body << "  _odsPrinter << " << op.getGetterName(operand->getVar()->name)
2085            << "();\n";
2086     }
2087   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
2088     const NamedRegion *var = region->getVar();
2089     std::string name = op.getGetterName(var->name);
2090     if (var->isVariadic()) {
2091       genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait);
2092     } else {
2093       genRegionPrinter(name + "()", body, hasImplicitTermTrait);
2094     }
2095   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
2096     const NamedSuccessor *var = successor->getVar();
2097     std::string name = op.getGetterName(var->name);
2098     if (var->isVariadic())
2099       body << "  ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n";
2100     else
2101       body << "  _odsPrinter << " << name << "();\n";
2102   } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
2103     genCustomDirectivePrinter(dir, op, body);
2104   } else if (isa<OperandsDirective>(element)) {
2105     body << "  _odsPrinter << getOperation()->getOperands();\n";
2106   } else if (isa<RegionsDirective>(element)) {
2107     genVariadicRegionPrinter("getOperation()->getRegions()", body,
2108                              hasImplicitTermTrait);
2109   } else if (isa<SuccessorsDirective>(element)) {
2110     body << "  ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2111             "_odsPrinter);\n";
2112   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
2113     if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand())) {
2114       if (operand->getVar()->isVariadicOfVariadic()) {
2115         body << llvm::formatv(
2116             "  ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2117             "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2118             "types << \")\"; });\n",
2119             op.getGetterName(operand->getVar()->name));
2120         return;
2121       }
2122     }
2123     const NamedTypeConstraint *var = nullptr;
2124     {
2125       if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand()))
2126         var = operand->getVar();
2127       else if (auto *operand = dyn_cast<ResultVariable>(dir->getOperand()))
2128         var = operand->getVar();
2129     }
2130     if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
2131         !var->isOptional()) {
2132       std::string cppClass = var->constraint.getCPPClassName();
2133       if (dir->shouldBeQualified()) {
2134         body << "   _odsPrinter << " << op.getGetterName(var->name)
2135              << "().getType();\n";
2136         return;
2137       }
2138       body << "  {\n"
2139            << "    auto type = " << op.getGetterName(var->name)
2140            << "().getType();\n"
2141            << "    if (auto validType = type.dyn_cast<" << cppClass << ">())\n"
2142            << "      _odsPrinter.printStrippedAttrOrType(validType);\n"
2143            << "   else\n"
2144            << "     _odsPrinter << type;\n"
2145            << "  }\n";
2146       return;
2147     }
2148     body << "  _odsPrinter << ";
2149     genTypeOperandPrinter(dir->getOperand(), op, body, /*useArrayRef=*/false)
2150         << ";\n";
2151   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
2152     body << "  _odsPrinter.printFunctionalType(";
2153     genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";
2154     genTypeOperandPrinter(dir->getResults(), op, body) << ");\n";
2155   } else {
2156     llvm_unreachable("unknown format element");
2157   }
2158 }
2159 
2160 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
2161   auto *method = opClass.addMethod(
2162       "void", "print",
2163       MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2164   auto &body = method->body();
2165 
2166   // Flags for if we should emit a space, and if the last element was
2167   // punctuation.
2168   bool shouldEmitSpace = true, lastWasPunctuation = false;
2169   for (auto &element : elements)
2170     genElementPrinter(element.get(), body, op, shouldEmitSpace,
2171                       lastWasPunctuation);
2172 }
2173 
2174 //===----------------------------------------------------------------------===//
2175 // FormatParser
2176 //===----------------------------------------------------------------------===//
2177 
2178 /// Function to find an element within the given range that has the same name as
2179 /// 'name'.
2180 template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
2181   auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2182   return it != range.end() ? &*it : nullptr;
2183 }
2184 
2185 namespace {
2186 /// This class implements a parser for an instance of an operation assembly
2187 /// format.
2188 class FormatParser {
2189 public:
2190   FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2191       : lexer(mgr, op.getLoc()[0]), curToken(lexer.lexToken()), fmt(format),
2192         op(op), seenOperandTypes(op.getNumOperands()),
2193         seenResultTypes(op.getNumResults()) {}
2194 
2195   /// Parse the operation assembly format.
2196   LogicalResult parse();
2197 
2198 private:
2199   /// The current context of the parser when parsing an element.
2200   enum ParserContext {
2201     /// The element is being parsed in a "top-level" context, i.e. at the top of
2202     /// the format or in an optional group.
2203     TopLevelContext,
2204     /// The element is being parsed as a custom directive child.
2205     CustomDirectiveContext,
2206     /// The element is being parsed as a type directive child.
2207     TypeDirectiveContext,
2208     /// The element is being parsed as a reference directive child.
2209     RefDirectiveContext
2210   };
2211 
2212   /// This struct represents a type resolution instance. It includes a specific
2213   /// type as well as an optional transformer to apply to that type in order to
2214   /// properly resolve the type of a variable.
2215   struct TypeResolutionInstance {
2216     ConstArgument resolver;
2217     Optional<StringRef> transformer;
2218   };
2219 
2220   /// An iterator over the elements of a format group.
2221   using ElementsIterT = llvm::pointee_iterator<
2222       std::vector<std::unique_ptr<Element>>::const_iterator>;
2223 
2224   /// Verify the state of operation attributes within the format.
2225   LogicalResult verifyAttributes(llvm::SMLoc loc);
2226   /// Verify the attribute elements at the back of the given stack of iterators.
2227   LogicalResult verifyAttributes(
2228       llvm::SMLoc loc,
2229       SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack);
2230 
2231   /// Verify the state of operation operands within the format.
2232   LogicalResult
2233   verifyOperands(llvm::SMLoc loc,
2234                  llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2235 
2236   /// Verify the state of operation regions within the format.
2237   LogicalResult verifyRegions(llvm::SMLoc loc);
2238 
2239   /// Verify the state of operation results within the format.
2240   LogicalResult
2241   verifyResults(llvm::SMLoc loc,
2242                 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2243 
2244   /// Verify the state of operation successors within the format.
2245   LogicalResult verifySuccessors(llvm::SMLoc loc);
2246 
2247   /// Given the values of an `AllTypesMatch` trait, check for inferable type
2248   /// resolution.
2249   void handleAllTypesMatchConstraint(
2250       ArrayRef<StringRef> values,
2251       llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2252   /// Check for inferable type resolution given all operands, and or results,
2253   /// have the same type. If 'includeResults' is true, the results also have the
2254   /// same type as all of the operands.
2255   void handleSameTypesConstraint(
2256       llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2257       bool includeResults);
2258   /// Check for inferable type resolution based on another operand, result, or
2259   /// attribute.
2260   void handleTypesMatchConstraint(
2261       llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2262       const llvm::Record &def);
2263 
2264   /// Returns an argument or attribute with the given name that has been seen
2265   /// within the format.
2266   ConstArgument findSeenArg(StringRef name);
2267 
2268   /// Parse a specific element.
2269   LogicalResult parseElement(std::unique_ptr<Element> &element,
2270                              ParserContext context);
2271   LogicalResult parseVariable(std::unique_ptr<Element> &element,
2272                               ParserContext context);
2273   LogicalResult parseDirective(std::unique_ptr<Element> &element,
2274                                ParserContext context);
2275   LogicalResult parseLiteral(std::unique_ptr<Element> &element,
2276                              ParserContext context);
2277   LogicalResult parseOptional(std::unique_ptr<Element> &element,
2278                               ParserContext context);
2279   LogicalResult parseOptionalChildElement(
2280       std::vector<std::unique_ptr<Element>> &childElements,
2281       Optional<unsigned> &anchorIdx);
2282   LogicalResult verifyOptionalChildElement(Element *element,
2283                                            llvm::SMLoc childLoc, bool isAnchor);
2284 
2285   /// Parse the various different directives.
2286   LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
2287                                        llvm::SMLoc loc, ParserContext context,
2288                                        bool withKeyword);
2289   LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
2290                                      llvm::SMLoc loc, ParserContext context);
2291   LogicalResult parseCustomDirectiveParameter(
2292       std::vector<std::unique_ptr<Element>> &parameters);
2293   LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
2294                                              FormatToken tok,
2295                                              ParserContext context);
2296   LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
2297                                        llvm::SMLoc loc, ParserContext context);
2298   LogicalResult parseQualifiedDirective(std::unique_ptr<Element> &element,
2299                                         FormatToken tok, ParserContext context);
2300   LogicalResult parseReferenceDirective(std::unique_ptr<Element> &element,
2301                                         llvm::SMLoc loc, ParserContext context);
2302   LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
2303                                       llvm::SMLoc loc, ParserContext context);
2304   LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
2305                                       llvm::SMLoc loc, ParserContext context);
2306   LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
2307                                          llvm::SMLoc loc,
2308                                          ParserContext context);
2309   LogicalResult parseTypeDirective(std::unique_ptr<Element> &element,
2310                                    FormatToken tok, ParserContext context);
2311   LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
2312                                           bool isRefChild = false);
2313 
2314   //===--------------------------------------------------------------------===//
2315   // Lexer Utilities
2316   //===--------------------------------------------------------------------===//
2317 
2318   /// Advance the current lexer onto the next token.
2319   void consumeToken() {
2320     assert(curToken.getKind() != FormatToken::eof &&
2321            curToken.getKind() != FormatToken::error &&
2322            "shouldn't advance past EOF or errors");
2323     curToken = lexer.lexToken();
2324   }
2325   LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
2326     if (curToken.getKind() != kind)
2327       return emitError(curToken.getLoc(), msg);
2328     consumeToken();
2329     return ::mlir::success();
2330   }
2331   LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
2332     lexer.emitError(loc, msg);
2333     return ::mlir::failure();
2334   }
2335   LogicalResult emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
2336                                  const Twine &note) {
2337     lexer.emitErrorAndNote(loc, msg, note);
2338     return ::mlir::failure();
2339   }
2340 
2341   //===--------------------------------------------------------------------===//
2342   // Fields
2343   //===--------------------------------------------------------------------===//
2344 
2345   FormatLexer lexer;
2346   FormatToken curToken;
2347   OperationFormat &fmt;
2348   Operator &op;
2349 
2350   // The following are various bits of format state used for verification
2351   // during parsing.
2352   bool hasAttrDict = false;
2353   bool hasAllRegions = false, hasAllSuccessors = false;
2354   bool canInferResultTypes = false;
2355   llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2356   llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2357   llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2358   llvm::DenseSet<const NamedRegion *> seenRegions;
2359   llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2360 };
2361 } // namespace
2362 
2363 LogicalResult FormatParser::parse() {
2364   llvm::SMLoc loc = curToken.getLoc();
2365 
2366   // Parse each of the format elements into the main format.
2367   while (curToken.getKind() != FormatToken::eof) {
2368     std::unique_ptr<Element> element;
2369     if (failed(parseElement(element, TopLevelContext)))
2370       return ::mlir::failure();
2371     fmt.elements.push_back(std::move(element));
2372   }
2373 
2374   // Check that the attribute dictionary is in the format.
2375   if (!hasAttrDict)
2376     return emitError(loc, "'attr-dict' directive not found in "
2377                           "custom assembly format");
2378 
2379   // Check for any type traits that we can use for inferring types.
2380   llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2381   for (const Trait &trait : op.getTraits()) {
2382     const llvm::Record &def = trait.getDef();
2383     if (def.isSubClassOf("AllTypesMatch")) {
2384       handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2385                                     variableTyResolver);
2386     } else if (def.getName() == "SameTypeOperands") {
2387       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2388     } else if (def.getName() == "SameOperandsAndResultType") {
2389       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2390     } else if (def.isSubClassOf("TypesMatchWith")) {
2391       handleTypesMatchConstraint(variableTyResolver, def);
2392     } else if (!op.allResultTypesKnown()) {
2393       // This doesn't check the name directly to handle
2394       //    DeclareOpInterfaceMethods<InferTypeOpInterface>
2395       // and the like.
2396       // TODO: Add hasCppInterface check.
2397       if (auto name = def.getValueAsOptionalString("cppClassName")) {
2398         if (*name == "InferTypeOpInterface" &&
2399             def.getValueAsString("cppNamespace") == "::mlir")
2400           canInferResultTypes = true;
2401       }
2402     }
2403   }
2404 
2405   // Verify the state of the various operation components.
2406   if (failed(verifyAttributes(loc)) ||
2407       failed(verifyResults(loc, variableTyResolver)) ||
2408       failed(verifyOperands(loc, variableTyResolver)) ||
2409       failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
2410     return ::mlir::failure();
2411 
2412   // Collect the set of used attributes in the format.
2413   fmt.usedAttributes = seenAttrs.takeVector();
2414   return ::mlir::success();
2415 }
2416 
2417 LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
2418   // Check that there are no `:` literals after an attribute without a constant
2419   // type. The attribute grammar contains an optional trailing colon type, which
2420   // can lead to unexpected and generally unintended behavior. Given that, it is
2421   // better to just error out here instead.
2422   using ElementsIterT = llvm::pointee_iterator<
2423       std::vector<std::unique_ptr<Element>>::const_iterator>;
2424   SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack;
2425   iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end());
2426   while (!iteratorStack.empty())
2427     if (failed(verifyAttributes(loc, iteratorStack)))
2428       return ::mlir::failure();
2429 
2430   // Check for VariadicOfVariadic variables. The segment attribute of those
2431   // variables will be infered.
2432   for (const NamedTypeConstraint *var : seenOperands) {
2433     if (var->constraint.isVariadicOfVariadic()) {
2434       fmt.inferredAttributes.insert(
2435           var->constraint.getVariadicOfVariadicSegmentSizeAttr());
2436     }
2437   }
2438 
2439   return ::mlir::success();
2440 }
2441 /// Verify the attribute elements at the back of the given stack of iterators.
2442 LogicalResult FormatParser::verifyAttributes(
2443     llvm::SMLoc loc,
2444     SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack) {
2445   auto &stackIt = iteratorStack.back();
2446   ElementsIterT &it = stackIt.first, e = stackIt.second;
2447   while (it != e) {
2448     Element *element = &*(it++);
2449 
2450     // Traverse into optional groups.
2451     if (auto *optional = dyn_cast<OptionalElement>(element)) {
2452       auto thenElements = optional->getThenElements();
2453       iteratorStack.emplace_back(thenElements.begin(), thenElements.end());
2454 
2455       auto elseElements = optional->getElseElements();
2456       iteratorStack.emplace_back(elseElements.begin(), elseElements.end());
2457       return ::mlir::success();
2458     }
2459 
2460     // We are checking for an attribute element followed by a `:`, so there is
2461     // no need to check the end.
2462     if (it == e && iteratorStack.size() == 1)
2463       break;
2464 
2465     // Check for an attribute with a constant type builder, followed by a `:`.
2466     auto *prevAttr = dyn_cast<AttributeVariable>(element);
2467     if (!prevAttr || prevAttr->getTypeBuilder())
2468       continue;
2469 
2470     // Check the next iterator within the stack for literal elements.
2471     for (auto &nextItPair : iteratorStack) {
2472       ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
2473       for (; nextIt != nextE; ++nextIt) {
2474         // Skip any trailing whitespace, attribute dictionaries, or optional
2475         // groups.
2476         if (isa<WhitespaceElement>(*nextIt) ||
2477             isa<AttrDictDirective>(*nextIt) || isa<OptionalElement>(*nextIt))
2478           continue;
2479 
2480         // We are only interested in `:` literals.
2481         auto *literal = dyn_cast<LiteralElement>(&*nextIt);
2482         if (!literal || literal->getLiteral() != ":")
2483           break;
2484 
2485         // TODO: Use the location of the literal element itself.
2486         return emitError(
2487             loc, llvm::formatv("format ambiguity caused by `:` literal found "
2488                                "after attribute `{0}` which does not have "
2489                                "a buildable type",
2490                                prevAttr->getVar()->name));
2491       }
2492     }
2493   }
2494   iteratorStack.pop_back();
2495   return ::mlir::success();
2496 }
2497 
2498 LogicalResult FormatParser::verifyOperands(
2499     llvm::SMLoc loc,
2500     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2501   // Check that all of the operands are within the format, and their types can
2502   // be inferred.
2503   auto &buildableTypes = fmt.buildableTypes;
2504   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2505     NamedTypeConstraint &operand = op.getOperand(i);
2506 
2507     // Check that the operand itself is in the format.
2508     if (!fmt.allOperands && !seenOperands.count(&operand)) {
2509       return emitErrorAndNote(loc,
2510                               "operand #" + Twine(i) + ", named '" +
2511                                   operand.name + "', not found",
2512                               "suggest adding a '$" + operand.name +
2513                                   "' directive to the custom assembly format");
2514     }
2515 
2516     // Check that the operand type is in the format, or that it can be inferred.
2517     if (fmt.allOperandTypes || seenOperandTypes.test(i))
2518       continue;
2519 
2520     // Check to see if we can infer this type from another variable.
2521     auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
2522     if (varResolverIt != variableTyResolver.end()) {
2523       TypeResolutionInstance &resolver = varResolverIt->second;
2524       fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
2525       continue;
2526     }
2527 
2528     // Similarly to results, allow a custom builder for resolving the type if
2529     // we aren't using the 'operands' directive.
2530     Optional<StringRef> builder = operand.constraint.getBuilderCall();
2531     if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2532       return emitErrorAndNote(
2533           loc,
2534           "type of operand #" + Twine(i) + ", named '" + operand.name +
2535               "', is not buildable and a buildable type cannot be inferred",
2536           "suggest adding a type constraint to the operation or adding a "
2537           "'type($" +
2538               operand.name + ")' directive to the " + "custom assembly format");
2539     }
2540     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2541     fmt.operandTypes[i].setBuilderIdx(it.first->second);
2542   }
2543   return ::mlir::success();
2544 }
2545 
2546 LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) {
2547   // Check that all of the regions are within the format.
2548   if (hasAllRegions)
2549     return ::mlir::success();
2550 
2551   for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2552     const NamedRegion &region = op.getRegion(i);
2553     if (!seenRegions.count(&region)) {
2554       return emitErrorAndNote(loc,
2555                               "region #" + Twine(i) + ", named '" +
2556                                   region.name + "', not found",
2557                               "suggest adding a '$" + region.name +
2558                                   "' directive to the custom assembly format");
2559     }
2560   }
2561   return ::mlir::success();
2562 }
2563 
2564 LogicalResult FormatParser::verifyResults(
2565     llvm::SMLoc loc,
2566     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2567   // If we format all of the types together, there is nothing to check.
2568   if (fmt.allResultTypes)
2569     return ::mlir::success();
2570 
2571   // If no result types are specified and we can infer them, infer all result
2572   // types
2573   if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
2574       canInferResultTypes) {
2575     fmt.infersResultTypes = true;
2576     return ::mlir::success();
2577   }
2578 
2579   // Check that all of the result types can be inferred.
2580   auto &buildableTypes = fmt.buildableTypes;
2581   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2582     if (seenResultTypes.test(i))
2583       continue;
2584 
2585     // Check to see if we can infer this type from another variable.
2586     auto varResolverIt = variableTyResolver.find(op.getResultName(i));
2587     if (varResolverIt != variableTyResolver.end()) {
2588       TypeResolutionInstance resolver = varResolverIt->second;
2589       fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
2590       continue;
2591     }
2592 
2593     // If the result is not variable length, allow for the case where the type
2594     // has a builder that we can use.
2595     NamedTypeConstraint &result = op.getResult(i);
2596     Optional<StringRef> builder = result.constraint.getBuilderCall();
2597     if (!builder || result.isVariableLength()) {
2598       return emitErrorAndNote(
2599           loc,
2600           "type of result #" + Twine(i) + ", named '" + result.name +
2601               "', is not buildable and a buildable type cannot be inferred",
2602           "suggest adding a type constraint to the operation or adding a "
2603           "'type($" +
2604               result.name + ")' directive to the " + "custom assembly format");
2605     }
2606     // Note in the format that this result uses the custom builder.
2607     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2608     fmt.resultTypes[i].setBuilderIdx(it.first->second);
2609   }
2610   return ::mlir::success();
2611 }
2612 
2613 LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) {
2614   // Check that all of the successors are within the format.
2615   if (hasAllSuccessors)
2616     return ::mlir::success();
2617 
2618   for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
2619     const NamedSuccessor &successor = op.getSuccessor(i);
2620     if (!seenSuccessors.count(&successor)) {
2621       return emitErrorAndNote(loc,
2622                               "successor #" + Twine(i) + ", named '" +
2623                                   successor.name + "', not found",
2624                               "suggest adding a '$" + successor.name +
2625                                   "' directive to the custom assembly format");
2626     }
2627   }
2628   return ::mlir::success();
2629 }
2630 
2631 void FormatParser::handleAllTypesMatchConstraint(
2632     ArrayRef<StringRef> values,
2633     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2634   for (unsigned i = 0, e = values.size(); i != e; ++i) {
2635     // Check to see if this value matches a resolved operand or result type.
2636     ConstArgument arg = findSeenArg(values[i]);
2637     if (!arg)
2638       continue;
2639 
2640     // Mark this value as the type resolver for the other variables.
2641     for (unsigned j = 0; j != i; ++j)
2642       variableTyResolver[values[j]] = {arg, llvm::None};
2643     for (unsigned j = i + 1; j != e; ++j)
2644       variableTyResolver[values[j]] = {arg, llvm::None};
2645   }
2646 }
2647 
2648 void FormatParser::handleSameTypesConstraint(
2649     llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2650     bool includeResults) {
2651   const NamedTypeConstraint *resolver = nullptr;
2652   int resolvedIt = -1;
2653 
2654   // Check to see if there is an operand or result to use for the resolution.
2655   if ((resolvedIt = seenOperandTypes.find_first()) != -1)
2656     resolver = &op.getOperand(resolvedIt);
2657   else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
2658     resolver = &op.getResult(resolvedIt);
2659   else
2660     return;
2661 
2662   // Set the resolvers for each operand and result.
2663   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
2664     if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
2665       variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
2666   if (includeResults) {
2667     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2668       if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
2669         variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
2670   }
2671 }
2672 
2673 void FormatParser::handleTypesMatchConstraint(
2674     llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2675     const llvm::Record &def) {
2676   StringRef lhsName = def.getValueAsString("lhs");
2677   StringRef rhsName = def.getValueAsString("rhs");
2678   StringRef transformer = def.getValueAsString("transformer");
2679   if (ConstArgument arg = findSeenArg(lhsName))
2680     variableTyResolver[rhsName] = {arg, transformer};
2681 }
2682 
2683 ConstArgument FormatParser::findSeenArg(StringRef name) {
2684   if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
2685     return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
2686   if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
2687     return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
2688   if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
2689     return seenAttrs.count(attr) ? attr : nullptr;
2690   return nullptr;
2691 }
2692 
2693 LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
2694                                          ParserContext context) {
2695   // Directives.
2696   if (curToken.isKeyword())
2697     return parseDirective(element, context);
2698   // Literals.
2699   if (curToken.getKind() == FormatToken::literal)
2700     return parseLiteral(element, context);
2701   // Optionals.
2702   if (curToken.getKind() == FormatToken::l_paren)
2703     return parseOptional(element, context);
2704   // Variables.
2705   if (curToken.getKind() == FormatToken::variable)
2706     return parseVariable(element, context);
2707   return emitError(curToken.getLoc(),
2708                    "expected directive, literal, variable, or optional group");
2709 }
2710 
2711 LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
2712                                           ParserContext context) {
2713   FormatToken varTok = curToken;
2714   consumeToken();
2715 
2716   StringRef name = varTok.getSpelling().drop_front();
2717   llvm::SMLoc loc = varTok.getLoc();
2718 
2719   // Check that the parsed argument is something actually registered on the
2720   // op.
2721   /// Attributes
2722   if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
2723     if (context == TypeDirectiveContext)
2724       return emitError(
2725           loc, "attributes cannot be used as children to a `type` directive");
2726     if (context == RefDirectiveContext) {
2727       if (!seenAttrs.count(attr))
2728         return emitError(loc, "attribute '" + name +
2729                                   "' must be bound before it is referenced");
2730     } else if (!seenAttrs.insert(attr)) {
2731       return emitError(loc, "attribute '" + name + "' is already bound");
2732     }
2733 
2734     element = std::make_unique<AttributeVariable>(attr);
2735     return ::mlir::success();
2736   }
2737   /// Operands
2738   if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
2739     if (context == TopLevelContext || context == CustomDirectiveContext) {
2740       if (fmt.allOperands || !seenOperands.insert(operand).second)
2741         return emitError(loc, "operand '" + name + "' is already bound");
2742     } else if (context == RefDirectiveContext && !seenOperands.count(operand)) {
2743       return emitError(loc, "operand '" + name +
2744                                 "' must be bound before it is referenced");
2745     }
2746     element = std::make_unique<OperandVariable>(operand);
2747     return ::mlir::success();
2748   }
2749   /// Regions
2750   if (const NamedRegion *region = findArg(op.getRegions(), name)) {
2751     if (context == TopLevelContext || context == CustomDirectiveContext) {
2752       if (hasAllRegions || !seenRegions.insert(region).second)
2753         return emitError(loc, "region '" + name + "' is already bound");
2754     } else if (context == RefDirectiveContext && !seenRegions.count(region)) {
2755       return emitError(loc, "region '" + name +
2756                                 "' must be bound before it is referenced");
2757     } else {
2758       return emitError(loc, "regions can only be used at the top level");
2759     }
2760     element = std::make_unique<RegionVariable>(region);
2761     return ::mlir::success();
2762   }
2763   /// Results.
2764   if (const auto *result = findArg(op.getResults(), name)) {
2765     if (context != TypeDirectiveContext)
2766       return emitError(loc, "result variables can can only be used as a child "
2767                             "to a 'type' directive");
2768     element = std::make_unique<ResultVariable>(result);
2769     return ::mlir::success();
2770   }
2771   /// Successors.
2772   if (const auto *successor = findArg(op.getSuccessors(), name)) {
2773     if (context == TopLevelContext || context == CustomDirectiveContext) {
2774       if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
2775         return emitError(loc, "successor '" + name + "' is already bound");
2776     } else if (context == RefDirectiveContext &&
2777                !seenSuccessors.count(successor)) {
2778       return emitError(loc, "successor '" + name +
2779                                 "' must be bound before it is referenced");
2780     } else {
2781       return emitError(loc, "successors can only be used at the top level");
2782     }
2783 
2784     element = std::make_unique<SuccessorVariable>(successor);
2785     return ::mlir::success();
2786   }
2787   return emitError(loc, "expected variable to refer to an argument, region, "
2788                         "result, or successor");
2789 }
2790 
2791 LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
2792                                            ParserContext context) {
2793   FormatToken dirTok = curToken;
2794   consumeToken();
2795 
2796   switch (dirTok.getKind()) {
2797   case FormatToken::kw_attr_dict:
2798     return parseAttrDictDirective(element, dirTok.getLoc(), context,
2799                                   /*withKeyword=*/false);
2800   case FormatToken::kw_attr_dict_w_keyword:
2801     return parseAttrDictDirective(element, dirTok.getLoc(), context,
2802                                   /*withKeyword=*/true);
2803   case FormatToken::kw_custom:
2804     return parseCustomDirective(element, dirTok.getLoc(), context);
2805   case FormatToken::kw_functional_type:
2806     return parseFunctionalTypeDirective(element, dirTok, context);
2807   case FormatToken::kw_operands:
2808     return parseOperandsDirective(element, dirTok.getLoc(), context);
2809   case FormatToken::kw_qualified:
2810     return parseQualifiedDirective(element, dirTok, context);
2811   case FormatToken::kw_regions:
2812     return parseRegionsDirective(element, dirTok.getLoc(), context);
2813   case FormatToken::kw_results:
2814     return parseResultsDirective(element, dirTok.getLoc(), context);
2815   case FormatToken::kw_successors:
2816     return parseSuccessorsDirective(element, dirTok.getLoc(), context);
2817   case FormatToken::kw_ref:
2818     return parseReferenceDirective(element, dirTok.getLoc(), context);
2819   case FormatToken::kw_type:
2820     return parseTypeDirective(element, dirTok, context);
2821 
2822   default:
2823     llvm_unreachable("unknown directive token");
2824   }
2825 }
2826 
2827 LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
2828                                          ParserContext context) {
2829   FormatToken literalTok = curToken;
2830   if (context != TopLevelContext) {
2831     return emitError(
2832         literalTok.getLoc(),
2833         "literals may only be used in a top-level section of the format");
2834   }
2835   consumeToken();
2836 
2837   StringRef value = literalTok.getSpelling().drop_front().drop_back();
2838 
2839   // The parsed literal is a space element (`` or ` `).
2840   if (value.empty() || (value.size() == 1 && value.front() == ' ')) {
2841     element = std::make_unique<SpaceElement>(!value.empty());
2842     return ::mlir::success();
2843   }
2844   // The parsed literal is a newline element.
2845   if (value == "\\n") {
2846     element = std::make_unique<NewlineElement>();
2847     return ::mlir::success();
2848   }
2849 
2850   // Check that the parsed literal is valid.
2851   if (!isValidLiteral(value, [&](Twine diag) {
2852         (void)emitError(literalTok.getLoc(),
2853                         "expected valid literal but got '" + value +
2854                             "': " + diag);
2855       }))
2856     return failure();
2857   element = std::make_unique<LiteralElement>(value);
2858   return ::mlir::success();
2859 }
2860 
2861 LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
2862                                           ParserContext context) {
2863   llvm::SMLoc curLoc = curToken.getLoc();
2864   if (context != TopLevelContext)
2865     return emitError(curLoc, "optional groups can only be used as top-level "
2866                              "elements");
2867   consumeToken();
2868 
2869   // Parse the child elements for this optional group.
2870   std::vector<std::unique_ptr<Element>> thenElements, elseElements;
2871   Optional<unsigned> anchorIdx;
2872   do {
2873     if (failed(parseOptionalChildElement(thenElements, anchorIdx)))
2874       return ::mlir::failure();
2875   } while (curToken.getKind() != FormatToken::r_paren);
2876   consumeToken();
2877 
2878   // Parse the `else` elements of this optional group.
2879   if (curToken.getKind() == FormatToken::colon) {
2880     consumeToken();
2881     if (failed(parseToken(FormatToken::l_paren,
2882                           "expected '(' to start else branch "
2883                           "of optional group")))
2884       return failure();
2885     do {
2886       llvm::SMLoc childLoc = curToken.getLoc();
2887       elseElements.push_back({});
2888       if (failed(parseElement(elseElements.back(), TopLevelContext)) ||
2889           failed(verifyOptionalChildElement(elseElements.back().get(), childLoc,
2890                                             /*isAnchor=*/false)))
2891         return failure();
2892     } while (curToken.getKind() != FormatToken::r_paren);
2893     consumeToken();
2894   }
2895 
2896   if (failed(parseToken(FormatToken::question,
2897                         "expected '?' after optional group")))
2898     return ::mlir::failure();
2899 
2900   // The optional group is required to have an anchor.
2901   if (!anchorIdx)
2902     return emitError(curLoc, "optional group specified no anchor element");
2903 
2904   // The first parsable element of the group must be able to be parsed in an
2905   // optional fashion.
2906   auto parseBegin = llvm::find_if_not(thenElements, [](auto &element) {
2907     return isa<WhitespaceElement>(element.get());
2908   });
2909   Element *firstElement = parseBegin->get();
2910   if (!isa<AttributeVariable>(firstElement) &&
2911       !isa<LiteralElement>(firstElement) &&
2912       !isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement))
2913     return emitError(curLoc,
2914                      "first parsable element of an operand group must be "
2915                      "an attribute, literal, operand, or region");
2916 
2917   auto parseStart = parseBegin - thenElements.begin();
2918   element = std::make_unique<OptionalElement>(
2919       std::move(thenElements), std::move(elseElements), *anchorIdx, parseStart);
2920   return ::mlir::success();
2921 }
2922 
2923 LogicalResult FormatParser::parseOptionalChildElement(
2924     std::vector<std::unique_ptr<Element>> &childElements,
2925     Optional<unsigned> &anchorIdx) {
2926   llvm::SMLoc childLoc = curToken.getLoc();
2927   childElements.push_back({});
2928   if (failed(parseElement(childElements.back(), TopLevelContext)))
2929     return ::mlir::failure();
2930 
2931   // Check to see if this element is the anchor of the optional group.
2932   bool isAnchor = curToken.getKind() == FormatToken::caret;
2933   if (isAnchor) {
2934     if (anchorIdx)
2935       return emitError(childLoc, "only one element can be marked as the anchor "
2936                                  "of an optional group");
2937     anchorIdx = childElements.size() - 1;
2938     consumeToken();
2939   }
2940 
2941   return verifyOptionalChildElement(childElements.back().get(), childLoc,
2942                                     isAnchor);
2943 }
2944 
2945 LogicalResult FormatParser::verifyOptionalChildElement(Element *element,
2946                                                        llvm::SMLoc childLoc,
2947                                                        bool isAnchor) {
2948   return TypeSwitch<Element *, LogicalResult>(element)
2949       // All attributes can be within the optional group, but only optional
2950       // attributes can be the anchor.
2951       .Case([&](AttributeVariable *attrEle) {
2952         if (isAnchor && !attrEle->getVar()->attr.isOptional())
2953           return emitError(childLoc, "only optional attributes can be used to "
2954                                      "anchor an optional group");
2955         return ::mlir::success();
2956       })
2957       // Only optional-like(i.e. variadic) operands can be within an optional
2958       // group.
2959       .Case([&](OperandVariable *ele) {
2960         if (!ele->getVar()->isVariableLength())
2961           return emitError(childLoc, "only variable length operands can be "
2962                                      "used within an optional group");
2963         return ::mlir::success();
2964       })
2965       // Only optional-like(i.e. variadic) results can be within an optional
2966       // group.
2967       .Case([&](ResultVariable *ele) {
2968         if (!ele->getVar()->isVariableLength())
2969           return emitError(childLoc, "only variable length results can be "
2970                                      "used within an optional group");
2971         return ::mlir::success();
2972       })
2973       .Case([&](RegionVariable *) {
2974         // TODO: When ODS has proper support for marking "optional" regions, add
2975         // a check here.
2976         return ::mlir::success();
2977       })
2978       .Case([&](TypeDirective *ele) {
2979         return verifyOptionalChildElement(ele->getOperand(), childLoc,
2980                                           /*isAnchor=*/false);
2981       })
2982       .Case([&](FunctionalTypeDirective *ele) {
2983         if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc,
2984                                               /*isAnchor=*/false)))
2985           return failure();
2986         return verifyOptionalChildElement(ele->getResults(), childLoc,
2987                                           /*isAnchor=*/false);
2988       })
2989       // Literals, whitespace, and custom directives may be used, but they can't
2990       // anchor the group.
2991       .Case<LiteralElement, WhitespaceElement, CustomDirective,
2992             FunctionalTypeDirective, OptionalElement>([&](Element *) {
2993         if (isAnchor)
2994           return emitError(childLoc, "only variables and types can be used "
2995                                      "to anchor an optional group");
2996         return ::mlir::success();
2997       })
2998       .Default([&](Element *) {
2999         return emitError(childLoc, "only literals, types, and variables can be "
3000                                    "used within an optional group");
3001       });
3002 }
3003 
3004 LogicalResult
3005 FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
3006                                      llvm::SMLoc loc, ParserContext context,
3007                                      bool withKeyword) {
3008   if (context == TypeDirectiveContext)
3009     return emitError(loc, "'attr-dict' directive can only be used as a "
3010                           "top-level directive");
3011 
3012   if (context == RefDirectiveContext) {
3013     if (!hasAttrDict)
3014       return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior "
3015                             "'attr-dict' directive");
3016 
3017     // Otherwise, this is a top-level context.
3018   } else {
3019     if (hasAttrDict)
3020       return emitError(loc, "'attr-dict' directive has already been seen");
3021     hasAttrDict = true;
3022   }
3023 
3024   element = std::make_unique<AttrDictDirective>(withKeyword);
3025   return ::mlir::success();
3026 }
3027 
3028 LogicalResult
3029 FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
3030                                    llvm::SMLoc loc, ParserContext context) {
3031   llvm::SMLoc curLoc = curToken.getLoc();
3032   if (context != TopLevelContext)
3033     return emitError(loc, "'custom' is only valid as a top-level directive");
3034 
3035   // Parse the custom directive name.
3036   if (failed(parseToken(FormatToken::less,
3037                         "expected '<' before custom directive name")))
3038     return ::mlir::failure();
3039 
3040   FormatToken nameTok = curToken;
3041   if (failed(parseToken(FormatToken::identifier,
3042                         "expected custom directive name identifier")) ||
3043       failed(parseToken(FormatToken::greater,
3044                         "expected '>' after custom directive name")) ||
3045       failed(parseToken(FormatToken::l_paren,
3046                         "expected '(' before custom directive parameters")))
3047     return ::mlir::failure();
3048 
3049   // Parse the child elements for this optional group.=
3050   std::vector<std::unique_ptr<Element>> elements;
3051   do {
3052     if (failed(parseCustomDirectiveParameter(elements)))
3053       return ::mlir::failure();
3054     if (curToken.getKind() != FormatToken::comma)
3055       break;
3056     consumeToken();
3057   } while (true);
3058 
3059   if (failed(parseToken(FormatToken::r_paren,
3060                         "expected ')' after custom directive parameters")))
3061     return ::mlir::failure();
3062 
3063   // After parsing all of the elements, ensure that all type directives refer
3064   // only to variables.
3065   for (auto &ele : elements) {
3066     if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
3067       if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
3068         return emitError(curLoc, "type directives within a custom directive "
3069                                  "may only refer to variables");
3070       }
3071     }
3072   }
3073 
3074   element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
3075                                               std::move(elements));
3076   return ::mlir::success();
3077 }
3078 
3079 LogicalResult FormatParser::parseCustomDirectiveParameter(
3080     std::vector<std::unique_ptr<Element>> &parameters) {
3081   llvm::SMLoc childLoc = curToken.getLoc();
3082   parameters.push_back({});
3083   if (failed(parseElement(parameters.back(), CustomDirectiveContext)))
3084     return ::mlir::failure();
3085 
3086   // Verify that the element can be placed within a custom directive.
3087   if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable,
3088            OperandVariable, RegionVariable, SuccessorVariable>(
3089           parameters.back().get())) {
3090     return emitError(childLoc, "only variables and types may be used as "
3091                                "parameters to a custom directive");
3092   }
3093   return ::mlir::success();
3094 }
3095 
3096 LogicalResult FormatParser::parseFunctionalTypeDirective(
3097     std::unique_ptr<Element> &element, FormatToken tok, ParserContext context) {
3098   llvm::SMLoc loc = tok.getLoc();
3099   if (context != TopLevelContext)
3100     return emitError(
3101         loc, "'functional-type' is only valid as a top-level directive");
3102 
3103   // Parse the main operand.
3104   std::unique_ptr<Element> inputs, results;
3105   if (failed(parseToken(FormatToken::l_paren,
3106                         "expected '(' before argument list")) ||
3107       failed(parseTypeDirectiveOperand(inputs)) ||
3108       failed(parseToken(FormatToken::comma,
3109                         "expected ',' after inputs argument")) ||
3110       failed(parseTypeDirectiveOperand(results)) ||
3111       failed(
3112           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3113     return ::mlir::failure();
3114   element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
3115                                                       std::move(results));
3116   return ::mlir::success();
3117 }
3118 
3119 LogicalResult
3120 FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
3121                                      llvm::SMLoc loc, ParserContext context) {
3122   if (context == RefDirectiveContext) {
3123     if (!fmt.allOperands)
3124       return emitError(loc, "'ref' of 'operands' is not bound by a prior "
3125                             "'operands' directive");
3126 
3127   } else if (context == TopLevelContext || context == CustomDirectiveContext) {
3128     if (fmt.allOperands || !seenOperands.empty())
3129       return emitError(loc, "'operands' directive creates overlap in format");
3130     fmt.allOperands = true;
3131   }
3132   element = std::make_unique<OperandsDirective>();
3133   return ::mlir::success();
3134 }
3135 
3136 LogicalResult
3137 FormatParser::parseReferenceDirective(std::unique_ptr<Element> &element,
3138                                       llvm::SMLoc loc, ParserContext context) {
3139   if (context != CustomDirectiveContext)
3140     return emitError(loc, "'ref' is only valid within a `custom` directive");
3141 
3142   std::unique_ptr<Element> operand;
3143   if (failed(parseToken(FormatToken::l_paren,
3144                         "expected '(' before argument list")) ||
3145       failed(parseElement(operand, RefDirectiveContext)) ||
3146       failed(
3147           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3148     return ::mlir::failure();
3149 
3150   element = std::make_unique<RefDirective>(std::move(operand));
3151   return ::mlir::success();
3152 }
3153 
3154 LogicalResult
3155 FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
3156                                     llvm::SMLoc loc, ParserContext context) {
3157   if (context == TypeDirectiveContext)
3158     return emitError(loc, "'regions' is only valid as a top-level directive");
3159   if (context == RefDirectiveContext) {
3160     if (!hasAllRegions)
3161       return emitError(loc, "'ref' of 'regions' is not bound by a prior "
3162                             "'regions' directive");
3163 
3164     // Otherwise, this is a TopLevel directive.
3165   } else {
3166     if (hasAllRegions || !seenRegions.empty())
3167       return emitError(loc, "'regions' directive creates overlap in format");
3168     hasAllRegions = true;
3169   }
3170   element = std::make_unique<RegionsDirective>();
3171   return ::mlir::success();
3172 }
3173 
3174 LogicalResult
3175 FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
3176                                     llvm::SMLoc loc, ParserContext context) {
3177   if (context != TypeDirectiveContext)
3178     return emitError(loc, "'results' directive can can only be used as a child "
3179                           "to a 'type' directive");
3180   element = std::make_unique<ResultsDirective>();
3181   return ::mlir::success();
3182 }
3183 
3184 LogicalResult
3185 FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
3186                                        llvm::SMLoc loc, ParserContext context) {
3187   if (context == TypeDirectiveContext)
3188     return emitError(loc,
3189                      "'successors' is only valid as a top-level directive");
3190   if (context == RefDirectiveContext) {
3191     if (!hasAllSuccessors)
3192       return emitError(loc, "'ref' of 'successors' is not bound by a prior "
3193                             "'successors' directive");
3194 
3195     // Otherwise, this is a TopLevel directive.
3196   } else {
3197     if (hasAllSuccessors || !seenSuccessors.empty())
3198       return emitError(loc, "'successors' directive creates overlap in format");
3199     hasAllSuccessors = true;
3200   }
3201   element = std::make_unique<SuccessorsDirective>();
3202   return ::mlir::success();
3203 }
3204 
3205 LogicalResult
3206 FormatParser::parseTypeDirective(std::unique_ptr<Element> &element,
3207                                  FormatToken tok, ParserContext context) {
3208   llvm::SMLoc loc = tok.getLoc();
3209   if (context == TypeDirectiveContext)
3210     return emitError(loc, "'type' cannot be used as a child of another `type`");
3211 
3212   bool isRefChild = context == RefDirectiveContext;
3213   std::unique_ptr<Element> operand;
3214   if (failed(parseToken(FormatToken::l_paren,
3215                         "expected '(' before argument list")) ||
3216       failed(parseTypeDirectiveOperand(operand, isRefChild)) ||
3217       failed(
3218           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3219     return ::mlir::failure();
3220 
3221   element = std::make_unique<TypeDirective>(std::move(operand));
3222   return ::mlir::success();
3223 }
3224 
3225 LogicalResult
3226 FormatParser::parseQualifiedDirective(std::unique_ptr<Element> &element,
3227                                       FormatToken tok, ParserContext context) {
3228   if (failed(parseToken(FormatToken::l_paren,
3229                         "expected '(' before argument list")) ||
3230       failed(parseElement(element, context)) ||
3231       failed(
3232           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3233     return failure();
3234   if (auto *attr = dyn_cast<AttributeVariable>(element.get())) {
3235     attr->setShouldBeQualified();
3236   } else if (auto *type = dyn_cast<TypeDirective>(element.get())) {
3237     type->setShouldBeQualified();
3238   } else {
3239     return emitError(
3240         tok.getLoc(),
3241         "'qualified' directive expects an attribute or a `type` directive");
3242   }
3243   return success();
3244 }
3245 
3246 LogicalResult
3247 FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
3248                                         bool isRefChild) {
3249   llvm::SMLoc loc = curToken.getLoc();
3250   if (failed(parseElement(element, TypeDirectiveContext)))
3251     return ::mlir::failure();
3252   if (isa<LiteralElement>(element.get()))
3253     return emitError(
3254         loc, "'type' directive operand expects variable or directive operand");
3255 
3256   if (auto *var = dyn_cast<OperandVariable>(element.get())) {
3257     unsigned opIdx = var->getVar() - op.operand_begin();
3258     if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3259       return emitError(loc, "'type' of '" + var->getVar()->name +
3260                                 "' is already bound");
3261     if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3262       return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3263                                 ")' is not bound by a prior 'type' directive");
3264     seenOperandTypes.set(opIdx);
3265   } else if (auto *var = dyn_cast<ResultVariable>(element.get())) {
3266     unsigned resIdx = var->getVar() - op.result_begin();
3267     if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
3268       return emitError(loc, "'type' of '" + var->getVar()->name +
3269                                 "' is already bound");
3270     if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
3271       return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3272                                 ")' is not bound by a prior 'type' directive");
3273     seenResultTypes.set(resIdx);
3274   } else if (isa<OperandsDirective>(&*element)) {
3275     if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any()))
3276       return emitError(loc, "'operands' 'type' is already bound");
3277     if (isRefChild && !fmt.allOperandTypes)
3278       return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior "
3279                             "'type' directive");
3280     fmt.allOperandTypes = true;
3281   } else if (isa<ResultsDirective>(&*element)) {
3282     if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any()))
3283       return emitError(loc, "'results' 'type' is already bound");
3284     if (isRefChild && !fmt.allResultTypes)
3285       return emitError(loc, "'ref' of 'type(results)' is not bound by a prior "
3286                             "'type' directive");
3287     fmt.allResultTypes = true;
3288   } else {
3289     return emitError(loc, "invalid argument to 'type' directive");
3290   }
3291   return ::mlir::success();
3292 }
3293 
3294 //===----------------------------------------------------------------------===//
3295 // Interface
3296 //===----------------------------------------------------------------------===//
3297 
3298 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
3299   // TODO: Operator doesn't expose all necessary functionality via
3300   // the const interface.
3301   Operator &op = const_cast<Operator &>(constOp);
3302   if (!op.hasAssemblyFormat())
3303     return;
3304 
3305   // Parse the format description.
3306   llvm::SourceMgr mgr;
3307   mgr.AddNewSourceBuffer(
3308       llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), llvm::SMLoc());
3309   OperationFormat format(op);
3310   if (failed(FormatParser(mgr, format, op).parse())) {
3311     // Exit the process if format errors are treated as fatal.
3312     if (formatErrorIsFatal) {
3313       // Invoke the interrupt handlers to run the file cleanup handlers.
3314       llvm::sys::RunInterruptHandlers();
3315       std::exit(1);
3316     }
3317     return;
3318   }
3319 
3320   // Generate the printer and parser based on the parsed format.
3321   format.genParser(op, opClass);
3322   format.genPrinter(op, opClass);
3323 }
3324