1 //===- Lexer.cpp - MLIR Lexer Implementation ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lexer for the MLIR textual form.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Lexer.h"
14 #include "mlir/AsmParser/CodeComplete.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Location.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/Support/SourceMgr.h"
21 
22 using namespace mlir;
23 
24 // Returns true if 'c' is an allowable punctuation character: [$._-]
25 // Returns false otherwise.
isPunct(char c)26 static bool isPunct(char c) {
27   return c == '$' || c == '.' || c == '_' || c == '-';
28 }
29 
Lexer(const llvm::SourceMgr & sourceMgr,MLIRContext * context,AsmParserCodeCompleteContext * codeCompleteContext)30 Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context,
31              AsmParserCodeCompleteContext *codeCompleteContext)
32     : sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) {
33   auto bufferID = sourceMgr.getMainFileID();
34   curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
35   curPtr = curBuffer.begin();
36 
37   // Set the code completion location if it was provided.
38   if (codeCompleteContext)
39     codeCompleteLoc = codeCompleteContext->getCodeCompleteLoc().getPointer();
40 }
41 
42 /// Encode the specified source location information into an attribute for
43 /// attachment to the IR.
getEncodedSourceLocation(SMLoc loc)44 Location Lexer::getEncodedSourceLocation(SMLoc loc) {
45   auto &sourceMgr = getSourceMgr();
46   unsigned mainFileID = sourceMgr.getMainFileID();
47 
48   // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can
49   //       use it here.
50   auto &bufferInfo = sourceMgr.getBufferInfo(mainFileID);
51   unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
52   unsigned column =
53       (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
54   auto *buffer = sourceMgr.getMemoryBuffer(mainFileID);
55 
56   return FileLineColLoc::get(context, buffer->getBufferIdentifier(), lineNo,
57                              column);
58 }
59 
60 /// emitError - Emit an error message and return an Token::error token.
emitError(const char * loc,const Twine & message)61 Token Lexer::emitError(const char *loc, const Twine &message) {
62   mlir::emitError(getEncodedSourceLocation(SMLoc::getFromPointer(loc)),
63                   message);
64   return formToken(Token::error, loc);
65 }
66 
lexToken()67 Token Lexer::lexToken() {
68   while (true) {
69     const char *tokStart = curPtr;
70 
71     // Check to see if the current token is at the code completion location.
72     if (tokStart == codeCompleteLoc)
73       return formToken(Token::code_complete, tokStart);
74 
75     // Lex the next token.
76     switch (*curPtr++) {
77     default:
78       // Handle bare identifiers.
79       if (isalpha(curPtr[-1]))
80         return lexBareIdentifierOrKeyword(tokStart);
81 
82       // Unknown character, emit an error.
83       return emitError(tokStart, "unexpected character");
84 
85     case ' ':
86     case '\t':
87     case '\n':
88     case '\r':
89       // Handle whitespace.
90       continue;
91 
92     case '_':
93       // Handle bare identifiers.
94       return lexBareIdentifierOrKeyword(tokStart);
95 
96     case 0:
97       // This may either be a nul character in the source file or may be the EOF
98       // marker that llvm::MemoryBuffer guarantees will be there.
99       if (curPtr - 1 == curBuffer.end())
100         return formToken(Token::eof, tokStart);
101       continue;
102 
103     case ':':
104       return formToken(Token::colon, tokStart);
105     case ',':
106       return formToken(Token::comma, tokStart);
107     case '.':
108       return lexEllipsis(tokStart);
109     case '(':
110       return formToken(Token::l_paren, tokStart);
111     case ')':
112       return formToken(Token::r_paren, tokStart);
113     case '{':
114       if (*curPtr == '-' && *(curPtr + 1) == '#') {
115         curPtr += 2;
116         return formToken(Token::file_metadata_begin, tokStart);
117       }
118       return formToken(Token::l_brace, tokStart);
119     case '}':
120       return formToken(Token::r_brace, tokStart);
121     case '[':
122       return formToken(Token::l_square, tokStart);
123     case ']':
124       return formToken(Token::r_square, tokStart);
125     case '<':
126       return formToken(Token::less, tokStart);
127     case '>':
128       return formToken(Token::greater, tokStart);
129     case '=':
130       return formToken(Token::equal, tokStart);
131 
132     case '+':
133       return formToken(Token::plus, tokStart);
134     case '*':
135       return formToken(Token::star, tokStart);
136     case '-':
137       if (*curPtr == '>') {
138         ++curPtr;
139         return formToken(Token::arrow, tokStart);
140       }
141       return formToken(Token::minus, tokStart);
142 
143     case '?':
144       return formToken(Token::question, tokStart);
145 
146     case '|':
147       return formToken(Token::vertical_bar, tokStart);
148 
149     case '/':
150       if (*curPtr == '/') {
151         skipComment();
152         continue;
153       }
154       return emitError(tokStart, "unexpected character");
155 
156     case '@':
157       return lexAtIdentifier(tokStart);
158 
159     case '#':
160       if (*curPtr == '-' && *(curPtr + 1) == '}') {
161         curPtr += 2;
162         return formToken(Token::file_metadata_end, tokStart);
163       }
164       LLVM_FALLTHROUGH;
165     case '!':
166     case '^':
167     case '%':
168       return lexPrefixedIdentifier(tokStart);
169     case '"':
170       return lexString(tokStart);
171 
172     case '0':
173     case '1':
174     case '2':
175     case '3':
176     case '4':
177     case '5':
178     case '6':
179     case '7':
180     case '8':
181     case '9':
182       return lexNumber(tokStart);
183     }
184   }
185 }
186 
187 /// Lex an '@foo' identifier.
188 ///
189 ///   symbol-ref-id ::= `@` (bare-id | string-literal)
190 ///
lexAtIdentifier(const char * tokStart)191 Token Lexer::lexAtIdentifier(const char *tokStart) {
192   char cur = *curPtr++;
193 
194   // Try to parse a string literal, if present.
195   if (cur == '"') {
196     Token stringIdentifier = lexString(curPtr);
197     if (stringIdentifier.is(Token::error))
198       return stringIdentifier;
199     return formToken(Token::at_identifier, tokStart);
200   }
201 
202   // Otherwise, these always start with a letter or underscore.
203   if (!isalpha(cur) && cur != '_')
204     return emitError(curPtr - 1,
205                      "@ identifier expected to start with letter or '_'");
206 
207   while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
208          *curPtr == '$' || *curPtr == '.')
209     ++curPtr;
210   return formToken(Token::at_identifier, tokStart);
211 }
212 
213 /// Lex a bare identifier or keyword that starts with a letter.
214 ///
215 ///   bare-id ::= (letter|[_]) (letter|digit|[_$.])*
216 ///   integer-type ::= `[su]?i[1-9][0-9]*`
217 ///
lexBareIdentifierOrKeyword(const char * tokStart)218 Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
219   // Match the rest of the identifier regex: [0-9a-zA-Z_.$]*
220   while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
221          *curPtr == '$' || *curPtr == '.')
222     ++curPtr;
223 
224   // Check to see if this identifier is a keyword.
225   StringRef spelling(tokStart, curPtr - tokStart);
226 
227   auto isAllDigit = [](StringRef str) {
228     return llvm::all_of(str, llvm::isDigit);
229   };
230 
231   // Check for i123, si456, ui789.
232   if ((spelling.size() > 1 && tokStart[0] == 'i' &&
233        isAllDigit(spelling.drop_front())) ||
234       ((spelling.size() > 2 && tokStart[1] == 'i' &&
235         (tokStart[0] == 's' || tokStart[0] == 'u')) &&
236        isAllDigit(spelling.drop_front(2))))
237     return Token(Token::inttype, spelling);
238 
239   Token::Kind kind = StringSwitch<Token::Kind>(spelling)
240 #define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING)
241 #include "TokenKinds.def"
242                          .Default(Token::bare_identifier);
243 
244   return Token(kind, spelling);
245 }
246 
247 /// Skip a comment line, starting with a '//'.
248 ///
249 ///   TODO: add a regex for comments here and to the spec.
250 ///
skipComment()251 void Lexer::skipComment() {
252   // Advance over the second '/' in a '//' comment.
253   assert(*curPtr == '/');
254   ++curPtr;
255 
256   while (true) {
257     switch (*curPtr++) {
258     case '\n':
259     case '\r':
260       // Newline is end of comment.
261       return;
262     case 0:
263       // If this is the end of the buffer, end the comment.
264       if (curPtr - 1 == curBuffer.end()) {
265         --curPtr;
266         return;
267       }
268       LLVM_FALLTHROUGH;
269     default:
270       // Skip over other characters.
271       break;
272     }
273   }
274 }
275 
276 /// Lex an ellipsis.
277 ///
278 ///   ellipsis ::= '...'
279 ///
lexEllipsis(const char * tokStart)280 Token Lexer::lexEllipsis(const char *tokStart) {
281   assert(curPtr[-1] == '.');
282 
283   if (curPtr == curBuffer.end() || *curPtr != '.' || *(curPtr + 1) != '.')
284     return emitError(curPtr, "expected three consecutive dots for an ellipsis");
285 
286   curPtr += 2;
287   return formToken(Token::ellipsis, tokStart);
288 }
289 
290 /// Lex a number literal.
291 ///
292 ///   integer-literal ::= digit+ | `0x` hex_digit+
293 ///   float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
294 ///
lexNumber(const char * tokStart)295 Token Lexer::lexNumber(const char *tokStart) {
296   assert(isdigit(curPtr[-1]));
297 
298   // Handle the hexadecimal case.
299   if (curPtr[-1] == '0' && *curPtr == 'x') {
300     // If we see stuff like 0xi32, this is a literal `0` followed by an
301     // identifier `xi32`, stop after `0`.
302     if (!isxdigit(curPtr[1]))
303       return formToken(Token::integer, tokStart);
304 
305     curPtr += 2;
306     while (isxdigit(*curPtr))
307       ++curPtr;
308 
309     return formToken(Token::integer, tokStart);
310   }
311 
312   // Handle the normal decimal case.
313   while (isdigit(*curPtr))
314     ++curPtr;
315 
316   if (*curPtr != '.')
317     return formToken(Token::integer, tokStart);
318   ++curPtr;
319 
320   // Skip over [0-9]*([eE][-+]?[0-9]+)?
321   while (isdigit(*curPtr))
322     ++curPtr;
323 
324   if (*curPtr == 'e' || *curPtr == 'E') {
325     if (isdigit(static_cast<unsigned char>(curPtr[1])) ||
326         ((curPtr[1] == '-' || curPtr[1] == '+') &&
327          isdigit(static_cast<unsigned char>(curPtr[2])))) {
328       curPtr += 2;
329       while (isdigit(*curPtr))
330         ++curPtr;
331     }
332   }
333   return formToken(Token::floatliteral, tokStart);
334 }
335 
336 /// Lex an identifier that starts with a prefix followed by suffix-id.
337 ///
338 ///   attribute-id  ::= `#` suffix-id
339 ///   ssa-id        ::= '%' suffix-id
340 ///   block-id      ::= '^' suffix-id
341 ///   type-id       ::= '!' suffix-id
342 ///   suffix-id     ::= digit+ | (letter|id-punct) (letter|id-punct|digit)*
343 ///   id-punct      ::= `$` | `.` | `_` | `-`
344 ///
lexPrefixedIdentifier(const char * tokStart)345 Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
346   Token::Kind kind;
347   StringRef errorKind;
348   switch (*tokStart) {
349   case '#':
350     kind = Token::hash_identifier;
351     errorKind = "invalid attribute name";
352     break;
353   case '%':
354     kind = Token::percent_identifier;
355     errorKind = "invalid SSA name";
356     break;
357   case '^':
358     kind = Token::caret_identifier;
359     errorKind = "invalid block name";
360     break;
361   case '!':
362     kind = Token::exclamation_identifier;
363     errorKind = "invalid type identifier";
364     break;
365   default:
366     llvm_unreachable("invalid caller");
367   }
368 
369   // Parse suffix-id.
370   if (isdigit(*curPtr)) {
371     // If suffix-id starts with a digit, the rest must be digits.
372     while (isdigit(*curPtr))
373       ++curPtr;
374   } else if (isalpha(*curPtr) || isPunct(*curPtr)) {
375     do {
376       ++curPtr;
377     } while (isalpha(*curPtr) || isdigit(*curPtr) || isPunct(*curPtr));
378   } else if (curPtr == codeCompleteLoc) {
379     return formToken(Token::code_complete, tokStart);
380   } else {
381     return emitError(curPtr - 1, errorKind);
382   }
383 
384   // Check for a code completion within the identifier.
385   if (codeCompleteLoc && codeCompleteLoc >= tokStart &&
386       codeCompleteLoc <= curPtr) {
387     return Token(Token::code_complete,
388                  StringRef(tokStart, codeCompleteLoc - tokStart));
389   }
390 
391   return formToken(kind, tokStart);
392 }
393 
394 /// Lex a string literal.
395 ///
396 ///   string-literal ::= '"' [^"\n\f\v\r]* '"'
397 ///
398 /// TODO: define escaping rules.
lexString(const char * tokStart)399 Token Lexer::lexString(const char *tokStart) {
400   assert(curPtr[-1] == '"');
401 
402   while (true) {
403     // Check to see if there is a code completion location within the string. In
404     // these cases we generate a completion location and place the currently
405     // lexed string within the token. This allows for the parser to use the
406     // partially lexed string when computing the completion results.
407     if (curPtr == codeCompleteLoc)
408       return formToken(Token::code_complete, tokStart);
409 
410     switch (*curPtr++) {
411     case '"':
412       return formToken(Token::string, tokStart);
413     case 0:
414       // If this is a random nul character in the middle of a string, just
415       // include it.  If it is the end of file, then it is an error.
416       if (curPtr - 1 != curBuffer.end())
417         continue;
418       LLVM_FALLTHROUGH;
419     case '\n':
420     case '\v':
421     case '\f':
422       return emitError(curPtr - 1, "expected '\"' in string literal");
423     case '\\':
424       // Handle explicitly a few escapes.
425       if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
426         ++curPtr;
427       else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
428         // Support \xx for two hex digits.
429         curPtr += 2;
430       else
431         return emitError(curPtr - 1, "unknown escape in string literal");
432       continue;
433 
434     default:
435       continue;
436     }
437   }
438 }
439