1 //===- LSPServer.cpp - MLIR Language Server -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "LSPServer.h"
10 #include "../lsp-server-support/Logging.h"
11 #include "../lsp-server-support/Protocol.h"
12 #include "../lsp-server-support/Transport.h"
13 #include "MLIRServer.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/ADT/StringMap.h"
16 
17 #define DEBUG_TYPE "mlir-lsp-server"
18 
19 using namespace mlir;
20 using namespace mlir::lsp;
21 
22 //===----------------------------------------------------------------------===//
23 // LSPServer
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 struct LSPServer {
LSPServer__anon2358c2420111::LSPServer28   LSPServer(MLIRServer &server) : server(server) {}
29 
30   //===--------------------------------------------------------------------===//
31   // Initialization
32 
33   void onInitialize(const InitializeParams &params,
34                     Callback<llvm::json::Value> reply);
35   void onInitialized(const InitializedParams &params);
36   void onShutdown(const NoParams &params, Callback<std::nullptr_t> reply);
37 
38   //===--------------------------------------------------------------------===//
39   // Document Change
40 
41   void onDocumentDidOpen(const DidOpenTextDocumentParams &params);
42   void onDocumentDidClose(const DidCloseTextDocumentParams &params);
43   void onDocumentDidChange(const DidChangeTextDocumentParams &params);
44 
45   //===--------------------------------------------------------------------===//
46   // Definitions and References
47 
48   void onGoToDefinition(const TextDocumentPositionParams &params,
49                         Callback<std::vector<Location>> reply);
50   void onReference(const ReferenceParams &params,
51                    Callback<std::vector<Location>> reply);
52 
53   //===--------------------------------------------------------------------===//
54   // Hover
55 
56   void onHover(const TextDocumentPositionParams &params,
57                Callback<Optional<Hover>> reply);
58 
59   //===--------------------------------------------------------------------===//
60   // Document Symbols
61 
62   void onDocumentSymbol(const DocumentSymbolParams &params,
63                         Callback<std::vector<DocumentSymbol>> reply);
64 
65   //===--------------------------------------------------------------------===//
66   // Code Completion
67 
68   void onCompletion(const CompletionParams &params,
69                     Callback<CompletionList> reply);
70 
71   //===--------------------------------------------------------------------===//
72   // Code Action
73 
74   void onCodeAction(const CodeActionParams &params,
75                     Callback<llvm::json::Value> reply);
76 
77   //===--------------------------------------------------------------------===//
78   // Fields
79   //===--------------------------------------------------------------------===//
80 
81   MLIRServer &server;
82 
83   /// An outgoing notification used to send diagnostics to the client when they
84   /// are ready to be processed.
85   OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
86 
87   /// Used to indicate that the 'shutdown' request was received from the
88   /// Language Server client.
89   bool shutdownRequestReceived = false;
90 };
91 } // namespace
92 
93 //===----------------------------------------------------------------------===//
94 // Initialization
95 
onInitialize(const InitializeParams & params,Callback<llvm::json::Value> reply)96 void LSPServer::onInitialize(const InitializeParams &params,
97                              Callback<llvm::json::Value> reply) {
98   // Send a response with the capabilities of this server.
99   llvm::json::Object serverCaps{
100       {"textDocumentSync",
101        llvm::json::Object{
102            {"openClose", true},
103            {"change", (int)TextDocumentSyncKind::Full},
104            {"save", true},
105        }},
106       {"completionProvider",
107        llvm::json::Object{
108            {"allCommitCharacters",
109             {
110                 "\t",
111                 ";",
112                 ",",
113                 ".",
114                 "=",
115             }},
116            {"resolveProvider", false},
117            {"triggerCharacters",
118             {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}},
119        }},
120       {"definitionProvider", true},
121       {"referencesProvider", true},
122       {"hoverProvider", true},
123 
124       // For now we only support documenting symbols when the client supports
125       // hierarchical symbols.
126       {"documentSymbolProvider",
127        params.capabilities.hierarchicalDocumentSymbol},
128   };
129 
130   // Per LSP, codeActionProvider can be either boolean or CodeActionOptions.
131   // CodeActionOptions is only valid if the client supports action literal
132   // via textDocument.codeAction.codeActionLiteralSupport.
133   serverCaps["codeActionProvider"] =
134       params.capabilities.codeActionStructure
135           ? llvm::json::Object{{"codeActionKinds",
136                                 {CodeAction::kQuickFix, CodeAction::kRefactor,
137                                  CodeAction::kInfo}}}
138           : llvm::json::Value(true);
139 
140   llvm::json::Object result{
141       {{"serverInfo",
142         llvm::json::Object{{"name", "mlir-lsp-server"}, {"version", "0.0.0"}}},
143        {"capabilities", std::move(serverCaps)}}};
144   reply(std::move(result));
145 }
onInitialized(const InitializedParams &)146 void LSPServer::onInitialized(const InitializedParams &) {}
onShutdown(const NoParams &,Callback<std::nullptr_t> reply)147 void LSPServer::onShutdown(const NoParams &, Callback<std::nullptr_t> reply) {
148   shutdownRequestReceived = true;
149   reply(nullptr);
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // Document Change
154 
onDocumentDidOpen(const DidOpenTextDocumentParams & params)155 void LSPServer::onDocumentDidOpen(const DidOpenTextDocumentParams &params) {
156   PublishDiagnosticsParams diagParams(params.textDocument.uri,
157                                       params.textDocument.version);
158   server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text,
159                              params.textDocument.version,
160                              diagParams.diagnostics);
161 
162   // Publish any recorded diagnostics.
163   publishDiagnostics(diagParams);
164 }
onDocumentDidClose(const DidCloseTextDocumentParams & params)165 void LSPServer::onDocumentDidClose(const DidCloseTextDocumentParams &params) {
166   Optional<int64_t> version = server.removeDocument(params.textDocument.uri);
167   if (!version)
168     return;
169 
170   // Empty out the diagnostics shown for this document. This will clear out
171   // anything currently displayed by the client for this document (e.g. in the
172   // "Problems" pane of VSCode).
173   publishDiagnostics(
174       PublishDiagnosticsParams(params.textDocument.uri, *version));
175 }
onDocumentDidChange(const DidChangeTextDocumentParams & params)176 void LSPServer::onDocumentDidChange(const DidChangeTextDocumentParams &params) {
177   // TODO: We currently only support full document updates, we should refactor
178   // to avoid this.
179   if (params.contentChanges.size() != 1)
180     return;
181   PublishDiagnosticsParams diagParams(params.textDocument.uri,
182                                       params.textDocument.version);
183   server.addOrUpdateDocument(
184       params.textDocument.uri, params.contentChanges.front().text,
185       params.textDocument.version, diagParams.diagnostics);
186 
187   // Publish any recorded diagnostics.
188   publishDiagnostics(diagParams);
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // Definitions and References
193 
onGoToDefinition(const TextDocumentPositionParams & params,Callback<std::vector<Location>> reply)194 void LSPServer::onGoToDefinition(const TextDocumentPositionParams &params,
195                                  Callback<std::vector<Location>> reply) {
196   std::vector<Location> locations;
197   server.getLocationsOf(params.textDocument.uri, params.position, locations);
198   reply(std::move(locations));
199 }
200 
onReference(const ReferenceParams & params,Callback<std::vector<Location>> reply)201 void LSPServer::onReference(const ReferenceParams &params,
202                             Callback<std::vector<Location>> reply) {
203   std::vector<Location> locations;
204   server.findReferencesOf(params.textDocument.uri, params.position, locations);
205   reply(std::move(locations));
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // Hover
210 
onHover(const TextDocumentPositionParams & params,Callback<Optional<Hover>> reply)211 void LSPServer::onHover(const TextDocumentPositionParams &params,
212                         Callback<Optional<Hover>> reply) {
213   reply(server.findHover(params.textDocument.uri, params.position));
214 }
215 
216 //===----------------------------------------------------------------------===//
217 // Document Symbols
218 
onDocumentSymbol(const DocumentSymbolParams & params,Callback<std::vector<DocumentSymbol>> reply)219 void LSPServer::onDocumentSymbol(const DocumentSymbolParams &params,
220                                  Callback<std::vector<DocumentSymbol>> reply) {
221   std::vector<DocumentSymbol> symbols;
222   server.findDocumentSymbols(params.textDocument.uri, symbols);
223   reply(std::move(symbols));
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // Code Completion
228 
onCompletion(const CompletionParams & params,Callback<CompletionList> reply)229 void LSPServer::onCompletion(const CompletionParams &params,
230                              Callback<CompletionList> reply) {
231   reply(server.getCodeCompletion(params.textDocument.uri, params.position));
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // Code Action
236 
onCodeAction(const CodeActionParams & params,Callback<llvm::json::Value> reply)237 void LSPServer::onCodeAction(const CodeActionParams &params,
238                              Callback<llvm::json::Value> reply) {
239   URIForFile uri = params.textDocument.uri;
240 
241   // Check whether a particular CodeActionKind is included in the response.
242   auto isKindAllowed = [only(params.context.only)](StringRef kind) {
243     if (only.empty())
244       return true;
245     return llvm::any_of(only, [&](StringRef base) {
246       return kind.consume_front(base) && (kind.empty() || kind.startswith("."));
247     });
248   };
249 
250   // We provide a code action for fixes on the specified diagnostics.
251   std::vector<CodeAction> actions;
252   if (isKindAllowed(CodeAction::kQuickFix))
253     server.getCodeActions(uri, params.range.start, params.context, actions);
254   reply(std::move(actions));
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // Entry point
259 //===----------------------------------------------------------------------===//
260 
runMlirLSPServer(MLIRServer & server,JSONTransport & transport)261 LogicalResult lsp::runMlirLSPServer(MLIRServer &server,
262                                     JSONTransport &transport) {
263   LSPServer lspServer(server);
264   MessageHandler messageHandler(transport);
265 
266   // Initialization
267   messageHandler.method("initialize", &lspServer, &LSPServer::onInitialize);
268   messageHandler.notification("initialized", &lspServer,
269                               &LSPServer::onInitialized);
270   messageHandler.method("shutdown", &lspServer, &LSPServer::onShutdown);
271 
272   // Document Changes
273   messageHandler.notification("textDocument/didOpen", &lspServer,
274                               &LSPServer::onDocumentDidOpen);
275   messageHandler.notification("textDocument/didClose", &lspServer,
276                               &LSPServer::onDocumentDidClose);
277   messageHandler.notification("textDocument/didChange", &lspServer,
278                               &LSPServer::onDocumentDidChange);
279 
280   // Definitions and References
281   messageHandler.method("textDocument/definition", &lspServer,
282                         &LSPServer::onGoToDefinition);
283   messageHandler.method("textDocument/references", &lspServer,
284                         &LSPServer::onReference);
285 
286   // Hover
287   messageHandler.method("textDocument/hover", &lspServer, &LSPServer::onHover);
288 
289   // Document Symbols
290   messageHandler.method("textDocument/documentSymbol", &lspServer,
291                         &LSPServer::onDocumentSymbol);
292 
293   // Code Completion
294   messageHandler.method("textDocument/completion", &lspServer,
295                         &LSPServer::onCompletion);
296 
297   // Code Action
298   messageHandler.method("textDocument/codeAction", &lspServer,
299                         &LSPServer::onCodeAction);
300 
301   // Diagnostics
302   lspServer.publishDiagnostics =
303       messageHandler.outgoingNotification<PublishDiagnosticsParams>(
304           "textDocument/publishDiagnostics");
305 
306   // Run the main loop of the transport.
307   LogicalResult result = success();
308   if (llvm::Error error = transport.run(messageHandler)) {
309     Logger::error("Transport error: {0}", error);
310     llvm::consumeError(std::move(error));
311     result = failure();
312   } else {
313     result = success(lspServer.shutdownRequestReceived);
314   }
315   return result;
316 }
317