1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "OpFormatGen.h" 10 #include "FormatGen.h" 11 #include "OpClass.h" 12 #include "mlir/Support/LogicalResult.h" 13 #include "mlir/TableGen/Class.h" 14 #include "mlir/TableGen/Format.h" 15 #include "mlir/TableGen/GenInfo.h" 16 #include "mlir/TableGen/Interfaces.h" 17 #include "mlir/TableGen/Operator.h" 18 #include "mlir/TableGen/Trait.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/ADT/SmallBitVector.h" 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/Support/Signals.h" 26 #include "llvm/TableGen/Error.h" 27 #include "llvm/TableGen/Record.h" 28 29 #define DEBUG_TYPE "mlir-tblgen-opformatgen" 30 31 using namespace mlir; 32 using namespace mlir::tblgen; 33 34 //===----------------------------------------------------------------------===// 35 // Element 36 //===----------------------------------------------------------------------===// 37 38 namespace { 39 /// This class represents a single format element. 40 class Element { 41 public: 42 enum class Kind { 43 /// This element is a directive. 44 AttrDictDirective, 45 CustomDirective, 46 FunctionalTypeDirective, 47 OperandsDirective, 48 RefDirective, 49 RegionsDirective, 50 ResultsDirective, 51 SuccessorsDirective, 52 TypeDirective, 53 54 /// This element is a literal. 55 Literal, 56 57 /// This element is a whitespace. 58 Newline, 59 Space, 60 61 /// This element is an variable value. 62 AttributeVariable, 63 OperandVariable, 64 RegionVariable, 65 ResultVariable, 66 SuccessorVariable, 67 68 /// This element is an optional element. 69 Optional, 70 }; 71 Element(Kind kind) : kind(kind) {} 72 virtual ~Element() = default; 73 74 /// Return the kind of this element. 75 Kind getKind() const { return kind; } 76 77 private: 78 /// The kind of this element. 79 Kind kind; 80 }; 81 } // namespace 82 83 //===----------------------------------------------------------------------===// 84 // VariableElement 85 86 namespace { 87 /// This class represents an instance of an variable element. A variable refers 88 /// to something registered on the operation itself, e.g. an argument, result, 89 /// etc. 90 template <typename VarT, Element::Kind kindVal> 91 class VariableElement : public Element { 92 public: 93 VariableElement(const VarT *var) : Element(kindVal), var(var) {} 94 static bool classof(const Element *element) { 95 return element->getKind() == kindVal; 96 } 97 const VarT *getVar() { return var; } 98 99 protected: 100 const VarT *var; 101 }; 102 103 /// This class represents a variable that refers to an attribute argument. 104 struct AttributeVariable 105 : public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> { 106 using VariableElement<NamedAttribute, 107 Element::Kind::AttributeVariable>::VariableElement; 108 109 /// Return the constant builder call for the type of this attribute, or None 110 /// if it doesn't have one. 111 Optional<StringRef> getTypeBuilder() const { 112 Optional<Type> attrType = var->attr.getValueType(); 113 return attrType ? attrType->getBuilderCall() : llvm::None; 114 } 115 116 /// Return if this attribute refers to a UnitAttr. 117 bool isUnitAttr() const { 118 return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr"; 119 } 120 121 /// Indicate if this attribute is printed "qualified" (that is it is 122 /// prefixed with the `#dialect.mnemonic`). 123 bool shouldBeQualified() { return shouldBeQualifiedFlag; } 124 void setShouldBeQualified(bool qualified = true) { 125 shouldBeQualifiedFlag = qualified; 126 } 127 128 private: 129 bool shouldBeQualifiedFlag = false; 130 }; 131 132 /// This class represents a variable that refers to an operand argument. 133 using OperandVariable = 134 VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>; 135 136 /// This class represents a variable that refers to a region. 137 using RegionVariable = 138 VariableElement<NamedRegion, Element::Kind::RegionVariable>; 139 140 /// This class represents a variable that refers to a result. 141 using ResultVariable = 142 VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>; 143 144 /// This class represents a variable that refers to a successor. 145 using SuccessorVariable = 146 VariableElement<NamedSuccessor, Element::Kind::SuccessorVariable>; 147 } // namespace 148 149 //===----------------------------------------------------------------------===// 150 // DirectiveElement 151 152 namespace { 153 /// This class implements single kind directives. 154 template <Element::Kind type> class DirectiveElement : public Element { 155 public: 156 DirectiveElement() : Element(type){}; 157 static bool classof(const Element *ele) { return ele->getKind() == type; } 158 }; 159 /// This class represents the `operands` directive. This directive represents 160 /// all of the operands of an operation. 161 using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>; 162 163 /// This class represents the `regions` directive. This directive represents 164 /// all of the regions of an operation. 165 using RegionsDirective = DirectiveElement<Element::Kind::RegionsDirective>; 166 167 /// This class represents the `results` directive. This directive represents 168 /// all of the results of an operation. 169 using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>; 170 171 /// This class represents the `successors` directive. This directive represents 172 /// all of the successors of an operation. 173 using SuccessorsDirective = 174 DirectiveElement<Element::Kind::SuccessorsDirective>; 175 176 /// This class represents the `attr-dict` directive. This directive represents 177 /// the attribute dictionary of the operation. 178 class AttrDictDirective 179 : public DirectiveElement<Element::Kind::AttrDictDirective> { 180 public: 181 explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {} 182 bool isWithKeyword() const { return withKeyword; } 183 184 private: 185 /// If the dictionary should be printed with the 'attributes' keyword. 186 bool withKeyword; 187 }; 188 189 /// This class represents a custom format directive that is implemented by the 190 /// user in C++. 191 class CustomDirective : public Element { 192 public: 193 CustomDirective(StringRef name, 194 std::vector<std::unique_ptr<Element>> &&arguments) 195 : Element{Kind::CustomDirective}, name(name), 196 arguments(std::move(arguments)) {} 197 198 static bool classof(const Element *element) { 199 return element->getKind() == Kind::CustomDirective; 200 } 201 202 /// Return the name of the custom directive. 203 StringRef getName() const { return name; } 204 205 /// Return the arguments to the custom directive. 206 auto getArguments() const { return llvm::make_pointee_range(arguments); } 207 208 private: 209 /// The user provided name of the directive. 210 StringRef name; 211 212 /// The arguments to the custom directive. 213 std::vector<std::unique_ptr<Element>> arguments; 214 }; 215 216 /// This class represents the `functional-type` directive. This directive takes 217 /// two arguments and formats them, respectively, as the inputs and results of a 218 /// FunctionType. 219 class FunctionalTypeDirective 220 : public DirectiveElement<Element::Kind::FunctionalTypeDirective> { 221 public: 222 FunctionalTypeDirective(std::unique_ptr<Element> inputs, 223 std::unique_ptr<Element> results) 224 : inputs(std::move(inputs)), results(std::move(results)) {} 225 Element *getInputs() const { return inputs.get(); } 226 Element *getResults() const { return results.get(); } 227 228 private: 229 /// The input and result arguments. 230 std::unique_ptr<Element> inputs, results; 231 }; 232 233 /// This class represents the `ref` directive. 234 class RefDirective : public DirectiveElement<Element::Kind::RefDirective> { 235 public: 236 RefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {} 237 Element *getOperand() const { return operand.get(); } 238 239 private: 240 /// The operand that is used to format the directive. 241 std::unique_ptr<Element> operand; 242 }; 243 244 /// This class represents the `type` directive. 245 class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> { 246 public: 247 TypeDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {} 248 Element *getOperand() const { return operand.get(); } 249 250 /// Indicate if this type is printed "qualified" (that is it is 251 /// prefixed with the `!dialect.mnemonic`). 252 bool shouldBeQualified() { return shouldBeQualifiedFlag; } 253 void setShouldBeQualified(bool qualified = true) { 254 shouldBeQualifiedFlag = qualified; 255 } 256 257 private: 258 /// The operand that is used to format the directive. 259 std::unique_ptr<Element> operand; 260 261 bool shouldBeQualifiedFlag = false; 262 }; 263 } // namespace 264 265 //===----------------------------------------------------------------------===// 266 // LiteralElement 267 268 namespace { 269 /// This class represents an instance of a literal element. 270 class LiteralElement : public Element { 271 public: 272 LiteralElement(StringRef literal) 273 : Element{Kind::Literal}, literal(literal) {} 274 static bool classof(const Element *element) { 275 return element->getKind() == Kind::Literal; 276 } 277 278 /// Return the literal for this element. 279 StringRef getLiteral() const { return literal; } 280 281 private: 282 /// The spelling of the literal for this element. 283 StringRef literal; 284 }; 285 } // namespace 286 287 //===----------------------------------------------------------------------===// 288 // WhitespaceElement 289 290 namespace { 291 /// This class represents a whitespace element, e.g. newline or space. It's a 292 /// literal that is printed but never parsed. 293 class WhitespaceElement : public Element { 294 public: 295 WhitespaceElement(Kind kind) : Element{kind} {} 296 static bool classof(const Element *element) { 297 Kind kind = element->getKind(); 298 return kind == Kind::Newline || kind == Kind::Space; 299 } 300 }; 301 302 /// This class represents an instance of a newline element. It's a literal that 303 /// prints a newline. It is ignored by the parser. 304 class NewlineElement : public WhitespaceElement { 305 public: 306 NewlineElement() : WhitespaceElement(Kind::Newline) {} 307 static bool classof(const Element *element) { 308 return element->getKind() == Kind::Newline; 309 } 310 }; 311 312 /// This class represents an instance of a space element. It's a literal that 313 /// prints or omits printing a space. It is ignored by the parser. 314 class SpaceElement : public WhitespaceElement { 315 public: 316 SpaceElement(bool value) : WhitespaceElement(Kind::Space), value(value) {} 317 static bool classof(const Element *element) { 318 return element->getKind() == Kind::Space; 319 } 320 321 /// Returns true if this element should print as a space. Otherwise, the 322 /// element should omit printing a space between the surrounding elements. 323 bool getValue() const { return value; } 324 325 private: 326 bool value; 327 }; 328 } // namespace 329 330 //===----------------------------------------------------------------------===// 331 // OptionalElement 332 333 namespace { 334 /// This class represents a group of elements that are optionally emitted based 335 /// upon an optional variable of the operation, and a group of elements that are 336 /// emotted when the anchor element is not present. 337 class OptionalElement : public Element { 338 public: 339 OptionalElement(std::vector<std::unique_ptr<Element>> &&thenElements, 340 std::vector<std::unique_ptr<Element>> &&elseElements, 341 unsigned anchor, unsigned parseStart) 342 : Element{Kind::Optional}, thenElements(std::move(thenElements)), 343 elseElements(std::move(elseElements)), anchor(anchor), 344 parseStart(parseStart) {} 345 static bool classof(const Element *element) { 346 return element->getKind() == Kind::Optional; 347 } 348 349 /// Return the `then` elements of this grouping. 350 auto getThenElements() const { 351 return llvm::make_pointee_range(thenElements); 352 } 353 354 /// Return the `else` elements of this grouping. 355 auto getElseElements() const { 356 return llvm::make_pointee_range(elseElements); 357 } 358 359 /// Return the anchor of this optional group. 360 Element *getAnchor() const { return thenElements[anchor].get(); } 361 362 /// Return the index of the first element that needs to be parsed. 363 unsigned getParseStart() const { return parseStart; } 364 365 private: 366 /// The child elements of `then` branch of this optional. 367 std::vector<std::unique_ptr<Element>> thenElements; 368 /// The child elements of `else` branch of this optional. 369 std::vector<std::unique_ptr<Element>> elseElements; 370 /// The index of the element that acts as the anchor for the optional group. 371 unsigned anchor; 372 /// The index of the first element that is parsed (is not a 373 /// WhitespaceElement). 374 unsigned parseStart; 375 }; 376 } // namespace 377 378 //===----------------------------------------------------------------------===// 379 // OperationFormat 380 //===----------------------------------------------------------------------===// 381 382 namespace { 383 384 using ConstArgument = 385 llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>; 386 387 struct OperationFormat { 388 /// This class represents a specific resolver for an operand or result type. 389 class TypeResolution { 390 public: 391 TypeResolution() = default; 392 393 /// Get the index into the buildable types for this type, or None. 394 Optional<int> getBuilderIdx() const { return builderIdx; } 395 void setBuilderIdx(int idx) { builderIdx = idx; } 396 397 /// Get the variable this type is resolved to, or nullptr. 398 const NamedTypeConstraint *getVariable() const { 399 return resolver.dyn_cast<const NamedTypeConstraint *>(); 400 } 401 /// Get the attribute this type is resolved to, or nullptr. 402 const NamedAttribute *getAttribute() const { 403 return resolver.dyn_cast<const NamedAttribute *>(); 404 } 405 /// Get the transformer for the type of the variable, or None. 406 Optional<StringRef> getVarTransformer() const { 407 return variableTransformer; 408 } 409 void setResolver(ConstArgument arg, Optional<StringRef> transformer) { 410 resolver = arg; 411 variableTransformer = transformer; 412 assert(getVariable() || getAttribute()); 413 } 414 415 private: 416 /// If the type is resolved with a buildable type, this is the index into 417 /// 'buildableTypes' in the parent format. 418 Optional<int> builderIdx; 419 /// If the type is resolved based upon another operand or result, this is 420 /// the variable or the attribute that this type is resolved to. 421 ConstArgument resolver; 422 /// If the type is resolved based upon another operand or result, this is 423 /// a transformer to apply to the variable when resolving. 424 Optional<StringRef> variableTransformer; 425 }; 426 427 /// The context in which an element is generated. 428 enum class GenContext { 429 /// The element is generated at the top-level or with the same behaviour. 430 Normal, 431 /// The element is generated inside an optional group. 432 Optional 433 }; 434 435 OperationFormat(const Operator &op) 436 : allOperands(false), allOperandTypes(false), allResultTypes(false), 437 infersResultTypes(false) { 438 operandTypes.resize(op.getNumOperands(), TypeResolution()); 439 resultTypes.resize(op.getNumResults(), TypeResolution()); 440 441 hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) { 442 return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator"); 443 }); 444 445 hasSingleBlockTrait = 446 hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock"); 447 } 448 449 /// Generate the operation parser from this format. 450 void genParser(Operator &op, OpClass &opClass); 451 /// Generate the parser code for a specific format element. 452 void genElementParser(Element *element, MethodBody &body, 453 FmtContext &attrTypeCtx, 454 GenContext genCtx = GenContext::Normal); 455 /// Generate the C++ to resolve the types of operands and results during 456 /// parsing. 457 void genParserTypeResolution(Operator &op, MethodBody &body); 458 /// Generate the C++ to resolve the types of the operands during parsing. 459 void genParserOperandTypeResolution( 460 Operator &op, MethodBody &body, 461 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver); 462 /// Generate the C++ to resolve regions during parsing. 463 void genParserRegionResolution(Operator &op, MethodBody &body); 464 /// Generate the C++ to resolve successors during parsing. 465 void genParserSuccessorResolution(Operator &op, MethodBody &body); 466 /// Generate the C++ to handling variadic segment size traits. 467 void genParserVariadicSegmentResolution(Operator &op, MethodBody &body); 468 469 /// Generate the operation printer from this format. 470 void genPrinter(Operator &op, OpClass &opClass); 471 472 /// Generate the printer code for a specific format element. 473 void genElementPrinter(Element *element, MethodBody &body, Operator &op, 474 bool &shouldEmitSpace, bool &lastWasPunctuation); 475 476 /// The various elements in this format. 477 std::vector<std::unique_ptr<Element>> elements; 478 479 /// A flag indicating if all operand/result types were seen. If the format 480 /// contains these, it can not contain individual type resolvers. 481 bool allOperands, allOperandTypes, allResultTypes; 482 483 /// A flag indicating if this operation infers its result types 484 bool infersResultTypes; 485 486 /// A flag indicating if this operation has the SingleBlockImplicitTerminator 487 /// trait. 488 bool hasImplicitTermTrait; 489 490 /// A flag indicating if this operation has the SingleBlock trait. 491 bool hasSingleBlockTrait; 492 493 /// A map of buildable types to indices. 494 llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes; 495 496 /// The index of the buildable type, if valid, for every operand and result. 497 std::vector<TypeResolution> operandTypes, resultTypes; 498 499 /// The set of attributes explicitly used within the format. 500 SmallVector<const NamedAttribute *, 8> usedAttributes; 501 llvm::StringSet<> inferredAttributes; 502 }; 503 } // namespace 504 505 //===----------------------------------------------------------------------===// 506 // Parser Gen 507 508 /// Returns true if we can format the given attribute as an EnumAttr in the 509 /// parser format. 510 static bool canFormatEnumAttr(const NamedAttribute *attr) { 511 Attribute baseAttr = attr->attr.getBaseAttr(); 512 const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr); 513 if (!enumAttr) 514 return false; 515 516 // The attribute must have a valid underlying type and a constant builder. 517 return !enumAttr->getUnderlyingType().empty() && 518 !enumAttr->getConstBuilderTemplate().empty(); 519 } 520 521 /// Returns if we should format the given attribute as an SymbolNameAttr. 522 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) { 523 return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr"; 524 } 525 526 /// The code snippet used to generate a parser call for an attribute. 527 /// 528 /// {0}: The name of the attribute. 529 /// {1}: The type for the attribute. 530 const char *const attrParserCode = R"( 531 if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}", 532 result.attributes)) {{ 533 return ::mlir::failure(); 534 } 535 )"; 536 537 /// The code snippet used to generate a parser call for an attribute. 538 /// 539 /// {0}: The name of the attribute. 540 /// {1}: The type for the attribute. 541 const char *const genericAttrParserCode = R"( 542 if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes)) 543 return ::mlir::failure(); 544 )"; 545 546 const char *const optionalAttrParserCode = R"( 547 { 548 ::mlir::OptionalParseResult parseResult = 549 parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes); 550 if (parseResult.hasValue() && failed(*parseResult)) 551 return ::mlir::failure(); 552 } 553 )"; 554 555 /// The code snippet used to generate a parser call for a symbol name attribute. 556 /// 557 /// {0}: The name of the attribute. 558 const char *const symbolNameAttrParserCode = R"( 559 if (parser.parseSymbolName({0}Attr, "{0}", result.attributes)) 560 return ::mlir::failure(); 561 )"; 562 const char *const optionalSymbolNameAttrParserCode = R"( 563 // Parsing an optional symbol name doesn't fail, so no need to check the 564 // result. 565 (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes); 566 )"; 567 568 /// The code snippet used to generate a parser call for an enum attribute. 569 /// 570 /// {0}: The name of the attribute. 571 /// {1}: The c++ namespace for the enum symbolize functions. 572 /// {2}: The function to symbolize a string of the enum. 573 /// {3}: The constant builder call to create an attribute of the enum type. 574 /// {4}: The set of allowed enum keywords. 575 /// {5}: The error message on failure when the enum isn't present. 576 const char *const enumAttrParserCode = R"( 577 { 578 ::llvm::StringRef attrStr; 579 ::mlir::NamedAttrList attrStorage; 580 auto loc = parser.getCurrentLocation(); 581 if (parser.parseOptionalKeyword(&attrStr, {4})) { 582 ::mlir::StringAttr attrVal; 583 ::mlir::OptionalParseResult parseResult = 584 parser.parseOptionalAttribute(attrVal, 585 parser.getBuilder().getNoneType(), 586 "{0}", attrStorage); 587 if (parseResult.hasValue()) {{ 588 if (failed(*parseResult)) 589 return ::mlir::failure(); 590 attrStr = attrVal.getValue(); 591 } else { 592 {5} 593 } 594 } 595 if (!attrStr.empty()) { 596 auto attrOptional = {1}::{2}(attrStr); 597 if (!attrOptional) 598 return parser.emitError(loc, "invalid ") 599 << "{0} attribute specification: \"" << attrStr << '"';; 600 601 {0}Attr = {3}; 602 result.addAttribute("{0}", {0}Attr); 603 } 604 } 605 )"; 606 607 /// The code snippet used to generate a parser call for an operand. 608 /// 609 /// {0}: The name of the operand. 610 const char *const variadicOperandParserCode = R"( 611 {0}OperandsLoc = parser.getCurrentLocation(); 612 if (parser.parseOperandList({0}Operands)) 613 return ::mlir::failure(); 614 )"; 615 const char *const optionalOperandParserCode = R"( 616 { 617 {0}OperandsLoc = parser.getCurrentLocation(); 618 ::mlir::OpAsmParser::OperandType operand; 619 ::mlir::OptionalParseResult parseResult = 620 parser.parseOptionalOperand(operand); 621 if (parseResult.hasValue()) { 622 if (failed(*parseResult)) 623 return ::mlir::failure(); 624 {0}Operands.push_back(operand); 625 } 626 } 627 )"; 628 const char *const operandParserCode = R"( 629 {0}OperandsLoc = parser.getCurrentLocation(); 630 if (parser.parseOperand({0}RawOperands[0])) 631 return ::mlir::failure(); 632 )"; 633 /// The code snippet used to generate a parser call for a VariadicOfVariadic 634 /// operand. 635 /// 636 /// {0}: The name of the operand. 637 /// {1}: The name of segment size attribute. 638 const char *const variadicOfVariadicOperandParserCode = R"( 639 { 640 {0}OperandsLoc = parser.getCurrentLocation(); 641 int32_t curSize = 0; 642 do { 643 if (parser.parseOptionalLParen()) 644 break; 645 if (parser.parseOperandList({0}Operands) || parser.parseRParen()) 646 return ::mlir::failure(); 647 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize); 648 curSize = {0}Operands.size(); 649 } while (succeeded(parser.parseOptionalComma())); 650 } 651 )"; 652 653 /// The code snippet used to generate a parser call for a type list. 654 /// 655 /// {0}: The name for the type list. 656 const char *const variadicOfVariadicTypeParserCode = R"( 657 do { 658 if (parser.parseOptionalLParen()) 659 break; 660 if (parser.parseOptionalRParen() && 661 (parser.parseTypeList({0}Types) || parser.parseRParen())) 662 return ::mlir::failure(); 663 } while (succeeded(parser.parseOptionalComma())); 664 )"; 665 const char *const variadicTypeParserCode = R"( 666 if (parser.parseTypeList({0}Types)) 667 return ::mlir::failure(); 668 )"; 669 const char *const optionalTypeParserCode = R"( 670 { 671 ::mlir::Type optionalType; 672 ::mlir::OptionalParseResult parseResult = 673 parser.parseOptionalType(optionalType); 674 if (parseResult.hasValue()) { 675 if (failed(*parseResult)) 676 return ::mlir::failure(); 677 {0}Types.push_back(optionalType); 678 } 679 } 680 )"; 681 const char *const typeParserCode = R"( 682 { 683 {0} type; 684 if (parser.parseCustomTypeWithFallback(type)) 685 return ::mlir::failure(); 686 {1}RawTypes[0] = type; 687 } 688 )"; 689 const char *const qualifiedTypeParserCode = R"( 690 if (parser.parseType({1}RawTypes[0])) 691 return ::mlir::failure(); 692 )"; 693 694 /// The code snippet used to generate a parser call for a functional type. 695 /// 696 /// {0}: The name for the input type list. 697 /// {1}: The name for the result type list. 698 const char *const functionalTypeParserCode = R"( 699 ::mlir::FunctionType {0}__{1}_functionType; 700 if (parser.parseType({0}__{1}_functionType)) 701 return ::mlir::failure(); 702 {0}Types = {0}__{1}_functionType.getInputs(); 703 {1}Types = {0}__{1}_functionType.getResults(); 704 )"; 705 706 /// The code snippet used to generate a parser call to infer return types. 707 /// 708 /// {0}: The operation class name 709 const char *const inferReturnTypesParserCode = R"( 710 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes; 711 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(), 712 result.location, result.operands, 713 result.attributes.getDictionary(parser.getContext()), 714 result.regions, inferredReturnTypes))) 715 return ::mlir::failure(); 716 result.addTypes(inferredReturnTypes); 717 )"; 718 719 /// The code snippet used to generate a parser call for a region list. 720 /// 721 /// {0}: The name for the region list. 722 const char *regionListParserCode = R"( 723 { 724 std::unique_ptr<::mlir::Region> region; 725 auto firstRegionResult = parser.parseOptionalRegion(region); 726 if (firstRegionResult.hasValue()) { 727 if (failed(*firstRegionResult)) 728 return ::mlir::failure(); 729 {0}Regions.emplace_back(std::move(region)); 730 731 // Parse any trailing regions. 732 while (succeeded(parser.parseOptionalComma())) { 733 region = std::make_unique<::mlir::Region>(); 734 if (parser.parseRegion(*region)) 735 return ::mlir::failure(); 736 {0}Regions.emplace_back(std::move(region)); 737 } 738 } 739 } 740 )"; 741 742 /// The code snippet used to ensure a list of regions have terminators. 743 /// 744 /// {0}: The name of the region list. 745 const char *regionListEnsureTerminatorParserCode = R"( 746 for (auto ®ion : {0}Regions) 747 ensureTerminator(*region, parser.getBuilder(), result.location); 748 )"; 749 750 /// The code snippet used to ensure a list of regions have a block. 751 /// 752 /// {0}: The name of the region list. 753 const char *regionListEnsureSingleBlockParserCode = R"( 754 for (auto ®ion : {0}Regions) 755 if (region->empty()) region->emplaceBlock(); 756 )"; 757 758 /// The code snippet used to generate a parser call for an optional region. 759 /// 760 /// {0}: The name of the region. 761 const char *optionalRegionParserCode = R"( 762 { 763 auto parseResult = parser.parseOptionalRegion(*{0}Region); 764 if (parseResult.hasValue() && failed(*parseResult)) 765 return ::mlir::failure(); 766 } 767 )"; 768 769 /// The code snippet used to generate a parser call for a region. 770 /// 771 /// {0}: The name of the region. 772 const char *regionParserCode = R"( 773 if (parser.parseRegion(*{0}Region)) 774 return ::mlir::failure(); 775 )"; 776 777 /// The code snippet used to ensure a region has a terminator. 778 /// 779 /// {0}: The name of the region. 780 const char *regionEnsureTerminatorParserCode = R"( 781 ensureTerminator(*{0}Region, parser.getBuilder(), result.location); 782 )"; 783 784 /// The code snippet used to ensure a region has a block. 785 /// 786 /// {0}: The name of the region. 787 const char *regionEnsureSingleBlockParserCode = R"( 788 if ({0}Region->empty()) {0}Region->emplaceBlock(); 789 )"; 790 791 /// The code snippet used to generate a parser call for a successor list. 792 /// 793 /// {0}: The name for the successor list. 794 const char *successorListParserCode = R"( 795 { 796 ::mlir::Block *succ; 797 auto firstSucc = parser.parseOptionalSuccessor(succ); 798 if (firstSucc.hasValue()) { 799 if (failed(*firstSucc)) 800 return ::mlir::failure(); 801 {0}Successors.emplace_back(succ); 802 803 // Parse any trailing successors. 804 while (succeeded(parser.parseOptionalComma())) { 805 if (parser.parseSuccessor(succ)) 806 return ::mlir::failure(); 807 {0}Successors.emplace_back(succ); 808 } 809 } 810 } 811 )"; 812 813 /// The code snippet used to generate a parser call for a successor. 814 /// 815 /// {0}: The name of the successor. 816 const char *successorParserCode = R"( 817 if (parser.parseSuccessor({0}Successor)) 818 return ::mlir::failure(); 819 )"; 820 821 namespace { 822 /// The type of length for a given parse argument. 823 enum class ArgumentLengthKind { 824 /// The argument is a variadic of a variadic, and may contain 0->N range 825 /// elements. 826 VariadicOfVariadic, 827 /// The argument is variadic, and may contain 0->N elements. 828 Variadic, 829 /// The argument is optional, and may contain 0 or 1 elements. 830 Optional, 831 /// The argument is a single element, i.e. always represents 1 element. 832 Single 833 }; 834 } // namespace 835 836 /// Get the length kind for the given constraint. 837 static ArgumentLengthKind 838 getArgumentLengthKind(const NamedTypeConstraint *var) { 839 if (var->isOptional()) 840 return ArgumentLengthKind::Optional; 841 if (var->isVariadicOfVariadic()) 842 return ArgumentLengthKind::VariadicOfVariadic; 843 if (var->isVariadic()) 844 return ArgumentLengthKind::Variadic; 845 return ArgumentLengthKind::Single; 846 } 847 848 /// Get the name used for the type list for the given type directive operand. 849 /// 'lengthKind' to the corresponding kind for the given argument. 850 static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) { 851 if (auto *operand = dyn_cast<OperandVariable>(arg)) { 852 lengthKind = getArgumentLengthKind(operand->getVar()); 853 return operand->getVar()->name; 854 } 855 if (auto *result = dyn_cast<ResultVariable>(arg)) { 856 lengthKind = getArgumentLengthKind(result->getVar()); 857 return result->getVar()->name; 858 } 859 lengthKind = ArgumentLengthKind::Variadic; 860 if (isa<OperandsDirective>(arg)) 861 return "allOperand"; 862 if (isa<ResultsDirective>(arg)) 863 return "allResult"; 864 llvm_unreachable("unknown 'type' directive argument"); 865 } 866 867 /// Generate the parser for a literal value. 868 static void genLiteralParser(StringRef value, MethodBody &body) { 869 // Handle the case of a keyword/identifier. 870 if (value.front() == '_' || isalpha(value.front())) { 871 body << "Keyword(\"" << value << "\")"; 872 return; 873 } 874 body << (StringRef)StringSwitch<StringRef>(value) 875 .Case("->", "Arrow()") 876 .Case(":", "Colon()") 877 .Case(",", "Comma()") 878 .Case("=", "Equal()") 879 .Case("<", "Less()") 880 .Case(">", "Greater()") 881 .Case("{", "LBrace()") 882 .Case("}", "RBrace()") 883 .Case("(", "LParen()") 884 .Case(")", "RParen()") 885 .Case("[", "LSquare()") 886 .Case("]", "RSquare()") 887 .Case("?", "Question()") 888 .Case("+", "Plus()") 889 .Case("*", "Star()"); 890 } 891 892 /// Generate the storage code required for parsing the given element. 893 static void genElementParserStorage(Element *element, const Operator &op, 894 MethodBody &body) { 895 if (auto *optional = dyn_cast<OptionalElement>(element)) { 896 auto elements = optional->getThenElements(); 897 898 // If the anchor is a unit attribute, it won't be parsed directly so elide 899 // it. 900 auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor()); 901 Element *elidedAnchorElement = nullptr; 902 if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr()) 903 elidedAnchorElement = anchor; 904 for (auto &childElement : elements) 905 if (&childElement != elidedAnchorElement) 906 genElementParserStorage(&childElement, op, body); 907 for (auto &childElement : optional->getElseElements()) 908 genElementParserStorage(&childElement, op, body); 909 910 } else if (auto *custom = dyn_cast<CustomDirective>(element)) { 911 for (auto ¶mElement : custom->getArguments()) 912 genElementParserStorage(¶mElement, op, body); 913 914 } else if (isa<OperandsDirective>(element)) { 915 body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> " 916 "allOperands;\n"; 917 918 } else if (isa<RegionsDirective>(element)) { 919 body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " 920 "fullRegions;\n"; 921 922 } else if (isa<SuccessorsDirective>(element)) { 923 body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n"; 924 925 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) { 926 const NamedAttribute *var = attr->getVar(); 927 body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), 928 var->name); 929 930 } else if (auto *operand = dyn_cast<OperandVariable>(element)) { 931 StringRef name = operand->getVar()->name; 932 if (operand->getVar()->isVariableLength()) { 933 body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> " 934 << name << "Operands;\n"; 935 if (operand->getVar()->isVariadicOfVariadic()) { 936 body << " llvm::SmallVector<int32_t> " << name 937 << "OperandGroupSizes;\n"; 938 } 939 } else { 940 body << " ::mlir::OpAsmParser::OperandType " << name 941 << "RawOperands[1];\n" 942 << " ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name 943 << "Operands(" << name << "RawOperands);"; 944 } 945 body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n" 946 " (void){0}OperandsLoc;\n", 947 name); 948 949 } else if (auto *region = dyn_cast<RegionVariable>(element)) { 950 StringRef name = region->getVar()->name; 951 if (region->getVar()->isVariadic()) { 952 body << llvm::formatv( 953 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " 954 "{0}Regions;\n", 955 name); 956 } else { 957 body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = " 958 "std::make_unique<::mlir::Region>();\n", 959 name); 960 } 961 962 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) { 963 StringRef name = successor->getVar()->name; 964 if (successor->getVar()->isVariadic()) { 965 body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> " 966 "{0}Successors;\n", 967 name); 968 } else { 969 body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name); 970 } 971 972 } else if (auto *dir = dyn_cast<TypeDirective>(element)) { 973 ArgumentLengthKind lengthKind; 974 StringRef name = getTypeListName(dir->getOperand(), lengthKind); 975 if (lengthKind != ArgumentLengthKind::Single) 976 body << " ::mlir::SmallVector<::mlir::Type, 1> " << name << "Types;\n"; 977 else 978 body << llvm::formatv(" ::mlir::Type {0}RawTypes[1];\n", name) 979 << llvm::formatv( 980 " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n", 981 name); 982 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) { 983 ArgumentLengthKind ignored; 984 body << " ::llvm::ArrayRef<::mlir::Type> " 985 << getTypeListName(dir->getInputs(), ignored) << "Types;\n"; 986 body << " ::llvm::ArrayRef<::mlir::Type> " 987 << getTypeListName(dir->getResults(), ignored) << "Types;\n"; 988 } 989 } 990 991 /// Generate the parser for a parameter to a custom directive. 992 static void genCustomParameterParser(Element ¶m, MethodBody &body) { 993 if (auto *attr = dyn_cast<AttributeVariable>(¶m)) { 994 body << attr->getVar()->name << "Attr"; 995 } else if (isa<AttrDictDirective>(¶m)) { 996 body << "result.attributes"; 997 } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) { 998 StringRef name = operand->getVar()->name; 999 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); 1000 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) 1001 body << llvm::formatv("{0}OperandGroups", name); 1002 else if (lengthKind == ArgumentLengthKind::Variadic) 1003 body << llvm::formatv("{0}Operands", name); 1004 else if (lengthKind == ArgumentLengthKind::Optional) 1005 body << llvm::formatv("{0}Operand", name); 1006 else 1007 body << formatv("{0}RawOperands[0]", name); 1008 1009 } else if (auto *region = dyn_cast<RegionVariable>(¶m)) { 1010 StringRef name = region->getVar()->name; 1011 if (region->getVar()->isVariadic()) 1012 body << llvm::formatv("{0}Regions", name); 1013 else 1014 body << llvm::formatv("*{0}Region", name); 1015 1016 } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) { 1017 StringRef name = successor->getVar()->name; 1018 if (successor->getVar()->isVariadic()) 1019 body << llvm::formatv("{0}Successors", name); 1020 else 1021 body << llvm::formatv("{0}Successor", name); 1022 1023 } else if (auto *dir = dyn_cast<RefDirective>(¶m)) { 1024 genCustomParameterParser(*dir->getOperand(), body); 1025 1026 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) { 1027 ArgumentLengthKind lengthKind; 1028 StringRef listName = getTypeListName(dir->getOperand(), lengthKind); 1029 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) 1030 body << llvm::formatv("{0}TypeGroups", listName); 1031 else if (lengthKind == ArgumentLengthKind::Variadic) 1032 body << llvm::formatv("{0}Types", listName); 1033 else if (lengthKind == ArgumentLengthKind::Optional) 1034 body << llvm::formatv("{0}Type", listName); 1035 else 1036 body << formatv("{0}RawTypes[0]", listName); 1037 } else { 1038 llvm_unreachable("unknown custom directive parameter"); 1039 } 1040 } 1041 1042 /// Generate the parser for a custom directive. 1043 static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) { 1044 body << " {\n"; 1045 1046 // Preprocess the directive variables. 1047 // * Add a local variable for optional operands and types. This provides a 1048 // better API to the user defined parser methods. 1049 // * Set the location of operand variables. 1050 for (Element ¶m : dir->getArguments()) { 1051 if (auto *operand = dyn_cast<OperandVariable>(¶m)) { 1052 auto *var = operand->getVar(); 1053 body << " " << var->name 1054 << "OperandsLoc = parser.getCurrentLocation();\n"; 1055 if (var->isOptional()) { 1056 body << llvm::formatv( 1057 " llvm::Optional<::mlir::OpAsmParser::OperandType> " 1058 "{0}Operand;\n", 1059 var->name); 1060 } else if (var->isVariadicOfVariadic()) { 1061 body << llvm::formatv(" " 1062 "llvm::SmallVector<llvm::SmallVector<::mlir::" 1063 "OpAsmParser::OperandType>> " 1064 "{0}OperandGroups;\n", 1065 var->name); 1066 } 1067 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) { 1068 ArgumentLengthKind lengthKind; 1069 StringRef listName = getTypeListName(dir->getOperand(), lengthKind); 1070 if (lengthKind == ArgumentLengthKind::Optional) { 1071 body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName); 1072 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { 1073 body << llvm::formatv( 1074 " llvm::SmallVector<llvm::SmallVector<::mlir::Type>> " 1075 "{0}TypeGroups;\n", 1076 listName); 1077 } 1078 } else if (auto *dir = dyn_cast<RefDirective>(¶m)) { 1079 Element *input = dir->getOperand(); 1080 if (auto *operand = dyn_cast<OperandVariable>(input)) { 1081 if (!operand->getVar()->isOptional()) 1082 continue; 1083 body << llvm::formatv( 1084 " {0} {1}Operand = {1}Operands.empty() ? {0}() : " 1085 "{1}Operands[0];\n", 1086 "llvm::Optional<::mlir::OpAsmParser::OperandType>", 1087 operand->getVar()->name); 1088 1089 } else if (auto *type = dyn_cast<TypeDirective>(input)) { 1090 ArgumentLengthKind lengthKind; 1091 StringRef listName = getTypeListName(type->getOperand(), lengthKind); 1092 if (lengthKind == ArgumentLengthKind::Optional) { 1093 body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? " 1094 "::mlir::Type() : {0}Types[0];\n", 1095 listName); 1096 } 1097 } 1098 } 1099 } 1100 1101 body << " if (parse" << dir->getName() << "(parser"; 1102 for (Element ¶m : dir->getArguments()) { 1103 body << ", "; 1104 genCustomParameterParser(param, body); 1105 } 1106 1107 body << "))\n" 1108 << " return ::mlir::failure();\n"; 1109 1110 // After parsing, add handling for any of the optional constructs. 1111 for (Element ¶m : dir->getArguments()) { 1112 if (auto *attr = dyn_cast<AttributeVariable>(¶m)) { 1113 const NamedAttribute *var = attr->getVar(); 1114 if (var->attr.isOptional()) 1115 body << llvm::formatv(" if ({0}Attr)\n ", var->name); 1116 1117 body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n", 1118 var->name); 1119 } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) { 1120 const NamedTypeConstraint *var = operand->getVar(); 1121 if (var->isOptional()) { 1122 body << llvm::formatv(" if ({0}Operand.hasValue())\n" 1123 " {0}Operands.push_back(*{0}Operand);\n", 1124 var->name); 1125 } else if (var->isVariadicOfVariadic()) { 1126 body << llvm::formatv( 1127 " for (const auto &subRange : {0}OperandGroups) {{\n" 1128 " {0}Operands.append(subRange.begin(), subRange.end());\n" 1129 " {0}OperandGroupSizes.push_back(subRange.size());\n" 1130 " }\n", 1131 var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr()); 1132 } 1133 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) { 1134 ArgumentLengthKind lengthKind; 1135 StringRef listName = getTypeListName(dir->getOperand(), lengthKind); 1136 if (lengthKind == ArgumentLengthKind::Optional) { 1137 body << llvm::formatv(" if ({0}Type)\n" 1138 " {0}Types.push_back({0}Type);\n", 1139 listName); 1140 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { 1141 body << llvm::formatv( 1142 " for (const auto &subRange : {0}TypeGroups)\n" 1143 " {0}Types.append(subRange.begin(), subRange.end());\n", 1144 listName); 1145 } 1146 } 1147 } 1148 1149 body << " }\n"; 1150 } 1151 1152 /// Generate the parser for a enum attribute. 1153 static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, 1154 FmtContext &attrTypeCtx) { 1155 Attribute baseAttr = var->attr.getBaseAttr(); 1156 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr); 1157 std::vector<EnumAttrCase> cases = enumAttr.getAllCases(); 1158 1159 // Generate the code for building an attribute for this enum. 1160 std::string attrBuilderStr; 1161 { 1162 llvm::raw_string_ostream os(attrBuilderStr); 1163 os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx, 1164 "attrOptional.getValue()"); 1165 } 1166 1167 // Build a string containing the cases that can be formatted as a keyword. 1168 std::string validCaseKeywordsStr = "{"; 1169 llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr); 1170 for (const EnumAttrCase &attrCase : cases) 1171 if (canFormatStringAsKeyword(attrCase.getStr())) 1172 validCaseKeywordsOS << '"' << attrCase.getStr() << "\","; 1173 validCaseKeywordsOS.str().back() = '}'; 1174 1175 // If the attribute is not optional, build an error message for the missing 1176 // attribute. 1177 std::string errorMessage; 1178 if (!var->attr.isOptional()) { 1179 llvm::raw_string_ostream errorMessageOS(errorMessage); 1180 errorMessageOS 1181 << "return parser.emitError(loc, \"expected string or " 1182 "keyword containing one of the following enum values for attribute '" 1183 << var->name << "' ["; 1184 llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) { 1185 errorMessageOS << attrCase.getStr(); 1186 }); 1187 errorMessageOS << "]\");"; 1188 } 1189 1190 body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(), 1191 enumAttr.getStringToSymbolFnName(), attrBuilderStr, 1192 validCaseKeywordsStr, errorMessage); 1193 } 1194 1195 void OperationFormat::genParser(Operator &op, OpClass &opClass) { 1196 SmallVector<MethodParameter> paramList; 1197 paramList.emplace_back("::mlir::OpAsmParser &", "parser"); 1198 paramList.emplace_back("::mlir::OperationState &", "result"); 1199 1200 auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse", 1201 std::move(paramList)); 1202 auto &body = method->body(); 1203 1204 // Generate variables to store the operands and type within the format. This 1205 // allows for referencing these variables in the presence of optional 1206 // groupings. 1207 for (auto &element : elements) 1208 genElementParserStorage(&*element, op, body); 1209 1210 // A format context used when parsing attributes with buildable types. 1211 FmtContext attrTypeCtx; 1212 attrTypeCtx.withBuilder("parser.getBuilder()"); 1213 1214 // Generate parsers for each of the elements. 1215 for (auto &element : elements) 1216 genElementParser(element.get(), body, attrTypeCtx); 1217 1218 // Generate the code to resolve the operand/result types and successors now 1219 // that they have been parsed. 1220 genParserRegionResolution(op, body); 1221 genParserSuccessorResolution(op, body); 1222 genParserVariadicSegmentResolution(op, body); 1223 genParserTypeResolution(op, body); 1224 1225 body << " return ::mlir::success();\n"; 1226 } 1227 1228 void OperationFormat::genElementParser(Element *element, MethodBody &body, 1229 FmtContext &attrTypeCtx, 1230 GenContext genCtx) { 1231 /// Optional Group. 1232 if (auto *optional = dyn_cast<OptionalElement>(element)) { 1233 auto elements = llvm::drop_begin(optional->getThenElements(), 1234 optional->getParseStart()); 1235 1236 // Generate a special optional parser for the first element to gate the 1237 // parsing of the rest of the elements. 1238 Element *firstElement = &*elements.begin(); 1239 if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) { 1240 genElementParser(attrVar, body, attrTypeCtx); 1241 body << " if (" << attrVar->getVar()->name << "Attr) {\n"; 1242 } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) { 1243 body << " if (succeeded(parser.parseOptional"; 1244 genLiteralParser(literal->getLiteral(), body); 1245 body << ")) {\n"; 1246 } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) { 1247 genElementParser(opVar, body, attrTypeCtx); 1248 body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n"; 1249 } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) { 1250 const NamedRegion *region = regionVar->getVar(); 1251 if (region->isVariadic()) { 1252 genElementParser(regionVar, body, attrTypeCtx); 1253 body << " if (!" << region->name << "Regions.empty()) {\n"; 1254 } else { 1255 body << llvm::formatv(optionalRegionParserCode, region->name); 1256 body << " if (!" << region->name << "Region->empty()) {\n "; 1257 if (hasImplicitTermTrait) 1258 body << llvm::formatv(regionEnsureTerminatorParserCode, region->name); 1259 else if (hasSingleBlockTrait) 1260 body << llvm::formatv(regionEnsureSingleBlockParserCode, 1261 region->name); 1262 } 1263 } 1264 1265 // If the anchor is a unit attribute, we don't need to print it. When 1266 // parsing, we will add this attribute if this group is present. 1267 Element *elidedAnchorElement = nullptr; 1268 auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor()); 1269 if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) { 1270 elidedAnchorElement = anchorAttr; 1271 1272 // Add the anchor unit attribute to the operation state. 1273 body << " result.addAttribute(\"" << anchorAttr->getVar()->name 1274 << "\", parser.getBuilder().getUnitAttr());\n"; 1275 } 1276 1277 // Generate the rest of the elements inside an optional group. Elements in 1278 // an optional group after the guard are parsed as required. 1279 for (Element &childElement : llvm::drop_begin(elements, 1)) { 1280 if (&childElement != elidedAnchorElement) { 1281 genElementParser(&childElement, body, attrTypeCtx, 1282 GenContext::Optional); 1283 } 1284 } 1285 body << " }"; 1286 1287 // Generate the else elements. 1288 auto elseElements = optional->getElseElements(); 1289 if (!elseElements.empty()) { 1290 body << " else {\n"; 1291 for (Element &childElement : elseElements) 1292 genElementParser(&childElement, body, attrTypeCtx); 1293 body << " }"; 1294 } 1295 body << "\n"; 1296 1297 /// Literals. 1298 } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) { 1299 body << " if (parser.parse"; 1300 genLiteralParser(literal->getLiteral(), body); 1301 body << ")\n return ::mlir::failure();\n"; 1302 1303 /// Whitespaces. 1304 } else if (isa<WhitespaceElement>(element)) { 1305 // Nothing to parse. 1306 1307 /// Arguments. 1308 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) { 1309 const NamedAttribute *var = attr->getVar(); 1310 1311 // Check to see if we can parse this as an enum attribute. 1312 if (canFormatEnumAttr(var)) 1313 return genEnumAttrParser(var, body, attrTypeCtx); 1314 1315 // Check to see if we should parse this as a symbol name attribute. 1316 if (shouldFormatSymbolNameAttr(var)) { 1317 body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode 1318 : symbolNameAttrParserCode, 1319 var->name); 1320 return; 1321 } 1322 1323 // If this attribute has a buildable type, use that when parsing the 1324 // attribute. 1325 std::string attrTypeStr; 1326 if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) { 1327 llvm::raw_string_ostream os(attrTypeStr); 1328 os << tgfmt(*typeBuilder, &attrTypeCtx); 1329 } else { 1330 attrTypeStr = "::mlir::Type{}"; 1331 } 1332 if (genCtx == GenContext::Normal && var->attr.isOptional()) { 1333 body << formatv(optionalAttrParserCode, var->name, attrTypeStr); 1334 } else { 1335 if (attr->shouldBeQualified() || 1336 var->attr.getStorageType() == "::mlir::Attribute") 1337 body << formatv(genericAttrParserCode, var->name, attrTypeStr); 1338 else 1339 body << formatv(attrParserCode, var->name, attrTypeStr); 1340 } 1341 1342 } else if (auto *operand = dyn_cast<OperandVariable>(element)) { 1343 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); 1344 StringRef name = operand->getVar()->name; 1345 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) 1346 body << llvm::formatv( 1347 variadicOfVariadicOperandParserCode, name, 1348 operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr()); 1349 else if (lengthKind == ArgumentLengthKind::Variadic) 1350 body << llvm::formatv(variadicOperandParserCode, name); 1351 else if (lengthKind == ArgumentLengthKind::Optional) 1352 body << llvm::formatv(optionalOperandParserCode, name); 1353 else 1354 body << formatv(operandParserCode, name); 1355 1356 } else if (auto *region = dyn_cast<RegionVariable>(element)) { 1357 bool isVariadic = region->getVar()->isVariadic(); 1358 body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode, 1359 region->getVar()->name); 1360 if (hasImplicitTermTrait) 1361 body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode 1362 : regionEnsureTerminatorParserCode, 1363 region->getVar()->name); 1364 else if (hasSingleBlockTrait) 1365 body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode 1366 : regionEnsureSingleBlockParserCode, 1367 region->getVar()->name); 1368 1369 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) { 1370 bool isVariadic = successor->getVar()->isVariadic(); 1371 body << formatv(isVariadic ? successorListParserCode : successorParserCode, 1372 successor->getVar()->name); 1373 1374 /// Directives. 1375 } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) { 1376 body << " if (parser.parseOptionalAttrDict" 1377 << (attrDict->isWithKeyword() ? "WithKeyword" : "") 1378 << "(result.attributes))\n" 1379 << " return ::mlir::failure();\n"; 1380 } else if (auto *customDir = dyn_cast<CustomDirective>(element)) { 1381 genCustomDirectiveParser(customDir, body); 1382 1383 } else if (isa<OperandsDirective>(element)) { 1384 body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n" 1385 << " if (parser.parseOperandList(allOperands))\n" 1386 << " return ::mlir::failure();\n"; 1387 1388 } else if (isa<RegionsDirective>(element)) { 1389 body << llvm::formatv(regionListParserCode, "full"); 1390 if (hasImplicitTermTrait) 1391 body << llvm::formatv(regionListEnsureTerminatorParserCode, "full"); 1392 else if (hasSingleBlockTrait) 1393 body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full"); 1394 1395 } else if (isa<SuccessorsDirective>(element)) { 1396 body << llvm::formatv(successorListParserCode, "full"); 1397 1398 } else if (auto *dir = dyn_cast<TypeDirective>(element)) { 1399 ArgumentLengthKind lengthKind; 1400 StringRef listName = getTypeListName(dir->getOperand(), lengthKind); 1401 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { 1402 body << llvm::formatv(variadicOfVariadicTypeParserCode, listName); 1403 } else if (lengthKind == ArgumentLengthKind::Variadic) { 1404 body << llvm::formatv(variadicTypeParserCode, listName); 1405 } else if (lengthKind == ArgumentLengthKind::Optional) { 1406 body << llvm::formatv(optionalTypeParserCode, listName); 1407 } else { 1408 const char *parserCode = 1409 dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode; 1410 TypeSwitch<Element *>(dir->getOperand()) 1411 .Case<OperandVariable, ResultVariable>([&](auto operand) { 1412 body << formatv(parserCode, 1413 operand->getVar()->constraint.getCPPClassName(), 1414 listName); 1415 }) 1416 .Default([&](auto operand) { 1417 body << formatv(parserCode, "::mlir::Type", listName); 1418 }); 1419 } 1420 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) { 1421 ArgumentLengthKind ignored; 1422 body << formatv(functionalTypeParserCode, 1423 getTypeListName(dir->getInputs(), ignored), 1424 getTypeListName(dir->getResults(), ignored)); 1425 } else { 1426 llvm_unreachable("unknown format element"); 1427 } 1428 } 1429 1430 void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { 1431 // If any of type resolutions use transformed variables, make sure that the 1432 // types of those variables are resolved. 1433 SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables; 1434 FmtContext verifierFCtx; 1435 for (TypeResolution &resolver : 1436 llvm::concat<TypeResolution>(resultTypes, operandTypes)) { 1437 Optional<StringRef> transformer = resolver.getVarTransformer(); 1438 if (!transformer) 1439 continue; 1440 // Ensure that we don't verify the same variables twice. 1441 const NamedTypeConstraint *variable = resolver.getVariable(); 1442 if (!variable || !verifiedVariables.insert(variable).second) 1443 continue; 1444 1445 auto constraint = variable->constraint; 1446 body << " for (::mlir::Type type : " << variable->name << "Types) {\n" 1447 << " (void)type;\n" 1448 << " if (!(" 1449 << tgfmt(constraint.getConditionTemplate(), 1450 &verifierFCtx.withSelf("type")) 1451 << ")) {\n" 1452 << formatv(" return parser.emitError(parser.getNameLoc()) << " 1453 "\"'{0}' must be {1}, but got \" << type;\n", 1454 variable->name, constraint.getSummary()) 1455 << " }\n" 1456 << " }\n"; 1457 } 1458 1459 // Initialize the set of buildable types. 1460 if (!buildableTypes.empty()) { 1461 FmtContext typeBuilderCtx; 1462 typeBuilderCtx.withBuilder("parser.getBuilder()"); 1463 for (auto &it : buildableTypes) 1464 body << " ::mlir::Type odsBuildableType" << it.second << " = " 1465 << tgfmt(it.first, &typeBuilderCtx) << ";\n"; 1466 } 1467 1468 // Emit the code necessary for a type resolver. 1469 auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) { 1470 if (Optional<int> val = resolver.getBuilderIdx()) { 1471 body << "odsBuildableType" << *val; 1472 } else if (const NamedTypeConstraint *var = resolver.getVariable()) { 1473 if (Optional<StringRef> tform = resolver.getVarTransformer()) { 1474 FmtContext fmtContext; 1475 fmtContext.addSubst("_ctxt", "parser.getContext()"); 1476 if (var->isVariadic()) 1477 fmtContext.withSelf(var->name + "Types"); 1478 else 1479 fmtContext.withSelf(var->name + "Types[0]"); 1480 body << tgfmt(*tform, &fmtContext); 1481 } else { 1482 body << var->name << "Types"; 1483 } 1484 } else if (const NamedAttribute *attr = resolver.getAttribute()) { 1485 if (Optional<StringRef> tform = resolver.getVarTransformer()) 1486 body << tgfmt(*tform, 1487 &FmtContext().withSelf(attr->name + "Attr.getType()")); 1488 else 1489 body << attr->name << "Attr.getType()"; 1490 } else { 1491 body << curVar << "Types"; 1492 } 1493 }; 1494 1495 // Resolve each of the result types. 1496 if (!infersResultTypes) { 1497 if (allResultTypes) { 1498 body << " result.addTypes(allResultTypes);\n"; 1499 } else { 1500 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { 1501 body << " result.addTypes("; 1502 emitTypeResolver(resultTypes[i], op.getResultName(i)); 1503 body << ");\n"; 1504 } 1505 } 1506 } 1507 1508 // Emit the operand type resolutions. 1509 genParserOperandTypeResolution(op, body, emitTypeResolver); 1510 1511 // Handle return type inference once all operands have been resolved 1512 if (infersResultTypes) 1513 body << formatv(inferReturnTypesParserCode, op.getCppClassName()); 1514 } 1515 1516 void OperationFormat::genParserOperandTypeResolution( 1517 Operator &op, MethodBody &body, 1518 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) { 1519 // Early exit if there are no operands. 1520 if (op.getNumOperands() == 0) 1521 return; 1522 1523 // Handle the case where all operand types are grouped together with 1524 // "types(operands)". 1525 if (allOperandTypes) { 1526 // If `operands` was specified, use the full operand list directly. 1527 if (allOperands) { 1528 body << " if (parser.resolveOperands(allOperands, allOperandTypes, " 1529 "allOperandLoc, result.operands))\n" 1530 " return ::mlir::failure();\n"; 1531 return; 1532 } 1533 1534 // Otherwise, use llvm::concat to merge the disjoint operand lists together. 1535 // llvm::concat does not allow the case of a single range, so guard it here. 1536 body << " if (parser.resolveOperands("; 1537 if (op.getNumOperands() > 1) { 1538 body << "::llvm::concat<const ::mlir::OpAsmParser::OperandType>("; 1539 llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) { 1540 body << operand.name << "Operands"; 1541 }); 1542 body << ")"; 1543 } else { 1544 body << op.operand_begin()->name << "Operands"; 1545 } 1546 body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n" 1547 << " return ::mlir::failure();\n"; 1548 return; 1549 } 1550 1551 // Handle the case where all operands are grouped together with "operands". 1552 if (allOperands) { 1553 body << " if (parser.resolveOperands(allOperands, "; 1554 1555 // Group all of the operand types together to perform the resolution all at 1556 // once. Use llvm::concat to perform the merge. llvm::concat does not allow 1557 // the case of a single range, so guard it here. 1558 if (op.getNumOperands() > 1) { 1559 body << "::llvm::concat<const ::mlir::Type>("; 1560 llvm::interleaveComma( 1561 llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) { 1562 body << "::llvm::ArrayRef<::mlir::Type>("; 1563 emitTypeResolver(operandTypes[i], op.getOperand(i).name); 1564 body << ")"; 1565 }); 1566 body << ")"; 1567 } else { 1568 emitTypeResolver(operandTypes.front(), op.getOperand(0).name); 1569 } 1570 1571 body << ", allOperandLoc, result.operands))\n" 1572 << " return ::mlir::failure();\n"; 1573 return; 1574 } 1575 1576 // The final case is the one where each of the operands types are resolved 1577 // separately. 1578 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { 1579 NamedTypeConstraint &operand = op.getOperand(i); 1580 body << " if (parser.resolveOperands(" << operand.name << "Operands, "; 1581 1582 // Resolve the type of this operand. 1583 TypeResolution &operandType = operandTypes[i]; 1584 emitTypeResolver(operandType, operand.name); 1585 1586 // If the type is resolved by a non-variadic variable, index into the 1587 // resolved type list. This allows for resolving the types of a variadic 1588 // operand list from a non-variadic variable. 1589 bool verifyOperandAndTypeSize = true; 1590 if (auto *resolverVar = operandType.getVariable()) { 1591 if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) { 1592 body << "[0]"; 1593 verifyOperandAndTypeSize = false; 1594 } 1595 } else { 1596 verifyOperandAndTypeSize = !operandType.getBuilderIdx(); 1597 } 1598 1599 // Check to see if the sizes between the types and operands must match. If 1600 // they do, provide the operand location to select the proper resolution 1601 // overload. 1602 if (verifyOperandAndTypeSize) 1603 body << ", " << operand.name << "OperandsLoc"; 1604 body << ", result.operands))\n return ::mlir::failure();\n"; 1605 } 1606 } 1607 1608 void OperationFormat::genParserRegionResolution(Operator &op, 1609 MethodBody &body) { 1610 // Check for the case where all regions were parsed. 1611 bool hasAllRegions = llvm::any_of( 1612 elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); }); 1613 if (hasAllRegions) { 1614 body << " result.addRegions(fullRegions);\n"; 1615 return; 1616 } 1617 1618 // Otherwise, handle each region individually. 1619 for (const NamedRegion ®ion : op.getRegions()) { 1620 if (region.isVariadic()) 1621 body << " result.addRegions(" << region.name << "Regions);\n"; 1622 else 1623 body << " result.addRegion(std::move(" << region.name << "Region));\n"; 1624 } 1625 } 1626 1627 void OperationFormat::genParserSuccessorResolution(Operator &op, 1628 MethodBody &body) { 1629 // Check for the case where all successors were parsed. 1630 bool hasAllSuccessors = llvm::any_of( 1631 elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); }); 1632 if (hasAllSuccessors) { 1633 body << " result.addSuccessors(fullSuccessors);\n"; 1634 return; 1635 } 1636 1637 // Otherwise, handle each successor individually. 1638 for (const NamedSuccessor &successor : op.getSuccessors()) { 1639 if (successor.isVariadic()) 1640 body << " result.addSuccessors(" << successor.name << "Successors);\n"; 1641 else 1642 body << " result.addSuccessors(" << successor.name << "Successor);\n"; 1643 } 1644 } 1645 1646 void OperationFormat::genParserVariadicSegmentResolution(Operator &op, 1647 MethodBody &body) { 1648 if (!allOperands) { 1649 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 1650 body << " result.addAttribute(\"operand_segment_sizes\", " 1651 << "parser.getBuilder().getI32VectorAttr({"; 1652 auto interleaveFn = [&](const NamedTypeConstraint &operand) { 1653 // If the operand is variadic emit the parsed size. 1654 if (operand.isVariableLength()) 1655 body << "static_cast<int32_t>(" << operand.name << "Operands.size())"; 1656 else 1657 body << "1"; 1658 }; 1659 llvm::interleaveComma(op.getOperands(), body, interleaveFn); 1660 body << "}));\n"; 1661 } 1662 for (const NamedTypeConstraint &operand : op.getOperands()) { 1663 if (!operand.isVariadicOfVariadic()) 1664 continue; 1665 body << llvm::formatv( 1666 " result.addAttribute(\"{0}\", " 1667 "parser.getBuilder().getI32TensorAttr({1}OperandGroupSizes));\n", 1668 operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), 1669 operand.name); 1670 } 1671 } 1672 1673 if (!allResultTypes && 1674 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { 1675 body << " result.addAttribute(\"result_segment_sizes\", " 1676 << "parser.getBuilder().getI32VectorAttr({"; 1677 auto interleaveFn = [&](const NamedTypeConstraint &result) { 1678 // If the result is variadic emit the parsed size. 1679 if (result.isVariableLength()) 1680 body << "static_cast<int32_t>(" << result.name << "Types.size())"; 1681 else 1682 body << "1"; 1683 }; 1684 llvm::interleaveComma(op.getResults(), body, interleaveFn); 1685 body << "}));\n"; 1686 } 1687 } 1688 1689 //===----------------------------------------------------------------------===// 1690 // PrinterGen 1691 1692 /// The code snippet used to generate a printer call for a region of an 1693 // operation that has the SingleBlockImplicitTerminator trait. 1694 /// 1695 /// {0}: The name of the region. 1696 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"( 1697 { 1698 bool printTerminator = true; 1699 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{ 1700 printTerminator = !term->getAttrDictionary().empty() || 1701 term->getNumOperands() != 0 || 1702 term->getNumResults() != 0; 1703 } 1704 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true, 1705 /*printBlockTerminators=*/printTerminator); 1706 } 1707 )"; 1708 1709 /// The code snippet used to generate a printer call for an enum that has cases 1710 /// that can't be represented with a keyword. 1711 /// 1712 /// {0}: The name of the enum attribute. 1713 /// {1}: The name of the enum attributes symbolToString function. 1714 const char *enumAttrBeginPrinterCode = R"( 1715 { 1716 auto caseValue = {0}(); 1717 auto caseValueStr = {1}(caseValue); 1718 )"; 1719 1720 /// Generate the printer for the 'attr-dict' directive. 1721 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op, 1722 MethodBody &body, bool withKeyword) { 1723 body << " _odsPrinter.printOptionalAttrDict" 1724 << (withKeyword ? "WithKeyword" : "") 1725 << "((*this)->getAttrs(), /*elidedAttrs=*/{"; 1726 // Elide the variadic segment size attributes if necessary. 1727 if (!fmt.allOperands && 1728 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) 1729 body << "\"operand_segment_sizes\", "; 1730 if (!fmt.allResultTypes && 1731 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) 1732 body << "\"result_segment_sizes\", "; 1733 if (!fmt.inferredAttributes.empty()) { 1734 for (const auto &attr : fmt.inferredAttributes) 1735 body << "\"" << attr.getKey() << "\", "; 1736 } 1737 llvm::interleaveComma( 1738 fmt.usedAttributes, body, 1739 [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; }); 1740 body << "});\n"; 1741 } 1742 1743 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a 1744 /// space should be emitted before this element. `lastWasPunctuation` is true if 1745 /// the previous element was a punctuation literal. 1746 static void genLiteralPrinter(StringRef value, MethodBody &body, 1747 bool &shouldEmitSpace, bool &lastWasPunctuation) { 1748 body << " _odsPrinter"; 1749 1750 // Don't insert a space for certain punctuation. 1751 if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation)) 1752 body << " << ' '"; 1753 body << " << \"" << value << "\";\n"; 1754 1755 // Insert a space after certain literals. 1756 shouldEmitSpace = 1757 value.size() != 1 || !StringRef("<({[").contains(value.front()); 1758 lastWasPunctuation = !(value.front() == '_' || isalpha(value.front())); 1759 } 1760 1761 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation` 1762 /// are set to false. 1763 static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace, 1764 bool &lastWasPunctuation) { 1765 if (value) { 1766 body << " _odsPrinter << ' ';\n"; 1767 lastWasPunctuation = false; 1768 } else { 1769 lastWasPunctuation = true; 1770 } 1771 shouldEmitSpace = false; 1772 } 1773 1774 /// Generate the printer for a custom directive parameter. 1775 static void genCustomDirectiveParameterPrinter(Element *element, 1776 const Operator &op, 1777 MethodBody &body) { 1778 if (auto *attr = dyn_cast<AttributeVariable>(element)) { 1779 body << op.getGetterName(attr->getVar()->name) << "Attr()"; 1780 1781 } else if (isa<AttrDictDirective>(element)) { 1782 body << "getOperation()->getAttrDictionary()"; 1783 1784 } else if (auto *operand = dyn_cast<OperandVariable>(element)) { 1785 body << op.getGetterName(operand->getVar()->name) << "()"; 1786 1787 } else if (auto *region = dyn_cast<RegionVariable>(element)) { 1788 body << op.getGetterName(region->getVar()->name) << "()"; 1789 1790 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) { 1791 body << op.getGetterName(successor->getVar()->name) << "()"; 1792 1793 } else if (auto *dir = dyn_cast<RefDirective>(element)) { 1794 genCustomDirectiveParameterPrinter(dir->getOperand(), op, body); 1795 1796 } else if (auto *dir = dyn_cast<TypeDirective>(element)) { 1797 auto *typeOperand = dir->getOperand(); 1798 auto *operand = dyn_cast<OperandVariable>(typeOperand); 1799 auto *var = operand ? operand->getVar() 1800 : cast<ResultVariable>(typeOperand)->getVar(); 1801 std::string name = op.getGetterName(var->name); 1802 if (var->isVariadic()) 1803 body << name << "().getTypes()"; 1804 else if (var->isOptional()) 1805 body << llvm::formatv("({0}() ? {0}().getType() : Type())", name); 1806 else 1807 body << name << "().getType()"; 1808 } else { 1809 llvm_unreachable("unknown custom directive parameter"); 1810 } 1811 } 1812 1813 /// Generate the printer for a custom directive. 1814 static void genCustomDirectivePrinter(CustomDirective *customDir, 1815 const Operator &op, MethodBody &body) { 1816 body << " print" << customDir->getName() << "(_odsPrinter, *this"; 1817 for (Element ¶m : customDir->getArguments()) { 1818 body << ", "; 1819 genCustomDirectiveParameterPrinter(¶m, op, body); 1820 } 1821 body << ");\n"; 1822 } 1823 1824 /// Generate the printer for a region with the given variable name. 1825 static void genRegionPrinter(const Twine ®ionName, MethodBody &body, 1826 bool hasImplicitTermTrait) { 1827 if (hasImplicitTermTrait) 1828 body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode, 1829 regionName); 1830 else 1831 body << " _odsPrinter.printRegion(" << regionName << ");\n"; 1832 } 1833 static void genVariadicRegionPrinter(const Twine ®ionListName, 1834 MethodBody &body, 1835 bool hasImplicitTermTrait) { 1836 body << " llvm::interleaveComma(" << regionListName 1837 << ", _odsPrinter, [&](::mlir::Region ®ion) {\n "; 1838 genRegionPrinter("region", body, hasImplicitTermTrait); 1839 body << " });\n"; 1840 } 1841 1842 /// Generate the C++ for an operand to a (*-)type directive. 1843 static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op, 1844 MethodBody &body, 1845 bool useArrayRef = true) { 1846 if (isa<OperandsDirective>(arg)) 1847 return body << "getOperation()->getOperandTypes()"; 1848 if (isa<ResultsDirective>(arg)) 1849 return body << "getOperation()->getResultTypes()"; 1850 auto *operand = dyn_cast<OperandVariable>(arg); 1851 auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar(); 1852 if (var->isVariadicOfVariadic()) 1853 return body << llvm::formatv("{0}().join().getTypes()", 1854 op.getGetterName(var->name)); 1855 if (var->isVariadic()) 1856 return body << op.getGetterName(var->name) << "().getTypes()"; 1857 if (var->isOptional()) 1858 return body << llvm::formatv( 1859 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " 1860 "::llvm::ArrayRef<::mlir::Type>())", 1861 op.getGetterName(var->name)); 1862 if (useArrayRef) 1863 return body << "::llvm::ArrayRef<::mlir::Type>(" 1864 << op.getGetterName(var->name) << "().getType())"; 1865 return body << op.getGetterName(var->name) << "().getType()"; 1866 } 1867 1868 /// Generate the printer for an enum attribute. 1869 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, 1870 MethodBody &body) { 1871 Attribute baseAttr = var->attr.getBaseAttr(); 1872 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr); 1873 std::vector<EnumAttrCase> cases = enumAttr.getAllCases(); 1874 1875 body << llvm::formatv(enumAttrBeginPrinterCode, 1876 (var->attr.isOptional() ? "*" : "") + 1877 op.getGetterName(var->name), 1878 enumAttr.getSymbolToStringFnName()); 1879 1880 // Get a string containing all of the cases that can't be represented with a 1881 // keyword. 1882 llvm::BitVector nonKeywordCases(cases.size()); 1883 bool hasStrCase = false; 1884 for (auto &it : llvm::enumerate(cases)) { 1885 hasStrCase = it.value().isStrCase(); 1886 if (!canFormatStringAsKeyword(it.value().getStr())) 1887 nonKeywordCases.set(it.index()); 1888 } 1889 1890 // If this is a string enum, use the case string to determine which cases 1891 // need to use the string form. 1892 if (hasStrCase) { 1893 if (nonKeywordCases.any()) { 1894 body << " if (llvm::is_contained(llvm::ArrayRef<llvm::StringRef>("; 1895 llvm::interleaveComma(nonKeywordCases.set_bits(), body, [&](unsigned it) { 1896 body << '"' << cases[it].getStr() << '"'; 1897 }); 1898 body << ")))\n" 1899 " _odsPrinter << '\"' << caseValueStr << '\"';\n" 1900 " else\n "; 1901 } 1902 body << " _odsPrinter << caseValueStr;\n" 1903 " }\n"; 1904 return; 1905 } 1906 1907 // Otherwise if this is a bit enum attribute, don't allow cases that may 1908 // overlap with other cases. For simplicity sake, only allow cases with a 1909 // single bit value. 1910 if (enumAttr.isBitEnum()) { 1911 for (auto &it : llvm::enumerate(cases)) { 1912 int64_t value = it.value().getValue(); 1913 if (value < 0 || !llvm::isPowerOf2_64(value)) 1914 nonKeywordCases.set(it.index()); 1915 } 1916 } 1917 1918 // If there are any cases that can't be used with a keyword, switch on the 1919 // case value to determine when to print in the string form. 1920 if (nonKeywordCases.any()) { 1921 body << " switch (caseValue) {\n"; 1922 StringRef cppNamespace = enumAttr.getCppNamespace(); 1923 StringRef enumName = enumAttr.getEnumClassName(); 1924 for (auto &it : llvm::enumerate(cases)) { 1925 if (nonKeywordCases.test(it.index())) 1926 continue; 1927 StringRef symbol = it.value().getSymbol(); 1928 body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName, 1929 llvm::isDigit(symbol.front()) ? ("_" + symbol) 1930 : symbol); 1931 } 1932 body << " _odsPrinter << caseValueStr;\n" 1933 " break;\n" 1934 " default:\n" 1935 " _odsPrinter << '\"' << caseValueStr << '\"';\n" 1936 " break;\n" 1937 " }\n" 1938 " }\n"; 1939 return; 1940 } 1941 1942 body << " _odsPrinter << caseValueStr;\n" 1943 " }\n"; 1944 } 1945 1946 /// Generate the check for the anchor of an optional group. 1947 static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op, 1948 MethodBody &body) { 1949 TypeSwitch<Element *>(anchor) 1950 .Case<OperandVariable, ResultVariable>([&](auto *element) { 1951 const NamedTypeConstraint *var = element->getVar(); 1952 std::string name = op.getGetterName(var->name); 1953 if (var->isOptional()) 1954 body << " if (" << name << "()) {\n"; 1955 else if (var->isVariadic()) 1956 body << " if (!" << name << "().empty()) {\n"; 1957 }) 1958 .Case<RegionVariable>([&](RegionVariable *element) { 1959 const NamedRegion *var = element->getVar(); 1960 std::string name = op.getGetterName(var->name); 1961 // TODO: Add a check for optional regions here when ODS supports it. 1962 body << " if (!" << name << "().empty()) {\n"; 1963 }) 1964 .Case<TypeDirective>([&](TypeDirective *element) { 1965 genOptionalGroupPrinterAnchor(element->getOperand(), op, body); 1966 }) 1967 .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) { 1968 genOptionalGroupPrinterAnchor(element->getInputs(), op, body); 1969 }) 1970 .Case<AttributeVariable>([&](AttributeVariable *attr) { 1971 body << " if ((*this)->getAttr(\"" << attr->getVar()->name 1972 << "\")) {\n"; 1973 }); 1974 } 1975 1976 void OperationFormat::genElementPrinter(Element *element, MethodBody &body, 1977 Operator &op, bool &shouldEmitSpace, 1978 bool &lastWasPunctuation) { 1979 if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) 1980 return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace, 1981 lastWasPunctuation); 1982 1983 // Emit a whitespace element. 1984 if (isa<NewlineElement>(element)) { 1985 body << " _odsPrinter.printNewline();\n"; 1986 return; 1987 } 1988 if (SpaceElement *space = dyn_cast<SpaceElement>(element)) 1989 return genSpacePrinter(space->getValue(), body, shouldEmitSpace, 1990 lastWasPunctuation); 1991 1992 // Emit an optional group. 1993 if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) { 1994 // Emit the check for the presence of the anchor element. 1995 Element *anchor = optional->getAnchor(); 1996 genOptionalGroupPrinterAnchor(anchor, op, body); 1997 1998 // If the anchor is a unit attribute, we don't need to print it. When 1999 // parsing, we will add this attribute if this group is present. 2000 auto elements = optional->getThenElements(); 2001 Element *elidedAnchorElement = nullptr; 2002 auto *anchorAttr = dyn_cast<AttributeVariable>(anchor); 2003 if (anchorAttr && anchorAttr != &*elements.begin() && 2004 anchorAttr->isUnitAttr()) { 2005 elidedAnchorElement = anchorAttr; 2006 } 2007 2008 // Emit each of the elements. 2009 for (Element &childElement : elements) { 2010 if (&childElement != elidedAnchorElement) { 2011 genElementPrinter(&childElement, body, op, shouldEmitSpace, 2012 lastWasPunctuation); 2013 } 2014 } 2015 body << " }"; 2016 2017 // Emit each of the else elements. 2018 auto elseElements = optional->getElseElements(); 2019 if (!elseElements.empty()) { 2020 body << " else {\n"; 2021 for (Element &childElement : elseElements) { 2022 genElementPrinter(&childElement, body, op, shouldEmitSpace, 2023 lastWasPunctuation); 2024 } 2025 body << " }"; 2026 } 2027 2028 body << "\n"; 2029 return; 2030 } 2031 2032 // Emit the attribute dictionary. 2033 if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) { 2034 genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword()); 2035 lastWasPunctuation = false; 2036 return; 2037 } 2038 2039 // Optionally insert a space before the next element. The AttrDict printer 2040 // already adds a space as necessary. 2041 if (shouldEmitSpace || !lastWasPunctuation) 2042 body << " _odsPrinter << ' ';\n"; 2043 lastWasPunctuation = false; 2044 shouldEmitSpace = true; 2045 2046 if (auto *attr = dyn_cast<AttributeVariable>(element)) { 2047 const NamedAttribute *var = attr->getVar(); 2048 2049 // If we are formatting as an enum, symbolize the attribute as a string. 2050 if (canFormatEnumAttr(var)) 2051 return genEnumAttrPrinter(var, op, body); 2052 2053 // If we are formatting as a symbol name, handle it as a symbol name. 2054 if (shouldFormatSymbolNameAttr(var)) { 2055 body << " _odsPrinter.printSymbolName(" << op.getGetterName(var->name) 2056 << "Attr().getValue());\n"; 2057 return; 2058 } 2059 2060 // Elide the attribute type if it is buildable. 2061 if (attr->getTypeBuilder()) 2062 body << " _odsPrinter.printAttributeWithoutType(" 2063 << op.getGetterName(var->name) << "Attr());\n"; 2064 else if (attr->shouldBeQualified() || 2065 var->attr.getStorageType() == "::mlir::Attribute") 2066 body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name) 2067 << "Attr());\n"; 2068 else 2069 body << "_odsPrinter.printStrippedAttrOrType(" 2070 << op.getGetterName(var->name) << "Attr());\n"; 2071 } else if (auto *operand = dyn_cast<OperandVariable>(element)) { 2072 if (operand->getVar()->isVariadicOfVariadic()) { 2073 body << " ::llvm::interleaveComma(" 2074 << op.getGetterName(operand->getVar()->name) 2075 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << " 2076 "\"(\" << operands << " 2077 "\")\"; });\n"; 2078 2079 } else if (operand->getVar()->isOptional()) { 2080 body << " if (::mlir::Value value = " 2081 << op.getGetterName(operand->getVar()->name) << "())\n" 2082 << " _odsPrinter << value;\n"; 2083 } else { 2084 body << " _odsPrinter << " << op.getGetterName(operand->getVar()->name) 2085 << "();\n"; 2086 } 2087 } else if (auto *region = dyn_cast<RegionVariable>(element)) { 2088 const NamedRegion *var = region->getVar(); 2089 std::string name = op.getGetterName(var->name); 2090 if (var->isVariadic()) { 2091 genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait); 2092 } else { 2093 genRegionPrinter(name + "()", body, hasImplicitTermTrait); 2094 } 2095 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) { 2096 const NamedSuccessor *var = successor->getVar(); 2097 std::string name = op.getGetterName(var->name); 2098 if (var->isVariadic()) 2099 body << " ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n"; 2100 else 2101 body << " _odsPrinter << " << name << "();\n"; 2102 } else if (auto *dir = dyn_cast<CustomDirective>(element)) { 2103 genCustomDirectivePrinter(dir, op, body); 2104 } else if (isa<OperandsDirective>(element)) { 2105 body << " _odsPrinter << getOperation()->getOperands();\n"; 2106 } else if (isa<RegionsDirective>(element)) { 2107 genVariadicRegionPrinter("getOperation()->getRegions()", body, 2108 hasImplicitTermTrait); 2109 } else if (isa<SuccessorsDirective>(element)) { 2110 body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), " 2111 "_odsPrinter);\n"; 2112 } else if (auto *dir = dyn_cast<TypeDirective>(element)) { 2113 if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand())) { 2114 if (operand->getVar()->isVariadicOfVariadic()) { 2115 body << llvm::formatv( 2116 " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, " 2117 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << " 2118 "types << \")\"; });\n", 2119 op.getGetterName(operand->getVar()->name)); 2120 return; 2121 } 2122 } 2123 const NamedTypeConstraint *var = nullptr; 2124 { 2125 if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand())) 2126 var = operand->getVar(); 2127 else if (auto *operand = dyn_cast<ResultVariable>(dir->getOperand())) 2128 var = operand->getVar(); 2129 } 2130 if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && 2131 !var->isOptional()) { 2132 std::string cppClass = var->constraint.getCPPClassName(); 2133 if (dir->shouldBeQualified()) { 2134 body << " _odsPrinter << " << op.getGetterName(var->name) 2135 << "().getType();\n"; 2136 return; 2137 } 2138 body << " {\n" 2139 << " auto type = " << op.getGetterName(var->name) 2140 << "().getType();\n" 2141 << " if (auto validType = type.dyn_cast<" << cppClass << ">())\n" 2142 << " _odsPrinter.printStrippedAttrOrType(validType);\n" 2143 << " else\n" 2144 << " _odsPrinter << type;\n" 2145 << " }\n"; 2146 return; 2147 } 2148 body << " _odsPrinter << "; 2149 genTypeOperandPrinter(dir->getOperand(), op, body, /*useArrayRef=*/false) 2150 << ";\n"; 2151 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) { 2152 body << " _odsPrinter.printFunctionalType("; 2153 genTypeOperandPrinter(dir->getInputs(), op, body) << ", "; 2154 genTypeOperandPrinter(dir->getResults(), op, body) << ");\n"; 2155 } else { 2156 llvm_unreachable("unknown format element"); 2157 } 2158 } 2159 2160 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { 2161 auto *method = opClass.addMethod( 2162 "void", "print", 2163 MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter")); 2164 auto &body = method->body(); 2165 2166 // Flags for if we should emit a space, and if the last element was 2167 // punctuation. 2168 bool shouldEmitSpace = true, lastWasPunctuation = false; 2169 for (auto &element : elements) 2170 genElementPrinter(element.get(), body, op, shouldEmitSpace, 2171 lastWasPunctuation); 2172 } 2173 2174 //===----------------------------------------------------------------------===// 2175 // FormatParser 2176 //===----------------------------------------------------------------------===// 2177 2178 /// Function to find an element within the given range that has the same name as 2179 /// 'name'. 2180 template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) { 2181 auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); 2182 return it != range.end() ? &*it : nullptr; 2183 } 2184 2185 namespace { 2186 /// This class implements a parser for an instance of an operation assembly 2187 /// format. 2188 class FormatParser { 2189 public: 2190 FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) 2191 : lexer(mgr, op.getLoc()[0]), curToken(lexer.lexToken()), fmt(format), 2192 op(op), seenOperandTypes(op.getNumOperands()), 2193 seenResultTypes(op.getNumResults()) {} 2194 2195 /// Parse the operation assembly format. 2196 LogicalResult parse(); 2197 2198 private: 2199 /// The current context of the parser when parsing an element. 2200 enum ParserContext { 2201 /// The element is being parsed in a "top-level" context, i.e. at the top of 2202 /// the format or in an optional group. 2203 TopLevelContext, 2204 /// The element is being parsed as a custom directive child. 2205 CustomDirectiveContext, 2206 /// The element is being parsed as a type directive child. 2207 TypeDirectiveContext, 2208 /// The element is being parsed as a reference directive child. 2209 RefDirectiveContext 2210 }; 2211 2212 /// This struct represents a type resolution instance. It includes a specific 2213 /// type as well as an optional transformer to apply to that type in order to 2214 /// properly resolve the type of a variable. 2215 struct TypeResolutionInstance { 2216 ConstArgument resolver; 2217 Optional<StringRef> transformer; 2218 }; 2219 2220 /// An iterator over the elements of a format group. 2221 using ElementsIterT = llvm::pointee_iterator< 2222 std::vector<std::unique_ptr<Element>>::const_iterator>; 2223 2224 /// Verify the state of operation attributes within the format. 2225 LogicalResult verifyAttributes(llvm::SMLoc loc); 2226 /// Verify the attribute elements at the back of the given stack of iterators. 2227 LogicalResult verifyAttributes( 2228 llvm::SMLoc loc, 2229 SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack); 2230 2231 /// Verify the state of operation operands within the format. 2232 LogicalResult 2233 verifyOperands(llvm::SMLoc loc, 2234 llvm::StringMap<TypeResolutionInstance> &variableTyResolver); 2235 2236 /// Verify the state of operation regions within the format. 2237 LogicalResult verifyRegions(llvm::SMLoc loc); 2238 2239 /// Verify the state of operation results within the format. 2240 LogicalResult 2241 verifyResults(llvm::SMLoc loc, 2242 llvm::StringMap<TypeResolutionInstance> &variableTyResolver); 2243 2244 /// Verify the state of operation successors within the format. 2245 LogicalResult verifySuccessors(llvm::SMLoc loc); 2246 2247 /// Given the values of an `AllTypesMatch` trait, check for inferable type 2248 /// resolution. 2249 void handleAllTypesMatchConstraint( 2250 ArrayRef<StringRef> values, 2251 llvm::StringMap<TypeResolutionInstance> &variableTyResolver); 2252 /// Check for inferable type resolution given all operands, and or results, 2253 /// have the same type. If 'includeResults' is true, the results also have the 2254 /// same type as all of the operands. 2255 void handleSameTypesConstraint( 2256 llvm::StringMap<TypeResolutionInstance> &variableTyResolver, 2257 bool includeResults); 2258 /// Check for inferable type resolution based on another operand, result, or 2259 /// attribute. 2260 void handleTypesMatchConstraint( 2261 llvm::StringMap<TypeResolutionInstance> &variableTyResolver, 2262 const llvm::Record &def); 2263 2264 /// Returns an argument or attribute with the given name that has been seen 2265 /// within the format. 2266 ConstArgument findSeenArg(StringRef name); 2267 2268 /// Parse a specific element. 2269 LogicalResult parseElement(std::unique_ptr<Element> &element, 2270 ParserContext context); 2271 LogicalResult parseVariable(std::unique_ptr<Element> &element, 2272 ParserContext context); 2273 LogicalResult parseDirective(std::unique_ptr<Element> &element, 2274 ParserContext context); 2275 LogicalResult parseLiteral(std::unique_ptr<Element> &element, 2276 ParserContext context); 2277 LogicalResult parseOptional(std::unique_ptr<Element> &element, 2278 ParserContext context); 2279 LogicalResult parseOptionalChildElement( 2280 std::vector<std::unique_ptr<Element>> &childElements, 2281 Optional<unsigned> &anchorIdx); 2282 LogicalResult verifyOptionalChildElement(Element *element, 2283 llvm::SMLoc childLoc, bool isAnchor); 2284 2285 /// Parse the various different directives. 2286 LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element, 2287 llvm::SMLoc loc, ParserContext context, 2288 bool withKeyword); 2289 LogicalResult parseCustomDirective(std::unique_ptr<Element> &element, 2290 llvm::SMLoc loc, ParserContext context); 2291 LogicalResult parseCustomDirectiveParameter( 2292 std::vector<std::unique_ptr<Element>> ¶meters); 2293 LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element, 2294 FormatToken tok, 2295 ParserContext context); 2296 LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element, 2297 llvm::SMLoc loc, ParserContext context); 2298 LogicalResult parseQualifiedDirective(std::unique_ptr<Element> &element, 2299 FormatToken tok, ParserContext context); 2300 LogicalResult parseReferenceDirective(std::unique_ptr<Element> &element, 2301 llvm::SMLoc loc, ParserContext context); 2302 LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element, 2303 llvm::SMLoc loc, ParserContext context); 2304 LogicalResult parseResultsDirective(std::unique_ptr<Element> &element, 2305 llvm::SMLoc loc, ParserContext context); 2306 LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element, 2307 llvm::SMLoc loc, 2308 ParserContext context); 2309 LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, 2310 FormatToken tok, ParserContext context); 2311 LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element, 2312 bool isRefChild = false); 2313 2314 //===--------------------------------------------------------------------===// 2315 // Lexer Utilities 2316 //===--------------------------------------------------------------------===// 2317 2318 /// Advance the current lexer onto the next token. 2319 void consumeToken() { 2320 assert(curToken.getKind() != FormatToken::eof && 2321 curToken.getKind() != FormatToken::error && 2322 "shouldn't advance past EOF or errors"); 2323 curToken = lexer.lexToken(); 2324 } 2325 LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) { 2326 if (curToken.getKind() != kind) 2327 return emitError(curToken.getLoc(), msg); 2328 consumeToken(); 2329 return ::mlir::success(); 2330 } 2331 LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) { 2332 lexer.emitError(loc, msg); 2333 return ::mlir::failure(); 2334 } 2335 LogicalResult emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, 2336 const Twine ¬e) { 2337 lexer.emitErrorAndNote(loc, msg, note); 2338 return ::mlir::failure(); 2339 } 2340 2341 //===--------------------------------------------------------------------===// 2342 // Fields 2343 //===--------------------------------------------------------------------===// 2344 2345 FormatLexer lexer; 2346 FormatToken curToken; 2347 OperationFormat &fmt; 2348 Operator &op; 2349 2350 // The following are various bits of format state used for verification 2351 // during parsing. 2352 bool hasAttrDict = false; 2353 bool hasAllRegions = false, hasAllSuccessors = false; 2354 bool canInferResultTypes = false; 2355 llvm::SmallBitVector seenOperandTypes, seenResultTypes; 2356 llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs; 2357 llvm::DenseSet<const NamedTypeConstraint *> seenOperands; 2358 llvm::DenseSet<const NamedRegion *> seenRegions; 2359 llvm::DenseSet<const NamedSuccessor *> seenSuccessors; 2360 }; 2361 } // namespace 2362 2363 LogicalResult FormatParser::parse() { 2364 llvm::SMLoc loc = curToken.getLoc(); 2365 2366 // Parse each of the format elements into the main format. 2367 while (curToken.getKind() != FormatToken::eof) { 2368 std::unique_ptr<Element> element; 2369 if (failed(parseElement(element, TopLevelContext))) 2370 return ::mlir::failure(); 2371 fmt.elements.push_back(std::move(element)); 2372 } 2373 2374 // Check that the attribute dictionary is in the format. 2375 if (!hasAttrDict) 2376 return emitError(loc, "'attr-dict' directive not found in " 2377 "custom assembly format"); 2378 2379 // Check for any type traits that we can use for inferring types. 2380 llvm::StringMap<TypeResolutionInstance> variableTyResolver; 2381 for (const Trait &trait : op.getTraits()) { 2382 const llvm::Record &def = trait.getDef(); 2383 if (def.isSubClassOf("AllTypesMatch")) { 2384 handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"), 2385 variableTyResolver); 2386 } else if (def.getName() == "SameTypeOperands") { 2387 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false); 2388 } else if (def.getName() == "SameOperandsAndResultType") { 2389 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); 2390 } else if (def.isSubClassOf("TypesMatchWith")) { 2391 handleTypesMatchConstraint(variableTyResolver, def); 2392 } else if (!op.allResultTypesKnown()) { 2393 // This doesn't check the name directly to handle 2394 // DeclareOpInterfaceMethods<InferTypeOpInterface> 2395 // and the like. 2396 // TODO: Add hasCppInterface check. 2397 if (auto name = def.getValueAsOptionalString("cppClassName")) { 2398 if (*name == "InferTypeOpInterface" && 2399 def.getValueAsString("cppNamespace") == "::mlir") 2400 canInferResultTypes = true; 2401 } 2402 } 2403 } 2404 2405 // Verify the state of the various operation components. 2406 if (failed(verifyAttributes(loc)) || 2407 failed(verifyResults(loc, variableTyResolver)) || 2408 failed(verifyOperands(loc, variableTyResolver)) || 2409 failed(verifyRegions(loc)) || failed(verifySuccessors(loc))) 2410 return ::mlir::failure(); 2411 2412 // Collect the set of used attributes in the format. 2413 fmt.usedAttributes = seenAttrs.takeVector(); 2414 return ::mlir::success(); 2415 } 2416 2417 LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) { 2418 // Check that there are no `:` literals after an attribute without a constant 2419 // type. The attribute grammar contains an optional trailing colon type, which 2420 // can lead to unexpected and generally unintended behavior. Given that, it is 2421 // better to just error out here instead. 2422 using ElementsIterT = llvm::pointee_iterator< 2423 std::vector<std::unique_ptr<Element>>::const_iterator>; 2424 SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack; 2425 iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end()); 2426 while (!iteratorStack.empty()) 2427 if (failed(verifyAttributes(loc, iteratorStack))) 2428 return ::mlir::failure(); 2429 2430 // Check for VariadicOfVariadic variables. The segment attribute of those 2431 // variables will be infered. 2432 for (const NamedTypeConstraint *var : seenOperands) { 2433 if (var->constraint.isVariadicOfVariadic()) { 2434 fmt.inferredAttributes.insert( 2435 var->constraint.getVariadicOfVariadicSegmentSizeAttr()); 2436 } 2437 } 2438 2439 return ::mlir::success(); 2440 } 2441 /// Verify the attribute elements at the back of the given stack of iterators. 2442 LogicalResult FormatParser::verifyAttributes( 2443 llvm::SMLoc loc, 2444 SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack) { 2445 auto &stackIt = iteratorStack.back(); 2446 ElementsIterT &it = stackIt.first, e = stackIt.second; 2447 while (it != e) { 2448 Element *element = &*(it++); 2449 2450 // Traverse into optional groups. 2451 if (auto *optional = dyn_cast<OptionalElement>(element)) { 2452 auto thenElements = optional->getThenElements(); 2453 iteratorStack.emplace_back(thenElements.begin(), thenElements.end()); 2454 2455 auto elseElements = optional->getElseElements(); 2456 iteratorStack.emplace_back(elseElements.begin(), elseElements.end()); 2457 return ::mlir::success(); 2458 } 2459 2460 // We are checking for an attribute element followed by a `:`, so there is 2461 // no need to check the end. 2462 if (it == e && iteratorStack.size() == 1) 2463 break; 2464 2465 // Check for an attribute with a constant type builder, followed by a `:`. 2466 auto *prevAttr = dyn_cast<AttributeVariable>(element); 2467 if (!prevAttr || prevAttr->getTypeBuilder()) 2468 continue; 2469 2470 // Check the next iterator within the stack for literal elements. 2471 for (auto &nextItPair : iteratorStack) { 2472 ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second; 2473 for (; nextIt != nextE; ++nextIt) { 2474 // Skip any trailing whitespace, attribute dictionaries, or optional 2475 // groups. 2476 if (isa<WhitespaceElement>(*nextIt) || 2477 isa<AttrDictDirective>(*nextIt) || isa<OptionalElement>(*nextIt)) 2478 continue; 2479 2480 // We are only interested in `:` literals. 2481 auto *literal = dyn_cast<LiteralElement>(&*nextIt); 2482 if (!literal || literal->getLiteral() != ":") 2483 break; 2484 2485 // TODO: Use the location of the literal element itself. 2486 return emitError( 2487 loc, llvm::formatv("format ambiguity caused by `:` literal found " 2488 "after attribute `{0}` which does not have " 2489 "a buildable type", 2490 prevAttr->getVar()->name)); 2491 } 2492 } 2493 } 2494 iteratorStack.pop_back(); 2495 return ::mlir::success(); 2496 } 2497 2498 LogicalResult FormatParser::verifyOperands( 2499 llvm::SMLoc loc, 2500 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) { 2501 // Check that all of the operands are within the format, and their types can 2502 // be inferred. 2503 auto &buildableTypes = fmt.buildableTypes; 2504 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { 2505 NamedTypeConstraint &operand = op.getOperand(i); 2506 2507 // Check that the operand itself is in the format. 2508 if (!fmt.allOperands && !seenOperands.count(&operand)) { 2509 return emitErrorAndNote(loc, 2510 "operand #" + Twine(i) + ", named '" + 2511 operand.name + "', not found", 2512 "suggest adding a '$" + operand.name + 2513 "' directive to the custom assembly format"); 2514 } 2515 2516 // Check that the operand type is in the format, or that it can be inferred. 2517 if (fmt.allOperandTypes || seenOperandTypes.test(i)) 2518 continue; 2519 2520 // Check to see if we can infer this type from another variable. 2521 auto varResolverIt = variableTyResolver.find(op.getOperand(i).name); 2522 if (varResolverIt != variableTyResolver.end()) { 2523 TypeResolutionInstance &resolver = varResolverIt->second; 2524 fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer); 2525 continue; 2526 } 2527 2528 // Similarly to results, allow a custom builder for resolving the type if 2529 // we aren't using the 'operands' directive. 2530 Optional<StringRef> builder = operand.constraint.getBuilderCall(); 2531 if (!builder || (fmt.allOperands && operand.isVariableLength())) { 2532 return emitErrorAndNote( 2533 loc, 2534 "type of operand #" + Twine(i) + ", named '" + operand.name + 2535 "', is not buildable and a buildable type cannot be inferred", 2536 "suggest adding a type constraint to the operation or adding a " 2537 "'type($" + 2538 operand.name + ")' directive to the " + "custom assembly format"); 2539 } 2540 auto it = buildableTypes.insert({*builder, buildableTypes.size()}); 2541 fmt.operandTypes[i].setBuilderIdx(it.first->second); 2542 } 2543 return ::mlir::success(); 2544 } 2545 2546 LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) { 2547 // Check that all of the regions are within the format. 2548 if (hasAllRegions) 2549 return ::mlir::success(); 2550 2551 for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) { 2552 const NamedRegion ®ion = op.getRegion(i); 2553 if (!seenRegions.count(®ion)) { 2554 return emitErrorAndNote(loc, 2555 "region #" + Twine(i) + ", named '" + 2556 region.name + "', not found", 2557 "suggest adding a '$" + region.name + 2558 "' directive to the custom assembly format"); 2559 } 2560 } 2561 return ::mlir::success(); 2562 } 2563 2564 LogicalResult FormatParser::verifyResults( 2565 llvm::SMLoc loc, 2566 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) { 2567 // If we format all of the types together, there is nothing to check. 2568 if (fmt.allResultTypes) 2569 return ::mlir::success(); 2570 2571 // If no result types are specified and we can infer them, infer all result 2572 // types 2573 if (op.getNumResults() > 0 && seenResultTypes.count() == 0 && 2574 canInferResultTypes) { 2575 fmt.infersResultTypes = true; 2576 return ::mlir::success(); 2577 } 2578 2579 // Check that all of the result types can be inferred. 2580 auto &buildableTypes = fmt.buildableTypes; 2581 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { 2582 if (seenResultTypes.test(i)) 2583 continue; 2584 2585 // Check to see if we can infer this type from another variable. 2586 auto varResolverIt = variableTyResolver.find(op.getResultName(i)); 2587 if (varResolverIt != variableTyResolver.end()) { 2588 TypeResolutionInstance resolver = varResolverIt->second; 2589 fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer); 2590 continue; 2591 } 2592 2593 // If the result is not variable length, allow for the case where the type 2594 // has a builder that we can use. 2595 NamedTypeConstraint &result = op.getResult(i); 2596 Optional<StringRef> builder = result.constraint.getBuilderCall(); 2597 if (!builder || result.isVariableLength()) { 2598 return emitErrorAndNote( 2599 loc, 2600 "type of result #" + Twine(i) + ", named '" + result.name + 2601 "', is not buildable and a buildable type cannot be inferred", 2602 "suggest adding a type constraint to the operation or adding a " 2603 "'type($" + 2604 result.name + ")' directive to the " + "custom assembly format"); 2605 } 2606 // Note in the format that this result uses the custom builder. 2607 auto it = buildableTypes.insert({*builder, buildableTypes.size()}); 2608 fmt.resultTypes[i].setBuilderIdx(it.first->second); 2609 } 2610 return ::mlir::success(); 2611 } 2612 2613 LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) { 2614 // Check that all of the successors are within the format. 2615 if (hasAllSuccessors) 2616 return ::mlir::success(); 2617 2618 for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) { 2619 const NamedSuccessor &successor = op.getSuccessor(i); 2620 if (!seenSuccessors.count(&successor)) { 2621 return emitErrorAndNote(loc, 2622 "successor #" + Twine(i) + ", named '" + 2623 successor.name + "', not found", 2624 "suggest adding a '$" + successor.name + 2625 "' directive to the custom assembly format"); 2626 } 2627 } 2628 return ::mlir::success(); 2629 } 2630 2631 void FormatParser::handleAllTypesMatchConstraint( 2632 ArrayRef<StringRef> values, 2633 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) { 2634 for (unsigned i = 0, e = values.size(); i != e; ++i) { 2635 // Check to see if this value matches a resolved operand or result type. 2636 ConstArgument arg = findSeenArg(values[i]); 2637 if (!arg) 2638 continue; 2639 2640 // Mark this value as the type resolver for the other variables. 2641 for (unsigned j = 0; j != i; ++j) 2642 variableTyResolver[values[j]] = {arg, llvm::None}; 2643 for (unsigned j = i + 1; j != e; ++j) 2644 variableTyResolver[values[j]] = {arg, llvm::None}; 2645 } 2646 } 2647 2648 void FormatParser::handleSameTypesConstraint( 2649 llvm::StringMap<TypeResolutionInstance> &variableTyResolver, 2650 bool includeResults) { 2651 const NamedTypeConstraint *resolver = nullptr; 2652 int resolvedIt = -1; 2653 2654 // Check to see if there is an operand or result to use for the resolution. 2655 if ((resolvedIt = seenOperandTypes.find_first()) != -1) 2656 resolver = &op.getOperand(resolvedIt); 2657 else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1) 2658 resolver = &op.getResult(resolvedIt); 2659 else 2660 return; 2661 2662 // Set the resolvers for each operand and result. 2663 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) 2664 if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty()) 2665 variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None}; 2666 if (includeResults) { 2667 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) 2668 if (!seenResultTypes.test(i) && !op.getResultName(i).empty()) 2669 variableTyResolver[op.getResultName(i)] = {resolver, llvm::None}; 2670 } 2671 } 2672 2673 void FormatParser::handleTypesMatchConstraint( 2674 llvm::StringMap<TypeResolutionInstance> &variableTyResolver, 2675 const llvm::Record &def) { 2676 StringRef lhsName = def.getValueAsString("lhs"); 2677 StringRef rhsName = def.getValueAsString("rhs"); 2678 StringRef transformer = def.getValueAsString("transformer"); 2679 if (ConstArgument arg = findSeenArg(lhsName)) 2680 variableTyResolver[rhsName] = {arg, transformer}; 2681 } 2682 2683 ConstArgument FormatParser::findSeenArg(StringRef name) { 2684 if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name)) 2685 return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr; 2686 if (const NamedTypeConstraint *arg = findArg(op.getResults(), name)) 2687 return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr; 2688 if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) 2689 return seenAttrs.count(attr) ? attr : nullptr; 2690 return nullptr; 2691 } 2692 2693 LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element, 2694 ParserContext context) { 2695 // Directives. 2696 if (curToken.isKeyword()) 2697 return parseDirective(element, context); 2698 // Literals. 2699 if (curToken.getKind() == FormatToken::literal) 2700 return parseLiteral(element, context); 2701 // Optionals. 2702 if (curToken.getKind() == FormatToken::l_paren) 2703 return parseOptional(element, context); 2704 // Variables. 2705 if (curToken.getKind() == FormatToken::variable) 2706 return parseVariable(element, context); 2707 return emitError(curToken.getLoc(), 2708 "expected directive, literal, variable, or optional group"); 2709 } 2710 2711 LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element, 2712 ParserContext context) { 2713 FormatToken varTok = curToken; 2714 consumeToken(); 2715 2716 StringRef name = varTok.getSpelling().drop_front(); 2717 llvm::SMLoc loc = varTok.getLoc(); 2718 2719 // Check that the parsed argument is something actually registered on the 2720 // op. 2721 /// Attributes 2722 if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) { 2723 if (context == TypeDirectiveContext) 2724 return emitError( 2725 loc, "attributes cannot be used as children to a `type` directive"); 2726 if (context == RefDirectiveContext) { 2727 if (!seenAttrs.count(attr)) 2728 return emitError(loc, "attribute '" + name + 2729 "' must be bound before it is referenced"); 2730 } else if (!seenAttrs.insert(attr)) { 2731 return emitError(loc, "attribute '" + name + "' is already bound"); 2732 } 2733 2734 element = std::make_unique<AttributeVariable>(attr); 2735 return ::mlir::success(); 2736 } 2737 /// Operands 2738 if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) { 2739 if (context == TopLevelContext || context == CustomDirectiveContext) { 2740 if (fmt.allOperands || !seenOperands.insert(operand).second) 2741 return emitError(loc, "operand '" + name + "' is already bound"); 2742 } else if (context == RefDirectiveContext && !seenOperands.count(operand)) { 2743 return emitError(loc, "operand '" + name + 2744 "' must be bound before it is referenced"); 2745 } 2746 element = std::make_unique<OperandVariable>(operand); 2747 return ::mlir::success(); 2748 } 2749 /// Regions 2750 if (const NamedRegion *region = findArg(op.getRegions(), name)) { 2751 if (context == TopLevelContext || context == CustomDirectiveContext) { 2752 if (hasAllRegions || !seenRegions.insert(region).second) 2753 return emitError(loc, "region '" + name + "' is already bound"); 2754 } else if (context == RefDirectiveContext && !seenRegions.count(region)) { 2755 return emitError(loc, "region '" + name + 2756 "' must be bound before it is referenced"); 2757 } else { 2758 return emitError(loc, "regions can only be used at the top level"); 2759 } 2760 element = std::make_unique<RegionVariable>(region); 2761 return ::mlir::success(); 2762 } 2763 /// Results. 2764 if (const auto *result = findArg(op.getResults(), name)) { 2765 if (context != TypeDirectiveContext) 2766 return emitError(loc, "result variables can can only be used as a child " 2767 "to a 'type' directive"); 2768 element = std::make_unique<ResultVariable>(result); 2769 return ::mlir::success(); 2770 } 2771 /// Successors. 2772 if (const auto *successor = findArg(op.getSuccessors(), name)) { 2773 if (context == TopLevelContext || context == CustomDirectiveContext) { 2774 if (hasAllSuccessors || !seenSuccessors.insert(successor).second) 2775 return emitError(loc, "successor '" + name + "' is already bound"); 2776 } else if (context == RefDirectiveContext && 2777 !seenSuccessors.count(successor)) { 2778 return emitError(loc, "successor '" + name + 2779 "' must be bound before it is referenced"); 2780 } else { 2781 return emitError(loc, "successors can only be used at the top level"); 2782 } 2783 2784 element = std::make_unique<SuccessorVariable>(successor); 2785 return ::mlir::success(); 2786 } 2787 return emitError(loc, "expected variable to refer to an argument, region, " 2788 "result, or successor"); 2789 } 2790 2791 LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element, 2792 ParserContext context) { 2793 FormatToken dirTok = curToken; 2794 consumeToken(); 2795 2796 switch (dirTok.getKind()) { 2797 case FormatToken::kw_attr_dict: 2798 return parseAttrDictDirective(element, dirTok.getLoc(), context, 2799 /*withKeyword=*/false); 2800 case FormatToken::kw_attr_dict_w_keyword: 2801 return parseAttrDictDirective(element, dirTok.getLoc(), context, 2802 /*withKeyword=*/true); 2803 case FormatToken::kw_custom: 2804 return parseCustomDirective(element, dirTok.getLoc(), context); 2805 case FormatToken::kw_functional_type: 2806 return parseFunctionalTypeDirective(element, dirTok, context); 2807 case FormatToken::kw_operands: 2808 return parseOperandsDirective(element, dirTok.getLoc(), context); 2809 case FormatToken::kw_qualified: 2810 return parseQualifiedDirective(element, dirTok, context); 2811 case FormatToken::kw_regions: 2812 return parseRegionsDirective(element, dirTok.getLoc(), context); 2813 case FormatToken::kw_results: 2814 return parseResultsDirective(element, dirTok.getLoc(), context); 2815 case FormatToken::kw_successors: 2816 return parseSuccessorsDirective(element, dirTok.getLoc(), context); 2817 case FormatToken::kw_ref: 2818 return parseReferenceDirective(element, dirTok.getLoc(), context); 2819 case FormatToken::kw_type: 2820 return parseTypeDirective(element, dirTok, context); 2821 2822 default: 2823 llvm_unreachable("unknown directive token"); 2824 } 2825 } 2826 2827 LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element, 2828 ParserContext context) { 2829 FormatToken literalTok = curToken; 2830 if (context != TopLevelContext) { 2831 return emitError( 2832 literalTok.getLoc(), 2833 "literals may only be used in a top-level section of the format"); 2834 } 2835 consumeToken(); 2836 2837 StringRef value = literalTok.getSpelling().drop_front().drop_back(); 2838 2839 // The parsed literal is a space element (`` or ` `). 2840 if (value.empty() || (value.size() == 1 && value.front() == ' ')) { 2841 element = std::make_unique<SpaceElement>(!value.empty()); 2842 return ::mlir::success(); 2843 } 2844 // The parsed literal is a newline element. 2845 if (value == "\\n") { 2846 element = std::make_unique<NewlineElement>(); 2847 return ::mlir::success(); 2848 } 2849 2850 // Check that the parsed literal is valid. 2851 if (!isValidLiteral(value, [&](Twine diag) { 2852 (void)emitError(literalTok.getLoc(), 2853 "expected valid literal but got '" + value + 2854 "': " + diag); 2855 })) 2856 return failure(); 2857 element = std::make_unique<LiteralElement>(value); 2858 return ::mlir::success(); 2859 } 2860 2861 LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element, 2862 ParserContext context) { 2863 llvm::SMLoc curLoc = curToken.getLoc(); 2864 if (context != TopLevelContext) 2865 return emitError(curLoc, "optional groups can only be used as top-level " 2866 "elements"); 2867 consumeToken(); 2868 2869 // Parse the child elements for this optional group. 2870 std::vector<std::unique_ptr<Element>> thenElements, elseElements; 2871 Optional<unsigned> anchorIdx; 2872 do { 2873 if (failed(parseOptionalChildElement(thenElements, anchorIdx))) 2874 return ::mlir::failure(); 2875 } while (curToken.getKind() != FormatToken::r_paren); 2876 consumeToken(); 2877 2878 // Parse the `else` elements of this optional group. 2879 if (curToken.getKind() == FormatToken::colon) { 2880 consumeToken(); 2881 if (failed(parseToken(FormatToken::l_paren, 2882 "expected '(' to start else branch " 2883 "of optional group"))) 2884 return failure(); 2885 do { 2886 llvm::SMLoc childLoc = curToken.getLoc(); 2887 elseElements.push_back({}); 2888 if (failed(parseElement(elseElements.back(), TopLevelContext)) || 2889 failed(verifyOptionalChildElement(elseElements.back().get(), childLoc, 2890 /*isAnchor=*/false))) 2891 return failure(); 2892 } while (curToken.getKind() != FormatToken::r_paren); 2893 consumeToken(); 2894 } 2895 2896 if (failed(parseToken(FormatToken::question, 2897 "expected '?' after optional group"))) 2898 return ::mlir::failure(); 2899 2900 // The optional group is required to have an anchor. 2901 if (!anchorIdx) 2902 return emitError(curLoc, "optional group specified no anchor element"); 2903 2904 // The first parsable element of the group must be able to be parsed in an 2905 // optional fashion. 2906 auto parseBegin = llvm::find_if_not(thenElements, [](auto &element) { 2907 return isa<WhitespaceElement>(element.get()); 2908 }); 2909 Element *firstElement = parseBegin->get(); 2910 if (!isa<AttributeVariable>(firstElement) && 2911 !isa<LiteralElement>(firstElement) && 2912 !isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement)) 2913 return emitError(curLoc, 2914 "first parsable element of an operand group must be " 2915 "an attribute, literal, operand, or region"); 2916 2917 auto parseStart = parseBegin - thenElements.begin(); 2918 element = std::make_unique<OptionalElement>( 2919 std::move(thenElements), std::move(elseElements), *anchorIdx, parseStart); 2920 return ::mlir::success(); 2921 } 2922 2923 LogicalResult FormatParser::parseOptionalChildElement( 2924 std::vector<std::unique_ptr<Element>> &childElements, 2925 Optional<unsigned> &anchorIdx) { 2926 llvm::SMLoc childLoc = curToken.getLoc(); 2927 childElements.push_back({}); 2928 if (failed(parseElement(childElements.back(), TopLevelContext))) 2929 return ::mlir::failure(); 2930 2931 // Check to see if this element is the anchor of the optional group. 2932 bool isAnchor = curToken.getKind() == FormatToken::caret; 2933 if (isAnchor) { 2934 if (anchorIdx) 2935 return emitError(childLoc, "only one element can be marked as the anchor " 2936 "of an optional group"); 2937 anchorIdx = childElements.size() - 1; 2938 consumeToken(); 2939 } 2940 2941 return verifyOptionalChildElement(childElements.back().get(), childLoc, 2942 isAnchor); 2943 } 2944 2945 LogicalResult FormatParser::verifyOptionalChildElement(Element *element, 2946 llvm::SMLoc childLoc, 2947 bool isAnchor) { 2948 return TypeSwitch<Element *, LogicalResult>(element) 2949 // All attributes can be within the optional group, but only optional 2950 // attributes can be the anchor. 2951 .Case([&](AttributeVariable *attrEle) { 2952 if (isAnchor && !attrEle->getVar()->attr.isOptional()) 2953 return emitError(childLoc, "only optional attributes can be used to " 2954 "anchor an optional group"); 2955 return ::mlir::success(); 2956 }) 2957 // Only optional-like(i.e. variadic) operands can be within an optional 2958 // group. 2959 .Case([&](OperandVariable *ele) { 2960 if (!ele->getVar()->isVariableLength()) 2961 return emitError(childLoc, "only variable length operands can be " 2962 "used within an optional group"); 2963 return ::mlir::success(); 2964 }) 2965 // Only optional-like(i.e. variadic) results can be within an optional 2966 // group. 2967 .Case([&](ResultVariable *ele) { 2968 if (!ele->getVar()->isVariableLength()) 2969 return emitError(childLoc, "only variable length results can be " 2970 "used within an optional group"); 2971 return ::mlir::success(); 2972 }) 2973 .Case([&](RegionVariable *) { 2974 // TODO: When ODS has proper support for marking "optional" regions, add 2975 // a check here. 2976 return ::mlir::success(); 2977 }) 2978 .Case([&](TypeDirective *ele) { 2979 return verifyOptionalChildElement(ele->getOperand(), childLoc, 2980 /*isAnchor=*/false); 2981 }) 2982 .Case([&](FunctionalTypeDirective *ele) { 2983 if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc, 2984 /*isAnchor=*/false))) 2985 return failure(); 2986 return verifyOptionalChildElement(ele->getResults(), childLoc, 2987 /*isAnchor=*/false); 2988 }) 2989 // Literals, whitespace, and custom directives may be used, but they can't 2990 // anchor the group. 2991 .Case<LiteralElement, WhitespaceElement, CustomDirective, 2992 FunctionalTypeDirective, OptionalElement>([&](Element *) { 2993 if (isAnchor) 2994 return emitError(childLoc, "only variables and types can be used " 2995 "to anchor an optional group"); 2996 return ::mlir::success(); 2997 }) 2998 .Default([&](Element *) { 2999 return emitError(childLoc, "only literals, types, and variables can be " 3000 "used within an optional group"); 3001 }); 3002 } 3003 3004 LogicalResult 3005 FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element, 3006 llvm::SMLoc loc, ParserContext context, 3007 bool withKeyword) { 3008 if (context == TypeDirectiveContext) 3009 return emitError(loc, "'attr-dict' directive can only be used as a " 3010 "top-level directive"); 3011 3012 if (context == RefDirectiveContext) { 3013 if (!hasAttrDict) 3014 return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior " 3015 "'attr-dict' directive"); 3016 3017 // Otherwise, this is a top-level context. 3018 } else { 3019 if (hasAttrDict) 3020 return emitError(loc, "'attr-dict' directive has already been seen"); 3021 hasAttrDict = true; 3022 } 3023 3024 element = std::make_unique<AttrDictDirective>(withKeyword); 3025 return ::mlir::success(); 3026 } 3027 3028 LogicalResult 3029 FormatParser::parseCustomDirective(std::unique_ptr<Element> &element, 3030 llvm::SMLoc loc, ParserContext context) { 3031 llvm::SMLoc curLoc = curToken.getLoc(); 3032 if (context != TopLevelContext) 3033 return emitError(loc, "'custom' is only valid as a top-level directive"); 3034 3035 // Parse the custom directive name. 3036 if (failed(parseToken(FormatToken::less, 3037 "expected '<' before custom directive name"))) 3038 return ::mlir::failure(); 3039 3040 FormatToken nameTok = curToken; 3041 if (failed(parseToken(FormatToken::identifier, 3042 "expected custom directive name identifier")) || 3043 failed(parseToken(FormatToken::greater, 3044 "expected '>' after custom directive name")) || 3045 failed(parseToken(FormatToken::l_paren, 3046 "expected '(' before custom directive parameters"))) 3047 return ::mlir::failure(); 3048 3049 // Parse the child elements for this optional group.= 3050 std::vector<std::unique_ptr<Element>> elements; 3051 do { 3052 if (failed(parseCustomDirectiveParameter(elements))) 3053 return ::mlir::failure(); 3054 if (curToken.getKind() != FormatToken::comma) 3055 break; 3056 consumeToken(); 3057 } while (true); 3058 3059 if (failed(parseToken(FormatToken::r_paren, 3060 "expected ')' after custom directive parameters"))) 3061 return ::mlir::failure(); 3062 3063 // After parsing all of the elements, ensure that all type directives refer 3064 // only to variables. 3065 for (auto &ele : elements) { 3066 if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) { 3067 if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) { 3068 return emitError(curLoc, "type directives within a custom directive " 3069 "may only refer to variables"); 3070 } 3071 } 3072 } 3073 3074 element = std::make_unique<CustomDirective>(nameTok.getSpelling(), 3075 std::move(elements)); 3076 return ::mlir::success(); 3077 } 3078 3079 LogicalResult FormatParser::parseCustomDirectiveParameter( 3080 std::vector<std::unique_ptr<Element>> ¶meters) { 3081 llvm::SMLoc childLoc = curToken.getLoc(); 3082 parameters.push_back({}); 3083 if (failed(parseElement(parameters.back(), CustomDirectiveContext))) 3084 return ::mlir::failure(); 3085 3086 // Verify that the element can be placed within a custom directive. 3087 if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable, 3088 OperandVariable, RegionVariable, SuccessorVariable>( 3089 parameters.back().get())) { 3090 return emitError(childLoc, "only variables and types may be used as " 3091 "parameters to a custom directive"); 3092 } 3093 return ::mlir::success(); 3094 } 3095 3096 LogicalResult FormatParser::parseFunctionalTypeDirective( 3097 std::unique_ptr<Element> &element, FormatToken tok, ParserContext context) { 3098 llvm::SMLoc loc = tok.getLoc(); 3099 if (context != TopLevelContext) 3100 return emitError( 3101 loc, "'functional-type' is only valid as a top-level directive"); 3102 3103 // Parse the main operand. 3104 std::unique_ptr<Element> inputs, results; 3105 if (failed(parseToken(FormatToken::l_paren, 3106 "expected '(' before argument list")) || 3107 failed(parseTypeDirectiveOperand(inputs)) || 3108 failed(parseToken(FormatToken::comma, 3109 "expected ',' after inputs argument")) || 3110 failed(parseTypeDirectiveOperand(results)) || 3111 failed( 3112 parseToken(FormatToken::r_paren, "expected ')' after argument list"))) 3113 return ::mlir::failure(); 3114 element = std::make_unique<FunctionalTypeDirective>(std::move(inputs), 3115 std::move(results)); 3116 return ::mlir::success(); 3117 } 3118 3119 LogicalResult 3120 FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element, 3121 llvm::SMLoc loc, ParserContext context) { 3122 if (context == RefDirectiveContext) { 3123 if (!fmt.allOperands) 3124 return emitError(loc, "'ref' of 'operands' is not bound by a prior " 3125 "'operands' directive"); 3126 3127 } else if (context == TopLevelContext || context == CustomDirectiveContext) { 3128 if (fmt.allOperands || !seenOperands.empty()) 3129 return emitError(loc, "'operands' directive creates overlap in format"); 3130 fmt.allOperands = true; 3131 } 3132 element = std::make_unique<OperandsDirective>(); 3133 return ::mlir::success(); 3134 } 3135 3136 LogicalResult 3137 FormatParser::parseReferenceDirective(std::unique_ptr<Element> &element, 3138 llvm::SMLoc loc, ParserContext context) { 3139 if (context != CustomDirectiveContext) 3140 return emitError(loc, "'ref' is only valid within a `custom` directive"); 3141 3142 std::unique_ptr<Element> operand; 3143 if (failed(parseToken(FormatToken::l_paren, 3144 "expected '(' before argument list")) || 3145 failed(parseElement(operand, RefDirectiveContext)) || 3146 failed( 3147 parseToken(FormatToken::r_paren, "expected ')' after argument list"))) 3148 return ::mlir::failure(); 3149 3150 element = std::make_unique<RefDirective>(std::move(operand)); 3151 return ::mlir::success(); 3152 } 3153 3154 LogicalResult 3155 FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element, 3156 llvm::SMLoc loc, ParserContext context) { 3157 if (context == TypeDirectiveContext) 3158 return emitError(loc, "'regions' is only valid as a top-level directive"); 3159 if (context == RefDirectiveContext) { 3160 if (!hasAllRegions) 3161 return emitError(loc, "'ref' of 'regions' is not bound by a prior " 3162 "'regions' directive"); 3163 3164 // Otherwise, this is a TopLevel directive. 3165 } else { 3166 if (hasAllRegions || !seenRegions.empty()) 3167 return emitError(loc, "'regions' directive creates overlap in format"); 3168 hasAllRegions = true; 3169 } 3170 element = std::make_unique<RegionsDirective>(); 3171 return ::mlir::success(); 3172 } 3173 3174 LogicalResult 3175 FormatParser::parseResultsDirective(std::unique_ptr<Element> &element, 3176 llvm::SMLoc loc, ParserContext context) { 3177 if (context != TypeDirectiveContext) 3178 return emitError(loc, "'results' directive can can only be used as a child " 3179 "to a 'type' directive"); 3180 element = std::make_unique<ResultsDirective>(); 3181 return ::mlir::success(); 3182 } 3183 3184 LogicalResult 3185 FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element, 3186 llvm::SMLoc loc, ParserContext context) { 3187 if (context == TypeDirectiveContext) 3188 return emitError(loc, 3189 "'successors' is only valid as a top-level directive"); 3190 if (context == RefDirectiveContext) { 3191 if (!hasAllSuccessors) 3192 return emitError(loc, "'ref' of 'successors' is not bound by a prior " 3193 "'successors' directive"); 3194 3195 // Otherwise, this is a TopLevel directive. 3196 } else { 3197 if (hasAllSuccessors || !seenSuccessors.empty()) 3198 return emitError(loc, "'successors' directive creates overlap in format"); 3199 hasAllSuccessors = true; 3200 } 3201 element = std::make_unique<SuccessorsDirective>(); 3202 return ::mlir::success(); 3203 } 3204 3205 LogicalResult 3206 FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, 3207 FormatToken tok, ParserContext context) { 3208 llvm::SMLoc loc = tok.getLoc(); 3209 if (context == TypeDirectiveContext) 3210 return emitError(loc, "'type' cannot be used as a child of another `type`"); 3211 3212 bool isRefChild = context == RefDirectiveContext; 3213 std::unique_ptr<Element> operand; 3214 if (failed(parseToken(FormatToken::l_paren, 3215 "expected '(' before argument list")) || 3216 failed(parseTypeDirectiveOperand(operand, isRefChild)) || 3217 failed( 3218 parseToken(FormatToken::r_paren, "expected ')' after argument list"))) 3219 return ::mlir::failure(); 3220 3221 element = std::make_unique<TypeDirective>(std::move(operand)); 3222 return ::mlir::success(); 3223 } 3224 3225 LogicalResult 3226 FormatParser::parseQualifiedDirective(std::unique_ptr<Element> &element, 3227 FormatToken tok, ParserContext context) { 3228 if (failed(parseToken(FormatToken::l_paren, 3229 "expected '(' before argument list")) || 3230 failed(parseElement(element, context)) || 3231 failed( 3232 parseToken(FormatToken::r_paren, "expected ')' after argument list"))) 3233 return failure(); 3234 if (auto *attr = dyn_cast<AttributeVariable>(element.get())) { 3235 attr->setShouldBeQualified(); 3236 } else if (auto *type = dyn_cast<TypeDirective>(element.get())) { 3237 type->setShouldBeQualified(); 3238 } else { 3239 return emitError( 3240 tok.getLoc(), 3241 "'qualified' directive expects an attribute or a `type` directive"); 3242 } 3243 return success(); 3244 } 3245 3246 LogicalResult 3247 FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element, 3248 bool isRefChild) { 3249 llvm::SMLoc loc = curToken.getLoc(); 3250 if (failed(parseElement(element, TypeDirectiveContext))) 3251 return ::mlir::failure(); 3252 if (isa<LiteralElement>(element.get())) 3253 return emitError( 3254 loc, "'type' directive operand expects variable or directive operand"); 3255 3256 if (auto *var = dyn_cast<OperandVariable>(element.get())) { 3257 unsigned opIdx = var->getVar() - op.operand_begin(); 3258 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx))) 3259 return emitError(loc, "'type' of '" + var->getVar()->name + 3260 "' is already bound"); 3261 if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx))) 3262 return emitError(loc, "'ref' of 'type($" + var->getVar()->name + 3263 ")' is not bound by a prior 'type' directive"); 3264 seenOperandTypes.set(opIdx); 3265 } else if (auto *var = dyn_cast<ResultVariable>(element.get())) { 3266 unsigned resIdx = var->getVar() - op.result_begin(); 3267 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx))) 3268 return emitError(loc, "'type' of '" + var->getVar()->name + 3269 "' is already bound"); 3270 if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx))) 3271 return emitError(loc, "'ref' of 'type($" + var->getVar()->name + 3272 ")' is not bound by a prior 'type' directive"); 3273 seenResultTypes.set(resIdx); 3274 } else if (isa<OperandsDirective>(&*element)) { 3275 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any())) 3276 return emitError(loc, "'operands' 'type' is already bound"); 3277 if (isRefChild && !fmt.allOperandTypes) 3278 return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior " 3279 "'type' directive"); 3280 fmt.allOperandTypes = true; 3281 } else if (isa<ResultsDirective>(&*element)) { 3282 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any())) 3283 return emitError(loc, "'results' 'type' is already bound"); 3284 if (isRefChild && !fmt.allResultTypes) 3285 return emitError(loc, "'ref' of 'type(results)' is not bound by a prior " 3286 "'type' directive"); 3287 fmt.allResultTypes = true; 3288 } else { 3289 return emitError(loc, "invalid argument to 'type' directive"); 3290 } 3291 return ::mlir::success(); 3292 } 3293 3294 //===----------------------------------------------------------------------===// 3295 // Interface 3296 //===----------------------------------------------------------------------===// 3297 3298 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) { 3299 // TODO: Operator doesn't expose all necessary functionality via 3300 // the const interface. 3301 Operator &op = const_cast<Operator &>(constOp); 3302 if (!op.hasAssemblyFormat()) 3303 return; 3304 3305 // Parse the format description. 3306 llvm::SourceMgr mgr; 3307 mgr.AddNewSourceBuffer( 3308 llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), llvm::SMLoc()); 3309 OperationFormat format(op); 3310 if (failed(FormatParser(mgr, format, op).parse())) { 3311 // Exit the process if format errors are treated as fatal. 3312 if (formatErrorIsFatal) { 3313 // Invoke the interrupt handlers to run the file cleanup handlers. 3314 llvm::sys::RunInterruptHandlers(); 3315 std::exit(1); 3316 } 3317 return; 3318 } 3319 3320 // Generate the printer and parser based on the parsed format. 3321 format.genParser(op, opClass); 3322 format.genPrinter(op, opClass); 3323 } 3324