1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MLIRServer.h"
10 #include "../lsp-server-support/Logging.h"
11 #include "../lsp-server-support/Protocol.h"
12 #include "../lsp-server-support/SourceMgrUtils.h"
13 #include "mlir/IR/FunctionInterfaces.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/Parser/AsmParserState.h"
16 #include "mlir/Parser/Parser.h"
17 #include "llvm/Support/SourceMgr.h"
18 
19 using namespace mlir;
20 
21 /// Returns a language server location from the given MLIR file location.
22 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) {
23   llvm::Expected<lsp::URIForFile> sourceURI =
24       lsp::URIForFile::fromFile(loc.getFilename());
25   if (!sourceURI) {
26     lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
27                        loc.getFilename(),
28                        llvm::toString(sourceURI.takeError()));
29     return llvm::None;
30   }
31 
32   lsp::Position position;
33   position.line = loc.getLine() - 1;
34   position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
35   return lsp::Location{*sourceURI, lsp::Range(position)};
36 }
37 
38 /// Returns a language server location from the given MLIR location, or None if
39 /// one couldn't be created. `uri` is an optional additional filter that, when
40 /// present, is used to filter sub locations that do not share the same uri.
41 static Optional<lsp::Location>
42 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
43                    const lsp::URIForFile *uri = nullptr) {
44   Optional<lsp::Location> location;
45   loc->walk([&](Location nestedLoc) {
46     FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
47     if (!fileLoc)
48       return WalkResult::advance();
49 
50     Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
51     if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
52       location = *sourceLoc;
53       SMLoc loc = sourceMgr.FindLocForLineAndColumn(
54           sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
55 
56       // Use range of potential identifier starting at location, else length 1
57       // range.
58       location->range.end.character += 1;
59       if (Optional<SMRange> range = lsp::convertTokenLocToRange(loc)) {
60         auto lineCol = sourceMgr.getLineAndColumn(range->End);
61         location->range.end.character =
62             std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
63       }
64       return WalkResult::interrupt();
65     }
66     return WalkResult::advance();
67   });
68   return location;
69 }
70 
71 /// Collect all of the locations from the given MLIR location that are not
72 /// contained within the given URI.
73 static void collectLocationsFromLoc(Location loc,
74                                     std::vector<lsp::Location> &locations,
75                                     const lsp::URIForFile &uri) {
76   SetVector<Location> visitedLocs;
77   loc->walk([&](Location nestedLoc) {
78     FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
79     if (!fileLoc || !visitedLocs.insert(nestedLoc))
80       return WalkResult::advance();
81 
82     Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
83     if (sourceLoc && sourceLoc->uri != uri)
84       locations.push_back(*sourceLoc);
85     return WalkResult::advance();
86   });
87 }
88 
89 /// Returns true if the given range contains the given source location. Note
90 /// that this has slightly different behavior than SMRange because it is
91 /// inclusive of the end location.
92 static bool contains(SMRange range, SMLoc loc) {
93   return range.Start.getPointer() <= loc.getPointer() &&
94          loc.getPointer() <= range.End.getPointer();
95 }
96 
97 /// Returns true if the given location is contained by the definition or one of
98 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
99 /// the range within `def` that the provided `loc` overlapped with.
100 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
101                        SMRange *overlappedRange = nullptr) {
102   // Check the main definition.
103   if (contains(def.loc, loc)) {
104     if (overlappedRange)
105       *overlappedRange = def.loc;
106     return true;
107   }
108 
109   // Check the uses.
110   const auto *useIt = llvm::find_if(
111       def.uses, [&](const SMRange &range) { return contains(range, loc); });
112   if (useIt != def.uses.end()) {
113     if (overlappedRange)
114       *overlappedRange = *useIt;
115     return true;
116   }
117   return false;
118 }
119 
120 /// Given a location pointing to a result, return the result number it refers
121 /// to or None if it refers to all of the results.
122 static Optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
123   // Skip all of the identifier characters.
124   auto isIdentifierChar = [](char c) {
125     return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
126            c == '-';
127   };
128   const char *curPtr = loc.getPointer();
129   while (isIdentifierChar(*curPtr))
130     ++curPtr;
131 
132   // Check to see if this location indexes into the result group, via `#`. If it
133   // doesn't, we can't extract a sub result number.
134   if (*curPtr != '#')
135     return llvm::None;
136 
137   // Compute the sub result number from the remaining portion of the string.
138   const char *numberStart = ++curPtr;
139   while (llvm::isDigit(*curPtr))
140     ++curPtr;
141   StringRef numberStr(numberStart, curPtr - numberStart);
142   unsigned resultNumber = 0;
143   return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>()
144                                                     : resultNumber;
145 }
146 
147 /// Given a source location range, return the text covered by the given range.
148 /// If the range is invalid, returns None.
149 static Optional<StringRef> getTextFromRange(SMRange range) {
150   if (!range.isValid())
151     return None;
152   const char *startPtr = range.Start.getPointer();
153   return StringRef(startPtr, range.End.getPointer() - startPtr);
154 }
155 
156 /// Given a block, return its position in its parent region.
157 static unsigned getBlockNumber(Block *block) {
158   return std::distance(block->getParent()->begin(), block->getIterator());
159 }
160 
161 /// Given a block and source location, print the source name of the block to the
162 /// given output stream.
163 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
164   // Try to extract a name from the source location.
165   Optional<StringRef> text = getTextFromRange(loc);
166   if (text && text->startswith("^")) {
167     os << *text;
168     return;
169   }
170 
171   // Otherwise, we don't have a name so print the block number.
172   os << "<Block #" << getBlockNumber(block) << ">";
173 }
174 static void printDefBlockName(raw_ostream &os,
175                               const AsmParserState::BlockDefinition &def) {
176   printDefBlockName(os, def.block, def.definition.loc);
177 }
178 
179 /// Convert the given MLIR diagnostic to the LSP form.
180 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
181                                                Diagnostic &diag,
182                                                const lsp::URIForFile &uri) {
183   lsp::Diagnostic lspDiag;
184   lspDiag.source = "mlir";
185 
186   // Note: Right now all of the diagnostics are treated as parser issues, but
187   // some are parser and some are verifier.
188   lspDiag.category = "Parse Error";
189 
190   // Try to grab a file location for this diagnostic.
191   // TODO: For simplicity, we just grab the first one. It may be likely that we
192   // will need a more interesting heuristic here.'
193   Optional<lsp::Location> lspLocation =
194       getLocationFromLoc(sourceMgr, diag.getLocation(), &uri);
195   if (lspLocation)
196     lspDiag.range = lspLocation->range;
197 
198   // Convert the severity for the diagnostic.
199   switch (diag.getSeverity()) {
200   case DiagnosticSeverity::Note:
201     llvm_unreachable("expected notes to be handled separately");
202   case DiagnosticSeverity::Warning:
203     lspDiag.severity = lsp::DiagnosticSeverity::Warning;
204     break;
205   case DiagnosticSeverity::Error:
206     lspDiag.severity = lsp::DiagnosticSeverity::Error;
207     break;
208   case DiagnosticSeverity::Remark:
209     lspDiag.severity = lsp::DiagnosticSeverity::Information;
210     break;
211   }
212   lspDiag.message = diag.str();
213 
214   // Attach any notes to the main diagnostic as related information.
215   std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
216   for (Diagnostic &note : diag.getNotes()) {
217     lsp::Location noteLoc;
218     if (Optional<lsp::Location> loc =
219             getLocationFromLoc(sourceMgr, note.getLocation()))
220       noteLoc = *loc;
221     else
222       noteLoc.uri = uri;
223     relatedDiags.emplace_back(noteLoc, note.str());
224   }
225   if (!relatedDiags.empty())
226     lspDiag.relatedInformation = std::move(relatedDiags);
227 
228   return lspDiag;
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // MLIRDocument
233 //===----------------------------------------------------------------------===//
234 
235 namespace {
236 /// This class represents all of the information pertaining to a specific MLIR
237 /// document.
238 struct MLIRDocument {
239   MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
240                StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
241   MLIRDocument(const MLIRDocument &) = delete;
242   MLIRDocument &operator=(const MLIRDocument &) = delete;
243 
244   //===--------------------------------------------------------------------===//
245   // Definitions and References
246   //===--------------------------------------------------------------------===//
247 
248   void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
249                       std::vector<lsp::Location> &locations);
250   void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
251                         std::vector<lsp::Location> &references);
252 
253   //===--------------------------------------------------------------------===//
254   // Hover
255   //===--------------------------------------------------------------------===//
256 
257   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
258                                  const lsp::Position &hoverPos);
259   Optional<lsp::Hover>
260   buildHoverForOperation(SMRange hoverRange,
261                          const AsmParserState::OperationDefinition &op);
262   lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
263                                           unsigned resultStart,
264                                           unsigned resultEnd, SMLoc posLoc);
265   lsp::Hover buildHoverForBlock(SMRange hoverRange,
266                                 const AsmParserState::BlockDefinition &block);
267   lsp::Hover
268   buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
269                              const AsmParserState::BlockDefinition &block);
270 
271   //===--------------------------------------------------------------------===//
272   // Document Symbols
273   //===--------------------------------------------------------------------===//
274 
275   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
276   void findDocumentSymbols(Operation *op,
277                            std::vector<lsp::DocumentSymbol> &symbols);
278 
279   //===--------------------------------------------------------------------===//
280   // Fields
281   //===--------------------------------------------------------------------===//
282 
283   /// The high level parser state used to find definitions and references within
284   /// the source file.
285   AsmParserState asmState;
286 
287   /// The container for the IR parsed from the input file.
288   Block parsedIR;
289 
290   /// The source manager containing the contents of the input file.
291   llvm::SourceMgr sourceMgr;
292 };
293 } // namespace
294 
295 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
296                            StringRef contents,
297                            std::vector<lsp::Diagnostic> &diagnostics) {
298   ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
299     diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
300   });
301 
302   // Try to parsed the given IR string.
303   auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
304   if (!memBuffer) {
305     lsp::Logger::error("Failed to create memory buffer for file", uri.file());
306     return;
307   }
308 
309   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
310   if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr,
311                              &asmState))) {
312     // If parsing failed, clear out any of the current state.
313     parsedIR.clear();
314     asmState = AsmParserState();
315     return;
316   }
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // MLIRDocument: Definitions and References
321 //===----------------------------------------------------------------------===//
322 
323 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
324                                   const lsp::Position &defPos,
325                                   std::vector<lsp::Location> &locations) {
326   SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
327 
328   // Functor used to check if an SM definition contains the position.
329   auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
330     if (!isDefOrUse(def, posLoc))
331       return false;
332     locations.emplace_back(uri, sourceMgr, def.loc);
333     return true;
334   };
335 
336   // Check all definitions related to operations.
337   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
338     if (contains(op.loc, posLoc))
339       return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
340     for (const auto &result : op.resultGroups)
341       if (containsPosition(result.definition))
342         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
343     for (const auto &symUse : op.symbolUses) {
344       if (contains(symUse, posLoc)) {
345         locations.emplace_back(uri, sourceMgr, op.loc);
346         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
347       }
348     }
349   }
350 
351   // Check all definitions related to blocks.
352   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
353     if (containsPosition(block.definition))
354       return;
355     for (const AsmParserState::SMDefinition &arg : block.arguments)
356       if (containsPosition(arg))
357         return;
358   }
359 }
360 
361 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
362                                     const lsp::Position &pos,
363                                     std::vector<lsp::Location> &references) {
364   // Functor used to append all of the definitions/uses of the given SM
365   // definition to the reference list.
366   auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
367     references.emplace_back(uri, sourceMgr, def.loc);
368     for (const SMRange &use : def.uses)
369       references.emplace_back(uri, sourceMgr, use);
370   };
371 
372   SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
373 
374   // Check all definitions related to operations.
375   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
376     if (contains(op.loc, posLoc)) {
377       for (const auto &result : op.resultGroups)
378         appendSMDef(result.definition);
379       for (const auto &symUse : op.symbolUses)
380         if (contains(symUse, posLoc))
381           references.emplace_back(uri, sourceMgr, symUse);
382       return;
383     }
384     for (const auto &result : op.resultGroups)
385       if (isDefOrUse(result.definition, posLoc))
386         return appendSMDef(result.definition);
387     for (const auto &symUse : op.symbolUses) {
388       if (!contains(symUse, posLoc))
389         continue;
390       for (const auto &symUse : op.symbolUses)
391         references.emplace_back(uri, sourceMgr, symUse);
392       return;
393     }
394   }
395 
396   // Check all definitions related to blocks.
397   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
398     if (isDefOrUse(block.definition, posLoc))
399       return appendSMDef(block.definition);
400 
401     for (const AsmParserState::SMDefinition &arg : block.arguments)
402       if (isDefOrUse(arg, posLoc))
403         return appendSMDef(arg);
404   }
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // MLIRDocument: Hover
409 //===----------------------------------------------------------------------===//
410 
411 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,
412                                              const lsp::Position &hoverPos) {
413   SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
414   SMRange hoverRange;
415 
416   // Check for Hovers on operations and results.
417   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
418     // Check if the position points at this operation.
419     if (contains(op.loc, posLoc))
420       return buildHoverForOperation(op.loc, op);
421 
422     // Check if the position points at the symbol name.
423     for (auto &use : op.symbolUses)
424       if (contains(use, posLoc))
425         return buildHoverForOperation(use, op);
426 
427     // Check if the position points at a result group.
428     for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
429       const auto &result = op.resultGroups[i];
430       if (!isDefOrUse(result.definition, posLoc, &hoverRange))
431         continue;
432 
433       // Get the range of results covered by the over position.
434       unsigned resultStart = result.startIndex;
435       unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
436                                         : op.resultGroups[i + 1].startIndex;
437       return buildHoverForOperationResult(hoverRange, op.op, resultStart,
438                                           resultEnd, posLoc);
439     }
440   }
441 
442   // Check to see if the hover is over a block argument.
443   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
444     if (isDefOrUse(block.definition, posLoc, &hoverRange))
445       return buildHoverForBlock(hoverRange, block);
446 
447     for (const auto &arg : llvm::enumerate(block.arguments)) {
448       if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
449         continue;
450 
451       return buildHoverForBlockArgument(
452           hoverRange, block.block->getArgument(arg.index()), block);
453     }
454   }
455   return llvm::None;
456 }
457 
458 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
459     SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
460   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
461   llvm::raw_string_ostream os(hover.contents.value);
462 
463   // Add the operation name to the hover.
464   os << "\"" << op.op->getName() << "\"";
465   if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
466     os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
467   os << "\n\n";
468 
469   os << "Generic Form:\n\n```mlir\n";
470 
471   // Temporary drop the regions of this operation so that they don't get
472   // printed in the output. This helps keeps the size of the output hover
473   // small.
474   SmallVector<std::unique_ptr<Region>> regions;
475   for (Region &region : op.op->getRegions()) {
476     regions.emplace_back(std::make_unique<Region>());
477     regions.back()->takeBody(region);
478   }
479 
480   op.op->print(
481       os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
482   os << "\n```\n";
483 
484   // Move the regions back to the current operation.
485   for (Region &region : op.op->getRegions())
486     region.takeBody(*regions.back());
487 
488   return hover;
489 }
490 
491 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
492                                                       Operation *op,
493                                                       unsigned resultStart,
494                                                       unsigned resultEnd,
495                                                       SMLoc posLoc) {
496   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
497   llvm::raw_string_ostream os(hover.contents.value);
498 
499   // Add the parent operation name to the hover.
500   os << "Operation: \"" << op->getName() << "\"\n\n";
501 
502   // Check to see if the location points to a specific result within the
503   // group.
504   if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
505     if ((resultStart + *resultNumber) < resultEnd) {
506       resultStart += *resultNumber;
507       resultEnd = resultStart + 1;
508     }
509   }
510 
511   // Add the range of results and their types to the hover info.
512   if ((resultStart + 1) == resultEnd) {
513     os << "Result #" << resultStart << "\n\n"
514        << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
515   } else {
516     os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
517        << "Types: ";
518     llvm::interleaveComma(
519         op->getResults().slice(resultStart, resultEnd), os,
520         [&](Value result) { os << "`" << result.getType() << "`"; });
521   }
522 
523   return hover;
524 }
525 
526 lsp::Hover
527 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
528                                  const AsmParserState::BlockDefinition &block) {
529   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
530   llvm::raw_string_ostream os(hover.contents.value);
531 
532   // Print the given block to the hover output stream.
533   auto printBlockToHover = [&](Block *newBlock) {
534     if (const auto *def = asmState.getBlockDef(newBlock))
535       printDefBlockName(os, *def);
536     else
537       printDefBlockName(os, newBlock);
538   };
539 
540   // Display the parent operation, block number, predecessors, and successors.
541   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
542      << "Block #" << getBlockNumber(block.block) << "\n\n";
543   if (!block.block->hasNoPredecessors()) {
544     os << "Predecessors: ";
545     llvm::interleaveComma(block.block->getPredecessors(), os,
546                           printBlockToHover);
547     os << "\n\n";
548   }
549   if (!block.block->hasNoSuccessors()) {
550     os << "Successors: ";
551     llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
552     os << "\n\n";
553   }
554 
555   return hover;
556 }
557 
558 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
559     SMRange hoverRange, BlockArgument arg,
560     const AsmParserState::BlockDefinition &block) {
561   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
562   llvm::raw_string_ostream os(hover.contents.value);
563 
564   // Display the parent operation, block, the argument number, and the type.
565   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
566      << "Block: ";
567   printDefBlockName(os, block);
568   os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
569      << "Type: `" << arg.getType() << "`\n\n";
570 
571   return hover;
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // MLIRDocument: Document Symbols
576 //===----------------------------------------------------------------------===//
577 
578 void MLIRDocument::findDocumentSymbols(
579     std::vector<lsp::DocumentSymbol> &symbols) {
580   for (Operation &op : parsedIR)
581     findDocumentSymbols(&op, symbols);
582 }
583 
584 void MLIRDocument::findDocumentSymbols(
585     Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
586   std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
587 
588   // Check for the source information of this operation.
589   if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
590     // If this operation defines a symbol, record it.
591     if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
592       symbols.emplace_back(symbol.getName(),
593                            isa<FunctionOpInterface>(op)
594                                ? lsp::SymbolKind::Function
595                                : lsp::SymbolKind::Class,
596                            lsp::Range(sourceMgr, def->scopeLoc),
597                            lsp::Range(sourceMgr, def->loc));
598       childSymbols = &symbols.back().children;
599 
600     } else if (op->hasTrait<OpTrait::SymbolTable>()) {
601       // Otherwise, if this is a symbol table push an anonymous document symbol.
602       symbols.emplace_back("<" + op->getName().getStringRef() + ">",
603                            lsp::SymbolKind::Namespace,
604                            lsp::Range(sourceMgr, def->scopeLoc),
605                            lsp::Range(sourceMgr, def->loc));
606       childSymbols = &symbols.back().children;
607     }
608   }
609 
610   // Recurse into the regions of this operation.
611   if (!op->getNumRegions())
612     return;
613   for (Region &region : op->getRegions())
614     for (Operation &childOp : region.getOps())
615       findDocumentSymbols(&childOp, *childSymbols);
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // MLIRTextFileChunk
620 //===----------------------------------------------------------------------===//
621 
622 namespace {
623 /// This class represents a single chunk of an MLIR text file.
624 struct MLIRTextFileChunk {
625   MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
626                     const lsp::URIForFile &uri, StringRef contents,
627                     std::vector<lsp::Diagnostic> &diagnostics)
628       : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
629 
630   /// Adjust the line number of the given range to anchor at the beginning of
631   /// the file, instead of the beginning of this chunk.
632   void adjustLocForChunkOffset(lsp::Range &range) {
633     adjustLocForChunkOffset(range.start);
634     adjustLocForChunkOffset(range.end);
635   }
636   /// Adjust the line number of the given position to anchor at the beginning of
637   /// the file, instead of the beginning of this chunk.
638   void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
639 
640   /// The line offset of this chunk from the beginning of the file.
641   uint64_t lineOffset;
642   /// The document referred to by this chunk.
643   MLIRDocument document;
644 };
645 } // namespace
646 
647 //===----------------------------------------------------------------------===//
648 // MLIRTextFile
649 //===----------------------------------------------------------------------===//
650 
651 namespace {
652 /// This class represents a text file containing one or more MLIR documents.
653 class MLIRTextFile {
654 public:
655   MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
656                int64_t version, DialectRegistry &registry,
657                std::vector<lsp::Diagnostic> &diagnostics);
658 
659   /// Return the current version of this text file.
660   int64_t getVersion() const { return version; }
661 
662   //===--------------------------------------------------------------------===//
663   // LSP Queries
664   //===--------------------------------------------------------------------===//
665 
666   void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
667                       std::vector<lsp::Location> &locations);
668   void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
669                         std::vector<lsp::Location> &references);
670   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
671                                  lsp::Position hoverPos);
672   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
673 
674 private:
675   /// Find the MLIR document that contains the given position, and update the
676   /// position to be anchored at the start of the found chunk instead of the
677   /// beginning of the file.
678   MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
679 
680   /// The context used to hold the state contained by the parsed document.
681   MLIRContext context;
682 
683   /// The full string contents of the file.
684   std::string contents;
685 
686   /// The version of this file.
687   int64_t version;
688 
689   /// The number of lines in the file.
690   int64_t totalNumLines = 0;
691 
692   /// The chunks of this file. The order of these chunks is the order in which
693   /// they appear in the text file.
694   std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
695 };
696 } // namespace
697 
698 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
699                            int64_t version, DialectRegistry &registry,
700                            std::vector<lsp::Diagnostic> &diagnostics)
701     : context(registry, MLIRContext::Threading::DISABLED),
702       contents(fileContents.str()), version(version) {
703   context.allowUnregisteredDialects();
704 
705   // Split the file into separate MLIR documents.
706   // TODO: Find a way to share the split file marker with other tools. We don't
707   // want to use `splitAndProcessBuffer` here, but we do want to make sure this
708   // marker doesn't go out of sync.
709   SmallVector<StringRef, 8> subContents;
710   StringRef(contents).split(subContents, "// -----");
711   chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
712       context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
713 
714   uint64_t lineOffset = subContents.front().count('\n');
715   for (StringRef docContents : llvm::drop_begin(subContents)) {
716     unsigned currentNumDiags = diagnostics.size();
717     auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
718                                                      docContents, diagnostics);
719     lineOffset += docContents.count('\n');
720 
721     // Adjust locations used in diagnostics to account for the offset from the
722     // beginning of the file.
723     for (lsp::Diagnostic &diag :
724          llvm::drop_begin(diagnostics, currentNumDiags)) {
725       chunk->adjustLocForChunkOffset(diag.range);
726 
727       if (!diag.relatedInformation)
728         continue;
729       for (auto &it : *diag.relatedInformation)
730         if (it.location.uri == uri)
731           chunk->adjustLocForChunkOffset(it.location.range);
732     }
733     chunks.emplace_back(std::move(chunk));
734   }
735   totalNumLines = lineOffset;
736 }
737 
738 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
739                                   lsp::Position defPos,
740                                   std::vector<lsp::Location> &locations) {
741   MLIRTextFileChunk &chunk = getChunkFor(defPos);
742   chunk.document.getLocationsOf(uri, defPos, locations);
743 
744   // Adjust any locations within this file for the offset of this chunk.
745   if (chunk.lineOffset == 0)
746     return;
747   for (lsp::Location &loc : locations)
748     if (loc.uri == uri)
749       chunk.adjustLocForChunkOffset(loc.range);
750 }
751 
752 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
753                                     lsp::Position pos,
754                                     std::vector<lsp::Location> &references) {
755   MLIRTextFileChunk &chunk = getChunkFor(pos);
756   chunk.document.findReferencesOf(uri, pos, references);
757 
758   // Adjust any locations within this file for the offset of this chunk.
759   if (chunk.lineOffset == 0)
760     return;
761   for (lsp::Location &loc : references)
762     if (loc.uri == uri)
763       chunk.adjustLocForChunkOffset(loc.range);
764 }
765 
766 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
767                                              lsp::Position hoverPos) {
768   MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
769   Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
770 
771   // Adjust any locations within this file for the offset of this chunk.
772   if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
773     chunk.adjustLocForChunkOffset(*hoverInfo->range);
774   return hoverInfo;
775 }
776 
777 void MLIRTextFile::findDocumentSymbols(
778     std::vector<lsp::DocumentSymbol> &symbols) {
779   if (chunks.size() == 1)
780     return chunks.front()->document.findDocumentSymbols(symbols);
781 
782   // If there are multiple chunks in this file, we create top-level symbols for
783   // each chunk.
784   for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
785     MLIRTextFileChunk &chunk = *chunks[i];
786     lsp::Position startPos(chunk.lineOffset);
787     lsp::Position endPos((i == e - 1) ? totalNumLines - 1
788                                       : chunks[i + 1]->lineOffset);
789     lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
790                                lsp::SymbolKind::Namespace,
791                                /*range=*/lsp::Range(startPos, endPos),
792                                /*selectionRange=*/lsp::Range(startPos));
793     chunk.document.findDocumentSymbols(symbol.children);
794 
795     // Fixup the locations of document symbols within this chunk.
796     if (i != 0) {
797       SmallVector<lsp::DocumentSymbol *> symbolsToFix;
798       for (lsp::DocumentSymbol &childSymbol : symbol.children)
799         symbolsToFix.push_back(&childSymbol);
800 
801       while (!symbolsToFix.empty()) {
802         lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
803         chunk.adjustLocForChunkOffset(symbol->range);
804         chunk.adjustLocForChunkOffset(symbol->selectionRange);
805 
806         for (lsp::DocumentSymbol &childSymbol : symbol->children)
807           symbolsToFix.push_back(&childSymbol);
808       }
809     }
810 
811     // Push the symbol for this chunk.
812     symbols.emplace_back(std::move(symbol));
813   }
814 }
815 
816 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
817   if (chunks.size() == 1)
818     return *chunks.front();
819 
820   // Search for the first chunk with a greater line offset, the previous chunk
821   // is the one that contains `pos`.
822   auto it = llvm::upper_bound(
823       chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
824         return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
825       });
826   MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
827   pos.line -= chunk.lineOffset;
828   return chunk;
829 }
830 
831 //===----------------------------------------------------------------------===//
832 // MLIRServer::Impl
833 //===----------------------------------------------------------------------===//
834 
835 struct lsp::MLIRServer::Impl {
836   Impl(DialectRegistry &registry) : registry(registry) {}
837 
838   /// The registry containing dialects that can be recognized in parsed .mlir
839   /// files.
840   DialectRegistry &registry;
841 
842   /// The files held by the server, mapped by their URI file name.
843   llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
844 };
845 
846 //===----------------------------------------------------------------------===//
847 // MLIRServer
848 //===----------------------------------------------------------------------===//
849 
850 lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
851     : impl(std::make_unique<Impl>(registry)) {}
852 lsp::MLIRServer::~MLIRServer() = default;
853 
854 void lsp::MLIRServer::addOrUpdateDocument(
855     const URIForFile &uri, StringRef contents, int64_t version,
856     std::vector<Diagnostic> &diagnostics) {
857   impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
858       uri, contents, version, impl->registry, diagnostics);
859 }
860 
861 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
862   auto it = impl->files.find(uri.file());
863   if (it == impl->files.end())
864     return llvm::None;
865 
866   int64_t version = it->second->getVersion();
867   impl->files.erase(it);
868   return version;
869 }
870 
871 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
872                                      const Position &defPos,
873                                      std::vector<Location> &locations) {
874   auto fileIt = impl->files.find(uri.file());
875   if (fileIt != impl->files.end())
876     fileIt->second->getLocationsOf(uri, defPos, locations);
877 }
878 
879 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
880                                        const Position &pos,
881                                        std::vector<Location> &references) {
882   auto fileIt = impl->files.find(uri.file());
883   if (fileIt != impl->files.end())
884     fileIt->second->findReferencesOf(uri, pos, references);
885 }
886 
887 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
888                                                 const Position &hoverPos) {
889   auto fileIt = impl->files.find(uri.file());
890   if (fileIt != impl->files.end())
891     return fileIt->second->findHover(uri, hoverPos);
892   return llvm::None;
893 }
894 
895 void lsp::MLIRServer::findDocumentSymbols(
896     const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
897   auto fileIt = impl->files.find(uri.file());
898   if (fileIt != impl->files.end())
899     fileIt->second->findDocumentSymbols(symbols);
900 }
901