1 //===- Lexer.cpp ----------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Lexer.h" 10 #include "mlir/Support/LogicalResult.h" 11 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 12 #include "mlir/Tools/PDLL/Parser/CodeComplete.h" 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/ADT/StringSwitch.h" 15 #include "llvm/Support/SourceMgr.h" 16 17 using namespace mlir; 18 using namespace mlir::pdll; 19 20 //===----------------------------------------------------------------------===// 21 // Token 22 //===----------------------------------------------------------------------===// 23 24 std::string Token::getStringValue() const { 25 assert(getKind() == string || getKind() == string_block); 26 27 // Start by dropping the quotes. 28 StringRef bytes = getSpelling().drop_front().drop_back(); 29 if (is(string_block)) bytes = bytes.drop_front().drop_back(); 30 31 std::string result; 32 result.reserve(bytes.size()); 33 for (unsigned i = 0, e = bytes.size(); i != e;) { 34 auto c = bytes[i++]; 35 if (c != '\\') { 36 result.push_back(c); 37 continue; 38 } 39 40 assert(i + 1 <= e && "invalid string should be caught by lexer"); 41 auto c1 = bytes[i++]; 42 switch (c1) { 43 case '"': 44 case '\\': 45 result.push_back(c1); 46 continue; 47 case 'n': 48 result.push_back('\n'); 49 continue; 50 case 't': 51 result.push_back('\t'); 52 continue; 53 default: 54 break; 55 } 56 57 assert(i + 1 <= e && "invalid string should be caught by lexer"); 58 auto c2 = bytes[i++]; 59 60 assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape"); 61 result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2)); 62 } 63 64 return result; 65 } 66 67 //===----------------------------------------------------------------------===// 68 // Lexer 69 //===----------------------------------------------------------------------===// 70 71 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine, 72 CodeCompleteContext *codeCompleteContext) 73 : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false), 74 codeCompletionLocation(nullptr) { 75 curBufferID = mgr.getMainFileID(); 76 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 77 curPtr = curBuffer.begin(); 78 79 // Set the code completion location if necessary. 80 if (codeCompleteContext) { 81 codeCompletionLocation = 82 codeCompleteContext->getCodeCompleteLoc().getPointer(); 83 } 84 85 // If the diag engine has no handler, add a default that emits to the 86 // SourceMgr. 87 if (!diagEngine.getHandlerFn()) { 88 diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) { 89 srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(), 90 diag.getMessage()); 91 for (const ast::Diagnostic ¬e : diag.getNotes()) 92 srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(), 93 note.getMessage()); 94 }); 95 addedHandlerToDiagEngine = true; 96 } 97 } 98 99 Lexer::~Lexer() { 100 if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr); 101 } 102 103 LogicalResult Lexer::pushInclude(StringRef filename) { 104 std::string includedFile; 105 int bufferID = srcMgr.AddIncludeFile( 106 filename.str(), SMLoc::getFromPointer(curPtr), includedFile); 107 if (!bufferID) return failure(); 108 109 curBufferID = bufferID; 110 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 111 curPtr = curBuffer.begin(); 112 return success(); 113 } 114 115 Token Lexer::emitError(SMRange loc, const Twine &msg) { 116 diagEngine.emitError(loc, msg); 117 return formToken(Token::error, loc.Start.getPointer()); 118 } 119 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, 120 SMRange noteLoc, const Twine ¬e) { 121 diagEngine.emitError(loc, msg)->attachNote(note, noteLoc); 122 return formToken(Token::error, loc.Start.getPointer()); 123 } 124 Token Lexer::emitError(const char *loc, const Twine &msg) { 125 return emitError(SMRange(SMLoc::getFromPointer(loc), 126 SMLoc::getFromPointer(loc + 1)), 127 msg); 128 } 129 130 int Lexer::getNextChar() { 131 char curChar = *curPtr++; 132 switch (curChar) { 133 default: 134 return static_cast<unsigned char>(curChar); 135 case 0: { 136 // A nul character in the stream is either the end of the current buffer 137 // or a random nul in the file. Disambiguate that here. 138 if (curPtr - 1 != curBuffer.end()) return 0; 139 140 // Otherwise, return end of file. 141 --curPtr; 142 return EOF; 143 } 144 case '\n': 145 case '\r': 146 // Handle the newline character by ignoring it and incrementing the line 147 // count. However, be careful about 'dos style' files with \n\r in them. 148 // Only treat a \n\r or \r\n as a single line. 149 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) 150 ++curPtr; 151 return '\n'; 152 } 153 } 154 155 Token Lexer::lexToken() { 156 while (true) { 157 const char *tokStart = curPtr; 158 159 // Check to see if this token is at the code completion location. 160 if (tokStart == codeCompletionLocation) 161 return formToken(Token::code_complete, tokStart); 162 163 // This always consumes at least one character. 164 int curChar = getNextChar(); 165 switch (curChar) { 166 default: 167 // Handle identifiers: [a-zA-Z_] 168 if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart); 169 170 // Unknown character, emit an error. 171 return emitError(tokStart, "unexpected character"); 172 case EOF: { 173 // Return EOF denoting the end of lexing. 174 Token eof = formToken(Token::eof, tokStart); 175 176 // Check to see if we are in an included file. 177 SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID); 178 if (parentIncludeLoc.isValid()) { 179 curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc); 180 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 181 curPtr = parentIncludeLoc.getPointer(); 182 } 183 184 return eof; 185 } 186 187 // Lex punctuation. 188 case '-': 189 if (*curPtr == '>') { 190 ++curPtr; 191 return formToken(Token::arrow, tokStart); 192 } 193 return emitError(tokStart, "unexpected character"); 194 case ':': 195 return formToken(Token::colon, tokStart); 196 case ',': 197 return formToken(Token::comma, tokStart); 198 case '.': 199 return formToken(Token::dot, tokStart); 200 case '=': 201 if (*curPtr == '>') { 202 ++curPtr; 203 return formToken(Token::equal_arrow, tokStart); 204 } 205 return formToken(Token::equal, tokStart); 206 case ';': 207 return formToken(Token::semicolon, tokStart); 208 case '[': 209 if (*curPtr == '{') { 210 ++curPtr; 211 return lexString(tokStart, /*isStringBlock=*/true); 212 } 213 return formToken(Token::l_square, tokStart); 214 case ']': 215 return formToken(Token::r_square, tokStart); 216 217 case '<': 218 return formToken(Token::less, tokStart); 219 case '>': 220 return formToken(Token::greater, tokStart); 221 case '{': 222 return formToken(Token::l_brace, tokStart); 223 case '}': 224 return formToken(Token::r_brace, tokStart); 225 case '(': 226 return formToken(Token::l_paren, tokStart); 227 case ')': 228 return formToken(Token::r_paren, tokStart); 229 case '/': 230 if (*curPtr == '/') { 231 lexComment(); 232 continue; 233 } 234 return emitError(tokStart, "unexpected character"); 235 236 // Ignore whitespace characters. 237 case 0: 238 case ' ': 239 case '\t': 240 case '\n': 241 return lexToken(); 242 243 case '#': 244 return lexDirective(tokStart); 245 case '"': 246 return lexString(tokStart, /*isStringBlock=*/false); 247 248 case '0': 249 case '1': 250 case '2': 251 case '3': 252 case '4': 253 case '5': 254 case '6': 255 case '7': 256 case '8': 257 case '9': 258 return lexNumber(tokStart); 259 } 260 } 261 } 262 263 /// Skip a comment line, starting with a '//'. 264 void Lexer::lexComment() { 265 // Advance over the second '/' in a '//' comment. 266 assert(*curPtr == '/'); 267 ++curPtr; 268 269 while (true) { 270 switch (*curPtr++) { 271 case '\n': 272 case '\r': 273 // Newline is end of comment. 274 return; 275 case 0: 276 // If this is the end of the buffer, end the comment. 277 if (curPtr - 1 == curBuffer.end()) { 278 --curPtr; 279 return; 280 } 281 LLVM_FALLTHROUGH; 282 default: 283 // Skip over other characters. 284 break; 285 } 286 } 287 } 288 289 Token Lexer::lexDirective(const char *tokStart) { 290 // Match the rest with an identifier regex: [0-9a-zA-Z_]* 291 while (isalnum(*curPtr) || *curPtr == '_') ++curPtr; 292 293 StringRef str(tokStart, curPtr - tokStart); 294 return Token(Token::directive, str); 295 } 296 297 Token Lexer::lexIdentifier(const char *tokStart) { 298 // Match the rest of the identifier regex: [0-9a-zA-Z_]* 299 while (isalnum(*curPtr) || *curPtr == '_') ++curPtr; 300 301 // Check to see if this identifier is a keyword. 302 StringRef str(tokStart, curPtr - tokStart); 303 Token::Kind kind = StringSwitch<Token::Kind>(str) 304 .Case("attr", Token::kw_attr) 305 .Case("Attr", Token::kw_Attr) 306 .Case("erase", Token::kw_erase) 307 .Case("let", Token::kw_let) 308 .Case("Constraint", Token::kw_Constraint) 309 .Case("op", Token::kw_op) 310 .Case("Op", Token::kw_Op) 311 .Case("OpName", Token::kw_OpName) 312 .Case("Pattern", Token::kw_Pattern) 313 .Case("replace", Token::kw_replace) 314 .Case("return", Token::kw_return) 315 .Case("rewrite", Token::kw_rewrite) 316 .Case("Rewrite", Token::kw_Rewrite) 317 .Case("type", Token::kw_type) 318 .Case("Type", Token::kw_Type) 319 .Case("TypeRange", Token::kw_TypeRange) 320 .Case("Value", Token::kw_Value) 321 .Case("ValueRange", Token::kw_ValueRange) 322 .Case("with", Token::kw_with) 323 .Case("_", Token::underscore) 324 .Default(Token::identifier); 325 return Token(kind, str); 326 } 327 328 Token Lexer::lexNumber(const char *tokStart) { 329 assert(isdigit(curPtr[-1])); 330 331 // Handle the normal decimal case. 332 while (isdigit(*curPtr)) ++curPtr; 333 334 return formToken(Token::integer, tokStart); 335 } 336 337 Token Lexer::lexString(const char *tokStart, bool isStringBlock) { 338 while (true) { 339 switch (*curPtr++) { 340 case '"': 341 // If this is a string block, we only end the string when we encounter a 342 // `}]`. 343 if (!isStringBlock) return formToken(Token::string, tokStart); 344 continue; 345 case '}': 346 // If this is a string block, we only end the string when we encounter a 347 // `}]`. 348 if (!isStringBlock || *curPtr != ']') continue; 349 ++curPtr; 350 return formToken(Token::string_block, tokStart); 351 case 0: 352 // If this is a random nul character in the middle of a string, just 353 // include it. If it is the end of file, then it is an error. 354 if (curPtr - 1 != curBuffer.end()) continue; 355 LLVM_FALLTHROUGH; 356 case '\n': 357 case '\v': 358 case '\f': 359 // String blocks allow multiple lines. 360 if (!isStringBlock) 361 return emitError(curPtr - 1, "expected '\"' in string literal"); 362 continue; 363 364 case '\\': 365 // Handle explicitly a few escapes. 366 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || 367 *curPtr == 't') { 368 ++curPtr; 369 } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) { 370 // Support \xx for two hex digits. 371 curPtr += 2; 372 } else { 373 return emitError(curPtr - 1, "unknown escape in string literal"); 374 } 375 continue; 376 377 default: 378 continue; 379 } 380 } 381 } 382