1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser  --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AsmParserImpl.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "llvm/Support/SourceMgr.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 using llvm::MemoryBuffer;
23 using llvm::SourceMgr;
24 
25 namespace {
26 /// This class provides the main implementation of the DialectAsmParser that
27 /// allows for dialects to parse attributes and types. This allows for dialect
28 /// hooking into the main MLIR parsing logic.
29 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
30 public:
CustomDialectAsmParser(StringRef fullSpec,Parser & parser)31   CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
32       : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
33         fullSpec(fullSpec) {}
34   ~CustomDialectAsmParser() override = default;
35 
36   /// Returns the full specification of the symbol being parsed. This allows
37   /// for using a separate parser if necessary.
getFullSymbolSpec() const38   StringRef getFullSymbolSpec() const override { return fullSpec; }
39 
40 private:
41   /// The full symbol specification.
42   StringRef fullSpec;
43 };
44 } // namespace
45 
46 ///
47 ///   pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
48 ///   pretty-dialect-sym-contents ::= pretty-dialect-sym-body
49 ///                                  | '(' pretty-dialect-sym-contents+ ')'
50 ///                                  | '[' pretty-dialect-sym-contents+ ']'
51 ///                                  | '{' pretty-dialect-sym-contents+ '}'
52 ///                                  | '[^[<({>\])}\0]+'
53 ///
parseDialectSymbolBody(StringRef & body,bool & isCodeCompletion)54 ParseResult Parser::parseDialectSymbolBody(StringRef &body,
55                                            bool &isCodeCompletion) {
56   // Symbol bodies are a relatively unstructured format that contains a series
57   // of properly nested punctuation, with anything else in the middle. Scan
58   // ahead to find it and consume it if successful, otherwise emit an error.
59   const char *curPtr = getTokenSpelling().data();
60 
61   // Scan over the nested punctuation, bailing out on error and consuming until
62   // we find the end. We know that we're currently looking at the '<', so we can
63   // go until we find the matching '>' character.
64   assert(*curPtr == '<');
65   SmallVector<char, 8> nestedPunctuation;
66   const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
67   do {
68     // Handle code completions, which may appear in the middle of the symbol
69     // body.
70     if (curPtr == codeCompleteLoc) {
71       isCodeCompletion = true;
72       nestedPunctuation.clear();
73       break;
74     }
75 
76     char c = *curPtr++;
77     switch (c) {
78     case '\0':
79       // This also handles the EOF case.
80       if (!nestedPunctuation.empty()) {
81         return emitError() << "unbalanced '" << nestedPunctuation.back()
82                            << "' character in pretty dialect name";
83       }
84       return emitError("unexpected nul or EOF in pretty dialect name");
85     case '<':
86     case '[':
87     case '(':
88     case '{':
89       nestedPunctuation.push_back(c);
90       continue;
91 
92     case '-':
93       // The sequence `->` is treated as special token.
94       if (*curPtr == '>')
95         ++curPtr;
96       continue;
97 
98     case '>':
99       if (nestedPunctuation.pop_back_val() != '<')
100         return emitError("unbalanced '>' character in pretty dialect name");
101       break;
102     case ']':
103       if (nestedPunctuation.pop_back_val() != '[')
104         return emitError("unbalanced ']' character in pretty dialect name");
105       break;
106     case ')':
107       if (nestedPunctuation.pop_back_val() != '(')
108         return emitError("unbalanced ')' character in pretty dialect name");
109       break;
110     case '}':
111       if (nestedPunctuation.pop_back_val() != '{')
112         return emitError("unbalanced '}' character in pretty dialect name");
113       break;
114     case '"': {
115       // Dispatch to the lexer to lex past strings.
116       resetToken(curPtr - 1);
117       curPtr = state.curToken.getEndLoc().getPointer();
118 
119       // Handle code completions, which may appear in the middle of the symbol
120       // body.
121       if (state.curToken.isCodeCompletion()) {
122         isCodeCompletion = true;
123         nestedPunctuation.clear();
124         break;
125       }
126 
127       // Otherwise, ensure this token was actually a string.
128       if (state.curToken.isNot(Token::string))
129         return failure();
130       break;
131     }
132 
133     default:
134       continue;
135     }
136   } while (!nestedPunctuation.empty());
137 
138   // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
139   // consuming all this stuff, and return.
140   resetToken(curPtr);
141 
142   unsigned length = curPtr - body.begin();
143   body = StringRef(body.data(), length);
144   return success();
145 }
146 
147 /// Parse an extended dialect symbol.
148 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
parseExtendedSymbol(Parser & p,SymbolAliasMap & aliases,CreateFn && createSymbol)149 static Symbol parseExtendedSymbol(Parser &p, SymbolAliasMap &aliases,
150                                   CreateFn &&createSymbol) {
151   Token tok = p.getToken();
152 
153   // Handle code completion of the extended symbol.
154   StringRef identifier = tok.getSpelling().drop_front();
155   if (tok.isCodeCompletion() && identifier.empty())
156     return p.codeCompleteDialectSymbol(aliases);
157 
158   // Parse the dialect namespace.
159   SMLoc loc = p.getToken().getLoc();
160   p.consumeToken();
161 
162   // Check to see if this is a pretty name.
163   StringRef dialectName;
164   StringRef symbolData;
165   std::tie(dialectName, symbolData) = identifier.split('.');
166   bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
167 
168   // Check to see if the symbol has trailing data, i.e. has an immediately
169   // following '<'.
170   bool hasTrailingData =
171       p.getToken().is(Token::less) &&
172       identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
173 
174   // If there is no '<' token following this, and if the typename contains no
175   // dot, then we are parsing a symbol alias.
176   if (!hasTrailingData && !isPrettyName) {
177     // Check for an alias for this type.
178     auto aliasIt = aliases.find(identifier);
179     if (aliasIt == aliases.end())
180       return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
181                                     "'"),
182               nullptr);
183     return aliasIt->second;
184   }
185 
186   // If this isn't an alias, we are parsing a dialect-specific symbol. If the
187   // name contains a dot, then this is the "pretty" form. If not, it is the
188   // verbose form that looks like <...>.
189   if (!isPrettyName) {
190     // Point the symbol data to the end of the dialect name to start.
191     symbolData = StringRef(dialectName.end(), 0);
192 
193     // Parse the body of the symbol.
194     bool isCodeCompletion = false;
195     if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
196       return nullptr;
197     symbolData = symbolData.drop_front();
198 
199     // If the body contained a code completion it won't have the trailing `>`
200     // token, so don't drop it.
201     if (!isCodeCompletion)
202       symbolData = symbolData.drop_back();
203   } else {
204     loc = SMLoc::getFromPointer(symbolData.data());
205 
206     // If the dialect's symbol is followed immediately by a <, then lex the body
207     // of it into prettyName.
208     if (hasTrailingData && p.parseDialectSymbolBody(symbolData))
209       return nullptr;
210   }
211 
212   return createSymbol(dialectName, symbolData, loc);
213 }
214 
215 /// Parse an extended attribute.
216 ///
217 ///   extended-attribute ::= (dialect-attribute | attribute-alias)
218 ///   dialect-attribute  ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
219 ///   dialect-attribute  ::= `#` alias-name pretty-dialect-sym-body?
220 ///   attribute-alias    ::= `#` alias-name
221 ///
parseExtendedAttr(Type type)222 Attribute Parser::parseExtendedAttr(Type type) {
223   MLIRContext *ctx = getContext();
224   Attribute attr = parseExtendedSymbol<Attribute>(
225       *this, state.symbols.attributeAliasDefinitions,
226       [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
227         // Parse an optional trailing colon type.
228         Type attrType = type;
229         if (consumeIf(Token::colon) && !(attrType = parseType()))
230           return Attribute();
231 
232         // If we found a registered dialect, then ask it to parse the attribute.
233         if (Dialect *dialect =
234                 builder.getContext()->getOrLoadDialect(dialectName)) {
235           // Temporarily reset the lexer to let the dialect parse the attribute.
236           const char *curLexerPos = getToken().getLoc().getPointer();
237           resetToken(symbolData.data());
238 
239           // Parse the attribute.
240           CustomDialectAsmParser customParser(symbolData, *this);
241           Attribute attr = dialect->parseAttribute(customParser, attrType);
242           resetToken(curLexerPos);
243           return attr;
244         }
245 
246         // Otherwise, form a new opaque attribute.
247         return OpaqueAttr::getChecked(
248             [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
249             symbolData, attrType ? attrType : NoneType::get(ctx));
250       });
251 
252   // Ensure that the attribute has the same type as requested.
253   if (attr && type && attr.getType() != type) {
254     emitError("attribute type different than expected: expected ")
255         << type << ", but got " << attr.getType();
256     return nullptr;
257   }
258   return attr;
259 }
260 
261 /// Parse an extended type.
262 ///
263 ///   extended-type ::= (dialect-type | type-alias)
264 ///   dialect-type  ::= `!` dialect-namespace `<` `"` type-data `"` `>`
265 ///   dialect-type  ::= `!` alias-name pretty-dialect-attribute-body?
266 ///   type-alias    ::= `!` alias-name
267 ///
parseExtendedType()268 Type Parser::parseExtendedType() {
269   MLIRContext *ctx = getContext();
270   return parseExtendedSymbol<Type>(
271       *this, state.symbols.typeAliasDefinitions,
272       [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
273         // If we found a registered dialect, then ask it to parse the type.
274         if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
275           // Temporarily reset the lexer to let the dialect parse the type.
276           const char *curLexerPos = getToken().getLoc().getPointer();
277           resetToken(symbolData.data());
278 
279           // Parse the type.
280           CustomDialectAsmParser customParser(symbolData, *this);
281           Type type = dialect->parseType(customParser);
282           resetToken(curLexerPos);
283           return type;
284         }
285 
286         // Otherwise, form a new opaque type.
287         return OpaqueType::getChecked([&] { return emitError(loc); },
288                                       StringAttr::get(ctx, dialectName),
289                                       symbolData);
290       });
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // mlir::parseAttribute/parseType
295 //===----------------------------------------------------------------------===//
296 
297 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
298 /// parsing failed, nullptr is returned. The number of bytes read from the input
299 /// string is returned in 'numRead'.
300 template <typename T, typename ParserFn>
parseSymbol(StringRef inputStr,MLIRContext * context,size_t & numRead,ParserFn && parserFn)301 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
302                      ParserFn &&parserFn) {
303   SourceMgr sourceMgr;
304   auto memBuffer = MemoryBuffer::getMemBuffer(
305       inputStr, /*BufferName=*/"<mlir_parser_buffer>",
306       /*RequiresNullTerminator=*/false);
307   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
308   SymbolState aliasState;
309   ParserConfig config(context);
310   ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
311                     /*codeCompleteContext=*/nullptr);
312   Parser parser(state);
313 
314   SourceMgrDiagnosticHandler handler(
315       const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
316       parser.getContext());
317   Token startTok = parser.getToken();
318   T symbol = parserFn(parser);
319   if (!symbol)
320     return T();
321 
322   // Provide the number of bytes that were read.
323   Token endTok = parser.getToken();
324   numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
325                                 startTok.getLoc().getPointer());
326   return symbol;
327 }
328 
parseAttribute(StringRef attrStr,MLIRContext * context)329 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) {
330   size_t numRead = 0;
331   return parseAttribute(attrStr, context, numRead);
332 }
parseAttribute(StringRef attrStr,Type type)333 Attribute mlir::parseAttribute(StringRef attrStr, Type type) {
334   size_t numRead = 0;
335   return parseAttribute(attrStr, type, numRead);
336 }
337 
parseAttribute(StringRef attrStr,MLIRContext * context,size_t & numRead)338 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
339                                size_t &numRead) {
340   return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
341     return parser.parseAttribute();
342   });
343 }
parseAttribute(StringRef attrStr,Type type,size_t & numRead)344 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) {
345   return parseSymbol<Attribute>(
346       attrStr, type.getContext(), numRead,
347       [type](Parser &parser) { return parser.parseAttribute(type); });
348 }
349 
parseType(StringRef typeStr,MLIRContext * context)350 Type mlir::parseType(StringRef typeStr, MLIRContext *context) {
351   size_t numRead = 0;
352   return parseType(typeStr, context, numRead);
353 }
354 
parseType(StringRef typeStr,MLIRContext * context,size_t & numRead)355 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) {
356   return parseSymbol<Type>(typeStr, context, numRead,
357                            [](Parser &parser) { return parser.parseType(); });
358 }
359