1 //===- Lexer.cpp ----------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Lexer.h" 10 #include "mlir/Support/LogicalResult.h" 11 #include "mlir/Tools/PDLL/AST/Diagnostic.h" 12 #include "llvm/ADT/StringExtras.h" 13 #include "llvm/ADT/StringSwitch.h" 14 #include "llvm/Support/SourceMgr.h" 15 16 using namespace mlir; 17 using namespace mlir::pdll; 18 19 //===----------------------------------------------------------------------===// 20 // Token 21 //===----------------------------------------------------------------------===// 22 23 std::string Token::getStringValue() const { 24 assert(getKind() == string || getKind() == string_block); 25 26 // Start by dropping the quotes. 27 StringRef bytes = getSpelling().drop_front().drop_back(); 28 if (is(string_block)) bytes = bytes.drop_front().drop_back(); 29 30 std::string result; 31 result.reserve(bytes.size()); 32 for (unsigned i = 0, e = bytes.size(); i != e;) { 33 auto c = bytes[i++]; 34 if (c != '\\') { 35 result.push_back(c); 36 continue; 37 } 38 39 assert(i + 1 <= e && "invalid string should be caught by lexer"); 40 auto c1 = bytes[i++]; 41 switch (c1) { 42 case '"': 43 case '\\': 44 result.push_back(c1); 45 continue; 46 case 'n': 47 result.push_back('\n'); 48 continue; 49 case 't': 50 result.push_back('\t'); 51 continue; 52 default: 53 break; 54 } 55 56 assert(i + 1 <= e && "invalid string should be caught by lexer"); 57 auto c2 = bytes[i++]; 58 59 assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape"); 60 result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2)); 61 } 62 63 return result; 64 } 65 66 //===----------------------------------------------------------------------===// 67 // Lexer 68 //===----------------------------------------------------------------------===// 69 70 Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine) 71 : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false) { 72 curBufferID = mgr.getMainFileID(); 73 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 74 curPtr = curBuffer.begin(); 75 76 // If the diag engine has no handler, add a default that emits to the 77 // SourceMgr. 78 if (!diagEngine.getHandlerFn()) { 79 diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) { 80 srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(), 81 diag.getMessage()); 82 for (const ast::Diagnostic ¬e : diag.getNotes()) 83 srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(), 84 note.getMessage()); 85 }); 86 addedHandlerToDiagEngine = true; 87 } 88 } 89 90 Lexer::~Lexer() { 91 if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr); 92 } 93 94 LogicalResult Lexer::pushInclude(StringRef filename) { 95 std::string includedFile; 96 int bufferID = srcMgr.AddIncludeFile( 97 filename.str(), SMLoc::getFromPointer(curPtr), includedFile); 98 if (!bufferID) return failure(); 99 100 curBufferID = bufferID; 101 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 102 curPtr = curBuffer.begin(); 103 return success(); 104 } 105 106 Token Lexer::emitError(SMRange loc, const Twine &msg) { 107 diagEngine.emitError(loc, msg); 108 return formToken(Token::error, loc.Start.getPointer()); 109 } 110 Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, 111 SMRange noteLoc, const Twine ¬e) { 112 diagEngine.emitError(loc, msg)->attachNote(note, noteLoc); 113 return formToken(Token::error, loc.Start.getPointer()); 114 } 115 Token Lexer::emitError(const char *loc, const Twine &msg) { 116 return emitError(SMRange(SMLoc::getFromPointer(loc), 117 SMLoc::getFromPointer(loc + 1)), 118 msg); 119 } 120 121 int Lexer::getNextChar() { 122 char curChar = *curPtr++; 123 switch (curChar) { 124 default: 125 return static_cast<unsigned char>(curChar); 126 case 0: { 127 // A nul character in the stream is either the end of the current buffer 128 // or a random nul in the file. Disambiguate that here. 129 if (curPtr - 1 != curBuffer.end()) return 0; 130 131 // Otherwise, return end of file. 132 --curPtr; 133 return EOF; 134 } 135 case '\n': 136 case '\r': 137 // Handle the newline character by ignoring it and incrementing the line 138 // count. However, be careful about 'dos style' files with \n\r in them. 139 // Only treat a \n\r or \r\n as a single line. 140 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) 141 ++curPtr; 142 return '\n'; 143 } 144 } 145 146 Token Lexer::lexToken() { 147 while (true) { 148 const char *tokStart = curPtr; 149 150 // This always consumes at least one character. 151 int curChar = getNextChar(); 152 switch (curChar) { 153 default: 154 // Handle identifiers: [a-zA-Z_] 155 if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart); 156 157 // Unknown character, emit an error. 158 return emitError(tokStart, "unexpected character"); 159 case EOF: { 160 // Return EOF denoting the end of lexing. 161 Token eof = formToken(Token::eof, tokStart); 162 163 // Check to see if we are in an included file. 164 SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID); 165 if (parentIncludeLoc.isValid()) { 166 curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc); 167 curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); 168 curPtr = parentIncludeLoc.getPointer(); 169 } 170 171 return eof; 172 } 173 174 // Lex punctuation. 175 case '-': 176 if (*curPtr == '>') { 177 ++curPtr; 178 return formToken(Token::arrow, tokStart); 179 } 180 return emitError(tokStart, "unexpected character"); 181 case ':': 182 return formToken(Token::colon, tokStart); 183 case ',': 184 return formToken(Token::comma, tokStart); 185 case '.': 186 return formToken(Token::dot, tokStart); 187 case '=': 188 if (*curPtr == '>') { 189 ++curPtr; 190 return formToken(Token::equal_arrow, tokStart); 191 } 192 return formToken(Token::equal, tokStart); 193 case ';': 194 return formToken(Token::semicolon, tokStart); 195 case '[': 196 if (*curPtr == '{') { 197 ++curPtr; 198 return lexString(tokStart, /*isStringBlock=*/true); 199 } 200 return formToken(Token::l_square, tokStart); 201 case ']': 202 return formToken(Token::r_square, tokStart); 203 204 case '<': 205 return formToken(Token::less, tokStart); 206 case '>': 207 return formToken(Token::greater, tokStart); 208 case '{': 209 return formToken(Token::l_brace, tokStart); 210 case '}': 211 return formToken(Token::r_brace, tokStart); 212 case '(': 213 return formToken(Token::l_paren, tokStart); 214 case ')': 215 return formToken(Token::r_paren, tokStart); 216 case '/': 217 if (*curPtr == '/') { 218 lexComment(); 219 continue; 220 } 221 return emitError(tokStart, "unexpected character"); 222 223 // Ignore whitespace characters. 224 case 0: 225 case ' ': 226 case '\t': 227 case '\n': 228 return lexToken(); 229 230 case '#': 231 return lexDirective(tokStart); 232 case '"': 233 return lexString(tokStart, /*isStringBlock=*/false); 234 235 case '0': 236 case '1': 237 case '2': 238 case '3': 239 case '4': 240 case '5': 241 case '6': 242 case '7': 243 case '8': 244 case '9': 245 return lexNumber(tokStart); 246 } 247 } 248 } 249 250 /// Skip a comment line, starting with a '//'. 251 void Lexer::lexComment() { 252 // Advance over the second '/' in a '//' comment. 253 assert(*curPtr == '/'); 254 ++curPtr; 255 256 while (true) { 257 switch (*curPtr++) { 258 case '\n': 259 case '\r': 260 // Newline is end of comment. 261 return; 262 case 0: 263 // If this is the end of the buffer, end the comment. 264 if (curPtr - 1 == curBuffer.end()) { 265 --curPtr; 266 return; 267 } 268 LLVM_FALLTHROUGH; 269 default: 270 // Skip over other characters. 271 break; 272 } 273 } 274 } 275 276 Token Lexer::lexDirective(const char *tokStart) { 277 // Match the rest with an identifier regex: [0-9a-zA-Z_]* 278 while (isalnum(*curPtr) || *curPtr == '_') ++curPtr; 279 280 StringRef str(tokStart, curPtr - tokStart); 281 return Token(Token::directive, str); 282 } 283 284 Token Lexer::lexIdentifier(const char *tokStart) { 285 // Match the rest of the identifier regex: [0-9a-zA-Z_]* 286 while (isalnum(*curPtr) || *curPtr == '_') ++curPtr; 287 288 // Check to see if this identifier is a keyword. 289 StringRef str(tokStart, curPtr - tokStart); 290 Token::Kind kind = StringSwitch<Token::Kind>(str) 291 .Case("attr", Token::kw_attr) 292 .Case("Attr", Token::kw_Attr) 293 .Case("erase", Token::kw_erase) 294 .Case("let", Token::kw_let) 295 .Case("Constraint", Token::kw_Constraint) 296 .Case("op", Token::kw_op) 297 .Case("Op", Token::kw_Op) 298 .Case("OpName", Token::kw_OpName) 299 .Case("Pattern", Token::kw_Pattern) 300 .Case("replace", Token::kw_replace) 301 .Case("return", Token::kw_return) 302 .Case("rewrite", Token::kw_rewrite) 303 .Case("Rewrite", Token::kw_Rewrite) 304 .Case("type", Token::kw_type) 305 .Case("Type", Token::kw_Type) 306 .Case("TypeRange", Token::kw_TypeRange) 307 .Case("Value", Token::kw_Value) 308 .Case("ValueRange", Token::kw_ValueRange) 309 .Case("with", Token::kw_with) 310 .Case("_", Token::underscore) 311 .Default(Token::identifier); 312 return Token(kind, str); 313 } 314 315 Token Lexer::lexNumber(const char *tokStart) { 316 assert(isdigit(curPtr[-1])); 317 318 // Handle the normal decimal case. 319 while (isdigit(*curPtr)) ++curPtr; 320 321 return formToken(Token::integer, tokStart); 322 } 323 324 Token Lexer::lexString(const char *tokStart, bool isStringBlock) { 325 while (true) { 326 switch (*curPtr++) { 327 case '"': 328 // If this is a string block, we only end the string when we encounter a 329 // `}]`. 330 if (!isStringBlock) return formToken(Token::string, tokStart); 331 continue; 332 case '}': 333 // If this is a string block, we only end the string when we encounter a 334 // `}]`. 335 if (!isStringBlock || *curPtr != ']') continue; 336 ++curPtr; 337 return formToken(Token::string_block, tokStart); 338 case 0: 339 // If this is a random nul character in the middle of a string, just 340 // include it. If it is the end of file, then it is an error. 341 if (curPtr - 1 != curBuffer.end()) continue; 342 LLVM_FALLTHROUGH; 343 case '\n': 344 case '\v': 345 case '\f': 346 // String blocks allow multiple lines. 347 if (!isStringBlock) 348 return emitError(curPtr - 1, "expected '\"' in string literal"); 349 continue; 350 351 case '\\': 352 // Handle explicitly a few escapes. 353 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || 354 *curPtr == 't') { 355 ++curPtr; 356 } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) { 357 // Support \xx for two hex digits. 358 curPtr += 2; 359 } else { 360 return emitError(curPtr - 1, "unknown escape in string literal"); 361 } 362 continue; 363 364 default: 365 continue; 366 } 367 } 368 } 369