1 //===- Lexer.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Lexer.h"
10 #include "mlir/Support/LogicalResult.h"
11 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
12 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringSwitch.h"
15 #include "llvm/Support/SourceMgr.h"
16 
17 using namespace mlir;
18 using namespace mlir::pdll;
19 
20 //===----------------------------------------------------------------------===//
21 // Token
22 //===----------------------------------------------------------------------===//
23 
24 std::string Token::getStringValue() const {
25   assert(getKind() == string || getKind() == string_block);
26 
27   // Start by dropping the quotes.
28   StringRef bytes = getSpelling().drop_front().drop_back();
29   if (is(string_block)) bytes = bytes.drop_front().drop_back();
30 
31   std::string result;
32   result.reserve(bytes.size());
33   for (unsigned i = 0, e = bytes.size(); i != e;) {
34     auto c = bytes[i++];
35     if (c != '\\') {
36       result.push_back(c);
37       continue;
38     }
39 
40     assert(i + 1 <= e && "invalid string should be caught by lexer");
41     auto c1 = bytes[i++];
42     switch (c1) {
43       case '"':
44       case '\\':
45         result.push_back(c1);
46         continue;
47       case 'n':
48         result.push_back('\n');
49         continue;
50       case 't':
51         result.push_back('\t');
52         continue;
53       default:
54         break;
55     }
56 
57     assert(i + 1 <= e && "invalid string should be caught by lexer");
58     auto c2 = bytes[i++];
59 
60     assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
61     result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
62   }
63 
64   return result;
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // Lexer
69 //===----------------------------------------------------------------------===//
70 
71 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
72              CodeCompleteContext *codeCompleteContext)
73     : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
74       codeCompletionLocation(nullptr) {
75   curBufferID = mgr.getMainFileID();
76   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
77   curPtr = curBuffer.begin();
78 
79   // Set the code completion location if necessary.
80   if (codeCompleteContext) {
81     codeCompletionLocation =
82         codeCompleteContext->getCodeCompleteLoc().getPointer();
83   }
84 
85   // If the diag engine has no handler, add a default that emits to the
86   // SourceMgr.
87   if (!diagEngine.getHandlerFn()) {
88     diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
89       srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
90                           diag.getMessage());
91       for (const ast::Diagnostic &note : diag.getNotes())
92         srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
93                             note.getMessage());
94     });
95     addedHandlerToDiagEngine = true;
96   }
97 }
98 
99 Lexer::~Lexer() {
100   if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr);
101 }
102 
103 LogicalResult Lexer::pushInclude(StringRef filename) {
104   std::string includedFile;
105   int bufferID = srcMgr.AddIncludeFile(
106       filename.str(), SMLoc::getFromPointer(curPtr), includedFile);
107   if (!bufferID) return failure();
108 
109   curBufferID = bufferID;
110   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
111   curPtr = curBuffer.begin();
112   return success();
113 }
114 
115 Token Lexer::emitError(SMRange loc, const Twine &msg) {
116   diagEngine.emitError(loc, msg);
117   return formToken(Token::error, loc.Start.getPointer());
118 }
119 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg,
120                               SMRange noteLoc, const Twine &note) {
121   diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
122   return formToken(Token::error, loc.Start.getPointer());
123 }
124 Token Lexer::emitError(const char *loc, const Twine &msg) {
125   return emitError(SMRange(SMLoc::getFromPointer(loc),
126                                  SMLoc::getFromPointer(loc + 1)),
127                    msg);
128 }
129 
130 int Lexer::getNextChar() {
131   char curChar = *curPtr++;
132   switch (curChar) {
133     default:
134       return static_cast<unsigned char>(curChar);
135     case 0: {
136       // A nul character in the stream is either the end of the current buffer
137       // or a random nul in the file. Disambiguate that here.
138       if (curPtr - 1 != curBuffer.end()) return 0;
139 
140       // Otherwise, return end of file.
141       --curPtr;
142       return EOF;
143     }
144     case '\n':
145     case '\r':
146       // Handle the newline character by ignoring it and incrementing the line
147       // count. However, be careful about 'dos style' files with \n\r in them.
148       // Only treat a \n\r or \r\n as a single line.
149       if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
150         ++curPtr;
151       return '\n';
152   }
153 }
154 
155 Token Lexer::lexToken() {
156   while (true) {
157     const char *tokStart = curPtr;
158 
159     // Check to see if this token is at the code completion location.
160     if (tokStart == codeCompletionLocation)
161       return formToken(Token::code_complete, tokStart);
162 
163     // This always consumes at least one character.
164     int curChar = getNextChar();
165     switch (curChar) {
166       default:
167         // Handle identifiers: [a-zA-Z_]
168         if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart);
169 
170         // Unknown character, emit an error.
171         return emitError(tokStart, "unexpected character");
172       case EOF: {
173         // Return EOF denoting the end of lexing.
174         Token eof = formToken(Token::eof, tokStart);
175 
176         // Check to see if we are in an included file.
177         SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
178         if (parentIncludeLoc.isValid()) {
179           curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
180           curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
181           curPtr = parentIncludeLoc.getPointer();
182         }
183 
184         return eof;
185       }
186 
187       // Lex punctuation.
188       case '-':
189         if (*curPtr == '>') {
190           ++curPtr;
191           return formToken(Token::arrow, tokStart);
192         }
193         return emitError(tokStart, "unexpected character");
194       case ':':
195         return formToken(Token::colon, tokStart);
196       case ',':
197         return formToken(Token::comma, tokStart);
198       case '.':
199         return formToken(Token::dot, tokStart);
200       case '=':
201         if (*curPtr == '>') {
202           ++curPtr;
203           return formToken(Token::equal_arrow, tokStart);
204         }
205         return formToken(Token::equal, tokStart);
206       case ';':
207         return formToken(Token::semicolon, tokStart);
208       case '[':
209         if (*curPtr == '{') {
210           ++curPtr;
211           return lexString(tokStart, /*isStringBlock=*/true);
212         }
213         return formToken(Token::l_square, tokStart);
214       case ']':
215         return formToken(Token::r_square, tokStart);
216 
217       case '<':
218         return formToken(Token::less, tokStart);
219       case '>':
220         return formToken(Token::greater, tokStart);
221       case '{':
222         return formToken(Token::l_brace, tokStart);
223       case '}':
224         return formToken(Token::r_brace, tokStart);
225       case '(':
226         return formToken(Token::l_paren, tokStart);
227       case ')':
228         return formToken(Token::r_paren, tokStart);
229       case '/':
230         if (*curPtr == '/') {
231           lexComment();
232           continue;
233         }
234         return emitError(tokStart, "unexpected character");
235 
236       // Ignore whitespace characters.
237       case 0:
238       case ' ':
239       case '\t':
240       case '\n':
241         return lexToken();
242 
243       case '#':
244         return lexDirective(tokStart);
245       case '"':
246         return lexString(tokStart, /*isStringBlock=*/false);
247 
248       case '0':
249       case '1':
250       case '2':
251       case '3':
252       case '4':
253       case '5':
254       case '6':
255       case '7':
256       case '8':
257       case '9':
258         return lexNumber(tokStart);
259     }
260   }
261 }
262 
263 /// Skip a comment line, starting with a '//'.
264 void Lexer::lexComment() {
265   // Advance over the second '/' in a '//' comment.
266   assert(*curPtr == '/');
267   ++curPtr;
268 
269   while (true) {
270     switch (*curPtr++) {
271       case '\n':
272       case '\r':
273         // Newline is end of comment.
274         return;
275       case 0:
276         // If this is the end of the buffer, end the comment.
277         if (curPtr - 1 == curBuffer.end()) {
278           --curPtr;
279           return;
280         }
281         LLVM_FALLTHROUGH;
282       default:
283         // Skip over other characters.
284         break;
285     }
286   }
287 }
288 
289 Token Lexer::lexDirective(const char *tokStart) {
290   // Match the rest with an identifier regex: [0-9a-zA-Z_]*
291   while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
292 
293   StringRef str(tokStart, curPtr - tokStart);
294   return Token(Token::directive, str);
295 }
296 
297 Token Lexer::lexIdentifier(const char *tokStart) {
298   // Match the rest of the identifier regex: [0-9a-zA-Z_]*
299   while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
300 
301   // Check to see if this identifier is a keyword.
302   StringRef str(tokStart, curPtr - tokStart);
303   Token::Kind kind = StringSwitch<Token::Kind>(str)
304                          .Case("attr", Token::kw_attr)
305                          .Case("Attr", Token::kw_Attr)
306                          .Case("erase", Token::kw_erase)
307                          .Case("let", Token::kw_let)
308                          .Case("Constraint", Token::kw_Constraint)
309                          .Case("op", Token::kw_op)
310                          .Case("Op", Token::kw_Op)
311                          .Case("OpName", Token::kw_OpName)
312                          .Case("Pattern", Token::kw_Pattern)
313                          .Case("replace", Token::kw_replace)
314                          .Case("return", Token::kw_return)
315                          .Case("rewrite", Token::kw_rewrite)
316                          .Case("Rewrite", Token::kw_Rewrite)
317                          .Case("type", Token::kw_type)
318                          .Case("Type", Token::kw_Type)
319                          .Case("TypeRange", Token::kw_TypeRange)
320                          .Case("Value", Token::kw_Value)
321                          .Case("ValueRange", Token::kw_ValueRange)
322                          .Case("with", Token::kw_with)
323                          .Case("_", Token::underscore)
324                          .Default(Token::identifier);
325   return Token(kind, str);
326 }
327 
328 Token Lexer::lexNumber(const char *tokStart) {
329   assert(isdigit(curPtr[-1]));
330 
331   // Handle the normal decimal case.
332   while (isdigit(*curPtr)) ++curPtr;
333 
334   return formToken(Token::integer, tokStart);
335 }
336 
337 Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
338   while (true) {
339     switch (*curPtr++) {
340       case '"':
341         // If this is a string block, we only end the string when we encounter a
342         // `}]`.
343         if (!isStringBlock) return formToken(Token::string, tokStart);
344         continue;
345       case '}':
346         // If this is a string block, we only end the string when we encounter a
347         // `}]`.
348         if (!isStringBlock || *curPtr != ']') continue;
349         ++curPtr;
350         return formToken(Token::string_block, tokStart);
351       case 0:
352         // If this is a random nul character in the middle of a string, just
353         // include it.  If it is the end of file, then it is an error.
354         if (curPtr - 1 != curBuffer.end()) continue;
355         LLVM_FALLTHROUGH;
356       case '\n':
357       case '\v':
358       case '\f':
359         // String blocks allow multiple lines.
360         if (!isStringBlock)
361           return emitError(curPtr - 1, "expected '\"' in string literal");
362         continue;
363 
364       case '\\':
365         // Handle explicitly a few escapes.
366         if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
367             *curPtr == 't') {
368           ++curPtr;
369         } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
370           // Support \xx for two hex digits.
371           curPtr += 2;
372         } else {
373           return emitError(curPtr - 1, "unknown escape in string literal");
374         }
375         continue;
376 
377       default:
378         continue;
379     }
380   }
381 }
382