1 //===- Lexer.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Lexer.h"
10 #include "mlir/Support/LogicalResult.h"
11 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
12 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringSwitch.h"
15 #include "llvm/Support/SourceMgr.h"
16 
17 using namespace mlir;
18 using namespace mlir::pdll;
19 
20 //===----------------------------------------------------------------------===//
21 // Token
22 //===----------------------------------------------------------------------===//
23 
24 std::string Token::getStringValue() const {
25   assert(getKind() == string || getKind() == string_block ||
26          getKind() == code_complete_string);
27 
28   // Start by dropping the quotes.
29   StringRef bytes = getSpelling();
30   if (is(string))
31     bytes = bytes.drop_front().drop_back();
32   else if (is(string_block))
33     bytes = bytes.drop_front(2).drop_back(2);
34 
35   std::string result;
36   result.reserve(bytes.size());
37   for (unsigned i = 0, e = bytes.size(); i != e;) {
38     auto c = bytes[i++];
39     if (c != '\\') {
40       result.push_back(c);
41       continue;
42     }
43 
44     assert(i + 1 <= e && "invalid string should be caught by lexer");
45     auto c1 = bytes[i++];
46     switch (c1) {
47       case '"':
48       case '\\':
49         result.push_back(c1);
50         continue;
51       case 'n':
52         result.push_back('\n');
53         continue;
54       case 't':
55         result.push_back('\t');
56         continue;
57       default:
58         break;
59     }
60 
61     assert(i + 1 <= e && "invalid string should be caught by lexer");
62     auto c2 = bytes[i++];
63 
64     assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
65     result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
66   }
67 
68   return result;
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Lexer
73 //===----------------------------------------------------------------------===//
74 
75 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
76              CodeCompleteContext *codeCompleteContext)
77     : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
78       codeCompletionLocation(nullptr) {
79   curBufferID = mgr.getMainFileID();
80   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
81   curPtr = curBuffer.begin();
82 
83   // Set the code completion location if necessary.
84   if (codeCompleteContext) {
85     codeCompletionLocation =
86         codeCompleteContext->getCodeCompleteLoc().getPointer();
87   }
88 
89   // If the diag engine has no handler, add a default that emits to the
90   // SourceMgr.
91   if (!diagEngine.getHandlerFn()) {
92     diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
93       srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
94                           diag.getMessage());
95       for (const ast::Diagnostic &note : diag.getNotes())
96         srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
97                             note.getMessage());
98     });
99     addedHandlerToDiagEngine = true;
100   }
101 }
102 
103 Lexer::~Lexer() {
104   if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr);
105 }
106 
107 LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) {
108   std::string includedFile;
109   int bufferID =
110       srcMgr.AddIncludeFile(filename.str(), includeLoc.End, includedFile);
111   if (!bufferID)
112     return failure();
113 
114   curBufferID = bufferID;
115   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
116   curPtr = curBuffer.begin();
117   return success();
118 }
119 
120 Token Lexer::emitError(SMRange loc, const Twine &msg) {
121   diagEngine.emitError(loc, msg);
122   return formToken(Token::error, loc.Start.getPointer());
123 }
124 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg,
125                               SMRange noteLoc, const Twine &note) {
126   diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
127   return formToken(Token::error, loc.Start.getPointer());
128 }
129 Token Lexer::emitError(const char *loc, const Twine &msg) {
130   return emitError(SMRange(SMLoc::getFromPointer(loc),
131                                  SMLoc::getFromPointer(loc + 1)),
132                    msg);
133 }
134 
135 int Lexer::getNextChar() {
136   char curChar = *curPtr++;
137   switch (curChar) {
138     default:
139       return static_cast<unsigned char>(curChar);
140     case 0: {
141       // A nul character in the stream is either the end of the current buffer
142       // or a random nul in the file. Disambiguate that here.
143       if (curPtr - 1 != curBuffer.end()) return 0;
144 
145       // Otherwise, return end of file.
146       --curPtr;
147       return EOF;
148     }
149     case '\n':
150     case '\r':
151       // Handle the newline character by ignoring it and incrementing the line
152       // count. However, be careful about 'dos style' files with \n\r in them.
153       // Only treat a \n\r or \r\n as a single line.
154       if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
155         ++curPtr;
156       return '\n';
157   }
158 }
159 
160 Token Lexer::lexToken() {
161   while (true) {
162     const char *tokStart = curPtr;
163 
164     // Check to see if this token is at the code completion location.
165     if (tokStart == codeCompletionLocation)
166       return formToken(Token::code_complete, tokStart);
167 
168     // This always consumes at least one character.
169     int curChar = getNextChar();
170     switch (curChar) {
171       default:
172         // Handle identifiers: [a-zA-Z_]
173         if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart);
174 
175         // Unknown character, emit an error.
176         return emitError(tokStart, "unexpected character");
177       case EOF: {
178         // Return EOF denoting the end of lexing.
179         Token eof = formToken(Token::eof, tokStart);
180 
181         // Check to see if we are in an included file.
182         SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
183         if (parentIncludeLoc.isValid()) {
184           curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
185           curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
186           curPtr = parentIncludeLoc.getPointer();
187         }
188 
189         return eof;
190       }
191 
192       // Lex punctuation.
193       case '-':
194         if (*curPtr == '>') {
195           ++curPtr;
196           return formToken(Token::arrow, tokStart);
197         }
198         return emitError(tokStart, "unexpected character");
199       case ':':
200         return formToken(Token::colon, tokStart);
201       case ',':
202         return formToken(Token::comma, tokStart);
203       case '.':
204         return formToken(Token::dot, tokStart);
205       case '=':
206         if (*curPtr == '>') {
207           ++curPtr;
208           return formToken(Token::equal_arrow, tokStart);
209         }
210         return formToken(Token::equal, tokStart);
211       case ';':
212         return formToken(Token::semicolon, tokStart);
213       case '[':
214         if (*curPtr == '{') {
215           ++curPtr;
216           return lexString(tokStart, /*isStringBlock=*/true);
217         }
218         return formToken(Token::l_square, tokStart);
219       case ']':
220         return formToken(Token::r_square, tokStart);
221 
222       case '<':
223         return formToken(Token::less, tokStart);
224       case '>':
225         return formToken(Token::greater, tokStart);
226       case '{':
227         return formToken(Token::l_brace, tokStart);
228       case '}':
229         return formToken(Token::r_brace, tokStart);
230       case '(':
231         return formToken(Token::l_paren, tokStart);
232       case ')':
233         return formToken(Token::r_paren, tokStart);
234       case '/':
235         if (*curPtr == '/') {
236           lexComment();
237           continue;
238         }
239         return emitError(tokStart, "unexpected character");
240 
241       // Ignore whitespace characters.
242       case 0:
243       case ' ':
244       case '\t':
245       case '\n':
246         return lexToken();
247 
248       case '#':
249         return lexDirective(tokStart);
250       case '"':
251         return lexString(tokStart, /*isStringBlock=*/false);
252 
253       case '0':
254       case '1':
255       case '2':
256       case '3':
257       case '4':
258       case '5':
259       case '6':
260       case '7':
261       case '8':
262       case '9':
263         return lexNumber(tokStart);
264     }
265   }
266 }
267 
268 /// Skip a comment line, starting with a '//'.
269 void Lexer::lexComment() {
270   // Advance over the second '/' in a '//' comment.
271   assert(*curPtr == '/');
272   ++curPtr;
273 
274   while (true) {
275     switch (*curPtr++) {
276       case '\n':
277       case '\r':
278         // Newline is end of comment.
279         return;
280       case 0:
281         // If this is the end of the buffer, end the comment.
282         if (curPtr - 1 == curBuffer.end()) {
283           --curPtr;
284           return;
285         }
286         LLVM_FALLTHROUGH;
287       default:
288         // Skip over other characters.
289         break;
290     }
291   }
292 }
293 
294 Token Lexer::lexDirective(const char *tokStart) {
295   // Match the rest with an identifier regex: [0-9a-zA-Z_]*
296   while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
297 
298   StringRef str(tokStart, curPtr - tokStart);
299   return Token(Token::directive, str);
300 }
301 
302 Token Lexer::lexIdentifier(const char *tokStart) {
303   // Match the rest of the identifier regex: [0-9a-zA-Z_]*
304   while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
305 
306   // Check to see if this identifier is a keyword.
307   StringRef str(tokStart, curPtr - tokStart);
308   Token::Kind kind = StringSwitch<Token::Kind>(str)
309                          .Case("attr", Token::kw_attr)
310                          .Case("Attr", Token::kw_Attr)
311                          .Case("erase", Token::kw_erase)
312                          .Case("let", Token::kw_let)
313                          .Case("Constraint", Token::kw_Constraint)
314                          .Case("op", Token::kw_op)
315                          .Case("Op", Token::kw_Op)
316                          .Case("OpName", Token::kw_OpName)
317                          .Case("Pattern", Token::kw_Pattern)
318                          .Case("replace", Token::kw_replace)
319                          .Case("return", Token::kw_return)
320                          .Case("rewrite", Token::kw_rewrite)
321                          .Case("Rewrite", Token::kw_Rewrite)
322                          .Case("type", Token::kw_type)
323                          .Case("Type", Token::kw_Type)
324                          .Case("TypeRange", Token::kw_TypeRange)
325                          .Case("Value", Token::kw_Value)
326                          .Case("ValueRange", Token::kw_ValueRange)
327                          .Case("with", Token::kw_with)
328                          .Case("_", Token::underscore)
329                          .Default(Token::identifier);
330   return Token(kind, str);
331 }
332 
333 Token Lexer::lexNumber(const char *tokStart) {
334   assert(isdigit(curPtr[-1]));
335 
336   // Handle the normal decimal case.
337   while (isdigit(*curPtr)) ++curPtr;
338 
339   return formToken(Token::integer, tokStart);
340 }
341 
342 Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
343   while (true) {
344     // Check to see if there is a code completion location within the string. In
345     // these cases we generate a completion location and place the currently
346     // lexed string within the token (without the quotes). This allows for the
347     // parser to use the partially lexed string when computing the completion
348     // results.
349     if (curPtr == codeCompletionLocation) {
350       return formToken(Token::code_complete_string,
351                        tokStart + (isStringBlock ? 2 : 1));
352     }
353 
354     switch (*curPtr++) {
355       case '"':
356         // If this is a string block, we only end the string when we encounter a
357         // `}]`.
358         if (!isStringBlock)
359           return formToken(Token::string, tokStart);
360         continue;
361       case '}':
362         // If this is a string block, we only end the string when we encounter a
363         // `}]`.
364         if (!isStringBlock || *curPtr != ']')
365           continue;
366         ++curPtr;
367         return formToken(Token::string_block, tokStart);
368       case 0: {
369         // If this is a random nul character in the middle of a string, just
370         // include it. If it is the end of file, then it is an error.
371         if (curPtr - 1 != curBuffer.end())
372           continue;
373         --curPtr;
374 
375         StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
376         return emitError(curPtr - 1,
377                          "expected '" + expectedEndStr + "' in string literal");
378       }
379 
380       case '\n':
381       case '\v':
382       case '\f':
383         // String blocks allow multiple lines.
384         if (!isStringBlock)
385           return emitError(curPtr - 1, "expected '\"' in string literal");
386         continue;
387 
388       case '\\':
389         // Handle explicitly a few escapes.
390         if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
391             *curPtr == 't') {
392           ++curPtr;
393         } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
394           // Support \xx for two hex digits.
395           curPtr += 2;
396         } else {
397           return emitError(curPtr - 1, "unknown escape in string literal");
398         }
399         continue;
400 
401       default:
402         continue;
403     }
404   }
405 }
406