1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
11 #include "OpClass.h"
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/TableGen/Record.h"
26
27 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
28
29 using namespace mlir;
30 using namespace mlir::tblgen;
31
32 //===----------------------------------------------------------------------===//
33 // VariableElement
34
35 namespace {
36 /// This class represents an instance of an op variable element. A variable
37 /// refers to something registered on the operation itself, e.g. an operand,
38 /// result, attribute, region, or successor.
39 template <typename VarT, VariableElement::Kind VariableKind>
40 class OpVariableElement : public VariableElementBase<VariableKind> {
41 public:
42 using Base = OpVariableElement<VarT, VariableKind>;
43
44 /// Create an op variable element with the variable value.
OpVariableElement(const VarT * var)45 OpVariableElement(const VarT *var) : var(var) {}
46
47 /// Get the variable.
getVar()48 const VarT *getVar() { return var; }
49
50 protected:
51 /// The op variable, e.g. a type or attribute constraint.
52 const VarT *var;
53 };
54
55 /// This class represents a variable that refers to an attribute argument.
56 struct AttributeVariable
57 : public OpVariableElement<NamedAttribute, VariableElement::Attribute> {
58 using Base::Base;
59
60 /// Return the constant builder call for the type of this attribute, or None
61 /// if it doesn't have one.
getTypeBuilder__anonf2a272a10111::AttributeVariable62 llvm::Optional<StringRef> getTypeBuilder() const {
63 llvm::Optional<Type> attrType = var->attr.getValueType();
64 return attrType ? attrType->getBuilderCall() : llvm::None;
65 }
66
67 /// Return if this attribute refers to a UnitAttr.
isUnitAttr__anonf2a272a10111::AttributeVariable68 bool isUnitAttr() const {
69 return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
70 }
71
72 /// Indicate if this attribute is printed "qualified" (that is it is
73 /// prefixed with the `#dialect.mnemonic`).
shouldBeQualified__anonf2a272a10111::AttributeVariable74 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
setShouldBeQualified__anonf2a272a10111::AttributeVariable75 void setShouldBeQualified(bool qualified = true) {
76 shouldBeQualifiedFlag = qualified;
77 }
78
79 private:
80 bool shouldBeQualifiedFlag = false;
81 };
82
83 /// This class represents a variable that refers to an operand argument.
84 using OperandVariable =
85 OpVariableElement<NamedTypeConstraint, VariableElement::Operand>;
86
87 /// This class represents a variable that refers to a result.
88 using ResultVariable =
89 OpVariableElement<NamedTypeConstraint, VariableElement::Result>;
90
91 /// This class represents a variable that refers to a region.
92 using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>;
93
94 /// This class represents a variable that refers to a successor.
95 using SuccessorVariable =
96 OpVariableElement<NamedSuccessor, VariableElement::Successor>;
97 } // namespace
98
99 //===----------------------------------------------------------------------===//
100 // DirectiveElement
101
102 namespace {
103 /// This class represents the `operands` directive. This directive represents
104 /// all of the operands of an operation.
105 using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>;
106
107 /// This class represents the `results` directive. This directive represents
108 /// all of the results of an operation.
109 using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>;
110
111 /// This class represents the `regions` directive. This directive represents
112 /// all of the regions of an operation.
113 using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>;
114
115 /// This class represents the `successors` directive. This directive represents
116 /// all of the successors of an operation.
117 using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>;
118
119 /// This class represents the `attr-dict` directive. This directive represents
120 /// the attribute dictionary of the operation.
121 class AttrDictDirective
122 : public DirectiveElementBase<DirectiveElement::AttrDict> {
123 public:
AttrDictDirective(bool withKeyword)124 explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
125
126 /// Return whether the dictionary should be printed with the 'attributes'
127 /// keyword.
isWithKeyword() const128 bool isWithKeyword() const { return withKeyword; }
129
130 private:
131 /// If the dictionary should be printed with the 'attributes' keyword.
132 bool withKeyword;
133 };
134
135 /// This class represents the `functional-type` directive. This directive takes
136 /// two arguments and formats them, respectively, as the inputs and results of a
137 /// FunctionType.
138 class FunctionalTypeDirective
139 : public DirectiveElementBase<DirectiveElement::FunctionalType> {
140 public:
FunctionalTypeDirective(FormatElement * inputs,FormatElement * results)141 FunctionalTypeDirective(FormatElement *inputs, FormatElement *results)
142 : inputs(inputs), results(results) {}
143
getInputs() const144 FormatElement *getInputs() const { return inputs; }
getResults() const145 FormatElement *getResults() const { return results; }
146
147 private:
148 /// The input and result arguments.
149 FormatElement *inputs, *results;
150 };
151
152 /// This class represents the `type` directive.
153 class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
154 public:
TypeDirective(FormatElement * arg)155 TypeDirective(FormatElement *arg) : arg(arg) {}
156
getArg() const157 FormatElement *getArg() const { return arg; }
158
159 /// Indicate if this type is printed "qualified" (that is it is
160 /// prefixed with the `!dialect.mnemonic`).
shouldBeQualified()161 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
setShouldBeQualified(bool qualified=true)162 void setShouldBeQualified(bool qualified = true) {
163 shouldBeQualifiedFlag = qualified;
164 }
165
166 private:
167 /// The argument that is used to format the directive.
168 FormatElement *arg;
169
170 bool shouldBeQualifiedFlag = false;
171 };
172
173 /// This class represents a group of order-independent optional clauses. Each
174 /// clause starts with a literal element and has a coressponding parsing
175 /// element. A parsing element is a continous sequence of format elements.
176 /// Each clause can appear 0 or 1 time.
177 class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> {
178 public:
OIListElement(std::vector<FormatElement * > && literalElements,std::vector<std::vector<FormatElement * >> && parsingElements)179 OIListElement(std::vector<FormatElement *> &&literalElements,
180 std::vector<std::vector<FormatElement *>> &&parsingElements)
181 : literalElements(std::move(literalElements)),
182 parsingElements(std::move(parsingElements)) {}
183
184 /// Returns a range to iterate over the LiteralElements.
getLiteralElements() const185 auto getLiteralElements() const {
186 function_ref<LiteralElement *(FormatElement * el)>
187 literalElementCastConverter =
188 [](FormatElement *el) { return cast<LiteralElement>(el); };
189 return llvm::map_range(literalElements, literalElementCastConverter);
190 }
191
192 /// Returns a range to iterate over the parsing elements corresponding to the
193 /// clauses.
getParsingElements() const194 ArrayRef<std::vector<FormatElement *>> getParsingElements() const {
195 return parsingElements;
196 }
197
198 /// Returns a range to iterate over tuples of parsing and literal elements.
getClauses() const199 auto getClauses() const {
200 return llvm::zip(getLiteralElements(), getParsingElements());
201 }
202
203 /// If the parsing element is a single UnitAttr element, then it returns the
204 /// attribute variable. Otherwise, returns nullptr.
205 AttributeVariable *
getUnitAttrParsingElement(ArrayRef<FormatElement * > pelement)206 getUnitAttrParsingElement(ArrayRef<FormatElement *> pelement) {
207 if (pelement.size() == 1) {
208 auto *attrElem = dyn_cast<AttributeVariable>(pelement[0]);
209 if (attrElem && attrElem->isUnitAttr())
210 return attrElem;
211 }
212 return nullptr;
213 }
214
215 private:
216 /// A vector of `LiteralElement` objects. Each element stores the keyword
217 /// for one case of oilist element. For example, an oilist element along with
218 /// the `literalElements` vector:
219 /// ```
220 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
221 /// literalElements = { `keyword`, `otherKeyword` }
222 /// ```
223 std::vector<FormatElement *> literalElements;
224
225 /// A vector of valid declarative assembly format vectors. Each object in
226 /// parsing elements is a vector of elements in assembly format syntax.
227 /// For example, an oilist element along with the parsingElements vector:
228 /// ```
229 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
230 /// parsingElements = {
231 /// { `=`, `(`, $arg0, `)` },
232 /// { `<`, $arg1, `>` }
233 /// }
234 /// ```
235 std::vector<std::vector<FormatElement *>> parsingElements;
236 };
237 } // namespace
238
239 //===----------------------------------------------------------------------===//
240 // OperationFormat
241 //===----------------------------------------------------------------------===//
242
243 namespace {
244
245 using ConstArgument =
246 llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
247
248 struct OperationFormat {
249 /// This class represents a specific resolver for an operand or result type.
250 class TypeResolution {
251 public:
252 TypeResolution() = default;
253
254 /// Get the index into the buildable types for this type, or None.
getBuilderIdx() const255 Optional<int> getBuilderIdx() const { return builderIdx; }
setBuilderIdx(int idx)256 void setBuilderIdx(int idx) { builderIdx = idx; }
257
258 /// Get the variable this type is resolved to, or nullptr.
getVariable() const259 const NamedTypeConstraint *getVariable() const {
260 return resolver.dyn_cast<const NamedTypeConstraint *>();
261 }
262 /// Get the attribute this type is resolved to, or nullptr.
getAttribute() const263 const NamedAttribute *getAttribute() const {
264 return resolver.dyn_cast<const NamedAttribute *>();
265 }
266 /// Get the transformer for the type of the variable, or None.
getVarTransformer() const267 Optional<StringRef> getVarTransformer() const {
268 return variableTransformer;
269 }
setResolver(ConstArgument arg,Optional<StringRef> transformer)270 void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
271 resolver = arg;
272 variableTransformer = transformer;
273 assert(getVariable() || getAttribute());
274 }
275
276 private:
277 /// If the type is resolved with a buildable type, this is the index into
278 /// 'buildableTypes' in the parent format.
279 Optional<int> builderIdx;
280 /// If the type is resolved based upon another operand or result, this is
281 /// the variable or the attribute that this type is resolved to.
282 ConstArgument resolver;
283 /// If the type is resolved based upon another operand or result, this is
284 /// a transformer to apply to the variable when resolving.
285 Optional<StringRef> variableTransformer;
286 };
287
288 /// The context in which an element is generated.
289 enum class GenContext {
290 /// The element is generated at the top-level or with the same behaviour.
291 Normal,
292 /// The element is generated inside an optional group.
293 Optional
294 };
295
OperationFormat__anonf2a272a10411::OperationFormat296 OperationFormat(const Operator &op)
297
298 {
299 operandTypes.resize(op.getNumOperands(), TypeResolution());
300 resultTypes.resize(op.getNumResults(), TypeResolution());
301
302 hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
303 return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
304 });
305
306 hasSingleBlockTrait =
307 hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock");
308 }
309
310 /// Generate the operation parser from this format.
311 void genParser(Operator &op, OpClass &opClass);
312 /// Generate the parser code for a specific format element.
313 void genElementParser(FormatElement *element, MethodBody &body,
314 FmtContext &attrTypeCtx,
315 GenContext genCtx = GenContext::Normal);
316 /// Generate the C++ to resolve the types of operands and results during
317 /// parsing.
318 void genParserTypeResolution(Operator &op, MethodBody &body);
319 /// Generate the C++ to resolve the types of the operands during parsing.
320 void genParserOperandTypeResolution(
321 Operator &op, MethodBody &body,
322 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
323 /// Generate the C++ to resolve regions during parsing.
324 void genParserRegionResolution(Operator &op, MethodBody &body);
325 /// Generate the C++ to resolve successors during parsing.
326 void genParserSuccessorResolution(Operator &op, MethodBody &body);
327 /// Generate the C++ to handling variadic segment size traits.
328 void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
329
330 /// Generate the operation printer from this format.
331 void genPrinter(Operator &op, OpClass &opClass);
332
333 /// Generate the printer code for a specific format element.
334 void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op,
335 bool &shouldEmitSpace, bool &lastWasPunctuation);
336
337 /// The various elements in this format.
338 std::vector<FormatElement *> elements;
339
340 /// A flag indicating if all operand/result types were seen. If the format
341 /// contains these, it can not contain individual type resolvers.
342 bool allOperands = false, allOperandTypes = false, allResultTypes = false;
343
344 /// A flag indicating if this operation infers its result types
345 bool infersResultTypes = false;
346
347 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
348 /// trait.
349 bool hasImplicitTermTrait;
350
351 /// A flag indicating if this operation has the SingleBlock trait.
352 bool hasSingleBlockTrait;
353
354 /// A map of buildable types to indices.
355 llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
356
357 /// The index of the buildable type, if valid, for every operand and result.
358 std::vector<TypeResolution> operandTypes, resultTypes;
359
360 /// The set of attributes explicitly used within the format.
361 SmallVector<const NamedAttribute *, 8> usedAttributes;
362 llvm::StringSet<> inferredAttributes;
363 };
364 } // namespace
365
366 //===----------------------------------------------------------------------===//
367 // Parser Gen
368
369 /// Returns true if we can format the given attribute as an EnumAttr in the
370 /// parser format.
canFormatEnumAttr(const NamedAttribute * attr)371 static bool canFormatEnumAttr(const NamedAttribute *attr) {
372 Attribute baseAttr = attr->attr.getBaseAttr();
373 const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
374 if (!enumAttr)
375 return false;
376
377 // The attribute must have a valid underlying type and a constant builder.
378 return !enumAttr->getUnderlyingType().empty() &&
379 !enumAttr->getConstBuilderTemplate().empty();
380 }
381
382 /// Returns if we should format the given attribute as an SymbolNameAttr.
shouldFormatSymbolNameAttr(const NamedAttribute * attr)383 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
384 return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
385 }
386
387 /// The code snippet used to generate a parser call for an attribute.
388 ///
389 /// {0}: The name of the attribute.
390 /// {1}: The type for the attribute.
391 const char *const attrParserCode = R"(
392 if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}",
393 result.attributes)) {{
394 return ::mlir::failure();
395 }
396 )";
397
398 /// The code snippet used to generate a parser call for an attribute.
399 ///
400 /// {0}: The name of the attribute.
401 /// {1}: The type for the attribute.
402 const char *const genericAttrParserCode = R"(
403 if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes))
404 return ::mlir::failure();
405 )";
406
407 const char *const optionalAttrParserCode = R"(
408 {
409 ::mlir::OptionalParseResult parseResult =
410 parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes);
411 if (parseResult.hasValue() && failed(*parseResult))
412 return ::mlir::failure();
413 }
414 )";
415
416 /// The code snippet used to generate a parser call for a symbol name attribute.
417 ///
418 /// {0}: The name of the attribute.
419 const char *const symbolNameAttrParserCode = R"(
420 if (parser.parseSymbolName({0}Attr, "{0}", result.attributes))
421 return ::mlir::failure();
422 )";
423 const char *const optionalSymbolNameAttrParserCode = R"(
424 // Parsing an optional symbol name doesn't fail, so no need to check the
425 // result.
426 (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes);
427 )";
428
429 /// The code snippet used to generate a parser call for an enum attribute.
430 ///
431 /// {0}: The name of the attribute.
432 /// {1}: The c++ namespace for the enum symbolize functions.
433 /// {2}: The function to symbolize a string of the enum.
434 /// {3}: The constant builder call to create an attribute of the enum type.
435 /// {4}: The set of allowed enum keywords.
436 /// {5}: The error message on failure when the enum isn't present.
437 const char *const enumAttrParserCode = R"(
438 {
439 ::llvm::StringRef attrStr;
440 ::mlir::NamedAttrList attrStorage;
441 auto loc = parser.getCurrentLocation();
442 if (parser.parseOptionalKeyword(&attrStr, {4})) {
443 ::mlir::StringAttr attrVal;
444 ::mlir::OptionalParseResult parseResult =
445 parser.parseOptionalAttribute(attrVal,
446 parser.getBuilder().getNoneType(),
447 "{0}", attrStorage);
448 if (parseResult.hasValue()) {{
449 if (failed(*parseResult))
450 return ::mlir::failure();
451 attrStr = attrVal.getValue();
452 } else {
453 {5}
454 }
455 }
456 if (!attrStr.empty()) {
457 auto attrOptional = {1}::{2}(attrStr);
458 if (!attrOptional)
459 return parser.emitError(loc, "invalid ")
460 << "{0} attribute specification: \"" << attrStr << '"';;
461
462 {0}Attr = {3};
463 result.addAttribute("{0}", {0}Attr);
464 }
465 }
466 )";
467
468 /// The code snippet used to generate a parser call for an operand.
469 ///
470 /// {0}: The name of the operand.
471 const char *const variadicOperandParserCode = R"(
472 {0}OperandsLoc = parser.getCurrentLocation();
473 if (parser.parseOperandList({0}Operands))
474 return ::mlir::failure();
475 )";
476 const char *const optionalOperandParserCode = R"(
477 {
478 {0}OperandsLoc = parser.getCurrentLocation();
479 ::mlir::OpAsmParser::UnresolvedOperand operand;
480 ::mlir::OptionalParseResult parseResult =
481 parser.parseOptionalOperand(operand);
482 if (parseResult.hasValue()) {
483 if (failed(*parseResult))
484 return ::mlir::failure();
485 {0}Operands.push_back(operand);
486 }
487 }
488 )";
489 const char *const operandParserCode = R"(
490 {0}OperandsLoc = parser.getCurrentLocation();
491 if (parser.parseOperand({0}RawOperands[0]))
492 return ::mlir::failure();
493 )";
494 /// The code snippet used to generate a parser call for a VariadicOfVariadic
495 /// operand.
496 ///
497 /// {0}: The name of the operand.
498 /// {1}: The name of segment size attribute.
499 const char *const variadicOfVariadicOperandParserCode = R"(
500 {
501 {0}OperandsLoc = parser.getCurrentLocation();
502 int32_t curSize = 0;
503 do {
504 if (parser.parseOptionalLParen())
505 break;
506 if (parser.parseOperandList({0}Operands) || parser.parseRParen())
507 return ::mlir::failure();
508 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
509 curSize = {0}Operands.size();
510 } while (succeeded(parser.parseOptionalComma()));
511 }
512 )";
513
514 /// The code snippet used to generate a parser call for a type list.
515 ///
516 /// {0}: The name for the type list.
517 const char *const variadicOfVariadicTypeParserCode = R"(
518 do {
519 if (parser.parseOptionalLParen())
520 break;
521 if (parser.parseOptionalRParen() &&
522 (parser.parseTypeList({0}Types) || parser.parseRParen()))
523 return ::mlir::failure();
524 } while (succeeded(parser.parseOptionalComma()));
525 )";
526 const char *const variadicTypeParserCode = R"(
527 if (parser.parseTypeList({0}Types))
528 return ::mlir::failure();
529 )";
530 const char *const optionalTypeParserCode = R"(
531 {
532 ::mlir::Type optionalType;
533 ::mlir::OptionalParseResult parseResult =
534 parser.parseOptionalType(optionalType);
535 if (parseResult.hasValue()) {
536 if (failed(*parseResult))
537 return ::mlir::failure();
538 {0}Types.push_back(optionalType);
539 }
540 }
541 )";
542 const char *const typeParserCode = R"(
543 {
544 {0} type;
545 if (parser.parseCustomTypeWithFallback(type))
546 return ::mlir::failure();
547 {1}RawTypes[0] = type;
548 }
549 )";
550 const char *const qualifiedTypeParserCode = R"(
551 if (parser.parseType({1}RawTypes[0]))
552 return ::mlir::failure();
553 )";
554
555 /// The code snippet used to generate a parser call for a functional type.
556 ///
557 /// {0}: The name for the input type list.
558 /// {1}: The name for the result type list.
559 const char *const functionalTypeParserCode = R"(
560 ::mlir::FunctionType {0}__{1}_functionType;
561 if (parser.parseType({0}__{1}_functionType))
562 return ::mlir::failure();
563 {0}Types = {0}__{1}_functionType.getInputs();
564 {1}Types = {0}__{1}_functionType.getResults();
565 )";
566
567 /// The code snippet used to generate a parser call to infer return types.
568 ///
569 /// {0}: The operation class name
570 const char *const inferReturnTypesParserCode = R"(
571 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
572 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
573 result.location, result.operands,
574 result.attributes.getDictionary(parser.getContext()),
575 result.regions, inferredReturnTypes)))
576 return ::mlir::failure();
577 result.addTypes(inferredReturnTypes);
578 )";
579
580 /// The code snippet used to generate a parser call for a region list.
581 ///
582 /// {0}: The name for the region list.
583 const char *regionListParserCode = R"(
584 {
585 std::unique_ptr<::mlir::Region> region;
586 auto firstRegionResult = parser.parseOptionalRegion(region);
587 if (firstRegionResult.hasValue()) {
588 if (failed(*firstRegionResult))
589 return ::mlir::failure();
590 {0}Regions.emplace_back(std::move(region));
591
592 // Parse any trailing regions.
593 while (succeeded(parser.parseOptionalComma())) {
594 region = std::make_unique<::mlir::Region>();
595 if (parser.parseRegion(*region))
596 return ::mlir::failure();
597 {0}Regions.emplace_back(std::move(region));
598 }
599 }
600 }
601 )";
602
603 /// The code snippet used to ensure a list of regions have terminators.
604 ///
605 /// {0}: The name of the region list.
606 const char *regionListEnsureTerminatorParserCode = R"(
607 for (auto ®ion : {0}Regions)
608 ensureTerminator(*region, parser.getBuilder(), result.location);
609 )";
610
611 /// The code snippet used to ensure a list of regions have a block.
612 ///
613 /// {0}: The name of the region list.
614 const char *regionListEnsureSingleBlockParserCode = R"(
615 for (auto ®ion : {0}Regions)
616 if (region->empty()) region->emplaceBlock();
617 )";
618
619 /// The code snippet used to generate a parser call for an optional region.
620 ///
621 /// {0}: The name of the region.
622 const char *optionalRegionParserCode = R"(
623 {
624 auto parseResult = parser.parseOptionalRegion(*{0}Region);
625 if (parseResult.hasValue() && failed(*parseResult))
626 return ::mlir::failure();
627 }
628 )";
629
630 /// The code snippet used to generate a parser call for a region.
631 ///
632 /// {0}: The name of the region.
633 const char *regionParserCode = R"(
634 if (parser.parseRegion(*{0}Region))
635 return ::mlir::failure();
636 )";
637
638 /// The code snippet used to ensure a region has a terminator.
639 ///
640 /// {0}: The name of the region.
641 const char *regionEnsureTerminatorParserCode = R"(
642 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
643 )";
644
645 /// The code snippet used to ensure a region has a block.
646 ///
647 /// {0}: The name of the region.
648 const char *regionEnsureSingleBlockParserCode = R"(
649 if ({0}Region->empty()) {0}Region->emplaceBlock();
650 )";
651
652 /// The code snippet used to generate a parser call for a successor list.
653 ///
654 /// {0}: The name for the successor list.
655 const char *successorListParserCode = R"(
656 {
657 ::mlir::Block *succ;
658 auto firstSucc = parser.parseOptionalSuccessor(succ);
659 if (firstSucc.hasValue()) {
660 if (failed(*firstSucc))
661 return ::mlir::failure();
662 {0}Successors.emplace_back(succ);
663
664 // Parse any trailing successors.
665 while (succeeded(parser.parseOptionalComma())) {
666 if (parser.parseSuccessor(succ))
667 return ::mlir::failure();
668 {0}Successors.emplace_back(succ);
669 }
670 }
671 }
672 )";
673
674 /// The code snippet used to generate a parser call for a successor.
675 ///
676 /// {0}: The name of the successor.
677 const char *successorParserCode = R"(
678 if (parser.parseSuccessor({0}Successor))
679 return ::mlir::failure();
680 )";
681
682 /// The code snippet used to generate a parser for OIList
683 ///
684 /// {0}: literal keyword corresponding to a case for oilist
685 const char *oilistParserCode = R"(
686 if ({0}Clause) {
687 return parser.emitError(parser.getNameLoc())
688 << "`{0}` clause can appear at most once in the expansion of the "
689 "oilist directive";
690 }
691 {0}Clause = true;
692 )";
693
694 namespace {
695 /// The type of length for a given parse argument.
696 enum class ArgumentLengthKind {
697 /// The argument is a variadic of a variadic, and may contain 0->N range
698 /// elements.
699 VariadicOfVariadic,
700 /// The argument is variadic, and may contain 0->N elements.
701 Variadic,
702 /// The argument is optional, and may contain 0 or 1 elements.
703 Optional,
704 /// The argument is a single element, i.e. always represents 1 element.
705 Single
706 };
707 } // namespace
708
709 /// Get the length kind for the given constraint.
710 static ArgumentLengthKind
getArgumentLengthKind(const NamedTypeConstraint * var)711 getArgumentLengthKind(const NamedTypeConstraint *var) {
712 if (var->isOptional())
713 return ArgumentLengthKind::Optional;
714 if (var->isVariadicOfVariadic())
715 return ArgumentLengthKind::VariadicOfVariadic;
716 if (var->isVariadic())
717 return ArgumentLengthKind::Variadic;
718 return ArgumentLengthKind::Single;
719 }
720
721 /// Get the name used for the type list for the given type directive operand.
722 /// 'lengthKind' to the corresponding kind for the given argument.
723 static StringRef getTypeListName(FormatElement *arg,
724 ArgumentLengthKind &lengthKind) {
725 if (auto *operand = dyn_cast<OperandVariable>(arg)) {
726 lengthKind = getArgumentLengthKind(operand->getVar());
727 return operand->getVar()->name;
728 }
729 if (auto *result = dyn_cast<ResultVariable>(arg)) {
730 lengthKind = getArgumentLengthKind(result->getVar());
731 return result->getVar()->name;
732 }
733 lengthKind = ArgumentLengthKind::Variadic;
734 if (isa<OperandsDirective>(arg))
735 return "allOperand";
736 if (isa<ResultsDirective>(arg))
737 return "allResult";
738 llvm_unreachable("unknown 'type' directive argument");
739 }
740
741 /// Generate the parser for a literal value.
742 static void genLiteralParser(StringRef value, MethodBody &body) {
743 // Handle the case of a keyword/identifier.
744 if (value.front() == '_' || isalpha(value.front())) {
745 body << "Keyword(\"" << value << "\")";
746 return;
747 }
748 body << (StringRef)StringSwitch<StringRef>(value)
749 .Case("->", "Arrow()")
750 .Case(":", "Colon()")
751 .Case(",", "Comma()")
752 .Case("=", "Equal()")
753 .Case("<", "Less()")
754 .Case(">", "Greater()")
755 .Case("{", "LBrace()")
756 .Case("}", "RBrace()")
757 .Case("(", "LParen()")
758 .Case(")", "RParen()")
759 .Case("[", "LSquare()")
760 .Case("]", "RSquare()")
761 .Case("?", "Question()")
762 .Case("+", "Plus()")
763 .Case("*", "Star()");
764 }
765
766 /// Generate the storage code required for parsing the given element.
genElementParserStorage(FormatElement * element,const Operator & op,MethodBody & body)767 static void genElementParserStorage(FormatElement *element, const Operator &op,
768 MethodBody &body) {
769 if (auto *optional = dyn_cast<OptionalElement>(element)) {
770 ArrayRef<FormatElement *> elements = optional->getThenElements();
771
772 // If the anchor is a unit attribute, it won't be parsed directly so elide
773 // it.
774 auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
775 FormatElement *elidedAnchorElement = nullptr;
776 if (anchor && anchor != elements.front() && anchor->isUnitAttr())
777 elidedAnchorElement = anchor;
778 for (FormatElement *childElement : elements)
779 if (childElement != elidedAnchorElement)
780 genElementParserStorage(childElement, op, body);
781 for (FormatElement *childElement : optional->getElseElements())
782 genElementParserStorage(childElement, op, body);
783
784 } else if (auto *oilist = dyn_cast<OIListElement>(element)) {
785 for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) {
786 if (!oilist->getUnitAttrParsingElement(pelement))
787 for (FormatElement *element : pelement)
788 genElementParserStorage(element, op, body);
789 }
790
791 } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
792 for (FormatElement *paramElement : custom->getArguments())
793 genElementParserStorage(paramElement, op, body);
794
795 } else if (isa<OperandsDirective>(element)) {
796 body << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
797 "allOperands;\n";
798
799 } else if (isa<RegionsDirective>(element)) {
800 body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
801 "fullRegions;\n";
802
803 } else if (isa<SuccessorsDirective>(element)) {
804 body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
805
806 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
807 const NamedAttribute *var = attr->getVar();
808 body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
809 var->name);
810
811 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
812 StringRef name = operand->getVar()->name;
813 if (operand->getVar()->isVariableLength()) {
814 body
815 << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
816 << name << "Operands;\n";
817 if (operand->getVar()->isVariadicOfVariadic()) {
818 body << " llvm::SmallVector<int32_t> " << name
819 << "OperandGroupSizes;\n";
820 }
821 } else {
822 body << " ::mlir::OpAsmParser::UnresolvedOperand " << name
823 << "RawOperands[1];\n"
824 << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
825 << name << "Operands(" << name << "RawOperands);";
826 }
827 body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
828 " (void){0}OperandsLoc;\n",
829 name);
830
831 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
832 StringRef name = region->getVar()->name;
833 if (region->getVar()->isVariadic()) {
834 body << llvm::formatv(
835 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
836 "{0}Regions;\n",
837 name);
838 } else {
839 body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
840 "std::make_unique<::mlir::Region>();\n",
841 name);
842 }
843
844 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
845 StringRef name = successor->getVar()->name;
846 if (successor->getVar()->isVariadic()) {
847 body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
848 "{0}Successors;\n",
849 name);
850 } else {
851 body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
852 }
853
854 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
855 ArgumentLengthKind lengthKind;
856 StringRef name = getTypeListName(dir->getArg(), lengthKind);
857 if (lengthKind != ArgumentLengthKind::Single)
858 body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
859 else
860 body << llvm::formatv(" ::mlir::Type {0}RawTypes[1];\n", name)
861 << llvm::formatv(
862 " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
863 name);
864 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
865 ArgumentLengthKind ignored;
866 body << " ::llvm::ArrayRef<::mlir::Type> "
867 << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
868 body << " ::llvm::ArrayRef<::mlir::Type> "
869 << getTypeListName(dir->getResults(), ignored) << "Types;\n";
870 }
871 }
872
873 /// Generate the parser for a parameter to a custom directive.
genCustomParameterParser(FormatElement * param,MethodBody & body)874 static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
875 if (auto *attr = dyn_cast<AttributeVariable>(param)) {
876 body << attr->getVar()->name << "Attr";
877 } else if (isa<AttrDictDirective>(param)) {
878 body << "result.attributes";
879 } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
880 StringRef name = operand->getVar()->name;
881 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
882 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
883 body << llvm::formatv("{0}OperandGroups", name);
884 else if (lengthKind == ArgumentLengthKind::Variadic)
885 body << llvm::formatv("{0}Operands", name);
886 else if (lengthKind == ArgumentLengthKind::Optional)
887 body << llvm::formatv("{0}Operand", name);
888 else
889 body << formatv("{0}RawOperands[0]", name);
890
891 } else if (auto *region = dyn_cast<RegionVariable>(param)) {
892 StringRef name = region->getVar()->name;
893 if (region->getVar()->isVariadic())
894 body << llvm::formatv("{0}Regions", name);
895 else
896 body << llvm::formatv("*{0}Region", name);
897
898 } else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
899 StringRef name = successor->getVar()->name;
900 if (successor->getVar()->isVariadic())
901 body << llvm::formatv("{0}Successors", name);
902 else
903 body << llvm::formatv("{0}Successor", name);
904
905 } else if (auto *dir = dyn_cast<RefDirective>(param)) {
906 genCustomParameterParser(dir->getArg(), body);
907
908 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
909 ArgumentLengthKind lengthKind;
910 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
911 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
912 body << llvm::formatv("{0}TypeGroups", listName);
913 else if (lengthKind == ArgumentLengthKind::Variadic)
914 body << llvm::formatv("{0}Types", listName);
915 else if (lengthKind == ArgumentLengthKind::Optional)
916 body << llvm::formatv("{0}Type", listName);
917 else
918 body << formatv("{0}RawTypes[0]", listName);
919 } else {
920 llvm_unreachable("unknown custom directive parameter");
921 }
922 }
923
924 /// Generate the parser for a custom directive.
genCustomDirectiveParser(CustomDirective * dir,MethodBody & body)925 static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
926 body << " {\n";
927
928 // Preprocess the directive variables.
929 // * Add a local variable for optional operands and types. This provides a
930 // better API to the user defined parser methods.
931 // * Set the location of operand variables.
932 for (FormatElement *param : dir->getArguments()) {
933 if (auto *operand = dyn_cast<OperandVariable>(param)) {
934 auto *var = operand->getVar();
935 body << " " << var->name
936 << "OperandsLoc = parser.getCurrentLocation();\n";
937 if (var->isOptional()) {
938 body << llvm::formatv(
939 " ::llvm::Optional<::mlir::OpAsmParser::UnresolvedOperand> "
940 "{0}Operand;\n",
941 var->name);
942 } else if (var->isVariadicOfVariadic()) {
943 body << llvm::formatv(" "
944 "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
945 "OpAsmParser::UnresolvedOperand>> "
946 "{0}OperandGroups;\n",
947 var->name);
948 }
949 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
950 ArgumentLengthKind lengthKind;
951 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
952 if (lengthKind == ArgumentLengthKind::Optional) {
953 body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
954 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
955 body << llvm::formatv(
956 " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
957 "{0}TypeGroups;\n",
958 listName);
959 }
960 } else if (auto *dir = dyn_cast<RefDirective>(param)) {
961 FormatElement *input = dir->getArg();
962 if (auto *operand = dyn_cast<OperandVariable>(input)) {
963 if (!operand->getVar()->isOptional())
964 continue;
965 body << llvm::formatv(
966 " {0} {1}Operand = {1}Operands.empty() ? {0}() : "
967 "{1}Operands[0];\n",
968 "::llvm::Optional<::mlir::OpAsmParser::UnresolvedOperand>",
969 operand->getVar()->name);
970
971 } else if (auto *type = dyn_cast<TypeDirective>(input)) {
972 ArgumentLengthKind lengthKind;
973 StringRef listName = getTypeListName(type->getArg(), lengthKind);
974 if (lengthKind == ArgumentLengthKind::Optional) {
975 body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
976 "::mlir::Type() : {0}Types[0];\n",
977 listName);
978 }
979 }
980 }
981 }
982
983 body << " if (parse" << dir->getName() << "(parser";
984 for (FormatElement *param : dir->getArguments()) {
985 body << ", ";
986 genCustomParameterParser(param, body);
987 }
988
989 body << "))\n"
990 << " return ::mlir::failure();\n";
991
992 // After parsing, add handling for any of the optional constructs.
993 for (FormatElement *param : dir->getArguments()) {
994 if (auto *attr = dyn_cast<AttributeVariable>(param)) {
995 const NamedAttribute *var = attr->getVar();
996 if (var->attr.isOptional())
997 body << llvm::formatv(" if ({0}Attr)\n ", var->name);
998
999 body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
1000 var->name);
1001 } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
1002 const NamedTypeConstraint *var = operand->getVar();
1003 if (var->isOptional()) {
1004 body << llvm::formatv(" if ({0}Operand.has_value())\n"
1005 " {0}Operands.push_back(*{0}Operand);\n",
1006 var->name);
1007 } else if (var->isVariadicOfVariadic()) {
1008 body << llvm::formatv(
1009 " for (const auto &subRange : {0}OperandGroups) {{\n"
1010 " {0}Operands.append(subRange.begin(), subRange.end());\n"
1011 " {0}OperandGroupSizes.push_back(subRange.size());\n"
1012 " }\n",
1013 var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr());
1014 }
1015 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
1016 ArgumentLengthKind lengthKind;
1017 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1018 if (lengthKind == ArgumentLengthKind::Optional) {
1019 body << llvm::formatv(" if ({0}Type)\n"
1020 " {0}Types.push_back({0}Type);\n",
1021 listName);
1022 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1023 body << llvm::formatv(
1024 " for (const auto &subRange : {0}TypeGroups)\n"
1025 " {0}Types.append(subRange.begin(), subRange.end());\n",
1026 listName);
1027 }
1028 }
1029 }
1030
1031 body << " }\n";
1032 }
1033
1034 /// Generate the parser for a enum attribute.
genEnumAttrParser(const NamedAttribute * var,MethodBody & body,FmtContext & attrTypeCtx)1035 static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
1036 FmtContext &attrTypeCtx) {
1037 Attribute baseAttr = var->attr.getBaseAttr();
1038 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1039 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1040
1041 // Generate the code for building an attribute for this enum.
1042 std::string attrBuilderStr;
1043 {
1044 llvm::raw_string_ostream os(attrBuilderStr);
1045 os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1046 "attrOptional.value()");
1047 }
1048
1049 // Build a string containing the cases that can be formatted as a keyword.
1050 std::string validCaseKeywordsStr = "{";
1051 llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
1052 for (const EnumAttrCase &attrCase : cases)
1053 if (canFormatStringAsKeyword(attrCase.getStr()))
1054 validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
1055 validCaseKeywordsOS.str().back() = '}';
1056
1057 // If the attribute is not optional, build an error message for the missing
1058 // attribute.
1059 std::string errorMessage;
1060 if (!var->attr.isOptional()) {
1061 llvm::raw_string_ostream errorMessageOS(errorMessage);
1062 errorMessageOS
1063 << "return parser.emitError(loc, \"expected string or "
1064 "keyword containing one of the following enum values for attribute '"
1065 << var->name << "' [";
1066 llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) {
1067 errorMessageOS << attrCase.getStr();
1068 });
1069 errorMessageOS << "]\");";
1070 }
1071
1072 body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
1073 enumAttr.getStringToSymbolFnName(), attrBuilderStr,
1074 validCaseKeywordsStr, errorMessage);
1075 }
1076
genParser(Operator & op,OpClass & opClass)1077 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1078 SmallVector<MethodParameter> paramList;
1079 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1080 paramList.emplace_back("::mlir::OperationState &", "result");
1081
1082 auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
1083 std::move(paramList));
1084 auto &body = method->body();
1085
1086 // Generate variables to store the operands and type within the format. This
1087 // allows for referencing these variables in the presence of optional
1088 // groupings.
1089 for (FormatElement *element : elements)
1090 genElementParserStorage(element, op, body);
1091
1092 // A format context used when parsing attributes with buildable types.
1093 FmtContext attrTypeCtx;
1094 attrTypeCtx.withBuilder("parser.getBuilder()");
1095
1096 // Generate parsers for each of the elements.
1097 for (FormatElement *element : elements)
1098 genElementParser(element, body, attrTypeCtx);
1099
1100 // Generate the code to resolve the operand/result types and successors now
1101 // that they have been parsed.
1102 genParserRegionResolution(op, body);
1103 genParserSuccessorResolution(op, body);
1104 genParserVariadicSegmentResolution(op, body);
1105 genParserTypeResolution(op, body);
1106
1107 body << " return ::mlir::success();\n";
1108 }
1109
genElementParser(FormatElement * element,MethodBody & body,FmtContext & attrTypeCtx,GenContext genCtx)1110 void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
1111 FmtContext &attrTypeCtx,
1112 GenContext genCtx) {
1113 /// Optional Group.
1114 if (auto *optional = dyn_cast<OptionalElement>(element)) {
1115 ArrayRef<FormatElement *> elements =
1116 optional->getThenElements().drop_front(optional->getParseStart());
1117
1118 // Generate a special optional parser for the first element to gate the
1119 // parsing of the rest of the elements.
1120 FormatElement *firstElement = elements.front();
1121 if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1122 genElementParser(attrVar, body, attrTypeCtx);
1123 body << " if (" << attrVar->getVar()->name << "Attr) {\n";
1124 } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1125 body << " if (succeeded(parser.parseOptional";
1126 genLiteralParser(literal->getSpelling(), body);
1127 body << ")) {\n";
1128 } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1129 genElementParser(opVar, body, attrTypeCtx);
1130 body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1131 } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1132 const NamedRegion *region = regionVar->getVar();
1133 if (region->isVariadic()) {
1134 genElementParser(regionVar, body, attrTypeCtx);
1135 body << " if (!" << region->name << "Regions.empty()) {\n";
1136 } else {
1137 body << llvm::formatv(optionalRegionParserCode, region->name);
1138 body << " if (!" << region->name << "Region->empty()) {\n ";
1139 if (hasImplicitTermTrait)
1140 body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
1141 else if (hasSingleBlockTrait)
1142 body << llvm::formatv(regionEnsureSingleBlockParserCode,
1143 region->name);
1144 }
1145 }
1146
1147 // If the anchor is a unit attribute, we don't need to print it. When
1148 // parsing, we will add this attribute if this group is present.
1149 FormatElement *elidedAnchorElement = nullptr;
1150 auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
1151 if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
1152 elidedAnchorElement = anchorAttr;
1153
1154 // Add the anchor unit attribute to the operation state.
1155 body << " result.addAttribute(\"" << anchorAttr->getVar()->name
1156 << "\", parser.getBuilder().getUnitAttr());\n";
1157 }
1158
1159 // Generate the rest of the elements inside an optional group. Elements in
1160 // an optional group after the guard are parsed as required.
1161 for (FormatElement *childElement : llvm::drop_begin(elements, 1))
1162 if (childElement != elidedAnchorElement)
1163 genElementParser(childElement, body, attrTypeCtx, GenContext::Optional);
1164 body << " }";
1165
1166 // Generate the else elements.
1167 auto elseElements = optional->getElseElements();
1168 if (!elseElements.empty()) {
1169 body << " else {\n";
1170 for (FormatElement *childElement : elseElements)
1171 genElementParser(childElement, body, attrTypeCtx);
1172 body << " }";
1173 }
1174 body << "\n";
1175
1176 /// OIList Directive
1177 } else if (OIListElement *oilist = dyn_cast<OIListElement>(element)) {
1178 for (LiteralElement *le : oilist->getLiteralElements())
1179 body << " bool " << le->getSpelling() << "Clause = false;\n";
1180
1181 // Generate the parsing loop
1182 body << " while(true) {\n";
1183 for (auto clause : oilist->getClauses()) {
1184 LiteralElement *lelement = std::get<0>(clause);
1185 ArrayRef<FormatElement *> pelement = std::get<1>(clause);
1186 body << "if (succeeded(parser.parseOptional";
1187 genLiteralParser(lelement->getSpelling(), body);
1188 body << ")) {\n";
1189 StringRef lelementName = lelement->getSpelling();
1190 body << formatv(oilistParserCode, lelementName);
1191 if (AttributeVariable *unitAttrElem =
1192 oilist->getUnitAttrParsingElement(pelement)) {
1193 body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
1194 << "\", UnitAttr::get(parser.getContext()));\n";
1195 } else {
1196 for (FormatElement *el : pelement)
1197 genElementParser(el, body, attrTypeCtx);
1198 }
1199 body << " } else ";
1200 }
1201 body << " {\n";
1202 body << " break;\n";
1203 body << " }\n";
1204 body << "}\n";
1205
1206 /// Literals.
1207 } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1208 body << " if (parser.parse";
1209 genLiteralParser(literal->getSpelling(), body);
1210 body << ")\n return ::mlir::failure();\n";
1211
1212 /// Whitespaces.
1213 } else if (isa<WhitespaceElement>(element)) {
1214 // Nothing to parse.
1215
1216 /// Arguments.
1217 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1218 const NamedAttribute *var = attr->getVar();
1219
1220 // Check to see if we can parse this as an enum attribute.
1221 if (canFormatEnumAttr(var))
1222 return genEnumAttrParser(var, body, attrTypeCtx);
1223
1224 // Check to see if we should parse this as a symbol name attribute.
1225 if (shouldFormatSymbolNameAttr(var)) {
1226 body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode
1227 : symbolNameAttrParserCode,
1228 var->name);
1229 return;
1230 }
1231
1232 // If this attribute has a buildable type, use that when parsing the
1233 // attribute.
1234 std::string attrTypeStr;
1235 if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1236 llvm::raw_string_ostream os(attrTypeStr);
1237 os << tgfmt(*typeBuilder, &attrTypeCtx);
1238 } else {
1239 attrTypeStr = "::mlir::Type{}";
1240 }
1241 if (genCtx == GenContext::Normal && var->attr.isOptional()) {
1242 body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
1243 } else {
1244 if (attr->shouldBeQualified() ||
1245 var->attr.getStorageType() == "::mlir::Attribute")
1246 body << formatv(genericAttrParserCode, var->name, attrTypeStr);
1247 else
1248 body << formatv(attrParserCode, var->name, attrTypeStr);
1249 }
1250
1251 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1252 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1253 StringRef name = operand->getVar()->name;
1254 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1255 body << llvm::formatv(
1256 variadicOfVariadicOperandParserCode, name,
1257 operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr());
1258 else if (lengthKind == ArgumentLengthKind::Variadic)
1259 body << llvm::formatv(variadicOperandParserCode, name);
1260 else if (lengthKind == ArgumentLengthKind::Optional)
1261 body << llvm::formatv(optionalOperandParserCode, name);
1262 else
1263 body << formatv(operandParserCode, name);
1264
1265 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1266 bool isVariadic = region->getVar()->isVariadic();
1267 body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
1268 region->getVar()->name);
1269 if (hasImplicitTermTrait)
1270 body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
1271 : regionEnsureTerminatorParserCode,
1272 region->getVar()->name);
1273 else if (hasSingleBlockTrait)
1274 body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
1275 : regionEnsureSingleBlockParserCode,
1276 region->getVar()->name);
1277
1278 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1279 bool isVariadic = successor->getVar()->isVariadic();
1280 body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1281 successor->getVar()->name);
1282
1283 /// Directives.
1284 } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1285 body << " if (parser.parseOptionalAttrDict"
1286 << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1287 << "(result.attributes))\n"
1288 << " return ::mlir::failure();\n";
1289 } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1290 genCustomDirectiveParser(customDir, body);
1291
1292 } else if (isa<OperandsDirective>(element)) {
1293 body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1294 << " if (parser.parseOperandList(allOperands))\n"
1295 << " return ::mlir::failure();\n";
1296
1297 } else if (isa<RegionsDirective>(element)) {
1298 body << llvm::formatv(regionListParserCode, "full");
1299 if (hasImplicitTermTrait)
1300 body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
1301 else if (hasSingleBlockTrait)
1302 body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
1303
1304 } else if (isa<SuccessorsDirective>(element)) {
1305 body << llvm::formatv(successorListParserCode, "full");
1306
1307 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1308 ArgumentLengthKind lengthKind;
1309 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1310 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1311 body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
1312 } else if (lengthKind == ArgumentLengthKind::Variadic) {
1313 body << llvm::formatv(variadicTypeParserCode, listName);
1314 } else if (lengthKind == ArgumentLengthKind::Optional) {
1315 body << llvm::formatv(optionalTypeParserCode, listName);
1316 } else {
1317 const char *parserCode =
1318 dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
1319 TypeSwitch<FormatElement *>(dir->getArg())
1320 .Case<OperandVariable, ResultVariable>([&](auto operand) {
1321 body << formatv(parserCode,
1322 operand->getVar()->constraint.getCPPClassName(),
1323 listName);
1324 })
1325 .Default([&](auto operand) {
1326 body << formatv(parserCode, "::mlir::Type", listName);
1327 });
1328 }
1329 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1330 ArgumentLengthKind ignored;
1331 body << formatv(functionalTypeParserCode,
1332 getTypeListName(dir->getInputs(), ignored),
1333 getTypeListName(dir->getResults(), ignored));
1334 } else {
1335 llvm_unreachable("unknown format element");
1336 }
1337 }
1338
genParserTypeResolution(Operator & op,MethodBody & body)1339 void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
1340 // If any of type resolutions use transformed variables, make sure that the
1341 // types of those variables are resolved.
1342 SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1343 FmtContext verifierFCtx;
1344 for (TypeResolution &resolver :
1345 llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1346 Optional<StringRef> transformer = resolver.getVarTransformer();
1347 if (!transformer)
1348 continue;
1349 // Ensure that we don't verify the same variables twice.
1350 const NamedTypeConstraint *variable = resolver.getVariable();
1351 if (!variable || !verifiedVariables.insert(variable).second)
1352 continue;
1353
1354 auto constraint = variable->constraint;
1355 body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
1356 << " (void)type;\n"
1357 << " if (!("
1358 << tgfmt(constraint.getConditionTemplate(),
1359 &verifierFCtx.withSelf("type"))
1360 << ")) {\n"
1361 << formatv(" return parser.emitError(parser.getNameLoc()) << "
1362 "\"'{0}' must be {1}, but got \" << type;\n",
1363 variable->name, constraint.getSummary())
1364 << " }\n"
1365 << " }\n";
1366 }
1367
1368 // Initialize the set of buildable types.
1369 if (!buildableTypes.empty()) {
1370 FmtContext typeBuilderCtx;
1371 typeBuilderCtx.withBuilder("parser.getBuilder()");
1372 for (auto &it : buildableTypes)
1373 body << " ::mlir::Type odsBuildableType" << it.second << " = "
1374 << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1375 }
1376
1377 // Emit the code necessary for a type resolver.
1378 auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1379 if (Optional<int> val = resolver.getBuilderIdx()) {
1380 body << "odsBuildableType" << *val;
1381 } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1382 if (Optional<StringRef> tform = resolver.getVarTransformer()) {
1383 FmtContext fmtContext;
1384 fmtContext.addSubst("_ctxt", "parser.getContext()");
1385 if (var->isVariadic())
1386 fmtContext.withSelf(var->name + "Types");
1387 else
1388 fmtContext.withSelf(var->name + "Types[0]");
1389 body << tgfmt(*tform, &fmtContext);
1390 } else {
1391 body << var->name << "Types";
1392 }
1393 } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1394 if (Optional<StringRef> tform = resolver.getVarTransformer())
1395 body << tgfmt(*tform,
1396 &FmtContext().withSelf(attr->name + "Attr.getType()"));
1397 else
1398 body << attr->name << "Attr.getType()";
1399 } else {
1400 body << curVar << "Types";
1401 }
1402 };
1403
1404 // Resolve each of the result types.
1405 if (!infersResultTypes) {
1406 if (allResultTypes) {
1407 body << " result.addTypes(allResultTypes);\n";
1408 } else {
1409 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1410 body << " result.addTypes(";
1411 emitTypeResolver(resultTypes[i], op.getResultName(i));
1412 body << ");\n";
1413 }
1414 }
1415 }
1416
1417 // Emit the operand type resolutions.
1418 genParserOperandTypeResolution(op, body, emitTypeResolver);
1419
1420 // Handle return type inference once all operands have been resolved
1421 if (infersResultTypes)
1422 body << formatv(inferReturnTypesParserCode, op.getCppClassName());
1423 }
1424
genParserOperandTypeResolution(Operator & op,MethodBody & body,function_ref<void (TypeResolution &,StringRef)> emitTypeResolver)1425 void OperationFormat::genParserOperandTypeResolution(
1426 Operator &op, MethodBody &body,
1427 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
1428 // Early exit if there are no operands.
1429 if (op.getNumOperands() == 0)
1430 return;
1431
1432 // Handle the case where all operand types are grouped together with
1433 // "types(operands)".
1434 if (allOperandTypes) {
1435 // If `operands` was specified, use the full operand list directly.
1436 if (allOperands) {
1437 body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
1438 "allOperandLoc, result.operands))\n"
1439 " return ::mlir::failure();\n";
1440 return;
1441 }
1442
1443 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1444 // llvm::concat does not allow the case of a single range, so guard it here.
1445 body << " if (parser.resolveOperands(";
1446 if (op.getNumOperands() > 1) {
1447 body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1448 llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1449 body << operand.name << "Operands";
1450 });
1451 body << ")";
1452 } else {
1453 body << op.operand_begin()->name << "Operands";
1454 }
1455 body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1456 << " return ::mlir::failure();\n";
1457 return;
1458 }
1459
1460 // Handle the case where all operands are grouped together with "operands".
1461 if (allOperands) {
1462 body << " if (parser.resolveOperands(allOperands, ";
1463
1464 // Group all of the operand types together to perform the resolution all at
1465 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1466 // the case of a single range, so guard it here.
1467 if (op.getNumOperands() > 1) {
1468 body << "::llvm::concat<const ::mlir::Type>(";
1469 llvm::interleaveComma(
1470 llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1471 body << "::llvm::ArrayRef<::mlir::Type>(";
1472 emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1473 body << ")";
1474 });
1475 body << ")";
1476 } else {
1477 emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1478 }
1479
1480 body << ", allOperandLoc, result.operands))\n"
1481 << " return ::mlir::failure();\n";
1482 return;
1483 }
1484
1485 // The final case is the one where each of the operands types are resolved
1486 // separately.
1487 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1488 NamedTypeConstraint &operand = op.getOperand(i);
1489 body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
1490
1491 // Resolve the type of this operand.
1492 TypeResolution &operandType = operandTypes[i];
1493 emitTypeResolver(operandType, operand.name);
1494
1495 // If the type is resolved by a non-variadic variable, index into the
1496 // resolved type list. This allows for resolving the types of a variadic
1497 // operand list from a non-variadic variable.
1498 bool verifyOperandAndTypeSize = true;
1499 if (auto *resolverVar = operandType.getVariable()) {
1500 if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) {
1501 body << "[0]";
1502 verifyOperandAndTypeSize = false;
1503 }
1504 } else {
1505 verifyOperandAndTypeSize = !operandType.getBuilderIdx();
1506 }
1507
1508 // Check to see if the sizes between the types and operands must match. If
1509 // they do, provide the operand location to select the proper resolution
1510 // overload.
1511 if (verifyOperandAndTypeSize)
1512 body << ", " << operand.name << "OperandsLoc";
1513 body << ", result.operands))\n return ::mlir::failure();\n";
1514 }
1515 }
1516
genParserRegionResolution(Operator & op,MethodBody & body)1517 void OperationFormat::genParserRegionResolution(Operator &op,
1518 MethodBody &body) {
1519 // Check for the case where all regions were parsed.
1520 bool hasAllRegions = llvm::any_of(
1521 elements, [](FormatElement *elt) { return isa<RegionsDirective>(elt); });
1522 if (hasAllRegions) {
1523 body << " result.addRegions(fullRegions);\n";
1524 return;
1525 }
1526
1527 // Otherwise, handle each region individually.
1528 for (const NamedRegion ®ion : op.getRegions()) {
1529 if (region.isVariadic())
1530 body << " result.addRegions(" << region.name << "Regions);\n";
1531 else
1532 body << " result.addRegion(std::move(" << region.name << "Region));\n";
1533 }
1534 }
1535
genParserSuccessorResolution(Operator & op,MethodBody & body)1536 void OperationFormat::genParserSuccessorResolution(Operator &op,
1537 MethodBody &body) {
1538 // Check for the case where all successors were parsed.
1539 bool hasAllSuccessors = llvm::any_of(elements, [](FormatElement *elt) {
1540 return isa<SuccessorsDirective>(elt);
1541 });
1542 if (hasAllSuccessors) {
1543 body << " result.addSuccessors(fullSuccessors);\n";
1544 return;
1545 }
1546
1547 // Otherwise, handle each successor individually.
1548 for (const NamedSuccessor &successor : op.getSuccessors()) {
1549 if (successor.isVariadic())
1550 body << " result.addSuccessors(" << successor.name << "Successors);\n";
1551 else
1552 body << " result.addSuccessors(" << successor.name << "Successor);\n";
1553 }
1554 }
1555
genParserVariadicSegmentResolution(Operator & op,MethodBody & body)1556 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1557 MethodBody &body) {
1558 if (!allOperands) {
1559 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1560 body << " result.addAttribute(\"operand_segment_sizes\", "
1561 << "parser.getBuilder().getI32VectorAttr({";
1562 auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1563 // If the operand is variadic emit the parsed size.
1564 if (operand.isVariableLength())
1565 body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1566 else
1567 body << "1";
1568 };
1569 llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1570 body << "}));\n";
1571 }
1572 for (const NamedTypeConstraint &operand : op.getOperands()) {
1573 if (!operand.isVariadicOfVariadic())
1574 continue;
1575 body << llvm::formatv(
1576 " result.addAttribute(\"{0}\", "
1577 "parser.getBuilder().getI32TensorAttr({1}OperandGroupSizes));\n",
1578 operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1579 operand.name);
1580 }
1581 }
1582
1583 if (!allResultTypes &&
1584 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1585 body << " result.addAttribute(\"result_segment_sizes\", "
1586 << "parser.getBuilder().getI32VectorAttr({";
1587 auto interleaveFn = [&](const NamedTypeConstraint &result) {
1588 // If the result is variadic emit the parsed size.
1589 if (result.isVariableLength())
1590 body << "static_cast<int32_t>(" << result.name << "Types.size())";
1591 else
1592 body << "1";
1593 };
1594 llvm::interleaveComma(op.getResults(), body, interleaveFn);
1595 body << "}));\n";
1596 }
1597 }
1598
1599 //===----------------------------------------------------------------------===//
1600 // PrinterGen
1601
1602 /// The code snippet used to generate a printer call for a region of an
1603 // operation that has the SingleBlockImplicitTerminator trait.
1604 ///
1605 /// {0}: The name of the region.
1606 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1607 {
1608 bool printTerminator = true;
1609 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1610 printTerminator = !term->getAttrDictionary().empty() ||
1611 term->getNumOperands() != 0 ||
1612 term->getNumResults() != 0;
1613 }
1614 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1615 /*printBlockTerminators=*/printTerminator);
1616 }
1617 )";
1618
1619 /// The code snippet used to generate a printer call for an enum that has cases
1620 /// that can't be represented with a keyword.
1621 ///
1622 /// {0}: The name of the enum attribute.
1623 /// {1}: The name of the enum attributes symbolToString function.
1624 const char *enumAttrBeginPrinterCode = R"(
1625 {
1626 auto caseValue = {0}();
1627 auto caseValueStr = {1}(caseValue);
1628 )";
1629
1630 /// Generate the printer for the 'attr-dict' directive.
genAttrDictPrinter(OperationFormat & fmt,Operator & op,MethodBody & body,bool withKeyword)1631 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1632 MethodBody &body, bool withKeyword) {
1633 body << " _odsPrinter.printOptionalAttrDict"
1634 << (withKeyword ? "WithKeyword" : "")
1635 << "((*this)->getAttrs(), /*elidedAttrs=*/{";
1636 // Elide the variadic segment size attributes if necessary.
1637 if (!fmt.allOperands &&
1638 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1639 body << "\"operand_segment_sizes\", ";
1640 if (!fmt.allResultTypes &&
1641 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1642 body << "\"result_segment_sizes\", ";
1643 if (!fmt.inferredAttributes.empty()) {
1644 for (const auto &attr : fmt.inferredAttributes)
1645 body << "\"" << attr.getKey() << "\", ";
1646 }
1647 llvm::interleaveComma(
1648 fmt.usedAttributes, body,
1649 [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
1650 body << "});\n";
1651 }
1652
1653 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1654 /// space should be emitted before this element. `lastWasPunctuation` is true if
1655 /// the previous element was a punctuation literal.
genLiteralPrinter(StringRef value,MethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1656 static void genLiteralPrinter(StringRef value, MethodBody &body,
1657 bool &shouldEmitSpace, bool &lastWasPunctuation) {
1658 body << " _odsPrinter";
1659
1660 // Don't insert a space for certain punctuation.
1661 if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
1662 body << " << ' '";
1663 body << " << \"" << value << "\";\n";
1664
1665 // Insert a space after certain literals.
1666 shouldEmitSpace =
1667 value.size() != 1 || !StringRef("<({[").contains(value.front());
1668 lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
1669 }
1670
1671 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1672 /// are set to false.
genSpacePrinter(bool value,MethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1673 static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
1674 bool &lastWasPunctuation) {
1675 if (value) {
1676 body << " _odsPrinter << ' ';\n";
1677 lastWasPunctuation = false;
1678 } else {
1679 lastWasPunctuation = true;
1680 }
1681 shouldEmitSpace = false;
1682 }
1683
1684 /// Generate the printer for a custom directive parameter.
genCustomDirectiveParameterPrinter(FormatElement * element,const Operator & op,MethodBody & body)1685 static void genCustomDirectiveParameterPrinter(FormatElement *element,
1686 const Operator &op,
1687 MethodBody &body) {
1688 if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1689 body << op.getGetterName(attr->getVar()->name) << "Attr()";
1690
1691 } else if (isa<AttrDictDirective>(element)) {
1692 body << "getOperation()->getAttrDictionary()";
1693
1694 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1695 body << op.getGetterName(operand->getVar()->name) << "()";
1696
1697 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1698 body << op.getGetterName(region->getVar()->name) << "()";
1699
1700 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1701 body << op.getGetterName(successor->getVar()->name) << "()";
1702
1703 } else if (auto *dir = dyn_cast<RefDirective>(element)) {
1704 genCustomDirectiveParameterPrinter(dir->getArg(), op, body);
1705
1706 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1707 auto *typeOperand = dir->getArg();
1708 auto *operand = dyn_cast<OperandVariable>(typeOperand);
1709 auto *var = operand ? operand->getVar()
1710 : cast<ResultVariable>(typeOperand)->getVar();
1711 std::string name = op.getGetterName(var->name);
1712 if (var->isVariadic())
1713 body << name << "().getTypes()";
1714 else if (var->isOptional())
1715 body << llvm::formatv("({0}() ? {0}().getType() : Type())", name);
1716 else
1717 body << name << "().getType()";
1718 } else {
1719 llvm_unreachable("unknown custom directive parameter");
1720 }
1721 }
1722
1723 /// Generate the printer for a custom directive.
genCustomDirectivePrinter(CustomDirective * customDir,const Operator & op,MethodBody & body)1724 static void genCustomDirectivePrinter(CustomDirective *customDir,
1725 const Operator &op, MethodBody &body) {
1726 body << " print" << customDir->getName() << "(_odsPrinter, *this";
1727 for (FormatElement *param : customDir->getArguments()) {
1728 body << ", ";
1729 genCustomDirectiveParameterPrinter(param, op, body);
1730 }
1731 body << ");\n";
1732 }
1733
1734 /// Generate the printer for a region with the given variable name.
genRegionPrinter(const Twine & regionName,MethodBody & body,bool hasImplicitTermTrait)1735 static void genRegionPrinter(const Twine ®ionName, MethodBody &body,
1736 bool hasImplicitTermTrait) {
1737 if (hasImplicitTermTrait)
1738 body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
1739 regionName);
1740 else
1741 body << " _odsPrinter.printRegion(" << regionName << ");\n";
1742 }
genVariadicRegionPrinter(const Twine & regionListName,MethodBody & body,bool hasImplicitTermTrait)1743 static void genVariadicRegionPrinter(const Twine ®ionListName,
1744 MethodBody &body,
1745 bool hasImplicitTermTrait) {
1746 body << " llvm::interleaveComma(" << regionListName
1747 << ", _odsPrinter, [&](::mlir::Region ®ion) {\n ";
1748 genRegionPrinter("region", body, hasImplicitTermTrait);
1749 body << " });\n";
1750 }
1751
1752 /// Generate the C++ for an operand to a (*-)type directive.
genTypeOperandPrinter(FormatElement * arg,const Operator & op,MethodBody & body,bool useArrayRef=true)1753 static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
1754 MethodBody &body,
1755 bool useArrayRef = true) {
1756 if (isa<OperandsDirective>(arg))
1757 return body << "getOperation()->getOperandTypes()";
1758 if (isa<ResultsDirective>(arg))
1759 return body << "getOperation()->getResultTypes()";
1760 auto *operand = dyn_cast<OperandVariable>(arg);
1761 auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
1762 if (var->isVariadicOfVariadic())
1763 return body << llvm::formatv("{0}().join().getTypes()",
1764 op.getGetterName(var->name));
1765 if (var->isVariadic())
1766 return body << op.getGetterName(var->name) << "().getTypes()";
1767 if (var->isOptional())
1768 return body << llvm::formatv(
1769 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1770 "::llvm::ArrayRef<::mlir::Type>())",
1771 op.getGetterName(var->name));
1772 if (useArrayRef)
1773 return body << "::llvm::ArrayRef<::mlir::Type>("
1774 << op.getGetterName(var->name) << "().getType())";
1775 return body << op.getGetterName(var->name) << "().getType()";
1776 }
1777
1778 /// Generate the printer for an enum attribute.
genEnumAttrPrinter(const NamedAttribute * var,const Operator & op,MethodBody & body)1779 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
1780 MethodBody &body) {
1781 Attribute baseAttr = var->attr.getBaseAttr();
1782 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1783 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1784
1785 body << llvm::formatv(enumAttrBeginPrinterCode,
1786 (var->attr.isOptional() ? "*" : "") +
1787 op.getGetterName(var->name),
1788 enumAttr.getSymbolToStringFnName());
1789
1790 // Get a string containing all of the cases that can't be represented with a
1791 // keyword.
1792 BitVector nonKeywordCases(cases.size());
1793 for (auto &it : llvm::enumerate(cases)) {
1794 if (!canFormatStringAsKeyword(it.value().getStr()))
1795 nonKeywordCases.set(it.index());
1796 }
1797
1798 // Otherwise if this is a bit enum attribute, don't allow cases that may
1799 // overlap with other cases. For simplicity sake, only allow cases with a
1800 // single bit value.
1801 if (enumAttr.isBitEnum()) {
1802 for (auto &it : llvm::enumerate(cases)) {
1803 int64_t value = it.value().getValue();
1804 if (value < 0 || !llvm::isPowerOf2_64(value))
1805 nonKeywordCases.set(it.index());
1806 }
1807 }
1808
1809 // If there are any cases that can't be used with a keyword, switch on the
1810 // case value to determine when to print in the string form.
1811 if (nonKeywordCases.any()) {
1812 body << " switch (caseValue) {\n";
1813 StringRef cppNamespace = enumAttr.getCppNamespace();
1814 StringRef enumName = enumAttr.getEnumClassName();
1815 for (auto &it : llvm::enumerate(cases)) {
1816 if (nonKeywordCases.test(it.index()))
1817 continue;
1818 StringRef symbol = it.value().getSymbol();
1819 body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
1820 llvm::isDigit(symbol.front()) ? ("_" + symbol)
1821 : symbol);
1822 }
1823 body << " _odsPrinter << caseValueStr;\n"
1824 " break;\n"
1825 " default:\n"
1826 " _odsPrinter << '\"' << caseValueStr << '\"';\n"
1827 " break;\n"
1828 " }\n"
1829 " }\n";
1830 return;
1831 }
1832
1833 body << " _odsPrinter << caseValueStr;\n"
1834 " }\n";
1835 }
1836
1837 /// Generate the check for the anchor of an optional group.
genOptionalGroupPrinterAnchor(FormatElement * anchor,const Operator & op,MethodBody & body)1838 static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
1839 const Operator &op,
1840 MethodBody &body) {
1841 TypeSwitch<FormatElement *>(anchor)
1842 .Case<OperandVariable, ResultVariable>([&](auto *element) {
1843 const NamedTypeConstraint *var = element->getVar();
1844 std::string name = op.getGetterName(var->name);
1845 if (var->isOptional())
1846 body << " if (" << name << "()) {\n";
1847 else if (var->isVariadic())
1848 body << " if (!" << name << "().empty()) {\n";
1849 })
1850 .Case<RegionVariable>([&](RegionVariable *element) {
1851 const NamedRegion *var = element->getVar();
1852 std::string name = op.getGetterName(var->name);
1853 // TODO: Add a check for optional regions here when ODS supports it.
1854 body << " if (!" << name << "().empty()) {\n";
1855 })
1856 .Case<TypeDirective>([&](TypeDirective *element) {
1857 genOptionalGroupPrinterAnchor(element->getArg(), op, body);
1858 })
1859 .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
1860 genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
1861 })
1862 .Case<AttributeVariable>([&](AttributeVariable *attr) {
1863 body << " if ((*this)->getAttr(\"" << attr->getVar()->name
1864 << "\")) {\n";
1865 });
1866 }
1867
collect(FormatElement * element,SmallVectorImpl<VariableElement * > & variables)1868 void collect(FormatElement *element,
1869 SmallVectorImpl<VariableElement *> &variables) {
1870 TypeSwitch<FormatElement *>(element)
1871 .Case([&](VariableElement *var) { variables.emplace_back(var); })
1872 .Case([&](CustomDirective *ele) {
1873 for (FormatElement *arg : ele->getArguments())
1874 collect(arg, variables);
1875 })
1876 .Case([&](OptionalElement *ele) {
1877 for (FormatElement *arg : ele->getThenElements())
1878 collect(arg, variables);
1879 for (FormatElement *arg : ele->getElseElements())
1880 collect(arg, variables);
1881 })
1882 .Case([&](FunctionalTypeDirective *funcType) {
1883 collect(funcType->getInputs(), variables);
1884 collect(funcType->getResults(), variables);
1885 })
1886 .Case([&](OIListElement *oilist) {
1887 for (ArrayRef<FormatElement *> arg : oilist->getParsingElements())
1888 for (FormatElement *arg : arg)
1889 collect(arg, variables);
1890 });
1891 }
1892
genElementPrinter(FormatElement * element,MethodBody & body,Operator & op,bool & shouldEmitSpace,bool & lastWasPunctuation)1893 void OperationFormat::genElementPrinter(FormatElement *element,
1894 MethodBody &body, Operator &op,
1895 bool &shouldEmitSpace,
1896 bool &lastWasPunctuation) {
1897 if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
1898 return genLiteralPrinter(literal->getSpelling(), body, shouldEmitSpace,
1899 lastWasPunctuation);
1900
1901 // Emit a whitespace element.
1902 if (auto *space = dyn_cast<WhitespaceElement>(element)) {
1903 if (space->getValue() == "\\n") {
1904 body << " _odsPrinter.printNewline();\n";
1905 } else {
1906 genSpacePrinter(!space->getValue().empty(), body, shouldEmitSpace,
1907 lastWasPunctuation);
1908 }
1909 return;
1910 }
1911
1912 // Emit an optional group.
1913 if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
1914 // Emit the check for the presence of the anchor element.
1915 FormatElement *anchor = optional->getAnchor();
1916 genOptionalGroupPrinterAnchor(anchor, op, body);
1917
1918 // If the anchor is a unit attribute, we don't need to print it. When
1919 // parsing, we will add this attribute if this group is present.
1920 auto elements = optional->getThenElements();
1921 FormatElement *elidedAnchorElement = nullptr;
1922 auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
1923 if (anchorAttr && anchorAttr != elements.front() &&
1924 anchorAttr->isUnitAttr()) {
1925 elidedAnchorElement = anchorAttr;
1926 }
1927
1928 // Emit each of the elements.
1929 for (FormatElement *childElement : elements) {
1930 if (childElement != elidedAnchorElement) {
1931 genElementPrinter(childElement, body, op, shouldEmitSpace,
1932 lastWasPunctuation);
1933 }
1934 }
1935 body << " }";
1936
1937 // Emit each of the else elements.
1938 auto elseElements = optional->getElseElements();
1939 if (!elseElements.empty()) {
1940 body << " else {\n";
1941 for (FormatElement *childElement : elseElements) {
1942 genElementPrinter(childElement, body, op, shouldEmitSpace,
1943 lastWasPunctuation);
1944 }
1945 body << " }";
1946 }
1947
1948 body << "\n";
1949 return;
1950 }
1951
1952 // Emit the OIList
1953 if (auto *oilist = dyn_cast<OIListElement>(element)) {
1954 genLiteralPrinter(" ", body, shouldEmitSpace, lastWasPunctuation);
1955 for (auto clause : oilist->getClauses()) {
1956 LiteralElement *lelement = std::get<0>(clause);
1957 ArrayRef<FormatElement *> pelement = std::get<1>(clause);
1958
1959 SmallVector<VariableElement *> vars;
1960 for (FormatElement *el : pelement)
1961 collect(el, vars);
1962 body << " if (false";
1963 for (VariableElement *var : vars) {
1964 TypeSwitch<FormatElement *>(var)
1965 .Case([&](AttributeVariable *attrEle) {
1966 body << " || " << op.getGetterName(attrEle->getVar()->name)
1967 << "Attr()";
1968 })
1969 .Case([&](OperandVariable *ele) {
1970 if (ele->getVar()->isVariadic()) {
1971 body << " || " << op.getGetterName(ele->getVar()->name)
1972 << "().size()";
1973 } else {
1974 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
1975 }
1976 })
1977 .Case([&](ResultVariable *ele) {
1978 if (ele->getVar()->isVariadic()) {
1979 body << " || " << op.getGetterName(ele->getVar()->name)
1980 << "().size()";
1981 } else {
1982 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
1983 }
1984 })
1985 .Case([&](RegionVariable *reg) {
1986 body << " || " << op.getGetterName(reg->getVar()->name) << "()";
1987 });
1988 }
1989
1990 body << ") {\n";
1991 genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace,
1992 lastWasPunctuation);
1993 if (oilist->getUnitAttrParsingElement(pelement) == nullptr) {
1994 for (FormatElement *element : pelement)
1995 genElementPrinter(element, body, op, shouldEmitSpace,
1996 lastWasPunctuation);
1997 }
1998 body << " }\n";
1999 }
2000 return;
2001 }
2002
2003 // Emit the attribute dictionary.
2004 if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
2005 genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
2006 lastWasPunctuation = false;
2007 return;
2008 }
2009
2010 // Optionally insert a space before the next element. The AttrDict printer
2011 // already adds a space as necessary.
2012 if (shouldEmitSpace || !lastWasPunctuation)
2013 body << " _odsPrinter << ' ';\n";
2014 lastWasPunctuation = false;
2015 shouldEmitSpace = true;
2016
2017 if (auto *attr = dyn_cast<AttributeVariable>(element)) {
2018 const NamedAttribute *var = attr->getVar();
2019
2020 // If we are formatting as an enum, symbolize the attribute as a string.
2021 if (canFormatEnumAttr(var))
2022 return genEnumAttrPrinter(var, op, body);
2023
2024 // If we are formatting as a symbol name, handle it as a symbol name.
2025 if (shouldFormatSymbolNameAttr(var)) {
2026 body << " _odsPrinter.printSymbolName(" << op.getGetterName(var->name)
2027 << "Attr().getValue());\n";
2028 return;
2029 }
2030
2031 // Elide the attribute type if it is buildable.
2032 if (attr->getTypeBuilder())
2033 body << " _odsPrinter.printAttributeWithoutType("
2034 << op.getGetterName(var->name) << "Attr());\n";
2035 else if (attr->shouldBeQualified() ||
2036 var->attr.getStorageType() == "::mlir::Attribute")
2037 body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name)
2038 << "Attr());\n";
2039 else
2040 body << "_odsPrinter.printStrippedAttrOrType("
2041 << op.getGetterName(var->name) << "Attr());\n";
2042 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
2043 if (operand->getVar()->isVariadicOfVariadic()) {
2044 body << " ::llvm::interleaveComma("
2045 << op.getGetterName(operand->getVar()->name)
2046 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2047 "\"(\" << operands << "
2048 "\")\"; });\n";
2049
2050 } else if (operand->getVar()->isOptional()) {
2051 body << " if (::mlir::Value value = "
2052 << op.getGetterName(operand->getVar()->name) << "())\n"
2053 << " _odsPrinter << value;\n";
2054 } else {
2055 body << " _odsPrinter << " << op.getGetterName(operand->getVar()->name)
2056 << "();\n";
2057 }
2058 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
2059 const NamedRegion *var = region->getVar();
2060 std::string name = op.getGetterName(var->name);
2061 if (var->isVariadic()) {
2062 genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait);
2063 } else {
2064 genRegionPrinter(name + "()", body, hasImplicitTermTrait);
2065 }
2066 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
2067 const NamedSuccessor *var = successor->getVar();
2068 std::string name = op.getGetterName(var->name);
2069 if (var->isVariadic())
2070 body << " ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n";
2071 else
2072 body << " _odsPrinter << " << name << "();\n";
2073 } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
2074 genCustomDirectivePrinter(dir, op, body);
2075 } else if (isa<OperandsDirective>(element)) {
2076 body << " _odsPrinter << getOperation()->getOperands();\n";
2077 } else if (isa<RegionsDirective>(element)) {
2078 genVariadicRegionPrinter("getOperation()->getRegions()", body,
2079 hasImplicitTermTrait);
2080 } else if (isa<SuccessorsDirective>(element)) {
2081 body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2082 "_odsPrinter);\n";
2083 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
2084 if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
2085 if (operand->getVar()->isVariadicOfVariadic()) {
2086 body << llvm::formatv(
2087 " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2088 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2089 "types << \")\"; });\n",
2090 op.getGetterName(operand->getVar()->name));
2091 return;
2092 }
2093 }
2094 const NamedTypeConstraint *var = nullptr;
2095 {
2096 if (auto *operand = dyn_cast<OperandVariable>(dir->getArg()))
2097 var = operand->getVar();
2098 else if (auto *operand = dyn_cast<ResultVariable>(dir->getArg()))
2099 var = operand->getVar();
2100 }
2101 if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
2102 !var->isOptional()) {
2103 std::string cppClass = var->constraint.getCPPClassName();
2104 if (dir->shouldBeQualified()) {
2105 body << " _odsPrinter << " << op.getGetterName(var->name)
2106 << "().getType();\n";
2107 return;
2108 }
2109 body << " {\n"
2110 << " auto type = " << op.getGetterName(var->name)
2111 << "().getType();\n"
2112 << " if (auto validType = type.dyn_cast<" << cppClass << ">())\n"
2113 << " _odsPrinter.printStrippedAttrOrType(validType);\n"
2114 << " else\n"
2115 << " _odsPrinter << type;\n"
2116 << " }\n";
2117 return;
2118 }
2119 body << " _odsPrinter << ";
2120 genTypeOperandPrinter(dir->getArg(), op, body, /*useArrayRef=*/false)
2121 << ";\n";
2122 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
2123 body << " _odsPrinter.printFunctionalType(";
2124 genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";
2125 genTypeOperandPrinter(dir->getResults(), op, body) << ");\n";
2126 } else {
2127 llvm_unreachable("unknown format element");
2128 }
2129 }
2130
genPrinter(Operator & op,OpClass & opClass)2131 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
2132 auto *method = opClass.addMethod(
2133 "void", "print",
2134 MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2135 auto &body = method->body();
2136
2137 // Flags for if we should emit a space, and if the last element was
2138 // punctuation.
2139 bool shouldEmitSpace = true, lastWasPunctuation = false;
2140 for (FormatElement *element : elements)
2141 genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation);
2142 }
2143
2144 //===----------------------------------------------------------------------===//
2145 // OpFormatParser
2146 //===----------------------------------------------------------------------===//
2147
2148 /// Function to find an element within the given range that has the same name as
2149 /// 'name'.
2150 template <typename RangeT>
findArg(RangeT && range,StringRef name)2151 static auto findArg(RangeT &&range, StringRef name) {
2152 auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2153 return it != range.end() ? &*it : nullptr;
2154 }
2155
2156 namespace {
2157 /// This class implements a parser for an instance of an operation assembly
2158 /// format.
2159 class OpFormatParser : public FormatParser {
2160 public:
OpFormatParser(llvm::SourceMgr & mgr,OperationFormat & format,Operator & op)2161 OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2162 : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op),
2163 seenOperandTypes(op.getNumOperands()),
2164 seenResultTypes(op.getNumResults()) {}
2165
2166 protected:
2167 /// Verify the format elements.
2168 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
2169 /// Verify the arguments to a custom directive.
2170 LogicalResult
2171 verifyCustomDirectiveArguments(SMLoc loc,
2172 ArrayRef<FormatElement *> arguments) override;
2173 /// Verify the elements of an optional group.
2174 LogicalResult
2175 verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
2176 Optional<unsigned> anchorIndex) override;
2177 LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
2178 bool isAnchor);
2179
2180 /// Parse an operation variable.
2181 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
2182 Context ctx) override;
2183 /// Parse an operation format directive.
2184 FailureOr<FormatElement *>
2185 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
2186
2187 private:
2188 /// This struct represents a type resolution instance. It includes a specific
2189 /// type as well as an optional transformer to apply to that type in order to
2190 /// properly resolve the type of a variable.
2191 struct TypeResolutionInstance {
2192 ConstArgument resolver;
2193 Optional<StringRef> transformer;
2194 };
2195
2196 /// Verify the state of operation attributes within the format.
2197 LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements);
2198
2199 /// Verify that attributes elements aren't followed by colon literals.
2200 LogicalResult verifyAttributeColonType(SMLoc loc,
2201 ArrayRef<FormatElement *> elements);
2202
2203 /// Verify the state of operation operands within the format.
2204 LogicalResult
2205 verifyOperands(SMLoc loc,
2206 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2207
2208 /// Verify the state of operation regions within the format.
2209 LogicalResult verifyRegions(SMLoc loc);
2210
2211 /// Verify the state of operation results within the format.
2212 LogicalResult
2213 verifyResults(SMLoc loc,
2214 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2215
2216 /// Verify the state of operation successors within the format.
2217 LogicalResult verifySuccessors(SMLoc loc);
2218
2219 LogicalResult verifyOIListElements(SMLoc loc,
2220 ArrayRef<FormatElement *> elements);
2221
2222 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2223 /// resolution.
2224 void handleAllTypesMatchConstraint(
2225 ArrayRef<StringRef> values,
2226 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2227 /// Check for inferable type resolution given all operands, and or results,
2228 /// have the same type. If 'includeResults' is true, the results also have the
2229 /// same type as all of the operands.
2230 void handleSameTypesConstraint(
2231 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2232 bool includeResults);
2233 /// Check for inferable type resolution based on another operand, result, or
2234 /// attribute.
2235 void handleTypesMatchConstraint(
2236 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2237 const llvm::Record &def);
2238
2239 /// Returns an argument or attribute with the given name that has been seen
2240 /// within the format.
2241 ConstArgument findSeenArg(StringRef name);
2242
2243 /// Parse the various different directives.
2244 FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
2245 bool withKeyword);
2246 FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
2247 Context context);
2248 FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
2249 LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
2250 FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
2251 FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
2252 Context context);
2253 FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
2254 Context context);
2255 FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
2256 FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
2257 FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
2258 Context context);
2259 FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context);
2260 FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc,
2261 bool isRefChild = false);
2262
2263 //===--------------------------------------------------------------------===//
2264 // Fields
2265 //===--------------------------------------------------------------------===//
2266
2267 OperationFormat &fmt;
2268 Operator &op;
2269
2270 // The following are various bits of format state used for verification
2271 // during parsing.
2272 bool hasAttrDict = false;
2273 bool hasAllRegions = false, hasAllSuccessors = false;
2274 bool canInferResultTypes = false;
2275 llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2276 llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2277 llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2278 llvm::DenseSet<const NamedRegion *> seenRegions;
2279 llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2280 };
2281 } // namespace
2282
verify(SMLoc loc,ArrayRef<FormatElement * > elements)2283 LogicalResult OpFormatParser::verify(SMLoc loc,
2284 ArrayRef<FormatElement *> elements) {
2285 // Check that the attribute dictionary is in the format.
2286 if (!hasAttrDict)
2287 return emitError(loc, "'attr-dict' directive not found in "
2288 "custom assembly format");
2289
2290 // Check for any type traits that we can use for inferring types.
2291 llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2292 for (const Trait &trait : op.getTraits()) {
2293 const llvm::Record &def = trait.getDef();
2294 if (def.isSubClassOf("AllTypesMatch")) {
2295 handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2296 variableTyResolver);
2297 } else if (def.getName() == "SameTypeOperands") {
2298 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2299 } else if (def.getName() == "SameOperandsAndResultType") {
2300 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2301 } else if (def.isSubClassOf("TypesMatchWith")) {
2302 handleTypesMatchConstraint(variableTyResolver, def);
2303 } else if (!op.allResultTypesKnown()) {
2304 // This doesn't check the name directly to handle
2305 // DeclareOpInterfaceMethods<InferTypeOpInterface>
2306 // and the like.
2307 // TODO: Add hasCppInterface check.
2308 if (auto name = def.getValueAsOptionalString("cppInterfaceName")) {
2309 if (*name == "InferTypeOpInterface" &&
2310 def.getValueAsString("cppNamespace") == "::mlir")
2311 canInferResultTypes = true;
2312 }
2313 }
2314 }
2315
2316 // Verify the state of the various operation components.
2317 if (failed(verifyAttributes(loc, elements)) ||
2318 failed(verifyResults(loc, variableTyResolver)) ||
2319 failed(verifyOperands(loc, variableTyResolver)) ||
2320 failed(verifyRegions(loc)) || failed(verifySuccessors(loc)) ||
2321 failed(verifyOIListElements(loc, elements)))
2322 return failure();
2323
2324 // Collect the set of used attributes in the format.
2325 fmt.usedAttributes = seenAttrs.takeVector();
2326 return success();
2327 }
2328
2329 LogicalResult
verifyAttributes(SMLoc loc,ArrayRef<FormatElement * > elements)2330 OpFormatParser::verifyAttributes(SMLoc loc,
2331 ArrayRef<FormatElement *> elements) {
2332 // Check that there are no `:` literals after an attribute without a constant
2333 // type. The attribute grammar contains an optional trailing colon type, which
2334 // can lead to unexpected and generally unintended behavior. Given that, it is
2335 // better to just error out here instead.
2336 if (failed(verifyAttributeColonType(loc, elements)))
2337 return failure();
2338
2339 // Check for VariadicOfVariadic variables. The segment attribute of those
2340 // variables will be infered.
2341 for (const NamedTypeConstraint *var : seenOperands) {
2342 if (var->constraint.isVariadicOfVariadic()) {
2343 fmt.inferredAttributes.insert(
2344 var->constraint.getVariadicOfVariadicSegmentSizeAttr());
2345 }
2346 }
2347
2348 return success();
2349 }
2350
2351 /// Returns whether the single format element is optionally parsed.
isOptionallyParsed(FormatElement * el)2352 static bool isOptionallyParsed(FormatElement *el) {
2353 if (auto *attrVar = dyn_cast<AttributeVariable>(el)) {
2354 Attribute attr = attrVar->getVar()->attr;
2355 return attr.isOptional() || attr.hasDefaultValue();
2356 }
2357 if (auto *operandVar = dyn_cast<OperandVariable>(el)) {
2358 const NamedTypeConstraint *operand = operandVar->getVar();
2359 return operand->isOptional() || operand->isVariadic() ||
2360 operand->isVariadicOfVariadic();
2361 }
2362 if (auto *successorVar = dyn_cast<SuccessorVariable>(el))
2363 return successorVar->getVar()->isVariadic();
2364 if (auto *regionVar = dyn_cast<RegionVariable>(el))
2365 return regionVar->getVar()->isVariadic();
2366 return isa<WhitespaceElement, AttrDictDirective>(el);
2367 }
2368
2369 /// Scan the given range of elements from the start for a colon literal,
2370 /// skipping any optionally-parsed elements. If an optional group is
2371 /// encountered, this function recurses into the 'then' and 'else' elements to
2372 /// check if they are invalid. Returns `success` if the range is known to be
2373 /// valid or `None` if scanning reached the end.
2374 ///
2375 /// Since the guard element of an optional group is required, this function
2376 /// accepts an optional element pointer to mark it as required.
checkElementRangeForColon(function_ref<LogicalResult (const Twine &)> emitError,StringRef attrName,iterator_range<ArrayRef<FormatElement * >::iterator> elementRange,FormatElement * optionalGuard=nullptr)2377 static Optional<LogicalResult> checkElementRangeForColon(
2378 function_ref<LogicalResult(const Twine &)> emitError, StringRef attrName,
2379 iterator_range<ArrayRef<FormatElement *>::iterator> elementRange,
2380 FormatElement *optionalGuard = nullptr) {
2381 for (FormatElement *element : elementRange) {
2382 // Skip optionally parsed elements.
2383 if (element != optionalGuard && isOptionallyParsed(element))
2384 continue;
2385
2386 // Recurse on optional groups.
2387 if (auto *optional = dyn_cast<OptionalElement>(element)) {
2388 if (Optional<LogicalResult> result = checkElementRangeForColon(
2389 emitError, attrName, optional->getThenElements(),
2390 // The optional group guard is required for the group.
2391 optional->getThenElements().front()))
2392 if (failed(*result))
2393 return failure();
2394 if (Optional<LogicalResult> result = checkElementRangeForColon(
2395 emitError, attrName, optional->getElseElements()))
2396 if (failed(*result))
2397 return failure();
2398 // Skip the optional group.
2399 continue;
2400 }
2401
2402 // If we encounter anything other than `:`, this range is range.
2403 auto *literal = dyn_cast<LiteralElement>(element);
2404 if (!literal || literal->getSpelling() != ":")
2405 return success();
2406 // If we encounter `:`, the range is known to be invalid.
2407 return emitError(
2408 llvm::formatv("format ambiguity caused by `:` literal found after "
2409 "attribute `{0}` which does not have a buildable type",
2410 attrName));
2411 }
2412 // Return None to indicate that we reached the end.
2413 return llvm::None;
2414 }
2415
2416 /// For the given elements, check whether any attributes are followed by a colon
2417 /// literal, resulting in an ambiguous assembly format. Returns a non-null
2418 /// attribute if verification of said attribute reached the end of the range.
2419 /// Returns null if all attribute elements are verified.
2420 static FailureOr<AttributeVariable *>
verifyAttributeColon(function_ref<LogicalResult (const Twine &)> emitError,ArrayRef<FormatElement * > elements)2421 verifyAttributeColon(function_ref<LogicalResult(const Twine &)> emitError,
2422 ArrayRef<FormatElement *> elements) {
2423 for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) {
2424 // The current attribute being verified.
2425 AttributeVariable *attr = nullptr;
2426
2427 if ((attr = dyn_cast<AttributeVariable>(*it))) {
2428 // Check only attributes without type builders or that are known to call
2429 // the generic attribute parser.
2430 if (attr->getTypeBuilder() ||
2431 !(attr->shouldBeQualified() ||
2432 attr->getVar()->attr.getStorageType() == "::mlir::Attribute"))
2433 continue;
2434 } else if (auto *optional = dyn_cast<OptionalElement>(*it)) {
2435 // Recurse on optional groups.
2436 FailureOr<AttributeVariable *> thenResult =
2437 verifyAttributeColon(emitError, optional->getThenElements());
2438 if (failed(thenResult))
2439 return failure();
2440 FailureOr<AttributeVariable *> elseResult =
2441 verifyAttributeColon(emitError, optional->getElseElements());
2442 if (failed(elseResult))
2443 return failure();
2444 // If either optional group has an unverified attribute, save it.
2445 // Otherwise, move on to the next element.
2446 if (!(attr = *thenResult) && !(attr = *elseResult))
2447 continue;
2448 } else {
2449 continue;
2450 }
2451
2452 // Verify subsequent elements for potential ambiguities.
2453 if (Optional<LogicalResult> result = checkElementRangeForColon(
2454 emitError, attr->getVar()->name, {std::next(it), e})) {
2455 if (failed(*result))
2456 return failure();
2457 } else {
2458 // Since we reached the end, return the attribute as unverified.
2459 return attr;
2460 }
2461 }
2462 // All attribute elements are known to be verified.
2463 return nullptr;
2464 }
2465
2466 LogicalResult
verifyAttributeColonType(SMLoc loc,ArrayRef<FormatElement * > elements)2467 OpFormatParser::verifyAttributeColonType(SMLoc loc,
2468 ArrayRef<FormatElement *> elements) {
2469 return verifyAttributeColon(
2470 [&](const Twine &msg) { return emitError(loc, msg); }, elements);
2471 }
2472
verifyOperands(SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2473 LogicalResult OpFormatParser::verifyOperands(
2474 SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2475 // Check that all of the operands are within the format, and their types can
2476 // be inferred.
2477 auto &buildableTypes = fmt.buildableTypes;
2478 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2479 NamedTypeConstraint &operand = op.getOperand(i);
2480
2481 // Check that the operand itself is in the format.
2482 if (!fmt.allOperands && !seenOperands.count(&operand)) {
2483 return emitErrorAndNote(loc,
2484 "operand #" + Twine(i) + ", named '" +
2485 operand.name + "', not found",
2486 "suggest adding a '$" + operand.name +
2487 "' directive to the custom assembly format");
2488 }
2489
2490 // Check that the operand type is in the format, or that it can be inferred.
2491 if (fmt.allOperandTypes || seenOperandTypes.test(i))
2492 continue;
2493
2494 // Check to see if we can infer this type from another variable.
2495 auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
2496 if (varResolverIt != variableTyResolver.end()) {
2497 TypeResolutionInstance &resolver = varResolverIt->second;
2498 fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
2499 continue;
2500 }
2501
2502 // Similarly to results, allow a custom builder for resolving the type if
2503 // we aren't using the 'operands' directive.
2504 Optional<StringRef> builder = operand.constraint.getBuilderCall();
2505 if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2506 return emitErrorAndNote(
2507 loc,
2508 "type of operand #" + Twine(i) + ", named '" + operand.name +
2509 "', is not buildable and a buildable type cannot be inferred",
2510 "suggest adding a type constraint to the operation or adding a "
2511 "'type($" +
2512 operand.name + ")' directive to the " + "custom assembly format");
2513 }
2514 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2515 fmt.operandTypes[i].setBuilderIdx(it.first->second);
2516 }
2517 return success();
2518 }
2519
verifyRegions(SMLoc loc)2520 LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
2521 // Check that all of the regions are within the format.
2522 if (hasAllRegions)
2523 return success();
2524
2525 for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2526 const NamedRegion ®ion = op.getRegion(i);
2527 if (!seenRegions.count(®ion)) {
2528 return emitErrorAndNote(loc,
2529 "region #" + Twine(i) + ", named '" +
2530 region.name + "', not found",
2531 "suggest adding a '$" + region.name +
2532 "' directive to the custom assembly format");
2533 }
2534 }
2535 return success();
2536 }
2537
verifyResults(SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2538 LogicalResult OpFormatParser::verifyResults(
2539 SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2540 // If we format all of the types together, there is nothing to check.
2541 if (fmt.allResultTypes)
2542 return success();
2543
2544 // If no result types are specified and we can infer them, infer all result
2545 // types
2546 if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
2547 canInferResultTypes) {
2548 fmt.infersResultTypes = true;
2549 return success();
2550 }
2551
2552 // Check that all of the result types can be inferred.
2553 auto &buildableTypes = fmt.buildableTypes;
2554 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2555 if (seenResultTypes.test(i))
2556 continue;
2557
2558 // Check to see if we can infer this type from another variable.
2559 auto varResolverIt = variableTyResolver.find(op.getResultName(i));
2560 if (varResolverIt != variableTyResolver.end()) {
2561 TypeResolutionInstance resolver = varResolverIt->second;
2562 fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
2563 continue;
2564 }
2565
2566 // If the result is not variable length, allow for the case where the type
2567 // has a builder that we can use.
2568 NamedTypeConstraint &result = op.getResult(i);
2569 Optional<StringRef> builder = result.constraint.getBuilderCall();
2570 if (!builder || result.isVariableLength()) {
2571 return emitErrorAndNote(
2572 loc,
2573 "type of result #" + Twine(i) + ", named '" + result.name +
2574 "', is not buildable and a buildable type cannot be inferred",
2575 "suggest adding a type constraint to the operation or adding a "
2576 "'type($" +
2577 result.name + ")' directive to the " + "custom assembly format");
2578 }
2579 // Note in the format that this result uses the custom builder.
2580 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2581 fmt.resultTypes[i].setBuilderIdx(it.first->second);
2582 }
2583 return success();
2584 }
2585
verifySuccessors(SMLoc loc)2586 LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) {
2587 // Check that all of the successors are within the format.
2588 if (hasAllSuccessors)
2589 return success();
2590
2591 for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
2592 const NamedSuccessor &successor = op.getSuccessor(i);
2593 if (!seenSuccessors.count(&successor)) {
2594 return emitErrorAndNote(loc,
2595 "successor #" + Twine(i) + ", named '" +
2596 successor.name + "', not found",
2597 "suggest adding a '$" + successor.name +
2598 "' directive to the custom assembly format");
2599 }
2600 }
2601 return success();
2602 }
2603
2604 LogicalResult
verifyOIListElements(SMLoc loc,ArrayRef<FormatElement * > elements)2605 OpFormatParser::verifyOIListElements(SMLoc loc,
2606 ArrayRef<FormatElement *> elements) {
2607 // Check that all of the successors are within the format.
2608 SmallVector<StringRef> prohibitedLiterals;
2609 for (FormatElement *it : elements) {
2610 if (auto *oilist = dyn_cast<OIListElement>(it)) {
2611 if (!prohibitedLiterals.empty()) {
2612 // We just saw an oilist element in last iteration. Literals should not
2613 // match.
2614 for (LiteralElement *literal : oilist->getLiteralElements()) {
2615 if (find(prohibitedLiterals, literal->getSpelling()) !=
2616 prohibitedLiterals.end()) {
2617 return emitError(
2618 loc, "format ambiguity because " + literal->getSpelling() +
2619 " is used in two adjacent oilist elements.");
2620 }
2621 }
2622 }
2623 for (LiteralElement *literal : oilist->getLiteralElements())
2624 prohibitedLiterals.push_back(literal->getSpelling());
2625 } else if (auto *literal = dyn_cast<LiteralElement>(it)) {
2626 if (find(prohibitedLiterals, literal->getSpelling()) !=
2627 prohibitedLiterals.end()) {
2628 return emitError(
2629 loc,
2630 "format ambiguity because " + literal->getSpelling() +
2631 " is used both in oilist element and the adjacent literal.");
2632 }
2633 prohibitedLiterals.clear();
2634 } else {
2635 prohibitedLiterals.clear();
2636 }
2637 }
2638 return success();
2639 }
2640
handleAllTypesMatchConstraint(ArrayRef<StringRef> values,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2641 void OpFormatParser::handleAllTypesMatchConstraint(
2642 ArrayRef<StringRef> values,
2643 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2644 for (unsigned i = 0, e = values.size(); i != e; ++i) {
2645 // Check to see if this value matches a resolved operand or result type.
2646 ConstArgument arg = findSeenArg(values[i]);
2647 if (!arg)
2648 continue;
2649
2650 // Mark this value as the type resolver for the other variables.
2651 for (unsigned j = 0; j != i; ++j)
2652 variableTyResolver[values[j]] = {arg, llvm::None};
2653 for (unsigned j = i + 1; j != e; ++j)
2654 variableTyResolver[values[j]] = {arg, llvm::None};
2655 }
2656 }
2657
handleSameTypesConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,bool includeResults)2658 void OpFormatParser::handleSameTypesConstraint(
2659 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2660 bool includeResults) {
2661 const NamedTypeConstraint *resolver = nullptr;
2662 int resolvedIt = -1;
2663
2664 // Check to see if there is an operand or result to use for the resolution.
2665 if ((resolvedIt = seenOperandTypes.find_first()) != -1)
2666 resolver = &op.getOperand(resolvedIt);
2667 else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
2668 resolver = &op.getResult(resolvedIt);
2669 else
2670 return;
2671
2672 // Set the resolvers for each operand and result.
2673 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
2674 if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
2675 variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
2676 if (includeResults) {
2677 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2678 if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
2679 variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
2680 }
2681 }
2682
handleTypesMatchConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,const llvm::Record & def)2683 void OpFormatParser::handleTypesMatchConstraint(
2684 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2685 const llvm::Record &def) {
2686 StringRef lhsName = def.getValueAsString("lhs");
2687 StringRef rhsName = def.getValueAsString("rhs");
2688 StringRef transformer = def.getValueAsString("transformer");
2689 if (ConstArgument arg = findSeenArg(lhsName))
2690 variableTyResolver[rhsName] = {arg, transformer};
2691 }
2692
findSeenArg(StringRef name)2693 ConstArgument OpFormatParser::findSeenArg(StringRef name) {
2694 if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
2695 return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
2696 if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
2697 return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
2698 if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
2699 return seenAttrs.count(attr) ? attr : nullptr;
2700 return nullptr;
2701 }
2702
2703 FailureOr<FormatElement *>
parseVariableImpl(SMLoc loc,StringRef name,Context ctx)2704 OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
2705 // Check that the parsed argument is something actually registered on the op.
2706 // Attributes
2707 if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
2708 if (ctx == TypeDirectiveContext)
2709 return emitError(
2710 loc, "attributes cannot be used as children to a `type` directive");
2711 if (ctx == RefDirectiveContext) {
2712 if (!seenAttrs.count(attr))
2713 return emitError(loc, "attribute '" + name +
2714 "' must be bound before it is referenced");
2715 } else if (!seenAttrs.insert(attr)) {
2716 return emitError(loc, "attribute '" + name + "' is already bound");
2717 }
2718
2719 return create<AttributeVariable>(attr);
2720 }
2721 // Operands
2722 if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
2723 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
2724 if (fmt.allOperands || !seenOperands.insert(operand).second)
2725 return emitError(loc, "operand '" + name + "' is already bound");
2726 } else if (ctx == RefDirectiveContext && !seenOperands.count(operand)) {
2727 return emitError(loc, "operand '" + name +
2728 "' must be bound before it is referenced");
2729 }
2730 return create<OperandVariable>(operand);
2731 }
2732 // Regions
2733 if (const NamedRegion *region = findArg(op.getRegions(), name)) {
2734 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
2735 if (hasAllRegions || !seenRegions.insert(region).second)
2736 return emitError(loc, "region '" + name + "' is already bound");
2737 } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
2738 return emitError(loc, "region '" + name +
2739 "' must be bound before it is referenced");
2740 } else {
2741 return emitError(loc, "regions can only be used at the top level");
2742 }
2743 return create<RegionVariable>(region);
2744 }
2745 // Results.
2746 if (const auto *result = findArg(op.getResults(), name)) {
2747 if (ctx != TypeDirectiveContext)
2748 return emitError(loc, "result variables can can only be used as a child "
2749 "to a 'type' directive");
2750 return create<ResultVariable>(result);
2751 }
2752 // Successors.
2753 if (const auto *successor = findArg(op.getSuccessors(), name)) {
2754 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
2755 if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
2756 return emitError(loc, "successor '" + name + "' is already bound");
2757 } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
2758 return emitError(loc, "successor '" + name +
2759 "' must be bound before it is referenced");
2760 } else {
2761 return emitError(loc, "successors can only be used at the top level");
2762 }
2763
2764 return create<SuccessorVariable>(successor);
2765 }
2766 return emitError(loc, "expected variable to refer to an argument, region, "
2767 "result, or successor");
2768 }
2769
2770 FailureOr<FormatElement *>
parseDirectiveImpl(SMLoc loc,FormatToken::Kind kind,Context ctx)2771 OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
2772 Context ctx) {
2773 switch (kind) {
2774 case FormatToken::kw_attr_dict:
2775 return parseAttrDictDirective(loc, ctx,
2776 /*withKeyword=*/false);
2777 case FormatToken::kw_attr_dict_w_keyword:
2778 return parseAttrDictDirective(loc, ctx,
2779 /*withKeyword=*/true);
2780 case FormatToken::kw_functional_type:
2781 return parseFunctionalTypeDirective(loc, ctx);
2782 case FormatToken::kw_operands:
2783 return parseOperandsDirective(loc, ctx);
2784 case FormatToken::kw_qualified:
2785 return parseQualifiedDirective(loc, ctx);
2786 case FormatToken::kw_regions:
2787 return parseRegionsDirective(loc, ctx);
2788 case FormatToken::kw_results:
2789 return parseResultsDirective(loc, ctx);
2790 case FormatToken::kw_successors:
2791 return parseSuccessorsDirective(loc, ctx);
2792 case FormatToken::kw_ref:
2793 return parseReferenceDirective(loc, ctx);
2794 case FormatToken::kw_type:
2795 return parseTypeDirective(loc, ctx);
2796 case FormatToken::kw_oilist:
2797 return parseOIListDirective(loc, ctx);
2798
2799 default:
2800 return emitError(loc, "unsupported directive kind");
2801 }
2802 }
2803
2804 FailureOr<FormatElement *>
parseAttrDictDirective(SMLoc loc,Context context,bool withKeyword)2805 OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
2806 bool withKeyword) {
2807 if (context == TypeDirectiveContext)
2808 return emitError(loc, "'attr-dict' directive can only be used as a "
2809 "top-level directive");
2810
2811 if (context == RefDirectiveContext) {
2812 if (!hasAttrDict)
2813 return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior "
2814 "'attr-dict' directive");
2815
2816 // Otherwise, this is a top-level context.
2817 } else {
2818 if (hasAttrDict)
2819 return emitError(loc, "'attr-dict' directive has already been seen");
2820 hasAttrDict = true;
2821 }
2822
2823 return create<AttrDictDirective>(withKeyword);
2824 }
2825
verifyCustomDirectiveArguments(SMLoc loc,ArrayRef<FormatElement * > arguments)2826 LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
2827 SMLoc loc, ArrayRef<FormatElement *> arguments) {
2828 for (FormatElement *argument : arguments) {
2829 if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable,
2830 OperandVariable, RegionVariable, SuccessorVariable>(argument)) {
2831 // TODO: FormatElement should have location info attached.
2832 return emitError(loc, "only variables and types may be used as "
2833 "parameters to a custom directive");
2834 }
2835 if (auto *type = dyn_cast<TypeDirective>(argument)) {
2836 if (!isa<OperandVariable, ResultVariable>(type->getArg())) {
2837 return emitError(loc, "type directives within a custom directive may "
2838 "only refer to variables");
2839 }
2840 }
2841 }
2842 return success();
2843 }
2844
2845 FailureOr<FormatElement *>
parseFunctionalTypeDirective(SMLoc loc,Context context)2846 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) {
2847 if (context != TopLevelContext)
2848 return emitError(
2849 loc, "'functional-type' is only valid as a top-level directive");
2850
2851 // Parse the main operand.
2852 FailureOr<FormatElement *> inputs, results;
2853 if (failed(parseToken(FormatToken::l_paren,
2854 "expected '(' before argument list")) ||
2855 failed(inputs = parseTypeDirectiveOperand(loc)) ||
2856 failed(parseToken(FormatToken::comma,
2857 "expected ',' after inputs argument")) ||
2858 failed(results = parseTypeDirectiveOperand(loc)) ||
2859 failed(
2860 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
2861 return failure();
2862 return create<FunctionalTypeDirective>(*inputs, *results);
2863 }
2864
2865 FailureOr<FormatElement *>
parseOperandsDirective(SMLoc loc,Context context)2866 OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
2867 if (context == RefDirectiveContext) {
2868 if (!fmt.allOperands)
2869 return emitError(loc, "'ref' of 'operands' is not bound by a prior "
2870 "'operands' directive");
2871
2872 } else if (context == TopLevelContext || context == CustomDirectiveContext) {
2873 if (fmt.allOperands || !seenOperands.empty())
2874 return emitError(loc, "'operands' directive creates overlap in format");
2875 fmt.allOperands = true;
2876 }
2877 return create<OperandsDirective>();
2878 }
2879
2880 FailureOr<FormatElement *>
parseReferenceDirective(SMLoc loc,Context context)2881 OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
2882 if (context != CustomDirectiveContext)
2883 return emitError(loc, "'ref' is only valid within a `custom` directive");
2884
2885 FailureOr<FormatElement *> arg;
2886 if (failed(parseToken(FormatToken::l_paren,
2887 "expected '(' before argument list")) ||
2888 failed(arg = parseElement(RefDirectiveContext)) ||
2889 failed(
2890 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
2891 return failure();
2892
2893 return create<RefDirective>(*arg);
2894 }
2895
2896 FailureOr<FormatElement *>
parseRegionsDirective(SMLoc loc,Context context)2897 OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
2898 if (context == TypeDirectiveContext)
2899 return emitError(loc, "'regions' is only valid as a top-level directive");
2900 if (context == RefDirectiveContext) {
2901 if (!hasAllRegions)
2902 return emitError(loc, "'ref' of 'regions' is not bound by a prior "
2903 "'regions' directive");
2904
2905 // Otherwise, this is a TopLevel directive.
2906 } else {
2907 if (hasAllRegions || !seenRegions.empty())
2908 return emitError(loc, "'regions' directive creates overlap in format");
2909 hasAllRegions = true;
2910 }
2911 return create<RegionsDirective>();
2912 }
2913
2914 FailureOr<FormatElement *>
parseResultsDirective(SMLoc loc,Context context)2915 OpFormatParser::parseResultsDirective(SMLoc loc, Context context) {
2916 if (context != TypeDirectiveContext)
2917 return emitError(loc, "'results' directive can can only be used as a child "
2918 "to a 'type' directive");
2919 return create<ResultsDirective>();
2920 }
2921
2922 FailureOr<FormatElement *>
parseSuccessorsDirective(SMLoc loc,Context context)2923 OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) {
2924 if (context == TypeDirectiveContext)
2925 return emitError(loc,
2926 "'successors' is only valid as a top-level directive");
2927 if (context == RefDirectiveContext) {
2928 if (!hasAllSuccessors)
2929 return emitError(loc, "'ref' of 'successors' is not bound by a prior "
2930 "'successors' directive");
2931
2932 // Otherwise, this is a TopLevel directive.
2933 } else {
2934 if (hasAllSuccessors || !seenSuccessors.empty())
2935 return emitError(loc, "'successors' directive creates overlap in format");
2936 hasAllSuccessors = true;
2937 }
2938 return create<SuccessorsDirective>();
2939 }
2940
2941 FailureOr<FormatElement *>
parseOIListDirective(SMLoc loc,Context context)2942 OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
2943 if (failed(parseToken(FormatToken::l_paren,
2944 "expected '(' before oilist argument list")))
2945 return failure();
2946 std::vector<FormatElement *> literalElements;
2947 std::vector<std::vector<FormatElement *>> parsingElements;
2948 do {
2949 FailureOr<FormatElement *> lelement = parseLiteral(context);
2950 if (failed(lelement))
2951 return failure();
2952 literalElements.push_back(*lelement);
2953 parsingElements.emplace_back();
2954 std::vector<FormatElement *> &currParsingElements = parsingElements.back();
2955 while (peekToken().getKind() != FormatToken::pipe &&
2956 peekToken().getKind() != FormatToken::r_paren) {
2957 FailureOr<FormatElement *> pelement = parseElement(context);
2958 if (failed(pelement) ||
2959 failed(verifyOIListParsingElement(*pelement, loc)))
2960 return failure();
2961 currParsingElements.push_back(*pelement);
2962 }
2963 if (peekToken().getKind() == FormatToken::pipe) {
2964 consumeToken();
2965 continue;
2966 }
2967 if (peekToken().getKind() == FormatToken::r_paren) {
2968 consumeToken();
2969 break;
2970 }
2971 } while (true);
2972
2973 return create<OIListElement>(std::move(literalElements),
2974 std::move(parsingElements));
2975 }
2976
verifyOIListParsingElement(FormatElement * element,SMLoc loc)2977 LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
2978 SMLoc loc) {
2979 SmallVector<VariableElement *> vars;
2980 collect(element, vars);
2981 for (VariableElement *elem : vars) {
2982 LogicalResult res =
2983 TypeSwitch<FormatElement *, LogicalResult>(elem)
2984 // Only optional attributes can be within an oilist parsing group.
2985 .Case([&](AttributeVariable *attrEle) {
2986 if (!attrEle->getVar()->attr.isOptional() &&
2987 !attrEle->getVar()->attr.hasDefaultValue())
2988 return emitError(loc, "only optional attributes can be used in "
2989 "an oilist parsing group");
2990 return success();
2991 })
2992 // Only optional-like(i.e. variadic) operands can be within an
2993 // oilist parsing group.
2994 .Case([&](OperandVariable *ele) {
2995 if (!ele->getVar()->isVariableLength())
2996 return emitError(loc, "only variable length operands can be "
2997 "used within an oilist parsing group");
2998 return success();
2999 })
3000 // Only optional-like(i.e. variadic) results can be within an oilist
3001 // parsing group.
3002 .Case([&](ResultVariable *ele) {
3003 if (!ele->getVar()->isVariableLength())
3004 return emitError(loc, "only variable length results can be "
3005 "used within an oilist parsing group");
3006 return success();
3007 })
3008 .Case([&](RegionVariable *) { return success(); })
3009 .Default([&](FormatElement *) {
3010 return emitError(loc,
3011 "only literals, types, and variables can be "
3012 "used within an oilist group");
3013 });
3014 if (failed(res))
3015 return failure();
3016 }
3017 return success();
3018 }
3019
parseTypeDirective(SMLoc loc,Context context)3020 FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
3021 Context context) {
3022 if (context == TypeDirectiveContext)
3023 return emitError(loc, "'type' cannot be used as a child of another `type`");
3024
3025 bool isRefChild = context == RefDirectiveContext;
3026 FailureOr<FormatElement *> operand;
3027 if (failed(parseToken(FormatToken::l_paren,
3028 "expected '(' before argument list")) ||
3029 failed(operand = parseTypeDirectiveOperand(loc, isRefChild)) ||
3030 failed(
3031 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3032 return failure();
3033
3034 return create<TypeDirective>(*operand);
3035 }
3036
3037 FailureOr<FormatElement *>
parseQualifiedDirective(SMLoc loc,Context context)3038 OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
3039 FailureOr<FormatElement *> element;
3040 if (failed(parseToken(FormatToken::l_paren,
3041 "expected '(' before argument list")) ||
3042 failed(element = parseElement(context)) ||
3043 failed(
3044 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3045 return failure();
3046 return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
3047 .Case<AttributeVariable, TypeDirective>([](auto *element) {
3048 element->setShouldBeQualified();
3049 return element;
3050 })
3051 .Default([&](auto *element) {
3052 return this->emitError(
3053 loc,
3054 "'qualified' directive expects an attribute or a `type` directive");
3055 });
3056 }
3057
3058 FailureOr<FormatElement *>
parseTypeDirectiveOperand(SMLoc loc,bool isRefChild)3059 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) {
3060 FailureOr<FormatElement *> result = parseElement(TypeDirectiveContext);
3061 if (failed(result))
3062 return failure();
3063
3064 FormatElement *element = *result;
3065 if (isa<LiteralElement>(element))
3066 return emitError(
3067 loc, "'type' directive operand expects variable or directive operand");
3068
3069 if (auto *var = dyn_cast<OperandVariable>(element)) {
3070 unsigned opIdx = var->getVar() - op.operand_begin();
3071 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3072 return emitError(loc, "'type' of '" + var->getVar()->name +
3073 "' is already bound");
3074 if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3075 return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3076 ")' is not bound by a prior 'type' directive");
3077 seenOperandTypes.set(opIdx);
3078 } else if (auto *var = dyn_cast<ResultVariable>(element)) {
3079 unsigned resIdx = var->getVar() - op.result_begin();
3080 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
3081 return emitError(loc, "'type' of '" + var->getVar()->name +
3082 "' is already bound");
3083 if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
3084 return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3085 ")' is not bound by a prior 'type' directive");
3086 seenResultTypes.set(resIdx);
3087 } else if (isa<OperandsDirective>(&*element)) {
3088 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any()))
3089 return emitError(loc, "'operands' 'type' is already bound");
3090 if (isRefChild && !fmt.allOperandTypes)
3091 return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior "
3092 "'type' directive");
3093 fmt.allOperandTypes = true;
3094 } else if (isa<ResultsDirective>(&*element)) {
3095 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any()))
3096 return emitError(loc, "'results' 'type' is already bound");
3097 if (isRefChild && !fmt.allResultTypes)
3098 return emitError(loc, "'ref' of 'type(results)' is not bound by a prior "
3099 "'type' directive");
3100 fmt.allResultTypes = true;
3101 } else {
3102 return emitError(loc, "invalid argument to 'type' directive");
3103 }
3104 return element;
3105 }
3106
3107 LogicalResult
verifyOptionalGroupElements(SMLoc loc,ArrayRef<FormatElement * > elements,Optional<unsigned> anchorIndex)3108 OpFormatParser::verifyOptionalGroupElements(SMLoc loc,
3109 ArrayRef<FormatElement *> elements,
3110 Optional<unsigned> anchorIndex) {
3111 for (auto &it : llvm::enumerate(elements)) {
3112 if (failed(verifyOptionalGroupElement(
3113 loc, it.value(), anchorIndex && *anchorIndex == it.index())))
3114 return failure();
3115 }
3116 return success();
3117 }
3118
verifyOptionalGroupElement(SMLoc loc,FormatElement * element,bool isAnchor)3119 LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
3120 FormatElement *element,
3121 bool isAnchor) {
3122 return TypeSwitch<FormatElement *, LogicalResult>(element)
3123 // All attributes can be within the optional group, but only optional
3124 // attributes can be the anchor.
3125 .Case([&](AttributeVariable *attrEle) {
3126 if (isAnchor && !attrEle->getVar()->attr.isOptional())
3127 return emitError(loc, "only optional attributes can be used to "
3128 "anchor an optional group");
3129 return success();
3130 })
3131 // Only optional-like(i.e. variadic) operands can be within an optional
3132 // group.
3133 .Case([&](OperandVariable *ele) {
3134 if (!ele->getVar()->isVariableLength())
3135 return emitError(loc, "only variable length operands can be used "
3136 "within an optional group");
3137 return success();
3138 })
3139 // Only optional-like(i.e. variadic) results can be within an optional
3140 // group.
3141 .Case([&](ResultVariable *ele) {
3142 if (!ele->getVar()->isVariableLength())
3143 return emitError(loc, "only variable length results can be used "
3144 "within an optional group");
3145 return success();
3146 })
3147 .Case([&](RegionVariable *) {
3148 // TODO: When ODS has proper support for marking "optional" regions, add
3149 // a check here.
3150 return success();
3151 })
3152 .Case([&](TypeDirective *ele) {
3153 return verifyOptionalGroupElement(loc, ele->getArg(),
3154 /*isAnchor=*/false);
3155 })
3156 .Case([&](FunctionalTypeDirective *ele) {
3157 if (failed(verifyOptionalGroupElement(loc, ele->getInputs(),
3158 /*isAnchor=*/false)))
3159 return failure();
3160 return verifyOptionalGroupElement(loc, ele->getResults(),
3161 /*isAnchor=*/false);
3162 })
3163 // Literals, whitespace, and custom directives may be used, but they can't
3164 // anchor the group.
3165 .Case<LiteralElement, WhitespaceElement, CustomDirective,
3166 FunctionalTypeDirective, OptionalElement>([&](FormatElement *) {
3167 if (isAnchor)
3168 return emitError(loc, "only variables and types can be used "
3169 "to anchor an optional group");
3170 return success();
3171 })
3172 .Default([&](FormatElement *) {
3173 return emitError(loc, "only literals, types, and variables can be "
3174 "used within an optional group");
3175 });
3176 }
3177
3178 //===----------------------------------------------------------------------===//
3179 // Interface
3180 //===----------------------------------------------------------------------===//
3181
generateOpFormat(const Operator & constOp,OpClass & opClass)3182 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
3183 // TODO: Operator doesn't expose all necessary functionality via
3184 // the const interface.
3185 Operator &op = const_cast<Operator &>(constOp);
3186 if (!op.hasAssemblyFormat())
3187 return;
3188
3189 // Parse the format description.
3190 llvm::SourceMgr mgr;
3191 mgr.AddNewSourceBuffer(
3192 llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), SMLoc());
3193 OperationFormat format(op);
3194 OpFormatParser parser(mgr, format, op);
3195 FailureOr<std::vector<FormatElement *>> elements = parser.parse();
3196 if (failed(elements)) {
3197 // Exit the process if format errors are treated as fatal.
3198 if (formatErrorIsFatal) {
3199 // Invoke the interrupt handlers to run the file cleanup handlers.
3200 llvm::sys::RunInterruptHandlers();
3201 std::exit(1);
3202 }
3203 return;
3204 }
3205 format.elements = std::move(*elements);
3206
3207 // Generate the printer and parser based on the parsed format.
3208 format.genParser(op, opClass);
3209 format.genPrinter(op, opClass);
3210 }
3211