1 //===- PDLLServer.cpp - PDLL Language Server ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PDLLServer.h"
10 
11 #include "../lsp-server-support/Logging.h"
12 #include "../lsp-server-support/Protocol.h"
13 #include "CompilationDatabase.h"
14 #include "mlir/Tools/PDLL/AST/Context.h"
15 #include "mlir/Tools/PDLL/AST/Nodes.h"
16 #include "mlir/Tools/PDLL/AST/Types.h"
17 #include "mlir/Tools/PDLL/ODS/Constraint.h"
18 #include "mlir/Tools/PDLL/ODS/Context.h"
19 #include "mlir/Tools/PDLL/ODS/Dialect.h"
20 #include "mlir/Tools/PDLL/ODS/Operation.h"
21 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
22 #include "mlir/Tools/PDLL/Parser/Parser.h"
23 #include "llvm/ADT/IntervalMap.h"
24 #include "llvm/ADT/StringMap.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FileSystem.h"
28 #include "llvm/Support/Path.h"
29 
30 using namespace mlir;
31 using namespace mlir::pdll;
32 
33 /// Returns a language server uri for the given source location. `mainFileURI`
34 /// corresponds to the uri for the main file of the source manager.
35 static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
36                                      const lsp::URIForFile &mainFileURI) {
37   int bufferId = mgr.FindBufferContainingLoc(loc.Start);
38   if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
39     return mainFileURI;
40   llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile(
41       mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
42   if (fileForLoc)
43     return *fileForLoc;
44   lsp::Logger::error("Failed to create URI for include file: {0}",
45                      llvm::toString(fileForLoc.takeError()));
46   return mainFileURI;
47 }
48 
49 /// Returns true if the given location is in the main file of the source
50 /// manager.
51 static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) {
52   return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID();
53 }
54 
55 /// Returns a language server location from the given source range.
56 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
57                                         const lsp::URIForFile &uri) {
58   return lsp::Location(getURIFromLoc(mgr, range, uri), lsp::Range(mgr, range));
59 }
60 
61 /// Convert the given MLIR diagnostic to the LSP form.
62 static Optional<lsp::Diagnostic>
63 getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
64                         const lsp::URIForFile &uri) {
65   lsp::Diagnostic lspDiag;
66   lspDiag.source = "pdll";
67 
68   // FIXME: Right now all of the diagnostics are treated as parser issues, but
69   // some are parser and some are verifier.
70   lspDiag.category = "Parse Error";
71 
72   // Try to grab a file location for this diagnostic.
73   lsp::Location loc = getLocationFromLoc(sourceMgr, diag.getLocation(), uri);
74   lspDiag.range = loc.range;
75 
76   // Skip diagnostics that weren't emitted within the main file.
77   if (loc.uri != uri)
78     return llvm::None;
79 
80   // Convert the severity for the diagnostic.
81   switch (diag.getSeverity()) {
82   case ast::Diagnostic::Severity::DK_Note:
83     llvm_unreachable("expected notes to be handled separately");
84   case ast::Diagnostic::Severity::DK_Warning:
85     lspDiag.severity = lsp::DiagnosticSeverity::Warning;
86     break;
87   case ast::Diagnostic::Severity::DK_Error:
88     lspDiag.severity = lsp::DiagnosticSeverity::Error;
89     break;
90   case ast::Diagnostic::Severity::DK_Remark:
91     lspDiag.severity = lsp::DiagnosticSeverity::Information;
92     break;
93   }
94   lspDiag.message = diag.getMessage().str();
95 
96   // Attach any notes to the main diagnostic as related information.
97   std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
98   for (const ast::Diagnostic &note : diag.getNotes()) {
99     relatedDiags.emplace_back(
100         getLocationFromLoc(sourceMgr, note.getLocation(), uri),
101         note.getMessage().str());
102   }
103   if (!relatedDiags.empty())
104     lspDiag.relatedInformation = std::move(relatedDiags);
105 
106   return lspDiag;
107 }
108 
109 //===----------------------------------------------------------------------===//
110 // PDLLInclude
111 //===----------------------------------------------------------------------===//
112 
113 namespace {
114 /// This class represents a single include within a root file.
115 struct PDLLInclude {
116   PDLLInclude(const lsp::URIForFile &uri, const lsp::Range &range)
117       : uri(uri), range(range) {}
118 
119   /// The URI of the file that is included.
120   lsp::URIForFile uri;
121 
122   /// The range of the include directive.
123   lsp::Range range;
124 };
125 } // namespace
126 
127 //===----------------------------------------------------------------------===//
128 // PDLIndex
129 //===----------------------------------------------------------------------===//
130 
131 namespace {
132 struct PDLIndexSymbol {
133   explicit PDLIndexSymbol(const ast::Decl *definition)
134       : definition(definition) {}
135   explicit PDLIndexSymbol(const ods::Operation *definition)
136       : definition(definition) {}
137 
138   /// Return the location of the definition of this symbol.
139   SMRange getDefLoc() const {
140     if (const ast::Decl *decl = definition.dyn_cast<const ast::Decl *>()) {
141       const ast::Name *declName = decl->getName();
142       return declName ? declName->getLoc() : decl->getLoc();
143     }
144     return definition.get<const ods::Operation *>()->getLoc();
145   }
146 
147   /// The main definition of the symbol.
148   PointerUnion<const ast::Decl *, const ods::Operation *> definition;
149   /// The set of references to the symbol.
150   std::vector<SMRange> references;
151 };
152 
153 /// This class provides an index for definitions/uses within a PDL document.
154 /// It provides efficient lookup of a definition given an input source range.
155 class PDLIndex {
156 public:
157   PDLIndex() : intervalMap(allocator) {}
158 
159   /// Initialize the index with the given ast::Module.
160   void initialize(const ast::Module &module, const ods::Context &odsContext);
161 
162   /// Lookup a symbol for the given location. Returns nullptr if no symbol could
163   /// be found. If provided, `overlappedRange` is set to the range that the
164   /// provided `loc` overlapped with.
165   const PDLIndexSymbol *lookup(SMLoc loc,
166                                SMRange *overlappedRange = nullptr) const;
167 
168 private:
169   /// The type of interval map used to store source references. SMRange is
170   /// half-open, so we also need to use a half-open interval map.
171   using MapT =
172       llvm::IntervalMap<const char *, const PDLIndexSymbol *,
173                         llvm::IntervalMapImpl::NodeSizer<
174                             const char *, const PDLIndexSymbol *>::LeafSize,
175                         llvm::IntervalMapHalfOpenInfo<const char *>>;
176 
177   /// An allocator for the interval map.
178   MapT::Allocator allocator;
179 
180   /// An interval map containing a corresponding definition mapped to a source
181   /// interval.
182   MapT intervalMap;
183 
184   /// A mapping between definitions and their corresponding symbol.
185   DenseMap<const void *, std::unique_ptr<PDLIndexSymbol>> defToSymbol;
186 };
187 } // namespace
188 
189 void PDLIndex::initialize(const ast::Module &module,
190                           const ods::Context &odsContext) {
191   auto getOrInsertDef = [&](const auto *def) -> PDLIndexSymbol * {
192     auto it = defToSymbol.try_emplace(def, nullptr);
193     if (it.second)
194       it.first->second = std::make_unique<PDLIndexSymbol>(def);
195     return &*it.first->second;
196   };
197   auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
198                            bool isDef = false) {
199     const char *startLoc = refLoc.Start.getPointer();
200     const char *endLoc = refLoc.End.getPointer();
201     if (!intervalMap.overlaps(startLoc, endLoc)) {
202       intervalMap.insert(startLoc, endLoc, sym);
203       if (!isDef)
204         sym->references.push_back(refLoc);
205     }
206   };
207   auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
208     const ods::Operation *odsOp = odsContext.lookupOperation(opName);
209     if (!odsOp)
210       return;
211 
212     PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
213     insertDeclRef(symbol, odsOp->getLoc(), /*isDef=*/true);
214     insertDeclRef(symbol, refLoc);
215   };
216 
217   module.walk([&](const ast::Node *node) {
218     // Handle references to PDL decls.
219     if (const auto *decl = dyn_cast<ast::OpNameDecl>(node)) {
220       if (Optional<StringRef> name = decl->getName())
221         insertODSOpRef(*name, decl->getLoc());
222     } else if (const ast::Decl *decl = dyn_cast<ast::Decl>(node)) {
223       const ast::Name *name = decl->getName();
224       if (!name)
225         return;
226       PDLIndexSymbol *declSym = getOrInsertDef(decl);
227       insertDeclRef(declSym, name->getLoc(), /*isDef=*/true);
228 
229       if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) {
230         // Record references to any constraints.
231         for (const auto &it : varDecl->getConstraints())
232           insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
233       }
234     } else if (const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) {
235       insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
236     }
237   });
238 }
239 
240 const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
241                                        SMRange *overlappedRange) const {
242   auto it = intervalMap.find(loc.getPointer());
243   if (!it.valid() || loc.getPointer() < it.start())
244     return nullptr;
245 
246   if (overlappedRange) {
247     *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()),
248                                SMLoc::getFromPointer(it.stop()));
249   }
250   return it.value();
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // PDLDocument
255 //===----------------------------------------------------------------------===//
256 
257 namespace {
258 /// This class represents all of the information pertaining to a specific PDL
259 /// document.
260 struct PDLDocument {
261   PDLDocument(const lsp::URIForFile &uri, StringRef contents,
262               const std::vector<std::string> &extraDirs,
263               std::vector<lsp::Diagnostic> &diagnostics);
264   PDLDocument(const PDLDocument &) = delete;
265   PDLDocument &operator=(const PDLDocument &) = delete;
266 
267   //===--------------------------------------------------------------------===//
268   // Definitions and References
269   //===--------------------------------------------------------------------===//
270 
271   void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
272                       std::vector<lsp::Location> &locations);
273   void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
274                         std::vector<lsp::Location> &references);
275 
276   //===--------------------------------------------------------------------===//
277   // Document Links
278   //===--------------------------------------------------------------------===//
279 
280   void getDocumentLinks(const lsp::URIForFile &uri,
281                         std::vector<lsp::DocumentLink> &links);
282 
283   //===--------------------------------------------------------------------===//
284   // Hover
285   //===--------------------------------------------------------------------===//
286 
287   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
288                                  const lsp::Position &hoverPos);
289   Optional<lsp::Hover> findHover(const ast::Decl *decl,
290                                  const SMRange &hoverRange);
291   lsp::Hover buildHoverForInclude(const PDLLInclude &include);
292   lsp::Hover buildHoverForOpName(const ods::Operation *op,
293                                  const SMRange &hoverRange);
294   lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
295                                    const SMRange &hoverRange);
296   lsp::Hover buildHoverForPattern(const ast::PatternDecl *patternDecl,
297                                   const SMRange &hoverRange);
298   lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
299                                          const SMRange &hoverRange);
300   template <typename T>
301   lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
302                                                   const T *decl,
303                                                   const SMRange &hoverRange);
304 
305   //===--------------------------------------------------------------------===//
306   // Document Symbols
307   //===--------------------------------------------------------------------===//
308 
309   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
310 
311   //===--------------------------------------------------------------------===//
312   // Code Completion
313   //===--------------------------------------------------------------------===//
314 
315   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
316                                         const lsp::Position &completePos);
317 
318   //===--------------------------------------------------------------------===//
319   // Signature Help
320   //===--------------------------------------------------------------------===//
321 
322   lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
323                                       const lsp::Position &helpPos);
324 
325   //===--------------------------------------------------------------------===//
326   // Fields
327   //===--------------------------------------------------------------------===//
328 
329   /// The include directories for this file.
330   std::vector<std::string> includeDirs;
331 
332   /// The source manager containing the contents of the input file.
333   llvm::SourceMgr sourceMgr;
334 
335   /// The ODS and AST contexts.
336   ods::Context odsContext;
337   ast::Context astContext;
338 
339   /// The parsed AST module, or failure if the file wasn't valid.
340   FailureOr<ast::Module *> astModule;
341 
342   /// The index of the parsed module.
343   PDLIndex index;
344 
345   /// The set of includes of the parsed module.
346   std::vector<PDLLInclude> parsedIncludes;
347 };
348 } // namespace
349 
350 PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents,
351                          const std::vector<std::string> &extraDirs,
352                          std::vector<lsp::Diagnostic> &diagnostics)
353     : astContext(odsContext) {
354   auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
355   if (!memBuffer) {
356     lsp::Logger::error("Failed to create memory buffer for file", uri.file());
357     return;
358   }
359 
360   // Build the set of include directories for this file.
361   llvm::SmallString<32> uriDirectory(uri.file());
362   llvm::sys::path::remove_filename(uriDirectory);
363   includeDirs.push_back(uriDirectory.str().str());
364   includeDirs.insert(includeDirs.end(), extraDirs.begin(), extraDirs.end());
365 
366   sourceMgr.setIncludeDirs(includeDirs);
367   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
368 
369   astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
370     if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri))
371       diagnostics.push_back(std::move(*lspDiag));
372   });
373   astModule = parsePDLAST(astContext, sourceMgr);
374 
375   // Initialize the set of parsed includes.
376   for (unsigned i = 1, e = sourceMgr.getNumBuffers(); i < e; ++i) {
377     // Check to see if this file was included by the main file.
378     SMLoc includeLoc = sourceMgr.getBufferInfo(i + 1).IncludeLoc;
379     if (!includeLoc.isValid() || sourceMgr.FindBufferContainingLoc(
380                                      includeLoc) != sourceMgr.getMainFileID())
381       continue;
382 
383     // Try to build a URI for this file path.
384     auto *buffer = sourceMgr.getMemoryBuffer(i + 1);
385     llvm::SmallString<256> path(buffer->getBufferIdentifier());
386     llvm::sys::path::remove_dots(path, /*remove_dot_dot=*/true);
387 
388     llvm::Expected<lsp::URIForFile> includedFileURI =
389         lsp::URIForFile::fromFile(path);
390     if (!includedFileURI)
391       continue;
392 
393     // Find the end of the include token.
394     const char *includeStart = includeLoc.getPointer() - 2;
395     while (*(--includeStart) != '\"')
396       continue;
397 
398     // Push this include.
399     SMRange includeRange(SMLoc::getFromPointer(includeStart), includeLoc);
400     parsedIncludes.emplace_back(*includedFileURI,
401                                 lsp::Range(sourceMgr, includeRange));
402   }
403 
404   // If we failed to parse the module, there is nothing left to initialize.
405   if (failed(astModule))
406     return;
407 
408   // Prepare the AST index with the parsed module.
409   index.initialize(**astModule, odsContext);
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // PDLDocument: Definitions and References
414 //===----------------------------------------------------------------------===//
415 
416 void PDLDocument::getLocationsOf(const lsp::URIForFile &uri,
417                                  const lsp::Position &defPos,
418                                  std::vector<lsp::Location> &locations) {
419   SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
420   const PDLIndexSymbol *symbol = index.lookup(posLoc);
421   if (!symbol)
422     return;
423 
424   locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
425 }
426 
427 void PDLDocument::findReferencesOf(const lsp::URIForFile &uri,
428                                    const lsp::Position &pos,
429                                    std::vector<lsp::Location> &references) {
430   SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
431   const PDLIndexSymbol *symbol = index.lookup(posLoc);
432   if (!symbol)
433     return;
434 
435   references.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
436   for (SMRange refLoc : symbol->references)
437     references.push_back(getLocationFromLoc(sourceMgr, refLoc, uri));
438 }
439 
440 //===--------------------------------------------------------------------===//
441 // PDLDocument: Document Links
442 //===--------------------------------------------------------------------===//
443 
444 void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri,
445                                    std::vector<lsp::DocumentLink> &links) {
446   for (const PDLLInclude &include : parsedIncludes)
447     links.emplace_back(include.range, include.uri);
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // PDLDocument: Hover
452 //===----------------------------------------------------------------------===//
453 
454 Optional<lsp::Hover> PDLDocument::findHover(const lsp::URIForFile &uri,
455                                             const lsp::Position &hoverPos) {
456   SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
457 
458   // Check for a reference to an include.
459   for (const PDLLInclude &include : parsedIncludes) {
460     if (include.range.contains(hoverPos))
461       return buildHoverForInclude(include);
462   }
463 
464   // Find the symbol at the given location.
465   SMRange hoverRange;
466   const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange);
467   if (!symbol)
468     return llvm::None;
469 
470   // Add hover for operation names.
471   if (const auto *op = symbol->definition.dyn_cast<const ods::Operation *>())
472     return buildHoverForOpName(op, hoverRange);
473   const auto *decl = symbol->definition.get<const ast::Decl *>();
474   return findHover(decl, hoverRange);
475 }
476 
477 Optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl,
478                                             const SMRange &hoverRange) {
479   // Add hover for variables.
480   if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
481     return buildHoverForVariable(varDecl, hoverRange);
482 
483   // Add hover for patterns.
484   if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl))
485     return buildHoverForPattern(patternDecl, hoverRange);
486 
487   // Add hover for core constraints.
488   if (const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl))
489     return buildHoverForCoreConstraint(cst, hoverRange);
490 
491   // Add hover for user constraints.
492   if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl))
493     return buildHoverForUserConstraintOrRewrite("Constraint", cst, hoverRange);
494 
495   // Add hover for user rewrites.
496   if (const auto *rewrite = dyn_cast<ast::UserRewriteDecl>(decl))
497     return buildHoverForUserConstraintOrRewrite("Rewrite", rewrite, hoverRange);
498 
499   return llvm::None;
500 }
501 
502 lsp::Hover PDLDocument::buildHoverForInclude(const PDLLInclude &include) {
503   lsp::Hover hover(include.range);
504   {
505     llvm::raw_string_ostream hoverOS(hover.contents.value);
506     hoverOS << "`" << llvm::sys::path::filename(include.uri.file())
507             << "`\n***\n"
508             << include.uri.file();
509   }
510   return hover;
511 }
512 
513 lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
514                                             const SMRange &hoverRange) {
515   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
516   {
517     llvm::raw_string_ostream hoverOS(hover.contents.value);
518     hoverOS << "**OpName**: `" << op->getName() << "`\n***\n"
519             << op->getSummary() << "\n***\n"
520             << op->getDescription();
521   }
522   return hover;
523 }
524 
525 lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
526                                               const SMRange &hoverRange) {
527   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
528   {
529     llvm::raw_string_ostream hoverOS(hover.contents.value);
530     hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n"
531             << "Type: `" << varDecl->getType() << "`\n";
532   }
533   return hover;
534 }
535 
536 lsp::Hover
537 PDLDocument::buildHoverForPattern(const ast::PatternDecl *patternDecl,
538                                   const SMRange &hoverRange) {
539   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
540   {
541     llvm::raw_string_ostream hoverOS(hover.contents.value);
542     hoverOS << "**Pattern**";
543     if (const ast::Name *name = patternDecl->getName())
544       hoverOS << ": `" << name->getName() << "`";
545     hoverOS << "\n***\n";
546     if (Optional<uint16_t> benefit = patternDecl->getBenefit())
547       hoverOS << "Benefit: " << *benefit << "\n";
548     if (patternDecl->hasBoundedRewriteRecursion())
549       hoverOS << "HasBoundedRewriteRecursion\n";
550     hoverOS << "RootOp: `"
551             << patternDecl->getRootRewriteStmt()->getRootOpExpr()->getType()
552             << "`\n";
553   }
554   return hover;
555 }
556 
557 lsp::Hover
558 PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
559                                          const SMRange &hoverRange) {
560   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
561   {
562     llvm::raw_string_ostream hoverOS(hover.contents.value);
563     hoverOS << "**Constraint**: `";
564     TypeSwitch<const ast::Decl *>(decl)
565         .Case([&](const ast::AttrConstraintDecl *) { hoverOS << "Attr"; })
566         .Case([&](const ast::OpConstraintDecl *opCst) {
567           hoverOS << "Op";
568           if (Optional<StringRef> name = opCst->getName())
569             hoverOS << "<" << name << ">";
570         })
571         .Case([&](const ast::TypeConstraintDecl *) { hoverOS << "Type"; })
572         .Case([&](const ast::TypeRangeConstraintDecl *) {
573           hoverOS << "TypeRange";
574         })
575         .Case([&](const ast::ValueConstraintDecl *) { hoverOS << "Value"; })
576         .Case([&](const ast::ValueRangeConstraintDecl *) {
577           hoverOS << "ValueRange";
578         });
579     hoverOS << "`\n";
580   }
581   return hover;
582 }
583 
584 template <typename T>
585 lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
586     StringRef typeName, const T *decl, const SMRange &hoverRange) {
587   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
588   {
589     llvm::raw_string_ostream hoverOS(hover.contents.value);
590     hoverOS << "**" << typeName << "**: `" << decl->getName().getName()
591             << "`\n***\n";
592     ArrayRef<ast::VariableDecl *> inputs = decl->getInputs();
593     if (!inputs.empty()) {
594       hoverOS << "Parameters:\n";
595       for (const ast::VariableDecl *input : inputs)
596         hoverOS << "* " << input->getName().getName() << ": `"
597                 << input->getType() << "`\n";
598       hoverOS << "***\n";
599     }
600     ast::Type resultType = decl->getResultType();
601     if (auto resultTupleTy = resultType.dyn_cast<ast::TupleType>()) {
602       if (resultTupleTy.empty())
603         return hover;
604 
605       hoverOS << "Results:\n";
606       for (auto it : llvm::zip(resultTupleTy.getElementNames(),
607                                resultTupleTy.getElementTypes())) {
608         StringRef name = std::get<0>(it);
609         hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
610                 << std::get<1>(it) << "`\n";
611       }
612     } else {
613       hoverOS << "Results:\n* `" << resultType << "`\n";
614     }
615     hoverOS << "***\n";
616   }
617   return hover;
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // PDLDocument: Document Symbols
622 //===----------------------------------------------------------------------===//
623 
624 void PDLDocument::findDocumentSymbols(
625     std::vector<lsp::DocumentSymbol> &symbols) {
626   if (failed(astModule))
627     return;
628 
629   for (const ast::Decl *decl : (*astModule)->getChildren()) {
630     if (!isMainFileLoc(sourceMgr, decl->getLoc()))
631       continue;
632 
633     if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) {
634       const ast::Name *name = patternDecl->getName();
635 
636       SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc();
637       SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
638 
639       symbols.emplace_back(
640           name ? name->getName() : "<pattern>", lsp::SymbolKind::Class,
641           lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
642     } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
643       // TODO: Add source information for the code block body.
644       SMRange nameLoc = cDecl->getName().getLoc();
645       SMRange bodyLoc = nameLoc;
646 
647       symbols.emplace_back(
648           cDecl->getName().getName(), lsp::SymbolKind::Function,
649           lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
650     } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
651       // TODO: Add source information for the code block body.
652       SMRange nameLoc = cDecl->getName().getLoc();
653       SMRange bodyLoc = nameLoc;
654 
655       symbols.emplace_back(
656           cDecl->getName().getName(), lsp::SymbolKind::Function,
657           lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
658     }
659   }
660 }
661 
662 //===----------------------------------------------------------------------===//
663 // PDLDocument: Code Completion
664 //===----------------------------------------------------------------------===//
665 
666 namespace {
667 class LSPCodeCompleteContext : public CodeCompleteContext {
668 public:
669   LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
670                          ods::Context &odsContext,
671                          ArrayRef<std::string> includeDirs)
672       : CodeCompleteContext(completeLoc), completionList(completionList),
673         odsContext(odsContext), includeDirs(includeDirs) {}
674 
675   void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
676     ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
677     ArrayRef<StringRef> elementNames = tupleType.getElementNames();
678     for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
679       // Push back a completion item that uses the result index.
680       lsp::CompletionItem item;
681       item.label = llvm::formatv("{0} (field #{0})", i).str();
682       item.insertText = Twine(i).str();
683       item.filterText = item.sortText = item.insertText;
684       item.kind = lsp::CompletionItemKind::Field;
685       item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]);
686       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
687       completionList.items.emplace_back(item);
688 
689       // If the element has a name, push back a completion item with that name.
690       if (!elementNames[i].empty()) {
691         item.label =
692             llvm::formatv("{1} (field #{0})", i, elementNames[i]).str();
693         item.filterText = item.label;
694         item.insertText = elementNames[i].str();
695         completionList.items.emplace_back(item);
696       }
697     }
698   }
699 
700   void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
701     Optional<StringRef> opName = opType.getName();
702     const ods::Operation *odsOp =
703         opName ? odsContext.lookupOperation(*opName) : nullptr;
704     if (!odsOp)
705       return;
706 
707     ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
708     for (const auto &it : llvm::enumerate(results)) {
709       const ods::OperandOrResult &result = it.value();
710       const ods::TypeConstraint &constraint = result.getConstraint();
711 
712       // Push back a completion item that uses the result index.
713       lsp::CompletionItem item;
714       item.label = llvm::formatv("{0} (field #{0})", it.index()).str();
715       item.insertText = Twine(it.index()).str();
716       item.filterText = item.sortText = item.insertText;
717       item.kind = lsp::CompletionItemKind::Field;
718       switch (result.getVariableLengthKind()) {
719       case ods::VariableLengthKind::Single:
720         item.detail = llvm::formatv("{0}: Value", it.index()).str();
721         break;
722       case ods::VariableLengthKind::Optional:
723         item.detail = llvm::formatv("{0}: Value?", it.index()).str();
724         break;
725       case ods::VariableLengthKind::Variadic:
726         item.detail = llvm::formatv("{0}: ValueRange", it.index()).str();
727         break;
728       }
729       item.documentation = lsp::MarkupContent{
730           lsp::MarkupKind::Markdown,
731           llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
732                         constraint.getCppClass())
733               .str()};
734       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
735       completionList.items.emplace_back(item);
736 
737       // If the result has a name, push back a completion item with the result
738       // name.
739       if (!result.getName().empty()) {
740         item.label =
741             llvm::formatv("{1} (field #{0})", it.index(), result.getName())
742                 .str();
743         item.filterText = item.label;
744         item.insertText = result.getName().str();
745         completionList.items.emplace_back(item);
746       }
747     }
748   }
749 
750   void codeCompleteOperationAttributeName(StringRef opName) final {
751     const ods::Operation *odsOp = odsContext.lookupOperation(opName);
752     if (!odsOp)
753       return;
754 
755     for (const ods::Attribute &attr : odsOp->getAttributes()) {
756       const ods::AttributeConstraint &constraint = attr.getConstraint();
757 
758       lsp::CompletionItem item;
759       item.label = attr.getName().str();
760       item.kind = lsp::CompletionItemKind::Field;
761       item.detail = attr.isOptional() ? "optional" : "";
762       item.documentation = lsp::MarkupContent{
763           lsp::MarkupKind::Markdown,
764           llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
765                         constraint.getCppClass())
766               .str()};
767       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
768       completionList.items.emplace_back(item);
769     }
770   }
771 
772   void codeCompleteConstraintName(ast::Type currentType,
773                                   bool allowNonCoreConstraints,
774                                   bool allowInlineTypeConstraints,
775                                   const ast::DeclScope *scope) final {
776     auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
777                                  StringRef snippetText = "") {
778       lsp::CompletionItem item;
779       item.label = constraint.str();
780       item.kind = lsp::CompletionItemKind::Class;
781       item.detail = (constraint + " constraint").str();
782       item.documentation = lsp::MarkupContent{
783           lsp::MarkupKind::Markdown,
784           ("A single entity core constraint of type `" + mlirType + "`").str()};
785       item.sortText = "0";
786       item.insertText = snippetText.str();
787       item.insertTextFormat = snippetText.empty()
788                                   ? lsp::InsertTextFormat::PlainText
789                                   : lsp::InsertTextFormat::Snippet;
790       completionList.items.emplace_back(item);
791     };
792 
793     // Insert completions for the core constraints. Some core constraints have
794     // additional characteristics, so we may add then even if a type has been
795     // inferred.
796     if (!currentType) {
797       addCoreConstraint("Attr", "mlir::Attribute");
798       addCoreConstraint("Op", "mlir::Operation *");
799       addCoreConstraint("Value", "mlir::Value");
800       addCoreConstraint("ValueRange", "mlir::ValueRange");
801       addCoreConstraint("Type", "mlir::Type");
802       addCoreConstraint("TypeRange", "mlir::TypeRange");
803     }
804     if (allowInlineTypeConstraints) {
805       /// Attr<Type>.
806       if (!currentType || currentType.isa<ast::AttributeType>())
807         addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
808       /// Value<Type>.
809       if (!currentType || currentType.isa<ast::ValueType>())
810         addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
811       /// ValueRange<TypeRange>.
812       if (!currentType || currentType.isa<ast::ValueRangeType>())
813         addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
814                           "ValueRange<$1>");
815     }
816 
817     // If a scope was provided, check it for potential constraints.
818     while (scope) {
819       for (const ast::Decl *decl : scope->getDecls()) {
820         if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
821           if (!allowNonCoreConstraints)
822             continue;
823 
824           lsp::CompletionItem item;
825           item.label = cst->getName().getName().str();
826           item.kind = lsp::CompletionItemKind::Interface;
827           item.sortText = "2_" + item.label;
828 
829           // Skip constraints that are not single-arg. We currently only
830           // complete variable constraints.
831           if (cst->getInputs().size() != 1)
832             continue;
833 
834           // Ensure the input type matched the given type.
835           ast::Type constraintType = cst->getInputs()[0]->getType();
836           if (currentType && !currentType.refineWith(constraintType))
837             continue;
838 
839           // Format the constraint signature.
840           {
841             llvm::raw_string_ostream strOS(item.detail);
842             strOS << "(";
843             llvm::interleaveComma(
844                 cst->getInputs(), strOS, [&](const ast::VariableDecl *var) {
845                   strOS << var->getName().getName() << ": " << var->getType();
846                 });
847             strOS << ") -> " << cst->getResultType();
848           }
849 
850           completionList.items.emplace_back(item);
851         }
852       }
853 
854       scope = scope->getParentScope();
855     }
856   }
857 
858   void codeCompleteDialectName() final {
859     // Code complete known dialects.
860     for (const ods::Dialect &dialect : odsContext.getDialects()) {
861       lsp::CompletionItem item;
862       item.label = dialect.getName().str();
863       item.kind = lsp::CompletionItemKind::Class;
864       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
865       completionList.items.emplace_back(item);
866     }
867   }
868 
869   void codeCompleteOperationName(StringRef dialectName) final {
870     const ods::Dialect *dialect = odsContext.lookupDialect(dialectName);
871     if (!dialect)
872       return;
873 
874     for (const auto &it : dialect->getOperations()) {
875       const ods::Operation &op = *it.second;
876 
877       lsp::CompletionItem item;
878       item.label = op.getName().drop_front(dialectName.size() + 1).str();
879       item.kind = lsp::CompletionItemKind::Field;
880       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
881       completionList.items.emplace_back(item);
882     }
883   }
884 
885   void codeCompletePatternMetadata() final {
886     auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
887                                    StringRef snippetText = "") {
888       lsp::CompletionItem item;
889       item.label = constraint.str();
890       item.kind = lsp::CompletionItemKind::Class;
891       item.detail = "pattern metadata";
892       item.documentation =
893           lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()};
894       item.insertText = snippetText.str();
895       item.insertTextFormat = snippetText.empty()
896                                   ? lsp::InsertTextFormat::PlainText
897                                   : lsp::InsertTextFormat::Snippet;
898       completionList.items.emplace_back(item);
899     };
900 
901     addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
902                         "benefit($1)");
903     addSimpleConstraint("recursion",
904                         "The pattern properly handles recursive application.");
905   }
906 
907   void codeCompleteIncludeFilename(StringRef curPath) final {
908     // Normalize the path to allow for interacting with the file system
909     // utilities.
910     SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(curPath));
911     llvm::sys::path::native(nativeRelDir);
912 
913     // Set of already included completion paths.
914     StringSet<> seenResults;
915 
916     // Functor used to add a single include completion item.
917     auto addIncludeCompletion = [&](StringRef path, bool isDirectory) {
918       lsp::CompletionItem item;
919       item.label = (path + (isDirectory ? "/" : "")).str();
920       item.kind = isDirectory ? lsp::CompletionItemKind::Folder
921                               : lsp::CompletionItemKind::File;
922       if (seenResults.insert(item.label).second)
923         completionList.items.emplace_back(item);
924     };
925 
926     // Process the include directories for this file, adding any potential
927     // nested include files or directories.
928     for (StringRef includeDir : includeDirs) {
929       llvm::SmallString<128> dir = includeDir;
930       if (!nativeRelDir.empty())
931         llvm::sys::path::append(dir, nativeRelDir);
932 
933       std::error_code errorCode;
934       for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
935                 e = llvm::sys::fs::directory_iterator();
936            !errorCode && it != e; it.increment(errorCode)) {
937         StringRef filename = llvm::sys::path::filename(it->path());
938 
939         // To know whether a symlink should be treated as file or a directory,
940         // we have to stat it. This should be cheap enough as there shouldn't be
941         // many symlinks.
942         llvm::sys::fs::file_type fileType = it->type();
943         if (fileType == llvm::sys::fs::file_type::symlink_file) {
944           if (auto fileStatus = it->status())
945             fileType = fileStatus->type();
946         }
947 
948         switch (fileType) {
949         case llvm::sys::fs::file_type::directory_file:
950           addIncludeCompletion(filename, /*isDirectory=*/true);
951           break;
952         case llvm::sys::fs::file_type::regular_file: {
953           // Only consider concrete files that can actually be included by PDLL.
954           if (filename.endswith(".pdll") || filename.endswith(".td"))
955             addIncludeCompletion(filename, /*isDirectory=*/false);
956           break;
957         }
958         default:
959           break;
960         }
961       }
962     }
963 
964     // Sort the completion results to make sure the output is deterministic in
965     // the face of different iteration schemes for different platforms.
966     llvm::sort(completionList.items, [](const lsp::CompletionItem &lhs,
967                                         const lsp::CompletionItem &rhs) {
968       return lhs.label < rhs.label;
969     });
970   }
971 
972 private:
973   lsp::CompletionList &completionList;
974   ods::Context &odsContext;
975   ArrayRef<std::string> includeDirs;
976 };
977 } // namespace
978 
979 lsp::CompletionList
980 PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
981                                const lsp::Position &completePos) {
982   SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
983   if (!posLoc.isValid())
984     return lsp::CompletionList();
985 
986   // Adjust the position one further to after the completion trigger token.
987   posLoc = SMLoc::getFromPointer(posLoc.getPointer() + 1);
988 
989   // To perform code completion, we run another parse of the module with the
990   // code completion context provided.
991   ods::Context tmpODSContext;
992   lsp::CompletionList completionList;
993   LSPCodeCompleteContext lspCompleteContext(
994       posLoc, completionList, tmpODSContext, sourceMgr.getIncludeDirs());
995 
996   ast::Context tmpContext(tmpODSContext);
997   (void)parsePDLAST(tmpContext, sourceMgr, &lspCompleteContext);
998 
999   return completionList;
1000 }
1001 
1002 //===----------------------------------------------------------------------===//
1003 // PDLDocument: Signature Help
1004 //===----------------------------------------------------------------------===//
1005 
1006 namespace {
1007 class LSPSignatureHelpContext : public CodeCompleteContext {
1008 public:
1009   LSPSignatureHelpContext(SMLoc completeLoc, lsp::SignatureHelp &signatureHelp,
1010                           ods::Context &odsContext)
1011       : CodeCompleteContext(completeLoc), signatureHelp(signatureHelp),
1012         odsContext(odsContext) {}
1013 
1014   void codeCompleteCallSignature(const ast::CallableDecl *callable,
1015                                  unsigned currentNumArgs) final {
1016     signatureHelp.activeParameter = currentNumArgs;
1017 
1018     lsp::SignatureInformation signatureInfo;
1019     {
1020       llvm::raw_string_ostream strOS(signatureInfo.label);
1021       strOS << callable->getName()->getName() << "(";
1022       auto formatParamFn = [&](const ast::VariableDecl *var) {
1023         unsigned paramStart = strOS.str().size();
1024         strOS << var->getName().getName() << ": " << var->getType();
1025         unsigned paramEnd = strOS.str().size();
1026         signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1027             StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1028             std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()});
1029       };
1030       llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
1031       strOS << ") -> " << callable->getResultType();
1032     }
1033     signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1034   }
1035 
1036   void
1037   codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
1038                                          unsigned currentNumOperands) final {
1039     const ods::Operation *odsOp =
1040         opName ? odsContext.lookupOperation(*opName) : nullptr;
1041     codeCompleteOperationOperandOrResultSignature(
1042         opName, odsOp, odsOp ? odsOp->getOperands() : llvm::None,
1043         currentNumOperands, "operand", "Value");
1044   }
1045 
1046   void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
1047                                              unsigned currentNumResults) final {
1048     const ods::Operation *odsOp =
1049         opName ? odsContext.lookupOperation(*opName) : nullptr;
1050     codeCompleteOperationOperandOrResultSignature(
1051         opName, odsOp, odsOp ? odsOp->getResults() : llvm::None,
1052         currentNumResults, "result", "Type");
1053   }
1054 
1055   void codeCompleteOperationOperandOrResultSignature(
1056       Optional<StringRef> opName, const ods::Operation *odsOp,
1057       ArrayRef<ods::OperandOrResult> values, unsigned currentValue,
1058       StringRef label, StringRef dataType) {
1059     signatureHelp.activeParameter = currentValue;
1060 
1061     // If we have ODS information for the operation, add in the ODS signature
1062     // for the operation. We also verify that the current number of values is
1063     // not more than what is defined in ODS, as this will result in an error
1064     // anyways.
1065     if (odsOp && currentValue < values.size()) {
1066       lsp::SignatureInformation signatureInfo;
1067 
1068       // Build the signature label.
1069       {
1070         llvm::raw_string_ostream strOS(signatureInfo.label);
1071         strOS << "(";
1072         auto formatFn = [&](const ods::OperandOrResult &value) {
1073           unsigned paramStart = strOS.str().size();
1074 
1075           strOS << value.getName() << ": ";
1076 
1077           StringRef constraintDoc = value.getConstraint().getSummary();
1078           std::string paramDoc;
1079           switch (value.getVariableLengthKind()) {
1080           case ods::VariableLengthKind::Single:
1081             strOS << dataType;
1082             paramDoc = constraintDoc.str();
1083             break;
1084           case ods::VariableLengthKind::Optional:
1085             strOS << dataType << "?";
1086             paramDoc = ("optional: " + constraintDoc).str();
1087             break;
1088           case ods::VariableLengthKind::Variadic:
1089             strOS << dataType << "Range";
1090             paramDoc = ("variadic: " + constraintDoc).str();
1091             break;
1092           }
1093 
1094           unsigned paramEnd = strOS.str().size();
1095           signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1096               StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1097               std::make_pair(paramStart, paramEnd), paramDoc});
1098         };
1099         llvm::interleaveComma(values, strOS, formatFn);
1100         strOS << ")";
1101       }
1102       signatureInfo.documentation =
1103           llvm::formatv("`op<{0}>` ODS {1} specification", *opName, label)
1104               .str();
1105       signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1106     }
1107 
1108     // If there aren't any arguments yet, we also add the generic signature.
1109     if (currentValue == 0 && (!odsOp || !values.empty())) {
1110       lsp::SignatureInformation signatureInfo;
1111       signatureInfo.label =
1112           llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str();
1113       signatureInfo.documentation =
1114           ("Generic operation " + label + " specification").str();
1115       signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1116           StringRef(signatureInfo.label).drop_front().drop_back().str(),
1117           std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
1118           ("All of the " + label + "s of the operation.").str()});
1119       signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1120     }
1121   }
1122 
1123 private:
1124   lsp::SignatureHelp &signatureHelp;
1125   ods::Context &odsContext;
1126 };
1127 } // namespace
1128 
1129 lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri,
1130                                                  const lsp::Position &helpPos) {
1131   SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr);
1132   if (!posLoc.isValid())
1133     return lsp::SignatureHelp();
1134 
1135   // Adjust the position one further to after the completion trigger token.
1136   posLoc = SMLoc::getFromPointer(posLoc.getPointer() + 1);
1137 
1138   // To perform code completion, we run another parse of the module with the
1139   // code completion context provided.
1140   ods::Context tmpODSContext;
1141   lsp::SignatureHelp signatureHelp;
1142   LSPSignatureHelpContext completeContext(posLoc, signatureHelp, tmpODSContext);
1143 
1144   ast::Context tmpContext(tmpODSContext);
1145   (void)parsePDLAST(tmpContext, sourceMgr, &completeContext);
1146 
1147   return signatureHelp;
1148 }
1149 
1150 //===----------------------------------------------------------------------===//
1151 // PDLTextFileChunk
1152 //===----------------------------------------------------------------------===//
1153 
1154 namespace {
1155 /// This class represents a single chunk of an PDL text file.
1156 struct PDLTextFileChunk {
1157   PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri,
1158                    StringRef contents,
1159                    const std::vector<std::string> &extraDirs,
1160                    std::vector<lsp::Diagnostic> &diagnostics)
1161       : lineOffset(lineOffset),
1162         document(uri, contents, extraDirs, diagnostics) {}
1163 
1164   /// Adjust the line number of the given range to anchor at the beginning of
1165   /// the file, instead of the beginning of this chunk.
1166   void adjustLocForChunkOffset(lsp::Range &range) {
1167     adjustLocForChunkOffset(range.start);
1168     adjustLocForChunkOffset(range.end);
1169   }
1170   /// Adjust the line number of the given position to anchor at the beginning of
1171   /// the file, instead of the beginning of this chunk.
1172   void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
1173 
1174   /// The line offset of this chunk from the beginning of the file.
1175   uint64_t lineOffset;
1176   /// The document referred to by this chunk.
1177   PDLDocument document;
1178 };
1179 } // namespace
1180 
1181 //===----------------------------------------------------------------------===//
1182 // PDLTextFile
1183 //===----------------------------------------------------------------------===//
1184 
1185 namespace {
1186 /// This class represents a text file containing one or more PDL documents.
1187 class PDLTextFile {
1188 public:
1189   PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1190               int64_t version, const std::vector<std::string> &extraDirs,
1191               std::vector<lsp::Diagnostic> &diagnostics);
1192 
1193   /// Return the current version of this text file.
1194   int64_t getVersion() const { return version; }
1195 
1196   //===--------------------------------------------------------------------===//
1197   // LSP Queries
1198   //===--------------------------------------------------------------------===//
1199 
1200   void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1201                       std::vector<lsp::Location> &locations);
1202   void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1203                         std::vector<lsp::Location> &references);
1204   void getDocumentLinks(const lsp::URIForFile &uri,
1205                         std::vector<lsp::DocumentLink> &links);
1206   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1207                                  lsp::Position hoverPos);
1208   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1209   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1210                                         lsp::Position completePos);
1211   lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
1212                                       lsp::Position helpPos);
1213 
1214 private:
1215   /// Find the PDL document that contains the given position, and update the
1216   /// position to be anchored at the start of the found chunk instead of the
1217   /// beginning of the file.
1218   PDLTextFileChunk &getChunkFor(lsp::Position &pos);
1219 
1220   /// The full string contents of the file.
1221   std::string contents;
1222 
1223   /// The version of this file.
1224   int64_t version;
1225 
1226   /// The number of lines in the file.
1227   int64_t totalNumLines = 0;
1228 
1229   /// The chunks of this file. The order of these chunks is the order in which
1230   /// they appear in the text file.
1231   std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1232 };
1233 } // namespace
1234 
1235 PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1236                          int64_t version,
1237                          const std::vector<std::string> &extraDirs,
1238                          std::vector<lsp::Diagnostic> &diagnostics)
1239     : contents(fileContents.str()), version(version) {
1240   // Split the file into separate PDL documents.
1241   // TODO: Find a way to share the split file marker with other tools. We don't
1242   // want to use `splitAndProcessBuffer` here, but we do want to make sure this
1243   // marker doesn't go out of sync.
1244   SmallVector<StringRef, 8> subContents;
1245   StringRef(contents).split(subContents, "// -----");
1246   chunks.emplace_back(std::make_unique<PDLTextFileChunk>(
1247       /*lineOffset=*/0, uri, subContents.front(), extraDirs, diagnostics));
1248 
1249   uint64_t lineOffset = subContents.front().count('\n');
1250   for (StringRef docContents : llvm::drop_begin(subContents)) {
1251     unsigned currentNumDiags = diagnostics.size();
1252     auto chunk = std::make_unique<PDLTextFileChunk>(
1253         lineOffset, uri, docContents, extraDirs, diagnostics);
1254     lineOffset += docContents.count('\n');
1255 
1256     // Adjust locations used in diagnostics to account for the offset from the
1257     // beginning of the file.
1258     for (lsp::Diagnostic &diag :
1259          llvm::drop_begin(diagnostics, currentNumDiags)) {
1260       chunk->adjustLocForChunkOffset(diag.range);
1261 
1262       if (!diag.relatedInformation)
1263         continue;
1264       for (auto &it : *diag.relatedInformation)
1265         if (it.location.uri == uri)
1266           chunk->adjustLocForChunkOffset(it.location.range);
1267     }
1268     chunks.emplace_back(std::move(chunk));
1269   }
1270   totalNumLines = lineOffset;
1271 }
1272 
1273 void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri,
1274                                  lsp::Position defPos,
1275                                  std::vector<lsp::Location> &locations) {
1276   PDLTextFileChunk &chunk = getChunkFor(defPos);
1277   chunk.document.getLocationsOf(uri, defPos, locations);
1278 
1279   // Adjust any locations within this file for the offset of this chunk.
1280   if (chunk.lineOffset == 0)
1281     return;
1282   for (lsp::Location &loc : locations)
1283     if (loc.uri == uri)
1284       chunk.adjustLocForChunkOffset(loc.range);
1285 }
1286 
1287 void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri,
1288                                    lsp::Position pos,
1289                                    std::vector<lsp::Location> &references) {
1290   PDLTextFileChunk &chunk = getChunkFor(pos);
1291   chunk.document.findReferencesOf(uri, pos, references);
1292 
1293   // Adjust any locations within this file for the offset of this chunk.
1294   if (chunk.lineOffset == 0)
1295     return;
1296   for (lsp::Location &loc : references)
1297     if (loc.uri == uri)
1298       chunk.adjustLocForChunkOffset(loc.range);
1299 }
1300 
1301 void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri,
1302                                    std::vector<lsp::DocumentLink> &links) {
1303   chunks.front()->document.getDocumentLinks(uri, links);
1304   for (const auto &it : llvm::drop_begin(chunks)) {
1305     size_t currentNumLinks = links.size();
1306     it->document.getDocumentLinks(uri, links);
1307 
1308     // Adjust any links within this file to account for the offset of this
1309     // chunk.
1310     for (auto &link : llvm::drop_begin(links, currentNumLinks))
1311       it->adjustLocForChunkOffset(link.range);
1312   }
1313 }
1314 
1315 Optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri,
1316                                             lsp::Position hoverPos) {
1317   PDLTextFileChunk &chunk = getChunkFor(hoverPos);
1318   Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1319 
1320   // Adjust any locations within this file for the offset of this chunk.
1321   if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1322     chunk.adjustLocForChunkOffset(*hoverInfo->range);
1323   return hoverInfo;
1324 }
1325 
1326 void PDLTextFile::findDocumentSymbols(
1327     std::vector<lsp::DocumentSymbol> &symbols) {
1328   if (chunks.size() == 1)
1329     return chunks.front()->document.findDocumentSymbols(symbols);
1330 
1331   // If there are multiple chunks in this file, we create top-level symbols for
1332   // each chunk.
1333   for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1334     PDLTextFileChunk &chunk = *chunks[i];
1335     lsp::Position startPos(chunk.lineOffset);
1336     lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1337                                       : chunks[i + 1]->lineOffset);
1338     lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1339                                lsp::SymbolKind::Namespace,
1340                                /*range=*/lsp::Range(startPos, endPos),
1341                                /*selectionRange=*/lsp::Range(startPos));
1342     chunk.document.findDocumentSymbols(symbol.children);
1343 
1344     // Fixup the locations of document symbols within this chunk.
1345     if (i != 0) {
1346       SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1347       for (lsp::DocumentSymbol &childSymbol : symbol.children)
1348         symbolsToFix.push_back(&childSymbol);
1349 
1350       while (!symbolsToFix.empty()) {
1351         lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1352         chunk.adjustLocForChunkOffset(symbol->range);
1353         chunk.adjustLocForChunkOffset(symbol->selectionRange);
1354 
1355         for (lsp::DocumentSymbol &childSymbol : symbol->children)
1356           symbolsToFix.push_back(&childSymbol);
1357       }
1358     }
1359 
1360     // Push the symbol for this chunk.
1361     symbols.emplace_back(std::move(symbol));
1362   }
1363 }
1364 
1365 lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1366                                                    lsp::Position completePos) {
1367   PDLTextFileChunk &chunk = getChunkFor(completePos);
1368   lsp::CompletionList completionList =
1369       chunk.document.getCodeCompletion(uri, completePos);
1370 
1371   // Adjust any completion locations.
1372   for (lsp::CompletionItem &item : completionList.items) {
1373     if (item.textEdit)
1374       chunk.adjustLocForChunkOffset(item.textEdit->range);
1375     for (lsp::TextEdit &edit : item.additionalTextEdits)
1376       chunk.adjustLocForChunkOffset(edit.range);
1377   }
1378   return completionList;
1379 }
1380 
1381 lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri,
1382                                                  lsp::Position helpPos) {
1383   return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
1384 }
1385 
1386 PDLTextFileChunk &PDLTextFile::getChunkFor(lsp::Position &pos) {
1387   if (chunks.size() == 1)
1388     return *chunks.front();
1389 
1390   // Search for the first chunk with a greater line offset, the previous chunk
1391   // is the one that contains `pos`.
1392   auto it = llvm::upper_bound(
1393       chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1394         return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1395       });
1396   PDLTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1397   pos.line -= chunk.lineOffset;
1398   return chunk;
1399 }
1400 
1401 //===----------------------------------------------------------------------===//
1402 // PDLLServer::Impl
1403 //===----------------------------------------------------------------------===//
1404 
1405 struct lsp::PDLLServer::Impl {
1406   explicit Impl(const Options &options)
1407       : options(options), compilationDatabase(options.compilationDatabases) {}
1408 
1409   /// PDLL LSP options.
1410   const Options &options;
1411 
1412   /// The compilation database containing additional information for files
1413   /// passed to the server.
1414   lsp::CompilationDatabase compilationDatabase;
1415 
1416   /// The files held by the server, mapped by their URI file name.
1417   llvm::StringMap<std::unique_ptr<PDLTextFile>> files;
1418 };
1419 
1420 //===----------------------------------------------------------------------===//
1421 // PDLLServer
1422 //===----------------------------------------------------------------------===//
1423 
1424 lsp::PDLLServer::PDLLServer(const Options &options)
1425     : impl(std::make_unique<Impl>(options)) {}
1426 lsp::PDLLServer::~PDLLServer() = default;
1427 
1428 void lsp::PDLLServer::addOrUpdateDocument(
1429     const URIForFile &uri, StringRef contents, int64_t version,
1430     std::vector<Diagnostic> &diagnostics) {
1431   std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs;
1432   if (auto *fileInfo = impl->compilationDatabase.getFileInfo(uri.file()))
1433     llvm::append_range(additionalIncludeDirs, fileInfo->includeDirs);
1434 
1435   impl->files[uri.file()] = std::make_unique<PDLTextFile>(
1436       uri, contents, version, additionalIncludeDirs, diagnostics);
1437 }
1438 
1439 Optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) {
1440   auto it = impl->files.find(uri.file());
1441   if (it == impl->files.end())
1442     return llvm::None;
1443 
1444   int64_t version = it->second->getVersion();
1445   impl->files.erase(it);
1446   return version;
1447 }
1448 
1449 void lsp::PDLLServer::getLocationsOf(const URIForFile &uri,
1450                                      const Position &defPos,
1451                                      std::vector<Location> &locations) {
1452   auto fileIt = impl->files.find(uri.file());
1453   if (fileIt != impl->files.end())
1454     fileIt->second->getLocationsOf(uri, defPos, locations);
1455 }
1456 
1457 void lsp::PDLLServer::findReferencesOf(const URIForFile &uri,
1458                                        const Position &pos,
1459                                        std::vector<Location> &references) {
1460   auto fileIt = impl->files.find(uri.file());
1461   if (fileIt != impl->files.end())
1462     fileIt->second->findReferencesOf(uri, pos, references);
1463 }
1464 
1465 void lsp::PDLLServer::getDocumentLinks(
1466     const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1467   auto fileIt = impl->files.find(uri.file());
1468   if (fileIt != impl->files.end())
1469     return fileIt->second->getDocumentLinks(uri, documentLinks);
1470 }
1471 
1472 Optional<lsp::Hover> lsp::PDLLServer::findHover(const URIForFile &uri,
1473                                                 const Position &hoverPos) {
1474   auto fileIt = impl->files.find(uri.file());
1475   if (fileIt != impl->files.end())
1476     return fileIt->second->findHover(uri, hoverPos);
1477   return llvm::None;
1478 }
1479 
1480 void lsp::PDLLServer::findDocumentSymbols(
1481     const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1482   auto fileIt = impl->files.find(uri.file());
1483   if (fileIt != impl->files.end())
1484     fileIt->second->findDocumentSymbols(symbols);
1485 }
1486 
1487 lsp::CompletionList
1488 lsp::PDLLServer::getCodeCompletion(const URIForFile &uri,
1489                                    const Position &completePos) {
1490   auto fileIt = impl->files.find(uri.file());
1491   if (fileIt != impl->files.end())
1492     return fileIt->second->getCodeCompletion(uri, completePos);
1493   return CompletionList();
1494 }
1495 
1496 lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri,
1497                                                      const Position &helpPos) {
1498   auto fileIt = impl->files.find(uri.file());
1499   if (fileIt != impl->files.end())
1500     return fileIt->second->getSignatureHelp(uri, helpPos);
1501   return SignatureHelp();
1502 }
1503