1 //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This classes used by the implementation details of Op types.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef MLIR_IR_OPIMPLEMENTATION_H
14 #define MLIR_IR_OPIMPLEMENTATION_H
15
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Support/SMLoc.h"
21
22 namespace mlir {
23 class AsmParsedResourceEntry;
24 class AsmResourceBuilder;
25 class Builder;
26
27 //===----------------------------------------------------------------------===//
28 // AsmDialectResourceHandle
29 //===----------------------------------------------------------------------===//
30
31 /// This class represents an opaque handle to a dialect resource entry.
32 class AsmDialectResourceHandle {
33 public:
34 AsmDialectResourceHandle() = default;
AsmDialectResourceHandle(void * resource,TypeID resourceID,Dialect * dialect)35 AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
36 : resource(resource), opaqueID(resourceID), dialect(dialect) {}
37 bool operator==(const AsmDialectResourceHandle &other) const {
38 return resource == other.resource;
39 }
40
41 /// Return an opaque pointer to the referenced resource.
getResource()42 void *getResource() const { return resource; }
43
44 /// Return the type ID of the resource.
getTypeID()45 TypeID getTypeID() const { return opaqueID; }
46
47 /// Return the dialect that owns the resource.
getDialect()48 Dialect *getDialect() const { return dialect; }
49
50 private:
51 /// The opaque handle to the dialect resource.
52 void *resource = nullptr;
53 /// The type of the resource referenced.
54 TypeID opaqueID;
55 /// The dialect owning the given resource.
56 Dialect *dialect;
57 };
58
59 /// This class represents a CRTP base class for dialect resource handles. It
60 /// abstracts away various utilities necessary for defined derived resource
61 /// handles.
62 template <typename DerivedT, typename ResourceT, typename DialectT>
63 class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
64 public:
65 using Dialect = DialectT;
66
67 /// Construct a handle from a pointer to the resource. The given pointer
68 /// should be guaranteed to live beyond the life of this handle.
AsmDialectResourceHandleBase(ResourceT * resource,DialectT * dialect)69 AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
70 : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)71 AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
72 : AsmDialectResourceHandle(handle) {
73 assert(handle.getTypeID() == TypeID::get<DerivedT>());
74 }
75
76 /// Return the resource referenced by this handle.
getResource()77 ResourceT *getResource() {
78 return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
79 }
getResource()80 const ResourceT *getResource() const {
81 return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
82 }
83
84 /// Return the dialect that owns the resource.
getDialect()85 DialectT *getDialect() const {
86 return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
87 }
88
89 /// Support llvm style casting.
classof(const AsmDialectResourceHandle * handle)90 static bool classof(const AsmDialectResourceHandle *handle) {
91 return handle->getTypeID() == TypeID::get<DerivedT>();
92 }
93 };
94
hash_value(const AsmDialectResourceHandle & param)95 inline llvm::hash_code hash_value(const AsmDialectResourceHandle ¶m) {
96 return llvm::hash_value(param.getResource());
97 }
98
99 //===----------------------------------------------------------------------===//
100 // AsmPrinter
101 //===----------------------------------------------------------------------===//
102
103 /// This base class exposes generic asm printer hooks, usable across the various
104 /// derived printers.
105 class AsmPrinter {
106 public:
107 /// This class contains the internal default implementation of the base
108 /// printer methods.
109 class Impl;
110
111 /// Initialize the printer with the given internal implementation.
AsmPrinter(Impl & impl)112 AsmPrinter(Impl &impl) : impl(&impl) {}
113 virtual ~AsmPrinter();
114
115 /// Return the raw output stream used by this printer.
116 virtual raw_ostream &getStream() const;
117
118 /// Print the given floating point value in a stabilized form that can be
119 /// roundtripped through the IR. This is the companion to the 'parseFloat'
120 /// hook on the AsmParser.
121 virtual void printFloat(const APFloat &value);
122
123 virtual void printType(Type type);
124 virtual void printAttribute(Attribute attr);
125
126 /// Trait to check if `AttrType` provides a `print` method.
127 template <typename AttrOrType>
128 using has_print_method =
129 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
130 template <typename AttrOrType>
131 using detect_has_print_method =
132 llvm::is_detected<has_print_method, AttrOrType>;
133
134 /// Print the provided attribute in the context of an operation custom
135 /// printer/parser: this will invoke directly the print method on the
136 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
137 template <typename AttrOrType,
138 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
139 *sfinae = nullptr>
printStrippedAttrOrType(AttrOrType attrOrType)140 void printStrippedAttrOrType(AttrOrType attrOrType) {
141 if (succeeded(printAlias(attrOrType)))
142 return;
143 attrOrType.print(*this);
144 }
145
146 /// Print the provided array of attributes or types in the context of an
147 /// operation custom printer/parser: this will invoke directly the print
148 /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
149 /// most cases.
150 template <typename AttrOrType,
151 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
152 *sfinae = nullptr>
printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes)153 void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
154 llvm::interleaveComma(
155 attrOrTypes, getStream(),
156 [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
157 }
158
159 /// SFINAE for printing the provided attribute in the context of an operation
160 /// custom printer in the case where the attribute does not define a print
161 /// method.
162 template <typename AttrOrType,
163 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
164 *sfinae = nullptr>
printStrippedAttrOrType(AttrOrType attrOrType)165 void printStrippedAttrOrType(AttrOrType attrOrType) {
166 *this << attrOrType;
167 }
168
169 /// Print the given attribute without its type. The corresponding parser must
170 /// provide a valid type for the attribute.
171 virtual void printAttributeWithoutType(Attribute attr);
172
173 /// Print the given string as a keyword, or a quoted and escaped string if it
174 /// has any special or non-printable characters in it.
175 virtual void printKeywordOrString(StringRef keyword);
176
177 /// Print the given string as a symbol reference, i.e. a form representable by
178 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
179 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
180 /// special or non-printable characters in it.
181 virtual void printSymbolName(StringRef symbolRef);
182
183 /// Print a handle to the given dialect resource.
184 void printResourceHandle(const AsmDialectResourceHandle &resource);
185
186 /// Print an optional arrow followed by a type list.
187 template <typename TypeRange>
printOptionalArrowTypeList(TypeRange && types)188 void printOptionalArrowTypeList(TypeRange &&types) {
189 if (types.begin() != types.end())
190 printArrowTypeList(types);
191 }
192 template <typename TypeRange>
printArrowTypeList(TypeRange && types)193 void printArrowTypeList(TypeRange &&types) {
194 auto &os = getStream() << " -> ";
195
196 bool wrapped = !llvm::hasSingleElement(types) ||
197 (*types.begin()).template isa<FunctionType>();
198 if (wrapped)
199 os << '(';
200 llvm::interleaveComma(types, *this);
201 if (wrapped)
202 os << ')';
203 }
204
205 /// Print the two given type ranges in a functional form.
206 template <typename InputRangeT, typename ResultRangeT>
printFunctionalType(InputRangeT && inputs,ResultRangeT && results)207 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
208 auto &os = getStream();
209 os << '(';
210 llvm::interleaveComma(inputs, *this);
211 os << ')';
212 printArrowTypeList(results);
213 }
214
215 protected:
216 /// Initialize the printer with no internal implementation. In this case, all
217 /// virtual methods of this class must be overriden.
AsmPrinter()218 AsmPrinter() {}
219
220 private:
221 AsmPrinter(const AsmPrinter &) = delete;
222 void operator=(const AsmPrinter &) = delete;
223
224 /// Print the alias for the given attribute, return failure if no alias could
225 /// be printed.
226 virtual LogicalResult printAlias(Attribute attr);
227
228 /// Print the alias for the given type, return failure if no alias could
229 /// be printed.
230 virtual LogicalResult printAlias(Type type);
231
232 /// The internal implementation of the printer.
233 Impl *impl{nullptr};
234 };
235
236 template <typename AsmPrinterT>
237 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
238 AsmPrinterT &>
239 operator<<(AsmPrinterT &p, Type type) {
240 p.printType(type);
241 return p;
242 }
243
244 template <typename AsmPrinterT>
245 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
246 AsmPrinterT &>
247 operator<<(AsmPrinterT &p, Attribute attr) {
248 p.printAttribute(attr);
249 return p;
250 }
251
252 template <typename AsmPrinterT>
253 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
254 AsmPrinterT &>
255 operator<<(AsmPrinterT &p, const APFloat &value) {
256 p.printFloat(value);
257 return p;
258 }
259 template <typename AsmPrinterT>
260 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
261 AsmPrinterT &>
262 operator<<(AsmPrinterT &p, float value) {
263 return p << APFloat(value);
264 }
265 template <typename AsmPrinterT>
266 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
267 AsmPrinterT &>
268 operator<<(AsmPrinterT &p, double value) {
269 return p << APFloat(value);
270 }
271
272 // Support printing anything that isn't convertible to one of the other
273 // streamable types, even if it isn't exactly one of them. For example, we want
274 // to print FunctionType with the Type version above, not have it match this.
275 template <
276 typename AsmPrinterT, typename T,
277 typename std::enable_if<!std::is_convertible<T &, Value &>::value &&
278 !std::is_convertible<T &, Type &>::value &&
279 !std::is_convertible<T &, Attribute &>::value &&
280 !std::is_convertible<T &, ValueRange>::value &&
281 !std::is_convertible<T &, APFloat &>::value &&
282 !llvm::is_one_of<T, bool, float, double>::value,
283 T>::type * = nullptr>
284 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
285 AsmPrinterT &>
286 operator<<(AsmPrinterT &p, const T &other) {
287 p.getStream() << other;
288 return p;
289 }
290
291 template <typename AsmPrinterT>
292 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
293 AsmPrinterT &>
294 operator<<(AsmPrinterT &p, bool value) {
295 return p << (value ? StringRef("true") : "false");
296 }
297
298 template <typename AsmPrinterT, typename ValueRangeT>
299 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
300 AsmPrinterT &>
301 operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
302 llvm::interleaveComma(types, p);
303 return p;
304 }
305 template <typename AsmPrinterT>
306 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
307 AsmPrinterT &>
308 operator<<(AsmPrinterT &p, const TypeRange &types) {
309 llvm::interleaveComma(types, p);
310 return p;
311 }
312 template <typename AsmPrinterT, typename ElementT>
313 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
314 AsmPrinterT &>
315 operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
316 llvm::interleaveComma(types, p);
317 return p;
318 }
319
320 //===----------------------------------------------------------------------===//
321 // OpAsmPrinter
322 //===----------------------------------------------------------------------===//
323
324 /// This is a pure-virtual base class that exposes the asmprinter hooks
325 /// necessary to implement a custom print() method.
326 class OpAsmPrinter : public AsmPrinter {
327 public:
328 using AsmPrinter::AsmPrinter;
329 ~OpAsmPrinter() override;
330
331 /// Print a newline and indent the printer to the start of the current
332 /// operation.
333 virtual void printNewline() = 0;
334
335 /// Print a block argument in the usual format of:
336 /// %ssaName : type {attr1=42} loc("here")
337 /// where location printing is controlled by the standard internal option.
338 /// You may pass omitType=true to not print a type, and pass an empty
339 /// attribute list if you don't care for attributes.
340 virtual void printRegionArgument(BlockArgument arg,
341 ArrayRef<NamedAttribute> argAttrs = {},
342 bool omitType = false) = 0;
343
344 /// Print implementations for various things an operation contains.
345 virtual void printOperand(Value value) = 0;
346 virtual void printOperand(Value value, raw_ostream &os) = 0;
347
348 /// Print a comma separated list of operands.
349 template <typename ContainerType>
printOperands(const ContainerType & container)350 void printOperands(const ContainerType &container) {
351 printOperands(container.begin(), container.end());
352 }
353
354 /// Print a comma separated list of operands.
355 template <typename IteratorType>
printOperands(IteratorType it,IteratorType end)356 void printOperands(IteratorType it, IteratorType end) {
357 if (it == end)
358 return;
359 printOperand(*it);
360 for (++it; it != end; ++it) {
361 getStream() << ", ";
362 printOperand(*it);
363 }
364 }
365
366 /// Print the given successor.
367 virtual void printSuccessor(Block *successor) = 0;
368
369 /// Print the successor and its operands.
370 virtual void printSuccessorAndUseList(Block *successor,
371 ValueRange succOperands) = 0;
372
373 /// If the specified operation has attributes, print out an attribute
374 /// dictionary with their values. elidedAttrs allows the client to ignore
375 /// specific well known attributes, commonly used if the attribute value is
376 /// printed some other way (like as a fixed operand).
377 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
378 ArrayRef<StringRef> elidedAttrs = {}) = 0;
379
380 /// If the specified operation has attributes, print out an attribute
381 /// dictionary prefixed with 'attributes'.
382 virtual void
383 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
384 ArrayRef<StringRef> elidedAttrs = {}) = 0;
385
386 /// Print the entire operation with the default generic assembly form.
387 /// If `printOpName` is true, then the operation name is printed (the default)
388 /// otherwise it is omitted and the print will start with the operand list.
389 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
390
391 /// Prints a region.
392 /// If 'printEntryBlockArgs' is false, the arguments of the
393 /// block are not printed. If 'printBlockTerminator' is false, the terminator
394 /// operation of the block is not printed. If printEmptyBlock is true, then
395 /// the block header is printed even if the block is empty.
396 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
397 bool printBlockTerminators = true,
398 bool printEmptyBlock = false) = 0;
399
400 /// Renumber the arguments for the specified region to the same names as the
401 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
402 /// operations. If any entry in namesToUse is null, the corresponding
403 /// argument name is left alone.
404 virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0;
405
406 /// Prints an affine map of SSA ids, where SSA id names are used in place
407 /// of dims/symbols.
408 /// Operand values must come from single-result sources, and be valid
409 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
410 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
411 ValueRange operands) = 0;
412
413 /// Prints an affine expression of SSA ids with SSA id names used instead of
414 /// dims and symbols.
415 /// Operand values must come from single-result sources, and be valid
416 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
417 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
418 ValueRange symOperands) = 0;
419
420 /// Print the complete type of an operation in functional form.
421 void printFunctionalType(Operation *op);
422 using AsmPrinter::printFunctionalType;
423 };
424
425 // Make the implementations convenient to use.
426 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
427 p.printOperand(value);
428 return p;
429 }
430
431 template <typename T,
432 typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
433 !std::is_convertible<T &, Value &>::value,
434 T>::type * = nullptr>
435 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
436 p.printOperands(values);
437 return p;
438 }
439
440 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
441 p.printSuccessor(value);
442 return p;
443 }
444
445 //===----------------------------------------------------------------------===//
446 // AsmParser
447 //===----------------------------------------------------------------------===//
448
449 /// This base class exposes generic asm parser hooks, usable across the various
450 /// derived parsers.
451 class AsmParser {
452 public:
453 AsmParser() = default;
454 virtual ~AsmParser();
455
456 MLIRContext *getContext() const;
457
458 /// Return the location of the original name token.
459 virtual SMLoc getNameLoc() const = 0;
460
461 //===--------------------------------------------------------------------===//
462 // Utilities
463 //===--------------------------------------------------------------------===//
464
465 /// Emit a diagnostic at the specified location and return failure.
466 virtual InFlightDiagnostic emitError(SMLoc loc,
467 const Twine &message = {}) = 0;
468
469 /// Return a builder which provides useful access to MLIRContext, global
470 /// objects like types and attributes.
471 virtual Builder &getBuilder() const = 0;
472
473 /// Get the location of the next token and store it into the argument. This
474 /// always succeeds.
475 virtual SMLoc getCurrentLocation() = 0;
getCurrentLocation(SMLoc * loc)476 ParseResult getCurrentLocation(SMLoc *loc) {
477 *loc = getCurrentLocation();
478 return success();
479 }
480
481 /// Re-encode the given source location as an MLIR location and return it.
482 /// Note: This method should only be used when a `Location` is necessary, as
483 /// the encoding process is not efficient.
484 virtual Location getEncodedSourceLoc(SMLoc loc) = 0;
485
486 //===--------------------------------------------------------------------===//
487 // Token Parsing
488 //===--------------------------------------------------------------------===//
489
490 /// Parse a '->' token.
491 virtual ParseResult parseArrow() = 0;
492
493 /// Parse a '->' token if present
494 virtual ParseResult parseOptionalArrow() = 0;
495
496 /// Parse a `{` token.
497 virtual ParseResult parseLBrace() = 0;
498
499 /// Parse a `{` token if present.
500 virtual ParseResult parseOptionalLBrace() = 0;
501
502 /// Parse a `}` token.
503 virtual ParseResult parseRBrace() = 0;
504
505 /// Parse a `}` token if present.
506 virtual ParseResult parseOptionalRBrace() = 0;
507
508 /// Parse a `:` token.
509 virtual ParseResult parseColon() = 0;
510
511 /// Parse a `:` token if present.
512 virtual ParseResult parseOptionalColon() = 0;
513
514 /// Parse a `,` token.
515 virtual ParseResult parseComma() = 0;
516
517 /// Parse a `,` token if present.
518 virtual ParseResult parseOptionalComma() = 0;
519
520 /// Parse a `=` token.
521 virtual ParseResult parseEqual() = 0;
522
523 /// Parse a `=` token if present.
524 virtual ParseResult parseOptionalEqual() = 0;
525
526 /// Parse a '<' token.
527 virtual ParseResult parseLess() = 0;
528
529 /// Parse a '<' token if present.
530 virtual ParseResult parseOptionalLess() = 0;
531
532 /// Parse a '>' token.
533 virtual ParseResult parseGreater() = 0;
534
535 /// Parse a '>' token if present.
536 virtual ParseResult parseOptionalGreater() = 0;
537
538 /// Parse a '?' token.
539 virtual ParseResult parseQuestion() = 0;
540
541 /// Parse a '?' token if present.
542 virtual ParseResult parseOptionalQuestion() = 0;
543
544 /// Parse a '+' token.
545 virtual ParseResult parsePlus() = 0;
546
547 /// Parse a '+' token if present.
548 virtual ParseResult parseOptionalPlus() = 0;
549
550 /// Parse a '*' token.
551 virtual ParseResult parseStar() = 0;
552
553 /// Parse a '*' token if present.
554 virtual ParseResult parseOptionalStar() = 0;
555
556 /// Parse a '|' token.
557 virtual ParseResult parseVerticalBar() = 0;
558
559 /// Parse a '|' token if present.
560 virtual ParseResult parseOptionalVerticalBar() = 0;
561
562 /// Parse a quoted string token.
parseString(std::string * string)563 ParseResult parseString(std::string *string) {
564 auto loc = getCurrentLocation();
565 if (parseOptionalString(string))
566 return emitError(loc, "expected string");
567 return success();
568 }
569
570 /// Parse a quoted string token if present.
571 virtual ParseResult parseOptionalString(std::string *string) = 0;
572
573 /// Parse a `(` token.
574 virtual ParseResult parseLParen() = 0;
575
576 /// Parse a `(` token if present.
577 virtual ParseResult parseOptionalLParen() = 0;
578
579 /// Parse a `)` token.
580 virtual ParseResult parseRParen() = 0;
581
582 /// Parse a `)` token if present.
583 virtual ParseResult parseOptionalRParen() = 0;
584
585 /// Parse a `[` token.
586 virtual ParseResult parseLSquare() = 0;
587
588 /// Parse a `[` token if present.
589 virtual ParseResult parseOptionalLSquare() = 0;
590
591 /// Parse a `]` token.
592 virtual ParseResult parseRSquare() = 0;
593
594 /// Parse a `]` token if present.
595 virtual ParseResult parseOptionalRSquare() = 0;
596
597 /// Parse a `...` token if present;
598 virtual ParseResult parseOptionalEllipsis() = 0;
599
600 /// Parse a floating point value from the stream.
601 virtual ParseResult parseFloat(double &result) = 0;
602
603 /// Parse an integer value from the stream.
604 template <typename IntT>
parseInteger(IntT & result)605 ParseResult parseInteger(IntT &result) {
606 auto loc = getCurrentLocation();
607 OptionalParseResult parseResult = parseOptionalInteger(result);
608 if (!parseResult.hasValue())
609 return emitError(loc, "expected integer value");
610 return *parseResult;
611 }
612
613 /// Parse an optional integer value from the stream.
614 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
615
616 template <typename IntT>
parseOptionalInteger(IntT & result)617 OptionalParseResult parseOptionalInteger(IntT &result) {
618 auto loc = getCurrentLocation();
619
620 // Parse the unsigned variant.
621 APInt uintResult;
622 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
623 if (!parseResult.hasValue() || failed(*parseResult))
624 return parseResult;
625
626 // Try to convert to the provided integer type. sextOrTrunc is correct even
627 // for unsigned types because parseOptionalInteger ensures the sign bit is
628 // zero for non-negated integers.
629 result =
630 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
631 if (APInt(uintResult.getBitWidth(), result) != uintResult)
632 return emitError(loc, "integer value too large");
633 return success();
634 }
635
636 /// These are the supported delimiters around operand lists and region
637 /// argument lists, used by parseOperandList.
638 enum class Delimiter {
639 /// Zero or more operands with no delimiters.
640 None,
641 /// Parens surrounding zero or more operands.
642 Paren,
643 /// Square brackets surrounding zero or more operands.
644 Square,
645 /// <> brackets surrounding zero or more operands.
646 LessGreater,
647 /// {} brackets surrounding zero or more operands.
648 Braces,
649 /// Parens supporting zero or more operands, or nothing.
650 OptionalParen,
651 /// Square brackets supporting zero or more ops, or nothing.
652 OptionalSquare,
653 /// <> brackets supporting zero or more ops, or nothing.
654 OptionalLessGreater,
655 /// {} brackets surrounding zero or more operands, or nothing.
656 OptionalBraces,
657 };
658
659 /// Parse a list of comma-separated items with an optional delimiter. If a
660 /// delimiter is provided, then an empty list is allowed. If not, then at
661 /// least one element will be parsed.
662 ///
663 /// contextMessage is an optional message appended to "expected '('" sorts of
664 /// diagnostics when parsing the delimeters.
665 virtual ParseResult
666 parseCommaSeparatedList(Delimiter delimiter,
667 function_ref<ParseResult()> parseElementFn,
668 StringRef contextMessage = StringRef()) = 0;
669
670 /// Parse a comma separated list of elements that must have at least one entry
671 /// in it.
672 ParseResult
parseCommaSeparatedList(function_ref<ParseResult ()> parseElementFn)673 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
674 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
675 }
676
677 //===--------------------------------------------------------------------===//
678 // Keyword Parsing
679 //===--------------------------------------------------------------------===//
680
681 /// This class represents a StringSwitch like class that is useful for parsing
682 /// expected keywords. On construction, it invokes `parseKeyword` and
683 /// processes each of the provided cases statements until a match is hit. The
684 /// provided `ResultT` must be assignable from `failure()`.
685 template <typename ResultT = ParseResult>
686 class KeywordSwitch {
687 public:
KeywordSwitch(AsmParser & parser)688 KeywordSwitch(AsmParser &parser)
689 : parser(parser), loc(parser.getCurrentLocation()) {
690 if (failed(parser.parseKeywordOrCompletion(&keyword)))
691 result = failure();
692 }
693
694 /// Case that uses the provided value when true.
Case(StringLiteral str,ResultT value)695 KeywordSwitch &Case(StringLiteral str, ResultT value) {
696 return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
697 }
Default(ResultT value)698 KeywordSwitch &Default(ResultT value) {
699 return Default([&](StringRef, SMLoc) { return std::move(value); });
700 }
701 /// Case that invokes the provided functor when true. The parameters passed
702 /// to the functor are the keyword, and the location of the keyword (in case
703 /// any errors need to be emitted).
704 template <typename FnT>
705 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
Case(StringLiteral str,FnT && fn)706 Case(StringLiteral str, FnT &&fn) {
707 if (result)
708 return *this;
709
710 // If the word was empty, record this as a completion.
711 if (keyword.empty())
712 parser.codeCompleteExpectedTokens(str);
713 else if (keyword == str)
714 result.emplace(std::move(fn(keyword, loc)));
715 return *this;
716 }
717 template <typename FnT>
718 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
Default(FnT && fn)719 Default(FnT &&fn) {
720 if (!result)
721 result.emplace(fn(keyword, loc));
722 return *this;
723 }
724
725 /// Returns true if this switch has a value yet.
hasValue()726 bool hasValue() const { return result.has_value(); }
727
728 /// Return the result of the switch.
ResultT()729 LLVM_NODISCARD operator ResultT() {
730 if (!result)
731 return parser.emitError(loc, "unexpected keyword: ") << keyword;
732 return std::move(*result);
733 }
734
735 private:
736 /// The parser used to construct this switch.
737 AsmParser &parser;
738
739 /// The location of the keyword, used to emit errors as necessary.
740 SMLoc loc;
741
742 /// The parsed keyword itself.
743 StringRef keyword;
744
745 /// The result of the switch statement or none if currently unknown.
746 Optional<ResultT> result;
747 };
748
749 /// Parse a given keyword.
parseKeyword(StringRef keyword)750 ParseResult parseKeyword(StringRef keyword) {
751 return parseKeyword(keyword, "");
752 }
753 virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
754
755 /// Parse a keyword into 'keyword'.
parseKeyword(StringRef * keyword)756 ParseResult parseKeyword(StringRef *keyword) {
757 auto loc = getCurrentLocation();
758 if (parseOptionalKeyword(keyword))
759 return emitError(loc, "expected valid keyword");
760 return success();
761 }
762
763 /// Parse the given keyword if present.
764 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
765
766 /// Parse a keyword, if present, into 'keyword'.
767 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
768
769 /// Parse a keyword, if present, and if one of the 'allowedValues',
770 /// into 'keyword'
771 virtual ParseResult
772 parseOptionalKeyword(StringRef *keyword,
773 ArrayRef<StringRef> allowedValues) = 0;
774
775 /// Parse a keyword or a quoted string.
parseKeywordOrString(std::string * result)776 ParseResult parseKeywordOrString(std::string *result) {
777 if (failed(parseOptionalKeywordOrString(result)))
778 return emitError(getCurrentLocation())
779 << "expected valid keyword or string";
780 return success();
781 }
782
783 /// Parse an optional keyword or string.
784 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
785
786 //===--------------------------------------------------------------------===//
787 // Attribute/Type Parsing
788 //===--------------------------------------------------------------------===//
789
790 /// Invoke the `getChecked` method of the given Attribute or Type class, using
791 /// the provided location to emit errors in the case of failure. Note that
792 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
793 /// context parameter.
794 template <typename T, typename... ParamsT>
getChecked(SMLoc loc,ParamsT &&...params)795 auto getChecked(SMLoc loc, ParamsT &&...params) {
796 return T::getChecked([&] { return emitError(loc); },
797 std::forward<ParamsT>(params)...);
798 }
799 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
800 /// errors.
801 template <typename T, typename... ParamsT>
getChecked(ParamsT &&...params)802 auto getChecked(ParamsT &&...params) {
803 return T::getChecked([&] { return emitError(getNameLoc()); },
804 std::forward<ParamsT>(params)...);
805 }
806
807 //===--------------------------------------------------------------------===//
808 // Attribute Parsing
809 //===--------------------------------------------------------------------===//
810
811 /// Parse an arbitrary attribute of a given type and return it in result.
812 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
813
814 /// Parse a custom attribute with the provided callback, unless the next
815 /// token is `#`, in which case the generic parser is invoked.
816 virtual ParseResult parseCustomAttributeWithFallback(
817 Attribute &result, Type type,
818 function_ref<ParseResult(Attribute &result, Type type)>
819 parseAttribute) = 0;
820
821 /// Parse an attribute of a specific kind and type.
822 template <typename AttrType>
823 ParseResult parseAttribute(AttrType &result, Type type = {}) {
824 SMLoc loc = getCurrentLocation();
825
826 // Parse any kind of attribute.
827 Attribute attr;
828 if (parseAttribute(attr, type))
829 return failure();
830
831 // Check for the right kind of attribute.
832 if (!(result = attr.dyn_cast<AttrType>()))
833 return emitError(loc, "invalid kind of attribute specified");
834
835 return success();
836 }
837
838 /// Parse an arbitrary attribute and return it in result. This also adds the
839 /// attribute to the specified attribute list with the specified name.
parseAttribute(Attribute & result,StringRef attrName,NamedAttrList & attrs)840 ParseResult parseAttribute(Attribute &result, StringRef attrName,
841 NamedAttrList &attrs) {
842 return parseAttribute(result, Type(), attrName, attrs);
843 }
844
845 /// Parse an attribute of a specific kind and type.
846 template <typename AttrType>
parseAttribute(AttrType & result,StringRef attrName,NamedAttrList & attrs)847 ParseResult parseAttribute(AttrType &result, StringRef attrName,
848 NamedAttrList &attrs) {
849 return parseAttribute(result, Type(), attrName, attrs);
850 }
851
852 /// Parse an arbitrary attribute of a given type and populate it in `result`.
853 /// This also adds the attribute to the specified attribute list with the
854 /// specified name.
855 template <typename AttrType>
parseAttribute(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)856 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
857 NamedAttrList &attrs) {
858 SMLoc loc = getCurrentLocation();
859
860 // Parse any kind of attribute.
861 Attribute attr;
862 if (parseAttribute(attr, type))
863 return failure();
864
865 // Check for the right kind of attribute.
866 result = attr.dyn_cast<AttrType>();
867 if (!result)
868 return emitError(loc, "invalid kind of attribute specified");
869
870 attrs.append(attrName, result);
871 return success();
872 }
873
874 /// Trait to check if `AttrType` provides a `parse` method.
875 template <typename AttrType>
876 using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
877 std::declval<Type>()));
878 template <typename AttrType>
879 using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
880
881 /// Parse a custom attribute of a given type unless the next token is `#`, in
882 /// which case the generic parser is invoked. The parsed attribute is
883 /// populated in `result` and also added to the specified attribute list with
884 /// the specified name.
885 template <typename AttrType>
886 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)887 parseCustomAttributeWithFallback(AttrType &result, Type type,
888 StringRef attrName, NamedAttrList &attrs) {
889 SMLoc loc = getCurrentLocation();
890
891 // Parse any kind of attribute.
892 Attribute attr;
893 if (parseCustomAttributeWithFallback(
894 attr, type, [&](Attribute &result, Type type) -> ParseResult {
895 result = AttrType::parse(*this, type);
896 if (!result)
897 return failure();
898 return success();
899 }))
900 return failure();
901
902 // Check for the right kind of attribute.
903 result = attr.dyn_cast<AttrType>();
904 if (!result)
905 return emitError(loc, "invalid kind of attribute specified");
906
907 attrs.append(attrName, result);
908 return success();
909 }
910
911 /// SFINAE parsing method for Attribute that don't implement a parse method.
912 template <typename AttrType>
913 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)914 parseCustomAttributeWithFallback(AttrType &result, Type type,
915 StringRef attrName, NamedAttrList &attrs) {
916 return parseAttribute(result, type, attrName, attrs);
917 }
918
919 /// Parse a custom attribute of a given type unless the next token is `#`, in
920 /// which case the generic parser is invoked. The parsed attribute is
921 /// populated in `result`.
922 template <typename AttrType>
923 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType & result)924 parseCustomAttributeWithFallback(AttrType &result) {
925 SMLoc loc = getCurrentLocation();
926
927 // Parse any kind of attribute.
928 Attribute attr;
929 if (parseCustomAttributeWithFallback(
930 attr, {}, [&](Attribute &result, Type type) -> ParseResult {
931 result = AttrType::parse(*this, type);
932 return success(!!result);
933 }))
934 return failure();
935
936 // Check for the right kind of attribute.
937 result = attr.dyn_cast<AttrType>();
938 if (!result)
939 return emitError(loc, "invalid kind of attribute specified");
940 return success();
941 }
942
943 /// SFINAE parsing method for Attribute that don't implement a parse method.
944 template <typename AttrType>
945 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType & result)946 parseCustomAttributeWithFallback(AttrType &result) {
947 return parseAttribute(result);
948 }
949
950 /// Parse an arbitrary optional attribute of a given type and return it in
951 /// result.
952 virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
953 Type type = {}) = 0;
954
955 /// Parse an optional array attribute and return it in result.
956 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
957 Type type = {}) = 0;
958
959 /// Parse an optional string attribute and return it in result.
960 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
961 Type type = {}) = 0;
962
963 /// Parse an optional attribute of a specific type and add it to the list with
964 /// the specified name.
965 template <typename AttrType>
parseOptionalAttribute(AttrType & result,StringRef attrName,NamedAttrList & attrs)966 OptionalParseResult parseOptionalAttribute(AttrType &result,
967 StringRef attrName,
968 NamedAttrList &attrs) {
969 return parseOptionalAttribute(result, Type(), attrName, attrs);
970 }
971
972 /// Parse an optional attribute of a specific type and add it to the list with
973 /// the specified name.
974 template <typename AttrType>
parseOptionalAttribute(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)975 OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
976 StringRef attrName,
977 NamedAttrList &attrs) {
978 OptionalParseResult parseResult = parseOptionalAttribute(result, type);
979 if (parseResult.hasValue() && succeeded(*parseResult))
980 attrs.append(attrName, result);
981 return parseResult;
982 }
983
984 /// Parse a named dictionary into 'result' if it is present.
985 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
986
987 /// Parse a named dictionary into 'result' if the `attributes` keyword is
988 /// present.
989 virtual ParseResult
990 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
991
992 /// Parse an affine map instance into 'map'.
993 virtual ParseResult parseAffineMap(AffineMap &map) = 0;
994
995 /// Parse an integer set instance into 'set'.
996 virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
997
998 //===--------------------------------------------------------------------===//
999 // Identifier Parsing
1000 //===--------------------------------------------------------------------===//
1001
1002 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1003 /// attribute named 'attrName'.
parseSymbolName(StringAttr & result,StringRef attrName,NamedAttrList & attrs)1004 ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
1005 NamedAttrList &attrs) {
1006 if (failed(parseOptionalSymbolName(result, attrName, attrs)))
1007 return emitError(getCurrentLocation())
1008 << "expected valid '@'-identifier for symbol name";
1009 return success();
1010 }
1011
1012 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1013 /// string attribute named 'attrName'.
1014 virtual ParseResult parseOptionalSymbolName(StringAttr &result,
1015 StringRef attrName,
1016 NamedAttrList &attrs) = 0;
1017
1018 //===--------------------------------------------------------------------===//
1019 // Resource Parsing
1020 //===--------------------------------------------------------------------===//
1021
1022 /// Parse a handle to a resource within the assembly format.
1023 template <typename ResourceT>
parseResourceHandle()1024 FailureOr<ResourceT> parseResourceHandle() {
1025 SMLoc handleLoc = getCurrentLocation();
1026 FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(
1027 getContext()->getOrLoadDialect<typename ResourceT::Dialect>());
1028 if (failed(handle))
1029 return failure();
1030 if (auto *result = dyn_cast<ResourceT>(&*handle))
1031 return std::move(*result);
1032 return emitError(handleLoc) << "provided resource handle differs from the "
1033 "expected resource type";
1034 }
1035
1036 //===--------------------------------------------------------------------===//
1037 // Type Parsing
1038 //===--------------------------------------------------------------------===//
1039
1040 /// Parse a type.
1041 virtual ParseResult parseType(Type &result) = 0;
1042
1043 /// Parse a custom type with the provided callback, unless the next
1044 /// token is `#`, in which case the generic parser is invoked.
1045 virtual ParseResult parseCustomTypeWithFallback(
1046 Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
1047
1048 /// Parse an optional type.
1049 virtual OptionalParseResult parseOptionalType(Type &result) = 0;
1050
1051 /// Parse a type of a specific type.
1052 template <typename TypeT>
parseType(TypeT & result)1053 ParseResult parseType(TypeT &result) {
1054 SMLoc loc = getCurrentLocation();
1055
1056 // Parse any kind of type.
1057 Type type;
1058 if (parseType(type))
1059 return failure();
1060
1061 // Check for the right kind of type.
1062 result = type.dyn_cast<TypeT>();
1063 if (!result)
1064 return emitError(loc, "invalid kind of type specified");
1065
1066 return success();
1067 }
1068
1069 /// Trait to check if `TypeT` provides a `parse` method.
1070 template <typename TypeT>
1071 using type_has_parse_method =
1072 decltype(TypeT::parse(std::declval<AsmParser &>()));
1073 template <typename TypeT>
1074 using detect_type_has_parse_method =
1075 llvm::is_detected<type_has_parse_method, TypeT>;
1076
1077 /// Parse a custom Type of a given type unless the next token is `#`, in
1078 /// which case the generic parser is invoked. The parsed Type is
1079 /// populated in `result`.
1080 template <typename TypeT>
1081 std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
parseCustomTypeWithFallback(TypeT & result)1082 parseCustomTypeWithFallback(TypeT &result) {
1083 SMLoc loc = getCurrentLocation();
1084
1085 // Parse any kind of Type.
1086 Type type;
1087 if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
1088 result = TypeT::parse(*this);
1089 return success(!!result);
1090 }))
1091 return failure();
1092
1093 // Check for the right kind of Type.
1094 result = type.dyn_cast<TypeT>();
1095 if (!result)
1096 return emitError(loc, "invalid kind of Type specified");
1097 return success();
1098 }
1099
1100 /// SFINAE parsing method for Type that don't implement a parse method.
1101 template <typename TypeT>
1102 std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
parseCustomTypeWithFallback(TypeT & result)1103 parseCustomTypeWithFallback(TypeT &result) {
1104 return parseType(result);
1105 }
1106
1107 /// Parse a type list.
parseTypeList(SmallVectorImpl<Type> & result)1108 ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
1109 return parseCommaSeparatedList(
1110 [&]() { return parseType(result.emplace_back()); });
1111 }
1112
1113 /// Parse an arrow followed by a type list.
1114 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1115
1116 /// Parse an optional arrow followed by a type list.
1117 virtual ParseResult
1118 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1119
1120 /// Parse a colon followed by a type.
1121 virtual ParseResult parseColonType(Type &result) = 0;
1122
1123 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
1124 template <typename TypeType>
parseColonType(TypeType & result)1125 ParseResult parseColonType(TypeType &result) {
1126 SMLoc loc = getCurrentLocation();
1127
1128 // Parse any kind of type.
1129 Type type;
1130 if (parseColonType(type))
1131 return failure();
1132
1133 // Check for the right kind of type.
1134 result = type.dyn_cast<TypeType>();
1135 if (!result)
1136 return emitError(loc, "invalid kind of type specified");
1137
1138 return success();
1139 }
1140
1141 /// Parse a colon followed by a type list, which must have at least one type.
1142 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
1143
1144 /// Parse an optional colon followed by a type list, which if present must
1145 /// have at least one type.
1146 virtual ParseResult
1147 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
1148
1149 /// Parse a keyword followed by a type.
parseKeywordType(const char * keyword,Type & result)1150 ParseResult parseKeywordType(const char *keyword, Type &result) {
1151 return failure(parseKeyword(keyword) || parseType(result));
1152 }
1153
1154 /// Add the specified type to the end of the specified type list and return
1155 /// success. This is a helper designed to allow parse methods to be simple
1156 /// and chain through || operators.
addTypeToList(Type type,SmallVectorImpl<Type> & result)1157 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
1158 result.push_back(type);
1159 return success();
1160 }
1161
1162 /// Add the specified types to the end of the specified type list and return
1163 /// success. This is a helper designed to allow parse methods to be simple
1164 /// and chain through || operators.
addTypesToList(ArrayRef<Type> types,SmallVectorImpl<Type> & result)1165 ParseResult addTypesToList(ArrayRef<Type> types,
1166 SmallVectorImpl<Type> &result) {
1167 result.append(types.begin(), types.end());
1168 return success();
1169 }
1170
1171 /// Parse a dimension list of a tensor or memref type. This populates the
1172 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set
1173 /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable.
1174 ///
1175 /// dimension-list ::= eps | dimension (`x` dimension)*
1176 /// dimension-list-with-trailing-x ::= (dimension `x`)*
1177 /// dimension ::= `?` | decimal-literal
1178 ///
1179 /// When `allowDynamic` is not set, this is used to parse:
1180 ///
1181 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
1182 /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
1183 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1184 bool allowDynamic = true,
1185 bool withTrailingX = true) = 0;
1186
1187 /// Parse an 'x' token in a dimension list, handling the case where the x is
1188 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1189 /// next token.
1190 virtual ParseResult parseXInDimensionList() = 0;
1191
1192 protected:
1193 /// Parse a handle to a resource within the assembly format for the given
1194 /// dialect.
1195 virtual FailureOr<AsmDialectResourceHandle>
1196 parseResourceHandle(Dialect *dialect) = 0;
1197
1198 //===--------------------------------------------------------------------===//
1199 // Code Completion
1200 //===--------------------------------------------------------------------===//
1201
1202 /// Parse a keyword, or an empty string if the current location signals a code
1203 /// completion.
1204 virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
1205
1206 /// Signal the code completion of a set of expected tokens.
1207 virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
1208
1209 private:
1210 AsmParser(const AsmParser &) = delete;
1211 void operator=(const AsmParser &) = delete;
1212 };
1213
1214 //===----------------------------------------------------------------------===//
1215 // OpAsmParser
1216 //===----------------------------------------------------------------------===//
1217
1218 /// The OpAsmParser has methods for interacting with the asm parser: parsing
1219 /// things from it, emitting errors etc. It has an intentionally high-level API
1220 /// that is designed to reduce/constrain syntax innovation in individual
1221 /// operations.
1222 ///
1223 /// For example, consider an op like this:
1224 ///
1225 /// %x = load %p[%1, %2] : memref<...>
1226 ///
1227 /// The "%x = load" tokens are already parsed and therefore invisible to the
1228 /// custom op parser. This can be supported by calling `parseOperandList` to
1229 /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1230 /// parse the indices, then calling `parseColonTypeList` to parse the result
1231 /// type.
1232 ///
1233 class OpAsmParser : public AsmParser {
1234 public:
1235 using AsmParser::AsmParser;
1236 ~OpAsmParser() override;
1237
1238 /// Parse a loc(...) specifier if present, filling in result if so.
1239 /// Location for BlockArgument and Operation may be deferred with an alias, in
1240 /// which case an OpaqueLoc is set and will be resolved when parsing
1241 /// completes.
1242 virtual ParseResult
1243 parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1244
1245 /// Return the name of the specified result in the specified syntax, as well
1246 /// as the sub-element in the name. It returns an empty string and ~0U for
1247 /// invalid result numbers. For example, in this operation:
1248 ///
1249 /// %x, %y:2, %z = foo.op
1250 ///
1251 /// getResultName(0) == {"x", 0 }
1252 /// getResultName(1) == {"y", 0 }
1253 /// getResultName(2) == {"y", 1 }
1254 /// getResultName(3) == {"z", 0 }
1255 /// getResultName(4) == {"", ~0U }
1256 virtual std::pair<StringRef, unsigned>
1257 getResultName(unsigned resultNo) const = 0;
1258
1259 /// Return the number of declared SSA results. This returns 4 for the foo.op
1260 /// example in the comment for `getResultName`.
1261 virtual size_t getNumResults() const = 0;
1262
1263 // These methods emit an error and return failure or success. This allows
1264 // these to be chained together into a linear sequence of || expressions in
1265 // many cases.
1266
1267 /// Parse an operation in its generic form.
1268 /// The parsed operation is parsed in the current context and inserted in the
1269 /// provided block and insertion point. The results produced by this operation
1270 /// aren't mapped to any named value in the parser. Returns nullptr on
1271 /// failure.
1272 virtual Operation *parseGenericOperation(Block *insertBlock,
1273 Block::iterator insertPt) = 0;
1274
1275 /// Parse the name of an operation, in the custom form. On success, return a
1276 /// an object of type 'OperationName'. Otherwise, failure is returned.
1277 virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1278
1279 //===--------------------------------------------------------------------===//
1280 // Operand Parsing
1281 //===--------------------------------------------------------------------===//
1282
1283 /// This is the representation of an operand reference.
1284 struct UnresolvedOperand {
1285 SMLoc location; // Location of the token.
1286 StringRef name; // Value name, e.g. %42 or %abc
1287 unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1288 };
1289
1290 /// Parse different components, viz., use-info of operand(s), successor(s),
1291 /// region(s), attribute(s) and function-type, of the generic form of an
1292 /// operation instance and populate the input operation-state 'result' with
1293 /// those components. If any of the components is explicitly provided, then
1294 /// skip parsing that component.
1295 virtual ParseResult parseGenericOperationAfterOpName(
1296 OperationState &result,
1297 Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = llvm::None,
1298 Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None,
1299 Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1300 llvm::None,
1301 Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
1302 Optional<FunctionType> parsedFnType = llvm::None) = 0;
1303
1304 /// Parse a single SSA value operand name along with a result number if
1305 /// `allowResultNumber` is true.
1306 virtual ParseResult parseOperand(UnresolvedOperand &result,
1307 bool allowResultNumber = true) = 0;
1308
1309 /// Parse a single operand if present.
1310 virtual OptionalParseResult
1311 parseOptionalOperand(UnresolvedOperand &result,
1312 bool allowResultNumber = true) = 0;
1313
1314 /// Parse zero or more SSA comma-separated operand references with a specified
1315 /// surrounding delimiter, and an optional required operand count.
1316 virtual ParseResult
1317 parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1318 Delimiter delimiter = Delimiter::None,
1319 bool allowResultNumber = true,
1320 int requiredOperandCount = -1) = 0;
1321
1322 /// Parse a specified number of comma separated operands.
1323 ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1324 int requiredOperandCount,
1325 Delimiter delimiter = Delimiter::None) {
1326 return parseOperandList(result, delimiter,
1327 /*allowResultNumber=*/true, requiredOperandCount);
1328 }
1329
1330 /// Parse zero or more trailing SSA comma-separated trailing operand
1331 /// references with a specified surrounding delimiter, and an optional
1332 /// required operand count. A leading comma is expected before the
1333 /// operands.
1334 ParseResult
1335 parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1336 Delimiter delimiter = Delimiter::None) {
1337 if (failed(parseOptionalComma()))
1338 return success(); // The comma is optional.
1339 return parseOperandList(result, delimiter);
1340 }
1341
1342 /// Resolve an operand to an SSA value, emitting an error on failure.
1343 virtual ParseResult resolveOperand(const UnresolvedOperand &operand,
1344 Type type,
1345 SmallVectorImpl<Value> &result) = 0;
1346
1347 /// Resolve a list of operands to SSA values, emitting an error on failure, or
1348 /// appending the results to the list on success. This method should be used
1349 /// when all operands have the same type.
resolveOperands(ArrayRef<UnresolvedOperand> operands,Type type,SmallVectorImpl<Value> & result)1350 ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands, Type type,
1351 SmallVectorImpl<Value> &result) {
1352 for (auto elt : operands)
1353 if (resolveOperand(elt, type, result))
1354 return failure();
1355 return success();
1356 }
1357
1358 /// Resolve a list of operands and a list of operand types to SSA values,
1359 /// emitting an error and returning failure, or appending the results
1360 /// to the list on success.
resolveOperands(ArrayRef<UnresolvedOperand> operands,ArrayRef<Type> types,SMLoc loc,SmallVectorImpl<Value> & result)1361 ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands,
1362 ArrayRef<Type> types, SMLoc loc,
1363 SmallVectorImpl<Value> &result) {
1364 if (operands.size() != types.size())
1365 return emitError(loc)
1366 << operands.size() << " operands present, but expected "
1367 << types.size();
1368
1369 for (unsigned i = 0, e = operands.size(); i != e; ++i)
1370 if (resolveOperand(operands[i], types[i], result))
1371 return failure();
1372 return success();
1373 }
1374 template <typename Operands>
resolveOperands(Operands && operands,Type type,SMLoc loc,SmallVectorImpl<Value> & result)1375 ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
1376 SmallVectorImpl<Value> &result) {
1377 return resolveOperands(std::forward<Operands>(operands),
1378 ArrayRef<Type>(type), loc, result);
1379 }
1380 template <typename Operands, typename Types>
1381 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
resolveOperands(Operands && operands,Types && types,SMLoc loc,SmallVectorImpl<Value> & result)1382 resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
1383 SmallVectorImpl<Value> &result) {
1384 size_t operandSize = std::distance(operands.begin(), operands.end());
1385 size_t typeSize = std::distance(types.begin(), types.end());
1386 if (operandSize != typeSize)
1387 return emitError(loc)
1388 << operandSize << " operands present, but expected " << typeSize;
1389
1390 for (auto it : llvm::zip(operands, types))
1391 if (resolveOperand(std::get<0>(it), std::get<1>(it), result))
1392 return failure();
1393 return success();
1394 }
1395
1396 /// Parses an affine map attribute where dims and symbols are SSA operands.
1397 /// Operand values must come from single-result sources, and be valid
1398 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1399 virtual ParseResult
1400 parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands,
1401 Attribute &map, StringRef attrName,
1402 NamedAttrList &attrs,
1403 Delimiter delimiter = Delimiter::Square) = 0;
1404
1405 /// Parses an affine expression where dims and symbols are SSA operands.
1406 /// Operand values must come from single-result sources, and be valid
1407 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1408 virtual ParseResult
1409 parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands,
1410 SmallVectorImpl<UnresolvedOperand> &symbOperands,
1411 AffineExpr &expr) = 0;
1412
1413 //===--------------------------------------------------------------------===//
1414 // Argument Parsing
1415 //===--------------------------------------------------------------------===//
1416
1417 struct Argument {
1418 UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
1419 Type type; // Type.
1420 DictionaryAttr attrs; // Attributes if present.
1421 Optional<Location> sourceLoc; // Source location specifier if present.
1422 };
1423
1424 /// Parse a single argument with the following syntax:
1425 ///
1426 /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
1427 ///
1428 /// If `allowType` is false or `allowAttrs` are false then the respective
1429 /// parts of the grammar are not parsed.
1430 virtual ParseResult parseArgument(Argument &result, bool allowType = false,
1431 bool allowAttrs = false) = 0;
1432
1433 /// Parse a single argument if present.
1434 virtual OptionalParseResult
1435 parseOptionalArgument(Argument &result, bool allowType = false,
1436 bool allowAttrs = false) = 0;
1437
1438 /// Parse zero or more arguments with a specified surrounding delimiter.
1439 virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
1440 Delimiter delimiter = Delimiter::None,
1441 bool allowType = false,
1442 bool allowAttrs = false) = 0;
1443
1444 //===--------------------------------------------------------------------===//
1445 // Region Parsing
1446 //===--------------------------------------------------------------------===//
1447
1448 /// Parses a region. Any parsed blocks are appended to 'region' and must be
1449 /// moved to the op regions after the op is created. The first block of the
1450 /// region takes 'arguments'.
1451 ///
1452 /// If 'enableNameShadowing' is set to true, the argument names are allowed to
1453 /// shadow the names of other existing SSA values defined above the region
1454 /// scope. 'enableNameShadowing' can only be set to true for regions attached
1455 /// to operations that are 'IsolatedFromAbove'.
1456 virtual ParseResult parseRegion(Region ®ion,
1457 ArrayRef<Argument> arguments = {},
1458 bool enableNameShadowing = false) = 0;
1459
1460 /// Parses a region if present.
1461 virtual OptionalParseResult
1462 parseOptionalRegion(Region ®ion, ArrayRef<Argument> arguments = {},
1463 bool enableNameShadowing = false) = 0;
1464
1465 /// Parses a region if present. If the region is present, a new region is
1466 /// allocated and placed in `region`. If no region is present or on failure,
1467 /// `region` remains untouched.
1468 virtual OptionalParseResult
1469 parseOptionalRegion(std::unique_ptr<Region> ®ion,
1470 ArrayRef<Argument> arguments = {},
1471 bool enableNameShadowing = false) = 0;
1472
1473 //===--------------------------------------------------------------------===//
1474 // Successor Parsing
1475 //===--------------------------------------------------------------------===//
1476
1477 /// Parse a single operation successor.
1478 virtual ParseResult parseSuccessor(Block *&dest) = 0;
1479
1480 /// Parse an optional operation successor.
1481 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1482
1483 /// Parse a single operation successor and its operand list.
1484 virtual ParseResult
1485 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1486
1487 //===--------------------------------------------------------------------===//
1488 // Type Parsing
1489 //===--------------------------------------------------------------------===//
1490
1491 /// Parse a list of assignments of the form
1492 /// (%x1 = %y1, %x2 = %y2, ...)
parseAssignmentList(SmallVectorImpl<Argument> & lhs,SmallVectorImpl<UnresolvedOperand> & rhs)1493 ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
1494 SmallVectorImpl<UnresolvedOperand> &rhs) {
1495 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1496 if (!result.hasValue())
1497 return emitError(getCurrentLocation(), "expected '('");
1498 return result.getValue();
1499 }
1500
1501 virtual OptionalParseResult
1502 parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
1503 SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
1504 };
1505
1506 //===--------------------------------------------------------------------===//
1507 // Dialect OpAsm interface.
1508 //===--------------------------------------------------------------------===//
1509
1510 /// A functor used to set the name of the start of a result group of an
1511 /// operation. See 'getAsmResultNames' below for more details.
1512 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1513
1514 /// A functor used to set the name of blocks in regions directly nested under
1515 /// an operation.
1516 using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
1517
1518 class OpAsmDialectInterface
1519 : public DialectInterface::Base<OpAsmDialectInterface> {
1520 public:
OpAsmDialectInterface(Dialect * dialect)1521 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1522
1523 //===------------------------------------------------------------------===//
1524 // Aliases
1525 //===------------------------------------------------------------------===//
1526
1527 /// Holds the result of `getAlias` hook call.
1528 enum class AliasResult {
1529 /// The object (type or attribute) is not supported by the hook
1530 /// and an alias was not provided.
1531 NoAlias,
1532 /// An alias was provided, but it might be overriden by other hook.
1533 OverridableAlias,
1534 /// An alias was provided and it should be used
1535 /// (no other hooks will be checked).
1536 FinalAlias
1537 };
1538
1539 /// Hooks for getting an alias identifier alias for a given symbol, that is
1540 /// not necessarily a part of this dialect. The identifier is used in place of
1541 /// the symbol when printing textual IR. These aliases must not contain `.` or
1542 /// end with a numeric digit([0-9]+).
getAlias(Attribute attr,raw_ostream & os)1543 virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1544 return AliasResult::NoAlias;
1545 }
getAlias(Type type,raw_ostream & os)1546 virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1547 return AliasResult::NoAlias;
1548 }
1549
1550 //===--------------------------------------------------------------------===//
1551 // Resources
1552 //===--------------------------------------------------------------------===//
1553
1554 /// Declare a resource with the given key, returning a handle to use for any
1555 /// references of this resource key within the IR during parsing. The result
1556 /// of `getResourceKey` on the returned handle is permitted to be different
1557 /// than `key`.
1558 virtual FailureOr<AsmDialectResourceHandle>
declareResource(StringRef key)1559 declareResource(StringRef key) const {
1560 return failure();
1561 }
1562
1563 /// Return a key to use for the given resource. This key should uniquely
1564 /// identify this resource within the dialect.
1565 virtual std::string
getResourceKey(const AsmDialectResourceHandle & handle)1566 getResourceKey(const AsmDialectResourceHandle &handle) const {
1567 llvm_unreachable(
1568 "Dialect must implement `getResourceKey` when defining resources");
1569 }
1570
1571 /// Hook for parsing resource entries. Returns failure if the entry was not
1572 /// valid, or could otherwise not be processed correctly. Any necessary errors
1573 /// can be emitted via the provided entry.
1574 virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
1575
1576 /// Hook for building resources to use during printing. The given `op` may be
1577 /// inspected to help determine what information to include.
1578 /// `referencedResources` contains all of the resources detected when printing
1579 /// 'op'.
1580 virtual void
buildResources(Operation * op,const SetVector<AsmDialectResourceHandle> & referencedResources,AsmResourceBuilder & builder)1581 buildResources(Operation *op,
1582 const SetVector<AsmDialectResourceHandle> &referencedResources,
1583 AsmResourceBuilder &builder) const {}
1584 };
1585 } // namespace mlir
1586
1587 //===--------------------------------------------------------------------===//
1588 // Operation OpAsm interface.
1589 //===--------------------------------------------------------------------===//
1590
1591 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1592 #include "mlir/IR/OpAsmInterface.h.inc"
1593
1594 namespace llvm {
1595 template <>
1596 struct DenseMapInfo<mlir::AsmDialectResourceHandle> {
1597 static inline mlir::AsmDialectResourceHandle getEmptyKey() {
1598 return {DenseMapInfo<void *>::getEmptyKey(),
1599 DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr};
1600 }
1601 static inline mlir::AsmDialectResourceHandle getTombstoneKey() {
1602 return {DenseMapInfo<void *>::getTombstoneKey(),
1603 DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr};
1604 }
1605 static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) {
1606 return DenseMapInfo<void *>::getHashValue(handle.getResource());
1607 }
1608 static bool isEqual(const mlir::AsmDialectResourceHandle &lhs,
1609 const mlir::AsmDialectResourceHandle &rhs) {
1610 return lhs.getResource() == rhs.getResource();
1611 }
1612 };
1613 } // namespace llvm
1614
1615 #endif
1616