1 //===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "Transport.h"
10 #include "Logging.h"
11 #include "Protocol.h"
12 #include "llvm/ADT/SmallString.h"
13 #include "llvm/Support/Errno.h"
14 #include "llvm/Support/Error.h"
15 #include <system_error>
16 #include <utility>
17
18 using namespace mlir;
19 using namespace mlir::lsp;
20
21 //===----------------------------------------------------------------------===//
22 // Reply
23 //===----------------------------------------------------------------------===//
24
25 namespace {
26 /// Function object to reply to an LSP call.
27 /// Each instance must be called exactly once, otherwise:
28 /// - if there was no reply, an error reply is sent
29 /// - if there were multiple replies, only the first is sent
30 class Reply {
31 public:
32 Reply(const llvm::json::Value &id, StringRef method,
33 JSONTransport &transport);
34 Reply(Reply &&other);
35 Reply &operator=(Reply &&) = delete;
36 Reply(const Reply &) = delete;
37 Reply &operator=(const Reply &) = delete;
38
39 void operator()(llvm::Expected<llvm::json::Value> reply);
40
41 private:
42 StringRef method;
43 std::atomic<bool> replied = {false};
44 llvm::json::Value id;
45 JSONTransport *transport;
46 };
47 } // namespace
48
Reply(const llvm::json::Value & id,llvm::StringRef method,JSONTransport & transport)49 Reply::Reply(const llvm::json::Value &id, llvm::StringRef method,
50 JSONTransport &transport)
51 : id(id), transport(&transport) {}
52
Reply(Reply && other)53 Reply::Reply(Reply &&other)
54 : replied(other.replied.load()), id(std::move(other.id)),
55 transport(other.transport) {
56 other.transport = nullptr;
57 }
58
operator ()(llvm::Expected<llvm::json::Value> reply)59 void Reply::operator()(llvm::Expected<llvm::json::Value> reply) {
60 if (replied.exchange(true)) {
61 Logger::error("Replied twice to message {0}({1})", method, id);
62 assert(false && "must reply to each call only once!");
63 return;
64 }
65 assert(transport && "expected valid transport to reply to");
66
67 if (reply) {
68 Logger::info("--> reply:{0}({1})", method, id);
69 transport->reply(std::move(id), std::move(reply));
70 } else {
71 llvm::Error error = reply.takeError();
72 Logger::info("--> reply:{0}({1})", method, id, error);
73 transport->reply(std::move(id), std::move(error));
74 }
75 }
76
77 //===----------------------------------------------------------------------===//
78 // MessageHandler
79 //===----------------------------------------------------------------------===//
80
onNotify(llvm::StringRef method,llvm::json::Value value)81 bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) {
82 Logger::info("--> {0}", method);
83
84 if (method == "exit")
85 return false;
86 if (method == "$cancel") {
87 // TODO: Add support for cancelling requests.
88 } else {
89 auto it = notificationHandlers.find(method);
90 if (it != notificationHandlers.end())
91 it->second(std::move(value));
92 }
93 return true;
94 }
95
onCall(llvm::StringRef method,llvm::json::Value params,llvm::json::Value id)96 bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
97 llvm::json::Value id) {
98 Logger::info("--> {0}({1})", method, id);
99
100 Reply reply(id, method, transport);
101
102 auto it = methodHandlers.find(method);
103 if (it != methodHandlers.end()) {
104 it->second(std::move(params), std::move(reply));
105 } else {
106 reply(llvm::make_error<LSPError>("method not found: " + method.str(),
107 ErrorCode::MethodNotFound));
108 }
109 return true;
110 }
111
onReply(llvm::json::Value id,llvm::Expected<llvm::json::Value> result)112 bool MessageHandler::onReply(llvm::json::Value id,
113 llvm::Expected<llvm::json::Value> result) {
114 // TODO: Add support for reply callbacks when support for outgoing messages is
115 // added. For now, we just log an error on any replies received.
116 Callback<llvm::json::Value> replyHandler =
117 [&id](llvm::Expected<llvm::json::Value> result) {
118 Logger::error(
119 "received a reply with ID {0}, but there was no such call", id);
120 if (!result)
121 llvm::consumeError(result.takeError());
122 };
123
124 // Log and run the reply handler.
125 if (result)
126 replyHandler(std::move(result));
127 else
128 replyHandler(result.takeError());
129 return true;
130 }
131
132 //===----------------------------------------------------------------------===//
133 // JSONTransport
134 //===----------------------------------------------------------------------===//
135
136 /// Encode the given error as a JSON object.
encodeError(llvm::Error error)137 static llvm::json::Object encodeError(llvm::Error error) {
138 std::string message;
139 ErrorCode code = ErrorCode::UnknownErrorCode;
140 auto handlerFn = [&](const LSPError &lspError) -> llvm::Error {
141 message = lspError.message;
142 code = lspError.code;
143 return llvm::Error::success();
144 };
145 if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn))
146 message = llvm::toString(std::move(unhandled));
147
148 return llvm::json::Object{
149 {"message", std::move(message)},
150 {"code", int64_t(code)},
151 };
152 }
153
154 /// Decode the given JSON object into an error.
decodeError(const llvm::json::Object & o)155 llvm::Error decodeError(const llvm::json::Object &o) {
156 StringRef msg = o.getString("message").value_or("Unspecified error");
157 if (Optional<int64_t> code = o.getInteger("code"))
158 return llvm::make_error<LSPError>(msg.str(), ErrorCode(*code));
159 return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(),
160 msg.str());
161 }
162
notify(StringRef method,llvm::json::Value params)163 void JSONTransport::notify(StringRef method, llvm::json::Value params) {
164 sendMessage(llvm::json::Object{
165 {"jsonrpc", "2.0"},
166 {"method", method},
167 {"params", std::move(params)},
168 });
169 }
call(StringRef method,llvm::json::Value params,llvm::json::Value id)170 void JSONTransport::call(StringRef method, llvm::json::Value params,
171 llvm::json::Value id) {
172 sendMessage(llvm::json::Object{
173 {"jsonrpc", "2.0"},
174 {"id", std::move(id)},
175 {"method", method},
176 {"params", std::move(params)},
177 });
178 }
reply(llvm::json::Value id,llvm::Expected<llvm::json::Value> result)179 void JSONTransport::reply(llvm::json::Value id,
180 llvm::Expected<llvm::json::Value> result) {
181 if (result) {
182 return sendMessage(llvm::json::Object{
183 {"jsonrpc", "2.0"},
184 {"id", std::move(id)},
185 {"result", std::move(*result)},
186 });
187 }
188
189 sendMessage(llvm::json::Object{
190 {"jsonrpc", "2.0"},
191 {"id", std::move(id)},
192 {"error", encodeError(result.takeError())},
193 });
194 }
195
run(MessageHandler & handler)196 llvm::Error JSONTransport::run(MessageHandler &handler) {
197 std::string json;
198 while (!feof(in)) {
199 if (ferror(in)) {
200 return llvm::errorCodeToError(
201 std::error_code(errno, std::system_category()));
202 }
203
204 if (succeeded(readMessage(json))) {
205 if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(json)) {
206 if (!handleMessage(std::move(*doc), handler))
207 return llvm::Error::success();
208 } else {
209 Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError()));
210 }
211 }
212 }
213 return llvm::errorCodeToError(std::make_error_code(std::errc::io_error));
214 }
215
sendMessage(llvm::json::Value msg)216 void JSONTransport::sendMessage(llvm::json::Value msg) {
217 outputBuffer.clear();
218 llvm::raw_svector_ostream os(outputBuffer);
219 os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg);
220 out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n"
221 << outputBuffer;
222 out.flush();
223 Logger::debug(">>> {0}\n", outputBuffer);
224 }
225
handleMessage(llvm::json::Value msg,MessageHandler & handler)226 bool JSONTransport::handleMessage(llvm::json::Value msg,
227 MessageHandler &handler) {
228 // Message must be an object with "jsonrpc":"2.0".
229 llvm::json::Object *object = msg.getAsObject();
230 if (!object ||
231 object->getString("jsonrpc") != llvm::Optional<StringRef>("2.0"))
232 return false;
233
234 // `id` may be any JSON value. If absent, this is a notification.
235 llvm::Optional<llvm::json::Value> id;
236 if (llvm::json::Value *i = object->get("id"))
237 id = std::move(*i);
238 Optional<StringRef> method = object->getString("method");
239
240 // This is a response.
241 if (!method) {
242 if (!id)
243 return false;
244 if (auto *err = object->getObject("error"))
245 return handler.onReply(std::move(*id), decodeError(*err));
246 // result should be given, use null if not.
247 llvm::json::Value result = nullptr;
248 if (llvm::json::Value *r = object->get("result"))
249 result = std::move(*r);
250 return handler.onReply(std::move(*id), std::move(result));
251 }
252
253 // Params should be given, use null if not.
254 llvm::json::Value params = nullptr;
255 if (llvm::json::Value *p = object->get("params"))
256 params = std::move(*p);
257
258 if (id)
259 return handler.onCall(*method, std::move(params), std::move(*id));
260 return handler.onNotify(*method, std::move(params));
261 }
262
263 /// Tries to read a line up to and including \n.
264 /// If failing, feof(), ferror(), or shutdownRequested() will be set.
readLine(std::FILE * in,SmallVectorImpl<char> & out)265 LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) {
266 // Big enough to hold any reasonable header line. May not fit content lines
267 // in delimited mode, but performance doesn't matter for that mode.
268 static constexpr int bufSize = 128;
269 size_t size = 0;
270 out.clear();
271 for (;;) {
272 out.resize_for_overwrite(size + bufSize);
273 if (!std::fgets(&out[size], bufSize, in))
274 return failure();
275
276 clearerr(in);
277
278 // If the line contained null bytes, anything after it (including \n) will
279 // be ignored. Fortunately this is not a legal header or JSON.
280 size_t read = std::strlen(&out[size]);
281 if (read > 0 && out[size + read - 1] == '\n') {
282 out.resize(size + read);
283 return success();
284 }
285 size += read;
286 }
287 }
288
289 // Returns None when:
290 // - ferror(), feof(), or shutdownRequested() are set.
291 // - Content-Length is missing or empty (protocol error)
readStandardMessage(std::string & json)292 LogicalResult JSONTransport::readStandardMessage(std::string &json) {
293 // A Language Server Protocol message starts with a set of HTTP headers,
294 // delimited by \r\n, and terminated by an empty line (\r\n).
295 unsigned long long contentLength = 0;
296 llvm::SmallString<128> line;
297 while (true) {
298 if (feof(in) || ferror(in) || failed(readLine(in, line)))
299 return failure();
300
301 // Content-Length is a mandatory header, and the only one we handle.
302 StringRef lineRef = line;
303 if (lineRef.consume_front("Content-Length: ")) {
304 llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength);
305 } else if (!lineRef.trim().empty()) {
306 // It's another header, ignore it.
307 continue;
308 } else {
309 // An empty line indicates the end of headers. Go ahead and read the JSON.
310 break;
311 }
312 }
313
314 // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999"
315 if (contentLength == 0 || contentLength > 1 << 30)
316 return failure();
317
318 json.resize(contentLength);
319 for (size_t pos = 0, read; pos < contentLength; pos += read) {
320 read = std::fread(&json[pos], 1, contentLength - pos, in);
321 if (read == 0)
322 return failure();
323
324 // If we're done, the error was transient. If we're not done, either it was
325 // transient or we'll see it again on retry.
326 clearerr(in);
327 pos += read;
328 }
329 return success();
330 }
331
332 /// For lit tests we support a simplified syntax:
333 /// - messages are delimited by '// -----' on a line by itself
334 /// - lines starting with // are ignored.
335 /// This is a testing path, so favor simplicity over performance here.
336 /// When returning failure: feof(), ferror(), or shutdownRequested() will be
337 /// set.
readDelimitedMessage(std::string & json)338 LogicalResult JSONTransport::readDelimitedMessage(std::string &json) {
339 json.clear();
340 llvm::SmallString<128> line;
341 while (succeeded(readLine(in, line))) {
342 StringRef lineRef = line.str().trim();
343 if (lineRef.startswith("//")) {
344 // Found a delimiter for the message.
345 if (lineRef == "// -----")
346 break;
347 continue;
348 }
349
350 json += line;
351 }
352
353 return failure(ferror(in));
354 }
355