1 //===- Lexer.cpp ----------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Lexer.h" 10 #include "mlir/Support/LogicalResult.h" 11 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 12 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/ADT/StringSwitch.h" 15 #include "llvm/Support/SourceMgr.h" 16 17 using namespace mlir; 18 using namespace mlir::pdll; 19 20 //===----------------------------------------------------------------------===// 21 // Token 22 //===----------------------------------------------------------------------===// 23 24 std::string Token::getStringValue() const { 25 assert(getKind() == string || getKind() == string_block || 26 getKind() == code_complete_string); 27 28 // Start by dropping the quotes. 29 StringRef bytes = getSpelling(); 30 if (is(string)) 31 bytes = bytes.drop_front().drop_back(); 32 else if (is(string_block)) 33 bytes = bytes.drop_front(2).drop_back(2); 34 35 std::string result; 36 result.reserve(bytes.size()); 37 for (unsigned i = 0, e = bytes.size(); i != e;) { 38 auto c = bytes[i++]; 39 if (c != '\\') { 40 result.push_back(c); 41 continue; 42 } 43 44 assert(i + 1 <= e && "invalid string should be caught by lexer"); 45 auto c1 = bytes[i++]; 46 switch (c1) { 47 case '"': 48 case '\\': 49 result.push_back(c1); 50 continue; 51 case 'n': 52 result.push_back('\n'); 53 continue; 54 case 't': 55 result.push_back('\t'); 56 continue; 57 default: 58 break; 59 } 60 61 assert(i + 1 <= e && "invalid string should be caught by lexer"); 62 auto c2 = bytes[i++]; 63 64 assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape"); 65 result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2)); 66 } 67 68 return result; 69 } 70 71 //===----------------------------------------------------------------------===// 72 // Lexer 73 //===----------------------------------------------------------------------===// 74 75 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine, 76 CodeCompleteContext *codeCompleteContext) 77 : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false), 78 codeCompletionLocation(nullptr) { 79 curBufferID = mgr.getMainFileID(); 80 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 81 curPtr = curBuffer.begin(); 82 83 // Set the code completion location if necessary. 84 if (codeCompleteContext) { 85 codeCompletionLocation = 86 codeCompleteContext->getCodeCompleteLoc().getPointer(); 87 } 88 89 // If the diag engine has no handler, add a default that emits to the 90 // SourceMgr. 91 if (!diagEngine.getHandlerFn()) { 92 diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) { 93 srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(), 94 diag.getMessage()); 95 for (const ast::Diagnostic ¬e : diag.getNotes()) 96 srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(), 97 note.getMessage()); 98 }); 99 addedHandlerToDiagEngine = true; 100 } 101 } 102 103 Lexer::~Lexer() { 104 if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr); 105 } 106 107 LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) { 108 std::string includedFile; 109 int bufferID = 110 srcMgr.AddIncludeFile(filename.str(), includeLoc.End, includedFile); 111 if (!bufferID) 112 return failure(); 113 114 curBufferID = bufferID; 115 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 116 curPtr = curBuffer.begin(); 117 return success(); 118 } 119 120 Token Lexer::emitError(SMRange loc, const Twine &msg) { 121 diagEngine.emitError(loc, msg); 122 return formToken(Token::error, loc.Start.getPointer()); 123 } 124 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, 125 SMRange noteLoc, const Twine ¬e) { 126 diagEngine.emitError(loc, msg)->attachNote(note, noteLoc); 127 return formToken(Token::error, loc.Start.getPointer()); 128 } 129 Token Lexer::emitError(const char *loc, const Twine &msg) { 130 return emitError(SMRange(SMLoc::getFromPointer(loc), 131 SMLoc::getFromPointer(loc + 1)), 132 msg); 133 } 134 135 int Lexer::getNextChar() { 136 char curChar = *curPtr++; 137 switch (curChar) { 138 default: 139 return static_cast<unsigned char>(curChar); 140 case 0: { 141 // A nul character in the stream is either the end of the current buffer 142 // or a random nul in the file. Disambiguate that here. 143 if (curPtr - 1 != curBuffer.end()) return 0; 144 145 // Otherwise, return end of file. 146 --curPtr; 147 return EOF; 148 } 149 case '\n': 150 case '\r': 151 // Handle the newline character by ignoring it and incrementing the line 152 // count. However, be careful about 'dos style' files with \n\r in them. 153 // Only treat a \n\r or \r\n as a single line. 154 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) 155 ++curPtr; 156 return '\n'; 157 } 158 } 159 160 Token Lexer::lexToken() { 161 while (true) { 162 const char *tokStart = curPtr; 163 164 // Check to see if this token is at the code completion location. 165 if (tokStart == codeCompletionLocation) 166 return formToken(Token::code_complete, tokStart); 167 168 // This always consumes at least one character. 169 int curChar = getNextChar(); 170 switch (curChar) { 171 default: 172 // Handle identifiers: [a-zA-Z_] 173 if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart); 174 175 // Unknown character, emit an error. 176 return emitError(tokStart, "unexpected character"); 177 case EOF: { 178 // Return EOF denoting the end of lexing. 179 Token eof = formToken(Token::eof, tokStart); 180 181 // Check to see if we are in an included file. 182 SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID); 183 if (parentIncludeLoc.isValid()) { 184 curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc); 185 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 186 curPtr = parentIncludeLoc.getPointer(); 187 } 188 189 return eof; 190 } 191 192 // Lex punctuation. 193 case '-': 194 if (*curPtr == '>') { 195 ++curPtr; 196 return formToken(Token::arrow, tokStart); 197 } 198 return emitError(tokStart, "unexpected character"); 199 case ':': 200 return formToken(Token::colon, tokStart); 201 case ',': 202 return formToken(Token::comma, tokStart); 203 case '.': 204 return formToken(Token::dot, tokStart); 205 case '=': 206 if (*curPtr == '>') { 207 ++curPtr; 208 return formToken(Token::equal_arrow, tokStart); 209 } 210 return formToken(Token::equal, tokStart); 211 case ';': 212 return formToken(Token::semicolon, tokStart); 213 case '[': 214 if (*curPtr == '{') { 215 ++curPtr; 216 return lexString(tokStart, /*isStringBlock=*/true); 217 } 218 return formToken(Token::l_square, tokStart); 219 case ']': 220 return formToken(Token::r_square, tokStart); 221 222 case '<': 223 return formToken(Token::less, tokStart); 224 case '>': 225 return formToken(Token::greater, tokStart); 226 case '{': 227 return formToken(Token::l_brace, tokStart); 228 case '}': 229 return formToken(Token::r_brace, tokStart); 230 case '(': 231 return formToken(Token::l_paren, tokStart); 232 case ')': 233 return formToken(Token::r_paren, tokStart); 234 case '/': 235 if (*curPtr == '/') { 236 lexComment(); 237 continue; 238 } 239 return emitError(tokStart, "unexpected character"); 240 241 // Ignore whitespace characters. 242 case 0: 243 case ' ': 244 case '\t': 245 case '\n': 246 return lexToken(); 247 248 case '#': 249 return lexDirective(tokStart); 250 case '"': 251 return lexString(tokStart, /*isStringBlock=*/false); 252 253 case '0': 254 case '1': 255 case '2': 256 case '3': 257 case '4': 258 case '5': 259 case '6': 260 case '7': 261 case '8': 262 case '9': 263 return lexNumber(tokStart); 264 } 265 } 266 } 267 268 /// Skip a comment line, starting with a '//'. 269 void Lexer::lexComment() { 270 // Advance over the second '/' in a '//' comment. 271 assert(*curPtr == '/'); 272 ++curPtr; 273 274 while (true) { 275 switch (*curPtr++) { 276 case '\n': 277 case '\r': 278 // Newline is end of comment. 279 return; 280 case 0: 281 // If this is the end of the buffer, end the comment. 282 if (curPtr - 1 == curBuffer.end()) { 283 --curPtr; 284 return; 285 } 286 LLVM_FALLTHROUGH; 287 default: 288 // Skip over other characters. 289 break; 290 } 291 } 292 } 293 294 Token Lexer::lexDirective(const char *tokStart) { 295 // Match the rest with an identifier regex: [0-9a-zA-Z_]* 296 while (isalnum(*curPtr) || *curPtr == '_') ++curPtr; 297 298 StringRef str(tokStart, curPtr - tokStart); 299 return Token(Token::directive, str); 300 } 301 302 Token Lexer::lexIdentifier(const char *tokStart) { 303 // Match the rest of the identifier regex: [0-9a-zA-Z_]* 304 while (isalnum(*curPtr) || *curPtr == '_') ++curPtr; 305 306 // Check to see if this identifier is a keyword. 307 StringRef str(tokStart, curPtr - tokStart); 308 Token::Kind kind = StringSwitch<Token::Kind>(str) 309 .Case("attr", Token::kw_attr) 310 .Case("Attr", Token::kw_Attr) 311 .Case("erase", Token::kw_erase) 312 .Case("let", Token::kw_let) 313 .Case("Constraint", Token::kw_Constraint) 314 .Case("op", Token::kw_op) 315 .Case("Op", Token::kw_Op) 316 .Case("OpName", Token::kw_OpName) 317 .Case("Pattern", Token::kw_Pattern) 318 .Case("replace", Token::kw_replace) 319 .Case("return", Token::kw_return) 320 .Case("rewrite", Token::kw_rewrite) 321 .Case("Rewrite", Token::kw_Rewrite) 322 .Case("type", Token::kw_type) 323 .Case("Type", Token::kw_Type) 324 .Case("TypeRange", Token::kw_TypeRange) 325 .Case("Value", Token::kw_Value) 326 .Case("ValueRange", Token::kw_ValueRange) 327 .Case("with", Token::kw_with) 328 .Case("_", Token::underscore) 329 .Default(Token::identifier); 330 return Token(kind, str); 331 } 332 333 Token Lexer::lexNumber(const char *tokStart) { 334 assert(isdigit(curPtr[-1])); 335 336 // Handle the normal decimal case. 337 while (isdigit(*curPtr)) ++curPtr; 338 339 return formToken(Token::integer, tokStart); 340 } 341 342 Token Lexer::lexString(const char *tokStart, bool isStringBlock) { 343 while (true) { 344 // Check to see if there is a code completion location within the string. In 345 // these cases we generate a completion location and place the currently 346 // lexed string within the token (without the quotes). This allows for the 347 // parser to use the partially lexed string when computing the completion 348 // results. 349 if (curPtr == codeCompletionLocation) { 350 return formToken(Token::code_complete_string, 351 tokStart + (isStringBlock ? 2 : 1)); 352 } 353 354 switch (*curPtr++) { 355 case '"': 356 // If this is a string block, we only end the string when we encounter a 357 // `}]`. 358 if (!isStringBlock) 359 return formToken(Token::string, tokStart); 360 continue; 361 case '}': 362 // If this is a string block, we only end the string when we encounter a 363 // `}]`. 364 if (!isStringBlock || *curPtr != ']') 365 continue; 366 ++curPtr; 367 return formToken(Token::string_block, tokStart); 368 case 0: { 369 // If this is a random nul character in the middle of a string, just 370 // include it. If it is the end of file, then it is an error. 371 if (curPtr - 1 != curBuffer.end()) 372 continue; 373 --curPtr; 374 375 StringRef expectedEndStr = isStringBlock ? "}]" : "\""; 376 return emitError(curPtr - 1, 377 "expected '" + expectedEndStr + "' in string literal"); 378 } 379 380 case '\n': 381 case '\v': 382 case '\f': 383 // String blocks allow multiple lines. 384 if (!isStringBlock) 385 return emitError(curPtr - 1, "expected '\"' in string literal"); 386 continue; 387 388 case '\\': 389 // Handle explicitly a few escapes. 390 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || 391 *curPtr == 't') { 392 ++curPtr; 393 } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) { 394 // Support \xx for two hex digits. 395 curPtr += 2; 396 } else { 397 return emitError(curPtr - 1, "unknown escape in string literal"); 398 } 399 continue; 400 401 default: 402 continue; 403 } 404 } 405 } 406