//===- Lexer.cpp ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Lexer.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/SourceMgr.h"

using namespace mlir;
using namespace mlir::pdll;

//===----------------------------------------------------------------------===//
// Token
//===----------------------------------------------------------------------===//

std::string Token::getStringValue() const {
  assert(getKind() == string || getKind() == string_block ||
         getKind() == code_complete_string);

  // Start by dropping the quotes.
  StringRef bytes = getSpelling();
  if (is(string))
    bytes = bytes.drop_front().drop_back();
  else if (is(string_block))
    bytes = bytes.drop_front(2).drop_back(2);

  std::string result;
  result.reserve(bytes.size());
  for (unsigned i = 0, e = bytes.size(); i != e;) {
    auto c = bytes[i++];
    if (c != '\\') {
      result.push_back(c);
      continue;
    }

    assert(i + 1 <= e && "invalid string should be caught by lexer");
    auto c1 = bytes[i++];
    switch (c1) {
      case '"':
      case '\\':
        result.push_back(c1);
        continue;
      case 'n':
        result.push_back('\n');
        continue;
      case 't':
        result.push_back('\t');
        continue;
      default:
        break;
    }

    assert(i + 1 <= e && "invalid string should be caught by lexer");
    auto c2 = bytes[i++];

    assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
    result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
  }

  return result;
}

//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//

Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
             CodeCompleteContext *codeCompleteContext)
    : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
      codeCompletionLocation(nullptr) {
  curBufferID = mgr.getMainFileID();
  curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
  curPtr = curBuffer.begin();

  // Set the code completion location if necessary.
  if (codeCompleteContext) {
    codeCompletionLocation =
        codeCompleteContext->getCodeCompleteLoc().getPointer();
  }

  // If the diag engine has no handler, add a default that emits to the
  // SourceMgr.
  if (!diagEngine.getHandlerFn()) {
    diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
      srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
                          diag.getMessage());
      for (const ast::Diagnostic &note : diag.getNotes())
        srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
                            note.getMessage());
    });
    addedHandlerToDiagEngine = true;
  }
}

Lexer::~Lexer() {
  if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr);
}

LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) {
  std::string includedFile;
  int bufferID =
      srcMgr.AddIncludeFile(filename.str(), includeLoc.End, includedFile);
  if (!bufferID)
    return failure();

  curBufferID = bufferID;
  curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
  curPtr = curBuffer.begin();
  return success();
}

Token Lexer::emitError(SMRange loc, const Twine &msg) {
  diagEngine.emitError(loc, msg);
  return formToken(Token::error, loc.Start.getPointer());
}
Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg,
                              SMRange noteLoc, const Twine &note) {
  diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
  return formToken(Token::error, loc.Start.getPointer());
}
Token Lexer::emitError(const char *loc, const Twine &msg) {
  return emitError(SMRange(SMLoc::getFromPointer(loc),
                                 SMLoc::getFromPointer(loc + 1)),
                   msg);
}

int Lexer::getNextChar() {
  char curChar = *curPtr++;
  switch (curChar) {
    default:
      return static_cast<unsigned char>(curChar);
    case 0: {
      // A nul character in the stream is either the end of the current buffer
      // or a random nul in the file. Disambiguate that here.
      if (curPtr - 1 != curBuffer.end()) return 0;

      // Otherwise, return end of file.
      --curPtr;
      return EOF;
    }
    case '\n':
    case '\r':
      // Handle the newline character by ignoring it and incrementing the line
      // count. However, be careful about 'dos style' files with \n\r in them.
      // Only treat a \n\r or \r\n as a single line.
      if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
        ++curPtr;
      return '\n';
  }
}

Token Lexer::lexToken() {
  while (true) {
    const char *tokStart = curPtr;

    // Check to see if this token is at the code completion location.
    if (tokStart == codeCompletionLocation)
      return formToken(Token::code_complete, tokStart);

    // This always consumes at least one character.
    int curChar = getNextChar();
    switch (curChar) {
      default:
        // Handle identifiers: [a-zA-Z_]
        if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart);

        // Unknown character, emit an error.
        return emitError(tokStart, "unexpected character");
      case EOF: {
        // Return EOF denoting the end of lexing.
        Token eof = formToken(Token::eof, tokStart);

        // Check to see if we are in an included file.
        SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
        if (parentIncludeLoc.isValid()) {
          curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
          curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
          curPtr = parentIncludeLoc.getPointer();
        }

        return eof;
      }

      // Lex punctuation.
      case '-':
        if (*curPtr == '>') {
          ++curPtr;
          return formToken(Token::arrow, tokStart);
        }
        return emitError(tokStart, "unexpected character");
      case ':':
        return formToken(Token::colon, tokStart);
      case ',':
        return formToken(Token::comma, tokStart);
      case '.':
        return formToken(Token::dot, tokStart);
      case '=':
        if (*curPtr == '>') {
          ++curPtr;
          return formToken(Token::equal_arrow, tokStart);
        }
        return formToken(Token::equal, tokStart);
      case ';':
        return formToken(Token::semicolon, tokStart);
      case '[':
        if (*curPtr == '{') {
          ++curPtr;
          return lexString(tokStart, /*isStringBlock=*/true);
        }
        return formToken(Token::l_square, tokStart);
      case ']':
        return formToken(Token::r_square, tokStart);

      case '<':
        return formToken(Token::less, tokStart);
      case '>':
        return formToken(Token::greater, tokStart);
      case '{':
        return formToken(Token::l_brace, tokStart);
      case '}':
        return formToken(Token::r_brace, tokStart);
      case '(':
        return formToken(Token::l_paren, tokStart);
      case ')':
        return formToken(Token::r_paren, tokStart);
      case '/':
        if (*curPtr == '/') {
          lexComment();
          continue;
        }
        return emitError(tokStart, "unexpected character");

      // Ignore whitespace characters.
      case 0:
      case ' ':
      case '\t':
      case '\n':
        return lexToken();

      case '#':
        return lexDirective(tokStart);
      case '"':
        return lexString(tokStart, /*isStringBlock=*/false);

      case '0':
      case '1':
      case '2':
      case '3':
      case '4':
      case '5':
      case '6':
      case '7':
      case '8':
      case '9':
        return lexNumber(tokStart);
    }
  }
}

/// Skip a comment line, starting with a '//'.
void Lexer::lexComment() {
  // Advance over the second '/' in a '//' comment.
  assert(*curPtr == '/');
  ++curPtr;

  while (true) {
    switch (*curPtr++) {
      case '\n':
      case '\r':
        // Newline is end of comment.
        return;
      case 0:
        // If this is the end of the buffer, end the comment.
        if (curPtr - 1 == curBuffer.end()) {
          --curPtr;
          return;
        }
        LLVM_FALLTHROUGH;
      default:
        // Skip over other characters.
        break;
    }
  }
}

Token Lexer::lexDirective(const char *tokStart) {
  // Match the rest with an identifier regex: [0-9a-zA-Z_]*
  while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;

  StringRef str(tokStart, curPtr - tokStart);
  return Token(Token::directive, str);
}

Token Lexer::lexIdentifier(const char *tokStart) {
  // Match the rest of the identifier regex: [0-9a-zA-Z_]*
  while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;

  // Check to see if this identifier is a keyword.
  StringRef str(tokStart, curPtr - tokStart);
  Token::Kind kind = StringSwitch<Token::Kind>(str)
                         .Case("attr", Token::kw_attr)
                         .Case("Attr", Token::kw_Attr)
                         .Case("erase", Token::kw_erase)
                         .Case("let", Token::kw_let)
                         .Case("Constraint", Token::kw_Constraint)
                         .Case("op", Token::kw_op)
                         .Case("Op", Token::kw_Op)
                         .Case("OpName", Token::kw_OpName)
                         .Case("Pattern", Token::kw_Pattern)
                         .Case("replace", Token::kw_replace)
                         .Case("return", Token::kw_return)
                         .Case("rewrite", Token::kw_rewrite)
                         .Case("Rewrite", Token::kw_Rewrite)
                         .Case("type", Token::kw_type)
                         .Case("Type", Token::kw_Type)
                         .Case("TypeRange", Token::kw_TypeRange)
                         .Case("Value", Token::kw_Value)
                         .Case("ValueRange", Token::kw_ValueRange)
                         .Case("with", Token::kw_with)
                         .Case("_", Token::underscore)
                         .Default(Token::identifier);
  return Token(kind, str);
}

Token Lexer::lexNumber(const char *tokStart) {
  assert(isdigit(curPtr[-1]));

  // Handle the normal decimal case.
  while (isdigit(*curPtr)) ++curPtr;

  return formToken(Token::integer, tokStart);
}

Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
  while (true) {
    // Check to see if there is a code completion location within the string. In
    // these cases we generate a completion location and place the currently
    // lexed string within the token (without the quotes). This allows for the
    // parser to use the partially lexed string when computing the completion
    // results.
    if (curPtr == codeCompletionLocation) {
      return formToken(Token::code_complete_string,
                       tokStart + (isStringBlock ? 2 : 1));
    }

    switch (*curPtr++) {
      case '"':
        // If this is a string block, we only end the string when we encounter a
        // `}]`.
        if (!isStringBlock)
          return formToken(Token::string, tokStart);
        continue;
      case '}':
        // If this is a string block, we only end the string when we encounter a
        // `}]`.
        if (!isStringBlock || *curPtr != ']')
          continue;
        ++curPtr;
        return formToken(Token::string_block, tokStart);
      case 0: {
        // If this is a random nul character in the middle of a string, just
        // include it. If it is the end of file, then it is an error.
        if (curPtr - 1 != curBuffer.end())
          continue;
        --curPtr;

        StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
        return emitError(curPtr - 1,
                         "expected '" + expectedEndStr + "' in string literal");
      }

      case '\n':
      case '\v':
      case '\f':
        // String blocks allow multiple lines.
        if (!isStringBlock)
          return emitError(curPtr - 1, "expected '\"' in string literal");
        continue;

      case '\\':
        // Handle explicitly a few escapes.
        if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
            *curPtr == 't') {
          ++curPtr;
        } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
          // Support \xx for two hex digits.
          curPtr += 2;
        } else {
          return emitError(curPtr - 1, "unknown escape in string literal");
        }
        continue;

      default:
        continue;
    }
  }
}
