1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
11 #include "OpClass.h"
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/TableGen/Record.h"
26 
27 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
28 
29 using namespace mlir;
30 using namespace mlir::tblgen;
31 
32 //===----------------------------------------------------------------------===//
33 // VariableElement
34 
35 namespace {
36 /// This class represents an instance of an op variable element. A variable
37 /// refers to something registered on the operation itself, e.g. an operand,
38 /// result, attribute, region, or successor.
39 template <typename VarT, VariableElement::Kind VariableKind>
40 class OpVariableElement : public VariableElementBase<VariableKind> {
41 public:
42   using Base = OpVariableElement<VarT, VariableKind>;
43 
44   /// Create an op variable element with the variable value.
OpVariableElement(const VarT * var)45   OpVariableElement(const VarT *var) : var(var) {}
46 
47   /// Get the variable.
getVar()48   const VarT *getVar() { return var; }
49 
50 protected:
51   /// The op variable, e.g. a type or attribute constraint.
52   const VarT *var;
53 };
54 
55 /// This class represents a variable that refers to an attribute argument.
56 struct AttributeVariable
57     : public OpVariableElement<NamedAttribute, VariableElement::Attribute> {
58   using Base::Base;
59 
60   /// Return the constant builder call for the type of this attribute, or None
61   /// if it doesn't have one.
getTypeBuilder__anonf2a272a10111::AttributeVariable62   llvm::Optional<StringRef> getTypeBuilder() const {
63     llvm::Optional<Type> attrType = var->attr.getValueType();
64     return attrType ? attrType->getBuilderCall() : llvm::None;
65   }
66 
67   /// Return if this attribute refers to a UnitAttr.
isUnitAttr__anonf2a272a10111::AttributeVariable68   bool isUnitAttr() const {
69     return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
70   }
71 
72   /// Indicate if this attribute is printed "qualified" (that is it is
73   /// prefixed with the `#dialect.mnemonic`).
shouldBeQualified__anonf2a272a10111::AttributeVariable74   bool shouldBeQualified() { return shouldBeQualifiedFlag; }
setShouldBeQualified__anonf2a272a10111::AttributeVariable75   void setShouldBeQualified(bool qualified = true) {
76     shouldBeQualifiedFlag = qualified;
77   }
78 
79 private:
80   bool shouldBeQualifiedFlag = false;
81 };
82 
83 /// This class represents a variable that refers to an operand argument.
84 using OperandVariable =
85     OpVariableElement<NamedTypeConstraint, VariableElement::Operand>;
86 
87 /// This class represents a variable that refers to a result.
88 using ResultVariable =
89     OpVariableElement<NamedTypeConstraint, VariableElement::Result>;
90 
91 /// This class represents a variable that refers to a region.
92 using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>;
93 
94 /// This class represents a variable that refers to a successor.
95 using SuccessorVariable =
96     OpVariableElement<NamedSuccessor, VariableElement::Successor>;
97 } // namespace
98 
99 //===----------------------------------------------------------------------===//
100 // DirectiveElement
101 
102 namespace {
103 /// This class represents the `operands` directive. This directive represents
104 /// all of the operands of an operation.
105 using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>;
106 
107 /// This class represents the `results` directive. This directive represents
108 /// all of the results of an operation.
109 using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>;
110 
111 /// This class represents the `regions` directive. This directive represents
112 /// all of the regions of an operation.
113 using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>;
114 
115 /// This class represents the `successors` directive. This directive represents
116 /// all of the successors of an operation.
117 using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>;
118 
119 /// This class represents the `attr-dict` directive. This directive represents
120 /// the attribute dictionary of the operation.
121 class AttrDictDirective
122     : public DirectiveElementBase<DirectiveElement::AttrDict> {
123 public:
AttrDictDirective(bool withKeyword)124   explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
125 
126   /// Return whether the dictionary should be printed with the 'attributes'
127   /// keyword.
isWithKeyword() const128   bool isWithKeyword() const { return withKeyword; }
129 
130 private:
131   /// If the dictionary should be printed with the 'attributes' keyword.
132   bool withKeyword;
133 };
134 
135 /// This class represents the `functional-type` directive. This directive takes
136 /// two arguments and formats them, respectively, as the inputs and results of a
137 /// FunctionType.
138 class FunctionalTypeDirective
139     : public DirectiveElementBase<DirectiveElement::FunctionalType> {
140 public:
FunctionalTypeDirective(FormatElement * inputs,FormatElement * results)141   FunctionalTypeDirective(FormatElement *inputs, FormatElement *results)
142       : inputs(inputs), results(results) {}
143 
getInputs() const144   FormatElement *getInputs() const { return inputs; }
getResults() const145   FormatElement *getResults() const { return results; }
146 
147 private:
148   /// The input and result arguments.
149   FormatElement *inputs, *results;
150 };
151 
152 /// This class represents the `type` directive.
153 class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
154 public:
TypeDirective(FormatElement * arg)155   TypeDirective(FormatElement *arg) : arg(arg) {}
156 
getArg() const157   FormatElement *getArg() const { return arg; }
158 
159   /// Indicate if this type is printed "qualified" (that is it is
160   /// prefixed with the `!dialect.mnemonic`).
shouldBeQualified()161   bool shouldBeQualified() { return shouldBeQualifiedFlag; }
setShouldBeQualified(bool qualified=true)162   void setShouldBeQualified(bool qualified = true) {
163     shouldBeQualifiedFlag = qualified;
164   }
165 
166 private:
167   /// The argument that is used to format the directive.
168   FormatElement *arg;
169 
170   bool shouldBeQualifiedFlag = false;
171 };
172 
173 /// This class represents a group of order-independent optional clauses. Each
174 /// clause starts with a literal element and has a coressponding parsing
175 /// element. A parsing element is a continous sequence of format elements.
176 /// Each clause can appear 0 or 1 time.
177 class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> {
178 public:
OIListElement(std::vector<FormatElement * > && literalElements,std::vector<std::vector<FormatElement * >> && parsingElements)179   OIListElement(std::vector<FormatElement *> &&literalElements,
180                 std::vector<std::vector<FormatElement *>> &&parsingElements)
181       : literalElements(std::move(literalElements)),
182         parsingElements(std::move(parsingElements)) {}
183 
184   /// Returns a range to iterate over the LiteralElements.
getLiteralElements() const185   auto getLiteralElements() const {
186     function_ref<LiteralElement *(FormatElement * el)>
187         literalElementCastConverter =
188             [](FormatElement *el) { return cast<LiteralElement>(el); };
189     return llvm::map_range(literalElements, literalElementCastConverter);
190   }
191 
192   /// Returns a range to iterate over the parsing elements corresponding to the
193   /// clauses.
getParsingElements() const194   ArrayRef<std::vector<FormatElement *>> getParsingElements() const {
195     return parsingElements;
196   }
197 
198   /// Returns a range to iterate over tuples of parsing and literal elements.
getClauses() const199   auto getClauses() const {
200     return llvm::zip(getLiteralElements(), getParsingElements());
201   }
202 
203   /// If the parsing element is a single UnitAttr element, then it returns the
204   /// attribute variable. Otherwise, returns nullptr.
205   AttributeVariable *
getUnitAttrParsingElement(ArrayRef<FormatElement * > pelement)206   getUnitAttrParsingElement(ArrayRef<FormatElement *> pelement) {
207     if (pelement.size() == 1) {
208       auto *attrElem = dyn_cast<AttributeVariable>(pelement[0]);
209       if (attrElem && attrElem->isUnitAttr())
210         return attrElem;
211     }
212     return nullptr;
213   }
214 
215 private:
216   /// A vector of `LiteralElement` objects. Each element stores the keyword
217   /// for one case of oilist element. For example, an oilist element along with
218   /// the `literalElements` vector:
219   /// ```
220   ///  oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
221   ///  literalElements = { `keyword`, `otherKeyword` }
222   /// ```
223   std::vector<FormatElement *> literalElements;
224 
225   /// A vector of valid declarative assembly format vectors. Each object in
226   /// parsing elements is a vector of elements in assembly format syntax.
227   /// For example, an oilist element along with the parsingElements vector:
228   /// ```
229   ///  oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
230   ///  parsingElements = {
231   ///    { `=`, `(`, $arg0, `)` },
232   ///    { `<`, $arg1, `>` }
233   ///  }
234   /// ```
235   std::vector<std::vector<FormatElement *>> parsingElements;
236 };
237 } // namespace
238 
239 //===----------------------------------------------------------------------===//
240 // OperationFormat
241 //===----------------------------------------------------------------------===//
242 
243 namespace {
244 
245 using ConstArgument =
246     llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
247 
248 struct OperationFormat {
249   /// This class represents a specific resolver for an operand or result type.
250   class TypeResolution {
251   public:
252     TypeResolution() = default;
253 
254     /// Get the index into the buildable types for this type, or None.
getBuilderIdx() const255     Optional<int> getBuilderIdx() const { return builderIdx; }
setBuilderIdx(int idx)256     void setBuilderIdx(int idx) { builderIdx = idx; }
257 
258     /// Get the variable this type is resolved to, or nullptr.
getVariable() const259     const NamedTypeConstraint *getVariable() const {
260       return resolver.dyn_cast<const NamedTypeConstraint *>();
261     }
262     /// Get the attribute this type is resolved to, or nullptr.
getAttribute() const263     const NamedAttribute *getAttribute() const {
264       return resolver.dyn_cast<const NamedAttribute *>();
265     }
266     /// Get the transformer for the type of the variable, or None.
getVarTransformer() const267     Optional<StringRef> getVarTransformer() const {
268       return variableTransformer;
269     }
setResolver(ConstArgument arg,Optional<StringRef> transformer)270     void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
271       resolver = arg;
272       variableTransformer = transformer;
273       assert(getVariable() || getAttribute());
274     }
275 
276   private:
277     /// If the type is resolved with a buildable type, this is the index into
278     /// 'buildableTypes' in the parent format.
279     Optional<int> builderIdx;
280     /// If the type is resolved based upon another operand or result, this is
281     /// the variable or the attribute that this type is resolved to.
282     ConstArgument resolver;
283     /// If the type is resolved based upon another operand or result, this is
284     /// a transformer to apply to the variable when resolving.
285     Optional<StringRef> variableTransformer;
286   };
287 
288   /// The context in which an element is generated.
289   enum class GenContext {
290     /// The element is generated at the top-level or with the same behaviour.
291     Normal,
292     /// The element is generated inside an optional group.
293     Optional
294   };
295 
OperationFormat__anonf2a272a10411::OperationFormat296   OperationFormat(const Operator &op)
297 
298   {
299     operandTypes.resize(op.getNumOperands(), TypeResolution());
300     resultTypes.resize(op.getNumResults(), TypeResolution());
301 
302     hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
303       return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
304     });
305 
306     hasSingleBlockTrait =
307         hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock");
308   }
309 
310   /// Generate the operation parser from this format.
311   void genParser(Operator &op, OpClass &opClass);
312   /// Generate the parser code for a specific format element.
313   void genElementParser(FormatElement *element, MethodBody &body,
314                         FmtContext &attrTypeCtx,
315                         GenContext genCtx = GenContext::Normal);
316   /// Generate the C++ to resolve the types of operands and results during
317   /// parsing.
318   void genParserTypeResolution(Operator &op, MethodBody &body);
319   /// Generate the C++ to resolve the types of the operands during parsing.
320   void genParserOperandTypeResolution(
321       Operator &op, MethodBody &body,
322       function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
323   /// Generate the C++ to resolve regions during parsing.
324   void genParserRegionResolution(Operator &op, MethodBody &body);
325   /// Generate the C++ to resolve successors during parsing.
326   void genParserSuccessorResolution(Operator &op, MethodBody &body);
327   /// Generate the C++ to handling variadic segment size traits.
328   void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
329 
330   /// Generate the operation printer from this format.
331   void genPrinter(Operator &op, OpClass &opClass);
332 
333   /// Generate the printer code for a specific format element.
334   void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op,
335                          bool &shouldEmitSpace, bool &lastWasPunctuation);
336 
337   /// The various elements in this format.
338   std::vector<FormatElement *> elements;
339 
340   /// A flag indicating if all operand/result types were seen. If the format
341   /// contains these, it can not contain individual type resolvers.
342   bool allOperands = false, allOperandTypes = false, allResultTypes = false;
343 
344   /// A flag indicating if this operation infers its result types
345   bool infersResultTypes = false;
346 
347   /// A flag indicating if this operation has the SingleBlockImplicitTerminator
348   /// trait.
349   bool hasImplicitTermTrait;
350 
351   /// A flag indicating if this operation has the SingleBlock trait.
352   bool hasSingleBlockTrait;
353 
354   /// A map of buildable types to indices.
355   llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
356 
357   /// The index of the buildable type, if valid, for every operand and result.
358   std::vector<TypeResolution> operandTypes, resultTypes;
359 
360   /// The set of attributes explicitly used within the format.
361   SmallVector<const NamedAttribute *, 8> usedAttributes;
362   llvm::StringSet<> inferredAttributes;
363 };
364 } // namespace
365 
366 //===----------------------------------------------------------------------===//
367 // Parser Gen
368 
369 /// Returns true if we can format the given attribute as an EnumAttr in the
370 /// parser format.
canFormatEnumAttr(const NamedAttribute * attr)371 static bool canFormatEnumAttr(const NamedAttribute *attr) {
372   Attribute baseAttr = attr->attr.getBaseAttr();
373   const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
374   if (!enumAttr)
375     return false;
376 
377   // The attribute must have a valid underlying type and a constant builder.
378   return !enumAttr->getUnderlyingType().empty() &&
379          !enumAttr->getConstBuilderTemplate().empty();
380 }
381 
382 /// Returns if we should format the given attribute as an SymbolNameAttr.
shouldFormatSymbolNameAttr(const NamedAttribute * attr)383 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
384   return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
385 }
386 
387 /// The code snippet used to generate a parser call for an attribute.
388 ///
389 /// {0}: The name of the attribute.
390 /// {1}: The type for the attribute.
391 const char *const attrParserCode = R"(
392   if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}",
393           result.attributes)) {{
394     return ::mlir::failure();
395   }
396 )";
397 
398 /// The code snippet used to generate a parser call for an attribute.
399 ///
400 /// {0}: The name of the attribute.
401 /// {1}: The type for the attribute.
402 const char *const genericAttrParserCode = R"(
403   if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes))
404     return ::mlir::failure();
405 )";
406 
407 const char *const optionalAttrParserCode = R"(
408   {
409     ::mlir::OptionalParseResult parseResult =
410       parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes);
411     if (parseResult.hasValue() && failed(*parseResult))
412       return ::mlir::failure();
413   }
414 )";
415 
416 /// The code snippet used to generate a parser call for a symbol name attribute.
417 ///
418 /// {0}: The name of the attribute.
419 const char *const symbolNameAttrParserCode = R"(
420   if (parser.parseSymbolName({0}Attr, "{0}", result.attributes))
421     return ::mlir::failure();
422 )";
423 const char *const optionalSymbolNameAttrParserCode = R"(
424   // Parsing an optional symbol name doesn't fail, so no need to check the
425   // result.
426   (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes);
427 )";
428 
429 /// The code snippet used to generate a parser call for an enum attribute.
430 ///
431 /// {0}: The name of the attribute.
432 /// {1}: The c++ namespace for the enum symbolize functions.
433 /// {2}: The function to symbolize a string of the enum.
434 /// {3}: The constant builder call to create an attribute of the enum type.
435 /// {4}: The set of allowed enum keywords.
436 /// {5}: The error message on failure when the enum isn't present.
437 const char *const enumAttrParserCode = R"(
438   {
439     ::llvm::StringRef attrStr;
440     ::mlir::NamedAttrList attrStorage;
441     auto loc = parser.getCurrentLocation();
442     if (parser.parseOptionalKeyword(&attrStr, {4})) {
443       ::mlir::StringAttr attrVal;
444       ::mlir::OptionalParseResult parseResult =
445         parser.parseOptionalAttribute(attrVal,
446                                       parser.getBuilder().getNoneType(),
447                                       "{0}", attrStorage);
448       if (parseResult.hasValue()) {{
449         if (failed(*parseResult))
450           return ::mlir::failure();
451         attrStr = attrVal.getValue();
452       } else {
453         {5}
454       }
455     }
456     if (!attrStr.empty()) {
457       auto attrOptional = {1}::{2}(attrStr);
458       if (!attrOptional)
459         return parser.emitError(loc, "invalid ")
460                << "{0} attribute specification: \"" << attrStr << '"';;
461 
462       {0}Attr = {3};
463       result.addAttribute("{0}", {0}Attr);
464     }
465   }
466 )";
467 
468 /// The code snippet used to generate a parser call for an operand.
469 ///
470 /// {0}: The name of the operand.
471 const char *const variadicOperandParserCode = R"(
472   {0}OperandsLoc = parser.getCurrentLocation();
473   if (parser.parseOperandList({0}Operands))
474     return ::mlir::failure();
475 )";
476 const char *const optionalOperandParserCode = R"(
477   {
478     {0}OperandsLoc = parser.getCurrentLocation();
479     ::mlir::OpAsmParser::UnresolvedOperand operand;
480     ::mlir::OptionalParseResult parseResult =
481                                     parser.parseOptionalOperand(operand);
482     if (parseResult.hasValue()) {
483       if (failed(*parseResult))
484         return ::mlir::failure();
485       {0}Operands.push_back(operand);
486     }
487   }
488 )";
489 const char *const operandParserCode = R"(
490   {0}OperandsLoc = parser.getCurrentLocation();
491   if (parser.parseOperand({0}RawOperands[0]))
492     return ::mlir::failure();
493 )";
494 /// The code snippet used to generate a parser call for a VariadicOfVariadic
495 /// operand.
496 ///
497 /// {0}: The name of the operand.
498 /// {1}: The name of segment size attribute.
499 const char *const variadicOfVariadicOperandParserCode = R"(
500   {
501     {0}OperandsLoc = parser.getCurrentLocation();
502     int32_t curSize = 0;
503     do {
504       if (parser.parseOptionalLParen())
505         break;
506       if (parser.parseOperandList({0}Operands) || parser.parseRParen())
507         return ::mlir::failure();
508       {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
509       curSize = {0}Operands.size();
510     } while (succeeded(parser.parseOptionalComma()));
511   }
512 )";
513 
514 /// The code snippet used to generate a parser call for a type list.
515 ///
516 /// {0}: The name for the type list.
517 const char *const variadicOfVariadicTypeParserCode = R"(
518   do {
519     if (parser.parseOptionalLParen())
520       break;
521     if (parser.parseOptionalRParen() &&
522         (parser.parseTypeList({0}Types) || parser.parseRParen()))
523       return ::mlir::failure();
524   } while (succeeded(parser.parseOptionalComma()));
525 )";
526 const char *const variadicTypeParserCode = R"(
527   if (parser.parseTypeList({0}Types))
528     return ::mlir::failure();
529 )";
530 const char *const optionalTypeParserCode = R"(
531   {
532     ::mlir::Type optionalType;
533     ::mlir::OptionalParseResult parseResult =
534                                     parser.parseOptionalType(optionalType);
535     if (parseResult.hasValue()) {
536       if (failed(*parseResult))
537         return ::mlir::failure();
538       {0}Types.push_back(optionalType);
539     }
540   }
541 )";
542 const char *const typeParserCode = R"(
543   {
544     {0} type;
545     if (parser.parseCustomTypeWithFallback(type))
546       return ::mlir::failure();
547     {1}RawTypes[0] = type;
548   }
549 )";
550 const char *const qualifiedTypeParserCode = R"(
551   if (parser.parseType({1}RawTypes[0]))
552     return ::mlir::failure();
553 )";
554 
555 /// The code snippet used to generate a parser call for a functional type.
556 ///
557 /// {0}: The name for the input type list.
558 /// {1}: The name for the result type list.
559 const char *const functionalTypeParserCode = R"(
560   ::mlir::FunctionType {0}__{1}_functionType;
561   if (parser.parseType({0}__{1}_functionType))
562     return ::mlir::failure();
563   {0}Types = {0}__{1}_functionType.getInputs();
564   {1}Types = {0}__{1}_functionType.getResults();
565 )";
566 
567 /// The code snippet used to generate a parser call to infer return types.
568 ///
569 /// {0}: The operation class name
570 const char *const inferReturnTypesParserCode = R"(
571   ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
572   if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
573       result.location, result.operands,
574       result.attributes.getDictionary(parser.getContext()),
575       result.regions, inferredReturnTypes)))
576     return ::mlir::failure();
577   result.addTypes(inferredReturnTypes);
578 )";
579 
580 /// The code snippet used to generate a parser call for a region list.
581 ///
582 /// {0}: The name for the region list.
583 const char *regionListParserCode = R"(
584   {
585     std::unique_ptr<::mlir::Region> region;
586     auto firstRegionResult = parser.parseOptionalRegion(region);
587     if (firstRegionResult.hasValue()) {
588       if (failed(*firstRegionResult))
589         return ::mlir::failure();
590       {0}Regions.emplace_back(std::move(region));
591 
592       // Parse any trailing regions.
593       while (succeeded(parser.parseOptionalComma())) {
594         region = std::make_unique<::mlir::Region>();
595         if (parser.parseRegion(*region))
596           return ::mlir::failure();
597         {0}Regions.emplace_back(std::move(region));
598       }
599     }
600   }
601 )";
602 
603 /// The code snippet used to ensure a list of regions have terminators.
604 ///
605 /// {0}: The name of the region list.
606 const char *regionListEnsureTerminatorParserCode = R"(
607   for (auto &region : {0}Regions)
608     ensureTerminator(*region, parser.getBuilder(), result.location);
609 )";
610 
611 /// The code snippet used to ensure a list of regions have a block.
612 ///
613 /// {0}: The name of the region list.
614 const char *regionListEnsureSingleBlockParserCode = R"(
615   for (auto &region : {0}Regions)
616     if (region->empty()) region->emplaceBlock();
617 )";
618 
619 /// The code snippet used to generate a parser call for an optional region.
620 ///
621 /// {0}: The name of the region.
622 const char *optionalRegionParserCode = R"(
623   {
624      auto parseResult = parser.parseOptionalRegion(*{0}Region);
625      if (parseResult.hasValue() && failed(*parseResult))
626        return ::mlir::failure();
627   }
628 )";
629 
630 /// The code snippet used to generate a parser call for a region.
631 ///
632 /// {0}: The name of the region.
633 const char *regionParserCode = R"(
634   if (parser.parseRegion(*{0}Region))
635     return ::mlir::failure();
636 )";
637 
638 /// The code snippet used to ensure a region has a terminator.
639 ///
640 /// {0}: The name of the region.
641 const char *regionEnsureTerminatorParserCode = R"(
642   ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
643 )";
644 
645 /// The code snippet used to ensure a region has a block.
646 ///
647 /// {0}: The name of the region.
648 const char *regionEnsureSingleBlockParserCode = R"(
649   if ({0}Region->empty()) {0}Region->emplaceBlock();
650 )";
651 
652 /// The code snippet used to generate a parser call for a successor list.
653 ///
654 /// {0}: The name for the successor list.
655 const char *successorListParserCode = R"(
656   {
657     ::mlir::Block *succ;
658     auto firstSucc = parser.parseOptionalSuccessor(succ);
659     if (firstSucc.hasValue()) {
660       if (failed(*firstSucc))
661         return ::mlir::failure();
662       {0}Successors.emplace_back(succ);
663 
664       // Parse any trailing successors.
665       while (succeeded(parser.parseOptionalComma())) {
666         if (parser.parseSuccessor(succ))
667           return ::mlir::failure();
668         {0}Successors.emplace_back(succ);
669       }
670     }
671   }
672 )";
673 
674 /// The code snippet used to generate a parser call for a successor.
675 ///
676 /// {0}: The name of the successor.
677 const char *successorParserCode = R"(
678   if (parser.parseSuccessor({0}Successor))
679     return ::mlir::failure();
680 )";
681 
682 /// The code snippet used to generate a parser for OIList
683 ///
684 /// {0}: literal keyword corresponding to a case for oilist
685 const char *oilistParserCode = R"(
686   if ({0}Clause) {
687     return parser.emitError(parser.getNameLoc())
688           << "`{0}` clause can appear at most once in the expansion of the "
689              "oilist directive";
690   }
691   {0}Clause = true;
692 )";
693 
694 namespace {
695 /// The type of length for a given parse argument.
696 enum class ArgumentLengthKind {
697   /// The argument is a variadic of a variadic, and may contain 0->N range
698   /// elements.
699   VariadicOfVariadic,
700   /// The argument is variadic, and may contain 0->N elements.
701   Variadic,
702   /// The argument is optional, and may contain 0 or 1 elements.
703   Optional,
704   /// The argument is a single element, i.e. always represents 1 element.
705   Single
706 };
707 } // namespace
708 
709 /// Get the length kind for the given constraint.
710 static ArgumentLengthKind
getArgumentLengthKind(const NamedTypeConstraint * var)711 getArgumentLengthKind(const NamedTypeConstraint *var) {
712   if (var->isOptional())
713     return ArgumentLengthKind::Optional;
714   if (var->isVariadicOfVariadic())
715     return ArgumentLengthKind::VariadicOfVariadic;
716   if (var->isVariadic())
717     return ArgumentLengthKind::Variadic;
718   return ArgumentLengthKind::Single;
719 }
720 
721 /// Get the name used for the type list for the given type directive operand.
722 /// 'lengthKind' to the corresponding kind for the given argument.
723 static StringRef getTypeListName(FormatElement *arg,
724                                  ArgumentLengthKind &lengthKind) {
725   if (auto *operand = dyn_cast<OperandVariable>(arg)) {
726     lengthKind = getArgumentLengthKind(operand->getVar());
727     return operand->getVar()->name;
728   }
729   if (auto *result = dyn_cast<ResultVariable>(arg)) {
730     lengthKind = getArgumentLengthKind(result->getVar());
731     return result->getVar()->name;
732   }
733   lengthKind = ArgumentLengthKind::Variadic;
734   if (isa<OperandsDirective>(arg))
735     return "allOperand";
736   if (isa<ResultsDirective>(arg))
737     return "allResult";
738   llvm_unreachable("unknown 'type' directive argument");
739 }
740 
741 /// Generate the parser for a literal value.
742 static void genLiteralParser(StringRef value, MethodBody &body) {
743   // Handle the case of a keyword/identifier.
744   if (value.front() == '_' || isalpha(value.front())) {
745     body << "Keyword(\"" << value << "\")";
746     return;
747   }
748   body << (StringRef)StringSwitch<StringRef>(value)
749               .Case("->", "Arrow()")
750               .Case(":", "Colon()")
751               .Case(",", "Comma()")
752               .Case("=", "Equal()")
753               .Case("<", "Less()")
754               .Case(">", "Greater()")
755               .Case("{", "LBrace()")
756               .Case("}", "RBrace()")
757               .Case("(", "LParen()")
758               .Case(")", "RParen()")
759               .Case("[", "LSquare()")
760               .Case("]", "RSquare()")
761               .Case("?", "Question()")
762               .Case("+", "Plus()")
763               .Case("*", "Star()");
764 }
765 
766 /// Generate the storage code required for parsing the given element.
genElementParserStorage(FormatElement * element,const Operator & op,MethodBody & body)767 static void genElementParserStorage(FormatElement *element, const Operator &op,
768                                     MethodBody &body) {
769   if (auto *optional = dyn_cast<OptionalElement>(element)) {
770     ArrayRef<FormatElement *> elements = optional->getThenElements();
771 
772     // If the anchor is a unit attribute, it won't be parsed directly so elide
773     // it.
774     auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
775     FormatElement *elidedAnchorElement = nullptr;
776     if (anchor && anchor != elements.front() && anchor->isUnitAttr())
777       elidedAnchorElement = anchor;
778     for (FormatElement *childElement : elements)
779       if (childElement != elidedAnchorElement)
780         genElementParserStorage(childElement, op, body);
781     for (FormatElement *childElement : optional->getElseElements())
782       genElementParserStorage(childElement, op, body);
783 
784   } else if (auto *oilist = dyn_cast<OIListElement>(element)) {
785     for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) {
786       if (!oilist->getUnitAttrParsingElement(pelement))
787         for (FormatElement *element : pelement)
788           genElementParserStorage(element, op, body);
789     }
790 
791   } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
792     for (FormatElement *paramElement : custom->getArguments())
793       genElementParserStorage(paramElement, op, body);
794 
795   } else if (isa<OperandsDirective>(element)) {
796     body << "  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
797             "allOperands;\n";
798 
799   } else if (isa<RegionsDirective>(element)) {
800     body << "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
801             "fullRegions;\n";
802 
803   } else if (isa<SuccessorsDirective>(element)) {
804     body << "  ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
805 
806   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
807     const NamedAttribute *var = attr->getVar();
808     body << llvm::formatv("  {0} {1}Attr;\n", var->attr.getStorageType(),
809                           var->name);
810 
811   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
812     StringRef name = operand->getVar()->name;
813     if (operand->getVar()->isVariableLength()) {
814       body
815           << "  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
816           << name << "Operands;\n";
817       if (operand->getVar()->isVariadicOfVariadic()) {
818         body << "    llvm::SmallVector<int32_t> " << name
819              << "OperandGroupSizes;\n";
820       }
821     } else {
822       body << "  ::mlir::OpAsmParser::UnresolvedOperand " << name
823            << "RawOperands[1];\n"
824            << "  ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
825            << name << "Operands(" << name << "RawOperands);";
826     }
827     body << llvm::formatv("  ::llvm::SMLoc {0}OperandsLoc;\n"
828                           "  (void){0}OperandsLoc;\n",
829                           name);
830 
831   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
832     StringRef name = region->getVar()->name;
833     if (region->getVar()->isVariadic()) {
834       body << llvm::formatv(
835           "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
836           "{0}Regions;\n",
837           name);
838     } else {
839       body << llvm::formatv("  std::unique_ptr<::mlir::Region> {0}Region = "
840                             "std::make_unique<::mlir::Region>();\n",
841                             name);
842     }
843 
844   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
845     StringRef name = successor->getVar()->name;
846     if (successor->getVar()->isVariadic()) {
847       body << llvm::formatv("  ::llvm::SmallVector<::mlir::Block *, 2> "
848                             "{0}Successors;\n",
849                             name);
850     } else {
851       body << llvm::formatv("  ::mlir::Block *{0}Successor = nullptr;\n", name);
852     }
853 
854   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
855     ArgumentLengthKind lengthKind;
856     StringRef name = getTypeListName(dir->getArg(), lengthKind);
857     if (lengthKind != ArgumentLengthKind::Single)
858       body << "  ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
859     else
860       body << llvm::formatv("  ::mlir::Type {0}RawTypes[1];\n", name)
861            << llvm::formatv(
862                   "  ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
863                   name);
864   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
865     ArgumentLengthKind ignored;
866     body << "  ::llvm::ArrayRef<::mlir::Type> "
867          << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
868     body << "  ::llvm::ArrayRef<::mlir::Type> "
869          << getTypeListName(dir->getResults(), ignored) << "Types;\n";
870   }
871 }
872 
873 /// Generate the parser for a parameter to a custom directive.
genCustomParameterParser(FormatElement * param,MethodBody & body)874 static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
875   if (auto *attr = dyn_cast<AttributeVariable>(param)) {
876     body << attr->getVar()->name << "Attr";
877   } else if (isa<AttrDictDirective>(param)) {
878     body << "result.attributes";
879   } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
880     StringRef name = operand->getVar()->name;
881     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
882     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
883       body << llvm::formatv("{0}OperandGroups", name);
884     else if (lengthKind == ArgumentLengthKind::Variadic)
885       body << llvm::formatv("{0}Operands", name);
886     else if (lengthKind == ArgumentLengthKind::Optional)
887       body << llvm::formatv("{0}Operand", name);
888     else
889       body << formatv("{0}RawOperands[0]", name);
890 
891   } else if (auto *region = dyn_cast<RegionVariable>(param)) {
892     StringRef name = region->getVar()->name;
893     if (region->getVar()->isVariadic())
894       body << llvm::formatv("{0}Regions", name);
895     else
896       body << llvm::formatv("*{0}Region", name);
897 
898   } else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
899     StringRef name = successor->getVar()->name;
900     if (successor->getVar()->isVariadic())
901       body << llvm::formatv("{0}Successors", name);
902     else
903       body << llvm::formatv("{0}Successor", name);
904 
905   } else if (auto *dir = dyn_cast<RefDirective>(param)) {
906     genCustomParameterParser(dir->getArg(), body);
907 
908   } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
909     ArgumentLengthKind lengthKind;
910     StringRef listName = getTypeListName(dir->getArg(), lengthKind);
911     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
912       body << llvm::formatv("{0}TypeGroups", listName);
913     else if (lengthKind == ArgumentLengthKind::Variadic)
914       body << llvm::formatv("{0}Types", listName);
915     else if (lengthKind == ArgumentLengthKind::Optional)
916       body << llvm::formatv("{0}Type", listName);
917     else
918       body << formatv("{0}RawTypes[0]", listName);
919   } else {
920     llvm_unreachable("unknown custom directive parameter");
921   }
922 }
923 
924 /// Generate the parser for a custom directive.
genCustomDirectiveParser(CustomDirective * dir,MethodBody & body)925 static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
926   body << "  {\n";
927 
928   // Preprocess the directive variables.
929   // * Add a local variable for optional operands and types. This provides a
930   //   better API to the user defined parser methods.
931   // * Set the location of operand variables.
932   for (FormatElement *param : dir->getArguments()) {
933     if (auto *operand = dyn_cast<OperandVariable>(param)) {
934       auto *var = operand->getVar();
935       body << "    " << var->name
936            << "OperandsLoc = parser.getCurrentLocation();\n";
937       if (var->isOptional()) {
938         body << llvm::formatv(
939             "    ::llvm::Optional<::mlir::OpAsmParser::UnresolvedOperand> "
940             "{0}Operand;\n",
941             var->name);
942       } else if (var->isVariadicOfVariadic()) {
943         body << llvm::formatv("    "
944                               "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
945                               "OpAsmParser::UnresolvedOperand>> "
946                               "{0}OperandGroups;\n",
947                               var->name);
948       }
949     } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
950       ArgumentLengthKind lengthKind;
951       StringRef listName = getTypeListName(dir->getArg(), lengthKind);
952       if (lengthKind == ArgumentLengthKind::Optional) {
953         body << llvm::formatv("    ::mlir::Type {0}Type;\n", listName);
954       } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
955         body << llvm::formatv(
956             "    ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
957             "{0}TypeGroups;\n",
958             listName);
959       }
960     } else if (auto *dir = dyn_cast<RefDirective>(param)) {
961       FormatElement *input = dir->getArg();
962       if (auto *operand = dyn_cast<OperandVariable>(input)) {
963         if (!operand->getVar()->isOptional())
964           continue;
965         body << llvm::formatv(
966             "    {0} {1}Operand = {1}Operands.empty() ? {0}() : "
967             "{1}Operands[0];\n",
968             "::llvm::Optional<::mlir::OpAsmParser::UnresolvedOperand>",
969             operand->getVar()->name);
970 
971       } else if (auto *type = dyn_cast<TypeDirective>(input)) {
972         ArgumentLengthKind lengthKind;
973         StringRef listName = getTypeListName(type->getArg(), lengthKind);
974         if (lengthKind == ArgumentLengthKind::Optional) {
975           body << llvm::formatv("    ::mlir::Type {0}Type = {0}Types.empty() ? "
976                                 "::mlir::Type() : {0}Types[0];\n",
977                                 listName);
978         }
979       }
980     }
981   }
982 
983   body << "    if (parse" << dir->getName() << "(parser";
984   for (FormatElement *param : dir->getArguments()) {
985     body << ", ";
986     genCustomParameterParser(param, body);
987   }
988 
989   body << "))\n"
990        << "      return ::mlir::failure();\n";
991 
992   // After parsing, add handling for any of the optional constructs.
993   for (FormatElement *param : dir->getArguments()) {
994     if (auto *attr = dyn_cast<AttributeVariable>(param)) {
995       const NamedAttribute *var = attr->getVar();
996       if (var->attr.isOptional())
997         body << llvm::formatv("    if ({0}Attr)\n  ", var->name);
998 
999       body << llvm::formatv("    result.addAttribute(\"{0}\", {0}Attr);\n",
1000                             var->name);
1001     } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
1002       const NamedTypeConstraint *var = operand->getVar();
1003       if (var->isOptional()) {
1004         body << llvm::formatv("    if ({0}Operand.has_value())\n"
1005                               "      {0}Operands.push_back(*{0}Operand);\n",
1006                               var->name);
1007       } else if (var->isVariadicOfVariadic()) {
1008         body << llvm::formatv(
1009             "    for (const auto &subRange : {0}OperandGroups) {{\n"
1010             "      {0}Operands.append(subRange.begin(), subRange.end());\n"
1011             "      {0}OperandGroupSizes.push_back(subRange.size());\n"
1012             "    }\n",
1013             var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr());
1014       }
1015     } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
1016       ArgumentLengthKind lengthKind;
1017       StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1018       if (lengthKind == ArgumentLengthKind::Optional) {
1019         body << llvm::formatv("    if ({0}Type)\n"
1020                               "      {0}Types.push_back({0}Type);\n",
1021                               listName);
1022       } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1023         body << llvm::formatv(
1024             "    for (const auto &subRange : {0}TypeGroups)\n"
1025             "      {0}Types.append(subRange.begin(), subRange.end());\n",
1026             listName);
1027       }
1028     }
1029   }
1030 
1031   body << "  }\n";
1032 }
1033 
1034 /// Generate the parser for a enum attribute.
genEnumAttrParser(const NamedAttribute * var,MethodBody & body,FmtContext & attrTypeCtx)1035 static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
1036                               FmtContext &attrTypeCtx) {
1037   Attribute baseAttr = var->attr.getBaseAttr();
1038   const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1039   std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1040 
1041   // Generate the code for building an attribute for this enum.
1042   std::string attrBuilderStr;
1043   {
1044     llvm::raw_string_ostream os(attrBuilderStr);
1045     os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1046                 "attrOptional.value()");
1047   }
1048 
1049   // Build a string containing the cases that can be formatted as a keyword.
1050   std::string validCaseKeywordsStr = "{";
1051   llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
1052   for (const EnumAttrCase &attrCase : cases)
1053     if (canFormatStringAsKeyword(attrCase.getStr()))
1054       validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
1055   validCaseKeywordsOS.str().back() = '}';
1056 
1057   // If the attribute is not optional, build an error message for the missing
1058   // attribute.
1059   std::string errorMessage;
1060   if (!var->attr.isOptional()) {
1061     llvm::raw_string_ostream errorMessageOS(errorMessage);
1062     errorMessageOS
1063         << "return parser.emitError(loc, \"expected string or "
1064            "keyword containing one of the following enum values for attribute '"
1065         << var->name << "' [";
1066     llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) {
1067       errorMessageOS << attrCase.getStr();
1068     });
1069     errorMessageOS << "]\");";
1070   }
1071 
1072   body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
1073                   enumAttr.getStringToSymbolFnName(), attrBuilderStr,
1074                   validCaseKeywordsStr, errorMessage);
1075 }
1076 
genParser(Operator & op,OpClass & opClass)1077 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1078   SmallVector<MethodParameter> paramList;
1079   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1080   paramList.emplace_back("::mlir::OperationState &", "result");
1081 
1082   auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
1083                                          std::move(paramList));
1084   auto &body = method->body();
1085 
1086   // Generate variables to store the operands and type within the format. This
1087   // allows for referencing these variables in the presence of optional
1088   // groupings.
1089   for (FormatElement *element : elements)
1090     genElementParserStorage(element, op, body);
1091 
1092   // A format context used when parsing attributes with buildable types.
1093   FmtContext attrTypeCtx;
1094   attrTypeCtx.withBuilder("parser.getBuilder()");
1095 
1096   // Generate parsers for each of the elements.
1097   for (FormatElement *element : elements)
1098     genElementParser(element, body, attrTypeCtx);
1099 
1100   // Generate the code to resolve the operand/result types and successors now
1101   // that they have been parsed.
1102   genParserRegionResolution(op, body);
1103   genParserSuccessorResolution(op, body);
1104   genParserVariadicSegmentResolution(op, body);
1105   genParserTypeResolution(op, body);
1106 
1107   body << "  return ::mlir::success();\n";
1108 }
1109 
genElementParser(FormatElement * element,MethodBody & body,FmtContext & attrTypeCtx,GenContext genCtx)1110 void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
1111                                        FmtContext &attrTypeCtx,
1112                                        GenContext genCtx) {
1113   /// Optional Group.
1114   if (auto *optional = dyn_cast<OptionalElement>(element)) {
1115     ArrayRef<FormatElement *> elements =
1116         optional->getThenElements().drop_front(optional->getParseStart());
1117 
1118     // Generate a special optional parser for the first element to gate the
1119     // parsing of the rest of the elements.
1120     FormatElement *firstElement = elements.front();
1121     if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1122       genElementParser(attrVar, body, attrTypeCtx);
1123       body << "  if (" << attrVar->getVar()->name << "Attr) {\n";
1124     } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1125       body << "  if (succeeded(parser.parseOptional";
1126       genLiteralParser(literal->getSpelling(), body);
1127       body << ")) {\n";
1128     } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1129       genElementParser(opVar, body, attrTypeCtx);
1130       body << "  if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1131     } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1132       const NamedRegion *region = regionVar->getVar();
1133       if (region->isVariadic()) {
1134         genElementParser(regionVar, body, attrTypeCtx);
1135         body << "  if (!" << region->name << "Regions.empty()) {\n";
1136       } else {
1137         body << llvm::formatv(optionalRegionParserCode, region->name);
1138         body << "  if (!" << region->name << "Region->empty()) {\n  ";
1139         if (hasImplicitTermTrait)
1140           body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
1141         else if (hasSingleBlockTrait)
1142           body << llvm::formatv(regionEnsureSingleBlockParserCode,
1143                                 region->name);
1144       }
1145     }
1146 
1147     // If the anchor is a unit attribute, we don't need to print it. When
1148     // parsing, we will add this attribute if this group is present.
1149     FormatElement *elidedAnchorElement = nullptr;
1150     auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
1151     if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
1152       elidedAnchorElement = anchorAttr;
1153 
1154       // Add the anchor unit attribute to the operation state.
1155       body << "    result.addAttribute(\"" << anchorAttr->getVar()->name
1156            << "\", parser.getBuilder().getUnitAttr());\n";
1157     }
1158 
1159     // Generate the rest of the elements inside an optional group. Elements in
1160     // an optional group after the guard are parsed as required.
1161     for (FormatElement *childElement : llvm::drop_begin(elements, 1))
1162       if (childElement != elidedAnchorElement)
1163         genElementParser(childElement, body, attrTypeCtx, GenContext::Optional);
1164     body << "  }";
1165 
1166     // Generate the else elements.
1167     auto elseElements = optional->getElseElements();
1168     if (!elseElements.empty()) {
1169       body << " else {\n";
1170       for (FormatElement *childElement : elseElements)
1171         genElementParser(childElement, body, attrTypeCtx);
1172       body << "  }";
1173     }
1174     body << "\n";
1175 
1176     /// OIList Directive
1177   } else if (OIListElement *oilist = dyn_cast<OIListElement>(element)) {
1178     for (LiteralElement *le : oilist->getLiteralElements())
1179       body << "  bool " << le->getSpelling() << "Clause = false;\n";
1180 
1181     // Generate the parsing loop
1182     body << "  while(true) {\n";
1183     for (auto clause : oilist->getClauses()) {
1184       LiteralElement *lelement = std::get<0>(clause);
1185       ArrayRef<FormatElement *> pelement = std::get<1>(clause);
1186       body << "if (succeeded(parser.parseOptional";
1187       genLiteralParser(lelement->getSpelling(), body);
1188       body << ")) {\n";
1189       StringRef lelementName = lelement->getSpelling();
1190       body << formatv(oilistParserCode, lelementName);
1191       if (AttributeVariable *unitAttrElem =
1192               oilist->getUnitAttrParsingElement(pelement)) {
1193         body << "  result.addAttribute(\"" << unitAttrElem->getVar()->name
1194              << "\", UnitAttr::get(parser.getContext()));\n";
1195       } else {
1196         for (FormatElement *el : pelement)
1197           genElementParser(el, body, attrTypeCtx);
1198       }
1199       body << "    } else ";
1200     }
1201     body << " {\n";
1202     body << "    break;\n";
1203     body << "  }\n";
1204     body << "}\n";
1205 
1206     /// Literals.
1207   } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1208     body << "  if (parser.parse";
1209     genLiteralParser(literal->getSpelling(), body);
1210     body << ")\n    return ::mlir::failure();\n";
1211 
1212     /// Whitespaces.
1213   } else if (isa<WhitespaceElement>(element)) {
1214     // Nothing to parse.
1215 
1216     /// Arguments.
1217   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1218     const NamedAttribute *var = attr->getVar();
1219 
1220     // Check to see if we can parse this as an enum attribute.
1221     if (canFormatEnumAttr(var))
1222       return genEnumAttrParser(var, body, attrTypeCtx);
1223 
1224     // Check to see if we should parse this as a symbol name attribute.
1225     if (shouldFormatSymbolNameAttr(var)) {
1226       body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode
1227                                              : symbolNameAttrParserCode,
1228                       var->name);
1229       return;
1230     }
1231 
1232     // If this attribute has a buildable type, use that when parsing the
1233     // attribute.
1234     std::string attrTypeStr;
1235     if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1236       llvm::raw_string_ostream os(attrTypeStr);
1237       os << tgfmt(*typeBuilder, &attrTypeCtx);
1238     } else {
1239       attrTypeStr = "::mlir::Type{}";
1240     }
1241     if (genCtx == GenContext::Normal && var->attr.isOptional()) {
1242       body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
1243     } else {
1244       if (attr->shouldBeQualified() ||
1245           var->attr.getStorageType() == "::mlir::Attribute")
1246         body << formatv(genericAttrParserCode, var->name, attrTypeStr);
1247       else
1248         body << formatv(attrParserCode, var->name, attrTypeStr);
1249     }
1250 
1251   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1252     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1253     StringRef name = operand->getVar()->name;
1254     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1255       body << llvm::formatv(
1256           variadicOfVariadicOperandParserCode, name,
1257           operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr());
1258     else if (lengthKind == ArgumentLengthKind::Variadic)
1259       body << llvm::formatv(variadicOperandParserCode, name);
1260     else if (lengthKind == ArgumentLengthKind::Optional)
1261       body << llvm::formatv(optionalOperandParserCode, name);
1262     else
1263       body << formatv(operandParserCode, name);
1264 
1265   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1266     bool isVariadic = region->getVar()->isVariadic();
1267     body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
1268                           region->getVar()->name);
1269     if (hasImplicitTermTrait)
1270       body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
1271                                        : regionEnsureTerminatorParserCode,
1272                             region->getVar()->name);
1273     else if (hasSingleBlockTrait)
1274       body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
1275                                        : regionEnsureSingleBlockParserCode,
1276                             region->getVar()->name);
1277 
1278   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1279     bool isVariadic = successor->getVar()->isVariadic();
1280     body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1281                     successor->getVar()->name);
1282 
1283     /// Directives.
1284   } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1285     body << "  if (parser.parseOptionalAttrDict"
1286          << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1287          << "(result.attributes))\n"
1288          << "    return ::mlir::failure();\n";
1289   } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1290     genCustomDirectiveParser(customDir, body);
1291 
1292   } else if (isa<OperandsDirective>(element)) {
1293     body << "  ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1294          << "  if (parser.parseOperandList(allOperands))\n"
1295          << "    return ::mlir::failure();\n";
1296 
1297   } else if (isa<RegionsDirective>(element)) {
1298     body << llvm::formatv(regionListParserCode, "full");
1299     if (hasImplicitTermTrait)
1300       body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
1301     else if (hasSingleBlockTrait)
1302       body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
1303 
1304   } else if (isa<SuccessorsDirective>(element)) {
1305     body << llvm::formatv(successorListParserCode, "full");
1306 
1307   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1308     ArgumentLengthKind lengthKind;
1309     StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1310     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1311       body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
1312     } else if (lengthKind == ArgumentLengthKind::Variadic) {
1313       body << llvm::formatv(variadicTypeParserCode, listName);
1314     } else if (lengthKind == ArgumentLengthKind::Optional) {
1315       body << llvm::formatv(optionalTypeParserCode, listName);
1316     } else {
1317       const char *parserCode =
1318           dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
1319       TypeSwitch<FormatElement *>(dir->getArg())
1320           .Case<OperandVariable, ResultVariable>([&](auto operand) {
1321             body << formatv(parserCode,
1322                             operand->getVar()->constraint.getCPPClassName(),
1323                             listName);
1324           })
1325           .Default([&](auto operand) {
1326             body << formatv(parserCode, "::mlir::Type", listName);
1327           });
1328     }
1329   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1330     ArgumentLengthKind ignored;
1331     body << formatv(functionalTypeParserCode,
1332                     getTypeListName(dir->getInputs(), ignored),
1333                     getTypeListName(dir->getResults(), ignored));
1334   } else {
1335     llvm_unreachable("unknown format element");
1336   }
1337 }
1338 
genParserTypeResolution(Operator & op,MethodBody & body)1339 void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
1340   // If any of type resolutions use transformed variables, make sure that the
1341   // types of those variables are resolved.
1342   SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1343   FmtContext verifierFCtx;
1344   for (TypeResolution &resolver :
1345        llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1346     Optional<StringRef> transformer = resolver.getVarTransformer();
1347     if (!transformer)
1348       continue;
1349     // Ensure that we don't verify the same variables twice.
1350     const NamedTypeConstraint *variable = resolver.getVariable();
1351     if (!variable || !verifiedVariables.insert(variable).second)
1352       continue;
1353 
1354     auto constraint = variable->constraint;
1355     body << "  for (::mlir::Type type : " << variable->name << "Types) {\n"
1356          << "    (void)type;\n"
1357          << "    if (!("
1358          << tgfmt(constraint.getConditionTemplate(),
1359                   &verifierFCtx.withSelf("type"))
1360          << ")) {\n"
1361          << formatv("      return parser.emitError(parser.getNameLoc()) << "
1362                     "\"'{0}' must be {1}, but got \" << type;\n",
1363                     variable->name, constraint.getSummary())
1364          << "    }\n"
1365          << "  }\n";
1366   }
1367 
1368   // Initialize the set of buildable types.
1369   if (!buildableTypes.empty()) {
1370     FmtContext typeBuilderCtx;
1371     typeBuilderCtx.withBuilder("parser.getBuilder()");
1372     for (auto &it : buildableTypes)
1373       body << "  ::mlir::Type odsBuildableType" << it.second << " = "
1374            << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1375   }
1376 
1377   // Emit the code necessary for a type resolver.
1378   auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1379     if (Optional<int> val = resolver.getBuilderIdx()) {
1380       body << "odsBuildableType" << *val;
1381     } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1382       if (Optional<StringRef> tform = resolver.getVarTransformer()) {
1383         FmtContext fmtContext;
1384         fmtContext.addSubst("_ctxt", "parser.getContext()");
1385         if (var->isVariadic())
1386           fmtContext.withSelf(var->name + "Types");
1387         else
1388           fmtContext.withSelf(var->name + "Types[0]");
1389         body << tgfmt(*tform, &fmtContext);
1390       } else {
1391         body << var->name << "Types";
1392       }
1393     } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1394       if (Optional<StringRef> tform = resolver.getVarTransformer())
1395         body << tgfmt(*tform,
1396                       &FmtContext().withSelf(attr->name + "Attr.getType()"));
1397       else
1398         body << attr->name << "Attr.getType()";
1399     } else {
1400       body << curVar << "Types";
1401     }
1402   };
1403 
1404   // Resolve each of the result types.
1405   if (!infersResultTypes) {
1406     if (allResultTypes) {
1407       body << "  result.addTypes(allResultTypes);\n";
1408     } else {
1409       for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1410         body << "  result.addTypes(";
1411         emitTypeResolver(resultTypes[i], op.getResultName(i));
1412         body << ");\n";
1413       }
1414     }
1415   }
1416 
1417   // Emit the operand type resolutions.
1418   genParserOperandTypeResolution(op, body, emitTypeResolver);
1419 
1420   // Handle return type inference once all operands have been resolved
1421   if (infersResultTypes)
1422     body << formatv(inferReturnTypesParserCode, op.getCppClassName());
1423 }
1424 
genParserOperandTypeResolution(Operator & op,MethodBody & body,function_ref<void (TypeResolution &,StringRef)> emitTypeResolver)1425 void OperationFormat::genParserOperandTypeResolution(
1426     Operator &op, MethodBody &body,
1427     function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
1428   // Early exit if there are no operands.
1429   if (op.getNumOperands() == 0)
1430     return;
1431 
1432   // Handle the case where all operand types are grouped together with
1433   // "types(operands)".
1434   if (allOperandTypes) {
1435     // If `operands` was specified, use the full operand list directly.
1436     if (allOperands) {
1437       body << "  if (parser.resolveOperands(allOperands, allOperandTypes, "
1438               "allOperandLoc, result.operands))\n"
1439               "    return ::mlir::failure();\n";
1440       return;
1441     }
1442 
1443     // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1444     // llvm::concat does not allow the case of a single range, so guard it here.
1445     body << "  if (parser.resolveOperands(";
1446     if (op.getNumOperands() > 1) {
1447       body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1448       llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1449         body << operand.name << "Operands";
1450       });
1451       body << ")";
1452     } else {
1453       body << op.operand_begin()->name << "Operands";
1454     }
1455     body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1456          << "    return ::mlir::failure();\n";
1457     return;
1458   }
1459 
1460   // Handle the case where all operands are grouped together with "operands".
1461   if (allOperands) {
1462     body << "  if (parser.resolveOperands(allOperands, ";
1463 
1464     // Group all of the operand types together to perform the resolution all at
1465     // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1466     // the case of a single range, so guard it here.
1467     if (op.getNumOperands() > 1) {
1468       body << "::llvm::concat<const ::mlir::Type>(";
1469       llvm::interleaveComma(
1470           llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1471             body << "::llvm::ArrayRef<::mlir::Type>(";
1472             emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1473             body << ")";
1474           });
1475       body << ")";
1476     } else {
1477       emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1478     }
1479 
1480     body << ", allOperandLoc, result.operands))\n"
1481          << "    return ::mlir::failure();\n";
1482     return;
1483   }
1484 
1485   // The final case is the one where each of the operands types are resolved
1486   // separately.
1487   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1488     NamedTypeConstraint &operand = op.getOperand(i);
1489     body << "  if (parser.resolveOperands(" << operand.name << "Operands, ";
1490 
1491     // Resolve the type of this operand.
1492     TypeResolution &operandType = operandTypes[i];
1493     emitTypeResolver(operandType, operand.name);
1494 
1495     // If the type is resolved by a non-variadic variable, index into the
1496     // resolved type list. This allows for resolving the types of a variadic
1497     // operand list from a non-variadic variable.
1498     bool verifyOperandAndTypeSize = true;
1499     if (auto *resolverVar = operandType.getVariable()) {
1500       if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) {
1501         body << "[0]";
1502         verifyOperandAndTypeSize = false;
1503       }
1504     } else {
1505       verifyOperandAndTypeSize = !operandType.getBuilderIdx();
1506     }
1507 
1508     // Check to see if the sizes between the types and operands must match. If
1509     // they do, provide the operand location to select the proper resolution
1510     // overload.
1511     if (verifyOperandAndTypeSize)
1512       body << ", " << operand.name << "OperandsLoc";
1513     body << ", result.operands))\n    return ::mlir::failure();\n";
1514   }
1515 }
1516 
genParserRegionResolution(Operator & op,MethodBody & body)1517 void OperationFormat::genParserRegionResolution(Operator &op,
1518                                                 MethodBody &body) {
1519   // Check for the case where all regions were parsed.
1520   bool hasAllRegions = llvm::any_of(
1521       elements, [](FormatElement *elt) { return isa<RegionsDirective>(elt); });
1522   if (hasAllRegions) {
1523     body << "  result.addRegions(fullRegions);\n";
1524     return;
1525   }
1526 
1527   // Otherwise, handle each region individually.
1528   for (const NamedRegion &region : op.getRegions()) {
1529     if (region.isVariadic())
1530       body << "  result.addRegions(" << region.name << "Regions);\n";
1531     else
1532       body << "  result.addRegion(std::move(" << region.name << "Region));\n";
1533   }
1534 }
1535 
genParserSuccessorResolution(Operator & op,MethodBody & body)1536 void OperationFormat::genParserSuccessorResolution(Operator &op,
1537                                                    MethodBody &body) {
1538   // Check for the case where all successors were parsed.
1539   bool hasAllSuccessors = llvm::any_of(elements, [](FormatElement *elt) {
1540     return isa<SuccessorsDirective>(elt);
1541   });
1542   if (hasAllSuccessors) {
1543     body << "  result.addSuccessors(fullSuccessors);\n";
1544     return;
1545   }
1546 
1547   // Otherwise, handle each successor individually.
1548   for (const NamedSuccessor &successor : op.getSuccessors()) {
1549     if (successor.isVariadic())
1550       body << "  result.addSuccessors(" << successor.name << "Successors);\n";
1551     else
1552       body << "  result.addSuccessors(" << successor.name << "Successor);\n";
1553   }
1554 }
1555 
genParserVariadicSegmentResolution(Operator & op,MethodBody & body)1556 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1557                                                          MethodBody &body) {
1558   if (!allOperands) {
1559     if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1560       body << "  result.addAttribute(\"operand_segment_sizes\", "
1561            << "parser.getBuilder().getI32VectorAttr({";
1562       auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1563         // If the operand is variadic emit the parsed size.
1564         if (operand.isVariableLength())
1565           body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1566         else
1567           body << "1";
1568       };
1569       llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1570       body << "}));\n";
1571     }
1572     for (const NamedTypeConstraint &operand : op.getOperands()) {
1573       if (!operand.isVariadicOfVariadic())
1574         continue;
1575       body << llvm::formatv(
1576           "  result.addAttribute(\"{0}\", "
1577           "parser.getBuilder().getI32TensorAttr({1}OperandGroupSizes));\n",
1578           operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1579           operand.name);
1580     }
1581   }
1582 
1583   if (!allResultTypes &&
1584       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1585     body << "  result.addAttribute(\"result_segment_sizes\", "
1586          << "parser.getBuilder().getI32VectorAttr({";
1587     auto interleaveFn = [&](const NamedTypeConstraint &result) {
1588       // If the result is variadic emit the parsed size.
1589       if (result.isVariableLength())
1590         body << "static_cast<int32_t>(" << result.name << "Types.size())";
1591       else
1592         body << "1";
1593     };
1594     llvm::interleaveComma(op.getResults(), body, interleaveFn);
1595     body << "}));\n";
1596   }
1597 }
1598 
1599 //===----------------------------------------------------------------------===//
1600 // PrinterGen
1601 
1602 /// The code snippet used to generate a printer call for a region of an
1603 // operation that has the SingleBlockImplicitTerminator trait.
1604 ///
1605 /// {0}: The name of the region.
1606 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1607   {
1608     bool printTerminator = true;
1609     if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1610       printTerminator = !term->getAttrDictionary().empty() ||
1611                         term->getNumOperands() != 0 ||
1612                         term->getNumResults() != 0;
1613     }
1614     _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1615       /*printBlockTerminators=*/printTerminator);
1616   }
1617 )";
1618 
1619 /// The code snippet used to generate a printer call for an enum that has cases
1620 /// that can't be represented with a keyword.
1621 ///
1622 /// {0}: The name of the enum attribute.
1623 /// {1}: The name of the enum attributes symbolToString function.
1624 const char *enumAttrBeginPrinterCode = R"(
1625   {
1626     auto caseValue = {0}();
1627     auto caseValueStr = {1}(caseValue);
1628 )";
1629 
1630 /// Generate the printer for the 'attr-dict' directive.
genAttrDictPrinter(OperationFormat & fmt,Operator & op,MethodBody & body,bool withKeyword)1631 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1632                                MethodBody &body, bool withKeyword) {
1633   body << "  _odsPrinter.printOptionalAttrDict"
1634        << (withKeyword ? "WithKeyword" : "")
1635        << "((*this)->getAttrs(), /*elidedAttrs=*/{";
1636   // Elide the variadic segment size attributes if necessary.
1637   if (!fmt.allOperands &&
1638       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1639     body << "\"operand_segment_sizes\", ";
1640   if (!fmt.allResultTypes &&
1641       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1642     body << "\"result_segment_sizes\", ";
1643   if (!fmt.inferredAttributes.empty()) {
1644     for (const auto &attr : fmt.inferredAttributes)
1645       body << "\"" << attr.getKey() << "\", ";
1646   }
1647   llvm::interleaveComma(
1648       fmt.usedAttributes, body,
1649       [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
1650   body << "});\n";
1651 }
1652 
1653 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1654 /// space should be emitted before this element. `lastWasPunctuation` is true if
1655 /// the previous element was a punctuation literal.
genLiteralPrinter(StringRef value,MethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1656 static void genLiteralPrinter(StringRef value, MethodBody &body,
1657                               bool &shouldEmitSpace, bool &lastWasPunctuation) {
1658   body << "  _odsPrinter";
1659 
1660   // Don't insert a space for certain punctuation.
1661   if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
1662     body << " << ' '";
1663   body << " << \"" << value << "\";\n";
1664 
1665   // Insert a space after certain literals.
1666   shouldEmitSpace =
1667       value.size() != 1 || !StringRef("<({[").contains(value.front());
1668   lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
1669 }
1670 
1671 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1672 /// are set to false.
genSpacePrinter(bool value,MethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1673 static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
1674                             bool &lastWasPunctuation) {
1675   if (value) {
1676     body << "  _odsPrinter << ' ';\n";
1677     lastWasPunctuation = false;
1678   } else {
1679     lastWasPunctuation = true;
1680   }
1681   shouldEmitSpace = false;
1682 }
1683 
1684 /// Generate the printer for a custom directive parameter.
genCustomDirectiveParameterPrinter(FormatElement * element,const Operator & op,MethodBody & body)1685 static void genCustomDirectiveParameterPrinter(FormatElement *element,
1686                                                const Operator &op,
1687                                                MethodBody &body) {
1688   if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1689     body << op.getGetterName(attr->getVar()->name) << "Attr()";
1690 
1691   } else if (isa<AttrDictDirective>(element)) {
1692     body << "getOperation()->getAttrDictionary()";
1693 
1694   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1695     body << op.getGetterName(operand->getVar()->name) << "()";
1696 
1697   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1698     body << op.getGetterName(region->getVar()->name) << "()";
1699 
1700   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1701     body << op.getGetterName(successor->getVar()->name) << "()";
1702 
1703   } else if (auto *dir = dyn_cast<RefDirective>(element)) {
1704     genCustomDirectiveParameterPrinter(dir->getArg(), op, body);
1705 
1706   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1707     auto *typeOperand = dir->getArg();
1708     auto *operand = dyn_cast<OperandVariable>(typeOperand);
1709     auto *var = operand ? operand->getVar()
1710                         : cast<ResultVariable>(typeOperand)->getVar();
1711     std::string name = op.getGetterName(var->name);
1712     if (var->isVariadic())
1713       body << name << "().getTypes()";
1714     else if (var->isOptional())
1715       body << llvm::formatv("({0}() ? {0}().getType() : Type())", name);
1716     else
1717       body << name << "().getType()";
1718   } else {
1719     llvm_unreachable("unknown custom directive parameter");
1720   }
1721 }
1722 
1723 /// Generate the printer for a custom directive.
genCustomDirectivePrinter(CustomDirective * customDir,const Operator & op,MethodBody & body)1724 static void genCustomDirectivePrinter(CustomDirective *customDir,
1725                                       const Operator &op, MethodBody &body) {
1726   body << "  print" << customDir->getName() << "(_odsPrinter, *this";
1727   for (FormatElement *param : customDir->getArguments()) {
1728     body << ", ";
1729     genCustomDirectiveParameterPrinter(param, op, body);
1730   }
1731   body << ");\n";
1732 }
1733 
1734 /// Generate the printer for a region with the given variable name.
genRegionPrinter(const Twine & regionName,MethodBody & body,bool hasImplicitTermTrait)1735 static void genRegionPrinter(const Twine &regionName, MethodBody &body,
1736                              bool hasImplicitTermTrait) {
1737   if (hasImplicitTermTrait)
1738     body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
1739                           regionName);
1740   else
1741     body << "  _odsPrinter.printRegion(" << regionName << ");\n";
1742 }
genVariadicRegionPrinter(const Twine & regionListName,MethodBody & body,bool hasImplicitTermTrait)1743 static void genVariadicRegionPrinter(const Twine &regionListName,
1744                                      MethodBody &body,
1745                                      bool hasImplicitTermTrait) {
1746   body << "    llvm::interleaveComma(" << regionListName
1747        << ", _odsPrinter, [&](::mlir::Region &region) {\n      ";
1748   genRegionPrinter("region", body, hasImplicitTermTrait);
1749   body << "    });\n";
1750 }
1751 
1752 /// Generate the C++ for an operand to a (*-)type directive.
genTypeOperandPrinter(FormatElement * arg,const Operator & op,MethodBody & body,bool useArrayRef=true)1753 static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
1754                                          MethodBody &body,
1755                                          bool useArrayRef = true) {
1756   if (isa<OperandsDirective>(arg))
1757     return body << "getOperation()->getOperandTypes()";
1758   if (isa<ResultsDirective>(arg))
1759     return body << "getOperation()->getResultTypes()";
1760   auto *operand = dyn_cast<OperandVariable>(arg);
1761   auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
1762   if (var->isVariadicOfVariadic())
1763     return body << llvm::formatv("{0}().join().getTypes()",
1764                                  op.getGetterName(var->name));
1765   if (var->isVariadic())
1766     return body << op.getGetterName(var->name) << "().getTypes()";
1767   if (var->isOptional())
1768     return body << llvm::formatv(
1769                "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1770                "::llvm::ArrayRef<::mlir::Type>())",
1771                op.getGetterName(var->name));
1772   if (useArrayRef)
1773     return body << "::llvm::ArrayRef<::mlir::Type>("
1774                 << op.getGetterName(var->name) << "().getType())";
1775   return body << op.getGetterName(var->name) << "().getType()";
1776 }
1777 
1778 /// Generate the printer for an enum attribute.
genEnumAttrPrinter(const NamedAttribute * var,const Operator & op,MethodBody & body)1779 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
1780                                MethodBody &body) {
1781   Attribute baseAttr = var->attr.getBaseAttr();
1782   const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1783   std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1784 
1785   body << llvm::formatv(enumAttrBeginPrinterCode,
1786                         (var->attr.isOptional() ? "*" : "") +
1787                             op.getGetterName(var->name),
1788                         enumAttr.getSymbolToStringFnName());
1789 
1790   // Get a string containing all of the cases that can't be represented with a
1791   // keyword.
1792   BitVector nonKeywordCases(cases.size());
1793   for (auto &it : llvm::enumerate(cases)) {
1794     if (!canFormatStringAsKeyword(it.value().getStr()))
1795       nonKeywordCases.set(it.index());
1796   }
1797 
1798   // Otherwise if this is a bit enum attribute, don't allow cases that may
1799   // overlap with other cases. For simplicity sake, only allow cases with a
1800   // single bit value.
1801   if (enumAttr.isBitEnum()) {
1802     for (auto &it : llvm::enumerate(cases)) {
1803       int64_t value = it.value().getValue();
1804       if (value < 0 || !llvm::isPowerOf2_64(value))
1805         nonKeywordCases.set(it.index());
1806     }
1807   }
1808 
1809   // If there are any cases that can't be used with a keyword, switch on the
1810   // case value to determine when to print in the string form.
1811   if (nonKeywordCases.any()) {
1812     body << "    switch (caseValue) {\n";
1813     StringRef cppNamespace = enumAttr.getCppNamespace();
1814     StringRef enumName = enumAttr.getEnumClassName();
1815     for (auto &it : llvm::enumerate(cases)) {
1816       if (nonKeywordCases.test(it.index()))
1817         continue;
1818       StringRef symbol = it.value().getSymbol();
1819       body << llvm::formatv("    case {0}::{1}::{2}:\n", cppNamespace, enumName,
1820                             llvm::isDigit(symbol.front()) ? ("_" + symbol)
1821                                                           : symbol);
1822     }
1823     body << "      _odsPrinter << caseValueStr;\n"
1824             "      break;\n"
1825             "    default:\n"
1826             "      _odsPrinter << '\"' << caseValueStr << '\"';\n"
1827             "      break;\n"
1828             "    }\n"
1829             "  }\n";
1830     return;
1831   }
1832 
1833   body << "    _odsPrinter << caseValueStr;\n"
1834           "  }\n";
1835 }
1836 
1837 /// Generate the check for the anchor of an optional group.
genOptionalGroupPrinterAnchor(FormatElement * anchor,const Operator & op,MethodBody & body)1838 static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
1839                                           const Operator &op,
1840                                           MethodBody &body) {
1841   TypeSwitch<FormatElement *>(anchor)
1842       .Case<OperandVariable, ResultVariable>([&](auto *element) {
1843         const NamedTypeConstraint *var = element->getVar();
1844         std::string name = op.getGetterName(var->name);
1845         if (var->isOptional())
1846           body << "  if (" << name << "()) {\n";
1847         else if (var->isVariadic())
1848           body << "  if (!" << name << "().empty()) {\n";
1849       })
1850       .Case<RegionVariable>([&](RegionVariable *element) {
1851         const NamedRegion *var = element->getVar();
1852         std::string name = op.getGetterName(var->name);
1853         // TODO: Add a check for optional regions here when ODS supports it.
1854         body << "  if (!" << name << "().empty()) {\n";
1855       })
1856       .Case<TypeDirective>([&](TypeDirective *element) {
1857         genOptionalGroupPrinterAnchor(element->getArg(), op, body);
1858       })
1859       .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
1860         genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
1861       })
1862       .Case<AttributeVariable>([&](AttributeVariable *attr) {
1863         body << "  if ((*this)->getAttr(\"" << attr->getVar()->name
1864              << "\")) {\n";
1865       });
1866 }
1867 
collect(FormatElement * element,SmallVectorImpl<VariableElement * > & variables)1868 void collect(FormatElement *element,
1869              SmallVectorImpl<VariableElement *> &variables) {
1870   TypeSwitch<FormatElement *>(element)
1871       .Case([&](VariableElement *var) { variables.emplace_back(var); })
1872       .Case([&](CustomDirective *ele) {
1873         for (FormatElement *arg : ele->getArguments())
1874           collect(arg, variables);
1875       })
1876       .Case([&](OptionalElement *ele) {
1877         for (FormatElement *arg : ele->getThenElements())
1878           collect(arg, variables);
1879         for (FormatElement *arg : ele->getElseElements())
1880           collect(arg, variables);
1881       })
1882       .Case([&](FunctionalTypeDirective *funcType) {
1883         collect(funcType->getInputs(), variables);
1884         collect(funcType->getResults(), variables);
1885       })
1886       .Case([&](OIListElement *oilist) {
1887         for (ArrayRef<FormatElement *> arg : oilist->getParsingElements())
1888           for (FormatElement *arg : arg)
1889             collect(arg, variables);
1890       });
1891 }
1892 
genElementPrinter(FormatElement * element,MethodBody & body,Operator & op,bool & shouldEmitSpace,bool & lastWasPunctuation)1893 void OperationFormat::genElementPrinter(FormatElement *element,
1894                                         MethodBody &body, Operator &op,
1895                                         bool &shouldEmitSpace,
1896                                         bool &lastWasPunctuation) {
1897   if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
1898     return genLiteralPrinter(literal->getSpelling(), body, shouldEmitSpace,
1899                              lastWasPunctuation);
1900 
1901   // Emit a whitespace element.
1902   if (auto *space = dyn_cast<WhitespaceElement>(element)) {
1903     if (space->getValue() == "\\n") {
1904       body << "  _odsPrinter.printNewline();\n";
1905     } else {
1906       genSpacePrinter(!space->getValue().empty(), body, shouldEmitSpace,
1907                       lastWasPunctuation);
1908     }
1909     return;
1910   }
1911 
1912   // Emit an optional group.
1913   if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
1914     // Emit the check for the presence of the anchor element.
1915     FormatElement *anchor = optional->getAnchor();
1916     genOptionalGroupPrinterAnchor(anchor, op, body);
1917 
1918     // If the anchor is a unit attribute, we don't need to print it. When
1919     // parsing, we will add this attribute if this group is present.
1920     auto elements = optional->getThenElements();
1921     FormatElement *elidedAnchorElement = nullptr;
1922     auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
1923     if (anchorAttr && anchorAttr != elements.front() &&
1924         anchorAttr->isUnitAttr()) {
1925       elidedAnchorElement = anchorAttr;
1926     }
1927 
1928     // Emit each of the elements.
1929     for (FormatElement *childElement : elements) {
1930       if (childElement != elidedAnchorElement) {
1931         genElementPrinter(childElement, body, op, shouldEmitSpace,
1932                           lastWasPunctuation);
1933       }
1934     }
1935     body << "  }";
1936 
1937     // Emit each of the else elements.
1938     auto elseElements = optional->getElseElements();
1939     if (!elseElements.empty()) {
1940       body << " else {\n";
1941       for (FormatElement *childElement : elseElements) {
1942         genElementPrinter(childElement, body, op, shouldEmitSpace,
1943                           lastWasPunctuation);
1944       }
1945       body << "  }";
1946     }
1947 
1948     body << "\n";
1949     return;
1950   }
1951 
1952   // Emit the OIList
1953   if (auto *oilist = dyn_cast<OIListElement>(element)) {
1954     genLiteralPrinter(" ", body, shouldEmitSpace, lastWasPunctuation);
1955     for (auto clause : oilist->getClauses()) {
1956       LiteralElement *lelement = std::get<0>(clause);
1957       ArrayRef<FormatElement *> pelement = std::get<1>(clause);
1958 
1959       SmallVector<VariableElement *> vars;
1960       for (FormatElement *el : pelement)
1961         collect(el, vars);
1962       body << "  if (false";
1963       for (VariableElement *var : vars) {
1964         TypeSwitch<FormatElement *>(var)
1965             .Case([&](AttributeVariable *attrEle) {
1966               body << " || " << op.getGetterName(attrEle->getVar()->name)
1967                    << "Attr()";
1968             })
1969             .Case([&](OperandVariable *ele) {
1970               if (ele->getVar()->isVariadic()) {
1971                 body << " || " << op.getGetterName(ele->getVar()->name)
1972                      << "().size()";
1973               } else {
1974                 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
1975               }
1976             })
1977             .Case([&](ResultVariable *ele) {
1978               if (ele->getVar()->isVariadic()) {
1979                 body << " || " << op.getGetterName(ele->getVar()->name)
1980                      << "().size()";
1981               } else {
1982                 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
1983               }
1984             })
1985             .Case([&](RegionVariable *reg) {
1986               body << " || " << op.getGetterName(reg->getVar()->name) << "()";
1987             });
1988       }
1989 
1990       body << ") {\n";
1991       genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace,
1992                         lastWasPunctuation);
1993       if (oilist->getUnitAttrParsingElement(pelement) == nullptr) {
1994         for (FormatElement *element : pelement)
1995           genElementPrinter(element, body, op, shouldEmitSpace,
1996                             lastWasPunctuation);
1997       }
1998       body << "  }\n";
1999     }
2000     return;
2001   }
2002 
2003   // Emit the attribute dictionary.
2004   if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
2005     genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
2006     lastWasPunctuation = false;
2007     return;
2008   }
2009 
2010   // Optionally insert a space before the next element. The AttrDict printer
2011   // already adds a space as necessary.
2012   if (shouldEmitSpace || !lastWasPunctuation)
2013     body << "  _odsPrinter << ' ';\n";
2014   lastWasPunctuation = false;
2015   shouldEmitSpace = true;
2016 
2017   if (auto *attr = dyn_cast<AttributeVariable>(element)) {
2018     const NamedAttribute *var = attr->getVar();
2019 
2020     // If we are formatting as an enum, symbolize the attribute as a string.
2021     if (canFormatEnumAttr(var))
2022       return genEnumAttrPrinter(var, op, body);
2023 
2024     // If we are formatting as a symbol name, handle it as a symbol name.
2025     if (shouldFormatSymbolNameAttr(var)) {
2026       body << "  _odsPrinter.printSymbolName(" << op.getGetterName(var->name)
2027            << "Attr().getValue());\n";
2028       return;
2029     }
2030 
2031     // Elide the attribute type if it is buildable.
2032     if (attr->getTypeBuilder())
2033       body << "  _odsPrinter.printAttributeWithoutType("
2034            << op.getGetterName(var->name) << "Attr());\n";
2035     else if (attr->shouldBeQualified() ||
2036              var->attr.getStorageType() == "::mlir::Attribute")
2037       body << "  _odsPrinter.printAttribute(" << op.getGetterName(var->name)
2038            << "Attr());\n";
2039     else
2040       body << "_odsPrinter.printStrippedAttrOrType("
2041            << op.getGetterName(var->name) << "Attr());\n";
2042   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
2043     if (operand->getVar()->isVariadicOfVariadic()) {
2044       body << "  ::llvm::interleaveComma("
2045            << op.getGetterName(operand->getVar()->name)
2046            << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2047               "\"(\" << operands << "
2048               "\")\"; });\n";
2049 
2050     } else if (operand->getVar()->isOptional()) {
2051       body << "  if (::mlir::Value value = "
2052            << op.getGetterName(operand->getVar()->name) << "())\n"
2053            << "    _odsPrinter << value;\n";
2054     } else {
2055       body << "  _odsPrinter << " << op.getGetterName(operand->getVar()->name)
2056            << "();\n";
2057     }
2058   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
2059     const NamedRegion *var = region->getVar();
2060     std::string name = op.getGetterName(var->name);
2061     if (var->isVariadic()) {
2062       genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait);
2063     } else {
2064       genRegionPrinter(name + "()", body, hasImplicitTermTrait);
2065     }
2066   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
2067     const NamedSuccessor *var = successor->getVar();
2068     std::string name = op.getGetterName(var->name);
2069     if (var->isVariadic())
2070       body << "  ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n";
2071     else
2072       body << "  _odsPrinter << " << name << "();\n";
2073   } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
2074     genCustomDirectivePrinter(dir, op, body);
2075   } else if (isa<OperandsDirective>(element)) {
2076     body << "  _odsPrinter << getOperation()->getOperands();\n";
2077   } else if (isa<RegionsDirective>(element)) {
2078     genVariadicRegionPrinter("getOperation()->getRegions()", body,
2079                              hasImplicitTermTrait);
2080   } else if (isa<SuccessorsDirective>(element)) {
2081     body << "  ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2082             "_odsPrinter);\n";
2083   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
2084     if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
2085       if (operand->getVar()->isVariadicOfVariadic()) {
2086         body << llvm::formatv(
2087             "  ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2088             "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2089             "types << \")\"; });\n",
2090             op.getGetterName(operand->getVar()->name));
2091         return;
2092       }
2093     }
2094     const NamedTypeConstraint *var = nullptr;
2095     {
2096       if (auto *operand = dyn_cast<OperandVariable>(dir->getArg()))
2097         var = operand->getVar();
2098       else if (auto *operand = dyn_cast<ResultVariable>(dir->getArg()))
2099         var = operand->getVar();
2100     }
2101     if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
2102         !var->isOptional()) {
2103       std::string cppClass = var->constraint.getCPPClassName();
2104       if (dir->shouldBeQualified()) {
2105         body << "   _odsPrinter << " << op.getGetterName(var->name)
2106              << "().getType();\n";
2107         return;
2108       }
2109       body << "  {\n"
2110            << "    auto type = " << op.getGetterName(var->name)
2111            << "().getType();\n"
2112            << "    if (auto validType = type.dyn_cast<" << cppClass << ">())\n"
2113            << "      _odsPrinter.printStrippedAttrOrType(validType);\n"
2114            << "   else\n"
2115            << "     _odsPrinter << type;\n"
2116            << "  }\n";
2117       return;
2118     }
2119     body << "  _odsPrinter << ";
2120     genTypeOperandPrinter(dir->getArg(), op, body, /*useArrayRef=*/false)
2121         << ";\n";
2122   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
2123     body << "  _odsPrinter.printFunctionalType(";
2124     genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";
2125     genTypeOperandPrinter(dir->getResults(), op, body) << ");\n";
2126   } else {
2127     llvm_unreachable("unknown format element");
2128   }
2129 }
2130 
genPrinter(Operator & op,OpClass & opClass)2131 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
2132   auto *method = opClass.addMethod(
2133       "void", "print",
2134       MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2135   auto &body = method->body();
2136 
2137   // Flags for if we should emit a space, and if the last element was
2138   // punctuation.
2139   bool shouldEmitSpace = true, lastWasPunctuation = false;
2140   for (FormatElement *element : elements)
2141     genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation);
2142 }
2143 
2144 //===----------------------------------------------------------------------===//
2145 // OpFormatParser
2146 //===----------------------------------------------------------------------===//
2147 
2148 /// Function to find an element within the given range that has the same name as
2149 /// 'name'.
2150 template <typename RangeT>
findArg(RangeT && range,StringRef name)2151 static auto findArg(RangeT &&range, StringRef name) {
2152   auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2153   return it != range.end() ? &*it : nullptr;
2154 }
2155 
2156 namespace {
2157 /// This class implements a parser for an instance of an operation assembly
2158 /// format.
2159 class OpFormatParser : public FormatParser {
2160 public:
OpFormatParser(llvm::SourceMgr & mgr,OperationFormat & format,Operator & op)2161   OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2162       : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op),
2163         seenOperandTypes(op.getNumOperands()),
2164         seenResultTypes(op.getNumResults()) {}
2165 
2166 protected:
2167   /// Verify the format elements.
2168   LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
2169   /// Verify the arguments to a custom directive.
2170   LogicalResult
2171   verifyCustomDirectiveArguments(SMLoc loc,
2172                                  ArrayRef<FormatElement *> arguments) override;
2173   /// Verify the elements of an optional group.
2174   LogicalResult
2175   verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
2176                               Optional<unsigned> anchorIndex) override;
2177   LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
2178                                            bool isAnchor);
2179 
2180   /// Parse an operation variable.
2181   FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
2182                                                Context ctx) override;
2183   /// Parse an operation format directive.
2184   FailureOr<FormatElement *>
2185   parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
2186 
2187 private:
2188   /// This struct represents a type resolution instance. It includes a specific
2189   /// type as well as an optional transformer to apply to that type in order to
2190   /// properly resolve the type of a variable.
2191   struct TypeResolutionInstance {
2192     ConstArgument resolver;
2193     Optional<StringRef> transformer;
2194   };
2195 
2196   /// Verify the state of operation attributes within the format.
2197   LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements);
2198 
2199   /// Verify that attributes elements aren't followed by colon literals.
2200   LogicalResult verifyAttributeColonType(SMLoc loc,
2201                                          ArrayRef<FormatElement *> elements);
2202 
2203   /// Verify the state of operation operands within the format.
2204   LogicalResult
2205   verifyOperands(SMLoc loc,
2206                  llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2207 
2208   /// Verify the state of operation regions within the format.
2209   LogicalResult verifyRegions(SMLoc loc);
2210 
2211   /// Verify the state of operation results within the format.
2212   LogicalResult
2213   verifyResults(SMLoc loc,
2214                 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2215 
2216   /// Verify the state of operation successors within the format.
2217   LogicalResult verifySuccessors(SMLoc loc);
2218 
2219   LogicalResult verifyOIListElements(SMLoc loc,
2220                                      ArrayRef<FormatElement *> elements);
2221 
2222   /// Given the values of an `AllTypesMatch` trait, check for inferable type
2223   /// resolution.
2224   void handleAllTypesMatchConstraint(
2225       ArrayRef<StringRef> values,
2226       llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2227   /// Check for inferable type resolution given all operands, and or results,
2228   /// have the same type. If 'includeResults' is true, the results also have the
2229   /// same type as all of the operands.
2230   void handleSameTypesConstraint(
2231       llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2232       bool includeResults);
2233   /// Check for inferable type resolution based on another operand, result, or
2234   /// attribute.
2235   void handleTypesMatchConstraint(
2236       llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2237       const llvm::Record &def);
2238 
2239   /// Returns an argument or attribute with the given name that has been seen
2240   /// within the format.
2241   ConstArgument findSeenArg(StringRef name);
2242 
2243   /// Parse the various different directives.
2244   FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
2245                                                     bool withKeyword);
2246   FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
2247                                                           Context context);
2248   FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
2249   LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
2250   FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
2251   FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
2252                                                      Context context);
2253   FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
2254                                                      Context context);
2255   FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
2256   FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
2257   FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
2258                                                       Context context);
2259   FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context);
2260   FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc,
2261                                                        bool isRefChild = false);
2262 
2263   //===--------------------------------------------------------------------===//
2264   // Fields
2265   //===--------------------------------------------------------------------===//
2266 
2267   OperationFormat &fmt;
2268   Operator &op;
2269 
2270   // The following are various bits of format state used for verification
2271   // during parsing.
2272   bool hasAttrDict = false;
2273   bool hasAllRegions = false, hasAllSuccessors = false;
2274   bool canInferResultTypes = false;
2275   llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2276   llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2277   llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2278   llvm::DenseSet<const NamedRegion *> seenRegions;
2279   llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2280 };
2281 } // namespace
2282 
verify(SMLoc loc,ArrayRef<FormatElement * > elements)2283 LogicalResult OpFormatParser::verify(SMLoc loc,
2284                                      ArrayRef<FormatElement *> elements) {
2285   // Check that the attribute dictionary is in the format.
2286   if (!hasAttrDict)
2287     return emitError(loc, "'attr-dict' directive not found in "
2288                           "custom assembly format");
2289 
2290   // Check for any type traits that we can use for inferring types.
2291   llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2292   for (const Trait &trait : op.getTraits()) {
2293     const llvm::Record &def = trait.getDef();
2294     if (def.isSubClassOf("AllTypesMatch")) {
2295       handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2296                                     variableTyResolver);
2297     } else if (def.getName() == "SameTypeOperands") {
2298       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2299     } else if (def.getName() == "SameOperandsAndResultType") {
2300       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2301     } else if (def.isSubClassOf("TypesMatchWith")) {
2302       handleTypesMatchConstraint(variableTyResolver, def);
2303     } else if (!op.allResultTypesKnown()) {
2304       // This doesn't check the name directly to handle
2305       //    DeclareOpInterfaceMethods<InferTypeOpInterface>
2306       // and the like.
2307       // TODO: Add hasCppInterface check.
2308       if (auto name = def.getValueAsOptionalString("cppInterfaceName")) {
2309         if (*name == "InferTypeOpInterface" &&
2310             def.getValueAsString("cppNamespace") == "::mlir")
2311           canInferResultTypes = true;
2312       }
2313     }
2314   }
2315 
2316   // Verify the state of the various operation components.
2317   if (failed(verifyAttributes(loc, elements)) ||
2318       failed(verifyResults(loc, variableTyResolver)) ||
2319       failed(verifyOperands(loc, variableTyResolver)) ||
2320       failed(verifyRegions(loc)) || failed(verifySuccessors(loc)) ||
2321       failed(verifyOIListElements(loc, elements)))
2322     return failure();
2323 
2324   // Collect the set of used attributes in the format.
2325   fmt.usedAttributes = seenAttrs.takeVector();
2326   return success();
2327 }
2328 
2329 LogicalResult
verifyAttributes(SMLoc loc,ArrayRef<FormatElement * > elements)2330 OpFormatParser::verifyAttributes(SMLoc loc,
2331                                  ArrayRef<FormatElement *> elements) {
2332   // Check that there are no `:` literals after an attribute without a constant
2333   // type. The attribute grammar contains an optional trailing colon type, which
2334   // can lead to unexpected and generally unintended behavior. Given that, it is
2335   // better to just error out here instead.
2336   if (failed(verifyAttributeColonType(loc, elements)))
2337     return failure();
2338 
2339   // Check for VariadicOfVariadic variables. The segment attribute of those
2340   // variables will be infered.
2341   for (const NamedTypeConstraint *var : seenOperands) {
2342     if (var->constraint.isVariadicOfVariadic()) {
2343       fmt.inferredAttributes.insert(
2344           var->constraint.getVariadicOfVariadicSegmentSizeAttr());
2345     }
2346   }
2347 
2348   return success();
2349 }
2350 
2351 /// Returns whether the single format element is optionally parsed.
isOptionallyParsed(FormatElement * el)2352 static bool isOptionallyParsed(FormatElement *el) {
2353   if (auto *attrVar = dyn_cast<AttributeVariable>(el)) {
2354     Attribute attr = attrVar->getVar()->attr;
2355     return attr.isOptional() || attr.hasDefaultValue();
2356   }
2357   if (auto *operandVar = dyn_cast<OperandVariable>(el)) {
2358     const NamedTypeConstraint *operand = operandVar->getVar();
2359     return operand->isOptional() || operand->isVariadic() ||
2360            operand->isVariadicOfVariadic();
2361   }
2362   if (auto *successorVar = dyn_cast<SuccessorVariable>(el))
2363     return successorVar->getVar()->isVariadic();
2364   if (auto *regionVar = dyn_cast<RegionVariable>(el))
2365     return regionVar->getVar()->isVariadic();
2366   return isa<WhitespaceElement, AttrDictDirective>(el);
2367 }
2368 
2369 /// Scan the given range of elements from the start for a colon literal,
2370 /// skipping any optionally-parsed elements. If an optional group is
2371 /// encountered, this function recurses into the 'then' and 'else' elements to
2372 /// check if they are invalid. Returns `success` if the range is known to be
2373 /// valid or `None` if scanning reached the end.
2374 ///
2375 /// Since the guard element of an optional group is required, this function
2376 /// accepts an optional element pointer to mark it as required.
checkElementRangeForColon(function_ref<LogicalResult (const Twine &)> emitError,StringRef attrName,iterator_range<ArrayRef<FormatElement * >::iterator> elementRange,FormatElement * optionalGuard=nullptr)2377 static Optional<LogicalResult> checkElementRangeForColon(
2378     function_ref<LogicalResult(const Twine &)> emitError, StringRef attrName,
2379     iterator_range<ArrayRef<FormatElement *>::iterator> elementRange,
2380     FormatElement *optionalGuard = nullptr) {
2381   for (FormatElement *element : elementRange) {
2382     // Skip optionally parsed elements.
2383     if (element != optionalGuard && isOptionallyParsed(element))
2384       continue;
2385 
2386     // Recurse on optional groups.
2387     if (auto *optional = dyn_cast<OptionalElement>(element)) {
2388       if (Optional<LogicalResult> result = checkElementRangeForColon(
2389               emitError, attrName, optional->getThenElements(),
2390               // The optional group guard is required for the group.
2391               optional->getThenElements().front()))
2392         if (failed(*result))
2393           return failure();
2394       if (Optional<LogicalResult> result = checkElementRangeForColon(
2395               emitError, attrName, optional->getElseElements()))
2396         if (failed(*result))
2397           return failure();
2398       // Skip the optional group.
2399       continue;
2400     }
2401 
2402     // If we encounter anything other than `:`, this range is range.
2403     auto *literal = dyn_cast<LiteralElement>(element);
2404     if (!literal || literal->getSpelling() != ":")
2405       return success();
2406     // If we encounter `:`, the range is known to be invalid.
2407     return emitError(
2408         llvm::formatv("format ambiguity caused by `:` literal found after "
2409                       "attribute `{0}` which does not have a buildable type",
2410                       attrName));
2411   }
2412   // Return None to indicate that we reached the end.
2413   return llvm::None;
2414 }
2415 
2416 /// For the given elements, check whether any attributes are followed by a colon
2417 /// literal, resulting in an ambiguous assembly format. Returns a non-null
2418 /// attribute if verification of said attribute reached the end of the range.
2419 /// Returns null if all attribute elements are verified.
2420 static FailureOr<AttributeVariable *>
verifyAttributeColon(function_ref<LogicalResult (const Twine &)> emitError,ArrayRef<FormatElement * > elements)2421 verifyAttributeColon(function_ref<LogicalResult(const Twine &)> emitError,
2422                      ArrayRef<FormatElement *> elements) {
2423   for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) {
2424     // The current attribute being verified.
2425     AttributeVariable *attr = nullptr;
2426 
2427     if ((attr = dyn_cast<AttributeVariable>(*it))) {
2428       // Check only attributes without type builders or that are known to call
2429       // the generic attribute parser.
2430       if (attr->getTypeBuilder() ||
2431           !(attr->shouldBeQualified() ||
2432             attr->getVar()->attr.getStorageType() == "::mlir::Attribute"))
2433         continue;
2434     } else if (auto *optional = dyn_cast<OptionalElement>(*it)) {
2435       // Recurse on optional groups.
2436       FailureOr<AttributeVariable *> thenResult =
2437           verifyAttributeColon(emitError, optional->getThenElements());
2438       if (failed(thenResult))
2439         return failure();
2440       FailureOr<AttributeVariable *> elseResult =
2441           verifyAttributeColon(emitError, optional->getElseElements());
2442       if (failed(elseResult))
2443         return failure();
2444       // If either optional group has an unverified attribute, save it.
2445       // Otherwise, move on to the next element.
2446       if (!(attr = *thenResult) && !(attr = *elseResult))
2447         continue;
2448     } else {
2449       continue;
2450     }
2451 
2452     // Verify subsequent elements for potential ambiguities.
2453     if (Optional<LogicalResult> result = checkElementRangeForColon(
2454             emitError, attr->getVar()->name, {std::next(it), e})) {
2455       if (failed(*result))
2456         return failure();
2457     } else {
2458       // Since we reached the end, return the attribute as unverified.
2459       return attr;
2460     }
2461   }
2462   // All attribute elements are known to be verified.
2463   return nullptr;
2464 }
2465 
2466 LogicalResult
verifyAttributeColonType(SMLoc loc,ArrayRef<FormatElement * > elements)2467 OpFormatParser::verifyAttributeColonType(SMLoc loc,
2468                                          ArrayRef<FormatElement *> elements) {
2469   return verifyAttributeColon(
2470       [&](const Twine &msg) { return emitError(loc, msg); }, elements);
2471 }
2472 
verifyOperands(SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2473 LogicalResult OpFormatParser::verifyOperands(
2474     SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2475   // Check that all of the operands are within the format, and their types can
2476   // be inferred.
2477   auto &buildableTypes = fmt.buildableTypes;
2478   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2479     NamedTypeConstraint &operand = op.getOperand(i);
2480 
2481     // Check that the operand itself is in the format.
2482     if (!fmt.allOperands && !seenOperands.count(&operand)) {
2483       return emitErrorAndNote(loc,
2484                               "operand #" + Twine(i) + ", named '" +
2485                                   operand.name + "', not found",
2486                               "suggest adding a '$" + operand.name +
2487                                   "' directive to the custom assembly format");
2488     }
2489 
2490     // Check that the operand type is in the format, or that it can be inferred.
2491     if (fmt.allOperandTypes || seenOperandTypes.test(i))
2492       continue;
2493 
2494     // Check to see if we can infer this type from another variable.
2495     auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
2496     if (varResolverIt != variableTyResolver.end()) {
2497       TypeResolutionInstance &resolver = varResolverIt->second;
2498       fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
2499       continue;
2500     }
2501 
2502     // Similarly to results, allow a custom builder for resolving the type if
2503     // we aren't using the 'operands' directive.
2504     Optional<StringRef> builder = operand.constraint.getBuilderCall();
2505     if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2506       return emitErrorAndNote(
2507           loc,
2508           "type of operand #" + Twine(i) + ", named '" + operand.name +
2509               "', is not buildable and a buildable type cannot be inferred",
2510           "suggest adding a type constraint to the operation or adding a "
2511           "'type($" +
2512               operand.name + ")' directive to the " + "custom assembly format");
2513     }
2514     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2515     fmt.operandTypes[i].setBuilderIdx(it.first->second);
2516   }
2517   return success();
2518 }
2519 
verifyRegions(SMLoc loc)2520 LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
2521   // Check that all of the regions are within the format.
2522   if (hasAllRegions)
2523     return success();
2524 
2525   for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2526     const NamedRegion &region = op.getRegion(i);
2527     if (!seenRegions.count(&region)) {
2528       return emitErrorAndNote(loc,
2529                               "region #" + Twine(i) + ", named '" +
2530                                   region.name + "', not found",
2531                               "suggest adding a '$" + region.name +
2532                                   "' directive to the custom assembly format");
2533     }
2534   }
2535   return success();
2536 }
2537 
verifyResults(SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2538 LogicalResult OpFormatParser::verifyResults(
2539     SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2540   // If we format all of the types together, there is nothing to check.
2541   if (fmt.allResultTypes)
2542     return success();
2543 
2544   // If no result types are specified and we can infer them, infer all result
2545   // types
2546   if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
2547       canInferResultTypes) {
2548     fmt.infersResultTypes = true;
2549     return success();
2550   }
2551 
2552   // Check that all of the result types can be inferred.
2553   auto &buildableTypes = fmt.buildableTypes;
2554   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2555     if (seenResultTypes.test(i))
2556       continue;
2557 
2558     // Check to see if we can infer this type from another variable.
2559     auto varResolverIt = variableTyResolver.find(op.getResultName(i));
2560     if (varResolverIt != variableTyResolver.end()) {
2561       TypeResolutionInstance resolver = varResolverIt->second;
2562       fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
2563       continue;
2564     }
2565 
2566     // If the result is not variable length, allow for the case where the type
2567     // has a builder that we can use.
2568     NamedTypeConstraint &result = op.getResult(i);
2569     Optional<StringRef> builder = result.constraint.getBuilderCall();
2570     if (!builder || result.isVariableLength()) {
2571       return emitErrorAndNote(
2572           loc,
2573           "type of result #" + Twine(i) + ", named '" + result.name +
2574               "', is not buildable and a buildable type cannot be inferred",
2575           "suggest adding a type constraint to the operation or adding a "
2576           "'type($" +
2577               result.name + ")' directive to the " + "custom assembly format");
2578     }
2579     // Note in the format that this result uses the custom builder.
2580     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2581     fmt.resultTypes[i].setBuilderIdx(it.first->second);
2582   }
2583   return success();
2584 }
2585 
verifySuccessors(SMLoc loc)2586 LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) {
2587   // Check that all of the successors are within the format.
2588   if (hasAllSuccessors)
2589     return success();
2590 
2591   for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
2592     const NamedSuccessor &successor = op.getSuccessor(i);
2593     if (!seenSuccessors.count(&successor)) {
2594       return emitErrorAndNote(loc,
2595                               "successor #" + Twine(i) + ", named '" +
2596                                   successor.name + "', not found",
2597                               "suggest adding a '$" + successor.name +
2598                                   "' directive to the custom assembly format");
2599     }
2600   }
2601   return success();
2602 }
2603 
2604 LogicalResult
verifyOIListElements(SMLoc loc,ArrayRef<FormatElement * > elements)2605 OpFormatParser::verifyOIListElements(SMLoc loc,
2606                                      ArrayRef<FormatElement *> elements) {
2607   // Check that all of the successors are within the format.
2608   SmallVector<StringRef> prohibitedLiterals;
2609   for (FormatElement *it : elements) {
2610     if (auto *oilist = dyn_cast<OIListElement>(it)) {
2611       if (!prohibitedLiterals.empty()) {
2612         // We just saw an oilist element in last iteration. Literals should not
2613         // match.
2614         for (LiteralElement *literal : oilist->getLiteralElements()) {
2615           if (find(prohibitedLiterals, literal->getSpelling()) !=
2616               prohibitedLiterals.end()) {
2617             return emitError(
2618                 loc, "format ambiguity because " + literal->getSpelling() +
2619                          " is used in two adjacent oilist elements.");
2620           }
2621         }
2622       }
2623       for (LiteralElement *literal : oilist->getLiteralElements())
2624         prohibitedLiterals.push_back(literal->getSpelling());
2625     } else if (auto *literal = dyn_cast<LiteralElement>(it)) {
2626       if (find(prohibitedLiterals, literal->getSpelling()) !=
2627           prohibitedLiterals.end()) {
2628         return emitError(
2629             loc,
2630             "format ambiguity because " + literal->getSpelling() +
2631                 " is used both in oilist element and the adjacent literal.");
2632       }
2633       prohibitedLiterals.clear();
2634     } else {
2635       prohibitedLiterals.clear();
2636     }
2637   }
2638   return success();
2639 }
2640 
handleAllTypesMatchConstraint(ArrayRef<StringRef> values,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2641 void OpFormatParser::handleAllTypesMatchConstraint(
2642     ArrayRef<StringRef> values,
2643     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2644   for (unsigned i = 0, e = values.size(); i != e; ++i) {
2645     // Check to see if this value matches a resolved operand or result type.
2646     ConstArgument arg = findSeenArg(values[i]);
2647     if (!arg)
2648       continue;
2649 
2650     // Mark this value as the type resolver for the other variables.
2651     for (unsigned j = 0; j != i; ++j)
2652       variableTyResolver[values[j]] = {arg, llvm::None};
2653     for (unsigned j = i + 1; j != e; ++j)
2654       variableTyResolver[values[j]] = {arg, llvm::None};
2655   }
2656 }
2657 
handleSameTypesConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,bool includeResults)2658 void OpFormatParser::handleSameTypesConstraint(
2659     llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2660     bool includeResults) {
2661   const NamedTypeConstraint *resolver = nullptr;
2662   int resolvedIt = -1;
2663 
2664   // Check to see if there is an operand or result to use for the resolution.
2665   if ((resolvedIt = seenOperandTypes.find_first()) != -1)
2666     resolver = &op.getOperand(resolvedIt);
2667   else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
2668     resolver = &op.getResult(resolvedIt);
2669   else
2670     return;
2671 
2672   // Set the resolvers for each operand and result.
2673   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
2674     if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
2675       variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
2676   if (includeResults) {
2677     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2678       if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
2679         variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
2680   }
2681 }
2682 
handleTypesMatchConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,const llvm::Record & def)2683 void OpFormatParser::handleTypesMatchConstraint(
2684     llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2685     const llvm::Record &def) {
2686   StringRef lhsName = def.getValueAsString("lhs");
2687   StringRef rhsName = def.getValueAsString("rhs");
2688   StringRef transformer = def.getValueAsString("transformer");
2689   if (ConstArgument arg = findSeenArg(lhsName))
2690     variableTyResolver[rhsName] = {arg, transformer};
2691 }
2692 
findSeenArg(StringRef name)2693 ConstArgument OpFormatParser::findSeenArg(StringRef name) {
2694   if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
2695     return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
2696   if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
2697     return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
2698   if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
2699     return seenAttrs.count(attr) ? attr : nullptr;
2700   return nullptr;
2701 }
2702 
2703 FailureOr<FormatElement *>
parseVariableImpl(SMLoc loc,StringRef name,Context ctx)2704 OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
2705   // Check that the parsed argument is something actually registered on the op.
2706   // Attributes
2707   if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
2708     if (ctx == TypeDirectiveContext)
2709       return emitError(
2710           loc, "attributes cannot be used as children to a `type` directive");
2711     if (ctx == RefDirectiveContext) {
2712       if (!seenAttrs.count(attr))
2713         return emitError(loc, "attribute '" + name +
2714                                   "' must be bound before it is referenced");
2715     } else if (!seenAttrs.insert(attr)) {
2716       return emitError(loc, "attribute '" + name + "' is already bound");
2717     }
2718 
2719     return create<AttributeVariable>(attr);
2720   }
2721   // Operands
2722   if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
2723     if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
2724       if (fmt.allOperands || !seenOperands.insert(operand).second)
2725         return emitError(loc, "operand '" + name + "' is already bound");
2726     } else if (ctx == RefDirectiveContext && !seenOperands.count(operand)) {
2727       return emitError(loc, "operand '" + name +
2728                                 "' must be bound before it is referenced");
2729     }
2730     return create<OperandVariable>(operand);
2731   }
2732   // Regions
2733   if (const NamedRegion *region = findArg(op.getRegions(), name)) {
2734     if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
2735       if (hasAllRegions || !seenRegions.insert(region).second)
2736         return emitError(loc, "region '" + name + "' is already bound");
2737     } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
2738       return emitError(loc, "region '" + name +
2739                                 "' must be bound before it is referenced");
2740     } else {
2741       return emitError(loc, "regions can only be used at the top level");
2742     }
2743     return create<RegionVariable>(region);
2744   }
2745   // Results.
2746   if (const auto *result = findArg(op.getResults(), name)) {
2747     if (ctx != TypeDirectiveContext)
2748       return emitError(loc, "result variables can can only be used as a child "
2749                             "to a 'type' directive");
2750     return create<ResultVariable>(result);
2751   }
2752   // Successors.
2753   if (const auto *successor = findArg(op.getSuccessors(), name)) {
2754     if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
2755       if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
2756         return emitError(loc, "successor '" + name + "' is already bound");
2757     } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
2758       return emitError(loc, "successor '" + name +
2759                                 "' must be bound before it is referenced");
2760     } else {
2761       return emitError(loc, "successors can only be used at the top level");
2762     }
2763 
2764     return create<SuccessorVariable>(successor);
2765   }
2766   return emitError(loc, "expected variable to refer to an argument, region, "
2767                         "result, or successor");
2768 }
2769 
2770 FailureOr<FormatElement *>
parseDirectiveImpl(SMLoc loc,FormatToken::Kind kind,Context ctx)2771 OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
2772                                    Context ctx) {
2773   switch (kind) {
2774   case FormatToken::kw_attr_dict:
2775     return parseAttrDictDirective(loc, ctx,
2776                                   /*withKeyword=*/false);
2777   case FormatToken::kw_attr_dict_w_keyword:
2778     return parseAttrDictDirective(loc, ctx,
2779                                   /*withKeyword=*/true);
2780   case FormatToken::kw_functional_type:
2781     return parseFunctionalTypeDirective(loc, ctx);
2782   case FormatToken::kw_operands:
2783     return parseOperandsDirective(loc, ctx);
2784   case FormatToken::kw_qualified:
2785     return parseQualifiedDirective(loc, ctx);
2786   case FormatToken::kw_regions:
2787     return parseRegionsDirective(loc, ctx);
2788   case FormatToken::kw_results:
2789     return parseResultsDirective(loc, ctx);
2790   case FormatToken::kw_successors:
2791     return parseSuccessorsDirective(loc, ctx);
2792   case FormatToken::kw_ref:
2793     return parseReferenceDirective(loc, ctx);
2794   case FormatToken::kw_type:
2795     return parseTypeDirective(loc, ctx);
2796   case FormatToken::kw_oilist:
2797     return parseOIListDirective(loc, ctx);
2798 
2799   default:
2800     return emitError(loc, "unsupported directive kind");
2801   }
2802 }
2803 
2804 FailureOr<FormatElement *>
parseAttrDictDirective(SMLoc loc,Context context,bool withKeyword)2805 OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
2806                                        bool withKeyword) {
2807   if (context == TypeDirectiveContext)
2808     return emitError(loc, "'attr-dict' directive can only be used as a "
2809                           "top-level directive");
2810 
2811   if (context == RefDirectiveContext) {
2812     if (!hasAttrDict)
2813       return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior "
2814                             "'attr-dict' directive");
2815 
2816     // Otherwise, this is a top-level context.
2817   } else {
2818     if (hasAttrDict)
2819       return emitError(loc, "'attr-dict' directive has already been seen");
2820     hasAttrDict = true;
2821   }
2822 
2823   return create<AttrDictDirective>(withKeyword);
2824 }
2825 
verifyCustomDirectiveArguments(SMLoc loc,ArrayRef<FormatElement * > arguments)2826 LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
2827     SMLoc loc, ArrayRef<FormatElement *> arguments) {
2828   for (FormatElement *argument : arguments) {
2829     if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable,
2830              OperandVariable, RegionVariable, SuccessorVariable>(argument)) {
2831       // TODO: FormatElement should have location info attached.
2832       return emitError(loc, "only variables and types may be used as "
2833                             "parameters to a custom directive");
2834     }
2835     if (auto *type = dyn_cast<TypeDirective>(argument)) {
2836       if (!isa<OperandVariable, ResultVariable>(type->getArg())) {
2837         return emitError(loc, "type directives within a custom directive may "
2838                               "only refer to variables");
2839       }
2840     }
2841   }
2842   return success();
2843 }
2844 
2845 FailureOr<FormatElement *>
parseFunctionalTypeDirective(SMLoc loc,Context context)2846 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) {
2847   if (context != TopLevelContext)
2848     return emitError(
2849         loc, "'functional-type' is only valid as a top-level directive");
2850 
2851   // Parse the main operand.
2852   FailureOr<FormatElement *> inputs, results;
2853   if (failed(parseToken(FormatToken::l_paren,
2854                         "expected '(' before argument list")) ||
2855       failed(inputs = parseTypeDirectiveOperand(loc)) ||
2856       failed(parseToken(FormatToken::comma,
2857                         "expected ',' after inputs argument")) ||
2858       failed(results = parseTypeDirectiveOperand(loc)) ||
2859       failed(
2860           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
2861     return failure();
2862   return create<FunctionalTypeDirective>(*inputs, *results);
2863 }
2864 
2865 FailureOr<FormatElement *>
parseOperandsDirective(SMLoc loc,Context context)2866 OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
2867   if (context == RefDirectiveContext) {
2868     if (!fmt.allOperands)
2869       return emitError(loc, "'ref' of 'operands' is not bound by a prior "
2870                             "'operands' directive");
2871 
2872   } else if (context == TopLevelContext || context == CustomDirectiveContext) {
2873     if (fmt.allOperands || !seenOperands.empty())
2874       return emitError(loc, "'operands' directive creates overlap in format");
2875     fmt.allOperands = true;
2876   }
2877   return create<OperandsDirective>();
2878 }
2879 
2880 FailureOr<FormatElement *>
parseReferenceDirective(SMLoc loc,Context context)2881 OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
2882   if (context != CustomDirectiveContext)
2883     return emitError(loc, "'ref' is only valid within a `custom` directive");
2884 
2885   FailureOr<FormatElement *> arg;
2886   if (failed(parseToken(FormatToken::l_paren,
2887                         "expected '(' before argument list")) ||
2888       failed(arg = parseElement(RefDirectiveContext)) ||
2889       failed(
2890           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
2891     return failure();
2892 
2893   return create<RefDirective>(*arg);
2894 }
2895 
2896 FailureOr<FormatElement *>
parseRegionsDirective(SMLoc loc,Context context)2897 OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
2898   if (context == TypeDirectiveContext)
2899     return emitError(loc, "'regions' is only valid as a top-level directive");
2900   if (context == RefDirectiveContext) {
2901     if (!hasAllRegions)
2902       return emitError(loc, "'ref' of 'regions' is not bound by a prior "
2903                             "'regions' directive");
2904 
2905     // Otherwise, this is a TopLevel directive.
2906   } else {
2907     if (hasAllRegions || !seenRegions.empty())
2908       return emitError(loc, "'regions' directive creates overlap in format");
2909     hasAllRegions = true;
2910   }
2911   return create<RegionsDirective>();
2912 }
2913 
2914 FailureOr<FormatElement *>
parseResultsDirective(SMLoc loc,Context context)2915 OpFormatParser::parseResultsDirective(SMLoc loc, Context context) {
2916   if (context != TypeDirectiveContext)
2917     return emitError(loc, "'results' directive can can only be used as a child "
2918                           "to a 'type' directive");
2919   return create<ResultsDirective>();
2920 }
2921 
2922 FailureOr<FormatElement *>
parseSuccessorsDirective(SMLoc loc,Context context)2923 OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) {
2924   if (context == TypeDirectiveContext)
2925     return emitError(loc,
2926                      "'successors' is only valid as a top-level directive");
2927   if (context == RefDirectiveContext) {
2928     if (!hasAllSuccessors)
2929       return emitError(loc, "'ref' of 'successors' is not bound by a prior "
2930                             "'successors' directive");
2931 
2932     // Otherwise, this is a TopLevel directive.
2933   } else {
2934     if (hasAllSuccessors || !seenSuccessors.empty())
2935       return emitError(loc, "'successors' directive creates overlap in format");
2936     hasAllSuccessors = true;
2937   }
2938   return create<SuccessorsDirective>();
2939 }
2940 
2941 FailureOr<FormatElement *>
parseOIListDirective(SMLoc loc,Context context)2942 OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
2943   if (failed(parseToken(FormatToken::l_paren,
2944                         "expected '(' before oilist argument list")))
2945     return failure();
2946   std::vector<FormatElement *> literalElements;
2947   std::vector<std::vector<FormatElement *>> parsingElements;
2948   do {
2949     FailureOr<FormatElement *> lelement = parseLiteral(context);
2950     if (failed(lelement))
2951       return failure();
2952     literalElements.push_back(*lelement);
2953     parsingElements.emplace_back();
2954     std::vector<FormatElement *> &currParsingElements = parsingElements.back();
2955     while (peekToken().getKind() != FormatToken::pipe &&
2956            peekToken().getKind() != FormatToken::r_paren) {
2957       FailureOr<FormatElement *> pelement = parseElement(context);
2958       if (failed(pelement) ||
2959           failed(verifyOIListParsingElement(*pelement, loc)))
2960         return failure();
2961       currParsingElements.push_back(*pelement);
2962     }
2963     if (peekToken().getKind() == FormatToken::pipe) {
2964       consumeToken();
2965       continue;
2966     }
2967     if (peekToken().getKind() == FormatToken::r_paren) {
2968       consumeToken();
2969       break;
2970     }
2971   } while (true);
2972 
2973   return create<OIListElement>(std::move(literalElements),
2974                                std::move(parsingElements));
2975 }
2976 
verifyOIListParsingElement(FormatElement * element,SMLoc loc)2977 LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
2978                                                          SMLoc loc) {
2979   SmallVector<VariableElement *> vars;
2980   collect(element, vars);
2981   for (VariableElement *elem : vars) {
2982     LogicalResult res =
2983         TypeSwitch<FormatElement *, LogicalResult>(elem)
2984             // Only optional attributes can be within an oilist parsing group.
2985             .Case([&](AttributeVariable *attrEle) {
2986               if (!attrEle->getVar()->attr.isOptional() &&
2987                   !attrEle->getVar()->attr.hasDefaultValue())
2988                 return emitError(loc, "only optional attributes can be used in "
2989                                       "an oilist parsing group");
2990               return success();
2991             })
2992             // Only optional-like(i.e. variadic) operands can be within an
2993             // oilist parsing group.
2994             .Case([&](OperandVariable *ele) {
2995               if (!ele->getVar()->isVariableLength())
2996                 return emitError(loc, "only variable length operands can be "
2997                                       "used within an oilist parsing group");
2998               return success();
2999             })
3000             // Only optional-like(i.e. variadic) results can be within an oilist
3001             // parsing group.
3002             .Case([&](ResultVariable *ele) {
3003               if (!ele->getVar()->isVariableLength())
3004                 return emitError(loc, "only variable length results can be "
3005                                       "used within an oilist parsing group");
3006               return success();
3007             })
3008             .Case([&](RegionVariable *) { return success(); })
3009             .Default([&](FormatElement *) {
3010               return emitError(loc,
3011                                "only literals, types, and variables can be "
3012                                "used within an oilist group");
3013             });
3014     if (failed(res))
3015       return failure();
3016   }
3017   return success();
3018 }
3019 
parseTypeDirective(SMLoc loc,Context context)3020 FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
3021                                                               Context context) {
3022   if (context == TypeDirectiveContext)
3023     return emitError(loc, "'type' cannot be used as a child of another `type`");
3024 
3025   bool isRefChild = context == RefDirectiveContext;
3026   FailureOr<FormatElement *> operand;
3027   if (failed(parseToken(FormatToken::l_paren,
3028                         "expected '(' before argument list")) ||
3029       failed(operand = parseTypeDirectiveOperand(loc, isRefChild)) ||
3030       failed(
3031           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3032     return failure();
3033 
3034   return create<TypeDirective>(*operand);
3035 }
3036 
3037 FailureOr<FormatElement *>
parseQualifiedDirective(SMLoc loc,Context context)3038 OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
3039   FailureOr<FormatElement *> element;
3040   if (failed(parseToken(FormatToken::l_paren,
3041                         "expected '(' before argument list")) ||
3042       failed(element = parseElement(context)) ||
3043       failed(
3044           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3045     return failure();
3046   return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
3047       .Case<AttributeVariable, TypeDirective>([](auto *element) {
3048         element->setShouldBeQualified();
3049         return element;
3050       })
3051       .Default([&](auto *element) {
3052         return this->emitError(
3053             loc,
3054             "'qualified' directive expects an attribute or a `type` directive");
3055       });
3056 }
3057 
3058 FailureOr<FormatElement *>
parseTypeDirectiveOperand(SMLoc loc,bool isRefChild)3059 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) {
3060   FailureOr<FormatElement *> result = parseElement(TypeDirectiveContext);
3061   if (failed(result))
3062     return failure();
3063 
3064   FormatElement *element = *result;
3065   if (isa<LiteralElement>(element))
3066     return emitError(
3067         loc, "'type' directive operand expects variable or directive operand");
3068 
3069   if (auto *var = dyn_cast<OperandVariable>(element)) {
3070     unsigned opIdx = var->getVar() - op.operand_begin();
3071     if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3072       return emitError(loc, "'type' of '" + var->getVar()->name +
3073                                 "' is already bound");
3074     if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3075       return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3076                                 ")' is not bound by a prior 'type' directive");
3077     seenOperandTypes.set(opIdx);
3078   } else if (auto *var = dyn_cast<ResultVariable>(element)) {
3079     unsigned resIdx = var->getVar() - op.result_begin();
3080     if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
3081       return emitError(loc, "'type' of '" + var->getVar()->name +
3082                                 "' is already bound");
3083     if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
3084       return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3085                                 ")' is not bound by a prior 'type' directive");
3086     seenResultTypes.set(resIdx);
3087   } else if (isa<OperandsDirective>(&*element)) {
3088     if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any()))
3089       return emitError(loc, "'operands' 'type' is already bound");
3090     if (isRefChild && !fmt.allOperandTypes)
3091       return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior "
3092                             "'type' directive");
3093     fmt.allOperandTypes = true;
3094   } else if (isa<ResultsDirective>(&*element)) {
3095     if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any()))
3096       return emitError(loc, "'results' 'type' is already bound");
3097     if (isRefChild && !fmt.allResultTypes)
3098       return emitError(loc, "'ref' of 'type(results)' is not bound by a prior "
3099                             "'type' directive");
3100     fmt.allResultTypes = true;
3101   } else {
3102     return emitError(loc, "invalid argument to 'type' directive");
3103   }
3104   return element;
3105 }
3106 
3107 LogicalResult
verifyOptionalGroupElements(SMLoc loc,ArrayRef<FormatElement * > elements,Optional<unsigned> anchorIndex)3108 OpFormatParser::verifyOptionalGroupElements(SMLoc loc,
3109                                             ArrayRef<FormatElement *> elements,
3110                                             Optional<unsigned> anchorIndex) {
3111   for (auto &it : llvm::enumerate(elements)) {
3112     if (failed(verifyOptionalGroupElement(
3113             loc, it.value(), anchorIndex && *anchorIndex == it.index())))
3114       return failure();
3115   }
3116   return success();
3117 }
3118 
verifyOptionalGroupElement(SMLoc loc,FormatElement * element,bool isAnchor)3119 LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
3120                                                          FormatElement *element,
3121                                                          bool isAnchor) {
3122   return TypeSwitch<FormatElement *, LogicalResult>(element)
3123       // All attributes can be within the optional group, but only optional
3124       // attributes can be the anchor.
3125       .Case([&](AttributeVariable *attrEle) {
3126         if (isAnchor && !attrEle->getVar()->attr.isOptional())
3127           return emitError(loc, "only optional attributes can be used to "
3128                                 "anchor an optional group");
3129         return success();
3130       })
3131       // Only optional-like(i.e. variadic) operands can be within an optional
3132       // group.
3133       .Case([&](OperandVariable *ele) {
3134         if (!ele->getVar()->isVariableLength())
3135           return emitError(loc, "only variable length operands can be used "
3136                                 "within an optional group");
3137         return success();
3138       })
3139       // Only optional-like(i.e. variadic) results can be within an optional
3140       // group.
3141       .Case([&](ResultVariable *ele) {
3142         if (!ele->getVar()->isVariableLength())
3143           return emitError(loc, "only variable length results can be used "
3144                                 "within an optional group");
3145         return success();
3146       })
3147       .Case([&](RegionVariable *) {
3148         // TODO: When ODS has proper support for marking "optional" regions, add
3149         // a check here.
3150         return success();
3151       })
3152       .Case([&](TypeDirective *ele) {
3153         return verifyOptionalGroupElement(loc, ele->getArg(),
3154                                           /*isAnchor=*/false);
3155       })
3156       .Case([&](FunctionalTypeDirective *ele) {
3157         if (failed(verifyOptionalGroupElement(loc, ele->getInputs(),
3158                                               /*isAnchor=*/false)))
3159           return failure();
3160         return verifyOptionalGroupElement(loc, ele->getResults(),
3161                                           /*isAnchor=*/false);
3162       })
3163       // Literals, whitespace, and custom directives may be used, but they can't
3164       // anchor the group.
3165       .Case<LiteralElement, WhitespaceElement, CustomDirective,
3166             FunctionalTypeDirective, OptionalElement>([&](FormatElement *) {
3167         if (isAnchor)
3168           return emitError(loc, "only variables and types can be used "
3169                                 "to anchor an optional group");
3170         return success();
3171       })
3172       .Default([&](FormatElement *) {
3173         return emitError(loc, "only literals, types, and variables can be "
3174                               "used within an optional group");
3175       });
3176 }
3177 
3178 //===----------------------------------------------------------------------===//
3179 // Interface
3180 //===----------------------------------------------------------------------===//
3181 
generateOpFormat(const Operator & constOp,OpClass & opClass)3182 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
3183   // TODO: Operator doesn't expose all necessary functionality via
3184   // the const interface.
3185   Operator &op = const_cast<Operator &>(constOp);
3186   if (!op.hasAssemblyFormat())
3187     return;
3188 
3189   // Parse the format description.
3190   llvm::SourceMgr mgr;
3191   mgr.AddNewSourceBuffer(
3192       llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), SMLoc());
3193   OperationFormat format(op);
3194   OpFormatParser parser(mgr, format, op);
3195   FailureOr<std::vector<FormatElement *>> elements = parser.parse();
3196   if (failed(elements)) {
3197     // Exit the process if format errors are treated as fatal.
3198     if (formatErrorIsFatal) {
3199       // Invoke the interrupt handlers to run the file cleanup handlers.
3200       llvm::sys::RunInterruptHandlers();
3201       std::exit(1);
3202     }
3203     return;
3204   }
3205   format.elements = std::move(*elements);
3206 
3207   // Generate the printer and parser based on the parsed format.
3208   format.genParser(op, opClass);
3209   format.genPrinter(op, opClass);
3210 }
3211