1 //===- Lexer.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Lexer.h"
10 #include "mlir/Support/LogicalResult.h"
11 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
12 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringSwitch.h"
15 #include "llvm/Support/SourceMgr.h"
16 
17 using namespace mlir;
18 using namespace mlir::pdll;
19 
20 //===----------------------------------------------------------------------===//
21 // Token
22 //===----------------------------------------------------------------------===//
23 
getStringValue() const24 std::string Token::getStringValue() const {
25   assert(getKind() == string || getKind() == string_block ||
26          getKind() == code_complete_string);
27 
28   // Start by dropping the quotes.
29   StringRef bytes = getSpelling();
30   if (is(string))
31     bytes = bytes.drop_front().drop_back();
32   else if (is(string_block))
33     bytes = bytes.drop_front(2).drop_back(2);
34 
35   std::string result;
36   result.reserve(bytes.size());
37   for (unsigned i = 0, e = bytes.size(); i != e;) {
38     auto c = bytes[i++];
39     if (c != '\\') {
40       result.push_back(c);
41       continue;
42     }
43 
44     assert(i + 1 <= e && "invalid string should be caught by lexer");
45     auto c1 = bytes[i++];
46     switch (c1) {
47     case '"':
48     case '\\':
49       result.push_back(c1);
50       continue;
51     case 'n':
52       result.push_back('\n');
53       continue;
54     case 't':
55       result.push_back('\t');
56       continue;
57     default:
58       break;
59     }
60 
61     assert(i + 1 <= e && "invalid string should be caught by lexer");
62     auto c2 = bytes[i++];
63 
64     assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
65     result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
66   }
67 
68   return result;
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Lexer
73 //===----------------------------------------------------------------------===//
74 
Lexer(llvm::SourceMgr & mgr,ast::DiagnosticEngine & diagEngine,CodeCompleteContext * codeCompleteContext)75 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
76              CodeCompleteContext *codeCompleteContext)
77     : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
78       codeCompletionLocation(nullptr) {
79   curBufferID = mgr.getMainFileID();
80   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
81   curPtr = curBuffer.begin();
82 
83   // Set the code completion location if necessary.
84   if (codeCompleteContext) {
85     codeCompletionLocation =
86         codeCompleteContext->getCodeCompleteLoc().getPointer();
87   }
88 
89   // If the diag engine has no handler, add a default that emits to the
90   // SourceMgr.
91   if (!diagEngine.getHandlerFn()) {
92     diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
93       srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
94                           diag.getMessage());
95       for (const ast::Diagnostic &note : diag.getNotes())
96         srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
97                             note.getMessage());
98     });
99     addedHandlerToDiagEngine = true;
100   }
101 }
102 
~Lexer()103 Lexer::~Lexer() {
104   if (addedHandlerToDiagEngine)
105     diagEngine.setHandlerFn(nullptr);
106 }
107 
pushInclude(StringRef filename,SMRange includeLoc)108 LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) {
109   std::string includedFile;
110   int bufferID =
111       srcMgr.AddIncludeFile(filename.str(), includeLoc.End, includedFile);
112   if (!bufferID)
113     return failure();
114 
115   curBufferID = bufferID;
116   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
117   curPtr = curBuffer.begin();
118   return success();
119 }
120 
emitError(SMRange loc,const Twine & msg)121 Token Lexer::emitError(SMRange loc, const Twine &msg) {
122   diagEngine.emitError(loc, msg);
123   return formToken(Token::error, loc.Start.getPointer());
124 }
emitErrorAndNote(SMRange loc,const Twine & msg,SMRange noteLoc,const Twine & note)125 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
126                               const Twine &note) {
127   diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
128   return formToken(Token::error, loc.Start.getPointer());
129 }
emitError(const char * loc,const Twine & msg)130 Token Lexer::emitError(const char *loc, const Twine &msg) {
131   return emitError(
132       SMRange(SMLoc::getFromPointer(loc), SMLoc::getFromPointer(loc + 1)), msg);
133 }
134 
getNextChar()135 int Lexer::getNextChar() {
136   char curChar = *curPtr++;
137   switch (curChar) {
138   default:
139     return static_cast<unsigned char>(curChar);
140   case 0: {
141     // A nul character in the stream is either the end of the current buffer
142     // or a random nul in the file. Disambiguate that here.
143     if (curPtr - 1 != curBuffer.end())
144       return 0;
145 
146     // Otherwise, return end of file.
147     --curPtr;
148     return EOF;
149   }
150   case '\n':
151   case '\r':
152     // Handle the newline character by ignoring it and incrementing the line
153     // count. However, be careful about 'dos style' files with \n\r in them.
154     // Only treat a \n\r or \r\n as a single line.
155     if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
156       ++curPtr;
157     return '\n';
158   }
159 }
160 
lexToken()161 Token Lexer::lexToken() {
162   while (true) {
163     const char *tokStart = curPtr;
164 
165     // Check to see if this token is at the code completion location.
166     if (tokStart == codeCompletionLocation)
167       return formToken(Token::code_complete, tokStart);
168 
169     // This always consumes at least one character.
170     int curChar = getNextChar();
171     switch (curChar) {
172     default:
173       // Handle identifiers: [a-zA-Z_]
174       if (isalpha(curChar) || curChar == '_')
175         return lexIdentifier(tokStart);
176 
177       // Unknown character, emit an error.
178       return emitError(tokStart, "unexpected character");
179     case EOF: {
180       // Return EOF denoting the end of lexing.
181       Token eof = formToken(Token::eof, tokStart);
182 
183       // Check to see if we are in an included file.
184       SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
185       if (parentIncludeLoc.isValid()) {
186         curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
187         curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
188         curPtr = parentIncludeLoc.getPointer();
189       }
190 
191       return eof;
192     }
193 
194     // Lex punctuation.
195     case '-':
196       if (*curPtr == '>') {
197         ++curPtr;
198         return formToken(Token::arrow, tokStart);
199       }
200       return emitError(tokStart, "unexpected character");
201     case ':':
202       return formToken(Token::colon, tokStart);
203     case ',':
204       return formToken(Token::comma, tokStart);
205     case '.':
206       return formToken(Token::dot, tokStart);
207     case '=':
208       if (*curPtr == '>') {
209         ++curPtr;
210         return formToken(Token::equal_arrow, tokStart);
211       }
212       return formToken(Token::equal, tokStart);
213     case ';':
214       return formToken(Token::semicolon, tokStart);
215     case '[':
216       if (*curPtr == '{') {
217         ++curPtr;
218         return lexString(tokStart, /*isStringBlock=*/true);
219       }
220       return formToken(Token::l_square, tokStart);
221     case ']':
222       return formToken(Token::r_square, tokStart);
223 
224     case '<':
225       return formToken(Token::less, tokStart);
226     case '>':
227       return formToken(Token::greater, tokStart);
228     case '{':
229       return formToken(Token::l_brace, tokStart);
230     case '}':
231       return formToken(Token::r_brace, tokStart);
232     case '(':
233       return formToken(Token::l_paren, tokStart);
234     case ')':
235       return formToken(Token::r_paren, tokStart);
236     case '/':
237       if (*curPtr == '/') {
238         lexComment();
239         continue;
240       }
241       return emitError(tokStart, "unexpected character");
242 
243     // Ignore whitespace characters.
244     case 0:
245     case ' ':
246     case '\t':
247     case '\n':
248       return lexToken();
249 
250     case '#':
251       return lexDirective(tokStart);
252     case '"':
253       return lexString(tokStart, /*isStringBlock=*/false);
254 
255     case '0':
256     case '1':
257     case '2':
258     case '3':
259     case '4':
260     case '5':
261     case '6':
262     case '7':
263     case '8':
264     case '9':
265       return lexNumber(tokStart);
266     }
267   }
268 }
269 
270 /// Skip a comment line, starting with a '//'.
lexComment()271 void Lexer::lexComment() {
272   // Advance over the second '/' in a '//' comment.
273   assert(*curPtr == '/');
274   ++curPtr;
275 
276   while (true) {
277     switch (*curPtr++) {
278     case '\n':
279     case '\r':
280       // Newline is end of comment.
281       return;
282     case 0:
283       // If this is the end of the buffer, end the comment.
284       if (curPtr - 1 == curBuffer.end()) {
285         --curPtr;
286         return;
287       }
288       LLVM_FALLTHROUGH;
289     default:
290       // Skip over other characters.
291       break;
292     }
293   }
294 }
295 
lexDirective(const char * tokStart)296 Token Lexer::lexDirective(const char *tokStart) {
297   // Match the rest with an identifier regex: [0-9a-zA-Z_]*
298   while (isalnum(*curPtr) || *curPtr == '_')
299     ++curPtr;
300 
301   StringRef str(tokStart, curPtr - tokStart);
302   return Token(Token::directive, str);
303 }
304 
lexIdentifier(const char * tokStart)305 Token Lexer::lexIdentifier(const char *tokStart) {
306   // Match the rest of the identifier regex: [0-9a-zA-Z_]*
307   while (isalnum(*curPtr) || *curPtr == '_')
308     ++curPtr;
309 
310   // Check to see if this identifier is a keyword.
311   StringRef str(tokStart, curPtr - tokStart);
312   Token::Kind kind = StringSwitch<Token::Kind>(str)
313                          .Case("attr", Token::kw_attr)
314                          .Case("Attr", Token::kw_Attr)
315                          .Case("erase", Token::kw_erase)
316                          .Case("let", Token::kw_let)
317                          .Case("Constraint", Token::kw_Constraint)
318                          .Case("op", Token::kw_op)
319                          .Case("Op", Token::kw_Op)
320                          .Case("OpName", Token::kw_OpName)
321                          .Case("Pattern", Token::kw_Pattern)
322                          .Case("replace", Token::kw_replace)
323                          .Case("return", Token::kw_return)
324                          .Case("rewrite", Token::kw_rewrite)
325                          .Case("Rewrite", Token::kw_Rewrite)
326                          .Case("type", Token::kw_type)
327                          .Case("Type", Token::kw_Type)
328                          .Case("TypeRange", Token::kw_TypeRange)
329                          .Case("Value", Token::kw_Value)
330                          .Case("ValueRange", Token::kw_ValueRange)
331                          .Case("with", Token::kw_with)
332                          .Case("_", Token::underscore)
333                          .Default(Token::identifier);
334   return Token(kind, str);
335 }
336 
lexNumber(const char * tokStart)337 Token Lexer::lexNumber(const char *tokStart) {
338   assert(isdigit(curPtr[-1]));
339 
340   // Handle the normal decimal case.
341   while (isdigit(*curPtr))
342     ++curPtr;
343 
344   return formToken(Token::integer, tokStart);
345 }
346 
lexString(const char * tokStart,bool isStringBlock)347 Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
348   while (true) {
349     // Check to see if there is a code completion location within the string. In
350     // these cases we generate a completion location and place the currently
351     // lexed string within the token (without the quotes). This allows for the
352     // parser to use the partially lexed string when computing the completion
353     // results.
354     if (curPtr == codeCompletionLocation) {
355       return formToken(Token::code_complete_string,
356                        tokStart + (isStringBlock ? 2 : 1));
357     }
358 
359     switch (*curPtr++) {
360     case '"':
361       // If this is a string block, we only end the string when we encounter a
362       // `}]`.
363       if (!isStringBlock)
364         return formToken(Token::string, tokStart);
365       continue;
366     case '}':
367       // If this is a string block, we only end the string when we encounter a
368       // `}]`.
369       if (!isStringBlock || *curPtr != ']')
370         continue;
371       ++curPtr;
372       return formToken(Token::string_block, tokStart);
373     case 0: {
374       // If this is a random nul character in the middle of a string, just
375       // include it. If it is the end of file, then it is an error.
376       if (curPtr - 1 != curBuffer.end())
377         continue;
378       --curPtr;
379 
380       StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
381       return emitError(curPtr - 1,
382                        "expected '" + expectedEndStr + "' in string literal");
383     }
384 
385     case '\n':
386     case '\v':
387     case '\f':
388       // String blocks allow multiple lines.
389       if (!isStringBlock)
390         return emitError(curPtr - 1, "expected '\"' in string literal");
391       continue;
392 
393     case '\\':
394       // Handle explicitly a few escapes.
395       if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
396           *curPtr == 't') {
397         ++curPtr;
398       } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
399         // Support \xx for two hex digits.
400         curPtr += 2;
401       } else {
402         return emitError(curPtr - 1, "unknown escape in string literal");
403       }
404       continue;
405 
406     default:
407       continue;
408     }
409   }
410 }
411