1 //===- PDLLServer.cpp - PDLL Language Server ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PDLLServer.h"
10
11 #include "../lsp-server-support/CompilationDatabase.h"
12 #include "../lsp-server-support/Logging.h"
13 #include "../lsp-server-support/SourceMgrUtils.h"
14 #include "Protocol.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Tools/PDLL/AST/Context.h"
17 #include "mlir/Tools/PDLL/AST/Nodes.h"
18 #include "mlir/Tools/PDLL/AST/Types.h"
19 #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
20 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
21 #include "mlir/Tools/PDLL/ODS/Constraint.h"
22 #include "mlir/Tools/PDLL/ODS/Context.h"
23 #include "mlir/Tools/PDLL/ODS/Dialect.h"
24 #include "mlir/Tools/PDLL/ODS/Operation.h"
25 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
26 #include "mlir/Tools/PDLL/Parser/Parser.h"
27 #include "llvm/ADT/IntervalMap.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringSet.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/FileSystem.h"
32 #include "llvm/Support/Path.h"
33
34 using namespace mlir;
35 using namespace mlir::pdll;
36
37 /// Returns a language server uri for the given source location. `mainFileURI`
38 /// corresponds to the uri for the main file of the source manager.
getURIFromLoc(llvm::SourceMgr & mgr,SMRange loc,const lsp::URIForFile & mainFileURI)39 static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
40 const lsp::URIForFile &mainFileURI) {
41 int bufferId = mgr.FindBufferContainingLoc(loc.Start);
42 if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
43 return mainFileURI;
44 llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile(
45 mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
46 if (fileForLoc)
47 return *fileForLoc;
48 lsp::Logger::error("Failed to create URI for include file: {0}",
49 llvm::toString(fileForLoc.takeError()));
50 return mainFileURI;
51 }
52
53 /// Returns true if the given location is in the main file of the source
54 /// manager.
isMainFileLoc(llvm::SourceMgr & mgr,SMRange loc)55 static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) {
56 return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID();
57 }
58
59 /// Returns a language server location from the given source range.
getLocationFromLoc(llvm::SourceMgr & mgr,SMRange range,const lsp::URIForFile & uri)60 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
61 const lsp::URIForFile &uri) {
62 return lsp::Location(getURIFromLoc(mgr, range, uri), lsp::Range(mgr, range));
63 }
64
65 /// Returns true if the given range contains the given source location. Note
66 /// that this has different behavior than SMRange because it is inclusive of the
67 /// end location.
contains(SMRange range,SMLoc loc)68 static bool contains(SMRange range, SMLoc loc) {
69 return range.Start.getPointer() <= loc.getPointer() &&
70 loc.getPointer() <= range.End.getPointer();
71 }
72
73 /// Convert the given MLIR diagnostic to the LSP form.
74 static Optional<lsp::Diagnostic>
getLspDiagnoticFromDiag(llvm::SourceMgr & sourceMgr,const ast::Diagnostic & diag,const lsp::URIForFile & uri)75 getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
76 const lsp::URIForFile &uri) {
77 lsp::Diagnostic lspDiag;
78 lspDiag.source = "pdll";
79
80 // FIXME: Right now all of the diagnostics are treated as parser issues, but
81 // some are parser and some are verifier.
82 lspDiag.category = "Parse Error";
83
84 // Try to grab a file location for this diagnostic.
85 lsp::Location loc = getLocationFromLoc(sourceMgr, diag.getLocation(), uri);
86 lspDiag.range = loc.range;
87
88 // Skip diagnostics that weren't emitted within the main file.
89 if (loc.uri != uri)
90 return llvm::None;
91
92 // Convert the severity for the diagnostic.
93 switch (diag.getSeverity()) {
94 case ast::Diagnostic::Severity::DK_Note:
95 llvm_unreachable("expected notes to be handled separately");
96 case ast::Diagnostic::Severity::DK_Warning:
97 lspDiag.severity = lsp::DiagnosticSeverity::Warning;
98 break;
99 case ast::Diagnostic::Severity::DK_Error:
100 lspDiag.severity = lsp::DiagnosticSeverity::Error;
101 break;
102 case ast::Diagnostic::Severity::DK_Remark:
103 lspDiag.severity = lsp::DiagnosticSeverity::Information;
104 break;
105 }
106 lspDiag.message = diag.getMessage().str();
107
108 // Attach any notes to the main diagnostic as related information.
109 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
110 for (const ast::Diagnostic ¬e : diag.getNotes()) {
111 relatedDiags.emplace_back(
112 getLocationFromLoc(sourceMgr, note.getLocation(), uri),
113 note.getMessage().str());
114 }
115 if (!relatedDiags.empty())
116 lspDiag.relatedInformation = std::move(relatedDiags);
117
118 return lspDiag;
119 }
120
121 /// Get or extract the documentation for the given decl.
getDocumentationFor(llvm::SourceMgr & sourceMgr,const ast::Decl * decl)122 static Optional<std::string> getDocumentationFor(llvm::SourceMgr &sourceMgr,
123 const ast::Decl *decl) {
124 // If the decl already had documentation set, use it.
125 if (Optional<StringRef> doc = decl->getDocComment())
126 return doc->str();
127
128 // If the decl doesn't yet have documentation, try to extract it from the
129 // source file. This is a heuristic, and isn't intended to cover every case,
130 // but should cover the most common. We essentially look for a comment
131 // preceding the decl, and if we find one, use that as the documentation.
132 SMLoc startLoc = decl->getLoc().Start;
133 if (!startLoc.isValid())
134 return llvm::None;
135 int bufferId = sourceMgr.FindBufferContainingLoc(startLoc);
136 if (bufferId == 0)
137 return llvm::None;
138 const char *bufferStart =
139 sourceMgr.getMemoryBuffer(bufferId)->getBufferStart();
140 StringRef buffer(bufferStart, startLoc.getPointer() - bufferStart);
141
142 // Pop the last line from the buffer string.
143 auto popLastLine = [&]() -> Optional<StringRef> {
144 size_t newlineOffset = buffer.find_last_of("\n");
145 if (newlineOffset == StringRef::npos)
146 return llvm::None;
147 StringRef lastLine = buffer.drop_front(newlineOffset).trim();
148 buffer = buffer.take_front(newlineOffset);
149 return lastLine;
150 };
151
152 // Try to pop the current line, which contains the decl.
153 if (!popLastLine())
154 return llvm::None;
155
156 // Try to parse a comment string from the source file.
157 SmallVector<StringRef> commentLines;
158 while (Optional<StringRef> line = popLastLine()) {
159 // Check for a comment at the beginning of the line.
160 if (!line->startswith("//"))
161 break;
162
163 // Extract the document string from the comment.
164 commentLines.push_back(line->drop_while([](char c) { return c == '/'; }));
165 }
166
167 if (commentLines.empty())
168 return llvm::None;
169 return llvm::join(llvm::reverse(commentLines), "\n");
170 }
171
172 //===----------------------------------------------------------------------===//
173 // PDLIndex
174 //===----------------------------------------------------------------------===//
175
176 namespace {
177 struct PDLIndexSymbol {
PDLIndexSymbol__anon9e2fdad80311::PDLIndexSymbol178 explicit PDLIndexSymbol(const ast::Decl *definition)
179 : definition(definition) {}
PDLIndexSymbol__anon9e2fdad80311::PDLIndexSymbol180 explicit PDLIndexSymbol(const ods::Operation *definition)
181 : definition(definition) {}
182
183 /// Return the location of the definition of this symbol.
getDefLoc__anon9e2fdad80311::PDLIndexSymbol184 SMRange getDefLoc() const {
185 if (const ast::Decl *decl = definition.dyn_cast<const ast::Decl *>()) {
186 const ast::Name *declName = decl->getName();
187 return declName ? declName->getLoc() : decl->getLoc();
188 }
189 return definition.get<const ods::Operation *>()->getLoc();
190 }
191
192 /// The main definition of the symbol.
193 PointerUnion<const ast::Decl *, const ods::Operation *> definition;
194 /// The set of references to the symbol.
195 std::vector<SMRange> references;
196 };
197
198 /// This class provides an index for definitions/uses within a PDL document.
199 /// It provides efficient lookup of a definition given an input source range.
200 class PDLIndex {
201 public:
PDLIndex()202 PDLIndex() : intervalMap(allocator) {}
203
204 /// Initialize the index with the given ast::Module.
205 void initialize(const ast::Module &module, const ods::Context &odsContext);
206
207 /// Lookup a symbol for the given location. Returns nullptr if no symbol could
208 /// be found. If provided, `overlappedRange` is set to the range that the
209 /// provided `loc` overlapped with.
210 const PDLIndexSymbol *lookup(SMLoc loc,
211 SMRange *overlappedRange = nullptr) const;
212
213 private:
214 /// The type of interval map used to store source references. SMRange is
215 /// half-open, so we also need to use a half-open interval map.
216 using MapT =
217 llvm::IntervalMap<const char *, const PDLIndexSymbol *,
218 llvm::IntervalMapImpl::NodeSizer<
219 const char *, const PDLIndexSymbol *>::LeafSize,
220 llvm::IntervalMapHalfOpenInfo<const char *>>;
221
222 /// An allocator for the interval map.
223 MapT::Allocator allocator;
224
225 /// An interval map containing a corresponding definition mapped to a source
226 /// interval.
227 MapT intervalMap;
228
229 /// A mapping between definitions and their corresponding symbol.
230 DenseMap<const void *, std::unique_ptr<PDLIndexSymbol>> defToSymbol;
231 };
232 } // namespace
233
initialize(const ast::Module & module,const ods::Context & odsContext)234 void PDLIndex::initialize(const ast::Module &module,
235 const ods::Context &odsContext) {
236 auto getOrInsertDef = [&](const auto *def) -> PDLIndexSymbol * {
237 auto it = defToSymbol.try_emplace(def, nullptr);
238 if (it.second)
239 it.first->second = std::make_unique<PDLIndexSymbol>(def);
240 return &*it.first->second;
241 };
242 auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
243 bool isDef = false) {
244 const char *startLoc = refLoc.Start.getPointer();
245 const char *endLoc = refLoc.End.getPointer();
246 if (!intervalMap.overlaps(startLoc, endLoc)) {
247 intervalMap.insert(startLoc, endLoc, sym);
248 if (!isDef)
249 sym->references.push_back(refLoc);
250 }
251 };
252 auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
253 const ods::Operation *odsOp = odsContext.lookupOperation(opName);
254 if (!odsOp)
255 return;
256
257 PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
258 insertDeclRef(symbol, odsOp->getLoc(), /*isDef=*/true);
259 insertDeclRef(symbol, refLoc);
260 };
261
262 module.walk([&](const ast::Node *node) {
263 // Handle references to PDL decls.
264 if (const auto *decl = dyn_cast<ast::OpNameDecl>(node)) {
265 if (Optional<StringRef> name = decl->getName())
266 insertODSOpRef(*name, decl->getLoc());
267 } else if (const ast::Decl *decl = dyn_cast<ast::Decl>(node)) {
268 const ast::Name *name = decl->getName();
269 if (!name)
270 return;
271 PDLIndexSymbol *declSym = getOrInsertDef(decl);
272 insertDeclRef(declSym, name->getLoc(), /*isDef=*/true);
273
274 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) {
275 // Record references to any constraints.
276 for (const auto &it : varDecl->getConstraints())
277 insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
278 }
279 } else if (const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) {
280 insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
281 }
282 });
283 }
284
lookup(SMLoc loc,SMRange * overlappedRange) const285 const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
286 SMRange *overlappedRange) const {
287 auto it = intervalMap.find(loc.getPointer());
288 if (!it.valid() || loc.getPointer() < it.start())
289 return nullptr;
290
291 if (overlappedRange) {
292 *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()),
293 SMLoc::getFromPointer(it.stop()));
294 }
295 return it.value();
296 }
297
298 //===----------------------------------------------------------------------===//
299 // PDLDocument
300 //===----------------------------------------------------------------------===//
301
302 namespace {
303 /// This class represents all of the information pertaining to a specific PDL
304 /// document.
305 struct PDLDocument {
306 PDLDocument(const lsp::URIForFile &uri, StringRef contents,
307 const std::vector<std::string> &extraDirs,
308 std::vector<lsp::Diagnostic> &diagnostics);
309 PDLDocument(const PDLDocument &) = delete;
310 PDLDocument &operator=(const PDLDocument &) = delete;
311
312 //===--------------------------------------------------------------------===//
313 // Definitions and References
314 //===--------------------------------------------------------------------===//
315
316 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
317 std::vector<lsp::Location> &locations);
318 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
319 std::vector<lsp::Location> &references);
320
321 //===--------------------------------------------------------------------===//
322 // Document Links
323 //===--------------------------------------------------------------------===//
324
325 void getDocumentLinks(const lsp::URIForFile &uri,
326 std::vector<lsp::DocumentLink> &links);
327
328 //===--------------------------------------------------------------------===//
329 // Hover
330 //===--------------------------------------------------------------------===//
331
332 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
333 const lsp::Position &hoverPos);
334 Optional<lsp::Hover> findHover(const ast::Decl *decl,
335 const SMRange &hoverRange);
336 lsp::Hover buildHoverForOpName(const ods::Operation *op,
337 const SMRange &hoverRange);
338 lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
339 const SMRange &hoverRange);
340 lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
341 const SMRange &hoverRange);
342 lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
343 const SMRange &hoverRange);
344 template <typename T>
345 lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
346 const T *decl,
347 const SMRange &hoverRange);
348
349 //===--------------------------------------------------------------------===//
350 // Document Symbols
351 //===--------------------------------------------------------------------===//
352
353 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
354
355 //===--------------------------------------------------------------------===//
356 // Code Completion
357 //===--------------------------------------------------------------------===//
358
359 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
360 const lsp::Position &completePos);
361
362 //===--------------------------------------------------------------------===//
363 // Signature Help
364 //===--------------------------------------------------------------------===//
365
366 lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
367 const lsp::Position &helpPos);
368
369 //===--------------------------------------------------------------------===//
370 // Inlay Hints
371 //===--------------------------------------------------------------------===//
372
373 void getInlayHints(const lsp::URIForFile &uri, const lsp::Range &range,
374 std::vector<lsp::InlayHint> &inlayHints);
375 void getInlayHintsFor(const ast::VariableDecl *decl,
376 const lsp::URIForFile &uri,
377 std::vector<lsp::InlayHint> &inlayHints);
378 void getInlayHintsFor(const ast::CallExpr *expr, const lsp::URIForFile &uri,
379 std::vector<lsp::InlayHint> &inlayHints);
380 void getInlayHintsFor(const ast::OperationExpr *expr,
381 const lsp::URIForFile &uri,
382 std::vector<lsp::InlayHint> &inlayHints);
383
384 /// Add a parameter hint for the given expression using `label`.
385 void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
386 const ast::Expr *expr, StringRef label);
387
388 //===--------------------------------------------------------------------===//
389 // PDLL ViewOutput
390 //===--------------------------------------------------------------------===//
391
392 void getPDLLViewOutput(raw_ostream &os, lsp::PDLLViewOutputKind kind);
393
394 //===--------------------------------------------------------------------===//
395 // Fields
396 //===--------------------------------------------------------------------===//
397
398 /// The include directories for this file.
399 std::vector<std::string> includeDirs;
400
401 /// The source manager containing the contents of the input file.
402 llvm::SourceMgr sourceMgr;
403
404 /// The ODS and AST contexts.
405 ods::Context odsContext;
406 ast::Context astContext;
407
408 /// The parsed AST module, or failure if the file wasn't valid.
409 FailureOr<ast::Module *> astModule;
410
411 /// The index of the parsed module.
412 PDLIndex index;
413
414 /// The set of includes of the parsed module.
415 SmallVector<lsp::SourceMgrInclude> parsedIncludes;
416 };
417 } // namespace
418
PDLDocument(const lsp::URIForFile & uri,StringRef contents,const std::vector<std::string> & extraDirs,std::vector<lsp::Diagnostic> & diagnostics)419 PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents,
420 const std::vector<std::string> &extraDirs,
421 std::vector<lsp::Diagnostic> &diagnostics)
422 : astContext(odsContext) {
423 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
424 if (!memBuffer) {
425 lsp::Logger::error("Failed to create memory buffer for file", uri.file());
426 return;
427 }
428
429 // Build the set of include directories for this file.
430 llvm::SmallString<32> uriDirectory(uri.file());
431 llvm::sys::path::remove_filename(uriDirectory);
432 includeDirs.push_back(uriDirectory.str().str());
433 includeDirs.insert(includeDirs.end(), extraDirs.begin(), extraDirs.end());
434
435 sourceMgr.setIncludeDirs(includeDirs);
436 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
437
438 astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
439 if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri))
440 diagnostics.push_back(std::move(*lspDiag));
441 });
442 astModule = parsePDLLAST(astContext, sourceMgr, /*enableDocumentation=*/true);
443
444 // Initialize the set of parsed includes.
445 lsp::gatherIncludeFiles(sourceMgr, parsedIncludes);
446
447 // If we failed to parse the module, there is nothing left to initialize.
448 if (failed(astModule))
449 return;
450
451 // Prepare the AST index with the parsed module.
452 index.initialize(**astModule, odsContext);
453 }
454
455 //===----------------------------------------------------------------------===//
456 // PDLDocument: Definitions and References
457 //===----------------------------------------------------------------------===//
458
getLocationsOf(const lsp::URIForFile & uri,const lsp::Position & defPos,std::vector<lsp::Location> & locations)459 void PDLDocument::getLocationsOf(const lsp::URIForFile &uri,
460 const lsp::Position &defPos,
461 std::vector<lsp::Location> &locations) {
462 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
463 const PDLIndexSymbol *symbol = index.lookup(posLoc);
464 if (!symbol)
465 return;
466
467 locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
468 }
469
findReferencesOf(const lsp::URIForFile & uri,const lsp::Position & pos,std::vector<lsp::Location> & references)470 void PDLDocument::findReferencesOf(const lsp::URIForFile &uri,
471 const lsp::Position &pos,
472 std::vector<lsp::Location> &references) {
473 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
474 const PDLIndexSymbol *symbol = index.lookup(posLoc);
475 if (!symbol)
476 return;
477
478 references.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
479 for (SMRange refLoc : symbol->references)
480 references.push_back(getLocationFromLoc(sourceMgr, refLoc, uri));
481 }
482
483 //===--------------------------------------------------------------------===//
484 // PDLDocument: Document Links
485 //===--------------------------------------------------------------------===//
486
getDocumentLinks(const lsp::URIForFile & uri,std::vector<lsp::DocumentLink> & links)487 void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri,
488 std::vector<lsp::DocumentLink> &links) {
489 for (const lsp::SourceMgrInclude &include : parsedIncludes)
490 links.emplace_back(include.range, include.uri);
491 }
492
493 //===----------------------------------------------------------------------===//
494 // PDLDocument: Hover
495 //===----------------------------------------------------------------------===//
496
findHover(const lsp::URIForFile & uri,const lsp::Position & hoverPos)497 Optional<lsp::Hover> PDLDocument::findHover(const lsp::URIForFile &uri,
498 const lsp::Position &hoverPos) {
499 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
500
501 // Check for a reference to an include.
502 for (const lsp::SourceMgrInclude &include : parsedIncludes)
503 if (include.range.contains(hoverPos))
504 return include.buildHover();
505
506 // Find the symbol at the given location.
507 SMRange hoverRange;
508 const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange);
509 if (!symbol)
510 return llvm::None;
511
512 // Add hover for operation names.
513 if (const auto *op = symbol->definition.dyn_cast<const ods::Operation *>())
514 return buildHoverForOpName(op, hoverRange);
515 const auto *decl = symbol->definition.get<const ast::Decl *>();
516 return findHover(decl, hoverRange);
517 }
518
findHover(const ast::Decl * decl,const SMRange & hoverRange)519 Optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl,
520 const SMRange &hoverRange) {
521 // Add hover for variables.
522 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
523 return buildHoverForVariable(varDecl, hoverRange);
524
525 // Add hover for patterns.
526 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl))
527 return buildHoverForPattern(patternDecl, hoverRange);
528
529 // Add hover for core constraints.
530 if (const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl))
531 return buildHoverForCoreConstraint(cst, hoverRange);
532
533 // Add hover for user constraints.
534 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl))
535 return buildHoverForUserConstraintOrRewrite("Constraint", cst, hoverRange);
536
537 // Add hover for user rewrites.
538 if (const auto *rewrite = dyn_cast<ast::UserRewriteDecl>(decl))
539 return buildHoverForUserConstraintOrRewrite("Rewrite", rewrite, hoverRange);
540
541 return llvm::None;
542 }
543
buildHoverForOpName(const ods::Operation * op,const SMRange & hoverRange)544 lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
545 const SMRange &hoverRange) {
546 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
547 {
548 llvm::raw_string_ostream hoverOS(hover.contents.value);
549 hoverOS << "**OpName**: `" << op->getName() << "`\n***\n"
550 << op->getSummary() << "\n***\n"
551 << op->getDescription();
552 }
553 return hover;
554 }
555
buildHoverForVariable(const ast::VariableDecl * varDecl,const SMRange & hoverRange)556 lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
557 const SMRange &hoverRange) {
558 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
559 {
560 llvm::raw_string_ostream hoverOS(hover.contents.value);
561 hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n"
562 << "Type: `" << varDecl->getType() << "`\n";
563 }
564 return hover;
565 }
566
buildHoverForPattern(const ast::PatternDecl * decl,const SMRange & hoverRange)567 lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
568 const SMRange &hoverRange) {
569 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
570 {
571 llvm::raw_string_ostream hoverOS(hover.contents.value);
572 hoverOS << "**Pattern**";
573 if (const ast::Name *name = decl->getName())
574 hoverOS << ": `" << name->getName() << "`";
575 hoverOS << "\n***\n";
576 if (Optional<uint16_t> benefit = decl->getBenefit())
577 hoverOS << "Benefit: " << *benefit << "\n";
578 if (decl->hasBoundedRewriteRecursion())
579 hoverOS << "HasBoundedRewriteRecursion\n";
580 hoverOS << "RootOp: `"
581 << decl->getRootRewriteStmt()->getRootOpExpr()->getType() << "`\n";
582
583 // Format the documentation for the decl.
584 if (Optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
585 hoverOS << "\n" << *doc << "\n";
586 }
587 return hover;
588 }
589
590 lsp::Hover
buildHoverForCoreConstraint(const ast::CoreConstraintDecl * decl,const SMRange & hoverRange)591 PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
592 const SMRange &hoverRange) {
593 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
594 {
595 llvm::raw_string_ostream hoverOS(hover.contents.value);
596 hoverOS << "**Constraint**: `";
597 TypeSwitch<const ast::Decl *>(decl)
598 .Case([&](const ast::AttrConstraintDecl *) { hoverOS << "Attr"; })
599 .Case([&](const ast::OpConstraintDecl *opCst) {
600 hoverOS << "Op";
601 if (Optional<StringRef> name = opCst->getName())
602 hoverOS << "<" << name << ">";
603 })
604 .Case([&](const ast::TypeConstraintDecl *) { hoverOS << "Type"; })
605 .Case([&](const ast::TypeRangeConstraintDecl *) {
606 hoverOS << "TypeRange";
607 })
608 .Case([&](const ast::ValueConstraintDecl *) { hoverOS << "Value"; })
609 .Case([&](const ast::ValueRangeConstraintDecl *) {
610 hoverOS << "ValueRange";
611 });
612 hoverOS << "`\n";
613 }
614 return hover;
615 }
616
617 template <typename T>
buildHoverForUserConstraintOrRewrite(StringRef typeName,const T * decl,const SMRange & hoverRange)618 lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
619 StringRef typeName, const T *decl, const SMRange &hoverRange) {
620 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
621 {
622 llvm::raw_string_ostream hoverOS(hover.contents.value);
623 hoverOS << "**" << typeName << "**: `" << decl->getName().getName()
624 << "`\n***\n";
625 ArrayRef<ast::VariableDecl *> inputs = decl->getInputs();
626 if (!inputs.empty()) {
627 hoverOS << "Parameters:\n";
628 for (const ast::VariableDecl *input : inputs)
629 hoverOS << "* " << input->getName().getName() << ": `"
630 << input->getType() << "`\n";
631 hoverOS << "***\n";
632 }
633 ast::Type resultType = decl->getResultType();
634 if (auto resultTupleTy = resultType.dyn_cast<ast::TupleType>()) {
635 if (!resultTupleTy.empty()) {
636 hoverOS << "Results:\n";
637 for (auto it : llvm::zip(resultTupleTy.getElementNames(),
638 resultTupleTy.getElementTypes())) {
639 StringRef name = std::get<0>(it);
640 hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
641 << std::get<1>(it) << "`\n";
642 }
643 hoverOS << "***\n";
644 }
645 } else {
646 hoverOS << "Results:\n* `" << resultType << "`\n";
647 hoverOS << "***\n";
648 }
649
650 // Format the documentation for the decl.
651 if (Optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
652 hoverOS << "\n" << *doc << "\n";
653 }
654 return hover;
655 }
656
657 //===----------------------------------------------------------------------===//
658 // PDLDocument: Document Symbols
659 //===----------------------------------------------------------------------===//
660
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)661 void PDLDocument::findDocumentSymbols(
662 std::vector<lsp::DocumentSymbol> &symbols) {
663 if (failed(astModule))
664 return;
665
666 for (const ast::Decl *decl : (*astModule)->getChildren()) {
667 if (!isMainFileLoc(sourceMgr, decl->getLoc()))
668 continue;
669
670 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) {
671 const ast::Name *name = patternDecl->getName();
672
673 SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc();
674 SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
675
676 symbols.emplace_back(
677 name ? name->getName() : "<pattern>", lsp::SymbolKind::Class,
678 lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
679 } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
680 // TODO: Add source information for the code block body.
681 SMRange nameLoc = cDecl->getName().getLoc();
682 SMRange bodyLoc = nameLoc;
683
684 symbols.emplace_back(
685 cDecl->getName().getName(), lsp::SymbolKind::Function,
686 lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
687 } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
688 // TODO: Add source information for the code block body.
689 SMRange nameLoc = cDecl->getName().getLoc();
690 SMRange bodyLoc = nameLoc;
691
692 symbols.emplace_back(
693 cDecl->getName().getName(), lsp::SymbolKind::Function,
694 lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
695 }
696 }
697 }
698
699 //===----------------------------------------------------------------------===//
700 // PDLDocument: Code Completion
701 //===----------------------------------------------------------------------===//
702
703 namespace {
704 class LSPCodeCompleteContext : public CodeCompleteContext {
705 public:
LSPCodeCompleteContext(SMLoc completeLoc,llvm::SourceMgr & sourceMgr,lsp::CompletionList & completionList,ods::Context & odsContext,ArrayRef<std::string> includeDirs)706 LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
707 lsp::CompletionList &completionList,
708 ods::Context &odsContext,
709 ArrayRef<std::string> includeDirs)
710 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
711 completionList(completionList), odsContext(odsContext),
712 includeDirs(includeDirs) {}
713
codeCompleteTupleMemberAccess(ast::TupleType tupleType)714 void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
715 ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
716 ArrayRef<StringRef> elementNames = tupleType.getElementNames();
717 for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
718 // Push back a completion item that uses the result index.
719 lsp::CompletionItem item;
720 item.label = llvm::formatv("{0} (field #{0})", i).str();
721 item.insertText = Twine(i).str();
722 item.filterText = item.sortText = item.insertText;
723 item.kind = lsp::CompletionItemKind::Field;
724 item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]);
725 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
726 completionList.items.emplace_back(item);
727
728 // If the element has a name, push back a completion item with that name.
729 if (!elementNames[i].empty()) {
730 item.label =
731 llvm::formatv("{1} (field #{0})", i, elementNames[i]).str();
732 item.filterText = item.label;
733 item.insertText = elementNames[i].str();
734 completionList.items.emplace_back(item);
735 }
736 }
737 }
738
codeCompleteOperationMemberAccess(ast::OperationType opType)739 void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
740 const ods::Operation *odsOp = opType.getODSOperation();
741 if (!odsOp)
742 return;
743
744 ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
745 for (const auto &it : llvm::enumerate(results)) {
746 const ods::OperandOrResult &result = it.value();
747 const ods::TypeConstraint &constraint = result.getConstraint();
748
749 // Push back a completion item that uses the result index.
750 lsp::CompletionItem item;
751 item.label = llvm::formatv("{0} (field #{0})", it.index()).str();
752 item.insertText = Twine(it.index()).str();
753 item.filterText = item.sortText = item.insertText;
754 item.kind = lsp::CompletionItemKind::Field;
755 switch (result.getVariableLengthKind()) {
756 case ods::VariableLengthKind::Single:
757 item.detail = llvm::formatv("{0}: Value", it.index()).str();
758 break;
759 case ods::VariableLengthKind::Optional:
760 item.detail = llvm::formatv("{0}: Value?", it.index()).str();
761 break;
762 case ods::VariableLengthKind::Variadic:
763 item.detail = llvm::formatv("{0}: ValueRange", it.index()).str();
764 break;
765 }
766 item.documentation = lsp::MarkupContent{
767 lsp::MarkupKind::Markdown,
768 llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
769 constraint.getCppClass())
770 .str()};
771 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
772 completionList.items.emplace_back(item);
773
774 // If the result has a name, push back a completion item with the result
775 // name.
776 if (!result.getName().empty()) {
777 item.label =
778 llvm::formatv("{1} (field #{0})", it.index(), result.getName())
779 .str();
780 item.filterText = item.label;
781 item.insertText = result.getName().str();
782 completionList.items.emplace_back(item);
783 }
784 }
785 }
786
codeCompleteOperationAttributeName(StringRef opName)787 void codeCompleteOperationAttributeName(StringRef opName) final {
788 const ods::Operation *odsOp = odsContext.lookupOperation(opName);
789 if (!odsOp)
790 return;
791
792 for (const ods::Attribute &attr : odsOp->getAttributes()) {
793 const ods::AttributeConstraint &constraint = attr.getConstraint();
794
795 lsp::CompletionItem item;
796 item.label = attr.getName().str();
797 item.kind = lsp::CompletionItemKind::Field;
798 item.detail = attr.isOptional() ? "optional" : "";
799 item.documentation = lsp::MarkupContent{
800 lsp::MarkupKind::Markdown,
801 llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
802 constraint.getCppClass())
803 .str()};
804 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
805 completionList.items.emplace_back(item);
806 }
807 }
808
codeCompleteConstraintName(ast::Type currentType,bool allowNonCoreConstraints,bool allowInlineTypeConstraints,const ast::DeclScope * scope)809 void codeCompleteConstraintName(ast::Type currentType,
810 bool allowNonCoreConstraints,
811 bool allowInlineTypeConstraints,
812 const ast::DeclScope *scope) final {
813 auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
814 StringRef snippetText = "") {
815 lsp::CompletionItem item;
816 item.label = constraint.str();
817 item.kind = lsp::CompletionItemKind::Class;
818 item.detail = (constraint + " constraint").str();
819 item.documentation = lsp::MarkupContent{
820 lsp::MarkupKind::Markdown,
821 ("A single entity core constraint of type `" + mlirType + "`").str()};
822 item.sortText = "0";
823 item.insertText = snippetText.str();
824 item.insertTextFormat = snippetText.empty()
825 ? lsp::InsertTextFormat::PlainText
826 : lsp::InsertTextFormat::Snippet;
827 completionList.items.emplace_back(item);
828 };
829
830 // Insert completions for the core constraints. Some core constraints have
831 // additional characteristics, so we may add then even if a type has been
832 // inferred.
833 if (!currentType) {
834 addCoreConstraint("Attr", "mlir::Attribute");
835 addCoreConstraint("Op", "mlir::Operation *");
836 addCoreConstraint("Value", "mlir::Value");
837 addCoreConstraint("ValueRange", "mlir::ValueRange");
838 addCoreConstraint("Type", "mlir::Type");
839 addCoreConstraint("TypeRange", "mlir::TypeRange");
840 }
841 if (allowInlineTypeConstraints) {
842 /// Attr<Type>.
843 if (!currentType || currentType.isa<ast::AttributeType>())
844 addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
845 /// Value<Type>.
846 if (!currentType || currentType.isa<ast::ValueType>())
847 addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
848 /// ValueRange<TypeRange>.
849 if (!currentType || currentType.isa<ast::ValueRangeType>())
850 addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
851 "ValueRange<$1>");
852 }
853
854 // If a scope was provided, check it for potential constraints.
855 while (scope) {
856 for (const ast::Decl *decl : scope->getDecls()) {
857 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
858 if (!allowNonCoreConstraints)
859 continue;
860
861 lsp::CompletionItem item;
862 item.label = cst->getName().getName().str();
863 item.kind = lsp::CompletionItemKind::Interface;
864 item.sortText = "2_" + item.label;
865
866 // Skip constraints that are not single-arg. We currently only
867 // complete variable constraints.
868 if (cst->getInputs().size() != 1)
869 continue;
870
871 // Ensure the input type matched the given type.
872 ast::Type constraintType = cst->getInputs()[0]->getType();
873 if (currentType && !currentType.refineWith(constraintType))
874 continue;
875
876 // Format the constraint signature.
877 {
878 llvm::raw_string_ostream strOS(item.detail);
879 strOS << "(";
880 llvm::interleaveComma(
881 cst->getInputs(), strOS, [&](const ast::VariableDecl *var) {
882 strOS << var->getName().getName() << ": " << var->getType();
883 });
884 strOS << ") -> " << cst->getResultType();
885 }
886
887 // Format the documentation for the constraint.
888 if (Optional<std::string> doc = getDocumentationFor(sourceMgr, cst)) {
889 item.documentation =
890 lsp::MarkupContent{lsp::MarkupKind::Markdown, std::move(*doc)};
891 }
892
893 completionList.items.emplace_back(item);
894 }
895 }
896
897 scope = scope->getParentScope();
898 }
899 }
900
codeCompleteDialectName()901 void codeCompleteDialectName() final {
902 // Code complete known dialects.
903 for (const ods::Dialect &dialect : odsContext.getDialects()) {
904 lsp::CompletionItem item;
905 item.label = dialect.getName().str();
906 item.kind = lsp::CompletionItemKind::Class;
907 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
908 completionList.items.emplace_back(item);
909 }
910 }
911
codeCompleteOperationName(StringRef dialectName)912 void codeCompleteOperationName(StringRef dialectName) final {
913 const ods::Dialect *dialect = odsContext.lookupDialect(dialectName);
914 if (!dialect)
915 return;
916
917 for (const auto &it : dialect->getOperations()) {
918 const ods::Operation &op = *it.second;
919
920 lsp::CompletionItem item;
921 item.label = op.getName().drop_front(dialectName.size() + 1).str();
922 item.kind = lsp::CompletionItemKind::Field;
923 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
924 completionList.items.emplace_back(item);
925 }
926 }
927
codeCompletePatternMetadata()928 void codeCompletePatternMetadata() final {
929 auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
930 StringRef snippetText = "") {
931 lsp::CompletionItem item;
932 item.label = constraint.str();
933 item.kind = lsp::CompletionItemKind::Class;
934 item.detail = "pattern metadata";
935 item.documentation =
936 lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()};
937 item.insertText = snippetText.str();
938 item.insertTextFormat = snippetText.empty()
939 ? lsp::InsertTextFormat::PlainText
940 : lsp::InsertTextFormat::Snippet;
941 completionList.items.emplace_back(item);
942 };
943
944 addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
945 "benefit($1)");
946 addSimpleConstraint("recursion",
947 "The pattern properly handles recursive application.");
948 }
949
codeCompleteIncludeFilename(StringRef curPath)950 void codeCompleteIncludeFilename(StringRef curPath) final {
951 // Normalize the path to allow for interacting with the file system
952 // utilities.
953 SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(curPath));
954 llvm::sys::path::native(nativeRelDir);
955
956 // Set of already included completion paths.
957 StringSet<> seenResults;
958
959 // Functor used to add a single include completion item.
960 auto addIncludeCompletion = [&](StringRef path, bool isDirectory) {
961 lsp::CompletionItem item;
962 item.label = path.str();
963 item.kind = isDirectory ? lsp::CompletionItemKind::Folder
964 : lsp::CompletionItemKind::File;
965 if (seenResults.insert(item.label).second)
966 completionList.items.emplace_back(item);
967 };
968
969 // Process the include directories for this file, adding any potential
970 // nested include files or directories.
971 for (StringRef includeDir : includeDirs) {
972 llvm::SmallString<128> dir = includeDir;
973 if (!nativeRelDir.empty())
974 llvm::sys::path::append(dir, nativeRelDir);
975
976 std::error_code errorCode;
977 for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
978 e = llvm::sys::fs::directory_iterator();
979 !errorCode && it != e; it.increment(errorCode)) {
980 StringRef filename = llvm::sys::path::filename(it->path());
981
982 // To know whether a symlink should be treated as file or a directory,
983 // we have to stat it. This should be cheap enough as there shouldn't be
984 // many symlinks.
985 llvm::sys::fs::file_type fileType = it->type();
986 if (fileType == llvm::sys::fs::file_type::symlink_file) {
987 if (auto fileStatus = it->status())
988 fileType = fileStatus->type();
989 }
990
991 switch (fileType) {
992 case llvm::sys::fs::file_type::directory_file:
993 addIncludeCompletion(filename, /*isDirectory=*/true);
994 break;
995 case llvm::sys::fs::file_type::regular_file: {
996 // Only consider concrete files that can actually be included by PDLL.
997 if (filename.endswith(".pdll") || filename.endswith(".td"))
998 addIncludeCompletion(filename, /*isDirectory=*/false);
999 break;
1000 }
1001 default:
1002 break;
1003 }
1004 }
1005 }
1006
1007 // Sort the completion results to make sure the output is deterministic in
1008 // the face of different iteration schemes for different platforms.
1009 llvm::sort(completionList.items, [](const lsp::CompletionItem &lhs,
1010 const lsp::CompletionItem &rhs) {
1011 return lhs.label < rhs.label;
1012 });
1013 }
1014
1015 private:
1016 llvm::SourceMgr &sourceMgr;
1017 lsp::CompletionList &completionList;
1018 ods::Context &odsContext;
1019 ArrayRef<std::string> includeDirs;
1020 };
1021 } // namespace
1022
1023 lsp::CompletionList
getCodeCompletion(const lsp::URIForFile & uri,const lsp::Position & completePos)1024 PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
1025 const lsp::Position &completePos) {
1026 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
1027 if (!posLoc.isValid())
1028 return lsp::CompletionList();
1029
1030 // To perform code completion, we run another parse of the module with the
1031 // code completion context provided.
1032 ods::Context tmpODSContext;
1033 lsp::CompletionList completionList;
1034 LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
1035 tmpODSContext,
1036 sourceMgr.getIncludeDirs());
1037
1038 ast::Context tmpContext(tmpODSContext);
1039 (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
1040 &lspCompleteContext);
1041
1042 return completionList;
1043 }
1044
1045 //===----------------------------------------------------------------------===//
1046 // PDLDocument: Signature Help
1047 //===----------------------------------------------------------------------===//
1048
1049 namespace {
1050 class LSPSignatureHelpContext : public CodeCompleteContext {
1051 public:
LSPSignatureHelpContext(SMLoc completeLoc,llvm::SourceMgr & sourceMgr,lsp::SignatureHelp & signatureHelp,ods::Context & odsContext)1052 LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1053 lsp::SignatureHelp &signatureHelp,
1054 ods::Context &odsContext)
1055 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
1056 signatureHelp(signatureHelp), odsContext(odsContext) {}
1057
codeCompleteCallSignature(const ast::CallableDecl * callable,unsigned currentNumArgs)1058 void codeCompleteCallSignature(const ast::CallableDecl *callable,
1059 unsigned currentNumArgs) final {
1060 signatureHelp.activeParameter = currentNumArgs;
1061
1062 lsp::SignatureInformation signatureInfo;
1063 {
1064 llvm::raw_string_ostream strOS(signatureInfo.label);
1065 strOS << callable->getName()->getName() << "(";
1066 auto formatParamFn = [&](const ast::VariableDecl *var) {
1067 unsigned paramStart = strOS.str().size();
1068 strOS << var->getName().getName() << ": " << var->getType();
1069 unsigned paramEnd = strOS.str().size();
1070 signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1071 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1072 std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()});
1073 };
1074 llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
1075 strOS << ") -> " << callable->getResultType();
1076 }
1077
1078 // Format the documentation for the callable.
1079 if (Optional<std::string> doc = getDocumentationFor(sourceMgr, callable))
1080 signatureInfo.documentation = std::move(*doc);
1081
1082 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1083 }
1084
1085 void
codeCompleteOperationOperandsSignature(Optional<StringRef> opName,unsigned currentNumOperands)1086 codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
1087 unsigned currentNumOperands) final {
1088 const ods::Operation *odsOp =
1089 opName ? odsContext.lookupOperation(*opName) : nullptr;
1090 codeCompleteOperationOperandOrResultSignature(
1091 opName, odsOp, odsOp ? odsOp->getOperands() : llvm::None,
1092 currentNumOperands, "operand", "Value");
1093 }
1094
codeCompleteOperationResultsSignature(Optional<StringRef> opName,unsigned currentNumResults)1095 void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
1096 unsigned currentNumResults) final {
1097 const ods::Operation *odsOp =
1098 opName ? odsContext.lookupOperation(*opName) : nullptr;
1099 codeCompleteOperationOperandOrResultSignature(
1100 opName, odsOp, odsOp ? odsOp->getResults() : llvm::None,
1101 currentNumResults, "result", "Type");
1102 }
1103
codeCompleteOperationOperandOrResultSignature(Optional<StringRef> opName,const ods::Operation * odsOp,ArrayRef<ods::OperandOrResult> values,unsigned currentValue,StringRef label,StringRef dataType)1104 void codeCompleteOperationOperandOrResultSignature(
1105 Optional<StringRef> opName, const ods::Operation *odsOp,
1106 ArrayRef<ods::OperandOrResult> values, unsigned currentValue,
1107 StringRef label, StringRef dataType) {
1108 signatureHelp.activeParameter = currentValue;
1109
1110 // If we have ODS information for the operation, add in the ODS signature
1111 // for the operation. We also verify that the current number of values is
1112 // not more than what is defined in ODS, as this will result in an error
1113 // anyways.
1114 if (odsOp && currentValue < values.size()) {
1115 lsp::SignatureInformation signatureInfo;
1116
1117 // Build the signature label.
1118 {
1119 llvm::raw_string_ostream strOS(signatureInfo.label);
1120 strOS << "(";
1121 auto formatFn = [&](const ods::OperandOrResult &value) {
1122 unsigned paramStart = strOS.str().size();
1123
1124 strOS << value.getName() << ": ";
1125
1126 StringRef constraintDoc = value.getConstraint().getSummary();
1127 std::string paramDoc;
1128 switch (value.getVariableLengthKind()) {
1129 case ods::VariableLengthKind::Single:
1130 strOS << dataType;
1131 paramDoc = constraintDoc.str();
1132 break;
1133 case ods::VariableLengthKind::Optional:
1134 strOS << dataType << "?";
1135 paramDoc = ("optional: " + constraintDoc).str();
1136 break;
1137 case ods::VariableLengthKind::Variadic:
1138 strOS << dataType << "Range";
1139 paramDoc = ("variadic: " + constraintDoc).str();
1140 break;
1141 }
1142
1143 unsigned paramEnd = strOS.str().size();
1144 signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1145 StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1146 std::make_pair(paramStart, paramEnd), paramDoc});
1147 };
1148 llvm::interleaveComma(values, strOS, formatFn);
1149 strOS << ")";
1150 }
1151 signatureInfo.documentation =
1152 llvm::formatv("`op<{0}>` ODS {1} specification", *opName, label)
1153 .str();
1154 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1155 }
1156
1157 // If there aren't any arguments yet, we also add the generic signature.
1158 if (currentValue == 0 && (!odsOp || !values.empty())) {
1159 lsp::SignatureInformation signatureInfo;
1160 signatureInfo.label =
1161 llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str();
1162 signatureInfo.documentation =
1163 ("Generic operation " + label + " specification").str();
1164 signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1165 StringRef(signatureInfo.label).drop_front().drop_back().str(),
1166 std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
1167 ("All of the " + label + "s of the operation.").str()});
1168 signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1169 }
1170 }
1171
1172 private:
1173 llvm::SourceMgr &sourceMgr;
1174 lsp::SignatureHelp &signatureHelp;
1175 ods::Context &odsContext;
1176 };
1177 } // namespace
1178
getSignatureHelp(const lsp::URIForFile & uri,const lsp::Position & helpPos)1179 lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri,
1180 const lsp::Position &helpPos) {
1181 SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr);
1182 if (!posLoc.isValid())
1183 return lsp::SignatureHelp();
1184
1185 // To perform code completion, we run another parse of the module with the
1186 // code completion context provided.
1187 ods::Context tmpODSContext;
1188 lsp::SignatureHelp signatureHelp;
1189 LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1190 tmpODSContext);
1191
1192 ast::Context tmpContext(tmpODSContext);
1193 (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
1194 &completeContext);
1195
1196 return signatureHelp;
1197 }
1198
1199 //===----------------------------------------------------------------------===//
1200 // PDLDocument: Inlay Hints
1201 //===----------------------------------------------------------------------===//
1202
1203 /// Returns true if the given name should be added as a hint for `expr`.
shouldAddHintFor(const ast::Expr * expr,StringRef name)1204 static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) {
1205 if (name.empty())
1206 return false;
1207
1208 // If the argument is a reference of the same name, don't add it as a hint.
1209 if (auto *ref = dyn_cast<ast::DeclRefExpr>(expr)) {
1210 const ast::Name *declName = ref->getDecl()->getName();
1211 if (declName && declName->getName() == name)
1212 return false;
1213 }
1214
1215 return true;
1216 }
1217
getInlayHints(const lsp::URIForFile & uri,const lsp::Range & range,std::vector<lsp::InlayHint> & inlayHints)1218 void PDLDocument::getInlayHints(const lsp::URIForFile &uri,
1219 const lsp::Range &range,
1220 std::vector<lsp::InlayHint> &inlayHints) {
1221 if (failed(astModule))
1222 return;
1223 SMRange rangeLoc = range.getAsSMRange(sourceMgr);
1224 if (!rangeLoc.isValid())
1225 return;
1226 (*astModule)->walk([&](const ast::Node *node) {
1227 SMRange loc = node->getLoc();
1228
1229 // Check that the location of this node is within the input range.
1230 if (!contains(rangeLoc, loc.Start) && !contains(rangeLoc, loc.End))
1231 return;
1232
1233 // Handle hints for various types of nodes.
1234 llvm::TypeSwitch<const ast::Node *>(node)
1235 .Case<ast::VariableDecl, ast::CallExpr, ast::OperationExpr>(
1236 [&](const auto *node) {
1237 this->getInlayHintsFor(node, uri, inlayHints);
1238 });
1239 });
1240 }
1241
getInlayHintsFor(const ast::VariableDecl * decl,const lsp::URIForFile & uri,std::vector<lsp::InlayHint> & inlayHints)1242 void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl,
1243 const lsp::URIForFile &uri,
1244 std::vector<lsp::InlayHint> &inlayHints) {
1245 // Check to see if the variable has a constraint list, if it does we don't
1246 // provide initializer hints.
1247 if (!decl->getConstraints().empty())
1248 return;
1249
1250 // Check to see if the variable has an initializer.
1251 if (const ast::Expr *expr = decl->getInitExpr()) {
1252 // Don't add hints for operation expression initialized variables given that
1253 // the type of the variable is easily inferred by the expression operation
1254 // name.
1255 if (isa<ast::OperationExpr>(expr))
1256 return;
1257 }
1258
1259 lsp::InlayHint hint(lsp::InlayHintKind::Type,
1260 lsp::Position(sourceMgr, decl->getLoc().End));
1261 {
1262 llvm::raw_string_ostream labelOS(hint.label);
1263 labelOS << ": " << decl->getType();
1264 }
1265
1266 inlayHints.emplace_back(std::move(hint));
1267 }
1268
getInlayHintsFor(const ast::CallExpr * expr,const lsp::URIForFile & uri,std::vector<lsp::InlayHint> & inlayHints)1269 void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr,
1270 const lsp::URIForFile &uri,
1271 std::vector<lsp::InlayHint> &inlayHints) {
1272 // Try to extract the callable of this call.
1273 const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
1274 const auto *callable =
1275 callableRef ? dyn_cast<ast::CallableDecl>(callableRef->getDecl())
1276 : nullptr;
1277 if (!callable)
1278 return;
1279
1280 // Add hints for the arguments to the call.
1281 for (const auto &it : llvm::zip(expr->getArguments(), callable->getInputs()))
1282 addParameterHintFor(inlayHints, std::get<0>(it),
1283 std::get<1>(it)->getName().getName());
1284 }
1285
getInlayHintsFor(const ast::OperationExpr * expr,const lsp::URIForFile & uri,std::vector<lsp::InlayHint> & inlayHints)1286 void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr,
1287 const lsp::URIForFile &uri,
1288 std::vector<lsp::InlayHint> &inlayHints) {
1289 // Check for ODS information.
1290 ast::OperationType opType = expr->getType().dyn_cast<ast::OperationType>();
1291 const auto *odsOp = opType ? opType.getODSOperation() : nullptr;
1292
1293 auto addOpHint = [&](const ast::Expr *valueExpr, StringRef label) {
1294 // If the value expression used the same location as the operation, don't
1295 // add a hint. This expression was materialized during parsing.
1296 if (expr->getLoc().Start == valueExpr->getLoc().Start)
1297 return;
1298 addParameterHintFor(inlayHints, valueExpr, label);
1299 };
1300
1301 // Functor used to process hints for the operands and results of the
1302 // operation. They effectively have the same format, and thus can be processed
1303 // using the same logic.
1304 auto addOperandOrResultHints = [&](ArrayRef<ast::Expr *> values,
1305 ArrayRef<ods::OperandOrResult> odsValues,
1306 StringRef allValuesName) {
1307 if (values.empty())
1308 return;
1309
1310 // The values should either map to a single range, or be equivalent to the
1311 // ODS values.
1312 if (values.size() != odsValues.size()) {
1313 // Handle the case of a single element that covers the full range.
1314 if (values.size() == 1)
1315 return addOpHint(values.front(), allValuesName);
1316 return;
1317 }
1318
1319 for (const auto &it : llvm::zip(values, odsValues))
1320 addOpHint(std::get<0>(it), std::get<1>(it).getName());
1321 };
1322
1323 // Add hints for the operands and results of the operation.
1324 addOperandOrResultHints(expr->getOperands(),
1325 odsOp ? odsOp->getOperands()
1326 : ArrayRef<ods::OperandOrResult>(),
1327 "operands");
1328 addOperandOrResultHints(expr->getResultTypes(),
1329 odsOp ? odsOp->getResults()
1330 : ArrayRef<ods::OperandOrResult>(),
1331 "results");
1332 }
1333
addParameterHintFor(std::vector<lsp::InlayHint> & inlayHints,const ast::Expr * expr,StringRef label)1334 void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
1335 const ast::Expr *expr, StringRef label) {
1336 if (!shouldAddHintFor(expr, label))
1337 return;
1338
1339 lsp::InlayHint hint(lsp::InlayHintKind::Parameter,
1340 lsp::Position(sourceMgr, expr->getLoc().Start));
1341 hint.label = (label + ":").str();
1342 hint.paddingRight = true;
1343 inlayHints.emplace_back(std::move(hint));
1344 }
1345
1346 //===----------------------------------------------------------------------===//
1347 // PDLL ViewOutput
1348 //===----------------------------------------------------------------------===//
1349
getPDLLViewOutput(raw_ostream & os,lsp::PDLLViewOutputKind kind)1350 void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1351 lsp::PDLLViewOutputKind kind) {
1352 if (failed(astModule))
1353 return;
1354 if (kind == lsp::PDLLViewOutputKind::AST) {
1355 (*astModule)->print(os);
1356 return;
1357 }
1358
1359 // Generate the MLIR for the ast module. We also capture diagnostics here to
1360 // show to the user, which may be useful if PDLL isn't capturing constraints
1361 // expected by PDL.
1362 MLIRContext mlirContext;
1363 SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os);
1364 OwningOpRef<ModuleOp> pdlModule =
1365 codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **astModule);
1366 if (!pdlModule)
1367 return;
1368 if (kind == lsp::PDLLViewOutputKind::MLIR) {
1369 pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
1370 return;
1371 }
1372
1373 // Otherwise, generate the output for C++.
1374 assert(kind == lsp::PDLLViewOutputKind::CPP &&
1375 "unexpected PDLLViewOutputKind");
1376 codegenPDLLToCPP(**astModule, *pdlModule, os);
1377 }
1378
1379 //===----------------------------------------------------------------------===//
1380 // PDLTextFileChunk
1381 //===----------------------------------------------------------------------===//
1382
1383 namespace {
1384 /// This class represents a single chunk of an PDL text file.
1385 struct PDLTextFileChunk {
PDLTextFileChunk__anon9e2fdad81d11::PDLTextFileChunk1386 PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri,
1387 StringRef contents,
1388 const std::vector<std::string> &extraDirs,
1389 std::vector<lsp::Diagnostic> &diagnostics)
1390 : lineOffset(lineOffset),
1391 document(uri, contents, extraDirs, diagnostics) {}
1392
1393 /// Adjust the line number of the given range to anchor at the beginning of
1394 /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anon9e2fdad81d11::PDLTextFileChunk1395 void adjustLocForChunkOffset(lsp::Range &range) {
1396 adjustLocForChunkOffset(range.start);
1397 adjustLocForChunkOffset(range.end);
1398 }
1399 /// Adjust the line number of the given position to anchor at the beginning of
1400 /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anon9e2fdad81d11::PDLTextFileChunk1401 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
1402
1403 /// The line offset of this chunk from the beginning of the file.
1404 uint64_t lineOffset;
1405 /// The document referred to by this chunk.
1406 PDLDocument document;
1407 };
1408 } // namespace
1409
1410 //===----------------------------------------------------------------------===//
1411 // PDLTextFile
1412 //===----------------------------------------------------------------------===//
1413
1414 namespace {
1415 /// This class represents a text file containing one or more PDL documents.
1416 class PDLTextFile {
1417 public:
1418 PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1419 int64_t version, const std::vector<std::string> &extraDirs,
1420 std::vector<lsp::Diagnostic> &diagnostics);
1421
1422 /// Return the current version of this text file.
getVersion() const1423 int64_t getVersion() const { return version; }
1424
1425 /// Update the file to the new version using the provided set of content
1426 /// changes. Returns failure if the update was unsuccessful.
1427 LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion,
1428 ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1429 std::vector<lsp::Diagnostic> &diagnostics);
1430
1431 //===--------------------------------------------------------------------===//
1432 // LSP Queries
1433 //===--------------------------------------------------------------------===//
1434
1435 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1436 std::vector<lsp::Location> &locations);
1437 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1438 std::vector<lsp::Location> &references);
1439 void getDocumentLinks(const lsp::URIForFile &uri,
1440 std::vector<lsp::DocumentLink> &links);
1441 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1442 lsp::Position hoverPos);
1443 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1444 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1445 lsp::Position completePos);
1446 lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
1447 lsp::Position helpPos);
1448 void getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1449 std::vector<lsp::InlayHint> &inlayHints);
1450 lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind);
1451
1452 private:
1453 using ChunkIterator = llvm::pointee_iterator<
1454 std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1455
1456 /// Initialize the text file from the given file contents.
1457 void initialize(const lsp::URIForFile &uri, int64_t newVersion,
1458 std::vector<lsp::Diagnostic> &diagnostics);
1459
1460 /// Find the PDL document that contains the given position, and update the
1461 /// position to be anchored at the start of the found chunk instead of the
1462 /// beginning of the file.
1463 ChunkIterator getChunkItFor(lsp::Position &pos);
getChunkFor(lsp::Position & pos)1464 PDLTextFileChunk &getChunkFor(lsp::Position &pos) {
1465 return *getChunkItFor(pos);
1466 }
1467
1468 /// The full string contents of the file.
1469 std::string contents;
1470
1471 /// The version of this file.
1472 int64_t version = 0;
1473
1474 /// The number of lines in the file.
1475 int64_t totalNumLines = 0;
1476
1477 /// The chunks of this file. The order of these chunks is the order in which
1478 /// they appear in the text file.
1479 std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1480
1481 /// The extra set of include directories for this file.
1482 std::vector<std::string> extraIncludeDirs;
1483 };
1484 } // namespace
1485
PDLTextFile(const lsp::URIForFile & uri,StringRef fileContents,int64_t version,const std::vector<std::string> & extraDirs,std::vector<lsp::Diagnostic> & diagnostics)1486 PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1487 int64_t version,
1488 const std::vector<std::string> &extraDirs,
1489 std::vector<lsp::Diagnostic> &diagnostics)
1490 : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1491 initialize(uri, version, diagnostics);
1492 }
1493
1494 LogicalResult
update(const lsp::URIForFile & uri,int64_t newVersion,ArrayRef<lsp::TextDocumentContentChangeEvent> changes,std::vector<lsp::Diagnostic> & diagnostics)1495 PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
1496 ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1497 std::vector<lsp::Diagnostic> &diagnostics) {
1498 if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
1499 lsp::Logger::error("Failed to update contents of {0}", uri.file());
1500 return failure();
1501 }
1502
1503 // If the file contents were properly changed, reinitialize the text file.
1504 initialize(uri, newVersion, diagnostics);
1505 return success();
1506 }
1507
getLocationsOf(const lsp::URIForFile & uri,lsp::Position defPos,std::vector<lsp::Location> & locations)1508 void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri,
1509 lsp::Position defPos,
1510 std::vector<lsp::Location> &locations) {
1511 PDLTextFileChunk &chunk = getChunkFor(defPos);
1512 chunk.document.getLocationsOf(uri, defPos, locations);
1513
1514 // Adjust any locations within this file for the offset of this chunk.
1515 if (chunk.lineOffset == 0)
1516 return;
1517 for (lsp::Location &loc : locations)
1518 if (loc.uri == uri)
1519 chunk.adjustLocForChunkOffset(loc.range);
1520 }
1521
findReferencesOf(const lsp::URIForFile & uri,lsp::Position pos,std::vector<lsp::Location> & references)1522 void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri,
1523 lsp::Position pos,
1524 std::vector<lsp::Location> &references) {
1525 PDLTextFileChunk &chunk = getChunkFor(pos);
1526 chunk.document.findReferencesOf(uri, pos, references);
1527
1528 // Adjust any locations within this file for the offset of this chunk.
1529 if (chunk.lineOffset == 0)
1530 return;
1531 for (lsp::Location &loc : references)
1532 if (loc.uri == uri)
1533 chunk.adjustLocForChunkOffset(loc.range);
1534 }
1535
getDocumentLinks(const lsp::URIForFile & uri,std::vector<lsp::DocumentLink> & links)1536 void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri,
1537 std::vector<lsp::DocumentLink> &links) {
1538 chunks.front()->document.getDocumentLinks(uri, links);
1539 for (const auto &it : llvm::drop_begin(chunks)) {
1540 size_t currentNumLinks = links.size();
1541 it->document.getDocumentLinks(uri, links);
1542
1543 // Adjust any links within this file to account for the offset of this
1544 // chunk.
1545 for (auto &link : llvm::drop_begin(links, currentNumLinks))
1546 it->adjustLocForChunkOffset(link.range);
1547 }
1548 }
1549
findHover(const lsp::URIForFile & uri,lsp::Position hoverPos)1550 Optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri,
1551 lsp::Position hoverPos) {
1552 PDLTextFileChunk &chunk = getChunkFor(hoverPos);
1553 Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1554
1555 // Adjust any locations within this file for the offset of this chunk.
1556 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1557 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1558 return hoverInfo;
1559 }
1560
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)1561 void PDLTextFile::findDocumentSymbols(
1562 std::vector<lsp::DocumentSymbol> &symbols) {
1563 if (chunks.size() == 1)
1564 return chunks.front()->document.findDocumentSymbols(symbols);
1565
1566 // If there are multiple chunks in this file, we create top-level symbols for
1567 // each chunk.
1568 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1569 PDLTextFileChunk &chunk = *chunks[i];
1570 lsp::Position startPos(chunk.lineOffset);
1571 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1572 : chunks[i + 1]->lineOffset);
1573 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1574 lsp::SymbolKind::Namespace,
1575 /*range=*/lsp::Range(startPos, endPos),
1576 /*selectionRange=*/lsp::Range(startPos));
1577 chunk.document.findDocumentSymbols(symbol.children);
1578
1579 // Fixup the locations of document symbols within this chunk.
1580 if (i != 0) {
1581 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1582 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1583 symbolsToFix.push_back(&childSymbol);
1584
1585 while (!symbolsToFix.empty()) {
1586 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1587 chunk.adjustLocForChunkOffset(symbol->range);
1588 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1589
1590 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1591 symbolsToFix.push_back(&childSymbol);
1592 }
1593 }
1594
1595 // Push the symbol for this chunk.
1596 symbols.emplace_back(std::move(symbol));
1597 }
1598 }
1599
getCodeCompletion(const lsp::URIForFile & uri,lsp::Position completePos)1600 lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1601 lsp::Position completePos) {
1602 PDLTextFileChunk &chunk = getChunkFor(completePos);
1603 lsp::CompletionList completionList =
1604 chunk.document.getCodeCompletion(uri, completePos);
1605
1606 // Adjust any completion locations.
1607 for (lsp::CompletionItem &item : completionList.items) {
1608 if (item.textEdit)
1609 chunk.adjustLocForChunkOffset(item.textEdit->range);
1610 for (lsp::TextEdit &edit : item.additionalTextEdits)
1611 chunk.adjustLocForChunkOffset(edit.range);
1612 }
1613 return completionList;
1614 }
1615
getSignatureHelp(const lsp::URIForFile & uri,lsp::Position helpPos)1616 lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri,
1617 lsp::Position helpPos) {
1618 return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
1619 }
1620
getInlayHints(const lsp::URIForFile & uri,lsp::Range range,std::vector<lsp::InlayHint> & inlayHints)1621 void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1622 std::vector<lsp::InlayHint> &inlayHints) {
1623 auto startIt = getChunkItFor(range.start);
1624 auto endIt = getChunkItFor(range.end);
1625
1626 // Functor used to get the chunks for a given file, and fixup any locations
1627 auto getHintsForChunk = [&](ChunkIterator chunkIt, lsp::Range range) {
1628 size_t currentNumHints = inlayHints.size();
1629 chunkIt->document.getInlayHints(uri, range, inlayHints);
1630
1631 // If this isn't the first chunk, update any positions to account for line
1632 // number differences.
1633 if (&*chunkIt != &*chunks.front()) {
1634 for (auto &hint : llvm::drop_begin(inlayHints, currentNumHints))
1635 chunkIt->adjustLocForChunkOffset(hint.position);
1636 }
1637 };
1638 // Returns the number of lines held by a given chunk.
1639 auto getNumLines = [](ChunkIterator chunkIt) {
1640 return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1641 };
1642
1643 // Check if the range is fully within a single chunk.
1644 if (startIt == endIt)
1645 return getHintsForChunk(startIt, range);
1646
1647 // Otherwise, the range is split between multiple chunks. The first chunk
1648 // has the correct range start, but covers the total document.
1649 getHintsForChunk(startIt, lsp::Range(range.start, getNumLines(startIt)));
1650
1651 // Every chunk in between uses the full document.
1652 for (++startIt; startIt != endIt; ++startIt)
1653 getHintsForChunk(startIt, lsp::Range(0, getNumLines(startIt)));
1654
1655 // The range for the last chunk starts at the beginning of the document, up
1656 // through the end of the input range.
1657 getHintsForChunk(startIt, lsp::Range(0, range.end));
1658 }
1659
1660 lsp::PDLLViewOutputResult
getPDLLViewOutput(lsp::PDLLViewOutputKind kind)1661 PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) {
1662 lsp::PDLLViewOutputResult result;
1663 {
1664 llvm::raw_string_ostream outputOS(result.output);
1665 llvm::interleave(
1666 llvm::make_pointee_range(chunks),
1667 [&](PDLTextFileChunk &chunk) {
1668 chunk.document.getPDLLViewOutput(outputOS, kind);
1669 },
1670 [&] { outputOS << "\n// -----\n\n"; });
1671 }
1672 return result;
1673 }
1674
initialize(const lsp::URIForFile & uri,int64_t newVersion,std::vector<lsp::Diagnostic> & diagnostics)1675 void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion,
1676 std::vector<lsp::Diagnostic> &diagnostics) {
1677 version = newVersion;
1678 chunks.clear();
1679
1680 // Split the file into separate PDL documents.
1681 // TODO: Find a way to share the split file marker with other tools. We don't
1682 // want to use `splitAndProcessBuffer` here, but we do want to make sure this
1683 // marker doesn't go out of sync.
1684 SmallVector<StringRef, 8> subContents;
1685 StringRef(contents).split(subContents, "// -----");
1686 chunks.emplace_back(std::make_unique<PDLTextFileChunk>(
1687 /*lineOffset=*/0, uri, subContents.front(), extraIncludeDirs,
1688 diagnostics));
1689
1690 uint64_t lineOffset = subContents.front().count('\n');
1691 for (StringRef docContents : llvm::drop_begin(subContents)) {
1692 unsigned currentNumDiags = diagnostics.size();
1693 auto chunk = std::make_unique<PDLTextFileChunk>(
1694 lineOffset, uri, docContents, extraIncludeDirs, diagnostics);
1695 lineOffset += docContents.count('\n');
1696
1697 // Adjust locations used in diagnostics to account for the offset from the
1698 // beginning of the file.
1699 for (lsp::Diagnostic &diag :
1700 llvm::drop_begin(diagnostics, currentNumDiags)) {
1701 chunk->adjustLocForChunkOffset(diag.range);
1702
1703 if (!diag.relatedInformation)
1704 continue;
1705 for (auto &it : *diag.relatedInformation)
1706 if (it.location.uri == uri)
1707 chunk->adjustLocForChunkOffset(it.location.range);
1708 }
1709 chunks.emplace_back(std::move(chunk));
1710 }
1711 totalNumLines = lineOffset;
1712 }
1713
getChunkItFor(lsp::Position & pos)1714 PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(lsp::Position &pos) {
1715 if (chunks.size() == 1)
1716 return chunks.begin();
1717
1718 // Search for the first chunk with a greater line offset, the previous chunk
1719 // is the one that contains `pos`.
1720 auto it = llvm::upper_bound(
1721 chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1722 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1723 });
1724 ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1725 pos.line -= chunkIt->lineOffset;
1726 return chunkIt;
1727 }
1728
1729 //===----------------------------------------------------------------------===//
1730 // PDLLServer::Impl
1731 //===----------------------------------------------------------------------===//
1732
1733 struct lsp::PDLLServer::Impl {
Impllsp::PDLLServer::Impl1734 explicit Impl(const Options &options)
1735 : options(options), compilationDatabase(options.compilationDatabases) {}
1736
1737 /// PDLL LSP options.
1738 const Options &options;
1739
1740 /// The compilation database containing additional information for files
1741 /// passed to the server.
1742 lsp::CompilationDatabase compilationDatabase;
1743
1744 /// The files held by the server, mapped by their URI file name.
1745 llvm::StringMap<std::unique_ptr<PDLTextFile>> files;
1746 };
1747
1748 //===----------------------------------------------------------------------===//
1749 // PDLLServer
1750 //===----------------------------------------------------------------------===//
1751
PDLLServer(const Options & options)1752 lsp::PDLLServer::PDLLServer(const Options &options)
1753 : impl(std::make_unique<Impl>(options)) {}
1754 lsp::PDLLServer::~PDLLServer() = default;
1755
addDocument(const URIForFile & uri,StringRef contents,int64_t version,std::vector<Diagnostic> & diagnostics)1756 void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents,
1757 int64_t version,
1758 std::vector<Diagnostic> &diagnostics) {
1759 // Build the set of additional include directories.
1760 std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs;
1761 const auto &fileInfo = impl->compilationDatabase.getFileInfo(uri.file());
1762 llvm::append_range(additionalIncludeDirs, fileInfo.includeDirs);
1763
1764 impl->files[uri.file()] = std::make_unique<PDLTextFile>(
1765 uri, contents, version, additionalIncludeDirs, diagnostics);
1766 }
1767
updateDocument(const URIForFile & uri,ArrayRef<TextDocumentContentChangeEvent> changes,int64_t version,std::vector<Diagnostic> & diagnostics)1768 void lsp::PDLLServer::updateDocument(
1769 const URIForFile &uri, ArrayRef<TextDocumentContentChangeEvent> changes,
1770 int64_t version, std::vector<Diagnostic> &diagnostics) {
1771 // Check that we actually have a document for this uri.
1772 auto it = impl->files.find(uri.file());
1773 if (it == impl->files.end())
1774 return;
1775
1776 // Try to update the document. If we fail, erase the file from the server. A
1777 // failed updated generally means we've fallen out of sync somewhere.
1778 if (failed(it->second->update(uri, version, changes, diagnostics)))
1779 impl->files.erase(it);
1780 }
1781
removeDocument(const URIForFile & uri)1782 Optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) {
1783 auto it = impl->files.find(uri.file());
1784 if (it == impl->files.end())
1785 return llvm::None;
1786
1787 int64_t version = it->second->getVersion();
1788 impl->files.erase(it);
1789 return version;
1790 }
1791
getLocationsOf(const URIForFile & uri,const Position & defPos,std::vector<Location> & locations)1792 void lsp::PDLLServer::getLocationsOf(const URIForFile &uri,
1793 const Position &defPos,
1794 std::vector<Location> &locations) {
1795 auto fileIt = impl->files.find(uri.file());
1796 if (fileIt != impl->files.end())
1797 fileIt->second->getLocationsOf(uri, defPos, locations);
1798 }
1799
findReferencesOf(const URIForFile & uri,const Position & pos,std::vector<Location> & references)1800 void lsp::PDLLServer::findReferencesOf(const URIForFile &uri,
1801 const Position &pos,
1802 std::vector<Location> &references) {
1803 auto fileIt = impl->files.find(uri.file());
1804 if (fileIt != impl->files.end())
1805 fileIt->second->findReferencesOf(uri, pos, references);
1806 }
1807
getDocumentLinks(const URIForFile & uri,std::vector<DocumentLink> & documentLinks)1808 void lsp::PDLLServer::getDocumentLinks(
1809 const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1810 auto fileIt = impl->files.find(uri.file());
1811 if (fileIt != impl->files.end())
1812 return fileIt->second->getDocumentLinks(uri, documentLinks);
1813 }
1814
findHover(const URIForFile & uri,const Position & hoverPos)1815 Optional<lsp::Hover> lsp::PDLLServer::findHover(const URIForFile &uri,
1816 const Position &hoverPos) {
1817 auto fileIt = impl->files.find(uri.file());
1818 if (fileIt != impl->files.end())
1819 return fileIt->second->findHover(uri, hoverPos);
1820 return llvm::None;
1821 }
1822
findDocumentSymbols(const URIForFile & uri,std::vector<DocumentSymbol> & symbols)1823 void lsp::PDLLServer::findDocumentSymbols(
1824 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1825 auto fileIt = impl->files.find(uri.file());
1826 if (fileIt != impl->files.end())
1827 fileIt->second->findDocumentSymbols(symbols);
1828 }
1829
1830 lsp::CompletionList
getCodeCompletion(const URIForFile & uri,const Position & completePos)1831 lsp::PDLLServer::getCodeCompletion(const URIForFile &uri,
1832 const Position &completePos) {
1833 auto fileIt = impl->files.find(uri.file());
1834 if (fileIt != impl->files.end())
1835 return fileIt->second->getCodeCompletion(uri, completePos);
1836 return CompletionList();
1837 }
1838
getSignatureHelp(const URIForFile & uri,const Position & helpPos)1839 lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri,
1840 const Position &helpPos) {
1841 auto fileIt = impl->files.find(uri.file());
1842 if (fileIt != impl->files.end())
1843 return fileIt->second->getSignatureHelp(uri, helpPos);
1844 return SignatureHelp();
1845 }
1846
getInlayHints(const URIForFile & uri,const Range & range,std::vector<InlayHint> & inlayHints)1847 void lsp::PDLLServer::getInlayHints(const URIForFile &uri, const Range &range,
1848 std::vector<InlayHint> &inlayHints) {
1849 auto fileIt = impl->files.find(uri.file());
1850 if (fileIt == impl->files.end())
1851 return;
1852 fileIt->second->getInlayHints(uri, range, inlayHints);
1853
1854 // Drop any duplicated hints that may have cropped up.
1855 llvm::sort(inlayHints);
1856 inlayHints.erase(std::unique(inlayHints.begin(), inlayHints.end()),
1857 inlayHints.end());
1858 }
1859
1860 Optional<lsp::PDLLViewOutputResult>
getPDLLViewOutput(const URIForFile & uri,PDLLViewOutputKind kind)1861 lsp::PDLLServer::getPDLLViewOutput(const URIForFile &uri,
1862 PDLLViewOutputKind kind) {
1863 auto fileIt = impl->files.find(uri.file());
1864 if (fileIt != impl->files.end())
1865 return fileIt->second->getPDLLViewOutput(kind);
1866 return llvm::None;
1867 }
1868