1 //===- LSPServer.cpp - MLIR Language Server -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "LSPServer.h"
10 #include "../lsp-server-support/Logging.h"
11 #include "../lsp-server-support/Protocol.h"
12 #include "../lsp-server-support/Transport.h"
13 #include "MLIRServer.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/ADT/StringMap.h"
16 
17 #define DEBUG_TYPE "mlir-lsp-server"
18 
19 using namespace mlir;
20 using namespace mlir::lsp;
21 
22 //===----------------------------------------------------------------------===//
23 // LSPServer::Impl
24 //===----------------------------------------------------------------------===//
25 
26 struct LSPServer::Impl {
27   Impl(MLIRServer &server, JSONTransport &transport)
28       : server(server), transport(transport) {}
29 
30   //===--------------------------------------------------------------------===//
31   // Initialization
32 
33   void onInitialize(const InitializeParams &params,
34                     Callback<llvm::json::Value> reply);
35   void onInitialized(const InitializedParams &params);
36   void onShutdown(const NoParams &params, Callback<std::nullptr_t> reply);
37 
38   //===--------------------------------------------------------------------===//
39   // Document Change
40 
41   void onDocumentDidOpen(const DidOpenTextDocumentParams &params);
42   void onDocumentDidClose(const DidCloseTextDocumentParams &params);
43   void onDocumentDidChange(const DidChangeTextDocumentParams &params);
44 
45   //===--------------------------------------------------------------------===//
46   // Definitions and References
47 
48   void onGoToDefinition(const TextDocumentPositionParams &params,
49                         Callback<std::vector<Location>> reply);
50   void onReference(const ReferenceParams &params,
51                    Callback<std::vector<Location>> reply);
52 
53   //===--------------------------------------------------------------------===//
54   // Hover
55 
56   void onHover(const TextDocumentPositionParams &params,
57                Callback<Optional<Hover>> reply);
58 
59   //===--------------------------------------------------------------------===//
60   // Document Symbols
61 
62   void onDocumentSymbol(const DocumentSymbolParams &params,
63                         Callback<std::vector<DocumentSymbol>> reply);
64 
65   //===--------------------------------------------------------------------===//
66   // Code Completion
67 
68   void onCompletion(const CompletionParams &params,
69                     Callback<CompletionList> reply);
70 
71   //===--------------------------------------------------------------------===//
72   // Fields
73   //===--------------------------------------------------------------------===//
74 
75   MLIRServer &server;
76   JSONTransport &transport;
77 
78   /// An outgoing notification used to send diagnostics to the client when they
79   /// are ready to be processed.
80   OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
81 
82   /// Used to indicate that the 'shutdown' request was received from the
83   /// Language Server client.
84   bool shutdownRequestReceived = false;
85 };
86 
87 //===----------------------------------------------------------------------===//
88 // Initialization
89 
90 void LSPServer::Impl::onInitialize(const InitializeParams &params,
91                                    Callback<llvm::json::Value> reply) {
92   // Send a response with the capabilities of this server.
93   llvm::json::Object serverCaps{
94       {"textDocumentSync",
95        llvm::json::Object{
96            {"openClose", true},
97            {"change", (int)TextDocumentSyncKind::Full},
98            {"save", true},
99        }},
100       {"completionProvider",
101        llvm::json::Object{
102            {"allCommitCharacters",
103             {"\t", ";", ",", "+", "-", "/", "*", "&", "?", ".", "=", "|"}},
104            {"resolveProvider", false},
105            {"triggerCharacters",
106             {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}},
107        }},
108       {"definitionProvider", true},
109       {"referencesProvider", true},
110       {"hoverProvider", true},
111 
112       // For now we only support documenting symbols when the client supports
113       // hierarchical symbols.
114       {"documentSymbolProvider",
115        params.capabilities.hierarchicalDocumentSymbol},
116   };
117 
118   llvm::json::Object result{
119       {{"serverInfo",
120         llvm::json::Object{{"name", "mlir-lsp-server"}, {"version", "0.0.0"}}},
121        {"capabilities", std::move(serverCaps)}}};
122   reply(std::move(result));
123 }
124 void LSPServer::Impl::onInitialized(const InitializedParams &) {}
125 void LSPServer::Impl::onShutdown(const NoParams &,
126                                  Callback<std::nullptr_t> reply) {
127   shutdownRequestReceived = true;
128   reply(nullptr);
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // Document Change
133 
134 void LSPServer::Impl::onDocumentDidOpen(
135     const DidOpenTextDocumentParams &params) {
136   PublishDiagnosticsParams diagParams(params.textDocument.uri,
137                                       params.textDocument.version);
138   server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text,
139                              params.textDocument.version,
140                              diagParams.diagnostics);
141 
142   // Publish any recorded diagnostics.
143   publishDiagnostics(diagParams);
144 }
145 void LSPServer::Impl::onDocumentDidClose(
146     const DidCloseTextDocumentParams &params) {
147   Optional<int64_t> version = server.removeDocument(params.textDocument.uri);
148   if (!version)
149     return;
150 
151   // Empty out the diagnostics shown for this document. This will clear out
152   // anything currently displayed by the client for this document (e.g. in the
153   // "Problems" pane of VSCode).
154   publishDiagnostics(
155       PublishDiagnosticsParams(params.textDocument.uri, *version));
156 }
157 void LSPServer::Impl::onDocumentDidChange(
158     const DidChangeTextDocumentParams &params) {
159   // TODO: We currently only support full document updates, we should refactor
160   // to avoid this.
161   if (params.contentChanges.size() != 1)
162     return;
163   PublishDiagnosticsParams diagParams(params.textDocument.uri,
164                                       params.textDocument.version);
165   server.addOrUpdateDocument(
166       params.textDocument.uri, params.contentChanges.front().text,
167       params.textDocument.version, diagParams.diagnostics);
168 
169   // Publish any recorded diagnostics.
170   publishDiagnostics(diagParams);
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // Definitions and References
175 
176 void LSPServer::Impl::onGoToDefinition(const TextDocumentPositionParams &params,
177                                        Callback<std::vector<Location>> reply) {
178   std::vector<Location> locations;
179   server.getLocationsOf(params.textDocument.uri, params.position, locations);
180   reply(std::move(locations));
181 }
182 
183 void LSPServer::Impl::onReference(const ReferenceParams &params,
184                                   Callback<std::vector<Location>> reply) {
185   std::vector<Location> locations;
186   server.findReferencesOf(params.textDocument.uri, params.position, locations);
187   reply(std::move(locations));
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // Hover
192 
193 void LSPServer::Impl::onHover(const TextDocumentPositionParams &params,
194                               Callback<Optional<Hover>> reply) {
195   reply(server.findHover(params.textDocument.uri, params.position));
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // Document Symbols
200 
201 void LSPServer::Impl::onDocumentSymbol(
202     const DocumentSymbolParams &params,
203     Callback<std::vector<DocumentSymbol>> reply) {
204   std::vector<DocumentSymbol> symbols;
205   server.findDocumentSymbols(params.textDocument.uri, symbols);
206   reply(std::move(symbols));
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // Code Completion
211 
212 void LSPServer::Impl::onCompletion(const CompletionParams &params,
213                                    Callback<CompletionList> reply) {
214   reply(server.getCodeCompletion(params.textDocument.uri, params.position));
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // LSPServer
219 //===----------------------------------------------------------------------===//
220 
221 LSPServer::LSPServer(MLIRServer &server, JSONTransport &transport)
222     : impl(std::make_unique<Impl>(server, transport)) {}
223 LSPServer::~LSPServer() = default;
224 
225 LogicalResult LSPServer::run() {
226   MessageHandler messageHandler(impl->transport);
227 
228   // Initialization
229   messageHandler.method("initialize", impl.get(), &Impl::onInitialize);
230   messageHandler.notification("initialized", impl.get(), &Impl::onInitialized);
231   messageHandler.method("shutdown", impl.get(), &Impl::onShutdown);
232 
233   // Document Changes
234   messageHandler.notification("textDocument/didOpen", impl.get(),
235                               &Impl::onDocumentDidOpen);
236   messageHandler.notification("textDocument/didClose", impl.get(),
237                               &Impl::onDocumentDidClose);
238   messageHandler.notification("textDocument/didChange", impl.get(),
239                               &Impl::onDocumentDidChange);
240 
241   // Definitions and References
242   messageHandler.method("textDocument/definition", impl.get(),
243                         &Impl::onGoToDefinition);
244   messageHandler.method("textDocument/references", impl.get(),
245                         &Impl::onReference);
246 
247   // Hover
248   messageHandler.method("textDocument/hover", impl.get(), &Impl::onHover);
249 
250   // Document Symbols
251   messageHandler.method("textDocument/documentSymbol", impl.get(),
252                         &Impl::onDocumentSymbol);
253 
254   // Code Completion
255   messageHandler.method("textDocument/completion", impl.get(),
256                         &Impl::onCompletion);
257 
258   // Diagnostics
259   impl->publishDiagnostics =
260       messageHandler.outgoingNotification<PublishDiagnosticsParams>(
261           "textDocument/publishDiagnostics");
262 
263   // Run the main loop of the transport.
264   LogicalResult result = success();
265   if (llvm::Error error = impl->transport.run(messageHandler)) {
266     Logger::error("Transport error: {0}", error);
267     llvm::consumeError(std::move(error));
268     result = failure();
269   } else {
270     result = success(impl->shutdownRequestReceived);
271   }
272   return result;
273 }
274