1 //===- Lexer.cpp ----------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Lexer.h" 10 #include "mlir/Support/LogicalResult.h" 11 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 12 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/ADT/StringSwitch.h" 15 #include "llvm/Support/SourceMgr.h" 16 17 using namespace mlir; 18 using namespace mlir::pdll; 19 20 //===----------------------------------------------------------------------===// 21 // Token 22 //===----------------------------------------------------------------------===// 23 24 std::string Token::getStringValue() const { 25 assert(getKind() == string || getKind() == string_block || 26 getKind() == code_complete_string); 27 28 // Start by dropping the quotes. 29 StringRef bytes = getSpelling(); 30 if (is(string)) 31 bytes = bytes.drop_front().drop_back(); 32 else if (is(string_block)) 33 bytes = bytes.drop_front(2).drop_back(2); 34 35 std::string result; 36 result.reserve(bytes.size()); 37 for (unsigned i = 0, e = bytes.size(); i != e;) { 38 auto c = bytes[i++]; 39 if (c != '\\') { 40 result.push_back(c); 41 continue; 42 } 43 44 assert(i + 1 <= e && "invalid string should be caught by lexer"); 45 auto c1 = bytes[i++]; 46 switch (c1) { 47 case '"': 48 case '\\': 49 result.push_back(c1); 50 continue; 51 case 'n': 52 result.push_back('\n'); 53 continue; 54 case 't': 55 result.push_back('\t'); 56 continue; 57 default: 58 break; 59 } 60 61 assert(i + 1 <= e && "invalid string should be caught by lexer"); 62 auto c2 = bytes[i++]; 63 64 assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape"); 65 result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2)); 66 } 67 68 return result; 69 } 70 71 //===----------------------------------------------------------------------===// 72 // Lexer 73 //===----------------------------------------------------------------------===// 74 75 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine, 76 CodeCompleteContext *codeCompleteContext) 77 : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false), 78 codeCompletionLocation(nullptr) { 79 curBufferID = mgr.getMainFileID(); 80 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 81 curPtr = curBuffer.begin(); 82 83 // Set the code completion location if necessary. 84 if (codeCompleteContext) { 85 codeCompletionLocation = 86 codeCompleteContext->getCodeCompleteLoc().getPointer(); 87 } 88 89 // If the diag engine has no handler, add a default that emits to the 90 // SourceMgr. 91 if (!diagEngine.getHandlerFn()) { 92 diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) { 93 srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(), 94 diag.getMessage()); 95 for (const ast::Diagnostic ¬e : diag.getNotes()) 96 srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(), 97 note.getMessage()); 98 }); 99 addedHandlerToDiagEngine = true; 100 } 101 } 102 103 Lexer::~Lexer() { 104 if (addedHandlerToDiagEngine) 105 diagEngine.setHandlerFn(nullptr); 106 } 107 108 LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) { 109 std::string includedFile; 110 int bufferID = 111 srcMgr.AddIncludeFile(filename.str(), includeLoc.End, includedFile); 112 if (!bufferID) 113 return failure(); 114 115 curBufferID = bufferID; 116 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 117 curPtr = curBuffer.begin(); 118 return success(); 119 } 120 121 Token Lexer::emitError(SMRange loc, const Twine &msg) { 122 diagEngine.emitError(loc, msg); 123 return formToken(Token::error, loc.Start.getPointer()); 124 } 125 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, 126 const Twine ¬e) { 127 diagEngine.emitError(loc, msg)->attachNote(note, noteLoc); 128 return formToken(Token::error, loc.Start.getPointer()); 129 } 130 Token Lexer::emitError(const char *loc, const Twine &msg) { 131 return emitError( 132 SMRange(SMLoc::getFromPointer(loc), SMLoc::getFromPointer(loc + 1)), msg); 133 } 134 135 int Lexer::getNextChar() { 136 char curChar = *curPtr++; 137 switch (curChar) { 138 default: 139 return static_cast<unsigned char>(curChar); 140 case 0: { 141 // A nul character in the stream is either the end of the current buffer 142 // or a random nul in the file. Disambiguate that here. 143 if (curPtr - 1 != curBuffer.end()) 144 return 0; 145 146 // Otherwise, return end of file. 147 --curPtr; 148 return EOF; 149 } 150 case '\n': 151 case '\r': 152 // Handle the newline character by ignoring it and incrementing the line 153 // count. However, be careful about 'dos style' files with \n\r in them. 154 // Only treat a \n\r or \r\n as a single line. 155 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) 156 ++curPtr; 157 return '\n'; 158 } 159 } 160 161 Token Lexer::lexToken() { 162 while (true) { 163 const char *tokStart = curPtr; 164 165 // Check to see if this token is at the code completion location. 166 if (tokStart == codeCompletionLocation) 167 return formToken(Token::code_complete, tokStart); 168 169 // This always consumes at least one character. 170 int curChar = getNextChar(); 171 switch (curChar) { 172 default: 173 // Handle identifiers: [a-zA-Z_] 174 if (isalpha(curChar) || curChar == '_') 175 return lexIdentifier(tokStart); 176 177 // Unknown character, emit an error. 178 return emitError(tokStart, "unexpected character"); 179 case EOF: { 180 // Return EOF denoting the end of lexing. 181 Token eof = formToken(Token::eof, tokStart); 182 183 // Check to see if we are in an included file. 184 SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID); 185 if (parentIncludeLoc.isValid()) { 186 curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc); 187 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 188 curPtr = parentIncludeLoc.getPointer(); 189 } 190 191 return eof; 192 } 193 194 // Lex punctuation. 195 case '-': 196 if (*curPtr == '>') { 197 ++curPtr; 198 return formToken(Token::arrow, tokStart); 199 } 200 return emitError(tokStart, "unexpected character"); 201 case ':': 202 return formToken(Token::colon, tokStart); 203 case ',': 204 return formToken(Token::comma, tokStart); 205 case '.': 206 return formToken(Token::dot, tokStart); 207 case '=': 208 if (*curPtr == '>') { 209 ++curPtr; 210 return formToken(Token::equal_arrow, tokStart); 211 } 212 return formToken(Token::equal, tokStart); 213 case ';': 214 return formToken(Token::semicolon, tokStart); 215 case '[': 216 if (*curPtr == '{') { 217 ++curPtr; 218 return lexString(tokStart, /*isStringBlock=*/true); 219 } 220 return formToken(Token::l_square, tokStart); 221 case ']': 222 return formToken(Token::r_square, tokStart); 223 224 case '<': 225 return formToken(Token::less, tokStart); 226 case '>': 227 return formToken(Token::greater, tokStart); 228 case '{': 229 return formToken(Token::l_brace, tokStart); 230 case '}': 231 return formToken(Token::r_brace, tokStart); 232 case '(': 233 return formToken(Token::l_paren, tokStart); 234 case ')': 235 return formToken(Token::r_paren, tokStart); 236 case '/': 237 if (*curPtr == '/') { 238 lexComment(); 239 continue; 240 } 241 return emitError(tokStart, "unexpected character"); 242 243 // Ignore whitespace characters. 244 case 0: 245 case ' ': 246 case '\t': 247 case '\n': 248 return lexToken(); 249 250 case '#': 251 return lexDirective(tokStart); 252 case '"': 253 return lexString(tokStart, /*isStringBlock=*/false); 254 255 case '0': 256 case '1': 257 case '2': 258 case '3': 259 case '4': 260 case '5': 261 case '6': 262 case '7': 263 case '8': 264 case '9': 265 return lexNumber(tokStart); 266 } 267 } 268 } 269 270 /// Skip a comment line, starting with a '//'. 271 void Lexer::lexComment() { 272 // Advance over the second '/' in a '//' comment. 273 assert(*curPtr == '/'); 274 ++curPtr; 275 276 while (true) { 277 switch (*curPtr++) { 278 case '\n': 279 case '\r': 280 // Newline is end of comment. 281 return; 282 case 0: 283 // If this is the end of the buffer, end the comment. 284 if (curPtr - 1 == curBuffer.end()) { 285 --curPtr; 286 return; 287 } 288 LLVM_FALLTHROUGH; 289 default: 290 // Skip over other characters. 291 break; 292 } 293 } 294 } 295 296 Token Lexer::lexDirective(const char *tokStart) { 297 // Match the rest with an identifier regex: [0-9a-zA-Z_]* 298 while (isalnum(*curPtr) || *curPtr == '_') 299 ++curPtr; 300 301 StringRef str(tokStart, curPtr - tokStart); 302 return Token(Token::directive, str); 303 } 304 305 Token Lexer::lexIdentifier(const char *tokStart) { 306 // Match the rest of the identifier regex: [0-9a-zA-Z_]* 307 while (isalnum(*curPtr) || *curPtr == '_') 308 ++curPtr; 309 310 // Check to see if this identifier is a keyword. 311 StringRef str(tokStart, curPtr - tokStart); 312 Token::Kind kind = StringSwitch<Token::Kind>(str) 313 .Case("attr", Token::kw_attr) 314 .Case("Attr", Token::kw_Attr) 315 .Case("erase", Token::kw_erase) 316 .Case("let", Token::kw_let) 317 .Case("Constraint", Token::kw_Constraint) 318 .Case("op", Token::kw_op) 319 .Case("Op", Token::kw_Op) 320 .Case("OpName", Token::kw_OpName) 321 .Case("Pattern", Token::kw_Pattern) 322 .Case("replace", Token::kw_replace) 323 .Case("return", Token::kw_return) 324 .Case("rewrite", Token::kw_rewrite) 325 .Case("Rewrite", Token::kw_Rewrite) 326 .Case("type", Token::kw_type) 327 .Case("Type", Token::kw_Type) 328 .Case("TypeRange", Token::kw_TypeRange) 329 .Case("Value", Token::kw_Value) 330 .Case("ValueRange", Token::kw_ValueRange) 331 .Case("with", Token::kw_with) 332 .Case("_", Token::underscore) 333 .Default(Token::identifier); 334 return Token(kind, str); 335 } 336 337 Token Lexer::lexNumber(const char *tokStart) { 338 assert(isdigit(curPtr[-1])); 339 340 // Handle the normal decimal case. 341 while (isdigit(*curPtr)) 342 ++curPtr; 343 344 return formToken(Token::integer, tokStart); 345 } 346 347 Token Lexer::lexString(const char *tokStart, bool isStringBlock) { 348 while (true) { 349 // Check to see if there is a code completion location within the string. In 350 // these cases we generate a completion location and place the currently 351 // lexed string within the token (without the quotes). This allows for the 352 // parser to use the partially lexed string when computing the completion 353 // results. 354 if (curPtr == codeCompletionLocation) { 355 return formToken(Token::code_complete_string, 356 tokStart + (isStringBlock ? 2 : 1)); 357 } 358 359 switch (*curPtr++) { 360 case '"': 361 // If this is a string block, we only end the string when we encounter a 362 // `}]`. 363 if (!isStringBlock) 364 return formToken(Token::string, tokStart); 365 continue; 366 case '}': 367 // If this is a string block, we only end the string when we encounter a 368 // `}]`. 369 if (!isStringBlock || *curPtr != ']') 370 continue; 371 ++curPtr; 372 return formToken(Token::string_block, tokStart); 373 case 0: { 374 // If this is a random nul character in the middle of a string, just 375 // include it. If it is the end of file, then it is an error. 376 if (curPtr - 1 != curBuffer.end()) 377 continue; 378 --curPtr; 379 380 StringRef expectedEndStr = isStringBlock ? "}]" : "\""; 381 return emitError(curPtr - 1, 382 "expected '" + expectedEndStr + "' in string literal"); 383 } 384 385 case '\n': 386 case '\v': 387 case '\f': 388 // String blocks allow multiple lines. 389 if (!isStringBlock) 390 return emitError(curPtr - 1, "expected '\"' in string literal"); 391 continue; 392 393 case '\\': 394 // Handle explicitly a few escapes. 395 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || 396 *curPtr == 't') { 397 ++curPtr; 398 } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) { 399 // Support \xx for two hex digits. 400 curPtr += 2; 401 } else { 402 return emitError(curPtr - 1, "unknown escape in string literal"); 403 } 404 continue; 405 406 default: 407 continue; 408 } 409 } 410 } 411