1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MLIRServer.h"
10 #include "lsp/Logging.h"
11 #include "lsp/Protocol.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/Parser.h"
14 #include "mlir/Parser/AsmParserState.h"
15 #include "llvm/Support/SourceMgr.h"
16 
17 using namespace mlir;
18 
19 /// Returns a language server position for the given source location.
20 static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, llvm::SMLoc loc) {
21   std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc);
22   lsp::Position pos;
23   pos.line = lineAndCol.first - 1;
24   pos.character = lineAndCol.second - 1;
25   return pos;
26 }
27 
28 /// Returns a source location from the given language server position.
29 static llvm::SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) {
30   return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), pos.line + 1,
31                                      pos.character);
32 }
33 
34 /// Returns a language server range for the given source range.
35 static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range) {
36   return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, range.End)};
37 }
38 
39 /// Returns a language server location from the given source range.
40 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr,
41                                         llvm::SMRange range,
42                                         const lsp::URIForFile &uri) {
43   return lsp::Location{uri, getRangeFromLoc(mgr, range)};
44 }
45 
46 /// Returns a language server location from the given MLIR file location.
47 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) {
48   llvm::Expected<lsp::URIForFile> sourceURI =
49       lsp::URIForFile::fromFile(loc.getFilename());
50   if (!sourceURI) {
51     lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
52                        loc.getFilename(),
53                        llvm::toString(sourceURI.takeError()));
54     return llvm::None;
55   }
56 
57   lsp::Position position;
58   position.line = loc.getLine() - 1;
59   position.character = loc.getColumn();
60   return lsp::Location{*sourceURI, lsp::Range(position)};
61 }
62 
63 /// Returns a language server location from the given MLIR location, or None if
64 /// one couldn't be created. `uri` is an optional additional filter that, when
65 /// present, is used to filter sub locations that do not share the same uri.
66 static Optional<lsp::Location>
67 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
68                    const lsp::URIForFile *uri = nullptr) {
69   Optional<lsp::Location> location;
70   loc->walk([&](Location nestedLoc) {
71     FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
72     if (!fileLoc)
73       return WalkResult::advance();
74 
75     Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
76     if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
77       location = *sourceLoc;
78       llvm::SMLoc loc = sourceMgr.FindLocForLineAndColumn(
79           sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
80 
81       // Use range of potential identifier starting at location, else length 1
82       // range.
83       location->range.end.character += 1;
84       if (Optional<llvm::SMRange> range =
85               AsmParserState::convertIdLocToRange(loc)) {
86         auto lineCol = sourceMgr.getLineAndColumn(range->End);
87         location->range.end.character =
88             std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
89       }
90       return WalkResult::interrupt();
91     }
92     return WalkResult::advance();
93   });
94   return location;
95 }
96 
97 /// Collect all of the locations from the given MLIR location that are not
98 /// contained within the given URI.
99 static void collectLocationsFromLoc(Location loc,
100                                     std::vector<lsp::Location> &locations,
101                                     const lsp::URIForFile &uri) {
102   SetVector<Location> visitedLocs;
103   loc->walk([&](Location nestedLoc) {
104     FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
105     if (!fileLoc || !visitedLocs.insert(nestedLoc))
106       return WalkResult::advance();
107 
108     Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
109     if (sourceLoc && sourceLoc->uri != uri)
110       locations.push_back(*sourceLoc);
111     return WalkResult::advance();
112   });
113 }
114 
115 /// Returns true if the given range contains the given source location. Note
116 /// that this has slightly different behavior than SMRange because it is
117 /// inclusive of the end location.
118 static bool contains(llvm::SMRange range, llvm::SMLoc loc) {
119   return range.Start.getPointer() <= loc.getPointer() &&
120          loc.getPointer() <= range.End.getPointer();
121 }
122 
123 /// Returns true if the given location is contained by the definition or one of
124 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
125 /// the range within `def` that the provided `loc` overlapped with.
126 static bool isDefOrUse(const AsmParserState::SMDefinition &def, llvm::SMLoc loc,
127                        llvm::SMRange *overlappedRange = nullptr) {
128   // Check the main definition.
129   if (contains(def.loc, loc)) {
130     if (overlappedRange)
131       *overlappedRange = def.loc;
132     return true;
133   }
134 
135   // Check the uses.
136   const auto *useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) {
137     return contains(range, loc);
138   });
139   if (useIt != def.uses.end()) {
140     if (overlappedRange)
141       *overlappedRange = *useIt;
142     return true;
143   }
144   return false;
145 }
146 
147 /// Given a location pointing to a result, return the result number it refers
148 /// to or None if it refers to all of the results.
149 static Optional<unsigned> getResultNumberFromLoc(llvm::SMLoc loc) {
150   // Skip all of the identifier characters.
151   auto isIdentifierChar = [](char c) {
152     return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
153            c == '-';
154   };
155   const char *curPtr = loc.getPointer();
156   while (isIdentifierChar(*curPtr))
157     ++curPtr;
158 
159   // Check to see if this location indexes into the result group, via `#`. If it
160   // doesn't, we can't extract a sub result number.
161   if (*curPtr != '#')
162     return llvm::None;
163 
164   // Compute the sub result number from the remaining portion of the string.
165   const char *numberStart = ++curPtr;
166   while (llvm::isDigit(*curPtr))
167     ++curPtr;
168   StringRef numberStr(numberStart, curPtr - numberStart);
169   unsigned resultNumber = 0;
170   return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>()
171                                                     : resultNumber;
172 }
173 
174 /// Given a source location range, return the text covered by the given range.
175 /// If the range is invalid, returns None.
176 static Optional<StringRef> getTextFromRange(llvm::SMRange range) {
177   if (!range.isValid())
178     return None;
179   const char *startPtr = range.Start.getPointer();
180   return StringRef(startPtr, range.End.getPointer() - startPtr);
181 }
182 
183 /// Given a block, return its position in its parent region.
184 static unsigned getBlockNumber(Block *block) {
185   return std::distance(block->getParent()->begin(), block->getIterator());
186 }
187 
188 /// Given a block and source location, print the source name of the block to the
189 /// given output stream.
190 static void printDefBlockName(raw_ostream &os, Block *block,
191                               llvm::SMRange loc = {}) {
192   // Try to extract a name from the source location.
193   Optional<StringRef> text = getTextFromRange(loc);
194   if (text && text->startswith("^")) {
195     os << *text;
196     return;
197   }
198 
199   // Otherwise, we don't have a name so print the block number.
200   os << "<Block #" << getBlockNumber(block) << ">";
201 }
202 static void printDefBlockName(raw_ostream &os,
203                               const AsmParserState::BlockDefinition &def) {
204   printDefBlockName(os, def.block, def.definition.loc);
205 }
206 
207 /// Convert the given MLIR diagnostic to the LSP form.
208 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
209                                                Diagnostic &diag,
210                                                const lsp::URIForFile &uri) {
211   lsp::Diagnostic lspDiag;
212   lspDiag.source = "mlir";
213 
214   // Note: Right now all of the diagnostics are treated as parser issues, but
215   // some are parser and some are verifier.
216   lspDiag.category = "Parse Error";
217 
218   // Try to grab a file location for this diagnostic.
219   // TODO: For simplicity, we just grab the first one. It may be likely that we
220   // will need a more interesting heuristic here.'
221   Optional<lsp::Location> lspLocation =
222       getLocationFromLoc(sourceMgr, diag.getLocation(), &uri);
223   if (lspLocation)
224     lspDiag.range = lspLocation->range;
225 
226   // Convert the severity for the diagnostic.
227   switch (diag.getSeverity()) {
228   case DiagnosticSeverity::Note:
229     llvm_unreachable("expected notes to be handled separately");
230   case DiagnosticSeverity::Warning:
231     lspDiag.severity = lsp::DiagnosticSeverity::Warning;
232     break;
233   case DiagnosticSeverity::Error:
234     lspDiag.severity = lsp::DiagnosticSeverity::Error;
235     break;
236   case DiagnosticSeverity::Remark:
237     lspDiag.severity = lsp::DiagnosticSeverity::Information;
238     break;
239   }
240   lspDiag.message = diag.str();
241 
242   // Attach any notes to the main diagnostic as related information.
243   std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
244   for (Diagnostic &note : diag.getNotes()) {
245     lsp::Location noteLoc;
246     if (Optional<lsp::Location> loc =
247             getLocationFromLoc(sourceMgr, note.getLocation()))
248       noteLoc = *loc;
249     else
250       noteLoc.uri = uri;
251     relatedDiags.emplace_back(noteLoc, note.str());
252   }
253   if (!relatedDiags.empty())
254     lspDiag.relatedInformation = std::move(relatedDiags);
255 
256   return lspDiag;
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // MLIRDocument
261 //===----------------------------------------------------------------------===//
262 
263 namespace {
264 /// This class represents all of the information pertaining to a specific MLIR
265 /// document.
266 struct MLIRDocument {
267   MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
268                StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
269   MLIRDocument(const MLIRDocument &) = delete;
270   MLIRDocument &operator=(const MLIRDocument &) = delete;
271 
272   //===--------------------------------------------------------------------===//
273   // Definitions and References
274   //===--------------------------------------------------------------------===//
275 
276   void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
277                       std::vector<lsp::Location> &locations);
278   void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
279                         std::vector<lsp::Location> &references);
280 
281   //===--------------------------------------------------------------------===//
282   // Hover
283   //===--------------------------------------------------------------------===//
284 
285   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
286                                  const lsp::Position &hoverPos);
287   Optional<lsp::Hover>
288   buildHoverForOperation(llvm::SMRange hoverRange,
289                          const AsmParserState::OperationDefinition &op);
290   lsp::Hover buildHoverForOperationResult(llvm::SMRange hoverRange,
291                                           Operation *op, unsigned resultStart,
292                                           unsigned resultEnd,
293                                           llvm::SMLoc posLoc);
294   lsp::Hover buildHoverForBlock(llvm::SMRange hoverRange,
295                                 const AsmParserState::BlockDefinition &block);
296   lsp::Hover
297   buildHoverForBlockArgument(llvm::SMRange hoverRange, BlockArgument arg,
298                              const AsmParserState::BlockDefinition &block);
299 
300   //===--------------------------------------------------------------------===//
301   // Document Symbols
302   //===--------------------------------------------------------------------===//
303 
304   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
305   void findDocumentSymbols(Operation *op,
306                            std::vector<lsp::DocumentSymbol> &symbols);
307 
308   //===--------------------------------------------------------------------===//
309   // Fields
310   //===--------------------------------------------------------------------===//
311 
312   /// The high level parser state used to find definitions and references within
313   /// the source file.
314   AsmParserState asmState;
315 
316   /// The container for the IR parsed from the input file.
317   Block parsedIR;
318 
319   /// The source manager containing the contents of the input file.
320   llvm::SourceMgr sourceMgr;
321 };
322 } // namespace
323 
324 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
325                            StringRef contents,
326                            std::vector<lsp::Diagnostic> &diagnostics) {
327   ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
328     diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
329   });
330 
331   // Try to parsed the given IR string.
332   auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
333   if (!memBuffer) {
334     lsp::Logger::error("Failed to create memory buffer for file", uri.file());
335     return;
336   }
337 
338   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), llvm::SMLoc());
339   if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr,
340                              &asmState))) {
341     // If parsing failed, clear out any of the current state.
342     parsedIR.clear();
343     asmState = AsmParserState();
344     return;
345   }
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // MLIRDocument: Definitions and References
350 //===----------------------------------------------------------------------===//
351 
352 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
353                                   const lsp::Position &defPos,
354                                   std::vector<lsp::Location> &locations) {
355   llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, defPos);
356 
357   // Functor used to check if an SM definition contains the position.
358   auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
359     if (!isDefOrUse(def, posLoc))
360       return false;
361     locations.push_back(getLocationFromLoc(sourceMgr, def.loc, uri));
362     return true;
363   };
364 
365   // Check all definitions related to operations.
366   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
367     if (contains(op.loc, posLoc))
368       return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
369     for (const auto &result : op.resultGroups)
370       if (containsPosition(result.definition))
371         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
372     for (const auto &symUse : op.symbolUses) {
373       if (contains(symUse, posLoc)) {
374         locations.push_back(getLocationFromLoc(sourceMgr, op.loc, uri));
375         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
376       }
377     }
378   }
379 
380   // Check all definitions related to blocks.
381   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
382     if (containsPosition(block.definition))
383       return;
384     for (const AsmParserState::SMDefinition &arg : block.arguments)
385       if (containsPosition(arg))
386         return;
387   }
388 }
389 
390 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
391                                     const lsp::Position &pos,
392                                     std::vector<lsp::Location> &references) {
393   // Functor used to append all of the definitions/uses of the given SM
394   // definition to the reference list.
395   auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
396     references.push_back(getLocationFromLoc(sourceMgr, def.loc, uri));
397     for (const llvm::SMRange &use : def.uses)
398       references.push_back(getLocationFromLoc(sourceMgr, use, uri));
399   };
400 
401   llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, pos);
402 
403   // Check all definitions related to operations.
404   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
405     if (contains(op.loc, posLoc)) {
406       for (const auto &result : op.resultGroups)
407         appendSMDef(result.definition);
408       for (const auto &symUse : op.symbolUses)
409         if (contains(symUse, posLoc))
410           references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
411       return;
412     }
413     for (const auto &result : op.resultGroups)
414       if (isDefOrUse(result.definition, posLoc))
415         return appendSMDef(result.definition);
416     for (const auto &symUse : op.symbolUses) {
417       if (!contains(symUse, posLoc))
418         continue;
419       for (const auto &symUse : op.symbolUses)
420         references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
421       return;
422     }
423   }
424 
425   // Check all definitions related to blocks.
426   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
427     if (isDefOrUse(block.definition, posLoc))
428       return appendSMDef(block.definition);
429 
430     for (const AsmParserState::SMDefinition &arg : block.arguments)
431       if (isDefOrUse(arg, posLoc))
432         return appendSMDef(arg);
433   }
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // MLIRDocument: Hover
438 //===----------------------------------------------------------------------===//
439 
440 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,
441                                              const lsp::Position &hoverPos) {
442   llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, hoverPos);
443   llvm::SMRange hoverRange;
444 
445   // Check for Hovers on operations and results.
446   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
447     // Check if the position points at this operation.
448     if (contains(op.loc, posLoc))
449       return buildHoverForOperation(op.loc, op);
450 
451     // Check if the position points at the symbol name.
452     for (auto &use : op.symbolUses)
453       if (contains(use, posLoc))
454         return buildHoverForOperation(use, op);
455 
456     // Check if the position points at a result group.
457     for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
458       const auto &result = op.resultGroups[i];
459       if (!isDefOrUse(result.definition, posLoc, &hoverRange))
460         continue;
461 
462       // Get the range of results covered by the over position.
463       unsigned resultStart = result.startIndex;
464       unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
465                                         : op.resultGroups[i + 1].startIndex;
466       return buildHoverForOperationResult(hoverRange, op.op, resultStart,
467                                           resultEnd, posLoc);
468     }
469   }
470 
471   // Check to see if the hover is over a block argument.
472   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
473     if (isDefOrUse(block.definition, posLoc, &hoverRange))
474       return buildHoverForBlock(hoverRange, block);
475 
476     for (const auto &arg : llvm::enumerate(block.arguments)) {
477       if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
478         continue;
479 
480       return buildHoverForBlockArgument(
481           hoverRange, block.block->getArgument(arg.index()), block);
482     }
483   }
484   return llvm::None;
485 }
486 
487 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
488     llvm::SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
489   lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
490   llvm::raw_string_ostream os(hover.contents.value);
491 
492   // Add the operation name to the hover.
493   os << "\"" << op.op->getName() << "\"";
494   if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
495     os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
496   os << "\n\n";
497 
498   os << "Generic Form:\n\n```mlir\n";
499 
500   // Temporary drop the regions of this operation so that they don't get
501   // printed in the output. This helps keeps the size of the output hover
502   // small.
503   SmallVector<std::unique_ptr<Region>> regions;
504   for (Region &region : op.op->getRegions()) {
505     regions.emplace_back(std::make_unique<Region>());
506     regions.back()->takeBody(region);
507   }
508 
509   op.op->print(
510       os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
511   os << "\n```\n";
512 
513   // Move the regions back to the current operation.
514   for (Region &region : op.op->getRegions())
515     region.takeBody(*regions.back());
516 
517   return hover;
518 }
519 
520 lsp::Hover MLIRDocument::buildHoverForOperationResult(llvm::SMRange hoverRange,
521                                                       Operation *op,
522                                                       unsigned resultStart,
523                                                       unsigned resultEnd,
524                                                       llvm::SMLoc posLoc) {
525   lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
526   llvm::raw_string_ostream os(hover.contents.value);
527 
528   // Add the parent operation name to the hover.
529   os << "Operation: \"" << op->getName() << "\"\n\n";
530 
531   // Check to see if the location points to a specific result within the
532   // group.
533   if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
534     if ((resultStart + *resultNumber) < resultEnd) {
535       resultStart += *resultNumber;
536       resultEnd = resultStart + 1;
537     }
538   }
539 
540   // Add the range of results and their types to the hover info.
541   if ((resultStart + 1) == resultEnd) {
542     os << "Result #" << resultStart << "\n\n"
543        << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
544   } else {
545     os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
546        << "Types: ";
547     llvm::interleaveComma(
548         op->getResults().slice(resultStart, resultEnd), os,
549         [&](Value result) { os << "`" << result.getType() << "`"; });
550   }
551 
552   return hover;
553 }
554 
555 lsp::Hover
556 MLIRDocument::buildHoverForBlock(llvm::SMRange hoverRange,
557                                  const AsmParserState::BlockDefinition &block) {
558   lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
559   llvm::raw_string_ostream os(hover.contents.value);
560 
561   // Print the given block to the hover output stream.
562   auto printBlockToHover = [&](Block *newBlock) {
563     if (const auto *def = asmState.getBlockDef(newBlock))
564       printDefBlockName(os, *def);
565     else
566       printDefBlockName(os, newBlock);
567   };
568 
569   // Display the parent operation, block number, predecessors, and successors.
570   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
571      << "Block #" << getBlockNumber(block.block) << "\n\n";
572   if (!block.block->hasNoPredecessors()) {
573     os << "Predecessors: ";
574     llvm::interleaveComma(block.block->getPredecessors(), os,
575                           printBlockToHover);
576     os << "\n\n";
577   }
578   if (!block.block->hasNoSuccessors()) {
579     os << "Successors: ";
580     llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
581     os << "\n\n";
582   }
583 
584   return hover;
585 }
586 
587 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
588     llvm::SMRange hoverRange, BlockArgument arg,
589     const AsmParserState::BlockDefinition &block) {
590   lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
591   llvm::raw_string_ostream os(hover.contents.value);
592 
593   // Display the parent operation, block, the argument number, and the type.
594   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
595      << "Block: ";
596   printDefBlockName(os, block);
597   os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
598      << "Type: `" << arg.getType() << "`\n\n";
599 
600   return hover;
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // MLIRDocument: Document Symbols
605 //===----------------------------------------------------------------------===//
606 
607 void MLIRDocument::findDocumentSymbols(
608     std::vector<lsp::DocumentSymbol> &symbols) {
609   for (Operation &op : parsedIR)
610     findDocumentSymbols(&op, symbols);
611 }
612 
613 void MLIRDocument::findDocumentSymbols(
614     Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
615   std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
616 
617   // Check for the source information of this operation.
618   if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
619     // If this operation defines a symbol, record it.
620     if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
621       symbols.emplace_back(symbol.getName(),
622                            op->hasTrait<OpTrait::FunctionLike>()
623                                ? lsp::SymbolKind::Function
624                                : lsp::SymbolKind::Class,
625                            getRangeFromLoc(sourceMgr, def->scopeLoc),
626                            getRangeFromLoc(sourceMgr, def->loc));
627       childSymbols = &symbols.back().children;
628 
629     } else if (op->hasTrait<OpTrait::SymbolTable>()) {
630       // Otherwise, if this is a symbol table push an anonymous document symbol.
631       symbols.emplace_back("<" + op->getName().getStringRef() + ">",
632                            lsp::SymbolKind::Namespace,
633                            getRangeFromLoc(sourceMgr, def->scopeLoc),
634                            getRangeFromLoc(sourceMgr, def->loc));
635       childSymbols = &symbols.back().children;
636     }
637   }
638 
639   // Recurse into the regions of this operation.
640   if (!op->getNumRegions())
641     return;
642   for (Region &region : op->getRegions())
643     for (Operation &childOp : region.getOps())
644       findDocumentSymbols(&childOp, *childSymbols);
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // MLIRTextFileChunk
649 //===----------------------------------------------------------------------===//
650 
651 namespace {
652 /// This class represents a single chunk of an MLIR text file.
653 struct MLIRTextFileChunk {
654   MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
655                     const lsp::URIForFile &uri, StringRef contents,
656                     std::vector<lsp::Diagnostic> &diagnostics)
657       : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
658 
659   /// Adjust the line number of the given range to anchor at the beginning of
660   /// the file, instead of the beginning of this chunk.
661   void adjustLocForChunkOffset(lsp::Range &range) {
662     adjustLocForChunkOffset(range.start);
663     adjustLocForChunkOffset(range.end);
664   }
665   /// Adjust the line number of the given position to anchor at the beginning of
666   /// the file, instead of the beginning of this chunk.
667   void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
668 
669   /// The line offset of this chunk from the beginning of the file.
670   uint64_t lineOffset;
671   /// The document referred to by this chunk.
672   MLIRDocument document;
673 };
674 } // namespace
675 
676 //===----------------------------------------------------------------------===//
677 // MLIRTextFile
678 //===----------------------------------------------------------------------===//
679 
680 namespace {
681 /// This class represents a text file containing one or more MLIR documents.
682 class MLIRTextFile {
683 public:
684   MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
685                int64_t version, DialectRegistry &registry,
686                std::vector<lsp::Diagnostic> &diagnostics);
687 
688   /// Return the current version of this text file.
689   int64_t getVersion() const { return version; }
690 
691   //===--------------------------------------------------------------------===//
692   // LSP Queries
693   //===--------------------------------------------------------------------===//
694 
695   void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
696                       std::vector<lsp::Location> &locations);
697   void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
698                         std::vector<lsp::Location> &references);
699   Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
700                                  lsp::Position hoverPos);
701   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
702 
703 private:
704   /// Find the MLIR document that contains the given position, and update the
705   /// position to be anchored at the start of the found chunk instead of the
706   /// beginning of the file.
707   MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
708 
709   /// The context used to hold the state contained by the parsed document.
710   MLIRContext context;
711 
712   /// The full string contents of the file.
713   std::string contents;
714 
715   /// The version of this file.
716   int64_t version;
717 
718   /// The number of lines in the file.
719   int64_t totalNumLines;
720 
721   /// The chunks of this file. The order of these chunks is the order in which
722   /// they appear in the text file.
723   std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
724 };
725 } // namespace
726 
727 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
728                            int64_t version, DialectRegistry &registry,
729                            std::vector<lsp::Diagnostic> &diagnostics)
730     : context(registry, MLIRContext::Threading::DISABLED),
731       contents(fileContents.str()), version(version), totalNumLines(0) {
732   context.allowUnregisteredDialects();
733 
734   // Split the file into separate MLIR documents.
735   // TODO: Find a way to share the split file marker with other tools. We don't
736   // want to use `splitAndProcessBuffer` here, but we do want to make sure this
737   // marker doesn't go out of sync.
738   SmallVector<StringRef, 8> subContents;
739   StringRef(contents).split(subContents, "// -----");
740   chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
741       context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
742 
743   uint64_t lineOffset = subContents.front().count('\n');
744   for (StringRef docContents : llvm::drop_begin(subContents)) {
745     unsigned currentNumDiags = diagnostics.size();
746     auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
747                                                      docContents, diagnostics);
748     lineOffset += docContents.count('\n');
749 
750     // Adjust locations used in diagnostics to account for the offset from the
751     // beginning of the file.
752     for (lsp::Diagnostic &diag :
753          llvm::drop_begin(diagnostics, currentNumDiags)) {
754       chunk->adjustLocForChunkOffset(diag.range);
755 
756       if (!diag.relatedInformation)
757         continue;
758       for (auto &it : *diag.relatedInformation)
759         if (it.location.uri == uri)
760           chunk->adjustLocForChunkOffset(it.location.range);
761     }
762     chunks.emplace_back(std::move(chunk));
763   }
764   totalNumLines = lineOffset;
765 }
766 
767 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
768                                   lsp::Position defPos,
769                                   std::vector<lsp::Location> &locations) {
770   MLIRTextFileChunk &chunk = getChunkFor(defPos);
771   chunk.document.getLocationsOf(uri, defPos, locations);
772 
773   // Adjust any locations within this file for the offset of this chunk.
774   if (chunk.lineOffset == 0)
775     return;
776   for (lsp::Location &loc : locations)
777     if (loc.uri == uri)
778       chunk.adjustLocForChunkOffset(loc.range);
779 }
780 
781 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
782                                     lsp::Position pos,
783                                     std::vector<lsp::Location> &references) {
784   MLIRTextFileChunk &chunk = getChunkFor(pos);
785   chunk.document.findReferencesOf(uri, pos, references);
786 
787   // Adjust any locations within this file for the offset of this chunk.
788   if (chunk.lineOffset == 0)
789     return;
790   for (lsp::Location &loc : references)
791     if (loc.uri == uri)
792       chunk.adjustLocForChunkOffset(loc.range);
793 }
794 
795 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
796                                              lsp::Position hoverPos) {
797   MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
798   Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
799 
800   // Adjust any locations within this file for the offset of this chunk.
801   if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
802     chunk.adjustLocForChunkOffset(*hoverInfo->range);
803   return hoverInfo;
804 }
805 
806 void MLIRTextFile::findDocumentSymbols(
807     std::vector<lsp::DocumentSymbol> &symbols) {
808   if (chunks.size() == 1)
809     return chunks.front()->document.findDocumentSymbols(symbols);
810 
811   // If there are multiple chunks in this file, we create top-level symbols for
812   // each chunk.
813   for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
814     MLIRTextFileChunk &chunk = *chunks[i];
815     lsp::Position startPos(chunk.lineOffset);
816     lsp::Position endPos((i == e - 1) ? totalNumLines - 1
817                                       : chunks[i + 1]->lineOffset);
818     lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
819                                lsp::SymbolKind::Namespace,
820                                /*range=*/lsp::Range(startPos, endPos),
821                                /*selectionRange=*/lsp::Range(startPos));
822     chunk.document.findDocumentSymbols(symbol.children);
823 
824     // Fixup the locations of document symbols within this chunk.
825     if (i != 0) {
826       SmallVector<lsp::DocumentSymbol *> symbolsToFix;
827       for (lsp::DocumentSymbol &childSymbol : symbol.children)
828         symbolsToFix.push_back(&childSymbol);
829 
830       while (!symbolsToFix.empty()) {
831         lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
832         chunk.adjustLocForChunkOffset(symbol->range);
833         chunk.adjustLocForChunkOffset(symbol->selectionRange);
834 
835         for (lsp::DocumentSymbol &childSymbol : symbol->children)
836           symbolsToFix.push_back(&childSymbol);
837       }
838     }
839 
840     // Push the symbol for this chunk.
841     symbols.emplace_back(std::move(symbol));
842   }
843 }
844 
845 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
846   if (chunks.size() == 1)
847     return *chunks.front();
848 
849   // Search for the first chunk with a greater line offset, the previous chunk
850   // is the one that contains `pos`.
851   auto it = llvm::upper_bound(
852       chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
853         return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
854       });
855   MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
856   pos.line -= chunk.lineOffset;
857   return chunk;
858 }
859 
860 //===----------------------------------------------------------------------===//
861 // MLIRServer::Impl
862 //===----------------------------------------------------------------------===//
863 
864 struct lsp::MLIRServer::Impl {
865   Impl(DialectRegistry &registry) : registry(registry) {}
866 
867   /// The registry containing dialects that can be recognized in parsed .mlir
868   /// files.
869   DialectRegistry &registry;
870 
871   /// The files held by the server, mapped by their URI file name.
872   llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
873 };
874 
875 //===----------------------------------------------------------------------===//
876 // MLIRServer
877 //===----------------------------------------------------------------------===//
878 
879 lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
880     : impl(std::make_unique<Impl>(registry)) {}
881 lsp::MLIRServer::~MLIRServer() = default;
882 
883 void lsp::MLIRServer::addOrUpdateDocument(
884     const URIForFile &uri, StringRef contents, int64_t version,
885     std::vector<Diagnostic> &diagnostics) {
886   impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
887       uri, contents, version, impl->registry, diagnostics);
888 }
889 
890 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
891   auto it = impl->files.find(uri.file());
892   if (it == impl->files.end())
893     return llvm::None;
894 
895   int64_t version = it->second->getVersion();
896   impl->files.erase(it);
897   return version;
898 }
899 
900 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
901                                      const Position &defPos,
902                                      std::vector<Location> &locations) {
903   auto fileIt = impl->files.find(uri.file());
904   if (fileIt != impl->files.end())
905     fileIt->second->getLocationsOf(uri, defPos, locations);
906 }
907 
908 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
909                                        const Position &pos,
910                                        std::vector<Location> &references) {
911   auto fileIt = impl->files.find(uri.file());
912   if (fileIt != impl->files.end())
913     fileIt->second->findReferencesOf(uri, pos, references);
914 }
915 
916 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
917                                                 const Position &hoverPos) {
918   auto fileIt = impl->files.find(uri.file());
919   if (fileIt != impl->files.end())
920     return fileIt->second->findHover(uri, hoverPos);
921   return llvm::None;
922 }
923 
924 void lsp::MLIRServer::findDocumentSymbols(
925     const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
926   auto fileIt = impl->files.find(uri.file());
927   if (fileIt != impl->files.end())
928     fileIt->second->findDocumentSymbols(symbols);
929 }
930