1 //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This classes used by the implementation details of Op types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_OPIMPLEMENTATION_H
14 #define MLIR_IR_OPIMPLEMENTATION_H
15 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Support/SMLoc.h"
21 
22 namespace mlir {
23 class AsmParsedResourceEntry;
24 class AsmResourceBuilder;
25 class Builder;
26 
27 //===----------------------------------------------------------------------===//
28 // AsmDialectResourceHandle
29 //===----------------------------------------------------------------------===//
30 
31 /// This class represents an opaque handle to a dialect resource entry.
32 class AsmDialectResourceHandle {
33 public:
34   AsmDialectResourceHandle() = default;
AsmDialectResourceHandle(void * resource,TypeID resourceID,Dialect * dialect)35   AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
36       : resource(resource), opaqueID(resourceID), dialect(dialect) {}
37   bool operator==(const AsmDialectResourceHandle &other) const {
38     return resource == other.resource;
39   }
40 
41   /// Return an opaque pointer to the referenced resource.
getResource()42   void *getResource() const { return resource; }
43 
44   /// Return the type ID of the resource.
getTypeID()45   TypeID getTypeID() const { return opaqueID; }
46 
47   /// Return the dialect that owns the resource.
getDialect()48   Dialect *getDialect() const { return dialect; }
49 
50 private:
51   /// The opaque handle to the dialect resource.
52   void *resource = nullptr;
53   /// The type of the resource referenced.
54   TypeID opaqueID;
55   /// The dialect owning the given resource.
56   Dialect *dialect;
57 };
58 
59 /// This class represents a CRTP base class for dialect resource handles. It
60 /// abstracts away various utilities necessary for defined derived resource
61 /// handles.
62 template <typename DerivedT, typename ResourceT, typename DialectT>
63 class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
64 public:
65   using Dialect = DialectT;
66 
67   /// Construct a handle from a pointer to the resource. The given pointer
68   /// should be guaranteed to live beyond the life of this handle.
AsmDialectResourceHandleBase(ResourceT * resource,DialectT * dialect)69   AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
70       : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)71   AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
72       : AsmDialectResourceHandle(handle) {
73     assert(handle.getTypeID() == TypeID::get<DerivedT>());
74   }
75 
76   /// Return the resource referenced by this handle.
getResource()77   ResourceT *getResource() {
78     return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
79   }
getResource()80   const ResourceT *getResource() const {
81     return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
82   }
83 
84   /// Return the dialect that owns the resource.
getDialect()85   DialectT *getDialect() const {
86     return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
87   }
88 
89   /// Support llvm style casting.
classof(const AsmDialectResourceHandle * handle)90   static bool classof(const AsmDialectResourceHandle *handle) {
91     return handle->getTypeID() == TypeID::get<DerivedT>();
92   }
93 };
94 
hash_value(const AsmDialectResourceHandle & param)95 inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
96   return llvm::hash_value(param.getResource());
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // AsmPrinter
101 //===----------------------------------------------------------------------===//
102 
103 /// This base class exposes generic asm printer hooks, usable across the various
104 /// derived printers.
105 class AsmPrinter {
106 public:
107   /// This class contains the internal default implementation of the base
108   /// printer methods.
109   class Impl;
110 
111   /// Initialize the printer with the given internal implementation.
AsmPrinter(Impl & impl)112   AsmPrinter(Impl &impl) : impl(&impl) {}
113   virtual ~AsmPrinter();
114 
115   /// Return the raw output stream used by this printer.
116   virtual raw_ostream &getStream() const;
117 
118   /// Print the given floating point value in a stabilized form that can be
119   /// roundtripped through the IR. This is the companion to the 'parseFloat'
120   /// hook on the AsmParser.
121   virtual void printFloat(const APFloat &value);
122 
123   virtual void printType(Type type);
124   virtual void printAttribute(Attribute attr);
125 
126   /// Trait to check if `AttrType` provides a `print` method.
127   template <typename AttrOrType>
128   using has_print_method =
129       decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
130   template <typename AttrOrType>
131   using detect_has_print_method =
132       llvm::is_detected<has_print_method, AttrOrType>;
133 
134   /// Print the provided attribute in the context of an operation custom
135   /// printer/parser: this will invoke directly the print method on the
136   /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
137   template <typename AttrOrType,
138             std::enable_if_t<detect_has_print_method<AttrOrType>::value>
139                 *sfinae = nullptr>
printStrippedAttrOrType(AttrOrType attrOrType)140   void printStrippedAttrOrType(AttrOrType attrOrType) {
141     if (succeeded(printAlias(attrOrType)))
142       return;
143     attrOrType.print(*this);
144   }
145 
146   /// Print the provided array of attributes or types in the context of an
147   /// operation custom printer/parser: this will invoke directly the print
148   /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
149   /// most cases.
150   template <typename AttrOrType,
151             std::enable_if_t<detect_has_print_method<AttrOrType>::value>
152                 *sfinae = nullptr>
printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes)153   void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
154     llvm::interleaveComma(
155         attrOrTypes, getStream(),
156         [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
157   }
158 
159   /// SFINAE for printing the provided attribute in the context of an operation
160   /// custom printer in the case where the attribute does not define a print
161   /// method.
162   template <typename AttrOrType,
163             std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
164                 *sfinae = nullptr>
printStrippedAttrOrType(AttrOrType attrOrType)165   void printStrippedAttrOrType(AttrOrType attrOrType) {
166     *this << attrOrType;
167   }
168 
169   /// Print the given attribute without its type. The corresponding parser must
170   /// provide a valid type for the attribute.
171   virtual void printAttributeWithoutType(Attribute attr);
172 
173   /// Print the given string as a keyword, or a quoted and escaped string if it
174   /// has any special or non-printable characters in it.
175   virtual void printKeywordOrString(StringRef keyword);
176 
177   /// Print the given string as a symbol reference, i.e. a form representable by
178   /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
179   /// with '@'. The reference is surrounded with ""'s and escaped if it has any
180   /// special or non-printable characters in it.
181   virtual void printSymbolName(StringRef symbolRef);
182 
183   /// Print a handle to the given dialect resource.
184   void printResourceHandle(const AsmDialectResourceHandle &resource);
185 
186   /// Print an optional arrow followed by a type list.
187   template <typename TypeRange>
printOptionalArrowTypeList(TypeRange && types)188   void printOptionalArrowTypeList(TypeRange &&types) {
189     if (types.begin() != types.end())
190       printArrowTypeList(types);
191   }
192   template <typename TypeRange>
printArrowTypeList(TypeRange && types)193   void printArrowTypeList(TypeRange &&types) {
194     auto &os = getStream() << " -> ";
195 
196     bool wrapped = !llvm::hasSingleElement(types) ||
197                    (*types.begin()).template isa<FunctionType>();
198     if (wrapped)
199       os << '(';
200     llvm::interleaveComma(types, *this);
201     if (wrapped)
202       os << ')';
203   }
204 
205   /// Print the two given type ranges in a functional form.
206   template <typename InputRangeT, typename ResultRangeT>
printFunctionalType(InputRangeT && inputs,ResultRangeT && results)207   void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
208     auto &os = getStream();
209     os << '(';
210     llvm::interleaveComma(inputs, *this);
211     os << ')';
212     printArrowTypeList(results);
213   }
214 
215 protected:
216   /// Initialize the printer with no internal implementation. In this case, all
217   /// virtual methods of this class must be overriden.
AsmPrinter()218   AsmPrinter() {}
219 
220 private:
221   AsmPrinter(const AsmPrinter &) = delete;
222   void operator=(const AsmPrinter &) = delete;
223 
224   /// Print the alias for the given attribute, return failure if no alias could
225   /// be printed.
226   virtual LogicalResult printAlias(Attribute attr);
227 
228   /// Print the alias for the given type, return failure if no alias could
229   /// be printed.
230   virtual LogicalResult printAlias(Type type);
231 
232   /// The internal implementation of the printer.
233   Impl *impl{nullptr};
234 };
235 
236 template <typename AsmPrinterT>
237 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
238                         AsmPrinterT &>
239 operator<<(AsmPrinterT &p, Type type) {
240   p.printType(type);
241   return p;
242 }
243 
244 template <typename AsmPrinterT>
245 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
246                         AsmPrinterT &>
247 operator<<(AsmPrinterT &p, Attribute attr) {
248   p.printAttribute(attr);
249   return p;
250 }
251 
252 template <typename AsmPrinterT>
253 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
254                         AsmPrinterT &>
255 operator<<(AsmPrinterT &p, const APFloat &value) {
256   p.printFloat(value);
257   return p;
258 }
259 template <typename AsmPrinterT>
260 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
261                         AsmPrinterT &>
262 operator<<(AsmPrinterT &p, float value) {
263   return p << APFloat(value);
264 }
265 template <typename AsmPrinterT>
266 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
267                         AsmPrinterT &>
268 operator<<(AsmPrinterT &p, double value) {
269   return p << APFloat(value);
270 }
271 
272 // Support printing anything that isn't convertible to one of the other
273 // streamable types, even if it isn't exactly one of them. For example, we want
274 // to print FunctionType with the Type version above, not have it match this.
275 template <
276     typename AsmPrinterT, typename T,
277     typename std::enable_if<!std::is_convertible<T &, Value &>::value &&
278                                 !std::is_convertible<T &, Type &>::value &&
279                                 !std::is_convertible<T &, Attribute &>::value &&
280                                 !std::is_convertible<T &, ValueRange>::value &&
281                                 !std::is_convertible<T &, APFloat &>::value &&
282                                 !llvm::is_one_of<T, bool, float, double>::value,
283                             T>::type * = nullptr>
284 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
285                         AsmPrinterT &>
286 operator<<(AsmPrinterT &p, const T &other) {
287   p.getStream() << other;
288   return p;
289 }
290 
291 template <typename AsmPrinterT>
292 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
293                         AsmPrinterT &>
294 operator<<(AsmPrinterT &p, bool value) {
295   return p << (value ? StringRef("true") : "false");
296 }
297 
298 template <typename AsmPrinterT, typename ValueRangeT>
299 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
300                         AsmPrinterT &>
301 operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
302   llvm::interleaveComma(types, p);
303   return p;
304 }
305 template <typename AsmPrinterT>
306 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
307                         AsmPrinterT &>
308 operator<<(AsmPrinterT &p, const TypeRange &types) {
309   llvm::interleaveComma(types, p);
310   return p;
311 }
312 template <typename AsmPrinterT, typename ElementT>
313 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
314                         AsmPrinterT &>
315 operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
316   llvm::interleaveComma(types, p);
317   return p;
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // OpAsmPrinter
322 //===----------------------------------------------------------------------===//
323 
324 /// This is a pure-virtual base class that exposes the asmprinter hooks
325 /// necessary to implement a custom print() method.
326 class OpAsmPrinter : public AsmPrinter {
327 public:
328   using AsmPrinter::AsmPrinter;
329   ~OpAsmPrinter() override;
330 
331   /// Print a newline and indent the printer to the start of the current
332   /// operation.
333   virtual void printNewline() = 0;
334 
335   /// Print a block argument in the usual format of:
336   ///   %ssaName : type {attr1=42} loc("here")
337   /// where location printing is controlled by the standard internal option.
338   /// You may pass omitType=true to not print a type, and pass an empty
339   /// attribute list if you don't care for attributes.
340   virtual void printRegionArgument(BlockArgument arg,
341                                    ArrayRef<NamedAttribute> argAttrs = {},
342                                    bool omitType = false) = 0;
343 
344   /// Print implementations for various things an operation contains.
345   virtual void printOperand(Value value) = 0;
346   virtual void printOperand(Value value, raw_ostream &os) = 0;
347 
348   /// Print a comma separated list of operands.
349   template <typename ContainerType>
printOperands(const ContainerType & container)350   void printOperands(const ContainerType &container) {
351     printOperands(container.begin(), container.end());
352   }
353 
354   /// Print a comma separated list of operands.
355   template <typename IteratorType>
printOperands(IteratorType it,IteratorType end)356   void printOperands(IteratorType it, IteratorType end) {
357     if (it == end)
358       return;
359     printOperand(*it);
360     for (++it; it != end; ++it) {
361       getStream() << ", ";
362       printOperand(*it);
363     }
364   }
365 
366   /// Print the given successor.
367   virtual void printSuccessor(Block *successor) = 0;
368 
369   /// Print the successor and its operands.
370   virtual void printSuccessorAndUseList(Block *successor,
371                                         ValueRange succOperands) = 0;
372 
373   /// If the specified operation has attributes, print out an attribute
374   /// dictionary with their values.  elidedAttrs allows the client to ignore
375   /// specific well known attributes, commonly used if the attribute value is
376   /// printed some other way (like as a fixed operand).
377   virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
378                                      ArrayRef<StringRef> elidedAttrs = {}) = 0;
379 
380   /// If the specified operation has attributes, print out an attribute
381   /// dictionary prefixed with 'attributes'.
382   virtual void
383   printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
384                                    ArrayRef<StringRef> elidedAttrs = {}) = 0;
385 
386   /// Print the entire operation with the default generic assembly form.
387   /// If `printOpName` is true, then the operation name is printed (the default)
388   /// otherwise it is omitted and the print will start with the operand list.
389   virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
390 
391   /// Prints a region.
392   /// If 'printEntryBlockArgs' is false, the arguments of the
393   /// block are not printed. If 'printBlockTerminator' is false, the terminator
394   /// operation of the block is not printed. If printEmptyBlock is true, then
395   /// the block header is printed even if the block is empty.
396   virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
397                            bool printBlockTerminators = true,
398                            bool printEmptyBlock = false) = 0;
399 
400   /// Renumber the arguments for the specified region to the same names as the
401   /// SSA values in namesToUse.  This may only be used for IsolatedFromAbove
402   /// operations.  If any entry in namesToUse is null, the corresponding
403   /// argument name is left alone.
404   virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
405 
406   /// Prints an affine map of SSA ids, where SSA id names are used in place
407   /// of dims/symbols.
408   /// Operand values must come from single-result sources, and be valid
409   /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
410   virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
411                                       ValueRange operands) = 0;
412 
413   /// Prints an affine expression of SSA ids with SSA id names used instead of
414   /// dims and symbols.
415   /// Operand values must come from single-result sources, and be valid
416   /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
417   virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
418                                        ValueRange symOperands) = 0;
419 
420   /// Print the complete type of an operation in functional form.
421   void printFunctionalType(Operation *op);
422   using AsmPrinter::printFunctionalType;
423 };
424 
425 // Make the implementations convenient to use.
426 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
427   p.printOperand(value);
428   return p;
429 }
430 
431 template <typename T,
432           typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
433                                       !std::is_convertible<T &, Value &>::value,
434                                   T>::type * = nullptr>
435 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
436   p.printOperands(values);
437   return p;
438 }
439 
440 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
441   p.printSuccessor(value);
442   return p;
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // AsmParser
447 //===----------------------------------------------------------------------===//
448 
449 /// This base class exposes generic asm parser hooks, usable across the various
450 /// derived parsers.
451 class AsmParser {
452 public:
453   AsmParser() = default;
454   virtual ~AsmParser();
455 
456   MLIRContext *getContext() const;
457 
458   /// Return the location of the original name token.
459   virtual SMLoc getNameLoc() const = 0;
460 
461   //===--------------------------------------------------------------------===//
462   // Utilities
463   //===--------------------------------------------------------------------===//
464 
465   /// Emit a diagnostic at the specified location and return failure.
466   virtual InFlightDiagnostic emitError(SMLoc loc,
467                                        const Twine &message = {}) = 0;
468 
469   /// Return a builder which provides useful access to MLIRContext, global
470   /// objects like types and attributes.
471   virtual Builder &getBuilder() const = 0;
472 
473   /// Get the location of the next token and store it into the argument.  This
474   /// always succeeds.
475   virtual SMLoc getCurrentLocation() = 0;
getCurrentLocation(SMLoc * loc)476   ParseResult getCurrentLocation(SMLoc *loc) {
477     *loc = getCurrentLocation();
478     return success();
479   }
480 
481   /// Re-encode the given source location as an MLIR location and return it.
482   /// Note: This method should only be used when a `Location` is necessary, as
483   /// the encoding process is not efficient.
484   virtual Location getEncodedSourceLoc(SMLoc loc) = 0;
485 
486   //===--------------------------------------------------------------------===//
487   // Token Parsing
488   //===--------------------------------------------------------------------===//
489 
490   /// Parse a '->' token.
491   virtual ParseResult parseArrow() = 0;
492 
493   /// Parse a '->' token if present
494   virtual ParseResult parseOptionalArrow() = 0;
495 
496   /// Parse a `{` token.
497   virtual ParseResult parseLBrace() = 0;
498 
499   /// Parse a `{` token if present.
500   virtual ParseResult parseOptionalLBrace() = 0;
501 
502   /// Parse a `}` token.
503   virtual ParseResult parseRBrace() = 0;
504 
505   /// Parse a `}` token if present.
506   virtual ParseResult parseOptionalRBrace() = 0;
507 
508   /// Parse a `:` token.
509   virtual ParseResult parseColon() = 0;
510 
511   /// Parse a `:` token if present.
512   virtual ParseResult parseOptionalColon() = 0;
513 
514   /// Parse a `,` token.
515   virtual ParseResult parseComma() = 0;
516 
517   /// Parse a `,` token if present.
518   virtual ParseResult parseOptionalComma() = 0;
519 
520   /// Parse a `=` token.
521   virtual ParseResult parseEqual() = 0;
522 
523   /// Parse a `=` token if present.
524   virtual ParseResult parseOptionalEqual() = 0;
525 
526   /// Parse a '<' token.
527   virtual ParseResult parseLess() = 0;
528 
529   /// Parse a '<' token if present.
530   virtual ParseResult parseOptionalLess() = 0;
531 
532   /// Parse a '>' token.
533   virtual ParseResult parseGreater() = 0;
534 
535   /// Parse a '>' token if present.
536   virtual ParseResult parseOptionalGreater() = 0;
537 
538   /// Parse a '?' token.
539   virtual ParseResult parseQuestion() = 0;
540 
541   /// Parse a '?' token if present.
542   virtual ParseResult parseOptionalQuestion() = 0;
543 
544   /// Parse a '+' token.
545   virtual ParseResult parsePlus() = 0;
546 
547   /// Parse a '+' token if present.
548   virtual ParseResult parseOptionalPlus() = 0;
549 
550   /// Parse a '*' token.
551   virtual ParseResult parseStar() = 0;
552 
553   /// Parse a '*' token if present.
554   virtual ParseResult parseOptionalStar() = 0;
555 
556   /// Parse a '|' token.
557   virtual ParseResult parseVerticalBar() = 0;
558 
559   /// Parse a '|' token if present.
560   virtual ParseResult parseOptionalVerticalBar() = 0;
561 
562   /// Parse a quoted string token.
parseString(std::string * string)563   ParseResult parseString(std::string *string) {
564     auto loc = getCurrentLocation();
565     if (parseOptionalString(string))
566       return emitError(loc, "expected string");
567     return success();
568   }
569 
570   /// Parse a quoted string token if present.
571   virtual ParseResult parseOptionalString(std::string *string) = 0;
572 
573   /// Parse a `(` token.
574   virtual ParseResult parseLParen() = 0;
575 
576   /// Parse a `(` token if present.
577   virtual ParseResult parseOptionalLParen() = 0;
578 
579   /// Parse a `)` token.
580   virtual ParseResult parseRParen() = 0;
581 
582   /// Parse a `)` token if present.
583   virtual ParseResult parseOptionalRParen() = 0;
584 
585   /// Parse a `[` token.
586   virtual ParseResult parseLSquare() = 0;
587 
588   /// Parse a `[` token if present.
589   virtual ParseResult parseOptionalLSquare() = 0;
590 
591   /// Parse a `]` token.
592   virtual ParseResult parseRSquare() = 0;
593 
594   /// Parse a `]` token if present.
595   virtual ParseResult parseOptionalRSquare() = 0;
596 
597   /// Parse a `...` token if present;
598   virtual ParseResult parseOptionalEllipsis() = 0;
599 
600   /// Parse a floating point value from the stream.
601   virtual ParseResult parseFloat(double &result) = 0;
602 
603   /// Parse an integer value from the stream.
604   template <typename IntT>
parseInteger(IntT & result)605   ParseResult parseInteger(IntT &result) {
606     auto loc = getCurrentLocation();
607     OptionalParseResult parseResult = parseOptionalInteger(result);
608     if (!parseResult.hasValue())
609       return emitError(loc, "expected integer value");
610     return *parseResult;
611   }
612 
613   /// Parse an optional integer value from the stream.
614   virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
615 
616   template <typename IntT>
parseOptionalInteger(IntT & result)617   OptionalParseResult parseOptionalInteger(IntT &result) {
618     auto loc = getCurrentLocation();
619 
620     // Parse the unsigned variant.
621     APInt uintResult;
622     OptionalParseResult parseResult = parseOptionalInteger(uintResult);
623     if (!parseResult.hasValue() || failed(*parseResult))
624       return parseResult;
625 
626     // Try to convert to the provided integer type.  sextOrTrunc is correct even
627     // for unsigned types because parseOptionalInteger ensures the sign bit is
628     // zero for non-negated integers.
629     result =
630         (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
631     if (APInt(uintResult.getBitWidth(), result) != uintResult)
632       return emitError(loc, "integer value too large");
633     return success();
634   }
635 
636   /// These are the supported delimiters around operand lists and region
637   /// argument lists, used by parseOperandList.
638   enum class Delimiter {
639     /// Zero or more operands with no delimiters.
640     None,
641     /// Parens surrounding zero or more operands.
642     Paren,
643     /// Square brackets surrounding zero or more operands.
644     Square,
645     /// <> brackets surrounding zero or more operands.
646     LessGreater,
647     /// {} brackets surrounding zero or more operands.
648     Braces,
649     /// Parens supporting zero or more operands, or nothing.
650     OptionalParen,
651     /// Square brackets supporting zero or more ops, or nothing.
652     OptionalSquare,
653     /// <> brackets supporting zero or more ops, or nothing.
654     OptionalLessGreater,
655     /// {} brackets surrounding zero or more operands, or nothing.
656     OptionalBraces,
657   };
658 
659   /// Parse a list of comma-separated items with an optional delimiter.  If a
660   /// delimiter is provided, then an empty list is allowed.  If not, then at
661   /// least one element will be parsed.
662   ///
663   /// contextMessage is an optional message appended to "expected '('" sorts of
664   /// diagnostics when parsing the delimeters.
665   virtual ParseResult
666   parseCommaSeparatedList(Delimiter delimiter,
667                           function_ref<ParseResult()> parseElementFn,
668                           StringRef contextMessage = StringRef()) = 0;
669 
670   /// Parse a comma separated list of elements that must have at least one entry
671   /// in it.
672   ParseResult
parseCommaSeparatedList(function_ref<ParseResult ()> parseElementFn)673   parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
674     return parseCommaSeparatedList(Delimiter::None, parseElementFn);
675   }
676 
677   //===--------------------------------------------------------------------===//
678   // Keyword Parsing
679   //===--------------------------------------------------------------------===//
680 
681   /// This class represents a StringSwitch like class that is useful for parsing
682   /// expected keywords. On construction, it invokes `parseKeyword` and
683   /// processes each of the provided cases statements until a match is hit. The
684   /// provided `ResultT` must be assignable from `failure()`.
685   template <typename ResultT = ParseResult>
686   class KeywordSwitch {
687   public:
KeywordSwitch(AsmParser & parser)688     KeywordSwitch(AsmParser &parser)
689         : parser(parser), loc(parser.getCurrentLocation()) {
690       if (failed(parser.parseKeywordOrCompletion(&keyword)))
691         result = failure();
692     }
693 
694     /// Case that uses the provided value when true.
Case(StringLiteral str,ResultT value)695     KeywordSwitch &Case(StringLiteral str, ResultT value) {
696       return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
697     }
Default(ResultT value)698     KeywordSwitch &Default(ResultT value) {
699       return Default([&](StringRef, SMLoc) { return std::move(value); });
700     }
701     /// Case that invokes the provided functor when true. The parameters passed
702     /// to the functor are the keyword, and the location of the keyword (in case
703     /// any errors need to be emitted).
704     template <typename FnT>
705     std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
Case(StringLiteral str,FnT && fn)706     Case(StringLiteral str, FnT &&fn) {
707       if (result)
708         return *this;
709 
710       // If the word was empty, record this as a completion.
711       if (keyword.empty())
712         parser.codeCompleteExpectedTokens(str);
713       else if (keyword == str)
714         result.emplace(std::move(fn(keyword, loc)));
715       return *this;
716     }
717     template <typename FnT>
718     std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
Default(FnT && fn)719     Default(FnT &&fn) {
720       if (!result)
721         result.emplace(fn(keyword, loc));
722       return *this;
723     }
724 
725     /// Returns true if this switch has a value yet.
hasValue()726     bool hasValue() const { return result.has_value(); }
727 
728     /// Return the result of the switch.
ResultT()729     LLVM_NODISCARD operator ResultT() {
730       if (!result)
731         return parser.emitError(loc, "unexpected keyword: ") << keyword;
732       return std::move(*result);
733     }
734 
735   private:
736     /// The parser used to construct this switch.
737     AsmParser &parser;
738 
739     /// The location of the keyword, used to emit errors as necessary.
740     SMLoc loc;
741 
742     /// The parsed keyword itself.
743     StringRef keyword;
744 
745     /// The result of the switch statement or none if currently unknown.
746     Optional<ResultT> result;
747   };
748 
749   /// Parse a given keyword.
parseKeyword(StringRef keyword)750   ParseResult parseKeyword(StringRef keyword) {
751     return parseKeyword(keyword, "");
752   }
753   virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
754 
755   /// Parse a keyword into 'keyword'.
parseKeyword(StringRef * keyword)756   ParseResult parseKeyword(StringRef *keyword) {
757     auto loc = getCurrentLocation();
758     if (parseOptionalKeyword(keyword))
759       return emitError(loc, "expected valid keyword");
760     return success();
761   }
762 
763   /// Parse the given keyword if present.
764   virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
765 
766   /// Parse a keyword, if present, into 'keyword'.
767   virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
768 
769   /// Parse a keyword, if present, and if one of the 'allowedValues',
770   /// into 'keyword'
771   virtual ParseResult
772   parseOptionalKeyword(StringRef *keyword,
773                        ArrayRef<StringRef> allowedValues) = 0;
774 
775   /// Parse a keyword or a quoted string.
parseKeywordOrString(std::string * result)776   ParseResult parseKeywordOrString(std::string *result) {
777     if (failed(parseOptionalKeywordOrString(result)))
778       return emitError(getCurrentLocation())
779              << "expected valid keyword or string";
780     return success();
781   }
782 
783   /// Parse an optional keyword or string.
784   virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
785 
786   //===--------------------------------------------------------------------===//
787   // Attribute/Type Parsing
788   //===--------------------------------------------------------------------===//
789 
790   /// Invoke the `getChecked` method of the given Attribute or Type class, using
791   /// the provided location to emit errors in the case of failure. Note that
792   /// unlike `OpBuilder::getType`, this method does not implicitly insert a
793   /// context parameter.
794   template <typename T, typename... ParamsT>
getChecked(SMLoc loc,ParamsT &&...params)795   auto getChecked(SMLoc loc, ParamsT &&...params) {
796     return T::getChecked([&] { return emitError(loc); },
797                          std::forward<ParamsT>(params)...);
798   }
799   /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
800   /// errors.
801   template <typename T, typename... ParamsT>
getChecked(ParamsT &&...params)802   auto getChecked(ParamsT &&...params) {
803     return T::getChecked([&] { return emitError(getNameLoc()); },
804                          std::forward<ParamsT>(params)...);
805   }
806 
807   //===--------------------------------------------------------------------===//
808   // Attribute Parsing
809   //===--------------------------------------------------------------------===//
810 
811   /// Parse an arbitrary attribute of a given type and return it in result.
812   virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
813 
814   /// Parse a custom attribute with the provided callback, unless the next
815   /// token is `#`, in which case the generic parser is invoked.
816   virtual ParseResult parseCustomAttributeWithFallback(
817       Attribute &result, Type type,
818       function_ref<ParseResult(Attribute &result, Type type)>
819           parseAttribute) = 0;
820 
821   /// Parse an attribute of a specific kind and type.
822   template <typename AttrType>
823   ParseResult parseAttribute(AttrType &result, Type type = {}) {
824     SMLoc loc = getCurrentLocation();
825 
826     // Parse any kind of attribute.
827     Attribute attr;
828     if (parseAttribute(attr, type))
829       return failure();
830 
831     // Check for the right kind of attribute.
832     if (!(result = attr.dyn_cast<AttrType>()))
833       return emitError(loc, "invalid kind of attribute specified");
834 
835     return success();
836   }
837 
838   /// Parse an arbitrary attribute and return it in result.  This also adds the
839   /// attribute to the specified attribute list with the specified name.
parseAttribute(Attribute & result,StringRef attrName,NamedAttrList & attrs)840   ParseResult parseAttribute(Attribute &result, StringRef attrName,
841                              NamedAttrList &attrs) {
842     return parseAttribute(result, Type(), attrName, attrs);
843   }
844 
845   /// Parse an attribute of a specific kind and type.
846   template <typename AttrType>
parseAttribute(AttrType & result,StringRef attrName,NamedAttrList & attrs)847   ParseResult parseAttribute(AttrType &result, StringRef attrName,
848                              NamedAttrList &attrs) {
849     return parseAttribute(result, Type(), attrName, attrs);
850   }
851 
852   /// Parse an arbitrary attribute of a given type and populate it in `result`.
853   /// This also adds the attribute to the specified attribute list with the
854   /// specified name.
855   template <typename AttrType>
parseAttribute(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)856   ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
857                              NamedAttrList &attrs) {
858     SMLoc loc = getCurrentLocation();
859 
860     // Parse any kind of attribute.
861     Attribute attr;
862     if (parseAttribute(attr, type))
863       return failure();
864 
865     // Check for the right kind of attribute.
866     result = attr.dyn_cast<AttrType>();
867     if (!result)
868       return emitError(loc, "invalid kind of attribute specified");
869 
870     attrs.append(attrName, result);
871     return success();
872   }
873 
874   /// Trait to check if `AttrType` provides a `parse` method.
875   template <typename AttrType>
876   using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
877                                                     std::declval<Type>()));
878   template <typename AttrType>
879   using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
880 
881   /// Parse a custom attribute of a given type unless the next token is `#`, in
882   /// which case the generic parser is invoked. The parsed attribute is
883   /// populated in `result` and also added to the specified attribute list with
884   /// the specified name.
885   template <typename AttrType>
886   std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)887   parseCustomAttributeWithFallback(AttrType &result, Type type,
888                                    StringRef attrName, NamedAttrList &attrs) {
889     SMLoc loc = getCurrentLocation();
890 
891     // Parse any kind of attribute.
892     Attribute attr;
893     if (parseCustomAttributeWithFallback(
894             attr, type, [&](Attribute &result, Type type) -> ParseResult {
895               result = AttrType::parse(*this, type);
896               if (!result)
897                 return failure();
898               return success();
899             }))
900       return failure();
901 
902     // Check for the right kind of attribute.
903     result = attr.dyn_cast<AttrType>();
904     if (!result)
905       return emitError(loc, "invalid kind of attribute specified");
906 
907     attrs.append(attrName, result);
908     return success();
909   }
910 
911   /// SFINAE parsing method for Attribute that don't implement a parse method.
912   template <typename AttrType>
913   std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)914   parseCustomAttributeWithFallback(AttrType &result, Type type,
915                                    StringRef attrName, NamedAttrList &attrs) {
916     return parseAttribute(result, type, attrName, attrs);
917   }
918 
919   /// Parse a custom attribute of a given type unless the next token is `#`, in
920   /// which case the generic parser is invoked. The parsed attribute is
921   /// populated in `result`.
922   template <typename AttrType>
923   std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType & result)924   parseCustomAttributeWithFallback(AttrType &result) {
925     SMLoc loc = getCurrentLocation();
926 
927     // Parse any kind of attribute.
928     Attribute attr;
929     if (parseCustomAttributeWithFallback(
930             attr, {}, [&](Attribute &result, Type type) -> ParseResult {
931               result = AttrType::parse(*this, type);
932               return success(!!result);
933             }))
934       return failure();
935 
936     // Check for the right kind of attribute.
937     result = attr.dyn_cast<AttrType>();
938     if (!result)
939       return emitError(loc, "invalid kind of attribute specified");
940     return success();
941   }
942 
943   /// SFINAE parsing method for Attribute that don't implement a parse method.
944   template <typename AttrType>
945   std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType & result)946   parseCustomAttributeWithFallback(AttrType &result) {
947     return parseAttribute(result);
948   }
949 
950   /// Parse an arbitrary optional attribute of a given type and return it in
951   /// result.
952   virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
953                                                      Type type = {}) = 0;
954 
955   /// Parse an optional array attribute and return it in result.
956   virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
957                                                      Type type = {}) = 0;
958 
959   /// Parse an optional string attribute and return it in result.
960   virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
961                                                      Type type = {}) = 0;
962 
963   /// Parse an optional attribute of a specific type and add it to the list with
964   /// the specified name.
965   template <typename AttrType>
parseOptionalAttribute(AttrType & result,StringRef attrName,NamedAttrList & attrs)966   OptionalParseResult parseOptionalAttribute(AttrType &result,
967                                              StringRef attrName,
968                                              NamedAttrList &attrs) {
969     return parseOptionalAttribute(result, Type(), attrName, attrs);
970   }
971 
972   /// Parse an optional attribute of a specific type and add it to the list with
973   /// the specified name.
974   template <typename AttrType>
parseOptionalAttribute(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)975   OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
976                                              StringRef attrName,
977                                              NamedAttrList &attrs) {
978     OptionalParseResult parseResult = parseOptionalAttribute(result, type);
979     if (parseResult.hasValue() && succeeded(*parseResult))
980       attrs.append(attrName, result);
981     return parseResult;
982   }
983 
984   /// Parse a named dictionary into 'result' if it is present.
985   virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
986 
987   /// Parse a named dictionary into 'result' if the `attributes` keyword is
988   /// present.
989   virtual ParseResult
990   parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
991 
992   /// Parse an affine map instance into 'map'.
993   virtual ParseResult parseAffineMap(AffineMap &map) = 0;
994 
995   /// Parse an integer set instance into 'set'.
996   virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
997 
998   //===--------------------------------------------------------------------===//
999   // Identifier Parsing
1000   //===--------------------------------------------------------------------===//
1001 
1002   /// Parse an @-identifier and store it (without the '@' symbol) in a string
1003   /// attribute named 'attrName'.
parseSymbolName(StringAttr & result,StringRef attrName,NamedAttrList & attrs)1004   ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
1005                               NamedAttrList &attrs) {
1006     if (failed(parseOptionalSymbolName(result, attrName, attrs)))
1007       return emitError(getCurrentLocation())
1008              << "expected valid '@'-identifier for symbol name";
1009     return success();
1010   }
1011 
1012   /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1013   /// string attribute named 'attrName'.
1014   virtual ParseResult parseOptionalSymbolName(StringAttr &result,
1015                                               StringRef attrName,
1016                                               NamedAttrList &attrs) = 0;
1017 
1018   //===--------------------------------------------------------------------===//
1019   // Resource Parsing
1020   //===--------------------------------------------------------------------===//
1021 
1022   /// Parse a handle to a resource within the assembly format.
1023   template <typename ResourceT>
parseResourceHandle()1024   FailureOr<ResourceT> parseResourceHandle() {
1025     SMLoc handleLoc = getCurrentLocation();
1026     FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(
1027         getContext()->getOrLoadDialect<typename ResourceT::Dialect>());
1028     if (failed(handle))
1029       return failure();
1030     if (auto *result = dyn_cast<ResourceT>(&*handle))
1031       return std::move(*result);
1032     return emitError(handleLoc) << "provided resource handle differs from the "
1033                                    "expected resource type";
1034   }
1035 
1036   //===--------------------------------------------------------------------===//
1037   // Type Parsing
1038   //===--------------------------------------------------------------------===//
1039 
1040   /// Parse a type.
1041   virtual ParseResult parseType(Type &result) = 0;
1042 
1043   /// Parse a custom type with the provided callback, unless the next
1044   /// token is `#`, in which case the generic parser is invoked.
1045   virtual ParseResult parseCustomTypeWithFallback(
1046       Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
1047 
1048   /// Parse an optional type.
1049   virtual OptionalParseResult parseOptionalType(Type &result) = 0;
1050 
1051   /// Parse a type of a specific type.
1052   template <typename TypeT>
parseType(TypeT & result)1053   ParseResult parseType(TypeT &result) {
1054     SMLoc loc = getCurrentLocation();
1055 
1056     // Parse any kind of type.
1057     Type type;
1058     if (parseType(type))
1059       return failure();
1060 
1061     // Check for the right kind of type.
1062     result = type.dyn_cast<TypeT>();
1063     if (!result)
1064       return emitError(loc, "invalid kind of type specified");
1065 
1066     return success();
1067   }
1068 
1069   /// Trait to check if `TypeT` provides a `parse` method.
1070   template <typename TypeT>
1071   using type_has_parse_method =
1072       decltype(TypeT::parse(std::declval<AsmParser &>()));
1073   template <typename TypeT>
1074   using detect_type_has_parse_method =
1075       llvm::is_detected<type_has_parse_method, TypeT>;
1076 
1077   /// Parse a custom Type of a given type unless the next token is `#`, in
1078   /// which case the generic parser is invoked. The parsed Type is
1079   /// populated in `result`.
1080   template <typename TypeT>
1081   std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
parseCustomTypeWithFallback(TypeT & result)1082   parseCustomTypeWithFallback(TypeT &result) {
1083     SMLoc loc = getCurrentLocation();
1084 
1085     // Parse any kind of Type.
1086     Type type;
1087     if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
1088           result = TypeT::parse(*this);
1089           return success(!!result);
1090         }))
1091       return failure();
1092 
1093     // Check for the right kind of Type.
1094     result = type.dyn_cast<TypeT>();
1095     if (!result)
1096       return emitError(loc, "invalid kind of Type specified");
1097     return success();
1098   }
1099 
1100   /// SFINAE parsing method for Type that don't implement a parse method.
1101   template <typename TypeT>
1102   std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
parseCustomTypeWithFallback(TypeT & result)1103   parseCustomTypeWithFallback(TypeT &result) {
1104     return parseType(result);
1105   }
1106 
1107   /// Parse a type list.
parseTypeList(SmallVectorImpl<Type> & result)1108   ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
1109     return parseCommaSeparatedList(
1110         [&]() { return parseType(result.emplace_back()); });
1111   }
1112 
1113   /// Parse an arrow followed by a type list.
1114   virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1115 
1116   /// Parse an optional arrow followed by a type list.
1117   virtual ParseResult
1118   parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1119 
1120   /// Parse a colon followed by a type.
1121   virtual ParseResult parseColonType(Type &result) = 0;
1122 
1123   /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
1124   template <typename TypeType>
parseColonType(TypeType & result)1125   ParseResult parseColonType(TypeType &result) {
1126     SMLoc loc = getCurrentLocation();
1127 
1128     // Parse any kind of type.
1129     Type type;
1130     if (parseColonType(type))
1131       return failure();
1132 
1133     // Check for the right kind of type.
1134     result = type.dyn_cast<TypeType>();
1135     if (!result)
1136       return emitError(loc, "invalid kind of type specified");
1137 
1138     return success();
1139   }
1140 
1141   /// Parse a colon followed by a type list, which must have at least one type.
1142   virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
1143 
1144   /// Parse an optional colon followed by a type list, which if present must
1145   /// have at least one type.
1146   virtual ParseResult
1147   parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
1148 
1149   /// Parse a keyword followed by a type.
parseKeywordType(const char * keyword,Type & result)1150   ParseResult parseKeywordType(const char *keyword, Type &result) {
1151     return failure(parseKeyword(keyword) || parseType(result));
1152   }
1153 
1154   /// Add the specified type to the end of the specified type list and return
1155   /// success.  This is a helper designed to allow parse methods to be simple
1156   /// and chain through || operators.
addTypeToList(Type type,SmallVectorImpl<Type> & result)1157   ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
1158     result.push_back(type);
1159     return success();
1160   }
1161 
1162   /// Add the specified types to the end of the specified type list and return
1163   /// success.  This is a helper designed to allow parse methods to be simple
1164   /// and chain through || operators.
addTypesToList(ArrayRef<Type> types,SmallVectorImpl<Type> & result)1165   ParseResult addTypesToList(ArrayRef<Type> types,
1166                              SmallVectorImpl<Type> &result) {
1167     result.append(types.begin(), types.end());
1168     return success();
1169   }
1170 
1171   /// Parse a dimension list of a tensor or memref type.  This populates the
1172   /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set
1173   /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable.
1174   ///
1175   ///   dimension-list ::= eps | dimension (`x` dimension)*
1176   ///   dimension-list-with-trailing-x ::= (dimension `x`)*
1177   ///   dimension ::= `?` | decimal-literal
1178   ///
1179   /// When `allowDynamic` is not set, this is used to parse:
1180   ///
1181   ///   static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
1182   ///   static-dimension-list-with-trailing-x ::= (dimension `x`)*
1183   virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1184                                          bool allowDynamic = true,
1185                                          bool withTrailingX = true) = 0;
1186 
1187   /// Parse an 'x' token in a dimension list, handling the case where the x is
1188   /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1189   /// next token.
1190   virtual ParseResult parseXInDimensionList() = 0;
1191 
1192 protected:
1193   /// Parse a handle to a resource within the assembly format for the given
1194   /// dialect.
1195   virtual FailureOr<AsmDialectResourceHandle>
1196   parseResourceHandle(Dialect *dialect) = 0;
1197 
1198   //===--------------------------------------------------------------------===//
1199   // Code Completion
1200   //===--------------------------------------------------------------------===//
1201 
1202   /// Parse a keyword, or an empty string if the current location signals a code
1203   /// completion.
1204   virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
1205 
1206   /// Signal the code completion of a set of expected tokens.
1207   virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
1208 
1209 private:
1210   AsmParser(const AsmParser &) = delete;
1211   void operator=(const AsmParser &) = delete;
1212 };
1213 
1214 //===----------------------------------------------------------------------===//
1215 // OpAsmParser
1216 //===----------------------------------------------------------------------===//
1217 
1218 /// The OpAsmParser has methods for interacting with the asm parser: parsing
1219 /// things from it, emitting errors etc.  It has an intentionally high-level API
1220 /// that is designed to reduce/constrain syntax innovation in individual
1221 /// operations.
1222 ///
1223 /// For example, consider an op like this:
1224 ///
1225 ///    %x = load %p[%1, %2] : memref<...>
1226 ///
1227 /// The "%x = load" tokens are already parsed and therefore invisible to the
1228 /// custom op parser.  This can be supported by calling `parseOperandList` to
1229 /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1230 /// parse the indices, then calling `parseColonTypeList` to parse the result
1231 /// type.
1232 ///
1233 class OpAsmParser : public AsmParser {
1234 public:
1235   using AsmParser::AsmParser;
1236   ~OpAsmParser() override;
1237 
1238   /// Parse a loc(...) specifier if present, filling in result if so.
1239   /// Location for BlockArgument and Operation may be deferred with an alias, in
1240   /// which case an OpaqueLoc is set and will be resolved when parsing
1241   /// completes.
1242   virtual ParseResult
1243   parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1244 
1245   /// Return the name of the specified result in the specified syntax, as well
1246   /// as the sub-element in the name.  It returns an empty string and ~0U for
1247   /// invalid result numbers.  For example, in this operation:
1248   ///
1249   ///  %x, %y:2, %z = foo.op
1250   ///
1251   ///    getResultName(0) == {"x", 0 }
1252   ///    getResultName(1) == {"y", 0 }
1253   ///    getResultName(2) == {"y", 1 }
1254   ///    getResultName(3) == {"z", 0 }
1255   ///    getResultName(4) == {"", ~0U }
1256   virtual std::pair<StringRef, unsigned>
1257   getResultName(unsigned resultNo) const = 0;
1258 
1259   /// Return the number of declared SSA results.  This returns 4 for the foo.op
1260   /// example in the comment for `getResultName`.
1261   virtual size_t getNumResults() const = 0;
1262 
1263   // These methods emit an error and return failure or success. This allows
1264   // these to be chained together into a linear sequence of || expressions in
1265   // many cases.
1266 
1267   /// Parse an operation in its generic form.
1268   /// The parsed operation is parsed in the current context and inserted in the
1269   /// provided block and insertion point. The results produced by this operation
1270   /// aren't mapped to any named value in the parser. Returns nullptr on
1271   /// failure.
1272   virtual Operation *parseGenericOperation(Block *insertBlock,
1273                                            Block::iterator insertPt) = 0;
1274 
1275   /// Parse the name of an operation, in the custom form. On success, return a
1276   /// an object of type 'OperationName'. Otherwise, failure is returned.
1277   virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1278 
1279   //===--------------------------------------------------------------------===//
1280   // Operand Parsing
1281   //===--------------------------------------------------------------------===//
1282 
1283   /// This is the representation of an operand reference.
1284   struct UnresolvedOperand {
1285     SMLoc location;  // Location of the token.
1286     StringRef name;  // Value name, e.g. %42 or %abc
1287     unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1288   };
1289 
1290   /// Parse different components, viz., use-info of operand(s), successor(s),
1291   /// region(s), attribute(s) and function-type, of the generic form of an
1292   /// operation instance and populate the input operation-state 'result' with
1293   /// those components. If any of the components is explicitly provided, then
1294   /// skip parsing that component.
1295   virtual ParseResult parseGenericOperationAfterOpName(
1296       OperationState &result,
1297       Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = llvm::None,
1298       Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None,
1299       Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1300           llvm::None,
1301       Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
1302       Optional<FunctionType> parsedFnType = llvm::None) = 0;
1303 
1304   /// Parse a single SSA value operand name along with a result number if
1305   /// `allowResultNumber` is true.
1306   virtual ParseResult parseOperand(UnresolvedOperand &result,
1307                                    bool allowResultNumber = true) = 0;
1308 
1309   /// Parse a single operand if present.
1310   virtual OptionalParseResult
1311   parseOptionalOperand(UnresolvedOperand &result,
1312                        bool allowResultNumber = true) = 0;
1313 
1314   /// Parse zero or more SSA comma-separated operand references with a specified
1315   /// surrounding delimiter, and an optional required operand count.
1316   virtual ParseResult
1317   parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1318                    Delimiter delimiter = Delimiter::None,
1319                    bool allowResultNumber = true,
1320                    int requiredOperandCount = -1) = 0;
1321 
1322   /// Parse a specified number of comma separated operands.
1323   ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1324                                int requiredOperandCount,
1325                                Delimiter delimiter = Delimiter::None) {
1326     return parseOperandList(result, delimiter,
1327                             /*allowResultNumber=*/true, requiredOperandCount);
1328   }
1329 
1330   /// Parse zero or more trailing SSA comma-separated trailing operand
1331   /// references with a specified surrounding delimiter, and an optional
1332   /// required operand count. A leading comma is expected before the
1333   /// operands.
1334   ParseResult
1335   parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1336                            Delimiter delimiter = Delimiter::None) {
1337     if (failed(parseOptionalComma()))
1338       return success(); // The comma is optional.
1339     return parseOperandList(result, delimiter);
1340   }
1341 
1342   /// Resolve an operand to an SSA value, emitting an error on failure.
1343   virtual ParseResult resolveOperand(const UnresolvedOperand &operand,
1344                                      Type type,
1345                                      SmallVectorImpl<Value> &result) = 0;
1346 
1347   /// Resolve a list of operands to SSA values, emitting an error on failure, or
1348   /// appending the results to the list on success. This method should be used
1349   /// when all operands have the same type.
resolveOperands(ArrayRef<UnresolvedOperand> operands,Type type,SmallVectorImpl<Value> & result)1350   ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands, Type type,
1351                               SmallVectorImpl<Value> &result) {
1352     for (auto elt : operands)
1353       if (resolveOperand(elt, type, result))
1354         return failure();
1355     return success();
1356   }
1357 
1358   /// Resolve a list of operands and a list of operand types to SSA values,
1359   /// emitting an error and returning failure, or appending the results
1360   /// to the list on success.
resolveOperands(ArrayRef<UnresolvedOperand> operands,ArrayRef<Type> types,SMLoc loc,SmallVectorImpl<Value> & result)1361   ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands,
1362                               ArrayRef<Type> types, SMLoc loc,
1363                               SmallVectorImpl<Value> &result) {
1364     if (operands.size() != types.size())
1365       return emitError(loc)
1366              << operands.size() << " operands present, but expected "
1367              << types.size();
1368 
1369     for (unsigned i = 0, e = operands.size(); i != e; ++i)
1370       if (resolveOperand(operands[i], types[i], result))
1371         return failure();
1372     return success();
1373   }
1374   template <typename Operands>
resolveOperands(Operands && operands,Type type,SMLoc loc,SmallVectorImpl<Value> & result)1375   ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
1376                               SmallVectorImpl<Value> &result) {
1377     return resolveOperands(std::forward<Operands>(operands),
1378                            ArrayRef<Type>(type), loc, result);
1379   }
1380   template <typename Operands, typename Types>
1381   std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
resolveOperands(Operands && operands,Types && types,SMLoc loc,SmallVectorImpl<Value> & result)1382   resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
1383                   SmallVectorImpl<Value> &result) {
1384     size_t operandSize = std::distance(operands.begin(), operands.end());
1385     size_t typeSize = std::distance(types.begin(), types.end());
1386     if (operandSize != typeSize)
1387       return emitError(loc)
1388              << operandSize << " operands present, but expected " << typeSize;
1389 
1390     for (auto it : llvm::zip(operands, types))
1391       if (resolveOperand(std::get<0>(it), std::get<1>(it), result))
1392         return failure();
1393     return success();
1394   }
1395 
1396   /// Parses an affine map attribute where dims and symbols are SSA operands.
1397   /// Operand values must come from single-result sources, and be valid
1398   /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1399   virtual ParseResult
1400   parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands,
1401                          Attribute &map, StringRef attrName,
1402                          NamedAttrList &attrs,
1403                          Delimiter delimiter = Delimiter::Square) = 0;
1404 
1405   /// Parses an affine expression where dims and symbols are SSA operands.
1406   /// Operand values must come from single-result sources, and be valid
1407   /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1408   virtual ParseResult
1409   parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands,
1410                           SmallVectorImpl<UnresolvedOperand> &symbOperands,
1411                           AffineExpr &expr) = 0;
1412 
1413   //===--------------------------------------------------------------------===//
1414   // Argument Parsing
1415   //===--------------------------------------------------------------------===//
1416 
1417   struct Argument {
1418     UnresolvedOperand ssaName;    // SourceLoc, SSA name, result #.
1419     Type type;                    // Type.
1420     DictionaryAttr attrs;         // Attributes if present.
1421     Optional<Location> sourceLoc; // Source location specifier if present.
1422   };
1423 
1424   /// Parse a single argument with the following syntax:
1425   ///
1426   ///   `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
1427   ///
1428   /// If `allowType` is false or `allowAttrs` are false then the respective
1429   /// parts of the grammar are not parsed.
1430   virtual ParseResult parseArgument(Argument &result, bool allowType = false,
1431                                     bool allowAttrs = false) = 0;
1432 
1433   /// Parse a single argument if present.
1434   virtual OptionalParseResult
1435   parseOptionalArgument(Argument &result, bool allowType = false,
1436                         bool allowAttrs = false) = 0;
1437 
1438   /// Parse zero or more arguments with a specified surrounding delimiter.
1439   virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
1440                                         Delimiter delimiter = Delimiter::None,
1441                                         bool allowType = false,
1442                                         bool allowAttrs = false) = 0;
1443 
1444   //===--------------------------------------------------------------------===//
1445   // Region Parsing
1446   //===--------------------------------------------------------------------===//
1447 
1448   /// Parses a region. Any parsed blocks are appended to 'region' and must be
1449   /// moved to the op regions after the op is created. The first block of the
1450   /// region takes 'arguments'.
1451   ///
1452   /// If 'enableNameShadowing' is set to true, the argument names are allowed to
1453   /// shadow the names of other existing SSA values defined above the region
1454   /// scope. 'enableNameShadowing' can only be set to true for regions attached
1455   /// to operations that are 'IsolatedFromAbove'.
1456   virtual ParseResult parseRegion(Region &region,
1457                                   ArrayRef<Argument> arguments = {},
1458                                   bool enableNameShadowing = false) = 0;
1459 
1460   /// Parses a region if present.
1461   virtual OptionalParseResult
1462   parseOptionalRegion(Region &region, ArrayRef<Argument> arguments = {},
1463                       bool enableNameShadowing = false) = 0;
1464 
1465   /// Parses a region if present. If the region is present, a new region is
1466   /// allocated and placed in `region`. If no region is present or on failure,
1467   /// `region` remains untouched.
1468   virtual OptionalParseResult
1469   parseOptionalRegion(std::unique_ptr<Region> &region,
1470                       ArrayRef<Argument> arguments = {},
1471                       bool enableNameShadowing = false) = 0;
1472 
1473   //===--------------------------------------------------------------------===//
1474   // Successor Parsing
1475   //===--------------------------------------------------------------------===//
1476 
1477   /// Parse a single operation successor.
1478   virtual ParseResult parseSuccessor(Block *&dest) = 0;
1479 
1480   /// Parse an optional operation successor.
1481   virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1482 
1483   /// Parse a single operation successor and its operand list.
1484   virtual ParseResult
1485   parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1486 
1487   //===--------------------------------------------------------------------===//
1488   // Type Parsing
1489   //===--------------------------------------------------------------------===//
1490 
1491   /// Parse a list of assignments of the form
1492   ///   (%x1 = %y1, %x2 = %y2, ...)
parseAssignmentList(SmallVectorImpl<Argument> & lhs,SmallVectorImpl<UnresolvedOperand> & rhs)1493   ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
1494                                   SmallVectorImpl<UnresolvedOperand> &rhs) {
1495     OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1496     if (!result.hasValue())
1497       return emitError(getCurrentLocation(), "expected '('");
1498     return result.getValue();
1499   }
1500 
1501   virtual OptionalParseResult
1502   parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
1503                               SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
1504 };
1505 
1506 //===--------------------------------------------------------------------===//
1507 // Dialect OpAsm interface.
1508 //===--------------------------------------------------------------------===//
1509 
1510 /// A functor used to set the name of the start of a result group of an
1511 /// operation. See 'getAsmResultNames' below for more details.
1512 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1513 
1514 /// A functor used to set the name of blocks in regions directly nested under
1515 /// an operation.
1516 using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
1517 
1518 class OpAsmDialectInterface
1519     : public DialectInterface::Base<OpAsmDialectInterface> {
1520 public:
OpAsmDialectInterface(Dialect * dialect)1521   OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1522 
1523   //===------------------------------------------------------------------===//
1524   // Aliases
1525   //===------------------------------------------------------------------===//
1526 
1527   /// Holds the result of `getAlias` hook call.
1528   enum class AliasResult {
1529     /// The object (type or attribute) is not supported by the hook
1530     /// and an alias was not provided.
1531     NoAlias,
1532     /// An alias was provided, but it might be overriden by other hook.
1533     OverridableAlias,
1534     /// An alias was provided and it should be used
1535     /// (no other hooks will be checked).
1536     FinalAlias
1537   };
1538 
1539   /// Hooks for getting an alias identifier alias for a given symbol, that is
1540   /// not necessarily a part of this dialect. The identifier is used in place of
1541   /// the symbol when printing textual IR. These aliases must not contain `.` or
1542   /// end with a numeric digit([0-9]+).
getAlias(Attribute attr,raw_ostream & os)1543   virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1544     return AliasResult::NoAlias;
1545   }
getAlias(Type type,raw_ostream & os)1546   virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1547     return AliasResult::NoAlias;
1548   }
1549 
1550   //===--------------------------------------------------------------------===//
1551   // Resources
1552   //===--------------------------------------------------------------------===//
1553 
1554   /// Declare a resource with the given key, returning a handle to use for any
1555   /// references of this resource key within the IR during parsing. The result
1556   /// of `getResourceKey` on the returned handle is permitted to be different
1557   /// than `key`.
1558   virtual FailureOr<AsmDialectResourceHandle>
declareResource(StringRef key)1559   declareResource(StringRef key) const {
1560     return failure();
1561   }
1562 
1563   /// Return a key to use for the given resource. This key should uniquely
1564   /// identify this resource within the dialect.
1565   virtual std::string
getResourceKey(const AsmDialectResourceHandle & handle)1566   getResourceKey(const AsmDialectResourceHandle &handle) const {
1567     llvm_unreachable(
1568         "Dialect must implement `getResourceKey` when defining resources");
1569   }
1570 
1571   /// Hook for parsing resource entries. Returns failure if the entry was not
1572   /// valid, or could otherwise not be processed correctly. Any necessary errors
1573   /// can be emitted via the provided entry.
1574   virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
1575 
1576   /// Hook for building resources to use during printing. The given `op` may be
1577   /// inspected to help determine what information to include.
1578   /// `referencedResources` contains all of the resources detected when printing
1579   /// 'op'.
1580   virtual void
buildResources(Operation * op,const SetVector<AsmDialectResourceHandle> & referencedResources,AsmResourceBuilder & builder)1581   buildResources(Operation *op,
1582                  const SetVector<AsmDialectResourceHandle> &referencedResources,
1583                  AsmResourceBuilder &builder) const {}
1584 };
1585 } // namespace mlir
1586 
1587 //===--------------------------------------------------------------------===//
1588 // Operation OpAsm interface.
1589 //===--------------------------------------------------------------------===//
1590 
1591 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1592 #include "mlir/IR/OpAsmInterface.h.inc"
1593 
1594 namespace llvm {
1595 template <>
1596 struct DenseMapInfo<mlir::AsmDialectResourceHandle> {
1597   static inline mlir::AsmDialectResourceHandle getEmptyKey() {
1598     return {DenseMapInfo<void *>::getEmptyKey(),
1599             DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr};
1600   }
1601   static inline mlir::AsmDialectResourceHandle getTombstoneKey() {
1602     return {DenseMapInfo<void *>::getTombstoneKey(),
1603             DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr};
1604   }
1605   static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) {
1606     return DenseMapInfo<void *>::getHashValue(handle.getResource());
1607   }
1608   static bool isEqual(const mlir::AsmDialectResourceHandle &lhs,
1609                       const mlir::AsmDialectResourceHandle &rhs) {
1610     return lhs.getResource() == rhs.getResource();
1611   }
1612 };
1613 } // namespace llvm
1614 
1615 #endif
1616