1 //===- Lexer.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Lexer.h"
10 #include "mlir/Support/LogicalResult.h"
11 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringSwitch.h"
14 #include "llvm/Support/SourceMgr.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdll;
18 
19 //===----------------------------------------------------------------------===//
20 // Token
21 //===----------------------------------------------------------------------===//
22 
23 std::string Token::getStringValue() const {
24   assert(getKind() == string || getKind() == string_block);
25 
26   // Start by dropping the quotes.
27   StringRef bytes = getSpelling().drop_front().drop_back();
28   if (is(string_block)) bytes = bytes.drop_front().drop_back();
29 
30   std::string result;
31   result.reserve(bytes.size());
32   for (unsigned i = 0, e = bytes.size(); i != e;) {
33     auto c = bytes[i++];
34     if (c != '\\') {
35       result.push_back(c);
36       continue;
37     }
38 
39     assert(i + 1 <= e && "invalid string should be caught by lexer");
40     auto c1 = bytes[i++];
41     switch (c1) {
42       case '"':
43       case '\\':
44         result.push_back(c1);
45         continue;
46       case 'n':
47         result.push_back('\n');
48         continue;
49       case 't':
50         result.push_back('\t');
51         continue;
52       default:
53         break;
54     }
55 
56     assert(i + 1 <= e && "invalid string should be caught by lexer");
57     auto c2 = bytes[i++];
58 
59     assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
60     result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
61   }
62 
63   return result;
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // Lexer
68 //===----------------------------------------------------------------------===//
69 
70 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine)
71     : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false) {
72   curBufferID = mgr.getMainFileID();
73   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
74   curPtr = curBuffer.begin();
75 
76   // If the diag engine has no handler, add a default that emits to the
77   // SourceMgr.
78   if (!diagEngine.getHandlerFn()) {
79     diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
80       srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
81                           diag.getMessage());
82       for (const ast::Diagnostic &note : diag.getNotes())
83         srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
84                             note.getMessage());
85     });
86     addedHandlerToDiagEngine = true;
87   }
88 }
89 
90 Lexer::~Lexer() {
91   if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr);
92 }
93 
94 LogicalResult Lexer::pushInclude(StringRef filename) {
95   std::string includedFile;
96   int bufferID = srcMgr.AddIncludeFile(
97       filename.str(), SMLoc::getFromPointer(curPtr), includedFile);
98   if (!bufferID) return failure();
99 
100   curBufferID = bufferID;
101   curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
102   curPtr = curBuffer.begin();
103   return success();
104 }
105 
106 Token Lexer::emitError(SMRange loc, const Twine &msg) {
107   diagEngine.emitError(loc, msg);
108   return formToken(Token::error, loc.Start.getPointer());
109 }
110 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg,
111                               SMRange noteLoc, const Twine &note) {
112   diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
113   return formToken(Token::error, loc.Start.getPointer());
114 }
115 Token Lexer::emitError(const char *loc, const Twine &msg) {
116   return emitError(SMRange(SMLoc::getFromPointer(loc),
117                                  SMLoc::getFromPointer(loc + 1)),
118                    msg);
119 }
120 
121 int Lexer::getNextChar() {
122   char curChar = *curPtr++;
123   switch (curChar) {
124     default:
125       return static_cast<unsigned char>(curChar);
126     case 0: {
127       // A nul character in the stream is either the end of the current buffer
128       // or a random nul in the file. Disambiguate that here.
129       if (curPtr - 1 != curBuffer.end()) return 0;
130 
131       // Otherwise, return end of file.
132       --curPtr;
133       return EOF;
134     }
135     case '\n':
136     case '\r':
137       // Handle the newline character by ignoring it and incrementing the line
138       // count. However, be careful about 'dos style' files with \n\r in them.
139       // Only treat a \n\r or \r\n as a single line.
140       if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
141         ++curPtr;
142       return '\n';
143   }
144 }
145 
146 Token Lexer::lexToken() {
147   while (true) {
148     const char *tokStart = curPtr;
149 
150     // This always consumes at least one character.
151     int curChar = getNextChar();
152     switch (curChar) {
153       default:
154         // Handle identifiers: [a-zA-Z_]
155         if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart);
156 
157         // Unknown character, emit an error.
158         return emitError(tokStart, "unexpected character");
159       case EOF: {
160         // Return EOF denoting the end of lexing.
161         Token eof = formToken(Token::eof, tokStart);
162 
163         // Check to see if we are in an included file.
164         SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
165         if (parentIncludeLoc.isValid()) {
166           curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
167           curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
168           curPtr = parentIncludeLoc.getPointer();
169         }
170 
171         return eof;
172       }
173 
174       // Lex punctuation.
175       case '-':
176         if (*curPtr == '>') {
177           ++curPtr;
178           return formToken(Token::arrow, tokStart);
179         }
180         return emitError(tokStart, "unexpected character");
181       case ':':
182         return formToken(Token::colon, tokStart);
183       case ',':
184         return formToken(Token::comma, tokStart);
185       case '.':
186         return formToken(Token::dot, tokStart);
187       case '=':
188         if (*curPtr == '>') {
189           ++curPtr;
190           return formToken(Token::equal_arrow, tokStart);
191         }
192         return formToken(Token::equal, tokStart);
193       case ';':
194         return formToken(Token::semicolon, tokStart);
195       case '[':
196         if (*curPtr == '{') {
197           ++curPtr;
198           return lexString(tokStart, /*isStringBlock=*/true);
199         }
200         return formToken(Token::l_square, tokStart);
201       case ']':
202         return formToken(Token::r_square, tokStart);
203 
204       case '<':
205         return formToken(Token::less, tokStart);
206       case '>':
207         return formToken(Token::greater, tokStart);
208       case '{':
209         return formToken(Token::l_brace, tokStart);
210       case '}':
211         return formToken(Token::r_brace, tokStart);
212       case '(':
213         return formToken(Token::l_paren, tokStart);
214       case ')':
215         return formToken(Token::r_paren, tokStart);
216       case '/':
217         if (*curPtr == '/') {
218           lexComment();
219           continue;
220         }
221         return emitError(tokStart, "unexpected character");
222 
223       // Ignore whitespace characters.
224       case 0:
225       case ' ':
226       case '\t':
227       case '\n':
228         return lexToken();
229 
230       case '#':
231         return lexDirective(tokStart);
232       case '"':
233         return lexString(tokStart, /*isStringBlock=*/false);
234 
235       case '0':
236       case '1':
237       case '2':
238       case '3':
239       case '4':
240       case '5':
241       case '6':
242       case '7':
243       case '8':
244       case '9':
245         return lexNumber(tokStart);
246     }
247   }
248 }
249 
250 /// Skip a comment line, starting with a '//'.
251 void Lexer::lexComment() {
252   // Advance over the second '/' in a '//' comment.
253   assert(*curPtr == '/');
254   ++curPtr;
255 
256   while (true) {
257     switch (*curPtr++) {
258       case '\n':
259       case '\r':
260         // Newline is end of comment.
261         return;
262       case 0:
263         // If this is the end of the buffer, end the comment.
264         if (curPtr - 1 == curBuffer.end()) {
265           --curPtr;
266           return;
267         }
268         LLVM_FALLTHROUGH;
269       default:
270         // Skip over other characters.
271         break;
272     }
273   }
274 }
275 
276 Token Lexer::lexDirective(const char *tokStart) {
277   // Match the rest with an identifier regex: [0-9a-zA-Z_]*
278   while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
279 
280   StringRef str(tokStart, curPtr - tokStart);
281   return Token(Token::directive, str);
282 }
283 
284 Token Lexer::lexIdentifier(const char *tokStart) {
285   // Match the rest of the identifier regex: [0-9a-zA-Z_]*
286   while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
287 
288   // Check to see if this identifier is a keyword.
289   StringRef str(tokStart, curPtr - tokStart);
290   Token::Kind kind = StringSwitch<Token::Kind>(str)
291                          .Case("attr", Token::kw_attr)
292                          .Case("Attr", Token::kw_Attr)
293                          .Case("erase", Token::kw_erase)
294                          .Case("let", Token::kw_let)
295                          .Case("Constraint", Token::kw_Constraint)
296                          .Case("op", Token::kw_op)
297                          .Case("Op", Token::kw_Op)
298                          .Case("OpName", Token::kw_OpName)
299                          .Case("Pattern", Token::kw_Pattern)
300                          .Case("replace", Token::kw_replace)
301                          .Case("return", Token::kw_return)
302                          .Case("rewrite", Token::kw_rewrite)
303                          .Case("Rewrite", Token::kw_Rewrite)
304                          .Case("type", Token::kw_type)
305                          .Case("Type", Token::kw_Type)
306                          .Case("TypeRange", Token::kw_TypeRange)
307                          .Case("Value", Token::kw_Value)
308                          .Case("ValueRange", Token::kw_ValueRange)
309                          .Case("with", Token::kw_with)
310                          .Case("_", Token::underscore)
311                          .Default(Token::identifier);
312   return Token(kind, str);
313 }
314 
315 Token Lexer::lexNumber(const char *tokStart) {
316   assert(isdigit(curPtr[-1]));
317 
318   // Handle the normal decimal case.
319   while (isdigit(*curPtr)) ++curPtr;
320 
321   return formToken(Token::integer, tokStart);
322 }
323 
324 Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
325   while (true) {
326     switch (*curPtr++) {
327       case '"':
328         // If this is a string block, we only end the string when we encounter a
329         // `}]`.
330         if (!isStringBlock) return formToken(Token::string, tokStart);
331         continue;
332       case '}':
333         // If this is a string block, we only end the string when we encounter a
334         // `}]`.
335         if (!isStringBlock || *curPtr != ']') continue;
336         ++curPtr;
337         return formToken(Token::string_block, tokStart);
338       case 0:
339         // If this is a random nul character in the middle of a string, just
340         // include it.  If it is the end of file, then it is an error.
341         if (curPtr - 1 != curBuffer.end()) continue;
342         LLVM_FALLTHROUGH;
343       case '\n':
344       case '\v':
345       case '\f':
346         // String blocks allow multiple lines.
347         if (!isStringBlock)
348           return emitError(curPtr - 1, "expected '\"' in string literal");
349         continue;
350 
351       case '\\':
352         // Handle explicitly a few escapes.
353         if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
354             *curPtr == 't') {
355           ++curPtr;
356         } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
357           // Support \xx for two hex digits.
358           curPtr += 2;
359         } else {
360           return emitError(curPtr - 1, "unknown escape in string literal");
361         }
362         continue;
363 
364       default:
365         continue;
366     }
367   }
368 }
369