1 //===- Lexer.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "Lexer.h"
10 #include "mlir/Support/LogicalResult.h"
11 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
12 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringSwitch.h"
15 #include "llvm/Support/SourceMgr.h"
16
17 using namespace mlir;
18 using namespace mlir::pdll;
19
20 //===----------------------------------------------------------------------===//
21 // Token
22 //===----------------------------------------------------------------------===//
23
getStringValue() const24 std::string Token::getStringValue() const {
25 assert(getKind() == string || getKind() == string_block ||
26 getKind() == code_complete_string);
27
28 // Start by dropping the quotes.
29 StringRef bytes = getSpelling();
30 if (is(string))
31 bytes = bytes.drop_front().drop_back();
32 else if (is(string_block))
33 bytes = bytes.drop_front(2).drop_back(2);
34
35 std::string result;
36 result.reserve(bytes.size());
37 for (unsigned i = 0, e = bytes.size(); i != e;) {
38 auto c = bytes[i++];
39 if (c != '\\') {
40 result.push_back(c);
41 continue;
42 }
43
44 assert(i + 1 <= e && "invalid string should be caught by lexer");
45 auto c1 = bytes[i++];
46 switch (c1) {
47 case '"':
48 case '\\':
49 result.push_back(c1);
50 continue;
51 case 'n':
52 result.push_back('\n');
53 continue;
54 case 't':
55 result.push_back('\t');
56 continue;
57 default:
58 break;
59 }
60
61 assert(i + 1 <= e && "invalid string should be caught by lexer");
62 auto c2 = bytes[i++];
63
64 assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
65 result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
66 }
67
68 return result;
69 }
70
71 //===----------------------------------------------------------------------===//
72 // Lexer
73 //===----------------------------------------------------------------------===//
74
Lexer(llvm::SourceMgr & mgr,ast::DiagnosticEngine & diagEngine,CodeCompleteContext * codeCompleteContext)75 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
76 CodeCompleteContext *codeCompleteContext)
77 : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
78 codeCompletionLocation(nullptr) {
79 curBufferID = mgr.getMainFileID();
80 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
81 curPtr = curBuffer.begin();
82
83 // Set the code completion location if necessary.
84 if (codeCompleteContext) {
85 codeCompletionLocation =
86 codeCompleteContext->getCodeCompleteLoc().getPointer();
87 }
88
89 // If the diag engine has no handler, add a default that emits to the
90 // SourceMgr.
91 if (!diagEngine.getHandlerFn()) {
92 diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
93 srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
94 diag.getMessage());
95 for (const ast::Diagnostic ¬e : diag.getNotes())
96 srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
97 note.getMessage());
98 });
99 addedHandlerToDiagEngine = true;
100 }
101 }
102
~Lexer()103 Lexer::~Lexer() {
104 if (addedHandlerToDiagEngine)
105 diagEngine.setHandlerFn(nullptr);
106 }
107
pushInclude(StringRef filename,SMRange includeLoc)108 LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) {
109 std::string includedFile;
110 int bufferID =
111 srcMgr.AddIncludeFile(filename.str(), includeLoc.End, includedFile);
112 if (!bufferID)
113 return failure();
114
115 curBufferID = bufferID;
116 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
117 curPtr = curBuffer.begin();
118 return success();
119 }
120
emitError(SMRange loc,const Twine & msg)121 Token Lexer::emitError(SMRange loc, const Twine &msg) {
122 diagEngine.emitError(loc, msg);
123 return formToken(Token::error, loc.Start.getPointer());
124 }
emitErrorAndNote(SMRange loc,const Twine & msg,SMRange noteLoc,const Twine & note)125 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
126 const Twine ¬e) {
127 diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
128 return formToken(Token::error, loc.Start.getPointer());
129 }
emitError(const char * loc,const Twine & msg)130 Token Lexer::emitError(const char *loc, const Twine &msg) {
131 return emitError(
132 SMRange(SMLoc::getFromPointer(loc), SMLoc::getFromPointer(loc + 1)), msg);
133 }
134
getNextChar()135 int Lexer::getNextChar() {
136 char curChar = *curPtr++;
137 switch (curChar) {
138 default:
139 return static_cast<unsigned char>(curChar);
140 case 0: {
141 // A nul character in the stream is either the end of the current buffer
142 // or a random nul in the file. Disambiguate that here.
143 if (curPtr - 1 != curBuffer.end())
144 return 0;
145
146 // Otherwise, return end of file.
147 --curPtr;
148 return EOF;
149 }
150 case '\n':
151 case '\r':
152 // Handle the newline character by ignoring it and incrementing the line
153 // count. However, be careful about 'dos style' files with \n\r in them.
154 // Only treat a \n\r or \r\n as a single line.
155 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
156 ++curPtr;
157 return '\n';
158 }
159 }
160
lexToken()161 Token Lexer::lexToken() {
162 while (true) {
163 const char *tokStart = curPtr;
164
165 // Check to see if this token is at the code completion location.
166 if (tokStart == codeCompletionLocation)
167 return formToken(Token::code_complete, tokStart);
168
169 // This always consumes at least one character.
170 int curChar = getNextChar();
171 switch (curChar) {
172 default:
173 // Handle identifiers: [a-zA-Z_]
174 if (isalpha(curChar) || curChar == '_')
175 return lexIdentifier(tokStart);
176
177 // Unknown character, emit an error.
178 return emitError(tokStart, "unexpected character");
179 case EOF: {
180 // Return EOF denoting the end of lexing.
181 Token eof = formToken(Token::eof, tokStart);
182
183 // Check to see if we are in an included file.
184 SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
185 if (parentIncludeLoc.isValid()) {
186 curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
187 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
188 curPtr = parentIncludeLoc.getPointer();
189 }
190
191 return eof;
192 }
193
194 // Lex punctuation.
195 case '-':
196 if (*curPtr == '>') {
197 ++curPtr;
198 return formToken(Token::arrow, tokStart);
199 }
200 return emitError(tokStart, "unexpected character");
201 case ':':
202 return formToken(Token::colon, tokStart);
203 case ',':
204 return formToken(Token::comma, tokStart);
205 case '.':
206 return formToken(Token::dot, tokStart);
207 case '=':
208 if (*curPtr == '>') {
209 ++curPtr;
210 return formToken(Token::equal_arrow, tokStart);
211 }
212 return formToken(Token::equal, tokStart);
213 case ';':
214 return formToken(Token::semicolon, tokStart);
215 case '[':
216 if (*curPtr == '{') {
217 ++curPtr;
218 return lexString(tokStart, /*isStringBlock=*/true);
219 }
220 return formToken(Token::l_square, tokStart);
221 case ']':
222 return formToken(Token::r_square, tokStart);
223
224 case '<':
225 return formToken(Token::less, tokStart);
226 case '>':
227 return formToken(Token::greater, tokStart);
228 case '{':
229 return formToken(Token::l_brace, tokStart);
230 case '}':
231 return formToken(Token::r_brace, tokStart);
232 case '(':
233 return formToken(Token::l_paren, tokStart);
234 case ')':
235 return formToken(Token::r_paren, tokStart);
236 case '/':
237 if (*curPtr == '/') {
238 lexComment();
239 continue;
240 }
241 return emitError(tokStart, "unexpected character");
242
243 // Ignore whitespace characters.
244 case 0:
245 case ' ':
246 case '\t':
247 case '\n':
248 return lexToken();
249
250 case '#':
251 return lexDirective(tokStart);
252 case '"':
253 return lexString(tokStart, /*isStringBlock=*/false);
254
255 case '0':
256 case '1':
257 case '2':
258 case '3':
259 case '4':
260 case '5':
261 case '6':
262 case '7':
263 case '8':
264 case '9':
265 return lexNumber(tokStart);
266 }
267 }
268 }
269
270 /// Skip a comment line, starting with a '//'.
lexComment()271 void Lexer::lexComment() {
272 // Advance over the second '/' in a '//' comment.
273 assert(*curPtr == '/');
274 ++curPtr;
275
276 while (true) {
277 switch (*curPtr++) {
278 case '\n':
279 case '\r':
280 // Newline is end of comment.
281 return;
282 case 0:
283 // If this is the end of the buffer, end the comment.
284 if (curPtr - 1 == curBuffer.end()) {
285 --curPtr;
286 return;
287 }
288 LLVM_FALLTHROUGH;
289 default:
290 // Skip over other characters.
291 break;
292 }
293 }
294 }
295
lexDirective(const char * tokStart)296 Token Lexer::lexDirective(const char *tokStart) {
297 // Match the rest with an identifier regex: [0-9a-zA-Z_]*
298 while (isalnum(*curPtr) || *curPtr == '_')
299 ++curPtr;
300
301 StringRef str(tokStart, curPtr - tokStart);
302 return Token(Token::directive, str);
303 }
304
lexIdentifier(const char * tokStart)305 Token Lexer::lexIdentifier(const char *tokStart) {
306 // Match the rest of the identifier regex: [0-9a-zA-Z_]*
307 while (isalnum(*curPtr) || *curPtr == '_')
308 ++curPtr;
309
310 // Check to see if this identifier is a keyword.
311 StringRef str(tokStart, curPtr - tokStart);
312 Token::Kind kind = StringSwitch<Token::Kind>(str)
313 .Case("attr", Token::kw_attr)
314 .Case("Attr", Token::kw_Attr)
315 .Case("erase", Token::kw_erase)
316 .Case("let", Token::kw_let)
317 .Case("Constraint", Token::kw_Constraint)
318 .Case("op", Token::kw_op)
319 .Case("Op", Token::kw_Op)
320 .Case("OpName", Token::kw_OpName)
321 .Case("Pattern", Token::kw_Pattern)
322 .Case("replace", Token::kw_replace)
323 .Case("return", Token::kw_return)
324 .Case("rewrite", Token::kw_rewrite)
325 .Case("Rewrite", Token::kw_Rewrite)
326 .Case("type", Token::kw_type)
327 .Case("Type", Token::kw_Type)
328 .Case("TypeRange", Token::kw_TypeRange)
329 .Case("Value", Token::kw_Value)
330 .Case("ValueRange", Token::kw_ValueRange)
331 .Case("with", Token::kw_with)
332 .Case("_", Token::underscore)
333 .Default(Token::identifier);
334 return Token(kind, str);
335 }
336
lexNumber(const char * tokStart)337 Token Lexer::lexNumber(const char *tokStart) {
338 assert(isdigit(curPtr[-1]));
339
340 // Handle the normal decimal case.
341 while (isdigit(*curPtr))
342 ++curPtr;
343
344 return formToken(Token::integer, tokStart);
345 }
346
lexString(const char * tokStart,bool isStringBlock)347 Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
348 while (true) {
349 // Check to see if there is a code completion location within the string. In
350 // these cases we generate a completion location and place the currently
351 // lexed string within the token (without the quotes). This allows for the
352 // parser to use the partially lexed string when computing the completion
353 // results.
354 if (curPtr == codeCompletionLocation) {
355 return formToken(Token::code_complete_string,
356 tokStart + (isStringBlock ? 2 : 1));
357 }
358
359 switch (*curPtr++) {
360 case '"':
361 // If this is a string block, we only end the string when we encounter a
362 // `}]`.
363 if (!isStringBlock)
364 return formToken(Token::string, tokStart);
365 continue;
366 case '}':
367 // If this is a string block, we only end the string when we encounter a
368 // `}]`.
369 if (!isStringBlock || *curPtr != ']')
370 continue;
371 ++curPtr;
372 return formToken(Token::string_block, tokStart);
373 case 0: {
374 // If this is a random nul character in the middle of a string, just
375 // include it. If it is the end of file, then it is an error.
376 if (curPtr - 1 != curBuffer.end())
377 continue;
378 --curPtr;
379
380 StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
381 return emitError(curPtr - 1,
382 "expected '" + expectedEndStr + "' in string literal");
383 }
384
385 case '\n':
386 case '\v':
387 case '\f':
388 // String blocks allow multiple lines.
389 if (!isStringBlock)
390 return emitError(curPtr - 1, "expected '\"' in string literal");
391 continue;
392
393 case '\\':
394 // Handle explicitly a few escapes.
395 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
396 *curPtr == 't') {
397 ++curPtr;
398 } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
399 // Support \xx for two hex digits.
400 curPtr += 2;
401 } else {
402 return emitError(curPtr - 1, "unknown escape in string literal");
403 }
404 continue;
405
406 default:
407 continue;
408 }
409 }
410 }
411