1 //===- LSPServer.cpp - MLIR Language Server -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "LSPServer.h"
10 #include "../lsp-server-support/Logging.h"
11 #include "../lsp-server-support/Protocol.h"
12 #include "../lsp-server-support/Transport.h"
13 #include "MLIRServer.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/ADT/StringMap.h"
16
17 #define DEBUG_TYPE "mlir-lsp-server"
18
19 using namespace mlir;
20 using namespace mlir::lsp;
21
22 //===----------------------------------------------------------------------===//
23 // LSPServer
24 //===----------------------------------------------------------------------===//
25
26 namespace {
27 struct LSPServer {
LSPServer__anon2358c2420111::LSPServer28 LSPServer(MLIRServer &server) : server(server) {}
29
30 //===--------------------------------------------------------------------===//
31 // Initialization
32
33 void onInitialize(const InitializeParams ¶ms,
34 Callback<llvm::json::Value> reply);
35 void onInitialized(const InitializedParams ¶ms);
36 void onShutdown(const NoParams ¶ms, Callback<std::nullptr_t> reply);
37
38 //===--------------------------------------------------------------------===//
39 // Document Change
40
41 void onDocumentDidOpen(const DidOpenTextDocumentParams ¶ms);
42 void onDocumentDidClose(const DidCloseTextDocumentParams ¶ms);
43 void onDocumentDidChange(const DidChangeTextDocumentParams ¶ms);
44
45 //===--------------------------------------------------------------------===//
46 // Definitions and References
47
48 void onGoToDefinition(const TextDocumentPositionParams ¶ms,
49 Callback<std::vector<Location>> reply);
50 void onReference(const ReferenceParams ¶ms,
51 Callback<std::vector<Location>> reply);
52
53 //===--------------------------------------------------------------------===//
54 // Hover
55
56 void onHover(const TextDocumentPositionParams ¶ms,
57 Callback<Optional<Hover>> reply);
58
59 //===--------------------------------------------------------------------===//
60 // Document Symbols
61
62 void onDocumentSymbol(const DocumentSymbolParams ¶ms,
63 Callback<std::vector<DocumentSymbol>> reply);
64
65 //===--------------------------------------------------------------------===//
66 // Code Completion
67
68 void onCompletion(const CompletionParams ¶ms,
69 Callback<CompletionList> reply);
70
71 //===--------------------------------------------------------------------===//
72 // Code Action
73
74 void onCodeAction(const CodeActionParams ¶ms,
75 Callback<llvm::json::Value> reply);
76
77 //===--------------------------------------------------------------------===//
78 // Fields
79 //===--------------------------------------------------------------------===//
80
81 MLIRServer &server;
82
83 /// An outgoing notification used to send diagnostics to the client when they
84 /// are ready to be processed.
85 OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
86
87 /// Used to indicate that the 'shutdown' request was received from the
88 /// Language Server client.
89 bool shutdownRequestReceived = false;
90 };
91 } // namespace
92
93 //===----------------------------------------------------------------------===//
94 // Initialization
95
onInitialize(const InitializeParams & params,Callback<llvm::json::Value> reply)96 void LSPServer::onInitialize(const InitializeParams ¶ms,
97 Callback<llvm::json::Value> reply) {
98 // Send a response with the capabilities of this server.
99 llvm::json::Object serverCaps{
100 {"textDocumentSync",
101 llvm::json::Object{
102 {"openClose", true},
103 {"change", (int)TextDocumentSyncKind::Full},
104 {"save", true},
105 }},
106 {"completionProvider",
107 llvm::json::Object{
108 {"allCommitCharacters",
109 {
110 "\t",
111 ";",
112 ",",
113 ".",
114 "=",
115 }},
116 {"resolveProvider", false},
117 {"triggerCharacters",
118 {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}},
119 }},
120 {"definitionProvider", true},
121 {"referencesProvider", true},
122 {"hoverProvider", true},
123
124 // For now we only support documenting symbols when the client supports
125 // hierarchical symbols.
126 {"documentSymbolProvider",
127 params.capabilities.hierarchicalDocumentSymbol},
128 };
129
130 // Per LSP, codeActionProvider can be either boolean or CodeActionOptions.
131 // CodeActionOptions is only valid if the client supports action literal
132 // via textDocument.codeAction.codeActionLiteralSupport.
133 serverCaps["codeActionProvider"] =
134 params.capabilities.codeActionStructure
135 ? llvm::json::Object{{"codeActionKinds",
136 {CodeAction::kQuickFix, CodeAction::kRefactor,
137 CodeAction::kInfo}}}
138 : llvm::json::Value(true);
139
140 llvm::json::Object result{
141 {{"serverInfo",
142 llvm::json::Object{{"name", "mlir-lsp-server"}, {"version", "0.0.0"}}},
143 {"capabilities", std::move(serverCaps)}}};
144 reply(std::move(result));
145 }
onInitialized(const InitializedParams &)146 void LSPServer::onInitialized(const InitializedParams &) {}
onShutdown(const NoParams &,Callback<std::nullptr_t> reply)147 void LSPServer::onShutdown(const NoParams &, Callback<std::nullptr_t> reply) {
148 shutdownRequestReceived = true;
149 reply(nullptr);
150 }
151
152 //===----------------------------------------------------------------------===//
153 // Document Change
154
onDocumentDidOpen(const DidOpenTextDocumentParams & params)155 void LSPServer::onDocumentDidOpen(const DidOpenTextDocumentParams ¶ms) {
156 PublishDiagnosticsParams diagParams(params.textDocument.uri,
157 params.textDocument.version);
158 server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text,
159 params.textDocument.version,
160 diagParams.diagnostics);
161
162 // Publish any recorded diagnostics.
163 publishDiagnostics(diagParams);
164 }
onDocumentDidClose(const DidCloseTextDocumentParams & params)165 void LSPServer::onDocumentDidClose(const DidCloseTextDocumentParams ¶ms) {
166 Optional<int64_t> version = server.removeDocument(params.textDocument.uri);
167 if (!version)
168 return;
169
170 // Empty out the diagnostics shown for this document. This will clear out
171 // anything currently displayed by the client for this document (e.g. in the
172 // "Problems" pane of VSCode).
173 publishDiagnostics(
174 PublishDiagnosticsParams(params.textDocument.uri, *version));
175 }
onDocumentDidChange(const DidChangeTextDocumentParams & params)176 void LSPServer::onDocumentDidChange(const DidChangeTextDocumentParams ¶ms) {
177 // TODO: We currently only support full document updates, we should refactor
178 // to avoid this.
179 if (params.contentChanges.size() != 1)
180 return;
181 PublishDiagnosticsParams diagParams(params.textDocument.uri,
182 params.textDocument.version);
183 server.addOrUpdateDocument(
184 params.textDocument.uri, params.contentChanges.front().text,
185 params.textDocument.version, diagParams.diagnostics);
186
187 // Publish any recorded diagnostics.
188 publishDiagnostics(diagParams);
189 }
190
191 //===----------------------------------------------------------------------===//
192 // Definitions and References
193
onGoToDefinition(const TextDocumentPositionParams & params,Callback<std::vector<Location>> reply)194 void LSPServer::onGoToDefinition(const TextDocumentPositionParams ¶ms,
195 Callback<std::vector<Location>> reply) {
196 std::vector<Location> locations;
197 server.getLocationsOf(params.textDocument.uri, params.position, locations);
198 reply(std::move(locations));
199 }
200
onReference(const ReferenceParams & params,Callback<std::vector<Location>> reply)201 void LSPServer::onReference(const ReferenceParams ¶ms,
202 Callback<std::vector<Location>> reply) {
203 std::vector<Location> locations;
204 server.findReferencesOf(params.textDocument.uri, params.position, locations);
205 reply(std::move(locations));
206 }
207
208 //===----------------------------------------------------------------------===//
209 // Hover
210
onHover(const TextDocumentPositionParams & params,Callback<Optional<Hover>> reply)211 void LSPServer::onHover(const TextDocumentPositionParams ¶ms,
212 Callback<Optional<Hover>> reply) {
213 reply(server.findHover(params.textDocument.uri, params.position));
214 }
215
216 //===----------------------------------------------------------------------===//
217 // Document Symbols
218
onDocumentSymbol(const DocumentSymbolParams & params,Callback<std::vector<DocumentSymbol>> reply)219 void LSPServer::onDocumentSymbol(const DocumentSymbolParams ¶ms,
220 Callback<std::vector<DocumentSymbol>> reply) {
221 std::vector<DocumentSymbol> symbols;
222 server.findDocumentSymbols(params.textDocument.uri, symbols);
223 reply(std::move(symbols));
224 }
225
226 //===----------------------------------------------------------------------===//
227 // Code Completion
228
onCompletion(const CompletionParams & params,Callback<CompletionList> reply)229 void LSPServer::onCompletion(const CompletionParams ¶ms,
230 Callback<CompletionList> reply) {
231 reply(server.getCodeCompletion(params.textDocument.uri, params.position));
232 }
233
234 //===----------------------------------------------------------------------===//
235 // Code Action
236
onCodeAction(const CodeActionParams & params,Callback<llvm::json::Value> reply)237 void LSPServer::onCodeAction(const CodeActionParams ¶ms,
238 Callback<llvm::json::Value> reply) {
239 URIForFile uri = params.textDocument.uri;
240
241 // Check whether a particular CodeActionKind is included in the response.
242 auto isKindAllowed = [only(params.context.only)](StringRef kind) {
243 if (only.empty())
244 return true;
245 return llvm::any_of(only, [&](StringRef base) {
246 return kind.consume_front(base) && (kind.empty() || kind.startswith("."));
247 });
248 };
249
250 // We provide a code action for fixes on the specified diagnostics.
251 std::vector<CodeAction> actions;
252 if (isKindAllowed(CodeAction::kQuickFix))
253 server.getCodeActions(uri, params.range.start, params.context, actions);
254 reply(std::move(actions));
255 }
256
257 //===----------------------------------------------------------------------===//
258 // Entry point
259 //===----------------------------------------------------------------------===//
260
runMlirLSPServer(MLIRServer & server,JSONTransport & transport)261 LogicalResult lsp::runMlirLSPServer(MLIRServer &server,
262 JSONTransport &transport) {
263 LSPServer lspServer(server);
264 MessageHandler messageHandler(transport);
265
266 // Initialization
267 messageHandler.method("initialize", &lspServer, &LSPServer::onInitialize);
268 messageHandler.notification("initialized", &lspServer,
269 &LSPServer::onInitialized);
270 messageHandler.method("shutdown", &lspServer, &LSPServer::onShutdown);
271
272 // Document Changes
273 messageHandler.notification("textDocument/didOpen", &lspServer,
274 &LSPServer::onDocumentDidOpen);
275 messageHandler.notification("textDocument/didClose", &lspServer,
276 &LSPServer::onDocumentDidClose);
277 messageHandler.notification("textDocument/didChange", &lspServer,
278 &LSPServer::onDocumentDidChange);
279
280 // Definitions and References
281 messageHandler.method("textDocument/definition", &lspServer,
282 &LSPServer::onGoToDefinition);
283 messageHandler.method("textDocument/references", &lspServer,
284 &LSPServer::onReference);
285
286 // Hover
287 messageHandler.method("textDocument/hover", &lspServer, &LSPServer::onHover);
288
289 // Document Symbols
290 messageHandler.method("textDocument/documentSymbol", &lspServer,
291 &LSPServer::onDocumentSymbol);
292
293 // Code Completion
294 messageHandler.method("textDocument/completion", &lspServer,
295 &LSPServer::onCompletion);
296
297 // Code Action
298 messageHandler.method("textDocument/codeAction", &lspServer,
299 &LSPServer::onCodeAction);
300
301 // Diagnostics
302 lspServer.publishDiagnostics =
303 messageHandler.outgoingNotification<PublishDiagnosticsParams>(
304 "textDocument/publishDiagnostics");
305
306 // Run the main loop of the transport.
307 LogicalResult result = success();
308 if (llvm::Error error = transport.run(messageHandler)) {
309 Logger::error("Transport error: {0}", error);
310 llvm::consumeError(std::move(error));
311 result = failure();
312 } else {
313 result = success(lspServer.shutdownRequestReceived);
314 }
315 return result;
316 }
317