1 //===- LSPServer.cpp - MLIR Language Server -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "LSPServer.h"
10 #include "../lsp-server-support/Logging.h"
11 #include "../lsp-server-support/Protocol.h"
12 #include "../lsp-server-support/Transport.h"
13 #include "MLIRServer.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/ADT/StringMap.h"
16 
17 #define DEBUG_TYPE "mlir-lsp-server"
18 
19 using namespace mlir;
20 using namespace mlir::lsp;
21 
22 //===----------------------------------------------------------------------===//
23 // LSPServer::Impl
24 //===----------------------------------------------------------------------===//
25 
26 struct LSPServer::Impl {
27   Impl(MLIRServer &server, JSONTransport &transport)
28       : server(server), transport(transport) {}
29 
30   //===--------------------------------------------------------------------===//
31   // Initialization
32 
33   void onInitialize(const InitializeParams &params,
34                     Callback<llvm::json::Value> reply);
35   void onInitialized(const InitializedParams &params);
36   void onShutdown(const NoParams &params, Callback<std::nullptr_t> reply);
37 
38   //===--------------------------------------------------------------------===//
39   // Document Change
40 
41   void onDocumentDidOpen(const DidOpenTextDocumentParams &params);
42   void onDocumentDidClose(const DidCloseTextDocumentParams &params);
43   void onDocumentDidChange(const DidChangeTextDocumentParams &params);
44 
45   //===--------------------------------------------------------------------===//
46   // Definitions and References
47 
48   void onGoToDefinition(const TextDocumentPositionParams &params,
49                         Callback<std::vector<Location>> reply);
50   void onReference(const ReferenceParams &params,
51                    Callback<std::vector<Location>> reply);
52 
53   //===--------------------------------------------------------------------===//
54   // Hover
55 
56   void onHover(const TextDocumentPositionParams &params,
57                Callback<Optional<Hover>> reply);
58 
59   //===--------------------------------------------------------------------===//
60   // Document Symbols
61 
62   void onDocumentSymbol(const DocumentSymbolParams &params,
63                         Callback<std::vector<DocumentSymbol>> reply);
64 
65   //===--------------------------------------------------------------------===//
66   // Code Completion
67 
68   void onCompletion(const CompletionParams &params,
69                     Callback<CompletionList> reply);
70 
71   //===--------------------------------------------------------------------===//
72   // Fields
73   //===--------------------------------------------------------------------===//
74 
75   MLIRServer &server;
76   JSONTransport &transport;
77 
78   /// An outgoing notification used to send diagnostics to the client when they
79   /// are ready to be processed.
80   OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
81 
82   /// Used to indicate that the 'shutdown' request was received from the
83   /// Language Server client.
84   bool shutdownRequestReceived = false;
85 };
86 
87 //===----------------------------------------------------------------------===//
88 // Initialization
89 
90 void LSPServer::Impl::onInitialize(const InitializeParams &params,
91                                    Callback<llvm::json::Value> reply) {
92   // Send a response with the capabilities of this server.
93   llvm::json::Object serverCaps{
94       {"textDocumentSync",
95        llvm::json::Object{
96            {"openClose", true},
97            {"change", (int)TextDocumentSyncKind::Full},
98            {"save", true},
99        }},
100       {"completionProvider",
101        llvm::json::Object{
102            {"allCommitCharacters",
103             {"\t", "(", ")", "[", "]", "<", ">", ";", ",", "+", "-", "/", "*",
104              "&", "?", ".", "=", "|"}},
105            {"resolveProvider", false},
106            {"triggerCharacters",
107             {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}},
108        }},
109       {"definitionProvider", true},
110       {"referencesProvider", true},
111       {"hoverProvider", true},
112 
113       // For now we only support documenting symbols when the client supports
114       // hierarchical symbols.
115       {"documentSymbolProvider",
116        params.capabilities.hierarchicalDocumentSymbol},
117   };
118 
119   llvm::json::Object result{
120       {{"serverInfo",
121         llvm::json::Object{{"name", "mlir-lsp-server"}, {"version", "0.0.0"}}},
122        {"capabilities", std::move(serverCaps)}}};
123   reply(std::move(result));
124 }
125 void LSPServer::Impl::onInitialized(const InitializedParams &) {}
126 void LSPServer::Impl::onShutdown(const NoParams &,
127                                  Callback<std::nullptr_t> reply) {
128   shutdownRequestReceived = true;
129   reply(nullptr);
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // Document Change
134 
135 void LSPServer::Impl::onDocumentDidOpen(
136     const DidOpenTextDocumentParams &params) {
137   PublishDiagnosticsParams diagParams(params.textDocument.uri,
138                                       params.textDocument.version);
139   server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text,
140                              params.textDocument.version,
141                              diagParams.diagnostics);
142 
143   // Publish any recorded diagnostics.
144   publishDiagnostics(diagParams);
145 }
146 void LSPServer::Impl::onDocumentDidClose(
147     const DidCloseTextDocumentParams &params) {
148   Optional<int64_t> version = server.removeDocument(params.textDocument.uri);
149   if (!version)
150     return;
151 
152   // Empty out the diagnostics shown for this document. This will clear out
153   // anything currently displayed by the client for this document (e.g. in the
154   // "Problems" pane of VSCode).
155   publishDiagnostics(
156       PublishDiagnosticsParams(params.textDocument.uri, *version));
157 }
158 void LSPServer::Impl::onDocumentDidChange(
159     const DidChangeTextDocumentParams &params) {
160   // TODO: We currently only support full document updates, we should refactor
161   // to avoid this.
162   if (params.contentChanges.size() != 1)
163     return;
164   PublishDiagnosticsParams diagParams(params.textDocument.uri,
165                                       params.textDocument.version);
166   server.addOrUpdateDocument(
167       params.textDocument.uri, params.contentChanges.front().text,
168       params.textDocument.version, diagParams.diagnostics);
169 
170   // Publish any recorded diagnostics.
171   publishDiagnostics(diagParams);
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Definitions and References
176 
177 void LSPServer::Impl::onGoToDefinition(const TextDocumentPositionParams &params,
178                                        Callback<std::vector<Location>> reply) {
179   std::vector<Location> locations;
180   server.getLocationsOf(params.textDocument.uri, params.position, locations);
181   reply(std::move(locations));
182 }
183 
184 void LSPServer::Impl::onReference(const ReferenceParams &params,
185                                   Callback<std::vector<Location>> reply) {
186   std::vector<Location> locations;
187   server.findReferencesOf(params.textDocument.uri, params.position, locations);
188   reply(std::move(locations));
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // Hover
193 
194 void LSPServer::Impl::onHover(const TextDocumentPositionParams &params,
195                               Callback<Optional<Hover>> reply) {
196   reply(server.findHover(params.textDocument.uri, params.position));
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // Document Symbols
201 
202 void LSPServer::Impl::onDocumentSymbol(
203     const DocumentSymbolParams &params,
204     Callback<std::vector<DocumentSymbol>> reply) {
205   std::vector<DocumentSymbol> symbols;
206   server.findDocumentSymbols(params.textDocument.uri, symbols);
207   reply(std::move(symbols));
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // Code Completion
212 
213 void LSPServer::Impl::onCompletion(const CompletionParams &params,
214                                    Callback<CompletionList> reply) {
215   reply(server.getCodeCompletion(params.textDocument.uri, params.position));
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // LSPServer
220 //===----------------------------------------------------------------------===//
221 
222 LSPServer::LSPServer(MLIRServer &server, JSONTransport &transport)
223     : impl(std::make_unique<Impl>(server, transport)) {}
224 LSPServer::~LSPServer() = default;
225 
226 LogicalResult LSPServer::run() {
227   MessageHandler messageHandler(impl->transport);
228 
229   // Initialization
230   messageHandler.method("initialize", impl.get(), &Impl::onInitialize);
231   messageHandler.notification("initialized", impl.get(), &Impl::onInitialized);
232   messageHandler.method("shutdown", impl.get(), &Impl::onShutdown);
233 
234   // Document Changes
235   messageHandler.notification("textDocument/didOpen", impl.get(),
236                               &Impl::onDocumentDidOpen);
237   messageHandler.notification("textDocument/didClose", impl.get(),
238                               &Impl::onDocumentDidClose);
239   messageHandler.notification("textDocument/didChange", impl.get(),
240                               &Impl::onDocumentDidChange);
241 
242   // Definitions and References
243   messageHandler.method("textDocument/definition", impl.get(),
244                         &Impl::onGoToDefinition);
245   messageHandler.method("textDocument/references", impl.get(),
246                         &Impl::onReference);
247 
248   // Hover
249   messageHandler.method("textDocument/hover", impl.get(), &Impl::onHover);
250 
251   // Document Symbols
252   messageHandler.method("textDocument/documentSymbol", impl.get(),
253                         &Impl::onDocumentSymbol);
254 
255   // Code Completion
256   messageHandler.method("textDocument/completion", impl.get(),
257                         &Impl::onCompletion);
258 
259   // Diagnostics
260   impl->publishDiagnostics =
261       messageHandler.outgoingNotification<PublishDiagnosticsParams>(
262           "textDocument/publishDiagnostics");
263 
264   // Run the main loop of the transport.
265   LogicalResult result = success();
266   if (llvm::Error error = impl->transport.run(messageHandler)) {
267     Logger::error("Transport error: {0}", error);
268     llvm::consumeError(std::move(error));
269     result = failure();
270   } else {
271     result = success(impl->shutdownRequestReceived);
272   }
273   return result;
274 }
275