1 //===- PDLLServer.cpp - PDLL Language Server ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PDLLServer.h"
10 
11 #include "../lsp-server-support/CompilationDatabase.h"
12 #include "../lsp-server-support/Logging.h"
13 #include "../lsp-server-support/SourceMgrUtils.h"
14 #include "Protocol.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Tools/PDLL/AST/Context.h"
17 #include "mlir/Tools/PDLL/AST/Nodes.h"
18 #include "mlir/Tools/PDLL/AST/Types.h"
19 #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
20 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
21 #include "mlir/Tools/PDLL/ODS/Constraint.h"
22 #include "mlir/Tools/PDLL/ODS/Context.h"
23 #include "mlir/Tools/PDLL/ODS/Dialect.h"
24 #include "mlir/Tools/PDLL/ODS/Operation.h"
25 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
26 #include "mlir/Tools/PDLL/Parser/Parser.h"
27 #include "llvm/ADT/IntervalMap.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringSet.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/FileSystem.h"
32 #include "llvm/Support/Path.h"
33 
34 using namespace mlir;
35 using namespace mlir::pdll;
36 
37 /// Returns a language server uri for the given source location. `mainFileURI`
38 /// corresponds to the uri for the main file of the source manager.
getURIFromLoc(llvm::SourceMgr & mgr,SMRange loc,const lsp::URIForFile & mainFileURI)39 static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
40                                      const lsp::URIForFile &mainFileURI) {
41   int bufferId = mgr.FindBufferContainingLoc(loc.Start);
42   if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
43     return mainFileURI;
44   llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile(
45       mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
46   if (fileForLoc)
47     return *fileForLoc;
48   lsp::Logger::error("Failed to create URI for include file: {0}",
49                      llvm::toString(fileForLoc.takeError()));
50   return mainFileURI;
51 }
52 
53 /// Returns true if the given location is in the main file of the source
54 /// manager.
isMainFileLoc(llvm::SourceMgr & mgr,SMRange loc)55 static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) {
56   return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID();
57 }
58 
59 /// Returns a language server location from the given source range.
getLocationFromLoc(llvm::SourceMgr & mgr,SMRange range,const lsp::URIForFile & uri)60 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
61                                         const lsp::URIForFile &uri) {
62   return lsp::Location(getURIFromLoc(mgr, range, uri), lsp::Range(mgr, range));
63 }
64 
65 /// Returns true if the given range contains the given source location. Note
66 /// that this has different behavior than SMRange because it is inclusive of the
67 /// end location.
contains(SMRange range,SMLoc loc)68 static bool contains(SMRange range, SMLoc loc) {
69   return range.Start.getPointer() <= loc.getPointer() &&
70          loc.getPointer() <= range.End.getPointer();
71 }
72 
73 /// Convert the given MLIR diagnostic to the LSP form.
74 static Optional<lsp::Diagnostic>
getLspDiagnoticFromDiag(llvm::SourceMgr & sourceMgr,const ast::Diagnostic & diag,const lsp::URIForFile & uri)75 getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
76                         const lsp::URIForFile &uri) {
77   lsp::Diagnostic lspDiag;
78   lspDiag.source = "pdll";
79 
80   // FIXME: Right now all of the diagnostics are treated as parser issues, but
81   // some are parser and some are verifier.
82   lspDiag.category = "Parse Error";
83 
84   // Try to grab a file location for this diagnostic.
85   lsp::Location loc = getLocationFromLoc(sourceMgr, diag.getLocation(), uri);
86   lspDiag.range = loc.range;
87 
88   // Skip diagnostics that weren't emitted within the main file.
89   if (loc.uri != uri)
90     return llvm::None;
91 
92   // Convert the severity for the diagnostic.
93   switch (diag.getSeverity()) {
94   case ast::Diagnostic::Severity::DK_Note:
95     llvm_unreachable("expected notes to be handled separately");
96   case ast::Diagnostic::Severity::DK_Warning:
97     lspDiag.severity = lsp::DiagnosticSeverity::Warning;
98     break;
99   case ast::Diagnostic::Severity::DK_Error:
100     lspDiag.severity = lsp::DiagnosticSeverity::Error;
101     break;
102   case ast::Diagnostic::Severity::DK_Remark:
103     lspDiag.severity = lsp::DiagnosticSeverity::Information;
104     break;
105   }
106   lspDiag.message = diag.getMessage().str();
107 
108   // Attach any notes to the main diagnostic as related information.
109   std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
110   for (const ast::Diagnostic &note : diag.getNotes()) {
111     relatedDiags.emplace_back(
112         getLocationFromLoc(sourceMgr, note.getLocation(), uri),
113         note.getMessage().str());
114   }
115   if (!relatedDiags.empty())
116     lspDiag.relatedInformation = std::move(relatedDiags);
117 
118   return lspDiag;
119 }
120 
121 /// Get or extract the documentation for the given decl.
getDocumentationFor(llvm::SourceMgr & sourceMgr,const ast::Decl * decl)122 static Optional<std::string> getDocumentationFor(llvm::SourceMgr &sourceMgr,
123                                                  const ast::Decl *decl) {
124   // If the decl already had documentation set, use it.
125   if (Optional<StringRef> doc = decl->getDocComment())
126     return doc->str();
127 
128   // If the decl doesn't yet have documentation, try to extract it from the
129   // source file. This is a heuristic, and isn't intended to cover every case,
130   // but should cover the most common. We essentially look for a comment
131   // preceding the decl, and if we find one, use that as the documentation.
132   SMLoc startLoc = decl->getLoc().Start;
133   if (!startLoc.isValid())
134     return llvm::None;
135   int bufferId = sourceMgr.FindBufferContainingLoc(startLoc);
136   if (bufferId == 0)
137     return llvm::None;
138   const char *bufferStart =
139       sourceMgr.getMemoryBuffer(bufferId)->getBufferStart();
140   StringRef buffer(bufferStart, startLoc.getPointer() - bufferStart);
141 
142   // Pop the last line from the buffer string.
143   auto popLastLine = [&]() -> Optional<StringRef> {
144     size_t newlineOffset = buffer.find_last_of("\n");
145     if (newlineOffset == StringRef::npos)
146       return llvm::None;
147     StringRef lastLine = buffer.drop_front(newlineOffset).trim();
148     buffer = buffer.take_front(newlineOffset);
149     return lastLine;
150   };
151 
152   // Try to pop the current line, which contains the decl.
153   if (!popLastLine())
154     return llvm::None;
155 
156   // Try to parse a comment string from the source file.
157   SmallVector<StringRef> commentLines;
158   while (Optional<StringRef> line = popLastLine()) {
159     // Check for a comment at the beginning of the line.
160     if (!line->startswith("//"))
161       break;
162 
163     // Extract the document string from the comment.
164     commentLines.push_back(line->drop_while([](char c) { return c == '/'; }));
165   }
166 
167   if (commentLines.empty())
168     return llvm::None;
169   return llvm::join(llvm::reverse(commentLines), "\n");
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // PDLIndex
174 //===----------------------------------------------------------------------===//
175 
176 namespace {
177 struct PDLIndexSymbol {
PDLIndexSymbol__anon9e2fdad80311::PDLIndexSymbol178   explicit PDLIndexSymbol(const ast::Decl *definition)
179       : definition(definition) {}
PDLIndexSymbol__anon9e2fdad80311::PDLIndexSymbol180   explicit PDLIndexSymbol(const ods::Operation *definition)
181       : definition(definition) {}
182 
183   /// Return the location of the definition of this symbol.
getDefLoc__anon9e2fdad80311::PDLIndexSymbol184   SMRange getDefLoc() const {
185     if (const ast::Decl *decl = definition.dyn_cast<const ast::Decl *>()) {
186       const ast::Name *declName = decl->getName();
187       return declName ? declName->getLoc() : decl->getLoc();
188     }
189     return definition.get<const ods::Operation *>()->getLoc();
190   }
191 
192   /// The main definition of the symbol.
193   PointerUnion<const ast::Decl *, const ods::Operation *> definition;
194   /// The set of references to the symbol.
195   std::vector<SMRange> references;
196 };
197 
198 /// This class provides an index for definitions/uses within a PDL document.
199 /// It provides efficient lookup of a definition given an input source range.
200 class PDLIndex {
201 public:
PDLIndex()202   PDLIndex() : intervalMap(allocator) {}
203 
204   /// Initialize the index with the given ast::Module.
205   void initialize(const ast::Module &module, const ods::Context &odsContext);
206 
207   /// Lookup a symbol for the given location. Returns nullptr if no symbol could
208   /// be found. If provided, `overlappedRange` is set to the range that the
209   /// provided `loc` overlapped with.
210   const PDLIndexSymbol *lookup(SMLoc loc,
211                                SMRange *overlappedRange = nullptr) const;
212 
213 private:
214   /// The type of interval map used to store source references. SMRange is
215   /// half-open, so we also need to use a half-open interval map.
216   using MapT =
217       llvm::IntervalMap<const char *, const PDLIndexSymbol *,
218                         llvm::IntervalMapImpl::NodeSizer<
219                             const char *, const PDLIndexSymbol *>::LeafSize,
220                         llvm::IntervalMapHalfOpenInfo<const char *>>;
221 
222   /// An allocator for the interval map.
223   MapT::Allocator allocator;
224 
225   /// An interval map containing a corresponding definition mapped to a source
226   /// interval.
227   MapT intervalMap;
228 
229   /// A mapping between definitions and their corresponding symbol.
230   DenseMap<const void *, std::unique_ptr<PDLIndexSymbol>> defToSymbol;
231 };
232 } // namespace
233 
initialize(const ast::Module & module,const ods::Context & odsContext)234 void PDLIndex::initialize(const ast::Module &module,
235                           const ods::Context &odsContext) {
236   auto getOrInsertDef = [&](const auto *def) -> PDLIndexSymbol * {
237     auto it = defToSymbol.try_emplace(def, nullptr);
238     if (it.second)
239       it.first->second = std::make_unique<PDLIndexSymbol>(def);
240     return &*it.first->second;
241   };
242   auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
243                            bool isDef = false) {
244     const char *startLoc = refLoc.Start.getPointer();
245     const char *endLoc = refLoc.End.getPointer();
246     if (!intervalMap.overlaps(startLoc, endLoc)) {
247       intervalMap.insert(startLoc, endLoc, sym);
248       if (!isDef)
249         sym->references.push_back(refLoc);
250     }
251   };
252   auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
253     const ods::Operation *odsOp = odsContext.lookupOperation(opName);
254     if (!odsOp)
255       return;
256 
257     PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
258     insertDeclRef(symbol, odsOp->getLoc(), /*isDef=*/true);
259     insertDeclRef(symbol, refLoc);
260   };
261 
262   module.walk([&](const ast::Node *node) {
263     // Handle references to PDL decls.
264     if (const auto *decl = dyn_cast<ast::OpNameDecl>(node)) {
265       if (Optional<StringRef> name = decl->getName())
266         insertODSOpRef(*name, decl->getLoc());
267     } else if (const ast::Decl *decl = dyn_cast<ast::Decl>(node)) {
268       const ast::Name *name = decl->getName();
269       if (!name)
270         return;
271       PDLIndexSymbol *declSym = getOrInsertDef(decl);
272       insertDeclRef(declSym, name->getLoc(), /*isDef=*/true);
273 
274       if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) {
275         // Record references to any constraints.
276         for (const auto &it : varDecl->getConstraints())
277           insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
278       }
279     } else if (const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) {
280       insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
281     }
282   });
283 }
284 
lookup(SMLoc loc,SMRange * overlappedRange) const285 const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
286                                        SMRange *overlappedRange) const {
287   auto it = intervalMap.find(loc.getPointer());
288   if (!it.valid() || loc.getPointer() < it.start())
289     return nullptr;
290 
291   if (overlappedRange) {
292     *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()),
293                                SMLoc::getFromPointer(it.stop()));
294   }
295   return it.value();
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // PDLDocument
300 //===----------------------------------------------------------------------===//
301 
302 namespace {
303 /// This class represents all of the information pertaining to a specific PDL
304 /// document.
305 struct PDLDocument {
306   PDLDocument(const lsp::URIForFile &uri, StringRef contents,
307               const std::vector<std::string> &extraDirs,
308               std::vector<lsp::Diagnostic> &diagnostics);
309   PDLDocument(const PDLDocument &) = delete;
310   PDLDocument &operator=(const PDLDocument &) = delete;
311 
312   //===--------------------------------------------------------------------===//
313   // Definitions and References
314   //===--------------------------------------------------------------------===//
315 
316   void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
317                       std::vector<lsp::Location> &locations);
318   void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
319                         std::vector<lsp::Location> &references);
320 
321   //===--------------------------------------------------------------------===//
322   // Document Links
323   //===--------------------------------------------------------------------===//
324 
325   void getDocumentLinks(const lsp::URIForFile &uri,
326                         std::vector<lsp::DocumentLink> &links);
327 
328   //===--------------------------------------------------------------------===//
329   // Hover
330   //===--------------------------------------------------------------------===//
331 
332   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
333                                  const lsp::Position &hoverPos);
334   Optional<lsp::Hover> findHover(const ast::Decl *decl,
335                                  const SMRange &hoverRange);
336   lsp::Hover buildHoverForOpName(const ods::Operation *op,
337                                  const SMRange &hoverRange);
338   lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
339                                    const SMRange &hoverRange);
340   lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
341                                   const SMRange &hoverRange);
342   lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
343                                          const SMRange &hoverRange);
344   template <typename T>
345   lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
346                                                   const T *decl,
347                                                   const SMRange &hoverRange);
348 
349   //===--------------------------------------------------------------------===//
350   // Document Symbols
351   //===--------------------------------------------------------------------===//
352 
353   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
354 
355   //===--------------------------------------------------------------------===//
356   // Code Completion
357   //===--------------------------------------------------------------------===//
358 
359   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
360                                         const lsp::Position &completePos);
361 
362   //===--------------------------------------------------------------------===//
363   // Signature Help
364   //===--------------------------------------------------------------------===//
365 
366   lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
367                                       const lsp::Position &helpPos);
368 
369   //===--------------------------------------------------------------------===//
370   // Inlay Hints
371   //===--------------------------------------------------------------------===//
372 
373   void getInlayHints(const lsp::URIForFile &uri, const lsp::Range &range,
374                      std::vector<lsp::InlayHint> &inlayHints);
375   void getInlayHintsFor(const ast::VariableDecl *decl,
376                         const lsp::URIForFile &uri,
377                         std::vector<lsp::InlayHint> &inlayHints);
378   void getInlayHintsFor(const ast::CallExpr *expr, const lsp::URIForFile &uri,
379                         std::vector<lsp::InlayHint> &inlayHints);
380   void getInlayHintsFor(const ast::OperationExpr *expr,
381                         const lsp::URIForFile &uri,
382                         std::vector<lsp::InlayHint> &inlayHints);
383 
384   /// Add a parameter hint for the given expression using `label`.
385   void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
386                            const ast::Expr *expr, StringRef label);
387 
388   //===--------------------------------------------------------------------===//
389   // PDLL ViewOutput
390   //===--------------------------------------------------------------------===//
391 
392   void getPDLLViewOutput(raw_ostream &os, lsp::PDLLViewOutputKind kind);
393 
394   //===--------------------------------------------------------------------===//
395   // Fields
396   //===--------------------------------------------------------------------===//
397 
398   /// The include directories for this file.
399   std::vector<std::string> includeDirs;
400 
401   /// The source manager containing the contents of the input file.
402   llvm::SourceMgr sourceMgr;
403 
404   /// The ODS and AST contexts.
405   ods::Context odsContext;
406   ast::Context astContext;
407 
408   /// The parsed AST module, or failure if the file wasn't valid.
409   FailureOr<ast::Module *> astModule;
410 
411   /// The index of the parsed module.
412   PDLIndex index;
413 
414   /// The set of includes of the parsed module.
415   SmallVector<lsp::SourceMgrInclude> parsedIncludes;
416 };
417 } // namespace
418 
PDLDocument(const lsp::URIForFile & uri,StringRef contents,const std::vector<std::string> & extraDirs,std::vector<lsp::Diagnostic> & diagnostics)419 PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents,
420                          const std::vector<std::string> &extraDirs,
421                          std::vector<lsp::Diagnostic> &diagnostics)
422     : astContext(odsContext) {
423   auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
424   if (!memBuffer) {
425     lsp::Logger::error("Failed to create memory buffer for file", uri.file());
426     return;
427   }
428 
429   // Build the set of include directories for this file.
430   llvm::SmallString<32> uriDirectory(uri.file());
431   llvm::sys::path::remove_filename(uriDirectory);
432   includeDirs.push_back(uriDirectory.str().str());
433   includeDirs.insert(includeDirs.end(), extraDirs.begin(), extraDirs.end());
434 
435   sourceMgr.setIncludeDirs(includeDirs);
436   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
437 
438   astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
439     if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri))
440       diagnostics.push_back(std::move(*lspDiag));
441   });
442   astModule = parsePDLLAST(astContext, sourceMgr, /*enableDocumentation=*/true);
443 
444   // Initialize the set of parsed includes.
445   lsp::gatherIncludeFiles(sourceMgr, parsedIncludes);
446 
447   // If we failed to parse the module, there is nothing left to initialize.
448   if (failed(astModule))
449     return;
450 
451   // Prepare the AST index with the parsed module.
452   index.initialize(**astModule, odsContext);
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // PDLDocument: Definitions and References
457 //===----------------------------------------------------------------------===//
458 
getLocationsOf(const lsp::URIForFile & uri,const lsp::Position & defPos,std::vector<lsp::Location> & locations)459 void PDLDocument::getLocationsOf(const lsp::URIForFile &uri,
460                                  const lsp::Position &defPos,
461                                  std::vector<lsp::Location> &locations) {
462   SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
463   const PDLIndexSymbol *symbol = index.lookup(posLoc);
464   if (!symbol)
465     return;
466 
467   locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
468 }
469 
findReferencesOf(const lsp::URIForFile & uri,const lsp::Position & pos,std::vector<lsp::Location> & references)470 void PDLDocument::findReferencesOf(const lsp::URIForFile &uri,
471                                    const lsp::Position &pos,
472                                    std::vector<lsp::Location> &references) {
473   SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
474   const PDLIndexSymbol *symbol = index.lookup(posLoc);
475   if (!symbol)
476     return;
477 
478   references.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
479   for (SMRange refLoc : symbol->references)
480     references.push_back(getLocationFromLoc(sourceMgr, refLoc, uri));
481 }
482 
483 //===--------------------------------------------------------------------===//
484 // PDLDocument: Document Links
485 //===--------------------------------------------------------------------===//
486 
getDocumentLinks(const lsp::URIForFile & uri,std::vector<lsp::DocumentLink> & links)487 void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri,
488                                    std::vector<lsp::DocumentLink> &links) {
489   for (const lsp::SourceMgrInclude &include : parsedIncludes)
490     links.emplace_back(include.range, include.uri);
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // PDLDocument: Hover
495 //===----------------------------------------------------------------------===//
496 
findHover(const lsp::URIForFile & uri,const lsp::Position & hoverPos)497 Optional<lsp::Hover> PDLDocument::findHover(const lsp::URIForFile &uri,
498                                             const lsp::Position &hoverPos) {
499   SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
500 
501   // Check for a reference to an include.
502   for (const lsp::SourceMgrInclude &include : parsedIncludes)
503     if (include.range.contains(hoverPos))
504       return include.buildHover();
505 
506   // Find the symbol at the given location.
507   SMRange hoverRange;
508   const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange);
509   if (!symbol)
510     return llvm::None;
511 
512   // Add hover for operation names.
513   if (const auto *op = symbol->definition.dyn_cast<const ods::Operation *>())
514     return buildHoverForOpName(op, hoverRange);
515   const auto *decl = symbol->definition.get<const ast::Decl *>();
516   return findHover(decl, hoverRange);
517 }
518 
findHover(const ast::Decl * decl,const SMRange & hoverRange)519 Optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl,
520                                             const SMRange &hoverRange) {
521   // Add hover for variables.
522   if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
523     return buildHoverForVariable(varDecl, hoverRange);
524 
525   // Add hover for patterns.
526   if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl))
527     return buildHoverForPattern(patternDecl, hoverRange);
528 
529   // Add hover for core constraints.
530   if (const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl))
531     return buildHoverForCoreConstraint(cst, hoverRange);
532 
533   // Add hover for user constraints.
534   if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl))
535     return buildHoverForUserConstraintOrRewrite("Constraint", cst, hoverRange);
536 
537   // Add hover for user rewrites.
538   if (const auto *rewrite = dyn_cast<ast::UserRewriteDecl>(decl))
539     return buildHoverForUserConstraintOrRewrite("Rewrite", rewrite, hoverRange);
540 
541   return llvm::None;
542 }
543 
buildHoverForOpName(const ods::Operation * op,const SMRange & hoverRange)544 lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
545                                             const SMRange &hoverRange) {
546   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
547   {
548     llvm::raw_string_ostream hoverOS(hover.contents.value);
549     hoverOS << "**OpName**: `" << op->getName() << "`\n***\n"
550             << op->getSummary() << "\n***\n"
551             << op->getDescription();
552   }
553   return hover;
554 }
555 
buildHoverForVariable(const ast::VariableDecl * varDecl,const SMRange & hoverRange)556 lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
557                                               const SMRange &hoverRange) {
558   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
559   {
560     llvm::raw_string_ostream hoverOS(hover.contents.value);
561     hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n"
562             << "Type: `" << varDecl->getType() << "`\n";
563   }
564   return hover;
565 }
566 
buildHoverForPattern(const ast::PatternDecl * decl,const SMRange & hoverRange)567 lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
568                                              const SMRange &hoverRange) {
569   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
570   {
571     llvm::raw_string_ostream hoverOS(hover.contents.value);
572     hoverOS << "**Pattern**";
573     if (const ast::Name *name = decl->getName())
574       hoverOS << ": `" << name->getName() << "`";
575     hoverOS << "\n***\n";
576     if (Optional<uint16_t> benefit = decl->getBenefit())
577       hoverOS << "Benefit: " << *benefit << "\n";
578     if (decl->hasBoundedRewriteRecursion())
579       hoverOS << "HasBoundedRewriteRecursion\n";
580     hoverOS << "RootOp: `"
581             << decl->getRootRewriteStmt()->getRootOpExpr()->getType() << "`\n";
582 
583     // Format the documentation for the decl.
584     if (Optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
585       hoverOS << "\n" << *doc << "\n";
586   }
587   return hover;
588 }
589 
590 lsp::Hover
buildHoverForCoreConstraint(const ast::CoreConstraintDecl * decl,const SMRange & hoverRange)591 PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
592                                          const SMRange &hoverRange) {
593   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
594   {
595     llvm::raw_string_ostream hoverOS(hover.contents.value);
596     hoverOS << "**Constraint**: `";
597     TypeSwitch<const ast::Decl *>(decl)
598         .Case([&](const ast::AttrConstraintDecl *) { hoverOS << "Attr"; })
599         .Case([&](const ast::OpConstraintDecl *opCst) {
600           hoverOS << "Op";
601           if (Optional<StringRef> name = opCst->getName())
602             hoverOS << "<" << name << ">";
603         })
604         .Case([&](const ast::TypeConstraintDecl *) { hoverOS << "Type"; })
605         .Case([&](const ast::TypeRangeConstraintDecl *) {
606           hoverOS << "TypeRange";
607         })
608         .Case([&](const ast::ValueConstraintDecl *) { hoverOS << "Value"; })
609         .Case([&](const ast::ValueRangeConstraintDecl *) {
610           hoverOS << "ValueRange";
611         });
612     hoverOS << "`\n";
613   }
614   return hover;
615 }
616 
617 template <typename T>
buildHoverForUserConstraintOrRewrite(StringRef typeName,const T * decl,const SMRange & hoverRange)618 lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
619     StringRef typeName, const T *decl, const SMRange &hoverRange) {
620   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
621   {
622     llvm::raw_string_ostream hoverOS(hover.contents.value);
623     hoverOS << "**" << typeName << "**: `" << decl->getName().getName()
624             << "`\n***\n";
625     ArrayRef<ast::VariableDecl *> inputs = decl->getInputs();
626     if (!inputs.empty()) {
627       hoverOS << "Parameters:\n";
628       for (const ast::VariableDecl *input : inputs)
629         hoverOS << "* " << input->getName().getName() << ": `"
630                 << input->getType() << "`\n";
631       hoverOS << "***\n";
632     }
633     ast::Type resultType = decl->getResultType();
634     if (auto resultTupleTy = resultType.dyn_cast<ast::TupleType>()) {
635       if (!resultTupleTy.empty()) {
636         hoverOS << "Results:\n";
637         for (auto it : llvm::zip(resultTupleTy.getElementNames(),
638                                  resultTupleTy.getElementTypes())) {
639           StringRef name = std::get<0>(it);
640           hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
641                   << std::get<1>(it) << "`\n";
642         }
643         hoverOS << "***\n";
644       }
645     } else {
646       hoverOS << "Results:\n* `" << resultType << "`\n";
647       hoverOS << "***\n";
648     }
649 
650     // Format the documentation for the decl.
651     if (Optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
652       hoverOS << "\n" << *doc << "\n";
653   }
654   return hover;
655 }
656 
657 //===----------------------------------------------------------------------===//
658 // PDLDocument: Document Symbols
659 //===----------------------------------------------------------------------===//
660 
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)661 void PDLDocument::findDocumentSymbols(
662     std::vector<lsp::DocumentSymbol> &symbols) {
663   if (failed(astModule))
664     return;
665 
666   for (const ast::Decl *decl : (*astModule)->getChildren()) {
667     if (!isMainFileLoc(sourceMgr, decl->getLoc()))
668       continue;
669 
670     if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) {
671       const ast::Name *name = patternDecl->getName();
672 
673       SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc();
674       SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
675 
676       symbols.emplace_back(
677           name ? name->getName() : "<pattern>", lsp::SymbolKind::Class,
678           lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
679     } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
680       // TODO: Add source information for the code block body.
681       SMRange nameLoc = cDecl->getName().getLoc();
682       SMRange bodyLoc = nameLoc;
683 
684       symbols.emplace_back(
685           cDecl->getName().getName(), lsp::SymbolKind::Function,
686           lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
687     } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
688       // TODO: Add source information for the code block body.
689       SMRange nameLoc = cDecl->getName().getLoc();
690       SMRange bodyLoc = nameLoc;
691 
692       symbols.emplace_back(
693           cDecl->getName().getName(), lsp::SymbolKind::Function,
694           lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
695     }
696   }
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // PDLDocument: Code Completion
701 //===----------------------------------------------------------------------===//
702 
703 namespace {
704 class LSPCodeCompleteContext : public CodeCompleteContext {
705 public:
LSPCodeCompleteContext(SMLoc completeLoc,llvm::SourceMgr & sourceMgr,lsp::CompletionList & completionList,ods::Context & odsContext,ArrayRef<std::string> includeDirs)706   LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
707                          lsp::CompletionList &completionList,
708                          ods::Context &odsContext,
709                          ArrayRef<std::string> includeDirs)
710       : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
711         completionList(completionList), odsContext(odsContext),
712         includeDirs(includeDirs) {}
713 
codeCompleteTupleMemberAccess(ast::TupleType tupleType)714   void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
715     ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
716     ArrayRef<StringRef> elementNames = tupleType.getElementNames();
717     for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
718       // Push back a completion item that uses the result index.
719       lsp::CompletionItem item;
720       item.label = llvm::formatv("{0} (field #{0})", i).str();
721       item.insertText = Twine(i).str();
722       item.filterText = item.sortText = item.insertText;
723       item.kind = lsp::CompletionItemKind::Field;
724       item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]);
725       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
726       completionList.items.emplace_back(item);
727 
728       // If the element has a name, push back a completion item with that name.
729       if (!elementNames[i].empty()) {
730         item.label =
731             llvm::formatv("{1} (field #{0})", i, elementNames[i]).str();
732         item.filterText = item.label;
733         item.insertText = elementNames[i].str();
734         completionList.items.emplace_back(item);
735       }
736     }
737   }
738 
codeCompleteOperationMemberAccess(ast::OperationType opType)739   void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
740     const ods::Operation *odsOp = opType.getODSOperation();
741     if (!odsOp)
742       return;
743 
744     ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
745     for (const auto &it : llvm::enumerate(results)) {
746       const ods::OperandOrResult &result = it.value();
747       const ods::TypeConstraint &constraint = result.getConstraint();
748 
749       // Push back a completion item that uses the result index.
750       lsp::CompletionItem item;
751       item.label = llvm::formatv("{0} (field #{0})", it.index()).str();
752       item.insertText = Twine(it.index()).str();
753       item.filterText = item.sortText = item.insertText;
754       item.kind = lsp::CompletionItemKind::Field;
755       switch (result.getVariableLengthKind()) {
756       case ods::VariableLengthKind::Single:
757         item.detail = llvm::formatv("{0}: Value", it.index()).str();
758         break;
759       case ods::VariableLengthKind::Optional:
760         item.detail = llvm::formatv("{0}: Value?", it.index()).str();
761         break;
762       case ods::VariableLengthKind::Variadic:
763         item.detail = llvm::formatv("{0}: ValueRange", it.index()).str();
764         break;
765       }
766       item.documentation = lsp::MarkupContent{
767           lsp::MarkupKind::Markdown,
768           llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
769                         constraint.getCppClass())
770               .str()};
771       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
772       completionList.items.emplace_back(item);
773 
774       // If the result has a name, push back a completion item with the result
775       // name.
776       if (!result.getName().empty()) {
777         item.label =
778             llvm::formatv("{1} (field #{0})", it.index(), result.getName())
779                 .str();
780         item.filterText = item.label;
781         item.insertText = result.getName().str();
782         completionList.items.emplace_back(item);
783       }
784     }
785   }
786 
codeCompleteOperationAttributeName(StringRef opName)787   void codeCompleteOperationAttributeName(StringRef opName) final {
788     const ods::Operation *odsOp = odsContext.lookupOperation(opName);
789     if (!odsOp)
790       return;
791 
792     for (const ods::Attribute &attr : odsOp->getAttributes()) {
793       const ods::AttributeConstraint &constraint = attr.getConstraint();
794 
795       lsp::CompletionItem item;
796       item.label = attr.getName().str();
797       item.kind = lsp::CompletionItemKind::Field;
798       item.detail = attr.isOptional() ? "optional" : "";
799       item.documentation = lsp::MarkupContent{
800           lsp::MarkupKind::Markdown,
801           llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
802                         constraint.getCppClass())
803               .str()};
804       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
805       completionList.items.emplace_back(item);
806     }
807   }
808 
codeCompleteConstraintName(ast::Type currentType,bool allowNonCoreConstraints,bool allowInlineTypeConstraints,const ast::DeclScope * scope)809   void codeCompleteConstraintName(ast::Type currentType,
810                                   bool allowNonCoreConstraints,
811                                   bool allowInlineTypeConstraints,
812                                   const ast::DeclScope *scope) final {
813     auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
814                                  StringRef snippetText = "") {
815       lsp::CompletionItem item;
816       item.label = constraint.str();
817       item.kind = lsp::CompletionItemKind::Class;
818       item.detail = (constraint + " constraint").str();
819       item.documentation = lsp::MarkupContent{
820           lsp::MarkupKind::Markdown,
821           ("A single entity core constraint of type `" + mlirType + "`").str()};
822       item.sortText = "0";
823       item.insertText = snippetText.str();
824       item.insertTextFormat = snippetText.empty()
825                                   ? lsp::InsertTextFormat::PlainText
826                                   : lsp::InsertTextFormat::Snippet;
827       completionList.items.emplace_back(item);
828     };
829 
830     // Insert completions for the core constraints. Some core constraints have
831     // additional characteristics, so we may add then even if a type has been
832     // inferred.
833     if (!currentType) {
834       addCoreConstraint("Attr", "mlir::Attribute");
835       addCoreConstraint("Op", "mlir::Operation *");
836       addCoreConstraint("Value", "mlir::Value");
837       addCoreConstraint("ValueRange", "mlir::ValueRange");
838       addCoreConstraint("Type", "mlir::Type");
839       addCoreConstraint("TypeRange", "mlir::TypeRange");
840     }
841     if (allowInlineTypeConstraints) {
842       /// Attr<Type>.
843       if (!currentType || currentType.isa<ast::AttributeType>())
844         addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
845       /// Value<Type>.
846       if (!currentType || currentType.isa<ast::ValueType>())
847         addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
848       /// ValueRange<TypeRange>.
849       if (!currentType || currentType.isa<ast::ValueRangeType>())
850         addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
851                           "ValueRange<$1>");
852     }
853 
854     // If a scope was provided, check it for potential constraints.
855     while (scope) {
856       for (const ast::Decl *decl : scope->getDecls()) {
857         if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
858           if (!allowNonCoreConstraints)
859             continue;
860 
861           lsp::CompletionItem item;
862           item.label = cst->getName().getName().str();
863           item.kind = lsp::CompletionItemKind::Interface;
864           item.sortText = "2_" + item.label;
865 
866           // Skip constraints that are not single-arg. We currently only
867           // complete variable constraints.
868           if (cst->getInputs().size() != 1)
869             continue;
870 
871           // Ensure the input type matched the given type.
872           ast::Type constraintType = cst->getInputs()[0]->getType();
873           if (currentType && !currentType.refineWith(constraintType))
874             continue;
875 
876           // Format the constraint signature.
877           {
878             llvm::raw_string_ostream strOS(item.detail);
879             strOS << "(";
880             llvm::interleaveComma(
881                 cst->getInputs(), strOS, [&](const ast::VariableDecl *var) {
882                   strOS << var->getName().getName() << ": " << var->getType();
883                 });
884             strOS << ") -> " << cst->getResultType();
885           }
886 
887           // Format the documentation for the constraint.
888           if (Optional<std::string> doc = getDocumentationFor(sourceMgr, cst)) {
889             item.documentation =
890                 lsp::MarkupContent{lsp::MarkupKind::Markdown, std::move(*doc)};
891           }
892 
893           completionList.items.emplace_back(item);
894         }
895       }
896 
897       scope = scope->getParentScope();
898     }
899   }
900 
codeCompleteDialectName()901   void codeCompleteDialectName() final {
902     // Code complete known dialects.
903     for (const ods::Dialect &dialect : odsContext.getDialects()) {
904       lsp::CompletionItem item;
905       item.label = dialect.getName().str();
906       item.kind = lsp::CompletionItemKind::Class;
907       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
908       completionList.items.emplace_back(item);
909     }
910   }
911 
codeCompleteOperationName(StringRef dialectName)912   void codeCompleteOperationName(StringRef dialectName) final {
913     const ods::Dialect *dialect = odsContext.lookupDialect(dialectName);
914     if (!dialect)
915       return;
916 
917     for (const auto &it : dialect->getOperations()) {
918       const ods::Operation &op = *it.second;
919 
920       lsp::CompletionItem item;
921       item.label = op.getName().drop_front(dialectName.size() + 1).str();
922       item.kind = lsp::CompletionItemKind::Field;
923       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
924       completionList.items.emplace_back(item);
925     }
926   }
927 
codeCompletePatternMetadata()928   void codeCompletePatternMetadata() final {
929     auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
930                                    StringRef snippetText = "") {
931       lsp::CompletionItem item;
932       item.label = constraint.str();
933       item.kind = lsp::CompletionItemKind::Class;
934       item.detail = "pattern metadata";
935       item.documentation =
936           lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()};
937       item.insertText = snippetText.str();
938       item.insertTextFormat = snippetText.empty()
939                                   ? lsp::InsertTextFormat::PlainText
940                                   : lsp::InsertTextFormat::Snippet;
941       completionList.items.emplace_back(item);
942     };
943 
944     addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
945                         "benefit($1)");
946     addSimpleConstraint("recursion",
947                         "The pattern properly handles recursive application.");
948   }
949 
codeCompleteIncludeFilename(StringRef curPath)950   void codeCompleteIncludeFilename(StringRef curPath) final {
951     // Normalize the path to allow for interacting with the file system
952     // utilities.
953     SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(curPath));
954     llvm::sys::path::native(nativeRelDir);
955 
956     // Set of already included completion paths.
957     StringSet<> seenResults;
958 
959     // Functor used to add a single include completion item.
960     auto addIncludeCompletion = [&](StringRef path, bool isDirectory) {
961       lsp::CompletionItem item;
962       item.label = path.str();
963       item.kind = isDirectory ? lsp::CompletionItemKind::Folder
964                               : lsp::CompletionItemKind::File;
965       if (seenResults.insert(item.label).second)
966         completionList.items.emplace_back(item);
967     };
968 
969     // Process the include directories for this file, adding any potential
970     // nested include files or directories.
971     for (StringRef includeDir : includeDirs) {
972       llvm::SmallString<128> dir = includeDir;
973       if (!nativeRelDir.empty())
974         llvm::sys::path::append(dir, nativeRelDir);
975 
976       std::error_code errorCode;
977       for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
978                 e = llvm::sys::fs::directory_iterator();
979            !errorCode && it != e; it.increment(errorCode)) {
980         StringRef filename = llvm::sys::path::filename(it->path());
981 
982         // To know whether a symlink should be treated as file or a directory,
983         // we have to stat it. This should be cheap enough as there shouldn't be
984         // many symlinks.
985         llvm::sys::fs::file_type fileType = it->type();
986         if (fileType == llvm::sys::fs::file_type::symlink_file) {
987           if (auto fileStatus = it->status())
988             fileType = fileStatus->type();
989         }
990 
991         switch (fileType) {
992         case llvm::sys::fs::file_type::directory_file:
993           addIncludeCompletion(filename, /*isDirectory=*/true);
994           break;
995         case llvm::sys::fs::file_type::regular_file: {
996           // Only consider concrete files that can actually be included by PDLL.
997           if (filename.endswith(".pdll") || filename.endswith(".td"))
998             addIncludeCompletion(filename, /*isDirectory=*/false);
999           break;
1000         }
1001         default:
1002           break;
1003         }
1004       }
1005     }
1006 
1007     // Sort the completion results to make sure the output is deterministic in
1008     // the face of different iteration schemes for different platforms.
1009     llvm::sort(completionList.items, [](const lsp::CompletionItem &lhs,
1010                                         const lsp::CompletionItem &rhs) {
1011       return lhs.label < rhs.label;
1012     });
1013   }
1014 
1015 private:
1016   llvm::SourceMgr &sourceMgr;
1017   lsp::CompletionList &completionList;
1018   ods::Context &odsContext;
1019   ArrayRef<std::string> includeDirs;
1020 };
1021 } // namespace
1022 
1023 lsp::CompletionList
getCodeCompletion(const lsp::URIForFile & uri,const lsp::Position & completePos)1024 PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
1025                                const lsp::Position &completePos) {
1026   SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
1027   if (!posLoc.isValid())
1028     return lsp::CompletionList();
1029 
1030   // To perform code completion, we run another parse of the module with the
1031   // code completion context provided.
1032   ods::Context tmpODSContext;
1033   lsp::CompletionList completionList;
1034   LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
1035                                             tmpODSContext,
1036                                             sourceMgr.getIncludeDirs());
1037 
1038   ast::Context tmpContext(tmpODSContext);
1039   (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
1040                      &lspCompleteContext);
1041 
1042   return completionList;
1043 }
1044 
1045 //===----------------------------------------------------------------------===//
1046 // PDLDocument: Signature Help
1047 //===----------------------------------------------------------------------===//
1048 
1049 namespace {
1050 class LSPSignatureHelpContext : public CodeCompleteContext {
1051 public:
LSPSignatureHelpContext(SMLoc completeLoc,llvm::SourceMgr & sourceMgr,lsp::SignatureHelp & signatureHelp,ods::Context & odsContext)1052   LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1053                           lsp::SignatureHelp &signatureHelp,
1054                           ods::Context &odsContext)
1055       : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
1056         signatureHelp(signatureHelp), odsContext(odsContext) {}
1057 
codeCompleteCallSignature(const ast::CallableDecl * callable,unsigned currentNumArgs)1058   void codeCompleteCallSignature(const ast::CallableDecl *callable,
1059                                  unsigned currentNumArgs) final {
1060     signatureHelp.activeParameter = currentNumArgs;
1061 
1062     lsp::SignatureInformation signatureInfo;
1063     {
1064       llvm::raw_string_ostream strOS(signatureInfo.label);
1065       strOS << callable->getName()->getName() << "(";
1066       auto formatParamFn = [&](const ast::VariableDecl *var) {
1067         unsigned paramStart = strOS.str().size();
1068         strOS << var->getName().getName() << ": " << var->getType();
1069         unsigned paramEnd = strOS.str().size();
1070         signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1071             StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1072             std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()});
1073       };
1074       llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
1075       strOS << ") -> " << callable->getResultType();
1076     }
1077 
1078     // Format the documentation for the callable.
1079     if (Optional<std::string> doc = getDocumentationFor(sourceMgr, callable))
1080       signatureInfo.documentation = std::move(*doc);
1081 
1082     signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1083   }
1084 
1085   void
codeCompleteOperationOperandsSignature(Optional<StringRef> opName,unsigned currentNumOperands)1086   codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
1087                                          unsigned currentNumOperands) final {
1088     const ods::Operation *odsOp =
1089         opName ? odsContext.lookupOperation(*opName) : nullptr;
1090     codeCompleteOperationOperandOrResultSignature(
1091         opName, odsOp, odsOp ? odsOp->getOperands() : llvm::None,
1092         currentNumOperands, "operand", "Value");
1093   }
1094 
codeCompleteOperationResultsSignature(Optional<StringRef> opName,unsigned currentNumResults)1095   void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
1096                                              unsigned currentNumResults) final {
1097     const ods::Operation *odsOp =
1098         opName ? odsContext.lookupOperation(*opName) : nullptr;
1099     codeCompleteOperationOperandOrResultSignature(
1100         opName, odsOp, odsOp ? odsOp->getResults() : llvm::None,
1101         currentNumResults, "result", "Type");
1102   }
1103 
codeCompleteOperationOperandOrResultSignature(Optional<StringRef> opName,const ods::Operation * odsOp,ArrayRef<ods::OperandOrResult> values,unsigned currentValue,StringRef label,StringRef dataType)1104   void codeCompleteOperationOperandOrResultSignature(
1105       Optional<StringRef> opName, const ods::Operation *odsOp,
1106       ArrayRef<ods::OperandOrResult> values, unsigned currentValue,
1107       StringRef label, StringRef dataType) {
1108     signatureHelp.activeParameter = currentValue;
1109 
1110     // If we have ODS information for the operation, add in the ODS signature
1111     // for the operation. We also verify that the current number of values is
1112     // not more than what is defined in ODS, as this will result in an error
1113     // anyways.
1114     if (odsOp && currentValue < values.size()) {
1115       lsp::SignatureInformation signatureInfo;
1116 
1117       // Build the signature label.
1118       {
1119         llvm::raw_string_ostream strOS(signatureInfo.label);
1120         strOS << "(";
1121         auto formatFn = [&](const ods::OperandOrResult &value) {
1122           unsigned paramStart = strOS.str().size();
1123 
1124           strOS << value.getName() << ": ";
1125 
1126           StringRef constraintDoc = value.getConstraint().getSummary();
1127           std::string paramDoc;
1128           switch (value.getVariableLengthKind()) {
1129           case ods::VariableLengthKind::Single:
1130             strOS << dataType;
1131             paramDoc = constraintDoc.str();
1132             break;
1133           case ods::VariableLengthKind::Optional:
1134             strOS << dataType << "?";
1135             paramDoc = ("optional: " + constraintDoc).str();
1136             break;
1137           case ods::VariableLengthKind::Variadic:
1138             strOS << dataType << "Range";
1139             paramDoc = ("variadic: " + constraintDoc).str();
1140             break;
1141           }
1142 
1143           unsigned paramEnd = strOS.str().size();
1144           signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1145               StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1146               std::make_pair(paramStart, paramEnd), paramDoc});
1147         };
1148         llvm::interleaveComma(values, strOS, formatFn);
1149         strOS << ")";
1150       }
1151       signatureInfo.documentation =
1152           llvm::formatv("`op<{0}>` ODS {1} specification", *opName, label)
1153               .str();
1154       signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1155     }
1156 
1157     // If there aren't any arguments yet, we also add the generic signature.
1158     if (currentValue == 0 && (!odsOp || !values.empty())) {
1159       lsp::SignatureInformation signatureInfo;
1160       signatureInfo.label =
1161           llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str();
1162       signatureInfo.documentation =
1163           ("Generic operation " + label + " specification").str();
1164       signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1165           StringRef(signatureInfo.label).drop_front().drop_back().str(),
1166           std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
1167           ("All of the " + label + "s of the operation.").str()});
1168       signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1169     }
1170   }
1171 
1172 private:
1173   llvm::SourceMgr &sourceMgr;
1174   lsp::SignatureHelp &signatureHelp;
1175   ods::Context &odsContext;
1176 };
1177 } // namespace
1178 
getSignatureHelp(const lsp::URIForFile & uri,const lsp::Position & helpPos)1179 lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri,
1180                                                  const lsp::Position &helpPos) {
1181   SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr);
1182   if (!posLoc.isValid())
1183     return lsp::SignatureHelp();
1184 
1185   // To perform code completion, we run another parse of the module with the
1186   // code completion context provided.
1187   ods::Context tmpODSContext;
1188   lsp::SignatureHelp signatureHelp;
1189   LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1190                                           tmpODSContext);
1191 
1192   ast::Context tmpContext(tmpODSContext);
1193   (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
1194                      &completeContext);
1195 
1196   return signatureHelp;
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // PDLDocument: Inlay Hints
1201 //===----------------------------------------------------------------------===//
1202 
1203 /// Returns true if the given name should be added as a hint for `expr`.
shouldAddHintFor(const ast::Expr * expr,StringRef name)1204 static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) {
1205   if (name.empty())
1206     return false;
1207 
1208   // If the argument is a reference of the same name, don't add it as a hint.
1209   if (auto *ref = dyn_cast<ast::DeclRefExpr>(expr)) {
1210     const ast::Name *declName = ref->getDecl()->getName();
1211     if (declName && declName->getName() == name)
1212       return false;
1213   }
1214 
1215   return true;
1216 }
1217 
getInlayHints(const lsp::URIForFile & uri,const lsp::Range & range,std::vector<lsp::InlayHint> & inlayHints)1218 void PDLDocument::getInlayHints(const lsp::URIForFile &uri,
1219                                 const lsp::Range &range,
1220                                 std::vector<lsp::InlayHint> &inlayHints) {
1221   if (failed(astModule))
1222     return;
1223   SMRange rangeLoc = range.getAsSMRange(sourceMgr);
1224   if (!rangeLoc.isValid())
1225     return;
1226   (*astModule)->walk([&](const ast::Node *node) {
1227     SMRange loc = node->getLoc();
1228 
1229     // Check that the location of this node is within the input range.
1230     if (!contains(rangeLoc, loc.Start) && !contains(rangeLoc, loc.End))
1231       return;
1232 
1233     // Handle hints for various types of nodes.
1234     llvm::TypeSwitch<const ast::Node *>(node)
1235         .Case<ast::VariableDecl, ast::CallExpr, ast::OperationExpr>(
1236             [&](const auto *node) {
1237               this->getInlayHintsFor(node, uri, inlayHints);
1238             });
1239   });
1240 }
1241 
getInlayHintsFor(const ast::VariableDecl * decl,const lsp::URIForFile & uri,std::vector<lsp::InlayHint> & inlayHints)1242 void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl,
1243                                    const lsp::URIForFile &uri,
1244                                    std::vector<lsp::InlayHint> &inlayHints) {
1245   // Check to see if the variable has a constraint list, if it does we don't
1246   // provide initializer hints.
1247   if (!decl->getConstraints().empty())
1248     return;
1249 
1250   // Check to see if the variable has an initializer.
1251   if (const ast::Expr *expr = decl->getInitExpr()) {
1252     // Don't add hints for operation expression initialized variables given that
1253     // the type of the variable is easily inferred by the expression operation
1254     // name.
1255     if (isa<ast::OperationExpr>(expr))
1256       return;
1257   }
1258 
1259   lsp::InlayHint hint(lsp::InlayHintKind::Type,
1260                       lsp::Position(sourceMgr, decl->getLoc().End));
1261   {
1262     llvm::raw_string_ostream labelOS(hint.label);
1263     labelOS << ": " << decl->getType();
1264   }
1265 
1266   inlayHints.emplace_back(std::move(hint));
1267 }
1268 
getInlayHintsFor(const ast::CallExpr * expr,const lsp::URIForFile & uri,std::vector<lsp::InlayHint> & inlayHints)1269 void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr,
1270                                    const lsp::URIForFile &uri,
1271                                    std::vector<lsp::InlayHint> &inlayHints) {
1272   // Try to extract the callable of this call.
1273   const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
1274   const auto *callable =
1275       callableRef ? dyn_cast<ast::CallableDecl>(callableRef->getDecl())
1276                   : nullptr;
1277   if (!callable)
1278     return;
1279 
1280   // Add hints for the arguments to the call.
1281   for (const auto &it : llvm::zip(expr->getArguments(), callable->getInputs()))
1282     addParameterHintFor(inlayHints, std::get<0>(it),
1283                         std::get<1>(it)->getName().getName());
1284 }
1285 
getInlayHintsFor(const ast::OperationExpr * expr,const lsp::URIForFile & uri,std::vector<lsp::InlayHint> & inlayHints)1286 void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr,
1287                                    const lsp::URIForFile &uri,
1288                                    std::vector<lsp::InlayHint> &inlayHints) {
1289   // Check for ODS information.
1290   ast::OperationType opType = expr->getType().dyn_cast<ast::OperationType>();
1291   const auto *odsOp = opType ? opType.getODSOperation() : nullptr;
1292 
1293   auto addOpHint = [&](const ast::Expr *valueExpr, StringRef label) {
1294     // If the value expression used the same location as the operation, don't
1295     // add a hint. This expression was materialized during parsing.
1296     if (expr->getLoc().Start == valueExpr->getLoc().Start)
1297       return;
1298     addParameterHintFor(inlayHints, valueExpr, label);
1299   };
1300 
1301   // Functor used to process hints for the operands and results of the
1302   // operation. They effectively have the same format, and thus can be processed
1303   // using the same logic.
1304   auto addOperandOrResultHints = [&](ArrayRef<ast::Expr *> values,
1305                                      ArrayRef<ods::OperandOrResult> odsValues,
1306                                      StringRef allValuesName) {
1307     if (values.empty())
1308       return;
1309 
1310     // The values should either map to a single range, or be equivalent to the
1311     // ODS values.
1312     if (values.size() != odsValues.size()) {
1313       // Handle the case of a single element that covers the full range.
1314       if (values.size() == 1)
1315         return addOpHint(values.front(), allValuesName);
1316       return;
1317     }
1318 
1319     for (const auto &it : llvm::zip(values, odsValues))
1320       addOpHint(std::get<0>(it), std::get<1>(it).getName());
1321   };
1322 
1323   // Add hints for the operands and results of the operation.
1324   addOperandOrResultHints(expr->getOperands(),
1325                           odsOp ? odsOp->getOperands()
1326                                 : ArrayRef<ods::OperandOrResult>(),
1327                           "operands");
1328   addOperandOrResultHints(expr->getResultTypes(),
1329                           odsOp ? odsOp->getResults()
1330                                 : ArrayRef<ods::OperandOrResult>(),
1331                           "results");
1332 }
1333 
addParameterHintFor(std::vector<lsp::InlayHint> & inlayHints,const ast::Expr * expr,StringRef label)1334 void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
1335                                       const ast::Expr *expr, StringRef label) {
1336   if (!shouldAddHintFor(expr, label))
1337     return;
1338 
1339   lsp::InlayHint hint(lsp::InlayHintKind::Parameter,
1340                       lsp::Position(sourceMgr, expr->getLoc().Start));
1341   hint.label = (label + ":").str();
1342   hint.paddingRight = true;
1343   inlayHints.emplace_back(std::move(hint));
1344 }
1345 
1346 //===----------------------------------------------------------------------===//
1347 // PDLL ViewOutput
1348 //===----------------------------------------------------------------------===//
1349 
getPDLLViewOutput(raw_ostream & os,lsp::PDLLViewOutputKind kind)1350 void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1351                                     lsp::PDLLViewOutputKind kind) {
1352   if (failed(astModule))
1353     return;
1354   if (kind == lsp::PDLLViewOutputKind::AST) {
1355     (*astModule)->print(os);
1356     return;
1357   }
1358 
1359   // Generate the MLIR for the ast module. We also capture diagnostics here to
1360   // show to the user, which may be useful if PDLL isn't capturing constraints
1361   // expected by PDL.
1362   MLIRContext mlirContext;
1363   SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os);
1364   OwningOpRef<ModuleOp> pdlModule =
1365       codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **astModule);
1366   if (!pdlModule)
1367     return;
1368   if (kind == lsp::PDLLViewOutputKind::MLIR) {
1369     pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
1370     return;
1371   }
1372 
1373   // Otherwise, generate the output for C++.
1374   assert(kind == lsp::PDLLViewOutputKind::CPP &&
1375          "unexpected PDLLViewOutputKind");
1376   codegenPDLLToCPP(**astModule, *pdlModule, os);
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // PDLTextFileChunk
1381 //===----------------------------------------------------------------------===//
1382 
1383 namespace {
1384 /// This class represents a single chunk of an PDL text file.
1385 struct PDLTextFileChunk {
PDLTextFileChunk__anon9e2fdad81d11::PDLTextFileChunk1386   PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri,
1387                    StringRef contents,
1388                    const std::vector<std::string> &extraDirs,
1389                    std::vector<lsp::Diagnostic> &diagnostics)
1390       : lineOffset(lineOffset),
1391         document(uri, contents, extraDirs, diagnostics) {}
1392 
1393   /// Adjust the line number of the given range to anchor at the beginning of
1394   /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anon9e2fdad81d11::PDLTextFileChunk1395   void adjustLocForChunkOffset(lsp::Range &range) {
1396     adjustLocForChunkOffset(range.start);
1397     adjustLocForChunkOffset(range.end);
1398   }
1399   /// Adjust the line number of the given position to anchor at the beginning of
1400   /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anon9e2fdad81d11::PDLTextFileChunk1401   void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
1402 
1403   /// The line offset of this chunk from the beginning of the file.
1404   uint64_t lineOffset;
1405   /// The document referred to by this chunk.
1406   PDLDocument document;
1407 };
1408 } // namespace
1409 
1410 //===----------------------------------------------------------------------===//
1411 // PDLTextFile
1412 //===----------------------------------------------------------------------===//
1413 
1414 namespace {
1415 /// This class represents a text file containing one or more PDL documents.
1416 class PDLTextFile {
1417 public:
1418   PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1419               int64_t version, const std::vector<std::string> &extraDirs,
1420               std::vector<lsp::Diagnostic> &diagnostics);
1421 
1422   /// Return the current version of this text file.
getVersion() const1423   int64_t getVersion() const { return version; }
1424 
1425   /// Update the file to the new version using the provided set of content
1426   /// changes. Returns failure if the update was unsuccessful.
1427   LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion,
1428                        ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1429                        std::vector<lsp::Diagnostic> &diagnostics);
1430 
1431   //===--------------------------------------------------------------------===//
1432   // LSP Queries
1433   //===--------------------------------------------------------------------===//
1434 
1435   void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1436                       std::vector<lsp::Location> &locations);
1437   void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1438                         std::vector<lsp::Location> &references);
1439   void getDocumentLinks(const lsp::URIForFile &uri,
1440                         std::vector<lsp::DocumentLink> &links);
1441   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1442                                  lsp::Position hoverPos);
1443   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1444   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1445                                         lsp::Position completePos);
1446   lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
1447                                       lsp::Position helpPos);
1448   void getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1449                      std::vector<lsp::InlayHint> &inlayHints);
1450   lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind);
1451 
1452 private:
1453   using ChunkIterator = llvm::pointee_iterator<
1454       std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1455 
1456   /// Initialize the text file from the given file contents.
1457   void initialize(const lsp::URIForFile &uri, int64_t newVersion,
1458                   std::vector<lsp::Diagnostic> &diagnostics);
1459 
1460   /// Find the PDL document that contains the given position, and update the
1461   /// position to be anchored at the start of the found chunk instead of the
1462   /// beginning of the file.
1463   ChunkIterator getChunkItFor(lsp::Position &pos);
getChunkFor(lsp::Position & pos)1464   PDLTextFileChunk &getChunkFor(lsp::Position &pos) {
1465     return *getChunkItFor(pos);
1466   }
1467 
1468   /// The full string contents of the file.
1469   std::string contents;
1470 
1471   /// The version of this file.
1472   int64_t version = 0;
1473 
1474   /// The number of lines in the file.
1475   int64_t totalNumLines = 0;
1476 
1477   /// The chunks of this file. The order of these chunks is the order in which
1478   /// they appear in the text file.
1479   std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1480 
1481   /// The extra set of include directories for this file.
1482   std::vector<std::string> extraIncludeDirs;
1483 };
1484 } // namespace
1485 
PDLTextFile(const lsp::URIForFile & uri,StringRef fileContents,int64_t version,const std::vector<std::string> & extraDirs,std::vector<lsp::Diagnostic> & diagnostics)1486 PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1487                          int64_t version,
1488                          const std::vector<std::string> &extraDirs,
1489                          std::vector<lsp::Diagnostic> &diagnostics)
1490     : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1491   initialize(uri, version, diagnostics);
1492 }
1493 
1494 LogicalResult
update(const lsp::URIForFile & uri,int64_t newVersion,ArrayRef<lsp::TextDocumentContentChangeEvent> changes,std::vector<lsp::Diagnostic> & diagnostics)1495 PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
1496                     ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1497                     std::vector<lsp::Diagnostic> &diagnostics) {
1498   if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
1499     lsp::Logger::error("Failed to update contents of {0}", uri.file());
1500     return failure();
1501   }
1502 
1503   // If the file contents were properly changed, reinitialize the text file.
1504   initialize(uri, newVersion, diagnostics);
1505   return success();
1506 }
1507 
getLocationsOf(const lsp::URIForFile & uri,lsp::Position defPos,std::vector<lsp::Location> & locations)1508 void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri,
1509                                  lsp::Position defPos,
1510                                  std::vector<lsp::Location> &locations) {
1511   PDLTextFileChunk &chunk = getChunkFor(defPos);
1512   chunk.document.getLocationsOf(uri, defPos, locations);
1513 
1514   // Adjust any locations within this file for the offset of this chunk.
1515   if (chunk.lineOffset == 0)
1516     return;
1517   for (lsp::Location &loc : locations)
1518     if (loc.uri == uri)
1519       chunk.adjustLocForChunkOffset(loc.range);
1520 }
1521 
findReferencesOf(const lsp::URIForFile & uri,lsp::Position pos,std::vector<lsp::Location> & references)1522 void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri,
1523                                    lsp::Position pos,
1524                                    std::vector<lsp::Location> &references) {
1525   PDLTextFileChunk &chunk = getChunkFor(pos);
1526   chunk.document.findReferencesOf(uri, pos, references);
1527 
1528   // Adjust any locations within this file for the offset of this chunk.
1529   if (chunk.lineOffset == 0)
1530     return;
1531   for (lsp::Location &loc : references)
1532     if (loc.uri == uri)
1533       chunk.adjustLocForChunkOffset(loc.range);
1534 }
1535 
getDocumentLinks(const lsp::URIForFile & uri,std::vector<lsp::DocumentLink> & links)1536 void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri,
1537                                    std::vector<lsp::DocumentLink> &links) {
1538   chunks.front()->document.getDocumentLinks(uri, links);
1539   for (const auto &it : llvm::drop_begin(chunks)) {
1540     size_t currentNumLinks = links.size();
1541     it->document.getDocumentLinks(uri, links);
1542 
1543     // Adjust any links within this file to account for the offset of this
1544     // chunk.
1545     for (auto &link : llvm::drop_begin(links, currentNumLinks))
1546       it->adjustLocForChunkOffset(link.range);
1547   }
1548 }
1549 
findHover(const lsp::URIForFile & uri,lsp::Position hoverPos)1550 Optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri,
1551                                             lsp::Position hoverPos) {
1552   PDLTextFileChunk &chunk = getChunkFor(hoverPos);
1553   Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1554 
1555   // Adjust any locations within this file for the offset of this chunk.
1556   if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1557     chunk.adjustLocForChunkOffset(*hoverInfo->range);
1558   return hoverInfo;
1559 }
1560 
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)1561 void PDLTextFile::findDocumentSymbols(
1562     std::vector<lsp::DocumentSymbol> &symbols) {
1563   if (chunks.size() == 1)
1564     return chunks.front()->document.findDocumentSymbols(symbols);
1565 
1566   // If there are multiple chunks in this file, we create top-level symbols for
1567   // each chunk.
1568   for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1569     PDLTextFileChunk &chunk = *chunks[i];
1570     lsp::Position startPos(chunk.lineOffset);
1571     lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1572                                       : chunks[i + 1]->lineOffset);
1573     lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1574                                lsp::SymbolKind::Namespace,
1575                                /*range=*/lsp::Range(startPos, endPos),
1576                                /*selectionRange=*/lsp::Range(startPos));
1577     chunk.document.findDocumentSymbols(symbol.children);
1578 
1579     // Fixup the locations of document symbols within this chunk.
1580     if (i != 0) {
1581       SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1582       for (lsp::DocumentSymbol &childSymbol : symbol.children)
1583         symbolsToFix.push_back(&childSymbol);
1584 
1585       while (!symbolsToFix.empty()) {
1586         lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1587         chunk.adjustLocForChunkOffset(symbol->range);
1588         chunk.adjustLocForChunkOffset(symbol->selectionRange);
1589 
1590         for (lsp::DocumentSymbol &childSymbol : symbol->children)
1591           symbolsToFix.push_back(&childSymbol);
1592       }
1593     }
1594 
1595     // Push the symbol for this chunk.
1596     symbols.emplace_back(std::move(symbol));
1597   }
1598 }
1599 
getCodeCompletion(const lsp::URIForFile & uri,lsp::Position completePos)1600 lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1601                                                    lsp::Position completePos) {
1602   PDLTextFileChunk &chunk = getChunkFor(completePos);
1603   lsp::CompletionList completionList =
1604       chunk.document.getCodeCompletion(uri, completePos);
1605 
1606   // Adjust any completion locations.
1607   for (lsp::CompletionItem &item : completionList.items) {
1608     if (item.textEdit)
1609       chunk.adjustLocForChunkOffset(item.textEdit->range);
1610     for (lsp::TextEdit &edit : item.additionalTextEdits)
1611       chunk.adjustLocForChunkOffset(edit.range);
1612   }
1613   return completionList;
1614 }
1615 
getSignatureHelp(const lsp::URIForFile & uri,lsp::Position helpPos)1616 lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri,
1617                                                  lsp::Position helpPos) {
1618   return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
1619 }
1620 
getInlayHints(const lsp::URIForFile & uri,lsp::Range range,std::vector<lsp::InlayHint> & inlayHints)1621 void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1622                                 std::vector<lsp::InlayHint> &inlayHints) {
1623   auto startIt = getChunkItFor(range.start);
1624   auto endIt = getChunkItFor(range.end);
1625 
1626   // Functor used to get the chunks for a given file, and fixup any locations
1627   auto getHintsForChunk = [&](ChunkIterator chunkIt, lsp::Range range) {
1628     size_t currentNumHints = inlayHints.size();
1629     chunkIt->document.getInlayHints(uri, range, inlayHints);
1630 
1631     // If this isn't the first chunk, update any positions to account for line
1632     // number differences.
1633     if (&*chunkIt != &*chunks.front()) {
1634       for (auto &hint : llvm::drop_begin(inlayHints, currentNumHints))
1635         chunkIt->adjustLocForChunkOffset(hint.position);
1636     }
1637   };
1638   // Returns the number of lines held by a given chunk.
1639   auto getNumLines = [](ChunkIterator chunkIt) {
1640     return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1641   };
1642 
1643   // Check if the range is fully within a single chunk.
1644   if (startIt == endIt)
1645     return getHintsForChunk(startIt, range);
1646 
1647   // Otherwise, the range is split between multiple chunks. The first chunk
1648   // has the correct range start, but covers the total document.
1649   getHintsForChunk(startIt, lsp::Range(range.start, getNumLines(startIt)));
1650 
1651   // Every chunk in between uses the full document.
1652   for (++startIt; startIt != endIt; ++startIt)
1653     getHintsForChunk(startIt, lsp::Range(0, getNumLines(startIt)));
1654 
1655   // The range for the last chunk starts at the beginning of the document, up
1656   // through the end of the input range.
1657   getHintsForChunk(startIt, lsp::Range(0, range.end));
1658 }
1659 
1660 lsp::PDLLViewOutputResult
getPDLLViewOutput(lsp::PDLLViewOutputKind kind)1661 PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) {
1662   lsp::PDLLViewOutputResult result;
1663   {
1664     llvm::raw_string_ostream outputOS(result.output);
1665     llvm::interleave(
1666         llvm::make_pointee_range(chunks),
1667         [&](PDLTextFileChunk &chunk) {
1668           chunk.document.getPDLLViewOutput(outputOS, kind);
1669         },
1670         [&] { outputOS << "\n// -----\n\n"; });
1671   }
1672   return result;
1673 }
1674 
initialize(const lsp::URIForFile & uri,int64_t newVersion,std::vector<lsp::Diagnostic> & diagnostics)1675 void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion,
1676                              std::vector<lsp::Diagnostic> &diagnostics) {
1677   version = newVersion;
1678   chunks.clear();
1679 
1680   // Split the file into separate PDL documents.
1681   // TODO: Find a way to share the split file marker with other tools. We don't
1682   // want to use `splitAndProcessBuffer` here, but we do want to make sure this
1683   // marker doesn't go out of sync.
1684   SmallVector<StringRef, 8> subContents;
1685   StringRef(contents).split(subContents, "// -----");
1686   chunks.emplace_back(std::make_unique<PDLTextFileChunk>(
1687       /*lineOffset=*/0, uri, subContents.front(), extraIncludeDirs,
1688       diagnostics));
1689 
1690   uint64_t lineOffset = subContents.front().count('\n');
1691   for (StringRef docContents : llvm::drop_begin(subContents)) {
1692     unsigned currentNumDiags = diagnostics.size();
1693     auto chunk = std::make_unique<PDLTextFileChunk>(
1694         lineOffset, uri, docContents, extraIncludeDirs, diagnostics);
1695     lineOffset += docContents.count('\n');
1696 
1697     // Adjust locations used in diagnostics to account for the offset from the
1698     // beginning of the file.
1699     for (lsp::Diagnostic &diag :
1700          llvm::drop_begin(diagnostics, currentNumDiags)) {
1701       chunk->adjustLocForChunkOffset(diag.range);
1702 
1703       if (!diag.relatedInformation)
1704         continue;
1705       for (auto &it : *diag.relatedInformation)
1706         if (it.location.uri == uri)
1707           chunk->adjustLocForChunkOffset(it.location.range);
1708     }
1709     chunks.emplace_back(std::move(chunk));
1710   }
1711   totalNumLines = lineOffset;
1712 }
1713 
getChunkItFor(lsp::Position & pos)1714 PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(lsp::Position &pos) {
1715   if (chunks.size() == 1)
1716     return chunks.begin();
1717 
1718   // Search for the first chunk with a greater line offset, the previous chunk
1719   // is the one that contains `pos`.
1720   auto it = llvm::upper_bound(
1721       chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1722         return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1723       });
1724   ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1725   pos.line -= chunkIt->lineOffset;
1726   return chunkIt;
1727 }
1728 
1729 //===----------------------------------------------------------------------===//
1730 // PDLLServer::Impl
1731 //===----------------------------------------------------------------------===//
1732 
1733 struct lsp::PDLLServer::Impl {
Impllsp::PDLLServer::Impl1734   explicit Impl(const Options &options)
1735       : options(options), compilationDatabase(options.compilationDatabases) {}
1736 
1737   /// PDLL LSP options.
1738   const Options &options;
1739 
1740   /// The compilation database containing additional information for files
1741   /// passed to the server.
1742   lsp::CompilationDatabase compilationDatabase;
1743 
1744   /// The files held by the server, mapped by their URI file name.
1745   llvm::StringMap<std::unique_ptr<PDLTextFile>> files;
1746 };
1747 
1748 //===----------------------------------------------------------------------===//
1749 // PDLLServer
1750 //===----------------------------------------------------------------------===//
1751 
PDLLServer(const Options & options)1752 lsp::PDLLServer::PDLLServer(const Options &options)
1753     : impl(std::make_unique<Impl>(options)) {}
1754 lsp::PDLLServer::~PDLLServer() = default;
1755 
addDocument(const URIForFile & uri,StringRef contents,int64_t version,std::vector<Diagnostic> & diagnostics)1756 void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents,
1757                                   int64_t version,
1758                                   std::vector<Diagnostic> &diagnostics) {
1759   // Build the set of additional include directories.
1760   std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs;
1761   const auto &fileInfo = impl->compilationDatabase.getFileInfo(uri.file());
1762   llvm::append_range(additionalIncludeDirs, fileInfo.includeDirs);
1763 
1764   impl->files[uri.file()] = std::make_unique<PDLTextFile>(
1765       uri, contents, version, additionalIncludeDirs, diagnostics);
1766 }
1767 
updateDocument(const URIForFile & uri,ArrayRef<TextDocumentContentChangeEvent> changes,int64_t version,std::vector<Diagnostic> & diagnostics)1768 void lsp::PDLLServer::updateDocument(
1769     const URIForFile &uri, ArrayRef<TextDocumentContentChangeEvent> changes,
1770     int64_t version, std::vector<Diagnostic> &diagnostics) {
1771   // Check that we actually have a document for this uri.
1772   auto it = impl->files.find(uri.file());
1773   if (it == impl->files.end())
1774     return;
1775 
1776   // Try to update the document. If we fail, erase the file from the server. A
1777   // failed updated generally means we've fallen out of sync somewhere.
1778   if (failed(it->second->update(uri, version, changes, diagnostics)))
1779     impl->files.erase(it);
1780 }
1781 
removeDocument(const URIForFile & uri)1782 Optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) {
1783   auto it = impl->files.find(uri.file());
1784   if (it == impl->files.end())
1785     return llvm::None;
1786 
1787   int64_t version = it->second->getVersion();
1788   impl->files.erase(it);
1789   return version;
1790 }
1791 
getLocationsOf(const URIForFile & uri,const Position & defPos,std::vector<Location> & locations)1792 void lsp::PDLLServer::getLocationsOf(const URIForFile &uri,
1793                                      const Position &defPos,
1794                                      std::vector<Location> &locations) {
1795   auto fileIt = impl->files.find(uri.file());
1796   if (fileIt != impl->files.end())
1797     fileIt->second->getLocationsOf(uri, defPos, locations);
1798 }
1799 
findReferencesOf(const URIForFile & uri,const Position & pos,std::vector<Location> & references)1800 void lsp::PDLLServer::findReferencesOf(const URIForFile &uri,
1801                                        const Position &pos,
1802                                        std::vector<Location> &references) {
1803   auto fileIt = impl->files.find(uri.file());
1804   if (fileIt != impl->files.end())
1805     fileIt->second->findReferencesOf(uri, pos, references);
1806 }
1807 
getDocumentLinks(const URIForFile & uri,std::vector<DocumentLink> & documentLinks)1808 void lsp::PDLLServer::getDocumentLinks(
1809     const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1810   auto fileIt = impl->files.find(uri.file());
1811   if (fileIt != impl->files.end())
1812     return fileIt->second->getDocumentLinks(uri, documentLinks);
1813 }
1814 
findHover(const URIForFile & uri,const Position & hoverPos)1815 Optional<lsp::Hover> lsp::PDLLServer::findHover(const URIForFile &uri,
1816                                                 const Position &hoverPos) {
1817   auto fileIt = impl->files.find(uri.file());
1818   if (fileIt != impl->files.end())
1819     return fileIt->second->findHover(uri, hoverPos);
1820   return llvm::None;
1821 }
1822 
findDocumentSymbols(const URIForFile & uri,std::vector<DocumentSymbol> & symbols)1823 void lsp::PDLLServer::findDocumentSymbols(
1824     const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1825   auto fileIt = impl->files.find(uri.file());
1826   if (fileIt != impl->files.end())
1827     fileIt->second->findDocumentSymbols(symbols);
1828 }
1829 
1830 lsp::CompletionList
getCodeCompletion(const URIForFile & uri,const Position & completePos)1831 lsp::PDLLServer::getCodeCompletion(const URIForFile &uri,
1832                                    const Position &completePos) {
1833   auto fileIt = impl->files.find(uri.file());
1834   if (fileIt != impl->files.end())
1835     return fileIt->second->getCodeCompletion(uri, completePos);
1836   return CompletionList();
1837 }
1838 
getSignatureHelp(const URIForFile & uri,const Position & helpPos)1839 lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri,
1840                                                      const Position &helpPos) {
1841   auto fileIt = impl->files.find(uri.file());
1842   if (fileIt != impl->files.end())
1843     return fileIt->second->getSignatureHelp(uri, helpPos);
1844   return SignatureHelp();
1845 }
1846 
getInlayHints(const URIForFile & uri,const Range & range,std::vector<InlayHint> & inlayHints)1847 void lsp::PDLLServer::getInlayHints(const URIForFile &uri, const Range &range,
1848                                     std::vector<InlayHint> &inlayHints) {
1849   auto fileIt = impl->files.find(uri.file());
1850   if (fileIt == impl->files.end())
1851     return;
1852   fileIt->second->getInlayHints(uri, range, inlayHints);
1853 
1854   // Drop any duplicated hints that may have cropped up.
1855   llvm::sort(inlayHints);
1856   inlayHints.erase(std::unique(inlayHints.begin(), inlayHints.end()),
1857                    inlayHints.end());
1858 }
1859 
1860 Optional<lsp::PDLLViewOutputResult>
getPDLLViewOutput(const URIForFile & uri,PDLLViewOutputKind kind)1861 lsp::PDLLServer::getPDLLViewOutput(const URIForFile &uri,
1862                                    PDLLViewOutputKind kind) {
1863   auto fileIt = impl->files.find(uri.file());
1864   if (fileIt != impl->files.end())
1865     return fileIt->second->getPDLLViewOutput(kind);
1866   return llvm::None;
1867 }
1868