1 //===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Transport.h"
10 #include "Logging.h"
11 #include "Protocol.h"
12 #include "llvm/ADT/SmallString.h"
13 #include "llvm/Support/Errno.h"
14 #include "llvm/Support/Error.h"
15 #include <system_error>
16 #include <utility>
17 
18 using namespace mlir;
19 using namespace mlir::lsp;
20 
21 //===----------------------------------------------------------------------===//
22 // Reply
23 //===----------------------------------------------------------------------===//
24 
25 namespace {
26 /// Function object to reply to an LSP call.
27 /// Each instance must be called exactly once, otherwise:
28 ///  - if there was no reply, an error reply is sent
29 ///  - if there were multiple replies, only the first is sent
30 class Reply {
31 public:
32   Reply(const llvm::json::Value &id, StringRef method,
33         JSONTransport &transport);
34   Reply(Reply &&other);
35   Reply &operator=(Reply &&) = delete;
36   Reply(const Reply &) = delete;
37   Reply &operator=(const Reply &) = delete;
38 
39   void operator()(llvm::Expected<llvm::json::Value> reply);
40 
41 private:
42   StringRef method;
43   std::atomic<bool> replied = {false};
44   llvm::json::Value id;
45   JSONTransport *transport;
46 };
47 } // namespace
48 
Reply(const llvm::json::Value & id,llvm::StringRef method,JSONTransport & transport)49 Reply::Reply(const llvm::json::Value &id, llvm::StringRef method,
50              JSONTransport &transport)
51     : id(id), transport(&transport) {}
52 
Reply(Reply && other)53 Reply::Reply(Reply &&other)
54     : replied(other.replied.load()), id(std::move(other.id)),
55       transport(other.transport) {
56   other.transport = nullptr;
57 }
58 
operator ()(llvm::Expected<llvm::json::Value> reply)59 void Reply::operator()(llvm::Expected<llvm::json::Value> reply) {
60   if (replied.exchange(true)) {
61     Logger::error("Replied twice to message {0}({1})", method, id);
62     assert(false && "must reply to each call only once!");
63     return;
64   }
65   assert(transport && "expected valid transport to reply to");
66 
67   if (reply) {
68     Logger::info("--> reply:{0}({1})", method, id);
69     transport->reply(std::move(id), std::move(reply));
70   } else {
71     llvm::Error error = reply.takeError();
72     Logger::info("--> reply:{0}({1})", method, id, error);
73     transport->reply(std::move(id), std::move(error));
74   }
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // MessageHandler
79 //===----------------------------------------------------------------------===//
80 
onNotify(llvm::StringRef method,llvm::json::Value value)81 bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) {
82   Logger::info("--> {0}", method);
83 
84   if (method == "exit")
85     return false;
86   if (method == "$cancel") {
87     // TODO: Add support for cancelling requests.
88   } else {
89     auto it = notificationHandlers.find(method);
90     if (it != notificationHandlers.end())
91       it->second(std::move(value));
92   }
93   return true;
94 }
95 
onCall(llvm::StringRef method,llvm::json::Value params,llvm::json::Value id)96 bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
97                             llvm::json::Value id) {
98   Logger::info("--> {0}({1})", method, id);
99 
100   Reply reply(id, method, transport);
101 
102   auto it = methodHandlers.find(method);
103   if (it != methodHandlers.end()) {
104     it->second(std::move(params), std::move(reply));
105   } else {
106     reply(llvm::make_error<LSPError>("method not found: " + method.str(),
107                                      ErrorCode::MethodNotFound));
108   }
109   return true;
110 }
111 
onReply(llvm::json::Value id,llvm::Expected<llvm::json::Value> result)112 bool MessageHandler::onReply(llvm::json::Value id,
113                              llvm::Expected<llvm::json::Value> result) {
114   // TODO: Add support for reply callbacks when support for outgoing messages is
115   // added. For now, we just log an error on any replies received.
116   Callback<llvm::json::Value> replyHandler =
117       [&id](llvm::Expected<llvm::json::Value> result) {
118         Logger::error(
119             "received a reply with ID {0}, but there was no such call", id);
120         if (!result)
121           llvm::consumeError(result.takeError());
122       };
123 
124   // Log and run the reply handler.
125   if (result)
126     replyHandler(std::move(result));
127   else
128     replyHandler(result.takeError());
129   return true;
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // JSONTransport
134 //===----------------------------------------------------------------------===//
135 
136 /// Encode the given error as a JSON object.
encodeError(llvm::Error error)137 static llvm::json::Object encodeError(llvm::Error error) {
138   std::string message;
139   ErrorCode code = ErrorCode::UnknownErrorCode;
140   auto handlerFn = [&](const LSPError &lspError) -> llvm::Error {
141     message = lspError.message;
142     code = lspError.code;
143     return llvm::Error::success();
144   };
145   if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn))
146     message = llvm::toString(std::move(unhandled));
147 
148   return llvm::json::Object{
149       {"message", std::move(message)},
150       {"code", int64_t(code)},
151   };
152 }
153 
154 /// Decode the given JSON object into an error.
decodeError(const llvm::json::Object & o)155 llvm::Error decodeError(const llvm::json::Object &o) {
156   StringRef msg = o.getString("message").value_or("Unspecified error");
157   if (Optional<int64_t> code = o.getInteger("code"))
158     return llvm::make_error<LSPError>(msg.str(), ErrorCode(*code));
159   return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(),
160                                              msg.str());
161 }
162 
notify(StringRef method,llvm::json::Value params)163 void JSONTransport::notify(StringRef method, llvm::json::Value params) {
164   sendMessage(llvm::json::Object{
165       {"jsonrpc", "2.0"},
166       {"method", method},
167       {"params", std::move(params)},
168   });
169 }
call(StringRef method,llvm::json::Value params,llvm::json::Value id)170 void JSONTransport::call(StringRef method, llvm::json::Value params,
171                          llvm::json::Value id) {
172   sendMessage(llvm::json::Object{
173       {"jsonrpc", "2.0"},
174       {"id", std::move(id)},
175       {"method", method},
176       {"params", std::move(params)},
177   });
178 }
reply(llvm::json::Value id,llvm::Expected<llvm::json::Value> result)179 void JSONTransport::reply(llvm::json::Value id,
180                           llvm::Expected<llvm::json::Value> result) {
181   if (result) {
182     return sendMessage(llvm::json::Object{
183         {"jsonrpc", "2.0"},
184         {"id", std::move(id)},
185         {"result", std::move(*result)},
186     });
187   }
188 
189   sendMessage(llvm::json::Object{
190       {"jsonrpc", "2.0"},
191       {"id", std::move(id)},
192       {"error", encodeError(result.takeError())},
193   });
194 }
195 
run(MessageHandler & handler)196 llvm::Error JSONTransport::run(MessageHandler &handler) {
197   std::string json;
198   while (!feof(in)) {
199     if (ferror(in)) {
200       return llvm::errorCodeToError(
201           std::error_code(errno, std::system_category()));
202     }
203 
204     if (succeeded(readMessage(json))) {
205       if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(json)) {
206         if (!handleMessage(std::move(*doc), handler))
207           return llvm::Error::success();
208       } else {
209         Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError()));
210       }
211     }
212   }
213   return llvm::errorCodeToError(std::make_error_code(std::errc::io_error));
214 }
215 
sendMessage(llvm::json::Value msg)216 void JSONTransport::sendMessage(llvm::json::Value msg) {
217   outputBuffer.clear();
218   llvm::raw_svector_ostream os(outputBuffer);
219   os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg);
220   out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n"
221       << outputBuffer;
222   out.flush();
223   Logger::debug(">>> {0}\n", outputBuffer);
224 }
225 
handleMessage(llvm::json::Value msg,MessageHandler & handler)226 bool JSONTransport::handleMessage(llvm::json::Value msg,
227                                   MessageHandler &handler) {
228   // Message must be an object with "jsonrpc":"2.0".
229   llvm::json::Object *object = msg.getAsObject();
230   if (!object ||
231       object->getString("jsonrpc") != llvm::Optional<StringRef>("2.0"))
232     return false;
233 
234   // `id` may be any JSON value. If absent, this is a notification.
235   llvm::Optional<llvm::json::Value> id;
236   if (llvm::json::Value *i = object->get("id"))
237     id = std::move(*i);
238   Optional<StringRef> method = object->getString("method");
239 
240   // This is a response.
241   if (!method) {
242     if (!id)
243       return false;
244     if (auto *err = object->getObject("error"))
245       return handler.onReply(std::move(*id), decodeError(*err));
246     // result should be given, use null if not.
247     llvm::json::Value result = nullptr;
248     if (llvm::json::Value *r = object->get("result"))
249       result = std::move(*r);
250     return handler.onReply(std::move(*id), std::move(result));
251   }
252 
253   // Params should be given, use null if not.
254   llvm::json::Value params = nullptr;
255   if (llvm::json::Value *p = object->get("params"))
256     params = std::move(*p);
257 
258   if (id)
259     return handler.onCall(*method, std::move(params), std::move(*id));
260   return handler.onNotify(*method, std::move(params));
261 }
262 
263 /// Tries to read a line up to and including \n.
264 /// If failing, feof(), ferror(), or shutdownRequested() will be set.
readLine(std::FILE * in,SmallVectorImpl<char> & out)265 LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) {
266   // Big enough to hold any reasonable header line. May not fit content lines
267   // in delimited mode, but performance doesn't matter for that mode.
268   static constexpr int bufSize = 128;
269   size_t size = 0;
270   out.clear();
271   for (;;) {
272     out.resize_for_overwrite(size + bufSize);
273     if (!std::fgets(&out[size], bufSize, in))
274       return failure();
275 
276     clearerr(in);
277 
278     // If the line contained null bytes, anything after it (including \n) will
279     // be ignored. Fortunately this is not a legal header or JSON.
280     size_t read = std::strlen(&out[size]);
281     if (read > 0 && out[size + read - 1] == '\n') {
282       out.resize(size + read);
283       return success();
284     }
285     size += read;
286   }
287 }
288 
289 // Returns None when:
290 //  - ferror(), feof(), or shutdownRequested() are set.
291 //  - Content-Length is missing or empty (protocol error)
readStandardMessage(std::string & json)292 LogicalResult JSONTransport::readStandardMessage(std::string &json) {
293   // A Language Server Protocol message starts with a set of HTTP headers,
294   // delimited  by \r\n, and terminated by an empty line (\r\n).
295   unsigned long long contentLength = 0;
296   llvm::SmallString<128> line;
297   while (true) {
298     if (feof(in) || ferror(in) || failed(readLine(in, line)))
299       return failure();
300 
301     // Content-Length is a mandatory header, and the only one we handle.
302     StringRef lineRef = line;
303     if (lineRef.consume_front("Content-Length: ")) {
304       llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength);
305     } else if (!lineRef.trim().empty()) {
306       // It's another header, ignore it.
307       continue;
308     } else {
309       // An empty line indicates the end of headers. Go ahead and read the JSON.
310       break;
311     }
312   }
313 
314   // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999"
315   if (contentLength == 0 || contentLength > 1 << 30)
316     return failure();
317 
318   json.resize(contentLength);
319   for (size_t pos = 0, read; pos < contentLength; pos += read) {
320     read = std::fread(&json[pos], 1, contentLength - pos, in);
321     if (read == 0)
322       return failure();
323 
324     // If we're done, the error was transient. If we're not done, either it was
325     // transient or we'll see it again on retry.
326     clearerr(in);
327     pos += read;
328   }
329   return success();
330 }
331 
332 /// For lit tests we support a simplified syntax:
333 /// - messages are delimited by '// -----' on a line by itself
334 /// - lines starting with // are ignored.
335 /// This is a testing path, so favor simplicity over performance here.
336 /// When returning failure: feof(), ferror(), or shutdownRequested() will be
337 /// set.
readDelimitedMessage(std::string & json)338 LogicalResult JSONTransport::readDelimitedMessage(std::string &json) {
339   json.clear();
340   llvm::SmallString<128> line;
341   while (succeeded(readLine(in, line))) {
342     StringRef lineRef = line.str().trim();
343     if (lineRef.startswith("//")) {
344       // Found a delimiter for the message.
345       if (lineRef == "// -----")
346         break;
347       continue;
348     }
349 
350     json += line;
351   }
352 
353   return failure(ferror(in));
354 }
355